Skip to content

Instantly share code, notes, and snippets.

@maple3142
Last active September 14, 2024 07:40
Show Gist options
  • Save maple3142/27af0ac0d6a5e1c0b69f6454e37e3999 to your computer and use it in GitHub Desktop.
Save maple3142/27af0ac0d6a5e1c0b69f6454e37e3999 to your computer and use it in GitHub Desktop.
import hashlib
try:
import gmpy2
HAS_GMPY2 = True
except ImportError:
HAS_GMPY2 = False
b = 256
q = 2**255 - 19
l = 2**252 + 27742317777372353535851937790883648493
expmod = pow
if HAS_GMPY2:
expmod = gmpy2.powmod
def H(m):
return hashlib.sha512(m).digest()
def inv(x):
return expmod(x, -1, q)
d = -121665 * inv(121666)
I = expmod(2, (q - 1) // 4, q)
def xrecover(y):
xx = (y * y - 1) * inv(d * y * y + 1)
x = expmod(xx, (q + 3) // 8, q)
if (x * x - xx) % q != 0:
x = (x * I) % q
if x % 2 != 0:
x = q - x
return x
By = 4 * inv(5)
Bx = xrecover(By)
B = [Bx % q, By % q]
def edwards(P, Q):
x1 = P[0]
y1 = P[1]
x2 = Q[0]
y2 = Q[1]
x3 = (x1 * y2 + x2 * y1) * inv(1 + d * x1 * x2 * y1 * y2)
y3 = (y1 * y2 + x1 * x2) * inv(1 - d * x1 * x2 * y1 * y2)
return [x3 % q, y3 % q]
def scalarmult(P, e):
if e == 0:
return [0, 1]
Q = scalarmult(P, e // 2)
Q = edwards(Q, Q)
if e & 1:
Q = edwards(Q, P)
return Q
def encodeint(y):
bits = [(y >> i) & 1 for i in range(b)]
return bytes([sum([bits[i * 8 + j] << j for j in range(8)]) for i in range(b // 8)])
def encodepoint(P):
x = P[0]
y = P[1]
bits = [(y >> i) & 1 for i in range(b - 1)] + [x & 1]
return bytes([sum([bits[i * 8 + j] << j for j in range(8)]) for i in range(b // 8)])
def bit(h, i):
return (h[i // 8] >> (i % 8)) & 1
def publickey(sk):
h = H(sk)
a = 2 ** (b - 2) + sum(2**i * bit(h, i) for i in range(3, b - 2))
A = scalarmult(B, a)
return encodepoint(A)
def Hint(m):
h = H(m)
return sum(2**i * bit(h, i) for i in range(2 * b))
def signature(m, sk, pk):
h = H(sk)
a = 2 ** (b - 2) + sum(2**i * bit(h, i) for i in range(3, b - 2))
r = Hint(h[b // 8 : b // 4] + m)
R = scalarmult(B, r)
S = (r + Hint(encodepoint(R) + pk + m) * a) % l
return encodepoint(R) + encodeint(S)
def isoncurve(P):
x = P[0]
y = P[1]
return (-x * x + y * y - 1 - d * x * x * y * y) % q == 0
def decodeint(s):
return sum(2**i * bit(s, i) for i in range(0, b))
def decodepoint(s):
y = sum(2**i * bit(s, i) for i in range(0, b - 1))
x = xrecover(y)
if x & 1 != bit(s, b - 1):
x = q - x
P = [x, y]
if not isoncurve(P):
raise Exception("decoding point that is not on curve")
return P
def checkvalid(s, m, pk):
if len(s) != b // 4:
raise Exception("signature length is wrong")
if len(pk) != b // 8:
raise Exception("public-key length is wrong")
R = decodepoint(s[0 : b // 8])
A = decodepoint(pk)
S = decodeint(s[b // 8 : b // 4])
h = Hint(encodepoint(R) + pk + m)
if scalarmult(B, S) != edwards(R, scalarmult(A, h)):
raise Exception("signature does not pass verification")
if __name__ == "__main__":
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey
sk = Ed25519PrivateKey.generate()
pk = sk.public_key()
sk_raw = sk.private_bytes_raw()
for i in range(10):
msg = f"peko{i}".encode()
pk_raw = publickey(sk_raw)
assert pk_raw == pk.public_bytes_raw()
sig = signature(msg, sk_raw, pk_raw)
checkvalid(sig, msg, pk_raw)
pk.verify(sig, msg)
sig2 = sk.sign(msg)
assert sig == sig2
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment