Created
April 4, 2020 16:26
-
-
Save mkow/8ded3e5aa80dcd0647036e5dc7000591 to your computer and use it in GitHub Desktop.
Solver for Verifier 2 from Midnight Sun CTF 2020 Quals
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
from pwn import * | |
from hashlib import sha1, sha256 | |
from random import randint | |
def h(msg): | |
# hash = sha256() | |
hash = sha1() | |
hash.update(msg.encode('ascii')) | |
hash = hash.digest() | |
shift = max(len(hash)*8 - 192, 0) | |
return (int.from_bytes(hash, byteorder='big') >> shift) | |
def unpack_sig(sig): | |
r, s = sig[:48], sig[48:] | |
return int(r, 16), int(s, 16) | |
def pack_sig(r, s): | |
return (r.to_bytes(24, byteorder='big') + s.to_bytes(24, byteorder='big')).hex() | |
def sign(msg): | |
global sock | |
assert '\n' not in msg | |
sock.send('1\n%s\n' % msg) | |
sock.readuntil('Signature: ') | |
sig = sock.readuntil('\n', drop=True).decode('ascii') | |
return unpack_sig(sig) | |
def ver(msg, r, s): | |
global sock | |
assert '\n' not in msg | |
sock.send('2\n%s\n%s\n' % (msg, pack_sig(r, s))) | |
sock.readuntil('signature> ') | |
res = sock.readuntil('\n') | |
print(res) | |
return 'Signature valid' in res.decode('ascii') | |
def win(r, s): | |
global sock | |
sock.send('3\n%s\n' % pack_sig(r, s)) | |
sock.readuntil('signature> ') | |
return sock.readuntil('\n', drop=True) | |
def egcd(a, b): | |
xa,xb = 1,0 | |
ya,yb = 0,1 | |
while ya*a + yb*b > 0: | |
cnt = (xa*a + xb*b) // (ya*a + yb*b) | |
xa -= ya*cnt | |
xb -= yb*cnt | |
xa,ya = ya,xa | |
xb,yb = yb,xb | |
return xa, xb | |
def gcd(a, b): | |
while b: | |
a, b = b, a % b | |
return abs(a) | |
def inv_mod(x, mod): | |
return (egcd(x%mod, mod)[0]) % mod | |
sock = remote('verifier2-01.play.midnightsunctf.se', 31337) | |
flag_rq = 'please_give_me_the_flag' | |
n = None | |
while True: | |
m0, m1, m2, m3 = ['%d' % randint(1, 10**9) for i in range(4)] | |
A = h(m0) - h(m1) | |
B = h(m1) - h(m2) | |
C = h(m2) - h(m3) | |
if A < 0 or B < 0 or C < 0 or gcd(A, B) != 1 or gcd(B, C) != 1: | |
continue | |
xa, xb = egcd(A, B) | |
xa *= C | |
xb *= C | |
assert xa*A + xb*B == C | |
break | |
for _ in range(5): | |
while True: | |
sigs = [sign(m0), sign(m1), sign(m2), sign(m3)] | |
r0, r1, r2, r3 = [r for r, s in sigs] | |
s0, s1, s2, s3 = [s for r, s in sigs] | |
if r0 == r1 == r2 == r3: | |
break | |
print('Retrying...') | |
# xa * (s0 - s1) + xb * (s1 - s2) == (s2 - s3) (mod n) | |
diff = abs(xa * (s0 - s1) + xb * (s1 - s2) - (s2 - s3)) # diff % n == 0 | |
if n is None: | |
n = diff | |
else: | |
n = gcd(n, diff) | |
print('n (%d): %d' % (n.bit_length(), n)) | |
assert n == 6277101735386680763835789423176059013767194773182842284081 # prime, looks good | |
print('n found, calculating priv key...') | |
rev_k = ((s0 - s1) * inv_mod(A, n)) % n | |
assert rev_k == ((s1 - s2) * inv_mod(B, n)) % n | |
k = inv_mod(rev_k, n) | |
print('k = %d' % k) | |
s_flag = (((s0 * k % n) - h(m0) + h(flag_rq)) * inv_mod(k, n)) % n | |
print(win(r0, s_flag).decode()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment