// Arup Guha
// 11/7/07
// Simulation of the Knapsack Cipher for CIS 3362.

import java.util.*;
import java.math.BigInteger;

public class Knapsack {

	// Components of the object.	
	BigInteger[] privateKeys;
	BigInteger[] publicKeys;
	BigInteger u;
	BigInteger w;
	
	public Knapsack(int setSize, Random r) {
		
		BigInteger sum = BigInteger.ZERO;
		
		// Allocating space for the set of private keys.
		privateKeys = new BigInteger[setSize];
		
		// Pick the first value in the set to be pretty small (2-10).
		int lowkey = Math.abs(r.nextInt())%9+2;
		privateKeys[0] = new BigInteger((new Integer(lowkey)).toString());
		
		// Set this first key.
		sum = privateKeys[0];
		
		// Set the rest of the keys.
		for (int i=1; i<setSize; i++) {
			
			// Pick a reasonably small offset to add to the sum of all the 
			// previous elements.
			int nextval = Math.abs(r.nextInt())%20+1;
			BigInteger addVal = new BigInteger((new Integer(nextval)).toString());
			
			// Set the next private Key element.
			privateKeys[i] = sum.add(addVal);
			
			// Update the sum.
			sum = sum.add(privateKeys[i]);
		}
		
		// Pick u to be a little bit bigger than twice the last element in the
		// private key set.
		u = privateKeys[setSize-1].multiply(new BigInteger("2"));
		int nextval = Math.abs(r.nextInt())%20+1;
		BigInteger addVal = new BigInteger((new Integer(nextval)).toString());
		u = u.add(addVal);
	
		// Find a value w relatively prime to u.
		BigInteger tryval;
		while (true) {
			
			// Try out a random value.
			tryval = new BigInteger(u.bitCount()-1,r);
			
			// If it works, break out of the loop!
			if (u.gcd(tryval).equals(BigInteger.ONE))
				break;
		}	
		
		// Set up w to be the number we found relatively prime to u.
		w = tryval;
		
		// Now we can set up the public keys!
		publicKeys = new BigInteger[setSize];
		for (int i=0; i<setSize; i++) {
			publicKeys[i] = (w.multiply(privateKeys[i])).mod(u);
		}
	}
	
	// Pre-condition: binary is the same length as the set size for this object,
	//                and binary only contains the characters '0' and '1'.
	// Post-condition: The encryption of binary is returned.
	public BigInteger encrypt(String binary) {
		
		// We can only encrypt if the # of input bits is equal to the size
		// of our encryption set.
		if (binary.length() != publicKeys.length)
			return null;
			
		BigInteger answer = BigInteger.ZERO;
		
		// Loop through each bit, adding the corresponding integer from the
		// set ONLY if the bit is set to 1.
		for (int i=0; i<binary.length(); i++)
			if (binary.charAt(i) == '1')
				answer = answer.add(publicKeys[i]);
			
		// Here's our ciphertext!
		return answer;
	}
	
	// Pre-condition: ciphertext is a valid ciphertext for this object.
	// Post-condition: The corresponding plaintext bit string will be returned.
	public String decrypt(BigInteger ciphertext) {
		
		// Not necessary, but this scales down the value of the ciphertext.
		ciphertext = ciphertext.mod(u);
		
		// Print out private key w-inverse mod u.
		System.out.println("We must multiply our ciphertext by "+w.modInverse(u));
		
		// First recover our ADJUSTED sum for our easy set of numbers.
		BigInteger easy = (ciphertext.multiply(w.modInverse(u))).mod(u);
		
		System.out.println("When we do that, we obtain "+easy+"\n");
		
		// Print out the private keys so we can verify the decryption process.
		System.out.println("With this and the private set:");
		printPrivateKeys();
		System.out.println();
		
		System.out.println("We can easily determine which values were chosen and reconstruct the plaintext.");
		
		int startIndex = privateKeys.length-1;
		String plain = "";
		
		// We have to go through our set backwards, from largest values to
		// smallest values.
		while (startIndex >=0) {
			
			// In this case, DON'T add this number in the subset, so we get a
			// 0 for this bit of the plantext.
			if (easy.compareTo(privateKeys[startIndex]) < 0) 
				plain = "0" + plain;
				
			// Here we get a 1 for the plaintext bit
			else {
				
				plain = "1" + plain;
				easy = easy.subtract(privateKeys[startIndex]);
			}
			
			// Go to figure out the next bit.
			startIndex--;
		}
		
		return plain;
	}
	
