Created
March 14, 2019 07:42
-
-
Save DIYer22/0bb621cea2d817b3ce85a610b73a5a55 to your computer and use it in GitHub Desktop.
赵振宇的 numpy 加速求每一对向量间的 L2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from boxx import * | |
import numpy as np | |
def zzy(M): | |
# M = np.array([[1,1],[2,2],[3,3]]) | |
a2 = np.sum(M*M, axis = 1).reshape(1,-1) | |
b2 = a2.T | |
ab = M.dot(M.T) | |
res = a2 - 2*ab + b2 | |
g() | |
return res**.5 | |
ma = randomm((1000,128), 10)/15 | |
#ma = M | |
#ma = np.array([[0, 1], | |
# [0, 2], | |
# [0, 1]]) | |
h, w = ma.shape | |
# target: ((a-b)**2).sum()**.5 | |
def old(ma): | |
l2 = np.zeros((h,h)) | |
for i in range(h-1): | |
for j in range(i+1, h): | |
l2[j, i] = l2[i, j] = ((ma[i]-ma[j])**2).sum()**.5 | |
return l2 | |
# (a-b)**2 = a**2 - 2ab + b**2 | |
def mynew(ma): | |
ab = np.matmul(ma.T[...,None], ma.T[...,None,:]) | |
poww = ma**2 | |
a_b2 = poww.T[...,None] + poww.T[...,None,:] - 2*ab | |
l2 = (a_b2).sum(0)**.5 | |
# g() | |
return l2 | |
with timeit('old'): | |
l21 = old(ma) | |
with timeit('new'): | |
l22 = mynew(ma) | |
with timeit('zzy'): | |
l23 = zzy(ma) | |
print('diff:', (np.abs((l21 - l22))).sum()) | |
print('diff2:', (np.abs((l21 - l23))).sum()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment