Skip to content

Instantly share code, notes, and snippets.

@bhpfelix
Created October 15, 2019 04:42
Show Gist options
  • Save bhpfelix/ab4d0f6c169f20891cd8fbb6acfa75ab to your computer and use it in GitHub Desktop.
Save bhpfelix/ab4d0f6c169f20891cd8fbb6acfa75ab to your computer and use it in GitHub Desktop.
Rotation preserves probability
import numpy as np
from numpy.linalg import norm
from scipy.special import erf
def get_rotation_matrix(x):
"""
Get rotation matrix for the space that aligns vector x with y = [1, 0, 0, 0, 0, ..., 0]
See: https://math.stackexchange.com/questions/598750/finding-the-rotation-matrix-in-n-dimensions
"""
u = x / norm(x)
y = np.zeros_like(u)
y[0] = 1.
v = y - np.dot(u, y) * u
v /= norm(v)
cost = np.dot(x, y) / (norm(x) * norm(y))
sint = np.sqrt(1. - cost**2.)
outer_v = np.outer(v, v)
outer_u = np.outer(u, u)
Rt = np.array([[cost, -sint], [sint, cost]])
uv = np.vstack([u, v])
R = np.eye(x.shape[0]) - outer_v - outer_u + uv.T.dot(Rt).dot(uv)
assert np.sum(np.abs(R.dot(u) - y)) < 1e-12
assert np.sum(np.abs(R.dot(R.T) - np.eye(x.shape[0]))) < 1e-12
return R
def experiment(mean, sigma, num_points):
## Run a simple experiment to determine F1 Score empirically
pts = np.random.multivariate_normal(np.zeros_like(mean), np.eye(mean.shape[0])*(sigma**2), num_points)
distance_to_origin = norm(pts, axis=1)
distance_to_mean = norm(pts - mean, axis=1)
FP = np.sum(distance_to_mean < distance_to_origin) / num_points
F1 = 1. - FP
## Show that the constructed rotation of data points preserves F1 Score
R = get_rotation_matrix(mean)
rotated_pts = pts.dot(R)
# now we only need to look at the first dimesion to determine classification outcome
rotated_FP = np.sum(rotated_pts[:, 0] > norm(mean) / 2.) / num_points
rotated_F1 = 1. - rotated_FP
return F1, rotated_F1
def exact_soln(mean, sigma):
return 0.5 * (1. + erf(norm(mean) / (2. * np.sqrt(2) * sigma)))
def main():
mean_vec = np.random.randn(100) + 0.5
sigma = 5.
F1, rotated_F1 = experiment(mean_vec, sigma, 5000000)
exact_F1 = exact_soln(mean_vec, sigma)
print("F1 score in original space: %.8f" % F1)
print("F1 score in rotated space : %.8f" % rotated_F1)
print("Exact Analytical Solution : %.8f" % exact_F1)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment