Skip to content

Instantly share code, notes, and snippets.

@cwindolf
Last active February 3, 2023 16:20
Show Gist options
  • Save cwindolf/737abf69c2251b9b733a168e92449a6b to your computer and use it in GitHub Desktop.
Save cwindolf/737abf69c2251b9b733a168e92449a6b to your computer and use it in GitHub Desktop.
1D optionally normalized, optionally weighted, optionally centered cross-correlation in PyTorch (+ SciPy fallback), with API like F.conv1d
try:
import torch
import torch.nn.functional as F
HAVE_TORCH = True
except ImportError:
HAVE_TORCH = False
def normxcorr1d(
template,
x,
weights=None,
centered=True,
normalized=True,
padding="same",
conv_engine="torch",
):
"""normxcorr1d: Normalized cross-correlation, optionally weighted
The API is like torch's F.conv1d, except I have accidentally
changed the position of input/weights -- template acts like weights,
and x acts like input.
Returns the cross-correlation of `template` and `x` at spatial lags
determined by `mode`. Useful for estimating the location of `template`
within `x`.
This might not be the most efficient implementation -- ideas welcome.
It uses a direct convolutional translation of the formula
corr = (E[XY] - EX EY) / sqrt(var X * var Y)
This also supports weights! In that case, the usual adaptation of
the above formula is made to the weighted case -- and all of the
normalizations are done per block in the same way.
Arguments
---------
template : tensor, shape (num_templates, length)
The reference template signal
x : tensor, 1d shape (length,) or 2d shape (num_inputs, length)
The signal in which to find `template`
weights : tensor, shape (length,)
Will use weighted means, variances, covariances if supplied.
centered : bool
If true, means will be subtracted (per weighted patch).
normalized : bool
If true, normalize by the variance (per weighted patch).
padding : int, optional
How far to look? if unset, we'll use half the length
conv_engine : string, one of "torch", "numpy"
What library to use for computing cross-correlations.
If numpy, falls back to the scipy correlate function.
Returns
-------
corr : tensor
"""
if conv_engine == "torch":
assert HAVE_TORCH
conv1d = F.conv1d
npx = torch
elif conv_engine == "numpy":
conv1d = scipy_conv1d
npx = np
else:
raise ValueError(f"Unknown conv_engine {conv_engine}")
x = npx.atleast_2d(x)
num_templates, length = template.shape
num_inputs, length_ = template.shape
assert length == length_
# generalize over weighted / unweighted case
device_kw = {} if conv_engine == "numpy" else dict(device=x.device)
ones = npx.ones((1, 1, length), dtype=x.dtype, **device_kw)
no_weights = weights is None
if no_weights:
weights = ones
wt = template[:, None, :]
wt2 = npx.square(template[:, None, :])
else:
assert weights.shape == (length,)
weights = weights[None, None]
wt = template[:, None, :] * weights
wt2 = npx.square(template)[:, None, :] * weights
# conv1d valid rule:
# (B,1,L),(O,1,L)->(B,O,L)
# compute expectations
# how many points in each window? seems necessary to normalize
# for numerical stability.
N = conv1d(ones, weights, padding=padding)
if centered:
Et = conv1d(ones, wt, padding=padding) / N
Ex = conv1d(x[:, None, :], weights, padding=padding) / N
# compute (weighted) covariance
# important: the formula E[XY] - EX EY is well-suited here,
# because the means are naturally subtracted correctly
# patch-wise. you couldn't pre-subtract them!
cov = conv1d(x[:, None, :], wt, padding=padding) / N
if centered:
cov -= Ex * Et
# compute variances for denominator, using var X = E[X^2] - (EX)^2
if normalized:
var_template = conv1d(
ones, wt2, padding=padding
) / N
var_x = conv1d(
npx.square(x)[:, None, :], weights, padding=padding
) / N
if centered:
var_template -= npx.square(Et)
var_x -= npx.square(Ex)
# now find the final normxcorr
corr = cov # renaming for clarity
if normalized:
corr /= npx.sqrt(var_x * var_template)
# get rid of NaNs in zero-variance areas
corr[~npx.isfinite(corr)] = 0
return corr
def scipy_conv1d(input, weights, padding="valid"):
"""SciPy translation of torch F.conv1d"""
from scipy.signal import correlate
n, c_in, length = input.shape
c_out, in_by_groups, kernel_size = weights.shape
assert in_by_groups == c_in == 1
if padding == "same":
mode = "same"
length_out = length
elif padding == "valid":
mode = "valid"
length_out = length - 2 * (kernel_size // 2)
elif isinstance(padding, int):
mode = "valid"
input = np.pad(input, [*[(0, 0)] * (input.ndim - 1), (padding, padding)])
length_out = length - (kernel_size - 1) + 2 * padding
else:
raise ValueError(f"Unknown padding {padding}")
output = np.zeros((n, c_out, length_out), dtype=input.dtype)
for m in range(n):
for c in range(c_out):
output[m, c] = correlate(input[m, 0], weights[c, 0], mode=mode)
return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment