Last active
January 6, 2022 03:20
-
-
Save Levi-Lesches/37f8f9fc4e7f792fcecff2e298b7aa2d to your computer and use it in GitHub Desktop.
An RSA Python script
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
import secrets | |
import string | |
from argparse import ArgumentParser | |
from base64 import b64encode, b64decode | |
SUPPORTED_CHARS = 10 | |
MIN_PRIME = 10**(SUPPORTED_CHARS + 1) | |
MAX_PRIME = 10**(SUPPORTED_CHARS + 2) | |
# ========= Utils ========= | |
def split_message(message): return [ | |
message[i:i + SUPPORTED_CHARS] | |
for i in range(0, len(message), SUPPORTED_CHARS) | |
] | |
encode = lambda x: b64encode(x.encode()).decode() | |
decode = lambda x: b64decode(x).decode() | |
def is_prime(num): | |
if num % 2 == 0: return False | |
for n in range(3, int(math.sqrt(num)), 2): | |
if num % n == 0: return False | |
else: return True | |
def generate_random_prime(): | |
result = 4 # statistically random, see https://xkcd.com/221/ | |
while not is_prime(result): | |
result = secrets.randbelow(MAX_PRIME - MIN_PRIME) + MIN_PRIME | |
return result | |
def modular_multiplicative_inverse(number, modulus): | |
m0 = modulus | |
y = 0 | |
x = 1 | |
while number > 1: | |
quotient = number // modulus | |
modulus, number = number % modulus, modulus | |
y, x = x - quotient * y, y | |
if (x < 0): x += m0 | |
return x | |
# ========= Initialization ========= | |
def generate_keys(): | |
""" | |
Generates a private key and a public key, in that order. | |
Steps: | |
1. Choose two large and random prime numbers, p and q | |
2. n = p * q | |
3. t = (p -1)(q - 1) | |
4. Choose an e such that (1 < e < t) AND e is co-prime with t | |
5. Calculate d, the modular multiplicative inverse of e and t | |
Private key = (n, d), Public key = (n, e) | |
""" | |
p = generate_random_prime() | |
q = generate_random_prime() | |
n = p * q | |
t = (p - 1) * (q - 1) | |
e = 65537 # statistically random, see https://xkcd.com/221/ | |
d = modular_multiplicative_inverse(e, t) | |
return encode(f"{n}_{d}"), encode(f"{n}_{e}") # private, public | |
SYMBOLS = [None] + list(string.printable) | |
SYMBOLS_DIGITS = 2 | |
# ========= Encrypt/Decrypt ========= | |
def pad(message): | |
"""Converts each letter into a 2-digit number""" | |
padded = "" | |
for symbol in message: # each letter is replaced by its index in SYMBOLS | |
padded += str(SYMBOLS.index(symbol)).zfill(SYMBOLS_DIGITS) | |
return int(padded) | |
def unpad(padded): | |
"""Converts a string of digits into characters""" | |
if len(padded) % 2 == 1: # replace any leading zeros that got stripped | |
padded = "0" + padded | |
result = "" | |
for i in range(0, len(padded), 2): # each couplet is an index in SYMBOLS | |
index = int(padded[i:i+2]) | |
result += SYMBOLS[index] | |
return result | |
def encrypt(message, public_key): | |
"""Converts a message to a number using [pad], then encrypts.""" | |
n, e = map(int, decode(public_key).split("_")) | |
return encode("_".join( | |
str(pow(pad(message_part), e, n)) | |
for message_part in split_message(message) | |
)) | |
def decrypt(cipher, private_key): | |
"""Converts a number into a message using [unpad], then decrypts.""" | |
n, d = map(int, decode(private_key).split("_")) | |
try: return "".join( | |
unpad(str(pow(int(cipher_part), d, n))) | |
for cipher_part in decode(cipher).split("_") | |
) | |
except (TypeError, IndexError): | |
raise Exception("Invalid private key") from None | |
if __name__ == "__main__": | |
def test_args(args): | |
# Generate keys | |
private_key, public_key = generate_keys() | |
print(f"Private key: {private_key}") | |
print(f"Public key: {public_key}") | |
print() | |
# Encrypt and decrypt message | |
cipher = encrypt(args.message, public_key) | |
print(f"Encrypted message: {cipher}") | |
message = decrypt(cipher, private_key) | |
print(f"Decrypted message: {message}") | |
def generate_args(args): | |
private_key, public_key = generate_keys() | |
print(f"Private key: {private_key}") | |
print(f"Public key: {public_key}") | |
def encrypt_args(args): | |
print(encrypt(args.message, args.key)) | |
def decrypt_args(args): | |
print(decrypt(args.message, args.key)) | |
parser = ArgumentParser() | |
subparsers = parser.add_subparsers(required=True, dest="command") | |
test_parser = subparsers.add_parser("test", help="Test the RSA encryption") | |
test_parser.add_argument("message", help="The test message", nargs="?", default="Hello, World!") | |
test_parser.set_defaults(func=test_args) | |
gen_parser = subparsers.add_parser("generate", help="Generates public and private keys") | |
gen_parser.set_defaults(func=generate_args) | |
encrypt_parser = subparsers.add_parser("encrypt", help="Encrypt a message using a public key") | |
encrypt_parser.add_argument("message", help="The message to encrypt") | |
encrypt_parser.add_argument("-k", "--key", help="The public key", required=True) | |
encrypt_parser.set_defaults(func=encrypt_args) | |
decrypt_parser = subparsers.add_parser("decrypt", help="Decrypt a message using a public key") | |
decrypt_parser.add_argument("message", help="The message to decrypt") | |
decrypt_parser.add_argument("-k", "--key", help="The private key", required=True) | |
decrypt_parser.set_defaults(func=decrypt_args) | |
args = parser.parse_args() | |
args.func(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment