# Arup Guha
# 9/21/2022
# Program to encrypt using column permutation.
# Edited to add decryption on 9/20/2023
# Added Double Transposition also on 9/20/2023

def getPerm(key):

    # list size length of key.
    perm =[0]*len(key)

    idx = 0
    
    # Trying each letter 1 by 1 in order!
    for let in range(ord('A'), ord('Z')+1):

        # Go through the word!
        for i in range(len(key)):

            # We have a matching letter.
            if key[i] == chr(let):
                perm[idx] = i
                idx += 1

    # Return the forward permutation.
    return perm

# Encrypts msg using the permutation stored in numKey.
def encrypt(msg, numKey):

    res = ""

    # Go through perm array.
    for i in range(len(numKey)):

        # This is the column to read down.
        whichcol = numKey[i]

        # Read down this column; skip by length of key.
        for i in range(whichcol, len(msg), len(numKey)):
            res = res + msg[i]

    return res

# We didn't use this, but it returns the inverse function stored in numKey.
def getinverse(numKey):

    res = [0]*len(numKey)

    # Just map output to input...
    for i in range(len(numKey)):
        res[numKey[i]] = i

    return res
    
# Decrypt.
def decrypt(cipher, numKey):

    # Calculate length of the short column and # of long columns.
    shortCol = len(cipher)//len(numKey)
    numLong = len(cipher)%len(numKey)

    res = [' ']*len(cipher)

    j = 0
    for i in range(len(numKey)):

        startC = numKey[i]

        mylen = shortCol
        if startC < numLong:
            mylen += 1

        # How many letters are in this group.

        for k in range(mylen):
            res[startC+len(numKey)*k] = cipher[j]
            j += 1

    # Make string.
    s = ""
    for i in range(len(res)):
        s = s + res[i]
    return s

# Just do two transpositions.
def doubletrans(msg,key1,key2):
    nk1 = getPerm(key1)
    nk2 = getPerm(key2)
    tmp = encrypt(msg, nk1)
    print("e1:",tmp)
    return encrypt(tmp, nk2)

# Decrypt in reverse.
def decryptdoubletrans(cipher,key1,key2):
    nk1 = getPerm(key1)
    nk2 = getPerm(key2)
    tmp = decrypt(cipher,nk2)
    return decrypt(tmp,nk1)

def main():

    # Get the keys and plaintext.
    key1 = input("Please enter the first key, all uppercase letters.")
    key1 = key1.strip()

    key2 = input("Please enter the first key, all uppercase letters.")
    key2 = key2.strip()
    
    plain = input("Enter the message all uppercase letters.")
    plain.strip()

    # Here is the ciphertext.
    cipher = doubletrans(plain, key1, key2)
    print("Encrypted")
    print(cipher)

    # Test recovery
    recover = decryptdoubletrans(cipher, key1, key2)
    print(recover)

def oldmain():

    # Get the key and plaintext.
    key = input("Please enter the key, all uppercase letters.")
    key = key.strip()
    numKey = getPerm(key)
    print(numKey)
    
    plain = input("Enter the message all uppercase letters.")
    plain.strip()

    # Convert the key to a permutation.
    numKey = getPerm(key)
    print(numKey)

    # This is the cipher text.
    cipher = encrypt(plain, numKey)
    print(cipher)

    # Test recovery
    recover = decrypt(cipher, numKey)
    print(recover)

main()
