Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save cre-mer/7ae562caff81009d4d3377e7572881b8 to your computer and use it in GitHub Desktop.
Save cre-mer/7ae562caff81009d4d3377e7572881b8 to your computer and use it in GitHub Desktop.
Circuit to prove that a list is sorted in an ascending sequence using only addition, multiplication, and equality checks
sorted_list = [3, 9, 10, 404, 1337]
unsorted_list = [3, 9, 10, 1337, 404]
# define default constant
MAX = 4
def decimal_to_binary(dec, max = MAX):
convert a decimal number `dec =< max` to binary, returning each binary index as a list
# to simplify the code, we only allow positive numbers
assert dec >= 0, "dec value MUST be positive"
# convert decimal to binary representation and remove '0b' characters from binary representation
dec_as_bin = bin(dec)[2:]
assert len(dec_as_bin) <= max, f'dec value MUST be smaller than or equal to `max`. Expected: dec <= {max}, is {dec} > {max}'
def validate_binaries(binaries, expected_value):
value = 0
exp = len(binaries) - 1
for binary in binaries:
# enforce that each binary is either 1 or 0
assert int(binary) * (int(binary) - 1) == 0, 'invalid binary'
# calculate new value and exponant for the next round
if int(binary) == 1:
value += 2 ** exp
exp -= 1
assert expected_value == value, f'values mismatch, expected {expected_value}, got {value}'
def define_midpoint(num_zeros = MAX + 1):
define a midpoint
generate a number, where the binary representation has MSB == 1, and the rest of the bits are 0s
binary = '1' + '0' * num_zeros
result = int(binary, 2)
return result
def compute_diff_relative_to_midpoint(midpoint, u, v):
calculate midpoint + (u - v)
to avoid a range error, the binary representation of the midpoint MUST use at least 1 bit more than the binary representations of u and v
return MSB of midpoint + (u - v)
midpoint_as_bin = bin(midpoint)[2:]
u_as_bin = bin(u)[2:]
v_as_bin = bin(v)[2:]
assert len(midpoint_as_bin) > len(u_as_bin), f'mipoint\'s binary representation must use 1 bit more than u'
assert len(midpoint_as_bin) > len(v_as_bin), f'mipoint\'s binary representation must use 1 bit more than v'
delta = u - v
mid_plus_delta = bin(midpoint + delta)[2:]
if len(mid_plus_delta) < len(midpoint_as_bin):
return 0
return 1
def is_list_sorted(list):
MAX = 11 # hardcoded to fit at max 2047
MIDPOINT = define_midpoint(MAX + 1)
prev_value = None
for value in list:
# 1. convert decimal to binary
value_as_bin = decimal_to_binary(value, MAX)
# 2. make sure the binaries are valid
validate_binaries(value_as_bin, value)
# 3. calculate u - v
if prev_value == None: # skip first item
prev_value = value
msb = compute_diff_relative_to_midpoint(MIDPOINT, prev_value, value)
prev_value = value
if msb == 0:
print(f'list: {list} is not sorted\n')
return False
print(f'list: {list} is sorted\n')
return True
assert is_list_sorted(sorted_list) == True, 'list should be sorted'
assert is_list_sorted(unsorted_list) == False, 'list should not be sorted'
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment