Skip to content

Instantly share code, notes, and snippets.

@k-ye
Created August 30, 2021 06:17
Show Gist options
  • Save k-ye/ec2dd5e0cdb2c3d034aff8f55a85a31e to your computer and use it in GitHub Desktop.
Save k-ye/ec2dd5e0cdb2c3d034aff8f55a85a31e to your computer and use it in GitHub Desktop.
import taichi as ti
# import math
ti.init(ti.gpu)
# global control
paused = ti.field(ti.i32, ())
# gravitational constant 6.67408e-11, using 1 for simplicity
G = 1
# number of planets
N = 8
# unit mass
m = 100
# galaxy size
galaxy_size = 0.2
# planet radius (for rendering)
planet_radius = 10
# init vel
init_vel = 40
# time-step size
h = 1e-6
# substepping
substepping = 10
PI = 3.141592653
# center of the screen
center = ti.Vector.field(2, ti.f32, ())
# pos, vel and force of the planets
# Nx2 vectors
pos = ti.Vector.field(2, ti.f32, N)
vel = ti.Vector.field(2, ti.f32, N)
force = ti.Vector.field(2, ti.f32, N)
@ti.kernel
def initialize():
center[None]=[0.5,0.5]
for i in range(N):
theta = ti.random() * 4 * PI
r = (ti.sqrt(ti.random()) * 0.7 + 0.3) * galaxy_size
offset = r * ti.Vector([ti.cos(theta), ti.sin(theta)])
pos[i] = center[None]+offset
vel[i] = [-offset.y, offset.x]
vel[i] *= init_vel
@ti.kernel
def compute_force():
# clear force
for i in range(N):
force[i] = [0.0, 0.0]
# compute gravitational force
for i in range(N):
p = pos[i]
for j in range(N):
#if i > j: # bad memory footprint and load balance
# diff = p-pos[j]
# r = diff.norm(1e-5)
# # gravitational force -(GMm / r^2) * (diff/r) for i
# f = -G * m * m * (1.0/r)**3 * diff
# # assign to each particle
# force[i] += f
# force[j] += -f
if i != j: # double the computation for a better memory footprint and load balance
diff = p-pos[j]
r = diff.norm(1e-5)
# gravitational force -(GMm / r^2) * (diff/r) for i
f = -G * m * m * (1.0/r)**3 * diff
# assign to each particle
force[i] += f
@ti.kernel
def update():
dt = h/substepping
for i in range(N):
#symplectic euler
vel[i] += dt*force[i]/m
pos[i] += dt*vel[i]
gui = ti.GUI('N-body problem', (512, 512))
initialize()
while gui.running:
for i in range(substepping):
compute_force()
update()
gui.clear(0x112F41)
gui.circles(pos.to_numpy(), color=0xffffff, radius=planet_radius)
gui.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment