Skip to content

Instantly share code, notes, and snippets.

@SerpentChris
Last active September 15, 2017 19:19
Show Gist options
  • Save SerpentChris/05807669575fd4fe09a1fe0d9b7d5f07 to your computer and use it in GitHub Desktop.
Save SerpentChris/05807669575fd4fe09a1fe0d9b7d5f07 to your computer and use it in GitHub Desktop.
Calculates the partition function with Python
# Copyright (c) 2017 Christian Calderon
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from __future__ import print_function, division
import math
import sys
if 3 == sys.version_info.major:
xrange = range
def range(*args):
return list(xrange(*args))
KNOWNS = [
1, 1, 2, 3, 5, 7, 11, 15, 22,
30, 42, 56, 77, 101, 135, 176,
231, 297, 385, 490, 627, 792, 1002,
1255, 1575, 1958, 2436, 3010, 3718,
4565, 5604, 6842, 8349, 10143, 12310,
14883, 17977, 21637, 26015, 31185, 37338,
44583, 53174, 63261, 75175
]
PENTS = [
1, 2, 5, 7, 12, 15, 22,
26, 35, 40, 51, 57, 70,
77, 92, 100, 117, 126,
145, 155, 176, 187, 210,
222, 247, 260, 287, 301,
330, 345, 376, 392, 425,
442, 477, 495, 532, 551,
590, 610, 651, 672, 715,
737, 782, 805, 852, 876,
925, 950, 1001, 1027,
1080, 1107, 1162, 1190,
1247, 1276
]
TIME_UNITS = (
('years', 31556952),
('weeks', 604800),
('days', 86400),
('hours', 3600),
('minutes', 60)
)
def print_time(t):
result = 'Time: '
for name, amount in TIME_UNITS:
r, t = divmod(t, amount)
if r:
result += '%s %d, ' % (name, r)
result += '%3.1f seconds.' % t
print(result)
def partitions1(n, VALS={0:1}):
"""Partition count via restricted partition identity, w/ simulated recursion."""
if n == 0:
return 1
# uses Cantor's pairing function to assign an index to each pair:
# idx(n, k) = (n + k)*(n + k + 1)/2 + k
# the largest index goes to the pair n, n
# idx(n, n) = (n + n)*(n + n + 1)/2 + n
# = 2n*(2n + 1)/2 + n
# = 2n**2 + 2n
ns = range(n)
stack = [(n, n - i) for i in ns]
push = stack.append
pop = stack.pop
while stack:
n, k = stack[-1]
idx = (n + k)*(n + k + 1)//2 + k
if idx in VALS:
pop()
continue
if n == 0 or k == 0 or n < k:
VALS[idx] = 0
pop()
continue
if k == 1 or k == n or k + 1 == n:
VALS[idx] = 1
pop()
continue
if k == 2:
VALS[idx] = n//2
pop()
continue
t1_idx = n*(n + 1)//2 + k
t2_idx = (n + k - 2)*(n + k - 1)//2 + k - 1
t1_check = t1_idx in VALS
t2_check = t2_idx in VALS
if t1_check and t2_check:
VALS[idx] = VALS[t1_idx] + VALS[t2_idx]
pop()
continue
if not t1_check:
push((n - k, k))
if not t2_check:
push((n - 1, k - 1))
return sum(VALS[2*n*(n - i +1) + (i*(i - 3) >> 1)] for i in ns)
def partitions2(n, VALS=KNOWNS[:]):
"""Partition count via pentagonal number theorem, w/ simulated recursion."""
if n < len(VALS):
return VALS[n]
else:
VALS += (n - len(VALS) + 1)*[0]
pents = []
append = pents.append
i = 1
p = 1
while p <= n:
append(p)
i += 1
j = ((i - 1) >> 1) + 1
if not i&1:
j = -j
p = (3*j*j - j) >> 1
append(p)
stack = [n]
push = stack.append
pop = stack.pop
while stack:
m = stack[-1]
if VALS[m]:
pop()
continue
i = 0
sign = 1
k = m - 1
result = 0
keep_adding = True
while k >= 0:
if keep_adding and VALS[k]:
result += sign*VALS[k]
elif not VALS[k]:
push(k)
keep_adding = False
i += 1
if not i&1:
sign = -sign
k = m - pents[i]
if keep_adding:
VALS[m] = result
pop()
return VALS[n]
def partitions3(n, vals=KNOWNS[:], pents=PENTS):
"""Partition count via pentagonal number theorem."""
l = len(vals)
if n < l:
return vals[n]
else:
vals.extend((n - l + 1)*[0])
# 0 -> 1
# 1 -> -1
# 2 -> 2
# 3 -> -2
# 4 -> 3
# 5 -> -3
# 6 -> 4
# 7 -> -4
# j -> (-1)^j * (floor(j/2) + 1)
# -> (-1)**(j&1) * (j//2 + 1) = k
#
# pentagonal numbers
# k(3k - 1)/2 = P_k
# k(3k - 1) - 2P_k = 0
# 3k^2 - k - 2P_k = 0
# k = (1 +- sqrt(1 + 24P_k))/6
#
# If k > 0 then floor(j/2) + 1 == k => j = 2(k - 1)
# If k < 0 then j = 2(-k - 1) + 1
#
# k1 = floor((1 + sqrt(1 + 24n))/6), j1 = 2(k1 - 1)
# k2 = ceil((1 - sqrt(1 + 24n))/6), j2 = 2(-k2 - 1) + 1
# j1 > j2 => compute up to (including) j1 + 1
# j2 > j1 => compute up to (including) j2 + 1
if n > pents[-1]:
k1 = int((1 + (1 + 24*n)**0.5)/6)
j1 = 2*(k1 - 1)
k2 = int((1 - (1 + 24*n)**0.5)/6)
j2 = 2*(-k2 - 1) + 1
max_j = max(j1, j2) + 2
start_j = len(pents)
pents.extend([k*(3*k - 1)//2 for k in [(-1)**(j&1)*(j//2 + 1) for j in range(start_j, max_j)]])
q_is = range(len(pents))
for m in range(l, n+1):
result = 0
sign = -1
for i in q_is:
if not i&1:
sign = -sign
p_i = pents[i]
if m < p_i:
break
result += sign*vals[m - p_i]
vals[m] = result
return vals[n]
def test():
for i, val in enumerate(KNOWNS):
result = partitions1(i)
assert val == result
print('TEST PASSED')
def timeit(n, f):
print('Running', f.__name__, '...')
start = time.time()
result = f(n)
end = time.time()
print_time(end - start)
print('Result:', result)
return result
def compare(n):
r1 = timeit(n, partitions1)
r2 = timeit(n, partitions2)
r3 = timeit(n, partitions3)
if r1 == r2 == r3:
print('PASSED')
else:
print('FAILED')
if __name__ == '__main__':
import cProfile
import time
if sys.argv[1] == 'calc3':
timeit(int(sys.argv[2]), partitions3)
elif sys.argv[1] == 'calc2':
timeit(int(sys.argv[2]), partitions2)
elif sys.argv[1] == 'calc1':
timeit(int(sys.argv[2]), partitions1)
elif sys.argv[1] == 'test':
test()
elif sys.argv[1] == 'profile':
cProfile.run('partitions{}({})'.format(sys.argv[2], sys.argv[3]))
elif sys.argv[1] == 'compare':
compare(int(sys.argv[2]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment