#!/usr/bin/env python """Demo of RSA encryption This module illustrates the core of the RSA algorithm in pure Python. On each run, the program follows these steps: 1. It generates fresh RSA keys. 2. It prompts the user for a message. 3. It encodes the message into fixed-size blocks. 4. It encrypts each block into a ciphertext. 5. It decrypts the ciphertext to reveal the original message. Although the asymmetric key encryption is secure, this implementation is not secure. It is for education purposes only. * The key sizes are too small to be cryptographically secure. * Padding is not added to the blocks, so the encryption could be defeated by frequency analysis. See also http://en.wikipedia.org/wiki/RSA_%28algorithm%29#A_working_example and http://www.di-mgt.com.au/rsa_alg.html#simpleexample. This module has been tested in both Python 2.7.3 and Python 3.2.3. Author: Chad Parry """ import functools import itertools import logging import math import random import sys BLOCK_SIZE_BYTES = 2 """Size of each block, in bytes""" if sys.version_info[0] >= 3: # Shims for Python 3 compatibility lazy_range = range zip_longest = itertools.zip_longest console_input = input def decode_input(arg): return arg def u(string): return string else: # Shims for Python 2 compatibility import codecs lazy_range = xrange zip_longest = itertools.izip_longest console_input = raw_input def decode_input(arg): logging.debug(u('Console encoding is %s'), sys.stdin.encoding) if sys.stdin.encoding is None: decoded = arg else: decoded = arg.decode(sys.stdin.encoding) return decoded def u(string): return codecs.unicode_escape_decode(string)[0] def egcd(a, b): """Extended Euclidean algorithm Returns a triple (g, x, y), such that ax + by = g = gcd(a, b). The greatest common divisor of a and b is egcd(a, b)[0]. See http://en.wikibooks.org/wiki/Algorithm_Implementation/Mathematics/Extended_Euclidean_algorithm#Python. """ if a == 0: return (b, 0, 1) else: g, y, x = egcd(b % a, a) return (g, x - (b // a) * y, y) def lcm(a, b): """Least common multiple""" return a * b // egcd(a, b)[0] def is_prime(n): """Simple primality test For very large keys, an efficient probabilistic test would be needed. """ if n < 2: return False for factor in lazy_range(2, int(math.sqrt(n)) + 1): if n % factor == 0: logging.debug(u('Candidate %d is divisible by %d'), n, factor) return False logging.debug(u('Candidate %d is prime'), n) return True def generate_prime(): """Generate a large random prime""" prime_bits = BLOCK_SIZE_BYTES * 8 // 2 prime_min = 2 ** prime_bits prime_max = 2 ** (prime_bits + 1) - 1 logging.debug(u('Prime should be from %d to %d'), prime_min, prime_max) while True: candidate = random.randint(prime_min, prime_max) if is_prime(candidate): return candidate def choose_public_exponent(p, q): """Choose a small exponent e for the public key""" totient = (p - 1) * (q - 1) logging.debug(u('Totient \u03c6 = (p - 1)(q - 1) = %d \u00d7 %d = %d'), p - 1, q - 1, totient) for candidate in itertools.count(2): divisor = egcd(candidate, totient)[0] if divisor == 1: logging.debug(u('Candidate %d is coprime with %d'), candidate, totient) return candidate else: logging.debug(u('Candidate %d shares factor %d with %d'), candidate, divisor, totient) def choose_private_exponent(p, q, e): """Calculate the exponent d for the private key""" modulus = lcm(p - 1, q - 1) logging.debug(u('Modulus \u03bb = LCM(p - 1, q - 1) ' + '= LCM(%d, %d) ' + '= %d'), p - 1, q - 1, modulus) logging.debug(u('Solving for de mod \u03bb = 1 ' + '\u2234 %d \u00d7 d mod %d = 1'), e, modulus) return egcd(e, modulus)[1] % modulus def generate_key(): """Generate fresh public and private keys""" while True: p = generate_prime() q = generate_prime() if p != q: break logging.debug(u('Rejecting identical primes %d'), p) logging.info(u('Prime p = %d'), p) logging.info(u('Prime q = %d'), q) n = p * q logging.info(u('Modulus n = pq = %d \u00d7 %d = %d'), p, q, n) e = choose_public_exponent(p, q) logging.info(u('Public exponent e = %d'), e) d = choose_private_exponent(p, q, e) logging.info(u('Private exponent d = %d'), d) return (n, e, d) def grouper(n, iterable, fillvalue=None): """Collect data into fixed-length chunks or blocks This recipe appears in the itertools documentation. """ args = [iter(iterable)] * n return zip_longest(fillvalue=fillvalue, *args) def pack_block(block): """Pack a list of bytes into a single unsigned integer""" return functools.reduce(lambda packed, byte: packed * 256 + byte, block, 0) def unpack_block_little_endian(block): """Generator that unpacks an unsigned integer into a list of bytes The returned list is reversed from what it should be, because that makes the implementation simpler. """ for index in lazy_range(BLOCK_SIZE_BYTES): yield int(block % 256) block //= 256 def unpack_block(block): """Unpack an unsigned integer into a list of bytes""" return reversed(list(unpack_block_little_endian(block))) def encrypt_block(n, e, plaintext): """Encrypt a single block with the public key""" ciphertext = pow(plaintext, e, n) logging.debug(u('Encryption m^e mod n ' + '= %d ^ %d mod %d ' + '= %d'), plaintext, e, n, ciphertext) return ciphertext def decrypt_block(n, d, ciphertext): """Decrypt a single block with the private key""" encoded = pow(ciphertext, d, n) logging.debug(u('Decryption c^d mod n ' + '= %d ^ %d mod %d ' + '= %d'), ciphertext, d, n, encoded) return encoded def encrypt_message(n, e, plaintext): """Encrypt a message with the public key""" encoded = bytearray(plaintext, 'utf_8') logging.debug(u('Encoded plaintext is: %s'), list(encoded)) unpacked_plaintext = grouper(BLOCK_SIZE_BYTES, encoded, 0) packed_plaintext = list(map(pack_block, unpacked_plaintext)) logging.debug(u('Plaintext blocks are: %s'), packed_plaintext) ciphertext = list(map(functools.partial(encrypt_block, n, e), packed_plaintext)) logging.debug(u('Ciphertext blocks are: %s'), ciphertext) return ciphertext def decrypt_message(n, d, ciphertext): """Decrypt a message with the private key""" packed_plaintext = list(map(functools.partial(decrypt_block, n, d), ciphertext)) logging.debug(u('Plaintext blocks are: %s'), packed_plaintext) unpacked_plaintext = map(unpack_block, packed_plaintext) padded = itertools.chain(*unpacked_plaintext) encoded = list(itertools.takewhile(lambda byte: byte != 0, padded)) logging.debug(u('Encoded plaintext is: %s'), encoded) plaintext = bytearray(encoded).decode('utf_8') return plaintext def main(): """Generate new keys and encrypt a message The program retrieves the message from the command-line arguments, if they exist. Otherwise it prompts the user for input. """ logging.basicConfig(level=logging.DEBUG) (n, e, d) = generate_key() logging.warning(u('Public key is { n=%d, e=%d }'), n, e) logging.warning(u('Private key is { n=%d, d=%d }'), n, d) args = sys.argv[1:] if args: logging.debug(u('Retrieving plaintext from the command-line arguments')) messages = map(decode_input, args) else: logging.debug(u('Prompting the user for plaintext')) message = decode_input(console_input(u('What is the message? '))) messages = [message] for plaintext in messages: print(u('Plaintext message: %s') % plaintext) ciphertext = encrypt_message(n, e, plaintext) print(u('Encrypted message: %s') % ciphertext) roundtrip = decrypt_message(n, d, ciphertext) logging.info(u('Decrypted message: %s'), roundtrip) if __name__ == "__main__": main()