# Rohan Sangani
# 11/3/2025
# Stores a single point on a specific Elliptic Curve
# Adapted from https://www.cs.ucf.edu/~dmarino/ucf/cis3362/progs/Point.java

# I wanted to use type annotations in all of these files,
# but I believe the version of python Professor Guha uses doesn't support it

from EllipticCurve import EllipticCurve

class Point:
    def __init__(self, c, x, y):
        self.x = x
        self.y = y
        self.curve = c

        # ensure the point is actually on the curve (or is the origin)
        if not self.is_on_curve(self.curve):
            raise ValueError("Given x and y coordinates do not lie on the given curve!")


    def is_on_curve(self, curve):
        # while the origin isn't necessarily on the curve, it is a legal point, and so we have to make a special case
        if self.is_origin():
            return True

        # note: equation is y^2 = x^3 + ax + b (mod p)
        # implemented as y**2 == ((x**3 % p) + (a*x % p) + b) % p
        left_side = (self.y * self.y) % curve.p
        right_side = (pow(self.x, 3, curve.p) + ((curve.a * self.x) % curve.p) + curve.b) % curve.p

        return left_side == right_side

    def __copy__(self):
        return Point(self.curve, self.x, self.y)

    # this was originally an overload of Point() when only a curve was passed in
    @staticmethod
    def origin(curve):
        return Point(curve, 0, 0)

    # the == operator
    def __eq__(self, other):
        return (self.x == other.x and
                self.y == other.y and
                self.curve.__eq__(other.curve))

    def is_origin(self):
        return self.x == 0 and self.y == 0

    # returns true iff other is the point's reflection over y = p/2 (same curve and x)
    def is_mirror(self, other):
        if self.is_origin() and other.is_origin():
            return True
        return (self.x == other.x and
                self.curve == other.curve and
                self.y == (other.curve.p - other.y))

    # Negation method gives the mirror; is_mirror(Point, -Point) is always True
    def __neg__(self):
        if self.is_origin():
            return self
        new_y = self.curve.p - self.y
        return Point(self.curve, self.x, new_y)

    # Add self to another point "other"
    def __add__(self, other):
        if self.curve != other.curve:
            raise ValueError("Adding two points requires the curves to be the same!")

        # Adding to itself
        if self == other:
            # I don't really understand this statement, but apparently it avoids adding to the origin
            if self.y == 0:
                return Point.origin(self.curve)

            lambda_value = pow(self.x, 2, self.curve.p)
            lambda_value *= 3
            lambda_value += self.curve.a

            denominator = 2 * self.y
            lambda_value *= pow(denominator, -1, self.curve.p) # pow(x, -1, p) is x's modular inverse mod p (x^-1 mod p)

            new_x = (lambda_value * lambda_value - self.x - self.x) % self.curve.p
            new_y = ((lambda_value * (self.x - new_x)) - self.y) % self.curve.p

            return Point(self.curve, new_x, new_y)

        # Since we return in the if statement, no else or elif is needed

        # Point + (-Point) = (0, 0)
        if self.is_mirror(other):
            return Point.origin(self.curve)

        # Standard case

        # (0, 0) + Point = Point
        if self.is_origin():
            return other.__copy__()

        # Point + (0, 0) = Point
        if other.is_origin():
            return self.__copy__()

        lambda_value = other.y - self.y
        denominator = other.x - self.x

        # same calculation as above, just with a different denominator and starting lambda
        lambda_value *= pow(denominator, -1, self.curve.p)

        new_x = (lambda_value * lambda_value - self.x - other.x) % self.curve.p
        new_y = ((lambda_value * (self.x - new_x)) - self.y) % self.curve.p

        return Point(self.curve, new_x, new_y)

    # Subtraction operator
    def __sub__(self, other):
        return self + -other # -other calls the Point.__neg__ method, which makes this easy

    # Multiplication operator, using the same speed-up process as fast modular exponentation
    def __mul__(self, factor):
        if factor == 0:
            return Point.origin(self.curve)

        if factor % 2 == 0:
            sqrt = self * (factor // 2)
            return sqrt + sqrt

        factor -= 1
        return self + (self * factor)

    # toString()
    def __str__(self):
        return f"({self.x}, {self.y})"

def main():
    curve = EllipticCurve(23, 1, 1)
    p = Point(curve, 3, 10)
    q = Point(curve, 9, 7)

    pq_sum = p + q # "sum" is a reserved word
    print(f"{p} + {q} = {pq_sum}")

    twice = p + p
    print(f"{p} + {p} = {twice}")

    res = p
    for i in range(1, 29):
        print(f"{i}. {res}")
        res += p

if __name__ == '__main__':
    main()