// Arup Guha
// 10/21/2010, partially written for 2010 AES assignment.
// Completed on 9/27/2013 for class example.

/*** Comment added on 10/4/2024

	 This does work, but it's important to understand how it takes in the input.
     The input is in column, row order. So the first byte of input is in plain[0][0], the second byte of input is in plain[1][0],
	 the third byte in plain[2][0], fourth byte in plain[3][0], fifth byte in plain[0][1], etc. When we print, we print the keys
	 as they look in the paper for AES, with the second byte directly below the first byte. I've added a blank line between round
	 keys. At the very end, the ciphertext is printed, and then the recovered plain text.
	 
	 When the state matrix is printed, the second byte is in the second row, right below the first byte, and so on.
	 I've uploaded the file AES-Example.pdf which shows the step by step trace of encrypting a particular plaintext with
	 a particular key. This plaintext and key are hard-coded in main in the format the program takes them.

***/

public class AES {

	final public static int SIZE = 4;
	final public static int ROUNDS = 10;
	private byte[][] state;
	private byte[][] keys;

	final private static int[] SBOX =

 		{0x63 ,0x7c ,0x77 ,0x7b ,0xf2 ,0x6b ,0x6f ,0xc5 ,0x30 ,0x01 ,0x67 ,0x2b ,0xfe ,0xd7 ,0xab ,0x76
 		,0xca ,0x82 ,0xc9 ,0x7d ,0xfa ,0x59 ,0x47 ,0xf0 ,0xad ,0xd4 ,0xa2 ,0xaf ,0x9c ,0xa4 ,0x72 ,0xc0
 		,0xb7 ,0xfd ,0x93 ,0x26 ,0x36 ,0x3f ,0xf7 ,0xcc ,0x34 ,0xa5 ,0xe5 ,0xf1 ,0x71 ,0xd8 ,0x31 ,0x15
 		,0x04 ,0xc7 ,0x23 ,0xc3 ,0x18 ,0x96 ,0x05 ,0x9a ,0x07 ,0x12 ,0x80 ,0xe2 ,0xeb ,0x27 ,0xb2 ,0x75
 		,0x09 ,0x83 ,0x2c ,0x1a ,0x1b ,0x6e ,0x5a ,0xa0 ,0x52 ,0x3b ,0xd6 ,0xb3 ,0x29 ,0xe3 ,0x2f ,0x84
 		,0x53 ,0xd1 ,0x00 ,0xed ,0x20 ,0xfc ,0xb1 ,0x5b ,0x6a ,0xcb ,0xbe ,0x39 ,0x4a ,0x4c ,0x58 ,0xcf
 		,0xd0 ,0xef ,0xaa ,0xfb ,0x43 ,0x4d ,0x33 ,0x85 ,0x45 ,0xf9 ,0x02 ,0x7f ,0x50 ,0x3c ,0x9f ,0xa8
 		,0x51 ,0xa3 ,0x40 ,0x8f ,0x92 ,0x9d ,0x38 ,0xf5 ,0xbc ,0xb6 ,0xda ,0x21 ,0x10 ,0xff ,0xf3 ,0xd2
 		,0xcd ,0x0c ,0x13 ,0xec ,0x5f ,0x97 ,0x44 ,0x17 ,0xc4 ,0xa7 ,0x7e ,0x3d ,0x64 ,0x5d ,0x19 ,0x73
 		,0x60 ,0x81 ,0x4f ,0xdc ,0x22 ,0x2a ,0x90 ,0x88 ,0x46 ,0xee ,0xb8 ,0x14 ,0xde ,0x5e ,0x0b ,0xdb
 		,0xe0 ,0x32 ,0x3a ,0x0a ,0x49 ,0x06 ,0x24 ,0x5c ,0xc2 ,0xd3 ,0xac ,0x62 ,0x91 ,0x95 ,0xe4 ,0x79
 		,0xe7 ,0xc8 ,0x37 ,0x6d ,0x8d ,0xd5 ,0x4e ,0xa9 ,0x6c ,0x56 ,0xf4 ,0xea ,0x65 ,0x7a ,0xae ,0x08
 		,0xba ,0x78 ,0x25 ,0x2e ,0x1c ,0xa6 ,0xb4 ,0xc6 ,0xe8 ,0xdd ,0x74 ,0x1f ,0x4b ,0xbd ,0x8b ,0x8a
 		,0x70 ,0x3e ,0xb5 ,0x66 ,0x48 ,0x03 ,0xf6 ,0x0e ,0x61 ,0x35 ,0x57 ,0xb9 ,0x86 ,0xc1 ,0x1d ,0x9e
 		,0xe1 ,0xf8 ,0x98 ,0x11 ,0x69 ,0xd9 ,0x8e ,0x94 ,0x9b ,0x1e ,0x87 ,0xe9 ,0xce ,0x55 ,0x28 ,0xdf
 		,0x8c ,0xa1 ,0x89 ,0x0d ,0xbf ,0xe6 ,0x42 ,0x68 ,0x41 ,0x99 ,0x2d ,0x0f ,0xb0 ,0x54 ,0xbb ,0x16};

	final private static int[] SBOX_INV =

 		{0x52 ,0x09 ,0x6a ,0xd5 ,0x30 ,0x36 ,0xa5 ,0x38 ,0xbf ,0x40 ,0xa3 ,0x9e ,0x81 ,0xf3 ,0xd7 ,0xfb
 		,0x7c ,0xe3 ,0x39 ,0x82 ,0x9b ,0x2f ,0xff ,0x87 ,0x34 ,0x8e ,0x43 ,0x44 ,0xc4 ,0xde ,0xe9 ,0xcb
 		,0x54 ,0x7b ,0x94 ,0x32 ,0xa6 ,0xc2 ,0x23 ,0x3d ,0xee ,0x4c ,0x95 ,0x0b ,0x42 ,0xfa ,0xc3 ,0x4e
 		,0x08 ,0x2e ,0xa1 ,0x66 ,0x28 ,0xd9 ,0x24 ,0xb2 ,0x76 ,0x5b ,0xa2 ,0x49 ,0x6d ,0x8b ,0xd1 ,0x25
 		,0x72 ,0xf8 ,0xf6 ,0x64 ,0x86 ,0x68 ,0x98 ,0x16 ,0xd4 ,0xa4 ,0x5c ,0xcc ,0x5d ,0x65 ,0xb6 ,0x92
 		,0x6c ,0x70 ,0x48 ,0x50 ,0xfd ,0xed ,0xb9 ,0xda ,0x5e ,0x15 ,0x46 ,0x57 ,0xa7 ,0x8d ,0x9d ,0x84
 		,0x90 ,0xd8 ,0xab ,0x00 ,0x8c ,0xbc ,0xd3 ,0x0a ,0xf7 ,0xe4 ,0x58 ,0x05 ,0xb8 ,0xb3 ,0x45 ,0x06
 		,0xd0 ,0x2c ,0x1e ,0x8f ,0xca ,0x3f ,0x0f ,0x02 ,0xc1 ,0xaf ,0xbd ,0x03 ,0x01 ,0x13 ,0x8a ,0x6b
 		,0x3a ,0x91 ,0x11 ,0x41 ,0x4f ,0x67 ,0xdc ,0xea ,0x97 ,0xf2 ,0xcf ,0xce ,0xf0 ,0xb4 ,0xe6 ,0x73
 		,0x96 ,0xac ,0x74 ,0x22 ,0xe7 ,0xad ,0x35 ,0x85 ,0xe2 ,0xf9 ,0x37 ,0xe8 ,0x1c ,0x75 ,0xdf ,0x6e
 		,0x47 ,0xf1 ,0x1a ,0x71 ,0x1d ,0x29 ,0xc5 ,0x89 ,0x6f ,0xb7 ,0x62 ,0x0e ,0xaa ,0x18 ,0xbe ,0x1b
 		,0xfc ,0x56 ,0x3e ,0x4b ,0xc6 ,0xd2 ,0x79 ,0x20 ,0x9a ,0xdb ,0xc0 ,0xfe ,0x78 ,0xcd ,0x5a ,0xf4
 		,0x1f ,0xdd ,0xa8 ,0x33 ,0x88 ,0x07 ,0xc7 ,0x31 ,0xb1 ,0x12 ,0x10 ,0x59 ,0x27 ,0x80 ,0xec ,0x5f
 		,0x60 ,0x51 ,0x7f ,0xa9 ,0x19 ,0xb5 ,0x4a ,0x0d ,0x2d ,0xe5 ,0x7a ,0x9f ,0x93 ,0xc9 ,0x9c ,0xef
 		,0xa0 ,0xe0 ,0x3b ,0x4d ,0xae ,0x2a ,0xf5 ,0xb0 ,0xc8 ,0xeb ,0xbb ,0x3c ,0x83 ,0x53 ,0x99 ,0x61
 		,0x17 ,0x2b ,0x04 ,0x7e ,0xba ,0x77 ,0xd6 ,0x26 ,0xe1 ,0x69 ,0x14 ,0x63 ,0x55 ,0x21 ,0x0c ,0x7d};

 	final private static byte[][] MIXCOL = { {2, 3, 1, 1}, {1, 2, 3, 1}, {1, 1, 2, 3}, {3, 1, 1, 2}};
 	final private static byte[][] INVMIXCOL = { {(byte)0x0e, (byte)0x0b, (byte)0x0d, (byte)0x09},
 												{(byte)0x09, (byte)0x0e, (byte)0x0b, (byte)0x0d},
 												{(byte)0x0d, (byte)0x09, (byte)0x0e, (byte)0x0b},
 												{(byte)0x0b, (byte)0x0d, (byte)0x09, (byte)0x0e}
 											  };
	final private static byte[] RCON = {(byte)0x00, (byte)0x01, (byte)0x02, (byte)0x04, (byte)0x08, (byte)0x10,
										(byte)0x20, (byte)0x40, (byte)0x80, (byte)0x1b, (byte)0x36
									   };

	public AES(byte[][] initState, byte[][] initKey) {

		// Make a copy of the state.
		state = copy(initState);

		// Allocate space for all the keys.
		keys = new byte[SIZE][SIZE*(ROUNDS+1)];
		keyExpansion(initKey);
	}

	public void keyExpansion(byte[][] initKey) {

		// Copy in the first key.
		for (int i=0; i<initKey.length; i++)
			for (int j=0; j<initKey[0].length; j++)
				keys[i][j] = initKey[i][j];

		// Calculate each word, one by one.
		for (int i=SIZE; i<SIZE*(ROUNDS+1); i++) {

			// Set previous word.
			byte[] temp = new byte[4];
			for (int j=0; j<SIZE; j++)
				temp[j] = keys[j][i-1];

			// Complicated step...
			if (i%4 == 0) {
				rotWord(temp);
				for (int j=0; j<SIZE; j++)
					temp[j] = subBytes(temp[j]);

				temp[0] = add(temp[0], RCON[i/4]);
			}

			// Last word XOR.
			for (int j=0; j<SIZE; j++)
				temp[j] = add(temp[j], keys[j][i-4]);

			// Copy back.
			for (int j=0; j<SIZE; j++)
				keys[j][i] = temp[j];
		}
		
		printKeys();
	}

	// Left rotates b by one byte.
	public static void rotWord(byte[] b) {
		byte temp = b[0];
		for (int i=0; i<b.length-1; i++)
			b[i] = b[i+1];
		b[b.length-1] = temp;
	}

	// Makes a deep copy of matrix and returns it.
	public static byte[][] copy(byte[][] matrix) {

		byte[][] mine = new byte[matrix.length][matrix[0].length];

		for (int i=0; i<matrix.length; i++)
			for (int j=0; j<matrix[i].length; j++)
				mine[i][j] = matrix[i][j];

		return mine;
	}

	public void addRoundKey(int round) {

		// Just XOR all of these guys. The offset is based on the round,
		// into the key array.
		for (int i=0; i<state.length; i++)
			for (int j=0; j<state[0].length; j++)
				state[i][j] = add(state[i][j], keys[i][4*round+j]);
	}

	// Pretty self-exaplanatory.
	public static byte add(byte a, byte b) {
		return (byte)(a ^ b);
	}


	// Returns b * 2, in the field. Hard-coded the AES mod polynomial result.
	public static byte mult2(byte b) {
		if (b < 0)
			return (byte)( ((0x7f & b) << 1) ^ 0x1b );
		return (byte)(b << 1);
	}

	public static byte mult(byte a, byte b) {

		byte pow2 = b;
		byte ans = 0;

		// Loop until we've shifted all bits of a over.
		while (a != 0) {

			// Add in the contribution for this bit of a, if necessary.
			if (a%2 != 0)
		        ans = add(ans, pow2);

		    // Move over to the next bit of a.
		    a = (byte)((a >> 1)&127);

		    // Multiply pow2 by 2, in the field, accordingly.
		    pow2 = mult2(pow2);

		}

		return ans;
	}

	public void shiftRows() {

		// Go to row i.
		for (int i=0; i<state.length; i++) {

			// Store result here for one row.
			byte[] newRow = new byte[SIZE];

			// Copy in values.
			for (int j=0; j<SIZE; j++)
				newRow[j] = state[i][(i+j)%SIZE];

			// Copy back.
			for (int j=0; j<SIZE; j++)
				state[i][j] = newRow[j];
		}
	}

	public void invShiftRows() {

		// Go to row i.
		for (int i=0; i<state.length; i++) {

			// Store result here for one row.
			byte[] newRow = new byte[SIZE];

			// Copy in values.
			for (int j=0; j<SIZE; j++)
				newRow[(i+j)%SIZE] = state[i][j];

			// Copy back.
			for (int j=0; j<SIZE; j++)
				state[i][j] = newRow[j];
		}
	}

	// Does subBytes.
	public void subBytes() {
		for (int i=0; i<state.length; i++)
			for (int j=0; j<state[0].length; j++)
				state[i][j] = subBytes(i,j);
	}

	// Does inverse subBytes.
	public void invSubBytes() {
		for (int i=0; i<state.length; i++)
			for (int j=0; j<state[0].length; j++)
				state[i][j] = invSubBytes(i,j);
	}

	public byte subBytes(int index) {

		// Adjust for negative byte values.
		if (index < 0)
			index += 256;

		return convert(SBOX[index]);
	}

	public byte invSubBytes(int index) {

		// Adjust for negative byte values.
		if (index < 0)
			index += 256;

		return convert(SBOX_INV[index]);
	}

	public byte subBytes(int i, int j) {
		return subBytes(state[i][j]);
	}

	public byte invSubBytes(int i, int j) {
		return invSubBytes(state[i][j]);
	}

	// Converts an int value in between 0 and 255 to the corresponding byte value
	// in between -128 and 127.
	public static byte convert(int value) {
		if (value < 128)
			return (byte)value;
		else
			return (byte)(value-256);
	}

	// This prints the state matrix as it's visualized by the AES authors.
	public void print() {

		for (int i=0; i<state.length; i++) {
			for (int j=0; j<state.length; j++)
				System.out.printf("%x ", state[i][j]);
			System.out.printf("\n");
		}
		System.out.println();
	}

	public void printKeys() {

		// Go through each round.
		for (int i=0; i<keys[0].length; i+=4) {

			// Here is one round key.
			for (int j=0; j<SIZE; j++) {
				for (int k=0; k<SIZE; k++)
					System.out.printf("%x ", keys[j][i+k]);
				System.out.printf("\n");
			}
			
			// Just so there is separation between round keys.
			System.out.printf("\n");
		}
		System.out.printf("\n");
	}

	public void mixColumns() {

		// Store new answer here.
		byte[][] temp = new byte[SIZE][SIZE];

		// Go through each of the 16 entries.
		for (int i = 0; i<SIZE; i++) {
			for (int j=0; j<SIZE; j++) {
				
				// Add the appropriate products in the AES field.
				byte ans = 0;
				for (int k=0; k<SIZE; k++) {
					ans = add(ans, mult(MIXCOL[i][k], state[k][j]));
				}
				temp[i][j] = ans;
			}
		}
		
		// Copy back into the state matrix.
		state = temp;
	}

	// This does the same thing as mixColumns, but with the corresponding inverse matrix to undo the operation.
	public void invMixColumns() {

		byte[][] temp = new byte[SIZE][SIZE];

		for (int i = 0; i<SIZE; i++) {
			for (int j=0; j<SIZE; j++) {
				byte ans = 0;
				for (int k=0; k<SIZE; k++) {
					ans = add(ans, mult(INVMIXCOL[i][k], state[k][j]));
				}
				temp[i][j] = ans;
			}
		}
		state = temp;
	}

	// Runs AES Encryption.
	public void encrypt() {

		// We must do this first before the rounds.
		addRoundKey(0);

		// Go through 10 rounds.
		for (int round=1; round<=ROUNDS; round++) {
			subBytes();

			shiftRows();

			// Only step not always repeated.
			if (round < 10)
				mixColumns();

			addRoundKey(round);
		}
	}

	// Runs AES decryption.
	public void decrypt() {

		// We must do this first before the rounds.
		addRoundKey(ROUNDS);

		// Go through 10 rounds, noting that the round keys
		// we use are numbered 9 down to 0.
		for (int round=ROUNDS-1; round>=0; round--) {

			invShiftRows();
			invSubBytes();
			addRoundKey(round);

			// Only step not always repeated.
			if (round > 0)
				invMixColumns();
		}
	}

	public static void main(String[] args) {
		
		/***
		
		Plaintext (in bytes) = 01 23 45 67 89 ab cd ef fe dc ba 98 76 54 32 10
		Key (in bytes)       = 0f 15 71 c9 47 d9 e8 59 0c b7 ad d6 af 7f 67 98
		Here is how to visualize it:
		Plaintext:						Key:
		
		[01 89 fe 76]					[0f 47 0c af]
		[23 ab dc 54]					[15 d9 b7 7f]
		[45 cd ba 32]					[71 e9 ad 67]
		[67 ef 98 10]					[c9 59 d6 98]
		
		***/

		byte[][] plain = { {(byte)0x01, (byte)0x89, (byte)0xfe, (byte)0x76}, {(byte)0x23, (byte)0xab, (byte)0xdc, (byte)0x54},
						   {(byte)0x45, (byte)0xcd, (byte)0xba, (byte)0x32}, {(byte)0x67, (byte)0xef, (byte)0x98, (byte)0x10} };
		byte[][] key =   { {(byte)0x0f, (byte)0x47, (byte)0x0c, (byte)0xaf}, {(byte)0x15, (byte)0xd9, (byte)0xb7, (byte)0x7f},
			               {(byte)0x71, (byte)0xe8, (byte)0xad, (byte)0x67}, {(byte)0xc9, (byte)0x59, (byte)0xd6, (byte)0x98}};

		AES box = new AES(plain, key);
		box.encrypt();
		box.print();
		System.out.println();
		box.decrypt();
		box.print();
	}

}