Last active
November 24, 2020 13:07
-
-
Save justanhduc/adbcc06dfd72e3a80026a30c9bd45f37 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def batch_pairwise_sqdist(x: T.Tensor, y: T.Tensor): | |
""" | |
Calculates the pair-wise square distance between two sets of points. | |
To get the Euclidean distance, explicit square root needs to be applied | |
to the output. | |
:param x: | |
a tensor of shape ``(m, nx, d)`` or ``(nx, d)``. | |
If the tensor dimension is 2, the tensor batch dim is broadcasted. | |
:param y: | |
a tensor of shape ``(m, ny, d)`` or ``(ny, d)``. | |
If the tensor dimension is 2, the tensor batch dim is broadcasted. | |
:param c_code: | |
whether to use a C++ implementation. | |
Default: ``True`` when the CUDA extension is installed. ``False`` otherwise. | |
:return: | |
a tensor containing the exhaustive square distance between every pair of points | |
in `x` and `y` from the same batch. | |
""" | |
xx = T.sum(x ** 2, -1) | |
yy = T.sum(y ** 2, -1) | |
zz = T.matmul(x, y.transpose(-1, -2).contiguous()) | |
rx = xx.unsqueeze(-2).expand_as(zz.transpose(-2, -1)) | |
ry = yy.unsqueeze(-2).expand_as(zz) | |
P = (rx.transpose(-2, -1) + ry - 2. * zz) | |
return P |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment