Skip to content

Instantly share code, notes, and snippets.

@Sonictherocketman
Last active September 23, 2024 04:51
Show Gist options
  • Save Sonictherocketman/2a1c9ec1cd5bb2e80df6e22d7428f2bb to your computer and use it in GitHub Desktop.
Save Sonictherocketman/2a1c9ec1cd5bb2e80df6e22d7428f2bb to your computer and use it in GitHub Desktop.
A very crude galaxy simulator.
"""
Code by:
Brian Schrader
09-21-2024
"""
from datetime import datetime
import logging
import os.path
import json
import time
import numpy as np
from numpy.linalg import inv
import matplotlib.pyplot as plt
from matplotlib import animation
from matplotlib import cm
from matplotlib import colors
from scipy.spatial.distance import euclidean
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
G = -6.6743 * 10e-11
R_ones = None # This value is just based on the size of the matrix given.
# It doesn't change once the params are set. This is a cache.
def R(P_n):
P_n_T = P_n.T
# This is from here:
# https://stackoverflow.com/a/46700369
return np.linalg.norm(P_n_T - P_n_T[:,None], axis=-1)
def A(P_n, GM):
global R_ones
G_M_P = GM * P_n
R_cu = R(P_n) ** 3
if R_ones is None:
R_ones = np.ones(R_cu.shape)
R_recp = np.divide(R_ones, R_cu, out=np.zeros_like(R_ones), where=R_cu!=0)
return G_M_P * np.sum(R_recp, axis=0)
def P_V(P_n1, V_n1, GM, dt=1):
a_n = A(P_n1, GM)
V_n = V_n1 + a_n * dt
P_n = P_n1 + V_n1 * dt + 0.5 * a_n * dt**2
return P_n, V_n
def plot(Ps, limit, anim_step=10):
logger.info('Plotting...')
cmap = colors.ListedColormap(['black', 'lightgreen'])
bounds = [0,1]
norm = colors.BoundaryNorm(bounds, cmap.N)
fig, ax = plt.subplots(subplot_kw={"projection": "3d"})
ax.set_title(f'P: Position t=0')
ax.set_xlim([0, limit])
ax.set_ylim([0, limit])
ax.set_zlim([0, limit])
ax.grid(True)
def update(args):
t, P = args
ax.clear()
ax.set_xlim([-limit, limit])
ax.set_ylim([-limit, limit])
ax.set_zlim([-limit, limit])
ax.set_title(f'P: Position {t=}')
plot = ax.scatter(P[0], P[1], P[2], s=1)
return plot
ani = animation.FuncAnimation(
fig=fig,
func=update,
frames=list(enumerate(Ps))[::anim_step],
interval=100,
)
ts = datetime.now().strftime('%Y-%M-%d-%H-%M')
ani.save(f'simulation-{ts}.m4v')
plt.close()
def simulate(
P_0,
V_0,
M,
steps,
args=None,
update_interval=100,
):
logger.info('Running...')
Ps = [P_0]
V_prev = V_0
t = 0
GM = G * M
start = t0 = time.time()
for t in range(steps):
P_t, V_t = P_V(Ps[-1], V_prev, GM)
Ps.append(P_t)
V_prev = V_t
if t % update_interval == 0:
logger.info(f'{t=}: d[{update_interval}] = {(time.time()-t0):.2f}s')
t0 = time.time()
logger.info(
f'Simulation complete! ({t=} steps) - total time = {(time.time()-start):.2f}s'
)
return Ps
def main(n_nodes=1_500, p_scale=2e11, steps=100_000, v_scale=1e5, m_scale=2e30):
logger.info(f'Settings: {n_nodes=}, {p_scale=}, {v_scale=}, {m_scale=}, {steps=}')
P_0 = np.random.randint(-p_scale, p_scale, size=(3, n_nodes))
V_0 = np.random.randint(-v_scale, v_scale, size=(3, n_nodes))
M = m_scale * np.random.uniform(0.7, 10, size=n_nodes)
Ps = simulate(P_0, V_0, M, steps)
plot(Ps, p_scale * 1.2)
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment