// Arup Guha
// 10/21/2010
// Used for Solution to CIS 3362 Homework #4 Fall 2010

public class AES {
	
	private byte[][] state;
	private byte[][] keys;
	
	final private static int[] s =  
 
 		{0x63 ,0x7c ,0x77 ,0x7b ,0xf2 ,0x6b ,0x6f ,0xc5 ,0x30 ,0x01 ,0x67 ,0x2b ,0xfe ,0xd7 ,0xab ,0x76
 		,0xca ,0x82 ,0xc9 ,0x7d ,0xfa ,0x59 ,0x47 ,0xf0 ,0xad ,0xd4 ,0xa2 ,0xaf ,0x9c ,0xa4 ,0x72 ,0xc0
 		,0xb7 ,0xfd ,0x93 ,0x26 ,0x36 ,0x3f ,0xf7 ,0xcc ,0x34 ,0xa5 ,0xe5 ,0xf1 ,0x71 ,0xd8 ,0x31 ,0x15
 		,0x04 ,0xc7 ,0x23 ,0xc3 ,0x18 ,0x96 ,0x05 ,0x9a ,0x07 ,0x12 ,0x80 ,0xe2 ,0xeb ,0x27 ,0xb2 ,0x75
 		,0x09 ,0x83 ,0x2c ,0x1a ,0x1b ,0x6e ,0x5a ,0xa0 ,0x52 ,0x3b ,0xd6 ,0xb3 ,0x29 ,0xe3 ,0x2f ,0x84
 		,0x53 ,0xd1 ,0x00 ,0xed ,0x20 ,0xfc ,0xb1 ,0x5b ,0x6a ,0xcb ,0xbe ,0x39 ,0x4a ,0x4c ,0x58 ,0xcf
 		,0xd0 ,0xef ,0xaa ,0xfb ,0x43 ,0x4d ,0x33 ,0x85 ,0x45 ,0xf9 ,0x02 ,0x7f ,0x50 ,0x3c ,0x9f ,0xa8
 		,0x51 ,0xa3 ,0x40 ,0x8f ,0x92 ,0x9d ,0x38 ,0xf5 ,0xbc ,0xb6 ,0xda ,0x21 ,0x10 ,0xff ,0xf3 ,0xd2
 		,0xcd ,0x0c ,0x13 ,0xec ,0x5f ,0x97 ,0x44 ,0x17 ,0xc4 ,0xa7 ,0x7e ,0x3d ,0x64 ,0x5d ,0x19 ,0x73
 		,0x60 ,0x81 ,0x4f ,0xdc ,0x22 ,0x2a ,0x90 ,0x88 ,0x46 ,0xee ,0xb8 ,0x14 ,0xde ,0x5e ,0x0b ,0xdb
 		,0xe0 ,0x32 ,0x3a ,0x0a ,0x49 ,0x06 ,0x24 ,0x5c ,0xc2 ,0xd3 ,0xac ,0x62 ,0x91 ,0x95 ,0xe4 ,0x79
 		,0xe7 ,0xc8 ,0x37 ,0x6d ,0x8d ,0xd5 ,0x4e ,0xa9 ,0x6c ,0x56 ,0xf4 ,0xea ,0x65 ,0x7a ,0xae ,0x08
 		,0xba ,0x78 ,0x25 ,0x2e ,0x1c ,0xa6 ,0xb4 ,0xc6 ,0xe8 ,0xdd ,0x74 ,0x1f ,0x4b ,0xbd ,0x8b ,0x8a
 		,0x70 ,0x3e ,0xb5 ,0x66 ,0x48 ,0x03 ,0xf6 ,0x0e ,0x61 ,0x35 ,0x57 ,0xb9 ,0x86 ,0xc1 ,0x1d ,0x9e
 		,0xe1 ,0xf8 ,0x98 ,0x11 ,0x69 ,0xd9 ,0x8e ,0x94 ,0x9b ,0x1e ,0x87 ,0xe9 ,0xce ,0x55 ,0x28 ,0xdf
 		,0x8c ,0xa1 ,0x89 ,0x0d ,0xbf ,0xe6 ,0x42 ,0x68 ,0x41 ,0x99 ,0x2d ,0x0f ,0xb0 ,0x54 ,0xbb ,0x16};

	final private static int[] inv_s = 
 
 		{0x52 ,0x09 ,0x6a ,0xd5 ,0x30 ,0x36 ,0xa5 ,0x38 ,0xbf ,0x40 ,0xa3 ,0x9e ,0x81 ,0xf3 ,0xd7 ,0xfb
 		,0x7c ,0xe3 ,0x39 ,0x82 ,0x9b ,0x2f ,0xff ,0x87 ,0x34 ,0x8e ,0x43 ,0x44 ,0xc4 ,0xde ,0xe9 ,0xcb
 		,0x54 ,0x7b ,0x94 ,0x32 ,0xa6 ,0xc2 ,0x23 ,0x3d ,0xee ,0x4c ,0x95 ,0x0b ,0x42 ,0xfa ,0xc3 ,0x4e
 		,0x08 ,0x2e ,0xa1 ,0x66 ,0x28 ,0xd9 ,0x24 ,0xb2 ,0x76 ,0x5b ,0xa2 ,0x49 ,0x6d ,0x8b ,0xd1 ,0x25
 		,0x72 ,0xf8 ,0xf6 ,0x64 ,0x86 ,0x68 ,0x98 ,0x16 ,0xd4 ,0xa4 ,0x5c ,0xcc ,0x5d ,0x65 ,0xb6 ,0x92
 		,0x6c ,0x70 ,0x48 ,0x50 ,0xfd ,0xed ,0xb9 ,0xda ,0x5e ,0x15 ,0x46 ,0x57 ,0xa7 ,0x8d ,0x9d ,0x84
 		,0x90 ,0xd8 ,0xab ,0x00 ,0x8c ,0xbc ,0xd3 ,0x0a ,0xf7 ,0xe4 ,0x58 ,0x05 ,0xb8 ,0xb3 ,0x45 ,0x06
 		,0xd0 ,0x2c ,0x1e ,0x8f ,0xca ,0x3f ,0x0f ,0x02 ,0xc1 ,0xaf ,0xbd ,0x03 ,0x01 ,0x13 ,0x8a ,0x6b
 		,0x3a ,0x91 ,0x11 ,0x41 ,0x4f ,0x67 ,0xdc ,0xea ,0x97 ,0xf2 ,0xcf ,0xce ,0xf0 ,0xb4 ,0xe6 ,0x73
 		,0x96 ,0xac ,0x74 ,0x22 ,0xe7 ,0xad ,0x35 ,0x85 ,0xe2 ,0xf9 ,0x37 ,0xe8 ,0x1c ,0x75 ,0xdf ,0x6e
 		,0x47 ,0xf1 ,0x1a ,0x71 ,0x1d ,0x29 ,0xc5 ,0x89 ,0x6f ,0xb7 ,0x62 ,0x0e ,0xaa ,0x18 ,0xbe ,0x1b
 		,0xfc ,0x56 ,0x3e ,0x4b ,0xc6 ,0xd2 ,0x79 ,0x20 ,0x9a ,0xdb ,0xc0 ,0xfe ,0x78 ,0xcd ,0x5a ,0xf4
 		,0x1f ,0xdd ,0xa8 ,0x33 ,0x88 ,0x07 ,0xc7 ,0x31 ,0xb1 ,0x12 ,0x10 ,0x59 ,0x27 ,0x80 ,0xec ,0x5f
 		,0x60 ,0x51 ,0x7f ,0xa9 ,0x19 ,0xb5 ,0x4a ,0x0d ,0x2d ,0xe5 ,0x7a ,0x9f ,0x93 ,0xc9 ,0x9c ,0xef
 		,0xa0 ,0xe0 ,0x3b ,0x4d ,0xae ,0x2a ,0xf5 ,0xb0 ,0xc8 ,0xeb ,0xbb ,0x3c ,0x83 ,0x53 ,0x99 ,0x61
 		,0x17 ,0x2b ,0x04 ,0x7e ,0xba ,0x77 ,0xd6 ,0x26 ,0xe1 ,0x69 ,0x14 ,0x63 ,0x55 ,0x21 ,0x0c ,0x7d};
 		
 	final private static byte[][] mixcol = { {2, 3, 1, 1}, {1, 2, 3, 1}, {1, 1, 2, 3}, {3, 1, 1, 2}};

	final private static byte[] rcon = {1, 2, 4, 8, 16, 32, 64, -128, 27, 54};
	
	public AES(byte[][] initState, byte[][] initKey) {
		
		// Make a copy of the state.
		state = copy(initState);
		
		// Allocate space for all the keys.
		keys = new byte[4][44];
		
		// Copy in the first key.
		for (int i=0; i<initKey.length; i++)
			for (int j=0; j<initKey[0].length; j++)
				keys[i][j] = initKey[i][j];
				
		// Fill in the rest of the keys.
		keyExpansion();
		
	}
	
	public static byte[][] copy(byte[][] matrix) {
		
		byte[][] mine = new byte[matrix.length][matrix[0].length];
		
		for (int i=0; i<matrix.length; i++)
			for (int j=0; j<matrix[i].length; j++)
				mine[i][j] = matrix[i][j];
				
		return mine;		
	}
	
	public void addRoundKey(int round) {
		
		// Just XOR all of these guys. The offset is based on the round, 
		// into the key array.
		for (int i=0; i<state.length; i++)
			for (int j=0; j<state[0].length; j++)
				state[i][j] = add(state[i][j], keys[i][4*round+j]);
	}
	
	// Pretty self-exaplanatory.
	public static byte add(byte a, byte b) {
		return (byte)(a ^ b);
	}
	
	
	// Returns b * 2, in the field for AES.
	public static byte mult2(byte b) {
		if (b < 0)
			return (byte)( ((0x7f & b) << 1) ^ 0x1b );
		return (byte)(b << 1);
	}
	
	// Returns the product of a and b, in the field for AES.
	public static byte mult(byte a, byte b) {
		
		byte pow2 = b;
		byte ans = 0;
		
		// Loop until we've shifted all bits of a over.
		while (a != 0) {
			
			// Add in the contribution for this bit of a, if necessary.
			if (a%2 != 0)
		        ans = add(ans, pow2);
		        
		    // Move over to the next bit of a.    
		    a = (byte)(a >> 1);
		    
		    // Multiply pow2 by 2, in the field, accordingly.
		    pow2 = mult2(pow2);
			
		}
		
		return ans;
	}
	
	// Transforms the state matrix using the Shift Rows transformation.
	public void shiftRows() {
		
		// Go to row i.
		for (int i=0; i<state.length; i++) {
			
			// left shift i times.
			for (int j=0; j<i; j++) 
				leftRotate(state[i]);
		}
	}
	
	public static void leftRotate(byte[] array) {
		
		// Save this left one.
		byte temp = array[0];
		
		// Move everyone over by one byte.
		for (int i=0; i<array.length-1; i++)
			array[i] = array[i+1];
			
		// Copy back the last one.
		array[array.length-1] = temp;
	}
	
	// Does subBytes.
	public void subBytes() {
		for (int i=0; i<state.length; i++)
			for (int j=0; j<state[0].length; j++) 
				subBytes(i,j);
	}
	
	// Substitutes the byte and row i, column j in the state matrix.
	public void subBytes(int i, int j) {
				
		int index = state[i][j];
				
		// Adjust for negative byte values.
		if (index < 0)
			index += 256;
					
		state[i][j] = convert(s[index]);		
	}
	
