from llm_engineering.domain.queries import Query, EmbeddedQuery 
from sentence_transformers import SentenceTransformer
import torch
from PIL import Image
import numpy as np
import logging
import re

# Make transformers optional
try:
    from transformers import CLIPProcessor, CLIPModel
    TRANSFORMERS_AVAILABLE = True
except ImportError:
    TRANSFORMERS_AVAILABLE = False
    print("Transformers library not available, using fallback text-only embeddings")


class TextEmbedder:
    def __init__(self, model_name="all-MiniLM-L6-v2"):
        # Force CPU usage for text embedding
        self.device = "cpu"
        self.model = SentenceTransformer(model_name, device="cpu")
    
    # def to(self, device: str):
    #     """Move the model to a specific device"""
    #     self.device = device
    #     self.model = self.model.to(device)
    #     return self  # Allow method chaining
    
    def encode(self, text: str) -> list[float]:
        with torch.no_grad():
            return self.model.encode(text, device="cpu", convert_to_tensor=False).tolist()


class MultimodalEmbeddedQuery:
    def __init__(self, text_embed: list[float], image_embed: list[float]):
        self.embedding = torch.cat([
            torch.tensor(text_embed), 
            torch.tensor(image_embed)
        ]).tolist()


class MultimodalEmbeddingDispatcher:
    @staticmethod
    def dispatch(query: Query) -> EmbeddedQuery:
        if TRANSFORMERS_AVAILABLE:
            embedder = ImageEmbedder()
            embedding = embedder.encode_text(query.content)
        else:
            # Fallback to text-only embedder
            embedder = TextEmbedder()
            embedding = embedder.encode(query.content)
            
        return EmbeddedQuery(
            embedding=embedding,
            content=query.content,
            metadata=query.metadata
        )


