Skip to content

Instantly share code, notes, and snippets.

@mkow
Created April 4, 2020 16:26
Show Gist options
  • Save mkow/8ded3e5aa80dcd0647036e5dc7000591 to your computer and use it in GitHub Desktop.
Save mkow/8ded3e5aa80dcd0647036e5dc7000591 to your computer and use it in GitHub Desktop.
Solver for Verifier 2 from Midnight Sun CTF 2020 Quals
#!/usr/bin/env python3
from pwn import *
from hashlib import sha1, sha256
from random import randint
def h(msg):
# hash = sha256()
hash = sha1()
hash.update(msg.encode('ascii'))
hash = hash.digest()
shift = max(len(hash)*8 - 192, 0)
return (int.from_bytes(hash, byteorder='big') >> shift)
def unpack_sig(sig):
r, s = sig[:48], sig[48:]
return int(r, 16), int(s, 16)
def pack_sig(r, s):
return (r.to_bytes(24, byteorder='big') + s.to_bytes(24, byteorder='big')).hex()
def sign(msg):
global sock
assert '\n' not in msg
sock.send('1\n%s\n' % msg)
sock.readuntil('Signature: ')
sig = sock.readuntil('\n', drop=True).decode('ascii')
return unpack_sig(sig)
def ver(msg, r, s):
global sock
assert '\n' not in msg
sock.send('2\n%s\n%s\n' % (msg, pack_sig(r, s)))
sock.readuntil('signature> ')
res = sock.readuntil('\n')
print(res)
return 'Signature valid' in res.decode('ascii')
def win(r, s):
global sock
sock.send('3\n%s\n' % pack_sig(r, s))
sock.readuntil('signature> ')
return sock.readuntil('\n', drop=True)
def egcd(a, b):
xa,xb = 1,0
ya,yb = 0,1
while ya*a + yb*b > 0:
cnt = (xa*a + xb*b) // (ya*a + yb*b)
xa -= ya*cnt
xb -= yb*cnt
xa,ya = ya,xa
xb,yb = yb,xb
return xa, xb
def gcd(a, b):
while b:
a, b = b, a % b
return abs(a)
def inv_mod(x, mod):
return (egcd(x%mod, mod)[0]) % mod
sock = remote('verifier2-01.play.midnightsunctf.se', 31337)
flag_rq = 'please_give_me_the_flag'
n = None
while True:
m0, m1, m2, m3 = ['%d' % randint(1, 10**9) for i in range(4)]
A = h(m0) - h(m1)
B = h(m1) - h(m2)
C = h(m2) - h(m3)
if A < 0 or B < 0 or C < 0 or gcd(A, B) != 1 or gcd(B, C) != 1:
continue
xa, xb = egcd(A, B)
xa *= C
xb *= C
assert xa*A + xb*B == C
break
for _ in range(5):
while True:
sigs = [sign(m0), sign(m1), sign(m2), sign(m3)]
r0, r1, r2, r3 = [r for r, s in sigs]
s0, s1, s2, s3 = [s for r, s in sigs]
if r0 == r1 == r2 == r3:
break
print('Retrying...')
# xa * (s0 - s1) + xb * (s1 - s2) == (s2 - s3) (mod n)
diff = abs(xa * (s0 - s1) + xb * (s1 - s2) - (s2 - s3)) # diff % n == 0
if n is None:
n = diff
else:
n = gcd(n, diff)
print('n (%d): %d' % (n.bit_length(), n))
assert n == 6277101735386680763835789423176059013767194773182842284081 # prime, looks good
print('n found, calculating priv key...')
rev_k = ((s0 - s1) * inv_mod(A, n)) % n
assert rev_k == ((s1 - s2) * inv_mod(B, n)) % n
k = inv_mod(rev_k, n)
print('k = %d' % k)
s_flag = (((s0 * k % n) - h(m0) + h(flag_rq)) * inv_mod(k, n)) % n
print(win(r0, s_flag).decode())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment