import cv2
import numpy as np
import traceback, base64
from datetime import datetime
from pytz import timezone

class Proctoring:
    def __init__(self):
        """
        Initialize without loading heavy models immediately
        """
        self.faceDetector = None
        self.fa = None
        self._models_loaded = False
        print("Proctoring service initialized")

    def _load_models(self):
        """
        Lazy load models when first needed
        """
        if not self._models_loaded:
            try:
                # Import here to avoid circular imports
                from ultralytics import YOLO
                import face_alignment
                
                self.faceDetector = YOLO("services/proctor/models/yolov8n-face.pt")
                self.faceDetector.model.fuse()
                self.fa = face_alignment.FaceAlignment(
                    face_alignment.LandmarksType.TWO_D,
                    flip_input=False,
                    device='cpu'
                )
                self._models_loaded = True
            except Exception as e:
                print(f"Error loading models: {e}")
                raise e

    def analyzeFrame(self, frame):
        """
        Analyze the frame to detect faces and compute head orientation.
        Returns:
            dict with keys: face_detected, people_count, head_pose (yaw, pitch, roll)
        Use case: Core vision engine for your proctoring system.
        """
        # Load models if not already loaded
        if not self._models_loaded:
            self._load_models()
            
        h, w = frame.shape[:2]
        preds = self.faceDetector.predict(source=frame, imgsz=640, verbose=False)[0]

        faceDetected = False
        peopleCount = 0
        headPose = None
        firstFaceAnalyzed = False  # Track if we've analyzed the first valid face

        # Fix: Check if boxes exist
        if preds.boxes is None or len(preds.boxes) == 0:
            return {
                "face_detected": False,
                "people_count": 0,
                "head_pose": None
            }

        # First pass: Count ALL faces
        for box in preds.boxes:
            x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
            
            # Fix: Validate bounding box coordinates
            if x1 >= 0 and y1 >= 0 and x2 <= w and y2 <= h and x2 > x1 and y2 > y1:
                peopleCount += 1
                faceDetected = True

        # Second pass: Analyze ONLY the first valid face for head pose
        for box in preds.boxes:
            if firstFaceAnalyzed:
                break  # Stop after analyzing first face
                
            x1, y1, x2, y2 = map(int, box.xyxy[0].tolist())
            
            # Validate bounding box again
            if not (x1 >= 0 and y1 >= 0 and x2 <= w and y2 <= h and x2 > x1 and y2 > y1):
                continue

            faceCrop = frame[y1:y2, x1:x2]
            
            # Fix: Validate crop
            if faceCrop.size == 0:
                continue
                
            try:
                landmarks = self.fa.get_landmarks_from_image(faceCrop)
            except Exception as e:
                print(f"Landmark detection failed: {e}")
                continue

            if landmarks is None or len(landmarks) == 0:
                continue

            lm = landmarks[0]
            lm += np.array([x1, y1])  # Adjust to full-frame coordinates

            # 6 facial landmarks for pose estimation
            imagePoints = np.array([
                lm[30],  # Nose tip
                lm[8],   # Chin
                lm[36],  # Left eye outer corner
                lm[45],  # Right eye outer corner
                lm[48],  # Left mouth corner
                lm[54],  # Right mouth corner
            ], dtype=np.float64)

            # 3D model reference points
            modelPoints = np.array([
                [0.0, 0.0, 0.0],
                [0.0, -63.6, -12.5],
                [-43.3, 32.7, -26.0],
                [43.3, 32.7, -26.0],
                [-28.9, -28.9, -24.1],
                [28.9, -28.9, -24.1]
            ], dtype=np.float64)

            focalLength = w
            cameraMatrix = np.array([
                [focalLength, 0, w / 2],
                [0, focalLength, h / 2],
                [0, 0, 1]
            ], dtype=np.float64)

            distCoeffs = np.zeros((4, 1))

            success, rot, _ = cv2.solvePnP(modelPoints, imagePoints, cameraMatrix, distCoeffs)
            if success:
                R, _ = cv2.Rodrigues(rot)
                sy = np.sqrt(R[0, 0]**2 + R[1, 0]**2)

                if sy < 1e-6:
                    x = np.arctan2(-R[1, 2], R[1, 1])
                    y = np.arctan2(-R[2, 0], sy)
                    z = 0
                else:
                    x = np.arctan2(R[2, 1], R[2, 2])
                    y = np.arctan2(-R[2, 0], sy)
                    z = np.arctan2(R[1, 0], R[0, 0])

                headPose = {
                    "yaw": round(np.degrees(y), 2),
                    "pitch": round(np.degrees(x), 2),
                    "roll": round(np.degrees(z), 2)
                }
                firstFaceAnalyzed = True  # Mark that we've analyzed one face

        return {
            "face_detected": faceDetected,
            "people_count": peopleCount,
            "head_pose": headPose
        }

    def comprehensiveAnalysis(self, image_b64):
        """
        Comprehensive proctoring analysis including face detection, emotion analysis, and head pose estimation.
        
        Args:
            image_b64 (str): Base64 encoded image string
            
        Returns:
            dict: Complete analysis results with all proctoring data
            
        Raises:
            ValueError: If image decoding fails
            Exception: For other processing errors
        """
        try:
            print("Starting comprehensive analysis")  # Fix: use print instead of console.log
            # Decode base64 image
            imageData = base64.b64decode(image_b64)
            frame = cv2.imdecode(np.frombuffer(imageData, np.uint8), cv2.IMREAD_COLOR)
            if frame is None:
                raise ValueError("Decoded image is None")
        except Exception as e:
            raise ValueError("Invalid base64 image") from e

        # Perform vision analysis
        visionResults = self.analyzeFrame(frame)
        
        # Process head pose data
        headPoseData = None
        headOrientationReadable = "Not detected"
        
        if visionResults["head_pose"]:
            headPoseData = visionResults["head_pose"]
            headOrientationReadable = self.describeHeadPose(
                headPoseData["yaw"],
                headPoseData["pitch"],
                headPoseData["roll"]
            )

        # Perform emotion analysis
        dominantEmotion = None
        try:
            # Import DeepFace locally to avoid circular imports
            from deepface import DeepFace
            
            analysis = DeepFace.analyze(
                img_path=frame,
                actions=['emotion'],
                enforce_detection=False,
                detector_backend='opencv'
            )
            if isinstance(analysis, list) and len(analysis) > 0:
                emotions = analysis[0]['emotion']
                dominantEmotion = self.fuseEmotions(emotions)
        except Exception as e:
            traceback.print_exc()
            # Continue without emotion data if analysis fails

        # Generate timestamp
        istTime = datetime.now(timezone('Asia/Kolkata')).isoformat()

        # Prepare response data
        response_data = {
            "face_detected": visionResults["face_detected"],
            "people_count": visionResults["people_count"],
            "head_pose": headPoseData,
            "dominant_emotion": dominantEmotion,
            "flags": [],  # Add flags logic here if needed
            "timestamp": istTime,
            "head_orientation_readable": headOrientationReadable
        }

        return response_data

    @staticmethod
    def fuseEmotions(emotions: dict) -> str:
        nervous = (emotions['fear'] + emotions['angry'] + emotions['disgust']) / 3
        relaxed = (emotions['happy'] + emotions['neutral']) / 2
        happy = emotions['happy']
        sad = emotions['sad']
        fear = emotions['fear']
        composite = {
        "nervous": nervous,
        "relaxed": relaxed,
        "happy": happy,
        "sad": sad,
        "fear": fear
        }

        return max(composite, key=composite.get)

    @staticmethod
    def describeHeadPose(yaw, pitch, roll):
        orientation = []
        if yaw > 20: orientation.append("looking right")
        elif yaw < -20: orientation.append("looking left")
        # Fix: Corrected pitch thresholds
        if pitch > 20: orientation.append("looking down")    # Positive = down
        elif pitch < -20: orientation.append("looking up")   # Negative = up
        if roll > 15: orientation.append("head tilted right")
        elif roll < -15: orientation.append("head tilted left")
        return ", ".join(orientation) if orientation else "facing forward"
