# Arup Guha
# 10/28/2020
# Example of the Diffie-Hellman Key Exchange

# Edited on 11/6/2020 to use a prime for p.

import random

def main():

    # Generate public keys.
    p = int(input("Enter a rough value for your prime number.\n"))
    p = getnextprime(p)
    print(p,"was the prime chosen.")
    a = random.randint(2,p-2)
    print(a,"was the base a chosen.")

    # Get private keys.
    alice = int(input("Enter Alice's secret key in between 1 and "+str(p)+"\n"))
    bob = int(input("Enter Bob's secret key in between 1 and "+str(p)+"\n"))

    # Show what Alice sends
    C1 = fastmodexpo(a, alice, p)
    print(C1,"is what Alice sends to Bob.")

    # Show what Bob sends
    C2 = fastmodexpo(a, bob, p)
    print(C2,"is what Bob sends to Alice.")

    # Show what Alice computes.
    aliceAns = fastmodexpo(C2, alice, p)
    print(aliceAns,"is what Alice calculates.")

    # And what bob computes.
    bobAns = fastmodexpo(C1, bob, p)
    print(bobAns,"is what Bob calculates.")

# Returns base to the power power mod mod.
def fastmodexpo(base,power,mod):

    # Base case.
    if power == 0:
        return 1%mod

    # Time savings by getting square root and squaring.
    if power%2 == 0:
        tmp = fastmodexpo(base, power//2, mod)
        return (tmp*tmp)%mod

    # Regular case.
    return (fastmodexpo(base,power-1,mod)*base)%mod

# Returns the next prime greater than or equal to n.
def getnextprime(n):

    # Make n odd.
    if n%2 == 0:
        n += 1

    # Now, keep on trying until we find one.
    while not isprobablyprime(n, 100):
        n += 2

    # Ta da!
    return n

# Returns the gcd of a and b.
def gcd(a,b):
    if b == 0:
        return a
    return gcd(b, a%b)

def millerrabin(n):

    # Choose random base for Miller Rabin.
    a = random.randint(2,n-2)

    # Set up our base exponent.
    baseexp = n-1
    k = 0

    # Divide out 2 as many times as possible from n-1.
    while baseexp%2 == 0:
        baseexp = baseexp//2
        k += 1

    # Calculate first exponentiation.
    curValue = fastmodexpo(a, baseexp, n)

    # Here we say it's probably prime.
    if curValue == 1:
        return True

    for i in range(k):

        # Must happen for all primes, and more rarely for composites.
        if curValue == n-1:
            return True

        # We just square it to get to the next number in the sequence.
        else:
            curValue = (curValue*curValue)%n

    # If we get here, it must be composite.
    return False

def isprobablyprime(n, numTimes):

    # Run Miller Rabin numTimes times.
    for i in range(numTimes):

        # If it ever fails, immediately return that the number is definitely
        # composite.
        tmp = millerrabin(n)
        if not tmp:
            return False

    # If we get here, it's probably prime.
    return True

# Run it!
main()
    
