Skip to content

Instantly share code, notes, and snippets.

@alexguirre
Last active August 20, 2024 20:21
Show Gist options
  • Save alexguirre/fb1fad0a4a91e443111746cc2dfe8f01 to your computer and use it in GitHub Desktop.
Save alexguirre/fb1fad0a4a91e443111746cc2dfe8f01 to your computer and use it in GitHub Desktop.
import numpy as np
import matplotlib.pyplot as plt
from functools import cache
def normalize_vecs(vecs):
return vecs / np.linalg.norm(vecs, axis=1)[:, np.newaxis]
@cache
def generate_vectors_structured(num_samples):
"""Generates vectors around the sphere, at regular intervals."""
# Uses the Fibonnaci lattice to generate evenly distributed points on a sphere
# https://arxiv.org/pdf/0912.4540.pdf
# https://extremelearning.com.au/how-to-evenly-distribute-points-on-a-sphere-more-effectively-than-the-canonical-fibonacci-lattice/
golden_ratio = (1 + 5 ** 0.5) / 2.0
i = np.arange(0, num_samples)
theta = np.arccos(1 - (2 * (i + 0.5)) / num_samples)
phi = 2 * np.pi * i / golden_ratio
vectors = np.empty((num_samples, 3))
vectors[:, 0] = np.sin(theta) * np.cos(phi)
vectors[:, 1] = np.cos(theta)
vectors[:, 2] = np.sin(theta) * np.sin(phi)
return vectors
@cache
def generate_vectors_pr(num_samples_per_axis):
X = np.linspace(-1.0, 1.0, num=num_samples_per_axis)
Y = np.linspace(-1.0, 1.0, num=num_samples_per_axis)
Z = np.linspace(-1.0, 1.0, num=num_samples_per_axis)
XX, YY, ZZ = np.meshgrid(X, Y, Z)
vectors = np.vstack([XX.ravel(), YY.ravel(), ZZ.ravel()]).T
return normalize_vecs(vectors)
def plot_vectors(vectors):
"""Display a list of 3D vectors in a plot."""
fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
ax.quiver([0], [0], [0], [3], [0], [0]).set_color("red") # main axes
ax.quiver([0], [0], [0], [0], [3], [0]).set_color("green")
ax.quiver([0], [0], [0], [0], [0], [3]).set_color("blue")
X = Y = Z = np.zeros(len(vectors))
U, V, W = zip(*(vectors))
ax.quiver(X, Y, Z, U, V, W)
ax.set_xlim([-1.5, 1.5])
ax.set_ylim([-1.5, 1.5])
ax.set_zlim([-1.5, 1.5])
ax.set_box_aspect([1.0, 1.0, 1.0])
plt.show()
def main():
n = 5
vectors = generate_vectors_pr(n)
plot_vectors(vectors)
vectors = generate_vectors_structured(n*n*n)
plot_vectors(vectors)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment