from typing import List, Dict, Union, Optional
import tiktoken
from ..embedding_provider import EmbeddingProvider
import numpy as np

class OpenAIEmbedding(EmbeddingProvider):
    def __init__(
        self,
        api_key: Optional[str] = None,
        model: str = "text-embedding-3-small",
        max_tokens: int = 8191
    ) -> None:
        """Initialize OpenAI embedding provider
        
        Args:
            model_name (str, optional): Name of the embedding model. Default to "text-embedding-3-small"
                more info: https://platform.openai.com/docs/models#embeddings
            api_key: api_key for OpenAI
        """
        from openai import OpenAI
        
        self.client = OpenAI(api_key=api_key)
        self.model = model
        self.max_tokens = max_tokens
        self.tokenizer = tiktoken.encoding_for_model(model)
        
    def _trancated_text(self, text: str) -> str:
        """Truncate text into maximum token length

        Args:
            text (str): Input text

        Returns:
            str: Truncated text
        """
        tokens = self.tokenizer.encode(text)
        truncated_tokens = tokens[:self.max_tokens]
        return self.tokenizer.decode(truncated_tokens)
        
    def embed_documents(
        self,
        documents: List[str],
        batch_size: int = 100
    ) -> np.array:
        """Embed a list of documents

        Args:
            documents (List[str]): List of documents to embed

        Returns:
            np.array: embeddings of documents
        """
        truncated_docs = [self._trancated_text(doc) for doc in documents]
        
        embeddings = []
        for i in range(0, len(truncated_docs), batch_size):
            batch = truncated_docs[i: i+batch_size]
            
            response = self.client.embeddings.create(
                input=batch,
                model=self.model
            )
            batch_embeddings = [
                embed.embedding for embed in response.data
            ]
            embeddings.extend(batch_embeddings)
            
        return np.array(embeddings)
        
    def embed_query(self, query):
        truncated_query = self._trancated_text(query)
        
        response = self.client.embeddings.create(
            input=[truncated_query],
            model=self.model
        )
        return np.array(response.data[0].embedding)
    
    def get_embedding_info(self) -> Dict[str, Union[str, int]]:
        """
        Get information about the current embedding configuration
        
        Returns:
            Dict: Embedding configuration details
        """
        return {
            "model": self.model,
            "max_tokens": self.max_tokens,
            "batch_size": 100,  # Default batch size
        }
    
    def list_available_models(self) -> List[str]:
        """
        List available OpenAI embedding models
        
        Returns:
            List[str]: Available embedding model names
        """
        return [
            "text-embedding-ada-002",  # Most common
            "text-embedding-3-small",  # Newer, more efficient
            "text-embedding-3-large"   # Highest quality
        ]
    
    def estimate_cost(self, num_documents: int) -> float:
        """
        Estimate embedding cost
        
        Args:
            num_documents (int): Number of documents to embed
        
        Returns:
            float: Estimated cost in USD
        """
        # Pricing as of 2024 (subject to change)
        pricing = {
            "text-embedding-ada-002": 0.0001 / 1000,  # $0.0001 per 1000 tokens
            "text-embedding-3-small": 0.00006 / 1000,
            "text-embedding-3-large": 0.00013 / 1000
        }
        
        # Estimate tokens (assuming ~100 tokens per document)
        total_tokens = num_documents * 100
        
        return total_tokens * pricing.get(self.model, pricing["text-embedding-ada-002"])