Created
May 8, 2020 08:11
-
-
Save kennyyu/5fe34db4c039ff5d04b5659fd1bc7735 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python3 | |
import base64 | |
import random | |
def exp(n, e, mod=None): | |
""" | |
Returns n^e (mod base if specified) | |
""" | |
result = 1 | |
current_pow = n | |
if mod: | |
current_pow = current_pow % mod | |
while e != 0: | |
bit = e & 1 | |
e = e >> 1 | |
if bit: | |
result *= current_pow | |
if mod: | |
result = result % mod | |
current_pow = current_pow * current_pow | |
if mod: | |
current_pow = current_pow % mod | |
return result | |
# print(f"5, 0, {exp(5, 0)}") | |
# print(f"5, 1, {exp(5, 1)}") | |
# print(f"5, 2, {exp(5, 2)}") | |
# print(f"5, 3, {exp(5, 3)}") | |
# print(f"5, 4, {exp(5, 4)}") | |
# print(f"5, 5, {exp(5, 5)}") | |
def extended_euclid(a, b): | |
""" | |
Returns (d, x, y) where d = gcd(a, b) | |
and ax + by = d | |
""" | |
if b == 0: | |
return (a, 1, 0) | |
# a == b * k + r | |
r = a % b | |
k = (a - r) // b | |
(d, x, y) = extended_euclid(b, r) | |
return (d, y, x - k * y) | |
# print(f"10, 6, {extended_euclid(10, 6)}") | |
# print(f"20, 17, {extended_euclid(20, 17)}") | |
# print(f"24, 16, {extended_euclid(24, 16)}") | |
def is_probably_prime(n, num_iter=50): | |
""" | |
Rabin-Miller: | |
Returns True if n is probably prime, where | |
P(n is not prime) < 1 / (2^num_iter) | |
""" | |
if n == 2: | |
return True | |
def get_t_and_u(n): | |
""" | |
Returns (t, u) where n = 2^t * u + 1 | |
""" | |
n_1 = n - 1 | |
t = 0 | |
while n_1 % 2 == 0: | |
t += 1 | |
n_1 = n_1 >> 1 | |
u = n_1 | |
return (t, u) | |
t, u = get_t_and_u(n) | |
for _ in range(num_iter): | |
# Generate all the powers: | |
# a^u, a^(2 * u), a^(4 * u), ..., a^(2^t * u) | |
a = random.randint(2, n - 2) | |
powers = [exp(a, u, mod=n)] | |
for _ in range(t): | |
curr_pow = powers[-1] | |
powers.append((curr_pow * curr_pow) % n) | |
# iterate backwards to check for non trivial | |
# square roots of 1 | |
for i in range(len(powers) - 1, -1, -1): | |
curr_pow = powers[i] | |
if curr_pow == n - 1: | |
# inconclusive, try another a | |
print( | |
f"Inconclusive. n: {n}, t:{t}, u: {u}, n == 2^t * u + 1, a: {a}, powers: {powers}" | |
) | |
break | |
elif curr_pow == 1: | |
# keep going up verifying we have 1 or -1 | |
continue | |
else: | |
# found a non-trivial square root of 1 | |
print( | |
f"Found composite! n: {n}, t:{t}, u: {u}, n == 2^t * u + 1, a: {a}, powers: {powers}" | |
) | |
return False | |
print(f"Found probable prime! n: {n}, t:{t}, u: {u}, n == 2^t * u + 1") | |
return True | |
# print(f"5, {is_probably_prime(5)}") | |
# print(f"109, {is_probably_prime(109)}") | |
# print(f"221, {is_probably_prime(221)}") | |
# print(f"222, {is_probably_prime(222)}") | |
# print(f"65, {is_probably_prime(65)}") | |
# print(f"1012313453, {is_probably_prime(1012313453)}") | |
def rsa_make_keys( | |
prime_min=100000000000000000000000000000000000000000, | |
prime_max=1000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000, | |
): | |
""" | |
Returns ((n, e), (p, q, d)) representing (public key, private key) where: | |
- p and q are large primes | |
- n = p * q | |
- e is randomly chosen where gcd((p - 1)(q - 1), e) == 1 | |
- d = e^(-1) mod (p - 1)(q - 1) | |
""" | |
def generate_prime(prime_min, prime_max): | |
""" | |
Returns a probable prime | |
""" | |
a = random.randint(prime_min, prime_max) | |
while not is_probably_prime(a): | |
a = random.randint(prime_min, prime_max) | |
return a | |
def generate_e_d(p, q): | |
""" | |
Finds an e such that gcd((p - 1)(q - 1), e) == 1, | |
and d such that d = e^(-1) mod (p - 1)(q - 1) | |
Returns (e, d). | |
""" | |
e = 3 | |
while True: | |
# d * e + _ * (p - 1)(q - 1) = 1 | |
# d * e = 1 mod (p - 1)(q - 1) | |
# d = e^(-1) mod (p - 1)(q - 1) | |
(gcd, d, _) = extended_euclid(e, (p - 1) * (q - 1)) | |
if gcd == 1: | |
break | |
e += 1 | |
# d might be negative, return the mod of it | |
return (e, d % ((p - 1) * (q - 1))) | |
p = generate_prime(prime_min, prime_max) | |
q = generate_prime(prime_min, prime_max) | |
n = p * q | |
e, d = generate_e_d(p, q) | |
return ((n, e), (p, q, d)) | |
def encode_with_public_key(n, e, num): | |
return exp(num, e, mod=n) | |
def decode_with_private_key(p, q, d, num_encrypted): | |
return exp(num_encrypted, d, mod=p * q) | |
# Size of each individual message | |
# n must be bigger than this | |
MESSAGE_SIZE_BYTES = 16 | |
def encode_message(n, e, message_str): | |
""" | |
Encodes a message with the public key. If the message | |
is large, this will divide up the message into chunks | |
""" | |
chunks = [] | |
chunk_pos = 0 | |
while chunk_pos < len(message_str): | |
chunk_max = min(chunk_pos + MESSAGE_SIZE_BYTES, len(message_str)) | |
chunks.append(message_str[chunk_pos:chunk_max]) | |
chunk_pos = chunk_max | |
return [encode_message_chunk(n, e, chunk) for chunk in chunks] | |
def decode_message(p, q, d, encrypted_chunks): | |
""" | |
Decodes a set of encrypted chunks and returns the final message | |
""" | |
chunks = [decode_message_chunk(p, q, d, chunk) for chunk in encrypted_chunks] | |
return "".join(chunks) | |
def encode_message_chunk(n, e, message_str): | |
""" | |
Encodes a string with the public key | |
""" | |
# add padding | |
message_str = message_str + ((MESSAGE_SIZE_BYTES - len(message_str)) * "\0") | |
message_bytes = str.encode(message_str) | |
message_int = int.from_bytes(message_bytes, byteorder="big", signed=False) | |
encrypted_int = encode_with_public_key(n, e, message_int) | |
encrypted_int_bytes = str.encode(str(encrypted_int)) | |
return base64.b64encode(encrypted_int_bytes) | |
def decode_message_chunk(p, q, d, encrypted_message_str): | |
""" | |
Decodes an encrypted string using the private key | |
""" | |
encrypted_int_bytes = base64.b64decode(encrypted_message_str) | |
encrypted_int = int(encrypted_int_bytes.decode()) | |
message_int = decode_with_private_key(p, q, d, encrypted_int) | |
message_bytes = message_int.to_bytes( | |
length=MESSAGE_SIZE_BYTES, byteorder="big", signed=False | |
) | |
message_str = message_bytes.decode() | |
# remove padding | |
cut_point = message_str.find("\0") | |
return message_str if cut_point == -1 else message_str[0:cut_point] | |
((n, e), (p, q, d)) = rsa_make_keys() | |
print(f"n:{n}") | |
print(f"e:{e}") | |
print(f"p:{p}") | |
print(f"q:{q}") | |
print(f"d:{d}") | |
num = 171717 | |
num_enc = encode_with_public_key(n, e, num) | |
num_dec = decode_with_private_key(p, q, d, num_enc) | |
print(f"num: {num}, num_enc: {num_enc}, num_dec: {num_dec}") | |
message = "Hello World Everyone! Hello World Everyone! Hello World Everyone! Hello World Everyone!" | |
encrypted_message = encode_message(n, e, message) | |
decrypted_message = decode_message(p, q, d, encrypted_message) | |
print(f" message: {message}") | |
print(f"encrypted: {encrypted_message}") | |
print(f"decrypted: {decrypted_message}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment