Last active
August 5, 2022 18:04
-
-
Save itzmeanjan/d4853347dfdfa853993f5ea059824de6 to your computer and use it in GitHub Desktop.
Montgomery Modular Arithmetic for 256 -bit `secp256k1` Prime Field
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/python3 | |
from math import ceil | |
from typing import List, Tuple | |
from random import randint | |
def bit_count(num: int) -> int: | |
''' | |
Same as len(bin(num)[2:]) | |
''' | |
cnt = 0 | |
num_ = num | |
while(num > 0): | |
cnt += 1 | |
num >>= 1 | |
assert cnt == len(bin(num_)[2:]) | |
return cnt | |
def calculate_mu() -> int: | |
''' | |
See algorithm 3 of https://eprint.iacr.org/2017/1057.pdf | |
''' | |
y = 1 | |
for i in range(2, RADIX_BIT_LEN + 1): | |
if (PRIME * y) % (1 << i) != 1: | |
y = y + (1 << (i - 1)) | |
return RADIX - y | |
# = p; See https://en.bitcoin.it/wiki/Secp256k1 | |
PRIME: int = 0x_FFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFE_FFFFFC2F | |
PRIME_BIT_LEN: int = bit_count(PRIME) | |
# Can be rewritten using uint32_t data type of C | |
RADIX_BIT_LEN: int = 32 | |
RADIX: int = 1 << RADIX_BIT_LEN | |
# = 8 | |
LIMB_COUNT: int = ceil(PRIME_BIT_LEN / RADIX_BIT_LEN) | |
# = (2 ^ 32) ^ 8 = 2 ^ 256 % p | |
R: int = (RADIX ** LIMB_COUNT) % PRIME | |
# = (2 ^ 256) ^ 2 % p | |
R2: int = (R * R) % PRIME | |
MU: int = calculate_mu() | |
TEST_CNT: int = 1 << 10 | |
def to_radix_r(num: int) -> List[int]: | |
''' | |
Converts large integer ( 256 -bit ) to radix-r interleaved representation | r = 2^32 | |
''' | |
limbs = [0] * LIMB_COUNT | |
idx = 0 | |
while num > 0: | |
limbs[idx] = num % RADIX | |
num //= RADIX | |
idx += 1 | |
return limbs | |
def from_radix_r(limbs: List[int]) -> int: | |
''' | |
Converts radix-r interleaved representation, to large integer ( 256 -bit ) | r = 2^32 | |
''' | |
cnt = len(limbs) | |
num = 0 | |
for idx in range(cnt-1, -1, -1): | |
num = num * RADIX + limbs[idx] | |
return num | |
def adc(a: int, b: int, carry: int) -> Tuple[int, int]: | |
''' | |
See https://github.com/dusk-network/bls12_381/blob/ed4d87c6756c0020629edb5d8912a41e338ac85a/src/util.rs#L1-L6 | |
''' | |
tmp = a + b + carry | |
return tmp & 0xffff_ffff, tmp >> 32 | |
def mac(a: int, b: int, c: int, carry: int) -> Tuple[int, int]: | |
''' | |
See https://github.com/dusk-network/bls12_381/blob/ed4d87c6756c0020629edb5d8912a41e338ac85a/src/util.rs#L15-L20 | |
''' | |
tmp = a + (b * c) + carry | |
return tmp & 0xffff_ffff, tmp >> 32 | |
def bitwise_not(a: int) -> int: | |
''' | |
Same as `!a` in C | |
''' | |
return RADIX - 1 - a | |
def u256xu32(a: List[int], b: int, c: List[int]) -> List[int]: | |
''' | |
Inspired by https://github.com/dusk-network/bls12_381/blob/ed4d87c6756c0020629edb5d8912a41e338ac85a/src/fp.rs#L517-L522 | |
''' | |
assert len(a) == 9 and len(c) == 8 | |
a[0], carry = mac(a[0], b, c[0], 0) | |
a[1], carry = mac(a[1], b, c[1], carry) | |
a[2], carry = mac(a[2], b, c[2], carry) | |
a[3], carry = mac(a[3], b, c[3], carry) | |
a[4], carry = mac(a[4], b, c[4], carry) | |
a[5], carry = mac(a[5], b, c[5], carry) | |
a[6], carry = mac(a[6], b, c[6], carry) | |
a[7], a[8] = mac(a[7], b, c[7], carry) | |
return a | |
def mont_mult(a: List[int], b: List[int]) -> Tuple[List[int], int]: | |
''' | |
Inspired by https://github.com/dusk-network/bls12_381/blob/ed4d87c6756c0020629edb5d8912a41e338ac85a/src/fp.rs#L437-L560 | |
and algorithm 2 of https://eprint.iacr.org/2017/1057.pdf | |
''' | |
assert len(a) == len(b) | |
prime = to_radix_r(PRIME) | |
cnt = len(a) | |
c = [0] * (cnt << 1) | |
c[0:9] = u256xu32(c[0:9], a[0], b) | |
q = (MU * c[0]) % RADIX | |
_, carry = mac(c[0], q, prime[0], 0) | |
c[1], carry = mac(c[1], q, prime[1], carry) | |
c[2], carry = mac(c[2], q, prime[2], carry) | |
c[3], carry = mac(c[3], q, prime[3], carry) | |
c[4], carry = mac(c[4], q, prime[4], carry) | |
c[5], carry = mac(c[5], q, prime[5], carry) | |
c[6], carry = mac(c[6], q, prime[6], carry) | |
c[7], carry = mac(c[7], q, prime[7], carry) | |
c[8], pc = adc(c[8], 0, carry) | |
c[1:10] = u256xu32(c[1:10], a[1], b) | |
q = (MU * c[1]) % RADIX | |
_, carry = mac(c[1], q, prime[0], 0) | |
c[2], carry = mac(c[2], q, prime[1], carry) | |
c[3], carry = mac(c[3], q, prime[2], carry) | |
c[4], carry = mac(c[4], q, prime[3], carry) | |
c[5], carry = mac(c[5], q, prime[4], carry) | |
c[6], carry = mac(c[6], q, prime[5], carry) | |
c[7], carry = mac(c[7], q, prime[6], carry) | |
c[8], carry = mac(c[8], q, prime[7], carry) | |
c[9], pc = adc(c[9], pc, carry) | |
c[2:11] = u256xu32(c[2:11], a[2], b) | |
q = (MU * c[2]) % RADIX | |
_, carry = mac(c[2], q, prime[0], 0) | |
c[3], carry = mac(c[3], q, prime[1], carry) | |
c[4], carry = mac(c[4], q, prime[2], carry) | |
c[5], carry = mac(c[5], q, prime[3], carry) | |
c[6], carry = mac(c[6], q, prime[4], carry) | |
c[7], carry = mac(c[7], q, prime[5], carry) | |
c[8], carry = mac(c[8], q, prime[6], carry) | |
c[9], carry = mac(c[9], q, prime[7], carry) | |
c[10], pc = adc(c[10], pc, carry) | |
c[3:12] = u256xu32(c[3:12], a[3], b) | |
q = (MU * c[3]) % RADIX | |
_, carry = mac(c[3], q, prime[0], 0) | |
c[4], carry = mac(c[4], q, prime[1], carry) | |
c[5], carry = mac(c[5], q, prime[2], carry) | |
c[6], carry = mac(c[6], q, prime[3], carry) | |
c[7], carry = mac(c[7], q, prime[4], carry) | |
c[8], carry = mac(c[8], q, prime[5], carry) | |
c[9], carry = mac(c[9], q, prime[6], carry) | |
c[10], carry = mac(c[10], q, prime[7], carry) | |
c[11], pc = adc(c[11], pc, carry) | |
c[4:13] = u256xu32(c[4:13], a[4], b) | |
q = (MU * c[4]) % RADIX | |
_, carry = mac(c[4], q, prime[0], 0) | |
c[5], carry = mac(c[5], q, prime[1], carry) | |
c[6], carry = mac(c[6], q, prime[2], carry) | |
c[7], carry = mac(c[7], q, prime[3], carry) | |
c[8], carry = mac(c[8], q, prime[4], carry) | |
c[9], carry = mac(c[9], q, prime[5], carry) | |
c[10], carry = mac(c[10], q, prime[6], carry) | |
c[11], carry = mac(c[11], q, prime[7], carry) | |
c[12], pc = adc(c[12], pc, carry) | |
c[5:14] = u256xu32(c[5:14], a[5], b) | |
q = (MU * c[5]) % RADIX | |
_, carry = mac(c[5], q, prime[0], 0) | |
c[6], carry = mac(c[6], q, prime[1], carry) | |
c[7], carry = mac(c[7], q, prime[2], carry) | |
c[8], carry = mac(c[8], q, prime[3], carry) | |
c[9], carry = mac(c[9], q, prime[4], carry) | |
c[10], carry = mac(c[10], q, prime[5], carry) | |
c[11], carry = mac(c[11], q, prime[6], carry) | |
c[12], carry = mac(c[12], q, prime[7], carry) | |
c[13], pc = adc(c[13], pc, carry) | |
c[6:15] = u256xu32(c[6:15], a[6], b) | |
q = (MU * c[6]) % RADIX | |
_, carry = mac(c[6], q, prime[0], 0) | |
c[7], carry = mac(c[7], q, prime[1], carry) | |
c[8], carry = mac(c[8], q, prime[2], carry) | |
c[9], carry = mac(c[9], q, prime[3], carry) | |
c[10], carry = mac(c[10], q, prime[4], carry) | |
c[11], carry = mac(c[11], q, prime[5], carry) | |
c[12], carry = mac(c[12], q, prime[6], carry) | |
c[13], carry = mac(c[13], q, prime[7], carry) | |
c[14], pc = adc(c[14], pc, carry) | |
c[7:16] = u256xu32(c[7:16], a[7], b) | |
q = (MU * c[7]) % RADIX | |
_, carry = mac(c[7], q, prime[0], 0) | |
c[8], carry = mac(c[8], q, prime[1], carry) | |
c[9], carry = mac(c[9], q, prime[2], carry) | |
c[10], carry = mac(c[10], q, prime[3], carry) | |
c[11], carry = mac(c[11], q, prime[4], carry) | |
c[12], carry = mac(c[12], q, prime[5], carry) | |
c[13], carry = mac(c[13], q, prime[6], carry) | |
c[14], carry = mac(c[14], q, prime[7], carry) | |
c[15], pc = adc(c[15], pc, carry) | |
c[8] += (pc * 977) | |
c[9] += pc | |
return c[8:16] | |
def mont_add(a: List[int], b: List[int]) -> List[int]: | |
''' | |
Collects some inspiration from https://github.com/dusk-network/bls12_381/blob/2c679a284c008475b543a67ee2300ee58ffe5d11/src/fp.rs#L394-L405 | |
''' | |
assert len(a) == len(b) | |
c = [0] * len(a) | |
c[0], carry = adc(a[0], b[0], 0) | |
c[1], carry = adc(a[1], b[1], carry) | |
c[2], carry = adc(a[2], b[2], carry) | |
c[3], carry = adc(a[3], b[3], carry) | |
c[4], carry = adc(a[4], b[4], carry) | |
c[5], carry = adc(a[5], b[5], carry) | |
c[6], carry = adc(a[6], b[6], carry) | |
c[7], carry = adc(a[7], b[7], carry) | |
c[0] += (carry * 977) | |
c[1] += carry | |
return c | |
def mont_inv(a: List[int]) -> List[int]: | |
''' | |
Collects inspiration from https://github.com/dusk-network/bls12_381/blob/2c679a284c008475b543a67ee2300ee58ffe5d11/src/fp.rs#L355-L370 | |
''' | |
def pow(a: List[int], b: List[int]) -> List[int]: | |
res = to_radix_r(R) | |
for i in reversed(b): | |
for j in reversed(range(RADIX_BIT_LEN)): | |
res = mont_mult(res, res) | |
if (i >> j) & 1: | |
res = mont_mult(res, a) | |
return res | |
return pow(a, to_radix_r(PRIME-2)) | |
def to_mont(a: List[int]) -> List[int]: | |
''' | |
Just like https://github.com/dusk-network/bls12_381/blob/ed4d87c6756c0020629edb5d8912a41e338ac85a/src/fp.rs#L251-L253; | |
for better understanding read section 2.2 of https://eprint.iacr.org/2017/1057.pdf | |
''' | |
return mont_mult(a, to_radix_r(R2)) | |
def from_mont(a: List[int]) -> List[int]: | |
''' | |
Read section 2.2 of https://eprint.iacr.org/2017/1057.pdf | |
''' | |
return mont_mult(a, to_radix_r(1)) | |
# --- Testing --- | |
def test_to_and_from_mont_repr(): | |
''' | |
Test with random secp256k1 field elements whether convertion in between radix-r and montgomery representation | |
is behaving as expected | |
''' | |
for _ in range(TEST_CNT): | |
a = randint(0, PRIME-1) | |
b = from_radix_r(from_mont(to_mont(to_radix_r(a)))) | |
assert a == b, f'expeted {a}, found {b}' | |
def test_mont_mult(): | |
''' | |
Test if modular multiplication of two randomly generated secp256k1 prime field elements, using Montgomery algorithm, | |
is behaving as expected | |
''' | |
for _ in range(TEST_CNT): | |
a = randint(0, PRIME-1) | |
b = randint(0, PRIME-1) | |
c = (a * b) % PRIME | |
d = from_radix_r(from_mont(mont_mult( | |
to_mont(to_radix_r(a)), | |
to_mont(to_radix_r(b))))) | |
assert c == d, f'expected {c}, found {d}' | |
def test_mont_add(): | |
''' | |
Test if modular addition of two randomly generated secp256k1 prime field elements, in Montgomery representation, | |
is behaving as expected | |
''' | |
for _ in range(TEST_CNT): | |
a = randint(0, PRIME-1) | |
b = randint(0, PRIME-1) | |
c = (a + b) % PRIME | |
d = from_radix_r(from_mont(mont_add( | |
to_mont(to_radix_r(a)), | |
to_mont(to_radix_r(b))))) | |
assert c == d, f'expected {c}, found {d}' | |
def test_mont_inv(): | |
''' | |
Test if modular multiplicative inverse of one randomly generated secp256k1 prime field element, in Montgomery representation, | |
is behaving as expected | |
''' | |
for _ in range(TEST_CNT): | |
a = randint(1, PRIME-1) | |
b = mont_inv(to_mont(to_radix_r(a))) | |
c = from_radix_r(from_mont(mont_inv(b))) | |
assert a == c, f'expected {a}, found {c}' | |
if __name__ == '__main__': | |
print('Use `pytest` to run test cases !') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This is a reference implementation of Montgomery Modular Arithmetic for prime field
F_p | p = 0x_FFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFF_FFFFFFFE_FFFFFC2F
For running the tests
pytest