	// Returns the byte substitution for b, directly.
	public static byte subBytes(byte b) {
		
		int index = b;
		
		// Adjust for negative byte values.
		if (b < 0)
			index += 256;
			
		return convert(s[index]);
	}
	
	// Converts an int value in between 0 and 255 to the corresponding byte value
	// in between -128 and 127.
	public static byte convert(int value) {
		if (value < 128)
			return (byte)value;
		else
			return (byte)(value-256);
	}
	
	// For debugging purposes, prints out the state matrix in a 4x4 grid.
	public void print() {
		
		for (int i=0; i<state.length; i++) {
			for (int j=0; j<state.length; j++)
				System.out.printf("%3x", state[i][j]);
			System.out.printf("\n");
		}
		System.out.println();
	}
	
	// Executes the Mix Columns step and stores the answer back in the state.
	public void mixColumns() {
		
		byte[][] temp = new byte[4][4];
		
		// Go through each row and column in answer.
		for (int i = 0; i<temp.length; i++) {
			for (int j=0; j<temp[i].length; j++) {
				
				// Add up the terms for this one entry and store them.
				byte ans = 0;
				for (int k=0; k<temp.length; k++)
					ans = add(ans, mult(mixcol[i][k], state[k][j]));
				temp[i][j] = ans;
			}
		}
		
		// Copy the answer back into state.
		for (int i=0; i<temp.length; i++)
			for (int j=0; j<temp.length; j++)
				state[i][j] = temp[i][j];
		
	}
	
	public void keyExpansion() {
		
		// Go through filling in each new word for all the keys.
		// Note: The initial key has already been filled in.
		for (int i=4; i<keys[0].length; i++) {
			
			// Copy the last word of the previous round key.
			byte[] temp = new byte[4];
			for (int j=0; j<temp.length; j++)
				temp[j] = keys[j][i-1];
				
			// Here is where we do the complicated step.
			if (i%4 == 0) {
				
				// rotWord
				leftRotate(temp);
				
				// subBytes
				for (int j=0; j<temp.length; j++)
					temp[j] = subBytes(temp[j]);
					
				// XOR Rcon
				temp[0] = add(temp[0], rcon[(i-4)/4]);
			}
			
			// XOR between w[i-4] and temp.
			for (int j=0; j<temp.length; j++)
				keys[j][i] = add(keys[j][i-4], temp[j]);
		}
	}
	
	public void printKey(int numWords) {
		
		for (int i=0; i<4; i++) {
		
			for (int j=0; j<numWords; j++) {
				System.out.printf("%3x", keys[i][j]);
			}
			System.out.printf("\n");
		}
	}
	
	/* Prints out answer to question #1, calculating words w[4] through w[11] in
	 * the key schedule. */
	
	public static void main(String[] args) {
		
		byte[][] plain = { {0, 4, 8, 12}, {1, 5, 9, 13}, {2, 6, 10, 14}, {3, 7, 11, 15} };
		byte[][] key = { {-1, -1, -1, -1}, {-1, -1, -1, -1}, {-1, -1, -1, -1}, {-1, -1, -1, -1}};
		AES box = new AES(plain, key);
			
		box.printKey(12);
		
	}
	
	
	/* Prints out answer to question #2, running one round of AES. 
	public static void main(String[] args) {
		
		byte[][] plain = { {0, 4, 8, 12}, {1, 5, 9, 13}, {2, 6, 10, 14}, {3, 7, 11, 15} };
		byte[][] key = { {16, 16, 16, 16}, {16, 16, 16, 16}, {16, 16, 16, 16}, {16, 16, 16, 16}};
		
		AES box = new AES(plain, key);
		
		box.print();
		box.addRoundKey(0);
		box.print();
		box.subBytes();
		box.print();
		box.shiftRows();
		box.print();
		box.mixColumns();
		box.print();
		
	}	*/
	
	
}