# Arup Guha
# Illustration of El-Gamal
# 11/4/2022

import random

def main():

    # As user for prime.
    p = int(input("Enter the approximate value of the prime number for your El Gamal key.\n"))
    p = getnextprime(p)
    g = getRndPrimRoot(p)
    Xa = random.randint(2,p-2)
    Ya = pow(g,Xa,p)

    # Print out public keys.
    print("Posting public keys.")
    print("prime =",p)
    print("gen   =",g)
    print("Ya    =",Ya)

    # Get message.
    M = int(input("Please enter your message in between 0 and "+str(p-1)+"\n"))

    # Pick random k to send message.
    k = random.randint(2,p-2)

    # Calculate Ciphertexts with random k.
    c1 = pow(g,k,p)
    c2 = pow(Ya,k,p)
    c2 = (c2*M)%p

    # Print Ciphertexts.
    print("C1 =",c1)
    print("C2 =",c2)

    # To decrypt.
    temp = pow(c1,Xa,p)
    temp = modInv(temp, p)
    print("K inv for recovery is",temp)

    # Print recovered text.
    recover = (temp*c2)%p
    print("Recovered", recover)

    
# Returns each unique prime divisor of n
def getPrimeDiv(n):

    # Store all unique prime divisors here.
    res = []
    div = 2

    # Do trial division until the square root.
    while div*div <= n:

        # Divide out as many copies of div as possible.
        flag = False
        while n%div == 0:
            n //= div
            flag = True

        # We found one, so a new prime divisor.
        if flag:
            res.append(div)

        # Go to the next number.
        div += 1

    # See if anything is leftover.
    if n > 1:
        res.append(n)

    # We got them all.
    return res


# Returns a random primitive root mod p.
def getRndPrimRoot(p):

    # Get all prime divisors of p-1.
    primediv = getPrimeDiv(p-1)
    
    while True:
        
        b = random.randint(2,p-2)   

        # Here are all of our candidates. If any of these are 1, it's not
        # primitive.
        flag = True
        for x in primediv:
            if pow(b, (p-1)//x, p) == 1:
                flag = False
                break

        # We found one, so return it.
        if flag:
            return b

# Returns the next prime greater than or equal to n.
def getnextprime(n):

    # Make n odd.
    if n%2 == 0:
        n += 1

    # Now, keep on trying until we find one.
    while not isprobablyprime(n, 100):
        n += 2

    # Ta da!
    return n

# Returns the gcd of a and b.
def gcd(a,b):
    if b == 0:
        return a
    return gcd(b, a%b)

def millerrabin(n):

    # Choose random base for Miller Rabin.
    a = random.randint(2,n-2)

    # Set up our base exponent.
    baseexp = n-1
    k = 0

    # Divide out 2 as many times as possible from n-1.
    while baseexp%2 == 0:
        baseexp = baseexp//2
        k += 1

    # Calculate first exponentiation.
    curValue = pow(a, baseexp, n)

    # Here we say it's probably prime.
    if curValue == 1:
        return True

    for i in range(k):

        # Must happen for all primes, and more rarely for composites.
        if curValue == n-1:
            return True

        # We just square it to get to the next number in the sequence.
        else:
            curValue = (curValue*curValue)%n

    # If we get here, it must be composite.
    return False

def isprobablyprime(n, numTimes):

    # Run Miller Rabin numTimes times.
    for i in range(numTimes):

        # If it ever fails, immediately return that the number is definitely
        # composite.
        tmp = millerrabin(n)
        if not tmp:
            return False

    # If we get here, it's probably prime.
    return True

# Returns a list storing [x, y, gcd(a,b)] where ax + by = gcd(a,b).
def EEA(a,b):

    # End of algorithm, 1*a + 0*b = a
    if b == 0:
        return [1,0,a]

    # Recursive case.
    else:

        # Next quotient and remainder.
        q = a//b
        r = a%b

        # Algorithm runs on b, r.
        rec = EEA(b,r)

        # Here is how we put the solution back together!
        return [rec[1], rec[0]-q*rec[1], rec[2]]

# Returns the modular inverse of x mod n. Returns 0 if there is no modular
# inverse.
def modInv(x,n):

    # Call the Extended Euclidean.
    arr = EEA(n, x)

    # Indicates that there is no solution.
    if arr[2] != 1:
        return 0

    # Do the wrap around, if necessary.
    if arr[1] < 0:
        arr[1] += n

    # This is the modular inverse.
    return arr[1]

# Run it.
main()
