#!/usr/bin/env python3
import tkinter as tk
import math
import numpy as np
from dataclasses import dataclass
from typing import Tuple, List
import sys

@dataclass
class Vector2D:
    x: float
    y: float

    def __add__(self, other):
        return Vector2D(self.x + other.x, self.y + other.y)

    def __sub__(self, other):
        return Vector2D(self.x - other.x, self.y - other.y)

    def __mul__(self, scalar):
        return Vector2D(self.x * scalar, self.y * scalar)

    def magnitude(self):
        return math.sqrt(self.x**2 + self.y**2)

    def normalize(self):
        mag = self.magnitude()
        if mag == 0:
            return Vector2D(0, 0)
        return Vector2D(self.x / mag, self.y / mag)

    def dot(self, other):
        return self.x * other.x + self.y * other.y

    def rotate(self, angle):
        cos_a = math.cos(angle)
        sin_a = math.sin(angle)
        return Vector2D(self.x * cos_a - self.y * sin_a, self.x * sin_a + self.y * cos_a)

class Ball:
    def __init__(self, pos: Vector2D, velocity: Vector2D, radius: float, color: str):
        self.pos = pos
        self.velocity = velocity
        self.radius = radius
        self.color = color
        self.friction = 0.98
        self.gravity = 0.2

    def update(self):
        self.velocity.y += self.gravity
        self.velocity = self.velocity * self.friction
        self.pos = self.pos + self.velocity

    def draw(self, canvas):
        canvas.create_oval(
            self.pos.x - self.radius,
            self.pos.y - self.radius,
            self.pos.x + self.radius,
            self.pos.y + self.radius,
            fill=self.color,
            outline="#000000"
        )

class Heptagon:
    def __init__(self, center: Vector2D, radius: float, rotation_speed: float):
        self.center = center
        self.radius = radius
        self.rotation_speed = rotation_speed
        self.angle = 0
        self.vertices = self._compute_vertices()
        
    def _compute_vertices(self):
        vertices = []
        for i in range(7):
            angle = self.angle + i * (2 * math.pi / 7)
            x = self.center.x + self.radius * math.cos(angle)
            y = self.center.y + self.radius * math.sin(angle)
            vertices.append(Vector2D(x, y))
        return vertices

    def update(self):
        self.angle += self.rotation_speed
        self.vertices = self._compute_vertices()

    def draw(self, canvas):
        vertices = self.vertices
        coords = []
        for vertex in vertices:
            coords.extend([vertex.x, vertex.y])
        canvas.create_polygon(coords, outline="#000000", fill="", width=2)

    def get_edges(self):
        edges = []
        vertices = self.vertices
        for i in range(len(vertices)):
            j = (i + 1) % len(vertices)
            edges.append((vertices[i], vertices[j]))
        return edges

    def contains(self, point: Vector2D):
        # Check if point is inside the heptagon using ray casting
        # Not a perfect solution for convex polygons, but works here
        # Count intersections from point to infinity in x direction
        count = 0
        for i in range(len(self.vertices)):
            j = (i + 1) % len(self.vertices)
            if (self.vertices[i].y > point.y) != (self.vertices[j].y > point.y):
                # Compute intersection x coordinate
                x_intersection = self.vertices[i].x + (point.y - self.vertices[i].y) * (self.vertices[j].x - self.vertices[i].x) / (self.vertices[j].y - self.vertices[i].y)
                if point.x < x_intersection:
                    count += 1
        return count % 2 == 1

    def get_edge_normal(self, edge_start: Vector2D, edge_end: Vector2D):
        # Get normal vector pointing inward
        edge = edge_end - edge_start
        normal = Vector2D(-edge.y, edge.x)
        normal = normal.normalize()
        # Make sure it points inward
        to_center = self.center - edge_start
        if normal.dot(to_center) < 0:
            normal = normal * -1
        return normal

    def get_collision_response(self, ball: Ball, edge_start: Vector2D, edge_end: Vector2D):
        # Find closest point on edge to ball
        edge = edge_end - edge_start
        edge_len_sq = edge.x**2 + edge.y**2
        if edge_len_sq == 0:
            return Vector2D(0, 0)
        
        t = max(0, min(1, (ball.pos - edge_start).dot(edge) / edge_len_sq))
        closest_point = edge_start + edge * t
        distance = (ball.pos - closest_point).magnitude()
        
        # If too far away, no collision
        if distance > ball.radius:
            return Vector2D(0, 0)
        
        # Normal vector of edge pointing inward
        normal = self.get_edge_normal(edge_start, edge_end)
        # Distance to edge
        distance = (ball.pos - closest_point).magnitude()
        # Push ball out of wall
        penetration = ball.radius - distance
        if penetration > 0:
            ball.pos = ball.pos + normal * penetration
            
        # Reflect velocity
        reflection = ball.velocity.dot(normal)
        ball.velocity = ball.velocity - normal * (2 * reflection)
        return normal

    def collide_with_walls(self, ball: Ball):
        edges = self.get_edges()
        for edge_start, edge_end in edges:
            # Check if ball collides with edge
            edge = edge_end - edge_start
            edge_len_sq = edge.x**2 + edge.y**2
            if edge_len_sq == 0:
                continue
            
            t = max(0, min(1, (ball.pos - edge_start).dot(edge) / edge_len_sq))
            closest_point = edge_start + edge * t
            distance = (ball.pos - closest_point).magnitude()
            
            if distance <= ball.radius:
                normal = self.get_edge_normal(edge_start, edge_end)
                # Move ball out of wall
                penetration = ball.radius - distance
                ball.pos = ball.pos + normal * penetration
                # Reflect velocity
                reflection = ball.velocity.dot(normal)
                ball.velocity = ball.velocity - normal * (2 * reflection)
                
    def get_collision_info(self, ball: Ball):
        # Find collision between ball and walls
        edges = self.get_edges()
        min_distance = float('inf')
        closest_edge = None
        collision_point = Vector2D(0, 0)
        
        for edge_start, edge_end in edges:
            edge = edge_end - edge_start
            edge_len_sq = edge.x**2 + edge.y**2
            if edge_len_sq == 0:
                continue
            
            t = max(0, min(1, (ball.pos - edge_start).dot(edge) / edge_len_sq))
            closest_point = edge_start + edge * t
            distance = (ball.pos - closest_point).magnitude()
            
            if distance <= ball.radius and distance < min_distance:
                min_distance = distance
                closest_edge = (edge_start, edge_end)
                collision_point = closest_point
        
        if closest_edge:
            edge_start, edge_end = closest_edge
            normal = self.get_edge_normal(edge_start, edge_end)
            return normal, collision_point
        return None, None

def check_collision(ball1: Ball, ball2: Ball):
    diff = ball1.pos - ball2.pos
    distance = diff.magnitude()
    if distance < (ball1.radius + ball2.radius):
        # Collision detected
        overlap = (ball1.radius + ball2.radius) - distance
        normal = diff.normalize()
        # Separate balls
        separation = normal * (overlap / 2)
        ball1.pos = ball1.pos + separation
        ball2.pos = ball2.pos - separation
        # Elastic collision response
        relative_velocity = ball1.velocity - ball2.velocity
        relative_velocity_dot_normal = relative_velocity.dot(normal)
        
        # Velocity along normal direction
        v1_normal = ball1.velocity.dot(normal)
        v2_normal = ball2.velocity.dot(normal)
        
        # Update velocities
        if abs(v1_normal - v2_normal) < 0.01:
            # If velocities are nearly equal, swap them
            ball1.velocity = ball1.velocity - normal * v1_normal + normal * v2_normal
            ball2.velocity = ball2.velocity - normal * v2_normal + normal * v1_normal
        else:
            # Simple elastic collision
            ball1.velocity = ball1.velocity - normal * (2 * (v1_normal - v2_normal) / 2)
            ball2.velocity = ball2.velocity + normal * (2 * (v1_normal - v2_normal) / 2)
        
        return True
    return False

class PhysicsSimulation:
    def __init__(self):
        self.root = tk.Tk()
        self.root.title("Bouncing Balls in Spinning Heptagon")
        self.root.geometry("800x600")
        self.canvas = tk.Canvas(self.root, width=800, height=600, bg="#ffffff")
        self.canvas.pack()
        
        self.ball_radius = 15
        self.ball_colors = [
            "#f8b862", "#f6ad49", "#f39800", "#f08300", "#ec6d51",
            "#ee7948", "#ed6d3d", "#ec6800", "#ec6800", "#ee7800",
            "#eb6238", "#ea5506", "#ea5506", "#eb6101", "#e49e61",
            "#e45e32", "#e17b34", "#dd7a56", "#db8449", "#d66a35"
        ]
        
        self.center = Vector2D(400, 300)
        self.heptagon_radius = 200
        self.heptagon = Heptagon(self.center, self.heptagon_radius, math.pi / 80)
        
        self.balls = []
        self.init_balls()
        
        self.root.after(0, self.animate)
        
    def init_balls(self):
        self.balls = []
        for i in range(20):
            angle = 2 * math.pi * i / 20
            pos = self.center + Vector2D(0, 0)  # Start at center
            velocity = Vector2D(0, 0)
            self.balls.append(Ball(pos, velocity, self.ball_radius, self.ball_colors[i]))
            
    def animate(self):
        self.canvas.delete("all")
        
        # Update heptagon rotation
        self.heptagon.update()
        
        # Update ball physics
        for ball in self.balls:
            ball.update()
            
            # Apply gravity
            ball.velocity.y += 0.2
            
            # Apply friction
            ball.velocity = ball.velocity * 0.98
            
            # Check wall collisions
            self.heptagon.collide_with_walls(ball)
            
            # Check collisions with other balls
            for other in self.balls:
                if other is not ball:
                    check_collision(ball, other)
                    
            # Check if ball is out of bounds
            if not self.heptagon.contains(ball.pos):
                # Move back inside
                distance = (ball.pos - self.center).magnitude()
                if distance > self.heptagon_radius - ball.radius:
                    # Move back to edge
                    direction = (ball.pos - self.center).normalize()
                    ball.pos = self.center + direction * (self.heptagon_radius - ball.radius)
                    # Reflect velocity
                    normal = (ball.pos - self.center).normalize()
                    reflection = ball.velocity.dot(normal)
                    ball.velocity = ball.velocity - normal * (2 * reflection)
                    
        # Draw heptagon
        self.heptagon.draw(self.canvas)
        
        # Draw balls
        for ball in self.balls:
            ball.draw(self.canvas)
            
        self.root.after(30, self.animate)

if __name__ == "__main__":
    app = PhysicsSimulation()
    app.root.mainloop()