class ImageEmbedder:
    def __init__(self, model_name="openai/clip-vit-base-patch32"):
        # Always initialize fallback embedder first to ensure it exists
        print("Initializing fallback TextEmbedder")
        self.fallback_embedder = TextEmbedder()
        
        if not TRANSFORMERS_AVAILABLE:
            # Create a simple fallback embedder
            print("Transformers not available - using fallback text embedder")
            self.model = None
            self.processor = None
            return
            
        self.device = "cpu"
        try:
            print("Loading CLIP model: {}".format(model_name))
            self.model = CLIPModel.from_pretrained(model_name).to(self.device)
            self.processor = CLIPProcessor.from_pretrained(model_name)
            print("CLIP model loaded successfully")
        except Exception as e:
            logging.warning("Failed to load CLIP model: {}".format(e))
            self.model = None
            self.processor = None
            print("Creating fallback text embedder due to CLIP load failure: {}".format(e))

    def encode(self, image_path: str) -> list[float]:
        """Image embedding (512-dim)"""
        if not TRANSFORMERS_AVAILABLE or self.model is None:
            print("Using placeholder embedding (512-dim) due to missing CLIP model")
            # Return a placeholder embedding of the right size (512)
            return [0.0] * 512
            
        try:
            print("Loading image from: {}".format(image_path))
            image = Image.open(image_path).convert("RGB")
            inputs = self.processor(images=image, return_tensors="pt").to(self.device)
            with torch.no_grad():
                output = self.model.get_image_features(**inputs)[0].cpu().numpy().tolist()
                if len(output) != 512:
                    print("Warning: CLIP model output has {} dimensions, normalizing to 512".format(len(output)))
                    if len(output) < 512:
                        output = output + [0.0] * (512 - len(output))
                    else:
                        output = output[:512]
                return output
        except Exception as e:
            logging.warning("Failed to encode image: {}".format(e))
            print("Returning zero embedding (512-dim) due to encoding error: {}".format(e))
            return [0.0] * 512

    def encode_text(self, text: str) -> list[float]:
        """Text embedding using CLIP's text encoder (512-dim)"""
        if not TRANSFORMERS_AVAILABLE or self.model is None:
            print("CLIP not available, using fallback text embedder")
            return self._get_normalized_text_embedding(text)
            
        try:
            # Clean and preprocess the text for CLIP
            try:
                # Clean the text - remove special characters that might cause problems
                # Remove excessive whitespace, newlines, etc.
                text = re.sub(r'\s+', ' ', text).strip()
                # Remove or replace problematic characters
                text = re.sub(r'[^\w\s.,!?\'"-]', '', text)
                
                # Limit text length aggressively to avoid tokenization issues
                if len(text) > 300:  # CLIP has limited context window
                    print("Text too long for CLIP ({}), truncating to 300 chars".format(len(text)))
                    text = text[:300]  # Truncate to avoid tensor size issues
                
                print("Cleaned text for CLIP: {}...".format(text[:50] if len(text) > 50 else text))
            except Exception as text_clean_error:
                print("Error cleaning text: {}. Using fallback.".format(text_clean_error))
                # Just truncate if cleaning fails
                if len(text) > 300:
                    text = text[:300]
            
            # Try to encode with CLIP with explicit max length
            try:
                # Use explicit max_length to avoid tensor size mismatches
                inputs = self.processor(
                    text=text,
                    return_tensors="pt",
                    padding="max_length",
                    max_length=77,  # CLIP's standard context length
                    truncation=True
                ).to(self.device)
                
                with torch.no_grad():
                    output = self.model.get_text_features(**inputs)[0].cpu().numpy().tolist()
                    if len(output) != 512:
                        print("Normalizing CLIP output from {} to 512 dimensions".format(len(output)))
                        if len(output) < 512:
                            output = output + [0.0] * (512 - len(output))
                        else:
                            output = output[:512]
                    return output
            except RuntimeError as e:
                print("CLIP encoding error: {}".format(e))
                if "size mismatch" in str(e) or "dimension" in str(e).lower():
                    print("Tensor size mismatch in CLIP, using fallback")
                    return self._get_normalized_text_embedding(text)
                raise
        except Exception as e:
            logging.warning("Failed to encode text with CLIP: {}".format(e))
            print("Using fallback text embedder due to error: {}".format(e))
            return self._get_normalized_text_embedding(text)
    
    def _get_normalized_text_embedding(self, text: str) -> list[float]:
        """Helper to get normalized text embeddings from the fallback embedder"""
        try:
            if self.fallback_embedder is None:
                print("Fallback embedder is None, initializing...")
                self.fallback_embedder = TextEmbedder()
                
            embed = self.fallback_embedder.encode(text)
            # Ensure 512 dimensions for compatibility
            if len(embed) < 512:
                print("Padding fallback embedding from {} to 512 dimensions".format(len(embed)))
                embed = embed + [0.0] * (512 - len(embed))
            elif len(embed) > 512:
                print("Truncating fallback embedding from {} to 512 dimensions".format(len(embed)))
                embed = embed[:512]
            return embed
        except Exception as e:
            print("Error in fallback embedding: {}".format(e))
            # Last resort: return zeros
            return [0.0] * 512
    
    def encode_batch(self, image_paths: list) -> list:
        if not TRANSFORMERS_AVAILABLE or self.model is None:
            print("CLIP not available for batch encoding, returning placeholders")
            # Return placeholder embeddings
            return [[0.0] * 512 for _ in range(len(image_paths))]
            
        try:
            print("Batch encoding {} images with CLIP".format(len(image_paths)))
            with torch.inference_mode():
                images = []
                for path in image_paths:
                    try:
                        img = Image.open(path).convert("RGB")
                        images.append(img)
                    except Exception as e:
                        print("Error opening image {}: {}".format(path, e))
                        # Add a black image as placeholder
                        images.append(Image.new('RGB', (224, 224), color='black'))
                
                if not images:
                    print("No valid images to process")
                    return [[0.0] * 512]
                    
                inputs = self.processor(images=images, return_tensors="pt").to(self.device)
                outputs = self.model.get_image_features(**inputs).cpu().numpy().tolist()
                
                # Ensure each output has 512 dimensions
                normalized_outputs = []
                for output in outputs:
                    if len(output) != 512:
                        if len(output) < 512:
                            output = output + [0.0] * (512 - len(output))
                        else:
                            output = output[:512]
                    normalized_outputs.append(output)
                
                return normalized_outputs
        except Exception as e:
            logging.warning("Failed to batch encode images: {}".format(e))
            print("Returning placeholder embeddings due to batch encoding error: {}".format(e))
            return [[0.0] * 512 for _ in range(len(image_paths))]