	// Prints out the public key set.
	public void printPublicKeys() {
		
		for (int i=0; i<publicKeys.length; i++)
			System.out.print(publicKeys[i]+" ");
		System.out.println();
	}
	
	// Prints out the public key set.
	public void printPrivateKeys() {
		
		for (int i=0; i<privateKeys.length; i++)
			System.out.print(privateKeys[i]+" ");
		System.out.println();
	}
	
	public static void main(String[] args) throws Exception {
		
		Scanner stdin = new Scanner(System.in);
		Random r = new Random();
		
		System.out.println("Please enter your desired block size in bits.");
		int blocksize = stdin.nextInt();
		
		// Create a new Knapsack object.
		Knapsack test = new Knapsack(blocksize, r);
		
		// Get the input.
		System.out.println("Please enter your string of "+blocksize/8+" characters.");
		String plaintext = stdin.next();
		
		// Here is the input converted to bits.
		String convPlain = convertToBits(plaintext);
		System.out.println("Your text has been converted to the bitstring "+convPlain);
		
		// Compute and print out the ciphertext, plus public keys.
		BigInteger ciphertext = test.encrypt(convPlain);
		System.out.println("\nYour ciphertext is "+ciphertext);
		
		// Print out the public keys, so we can verify encryption.
		System.out.println("It was computed using the public key set:");
		test.printPublicKeys();
		System.out.println();
		
		// Now decrypt and print out.
		String plainback = test.decrypt(ciphertext);
		
		System.out.println("We have recovered bitstring plaintext "+plainback);
		
		// Convert the bits back to regular chars.
		String convPlainBack = convertToCharString(plainback);
		System.out.println("Here is the corresponding text: "+convPlainBack);
		
	}
	
	// Returns a bitstring representation of c, according to c's ascii value.
	public static String convertToBits(char c) {
		String ans = "";
		int val = (int)c;
		
		// Peel off the digits one by one from least significant to most.
		for (int i=0; i<8; i++) {
			
			// Prepend the string with the appropriate character.
			if (val%2 == 0)
				ans = "0" + ans;
			else
				ans = "1" + ans;
			val /= 2;
		}
		
		return ans;
	}
	
	// Returns a bitstring representation of s, according to the ascii values
	// of each character in s.
	public static String convertToBits(String s) {
		
		String ans = "";
		
		// Just convert each character, one at a time and concatenate.
		for (int i=0; i<s.length(); i++)
			ans = ans + convertToBits(s.charAt(i));
		return ans;
	}
	
	// Pre-condition: s MUST be of length 8 and contain only the characters
	//                '0' and '1'.
	// Post-condition: The character with the ascii value denoted by s will be
	//                 returned.
	public static char convertToChar(String s) {
		
		int val = 0;
		
		// Reverse the process, for each character, add in the appropriate value
		// using Horner's method.
		for (int i=0; i<8; i++)
			if (s.charAt(i) == '0')
				val = 2*val;
			else
				val = 2*val + 1;
				
		return (char)val;
	}
	
	// Pre-condition: the length of s MUST BE a multiple of 8 and s must ONLY
	//                contain the characters '0' and '1'.
	// Post-condition: A string with the characters that correspond to the
	//                 ascii values denoted in s is returned.
	public static String convertToCharString(String s) {
		
		char[] ans = new char[s.length()/8];
		
		for (int i=0; i<ans.length; i++) {
			ans[i] = convertToChar(s.substring(8*i,8*i+8));
		}
		
		return new String(ans);
	}
	
}
		