Skip to content

Instantly share code, notes, and snippets.

@Yu212
Last active January 15, 2024 17:01
Show Gist options
  • Save Yu212/4000b235e118d965897035c55920701b to your computer and use it in GitHub Desktop.
Save Yu212/4000b235e118d965897035c55920701b to your computer and use it in GitHub Desktop.
Small Secret Exponent Attack
from math import sqrt, ceil, gcd, comb, log, exp
import subprocess
import time
def gen(size, delta_list):
while True:
p = random_prime(2^(size//2), lbound=2^(size//2-1))
q = random_prime(2^(size//2), lbound=2^(size//2-1))
if gcd(p-1, q-1) == 2:
break
phi = (p-1)*(q-1)
d_list = []
for delta in delta_list:
nbits = log(p*q)*delta
d = int(exp(nbits))
while True:
if gcd(d, phi//2) == 1:
break
d -= 1
d_list.append(d)
return p, q, d_list
def decrypt(n, e, c, y):
s = -2*y
p = (s+(s*s-4*n).isqrt())//2
q = n//p
phi = (p-1)*(q-1)
d = pow(e, -1, phi)
return pow(c, d, n)
def ssea(algo, m, t, plain, key):
global n, e
p, q, d = key
n = p*q
phi = (p-1)*(q-1)
e = int(pow(d, -1, phi//2))
A = (n+1)//2
x = (e*d-1)//(phi//2)
y = -(p+q)//2
z = x*y+1
c = pow(plain, e, n)
print(f"{p = }")
print(f"{q = }")
print(f"{e = }")
print(f"{d = }")
print(f"{A = }")
print(f"{x = }")
print(f"{y = }")
assert (x * (A + y) + 1) % e == 0
xx = int(e ** 0.292)
yy = int(e ** 0.5)
zz = xx * yy + 1
ps = polynomials(m, *hybrid_param(t))
lat = lattice(ps, xx, yy, zz)
print(f"LLL start: {len(lat)}")
start = time.time()
start = time.time()
if algo == "flatter":
mat = flatter(lat)
elif algo == "self":
mat = LLL(lat)
elif algo == "sage":
mat = matrix(ZZ, lat).LLL(delta=0.999)
else:
raise ValueError(algo)
print(f"{algo}, {sum(int(sum(v**2 for v in row)).bit_length() for row in mat) / len(lat):.0f}, {time.time() - start:.3f}s")
start = time.time()
degrees = list_degree(ps)
qs = [Polynomial({deg: coef//monomial(*deg)(xx, yy) for coef, deg in zip(vec, degrees)}).remove_z() for vec in mat]
x, y = solve(qs)
print(decrypt(n, e, c, y), f"{time.time() - start:.3f}s")
def herrmann_may_lattice(m, tau): # 0.292
return polynomials(m, tau, 1)
def blomer_may_lattice(m, t): # 0.290
return polynomials(m, 1, t/m)
def polynomials(m, tau, eta):
gs = []
for u in range(ceil(m*(1-eta)), m+1):
for i in range(0, u+1):
gs.append((u-i, i))
gs.sort(key=lambda g: (g[0]+g[1], g[1]))
hs = []
for u in range(ceil(m*(1-eta)), m+1):
for i in range(1, ceil(tau*(u-m*(1-eta)))+1):
hs.append((i, u))
hs.sort(key=lambda h: (h[1], h[0]))
return [g(m, i, j) for i, j in gs] + [h(m, i, u) for i, u in hs]
def list_degree(polynomials):
degrees = set()
for p in polynomials:
degrees.update(p.coef.keys())
return sorted(degrees, key=lambda e: (e[1], e[0]+e[2], e[2]))
def lattice(polynomials, x, y, z):
n = len(polynomials)
lattice = []
degrees = list_degree(polynomials)
for p in polynomials:
lattice.append([p.get_coef(*d) * x**d[0] * y**d[1] * z**d[2] for d in degrees])
return lattice
def hybrid_param(t):
return 1-(2-2**0.5)*t, (6**0.5-2)*t+(3-6**0.5)
# Z + AX
def f():
return monomial(Z=1) + monomial(X=1, coef=(n+1)//2)
# X^i * f^k * e^(m-k)
def g(m, i, k):
return (monomial(X=i, coef=e**(m-k)) * f() ** k).normalize()
# Y^i * f^u * e^(m-u)
def h(m, i, u):
return (monomial(Y=i, coef=e**(m-u)) * f() ** u).normalize()
def monomial(X=0, Y=0, Z=0, coef=1):
return Polynomial({(X, Y, Z): coef})
class Polynomial:
def __init__(self, coef):
self.coef = coef
def __add__(self, other):
coef = self.coef.copy()
for k, v in other.coef.items():
coef[k] = coef.get(k, 0) + v
return Polynomial(coef)
def __mul__(self, other):
coef = dict()
for k1, v1 in self.coef.items():
for k2, v2 in other.coef.items():
k = (k1[0]+k2[0], k1[1]+k2[1], k1[2]+k2[2])
coef[k] = coef.get(k, 0) + v1 * v2
return Polynomial(coef)
def __pow__(self, n):
if n == 0:
return monomial()
coef = Polynomial(self.coef.copy())
for _ in range(1, n):
coef *= self
return coef
def __call__(self, x, y):
z = x * y + 1
val = 0
for k, v in self.coef.items():
val += x ** k[0] * y ** k[1] * z ** k[2] * v
return val
def normalize(self):
coef = dict()
for k, v in self.coef.items():
xy = min(k[0], k[1])
for i in range(xy+1):
nk = (k[0]-xy, k[1]-xy, k[2]+i)
sign = 1 if i%2 == xy%2 else -1
coef[nk] = coef.get(nk, 0) + sign * v * comb(xy, xy-i)
if coef[nk] == 0:
coef.pop(nk)
return Polynomial(coef)
def remove_z(self):
coef = dict()
for k, v in self.coef.items():
for i in range(k[2]+1):
nk = (k[0]+i, k[1]+i, 0)
coef[nk] = coef.get(nk, 0) + v * comb(k[2], i)
if coef[nk] == 0:
coef.pop(nk)
return Polynomial(coef)
def monomials(self):
return sorted(self.coef.items(), key=lambda e: (e[0][1], e[0][0]+e[0][2], e[0][2]))
def get_coef(self, X=0, Y=0, Z=0):
return self.coef.get((X, Y, Z), 0)
def __str__(self):
s = []
for k, v in self.monomials():
t = []
if v != 1:
t.append(str(v))
for c, x in zip("XYZ", k):
if x == 1:
t.append(str(c))
elif x >= 2:
t.append(f"{c}^{x}")
if len(t) == 0:
t.append(str(v))
s.append("*".join(t))
return " + ".join(s)
def solve(qs):
R.<x,y> = PolynomialRing(QQ)
ps = [poly(x, y) for poly in qs]
start = time.time()
left = 0
right = len(ps) + 1
while right - left > 1:
mid = (left + right) // 2
H = Sequence(ps[:mid], R)
I = H.ideal()
dim = I.dimension()
if dim == -1:
right = mid
elif dim != 0:
left = mid
else:
root = I.variety(ring=ZZ)[0]
return root["x"], root["y"]
H = Sequence(ps[:left], R)
for i, h in enumerate(ps[left:]):
H.append(h)
I = H.ideal()
dim = I.dimension()
if dim == -1:
H.pop()
elif dim == 0:
root = I.variety(ring=ZZ)[0]
return root["x"], root["y"]
def gram_schmidt(bases):
n = len(bases)
bases = [vector(RR, base) for base in bases]
gs_bases = [zero_vector(RR, n) for _ in range(n)]
gs_coef = [[0.0] * n for _ in range(n)]
for i in range(n):
gs_bases[i] = bases[i]
for j in range(i):
gs_coef[i][j] = bases[i].dot_product(gs_bases[j]) / gs_bases[j].dot_product(gs_bases[j])
gs_bases[i] -= gs_coef[i][j] * gs_bases[j]
return gs_bases, gs_coef
def LLL(basis, delta=0.75):
basis = [vector(ZZ, base) for base in basis]
n = len(basis)
gs_basis, gs_coef = gram_schmidt(basis)
gs_basis_dot = [base.dot_product(base) for base in gs_basis]
k = 1
while k < n:
for j in reversed(range(k)):
if abs(gs_coef[k][j]) > 0.5:
r = round(gs_coef[k][j])
basis[k] -= r * basis[j]
for i in range(n):
gs_coef[k][i] -= r * gs_coef[j][i]
gs_coef[k][j] %= 1
if 0.5 < gs_coef[k][j]:
gs_coef[k][j] -= 1
if gs_basis_dot[k] >= (delta - gs_coef[k][k-1] ** 2) * gs_basis_dot[k-1]:
k += 1
else:
basis[k], basis[k-1] = basis[k-1], basis[k]
mu_prime = gs_coef[k][k-1]
b = gs_basis_dot[k] + mu_prime * gs_basis_dot[k-1]
gs_basis_dot[k] = gs_basis_dot[k] * gs_basis_dot[k-1] / b
gs_basis_dot[k-1] = b
for j in range(k-1):
gs_coef[k-1][j], gs_coef[k][j] = gs_coef[k][j], gs_coef[k-1][j]
for j in range(k+1, n):
t = gs_coef[j][k]
gs_coef[j][k] = gs_coef[j][k-1] - mu_prime * t
gs_coef[j][k-1] = t + gs_coef[k][k-1] * gs_coef[j][k]
k = max(1, k-1)
return basis
def flatter(bs):
lines = ["[" + " ".join(map(str, b)) + "]" for b in bs]
lattice = "[" + "\n".join(lines) + "]"
proc = subprocess.run("flatter", input=lattice, text=True, stdout=subprocess.PIPE)
result = proc.stdout.replace("[", "").replace("]", "").strip()
return [list(map(int, line.split(" "))) for line in result.split("\n")]
def manual():
size, E, m, t, algo = input("input params (ex: \"64,0.292,15,1,flatter\"): ").split(",")
size, E, m, t = int(size), float(E), int(m), float(t)
p, q, d_list = gen(size, [E])
d = d_list[0]
print(f"solve with param: E={E:.2f} real={log(d)/log(p*q):.4f}")
ssea(algo, m, t, 998244353, (p, q, d))
def bench():
delta_list = [float(i / 100) for i in range(10, 28)]
delta_list = list(reversed(delta_list))
p, q, d_list = gen(64, delta_list)
for delta, d in zip(delta_list, d_list):
print(f"solve with param: E={delta:.2f} real={log(d)/log(p*q):.4f}")
ssea("flatter", 20, 0, 998244353, (p, q, d))
print()
if __name__ == "__main__":
manual()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment