# Rohan Sangani
# 11/3/2025
# Simple Example Illustrating Elliptic Curve Cryptography
# Adapted from https://www.cs.ucf.edu/~dmarino/ucf/cis3362/progs/ECC.java
import random

from EllipticCurve import EllipticCurve
from Point import Point

class ECC:
    def __init__(self, c, x, y, n_a):
        self.curve = c
        self.generator = Point(c, x, y)
        self.private_key = n_a
        self.public_key = self.generator * self.private_key

    # Assumes plainText is proper blocksize.
    def encryptBlock(self, plainText):
        xVal = convert(plainText)
        plainPt = self.convertMsgToPoint(xVal, 8*len(plainText))
        print("Plain point is")
        print(plainPt)
        return self.encryptPoint(plainPt)

    # Encrypts a single point, returning a list of points [c1,c2] representing the ciphertext
    def encryptPoint(self, plain):
        bits = self.curve.p.bit_length() # same as this.curve.getP().bitLength()
        k = random.getrandbits(bits) # same as new BigInteger(bits, new Random())
        print(f"Picked {k=} for encryption")

        # return both parts of the cipher text
        ans = [self.generator * k, plain + (self.public_key * k)]
        return ans

    # Similar to El-Gamal
    def decryptPoint(self, cipher):
        sub = cipher[0] * self.private_key

        print(f"Decryption result is: {cipher[1]} - {sub}")
        res = cipher[1] - sub
        print("As a single point its")
        print(res)
        print()
        return res

    # Returns the corresponding numeric message for the plaintext point
    # plainPoint when the message is numbits number of bits.
    def decryptPointToMsg(self, plainPoint, numbits):
        return plainPoint.x & ( (1<<numbits)-1 )

    # Decrypts the cipherTextPt assumed to be one block that is numbits
    # long. numbits must be a multiple of 8.
    def decryptBlock(self, cipherTextPt, numbits):
        plainPt = self.decryptPoint(cipherTextPt)
        plainVal = self.decryptPointToMsg(plainPt, numbits)
        return convertBack(plainVal, numbits//8)

    # This method takes in a numeric plaintext message numMsg, which is to
    # be stored in numbits number of bits and returns a corresponding point
    # that can be used for ECC as the corresponding plaintext point.
    # The prime for the curve needs to be larger than numbits bits.
    # I recommend at least 4 to 8 extra bits. For big numbers probably 8.
    def convertMsgToPoint(self, numMsg, numbits):

        limit = self.curve.p
        tempx = 0
        tempy = 0

        # i is my extra most significant bits for the x value.
        for i in range(limit):

            # So the x value we are trying is 2 to the power numbits times i plus numMsg.
            tempx = (i<<numbits) + numMsg

            # Get a matching y value.
            tempy = self.curve.getMatchingY(tempx)

            # I am not going to use points of the form (x,0).
            if tempy != None and tempy != 0:
                break

        # Here is the corresponding point.
        return Point(self.curve, tempx, tempy)

    # toString()
    def __str__(self):
        return (f"Generator: {self.generator}\n"
                f"Private Key: {self.private_key}\n"
                f"Public Key: {self.public_key}")

# Treating each byte as 0 to 256. Converts string to integer value.
def convert(s):
    res = 0
    for i in range(len(s)):
        res = 256*res + ord(s[i])
    return res

# Converts an integer storing numChars characters back to a string.
def convertBack(num, numChars):
    res = ""
    for i in range(numChars):
        res = chr(num%256) + res
        num = num//256
    return res

def newtest1(plainTextStr):

    # I created these from another program. Should have room for 256 bits + 8bits.
    x = 49702535146193300246464588416305760568645130274018063689696125944109251949023191
    a = 19627337416210408032378349494918631544154182870797998756759845167366967474775874
    b = 8471128260126450325184330403682286008517339285690516006313148515457442024816137

    # Make the curve and find some random point on it.
    myC = EllipticCurve(x, a, b)

    # Note: This is a tuple...
    myG = myC.getRandomPoint()

    # I'll let this be the full range, almost. Alice's private key.
    nA = random.randint(2, myC.p-2)

    # Creates the ECC object.
    aliceECC = ECC(myC, myG[0], myG[1], nA)

    for i in range(0, len(plainTextStr), 32):

        # current block.
        block = plainTextStr[i:min(i+32, len(plainTextStr))]

        # Padding...
        while len(block) < 32:
            block = block+"-"

        # Show plain.
        print(block)

        # Encrypt.
        cipherPt = aliceECC.encryptBlock(block)

        # Show cipher.
        print("Encrypted pair is")
        print(cipherPt[0])
        print(cipherPt[1])

        msgBack = aliceECC.decryptBlock(cipherPt, 32*8)
        print(msgBack)

        # Add blank line after.
        print()
    
    

def oldtest():
    curve = EllipticCurve(23, 1, 1) # 23 is prime, a = b = 1
    x = 6
    y = 19
    n_a = 10
    alice = ECC(curve, x, y, n_a)

    # Points starts with just the origin, then repeatedly adds (3, 13)
    points = [Point(curve, 0, 0)]
    temp = Point(curve, 3, 13)
    for i in range (1, 28):
        points.append(points[i-1] + temp)

    # For each of those points, we can encrypt and decrypt it
    for plain in points:
        print(f"Encrypting {plain}")

        cipher = alice.encrypt(plain)
        print(f"Cipher first part: {cipher[0]}")
        print(f"Cipher second part: {cipher[1]}")

        recover = alice.decrypt(cipher)
        print(f"Recovered: {recover}")

# Give it a shot!
newtest1("This is a new test of the refactored code. This is definitely better and more object oriented.")
