# Arup Guha
# Example of El-Gamal Digital Signature
# 11/29/2021

import random

def main():

    # Get public keys.
    q = int(input("Enter the approximate prime number you want.\n"))
    q = getnextprime(q)

    # Pick a random generator here instead of asking the user.
    # alpha = int(input("Enter a generator mod this prime.\n"))
    alpha = getRndGen(q)

    # Generate private and matching public key.
    xA = random.randint(2, q-2)
    yA = fastmodexpo(alpha, xA, q)

    # Print them.
    print("Public key q =",q)
    print("Public key a =", alpha)
    print("Public key yA =", yA)
    print("Private key xA = ", xA)

    # Message Hash Value
    mHash = random.randint(0, q-1)

    # Pick a random k.
    k = random.randint(2, q-2)
    while gcd(k, q-1) != 1:
        k = random.randint(2, q-2)

    # Get modular inverse.
    kInv = modInv(k, q-1)

    # See these.
    print("k and kinverse are", k, kInv)

    # Signature
    S1 = fastmodexpo(alpha, k, q)

    tmp = (xA*S1)%(q-1)
    S2 = (mHash-tmp+q-1)%(q-1)
    S2 = (S2*kInv)%(q-1)

    # Print the signature.
    print("The signature is")
    print("S1 =", S1)
    print("S2 =", S2)

    # Do the verification.
    V1 = fastmodexpo(alpha, mHash, q)
    print("V1 is ", V1)
    
    term1 = fastmodexpo(yA, S1, q)
    term2 = fastmodexpo(S1, S2, q)
    print("Starting verifying, yA^S1 =",term1)
    print("And S1^S2 =", term2)
    V2 = (term1*term2)%q
    print("V2 is",V2)
    
    if V1 == V2:
        print("Equal, signature verified.")
    else:
        print("Something went wrong.")

    
# Returns the next prime greater than or equal to n.
def getnextprime(n):

    # Make n odd.
    if n%2 == 0:
        n += 1

    # Now, keep on trying until we find one.
    while not isprobablyprime(n, 100):
        n += 2

    # Ta da!
    return n

# Returns (base**exp) % mod, efficiently.
def fastmodexpo(base,exp,mod):

    # Base case.
    if exp == 0:
        return 1

    # Speed up here with even exponent.
    if exp%2 == 0:
        tmp = fastmodexpo(base,exp//2,mod)
        return (tmp*tmp)%mod

    # Odd case, must just do the regular ways.
    return (base*fastmodexpo(base,exp-1,mod))%mod

# Returns the gcd of a and b.
def gcd(a,b):
    if b == 0:
        return a
    return gcd(b, a%b)

def millerrabin(n):

    # Choose random base for Miller Rabin.
    a = random.randint(2,n-2)

    # Set up our base exponent.
    baseexp = n-1
    k = 0

    # Divide out 2 as many times as possible from n-1.
    while baseexp%2 == 0:
        baseexp = baseexp//2
        k += 1

    # Calculate first exponentiation.
    curValue = fastmodexpo(a, baseexp, n)

    # Here we say it's probably prime.
    if curValue == 1:
        return True

    for i in range(k):

        # Must happen for all primes, and more rarely for composites.
        if curValue == n-1:
            return True

        # We just square it to get to the next number in the sequence.
        else:
            curValue = (curValue*curValue)%n

    # If we get here, it must be composite.
    return False

def isprobablyprime(n, numTimes):

    # Run Miller Rabin numTimes times.
    for i in range(numTimes):

        # If it ever fails, immediately return that the number is definitely
        # composite.
        tmp = millerrabin(n)
        if not tmp:
            return False

    # If we get here, it's probably prime.
    return True

# Assumes p is prime number, returns a random generator mod p.
def getRndGen(p):

    while True:

        # Valid range for a generator.
        x = random.randint(2, p-2)

        # Found one.
        if isgen(p, x):
            return x

    return -1

def isgen(p, x):

    div = 2
    top = p-1
    while div*div <= top:

        # See if div divides in evenly.
        exp = 0
        while top%div == 0:
            exp += 1
            top = top//div

        if exp > 0 and fastmodexpo(x, (p-1)//div, p) == 1:
            return False

        div += 1

    return True

# Returns a list storing [x, y, gcd(a,b)] where ax + by = gcd(a,b).
def EEA(a,b):

    # End of algorithm, 1*a + 0*b = a
    if b == 0:
        return [1,0,a]

    # Recursive case.
    else:

        # Next quotient and remainder.
        q = a//b
        r = a%b

        # Algorithm runs on b, r.
        rec = EEA(b,r)

        # Here is how we put the solution back together!
        return [rec[1], rec[0]-q*rec[1], rec[2]]

# Returns the modular inverse of x mod n. Returns 0 if there is no modular
# inverse.
def modInv(x,n):

    # Call the Extended Euclidean.
    arr = EEA(n, x)

    # Indicates that there is no solution.
    if arr[2] != 1:
        return 0

    # Do the wrap around, if necessary.
    if arr[1] < 0:
        arr[1] += n

    # This is the modular inverse.
    return arr[1]

# Start it!
main()
