import os
import logging
import traceback
from typing import Dict, List, Any

from nemo_skills.inference.server.code_execution_model import get_code_execution_model
from nemo_skills.code_execution.sandbox import get_sandbox
from nemo_skills.prompt.utils import get_prompt

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class EndpointHandler:
    """Custom endpoint handler for NeMo Skills code execution inference."""
    
    def __init__(self):
        """
        Initialize the handler with the model and prompt configurations.
        """
        self.model = None
        self.prompt = None
        self.initialized = False
        
        # Configuration
        self.prompt_config_path = os.getenv("PROMPT_CONFIG_PATH", "generic/math")
        self.prompt_template_path = os.getenv("PROMPT_TEMPLATE_PATH", "openmath-instruct")
        
    def _initialize_components(self):
        """Initialize the model, sandbox, and prompt components lazily."""
        if self.initialized:
            return
            
        try:
            logger.info("Initializing sandbox...")
            sandbox = get_sandbox(sandbox_type="local")
            
            logger.info("Initializing code execution model...")
            self.model = get_code_execution_model(
                server_type="vllm",
                sandbox=sandbox,
                host="127.0.0.1",
                port=5000
            )
            
            logger.info("Initializing prompt...")
            if self.prompt_config_path:
                self.prompt = get_prompt(
                    prompt_config=self.prompt_config_path,
                    prompt_template=self.prompt_template_path
                )
            
            self.initialized = True
            logger.info("All components initialized successfully")
            
        except Exception as e:
            logger.warning(f"Failed to initialize the model")
    
    def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
        """
        Process inference requests.
        
        Args:
            data: Dictionary containing the request data
                Expected keys:
                - inputs: str or list of str - the input prompts/problems
                - parameters: dict (optional) - generation parameters
                
        Returns:
            List of dictionaries containing the generated responses
        """
        try:
            # Initialize components if not already done
            self._initialize_components()
            
            # Extract inputs and parameters
            inputs = data.get("inputs", "")
            parameters = data.get("parameters", {})
            
            # Handle both single string and list of strings
            if isinstance(inputs, str):
                prompts = [inputs]
            elif isinstance(inputs, list):
                prompts = inputs
            else:
                raise ValueError("inputs must be a string or list of strings")
            
            # If we have a prompt template configured, format the inputs
            if self.prompt is not None:
                formatted_prompts = []
                for prompt_text in prompts:
                    formatted_prompt = self.prompt.fill({"problem": prompt_text, "total_code_executions": 8})
                    formatted_prompts.append(formatted_prompt)
                prompts = formatted_prompts
            
            # Get code execution arguments from prompt if available
            extra_generate_params = {}
            if self.prompt is not None:
                extra_generate_params = self.prompt.get_code_execution_args()
            
            # Set default generation parameters
            generation_params = {
                "tokens_to_generate": 12000,
                "temperature": 0.0,
                "top_p": 0.95,
                "top_k": 0,
                "repetition_penalty": 1.0,
                "random_seed": 0,
            }
            
            # Update with provided parameters
            generation_params.update(parameters)
            generation_params.update(extra_generate_params)
            
            logger.info(f"Processing {len(prompts)} prompt(s)")
            
            # Generate responses
            outputs = self.model.generate(
                prompts=prompts,
                **generation_params
            )
            
            # Format outputs
            results = []
            for output in outputs:
                result = {
                    "generated_text": output.get("generation", ""),
                    "code_rounds_executed": output.get("code_rounds_executed", 0),
                }
                results.append(result)
            
            logger.info(f"Successfully processed {len(results)} request(s)")
            return results
            
        except Exception as e:
            logger.error(f"Error processing request: {str(e)}")
            logger.error(traceback.format_exc())
            return [{"error": str(e), "generated_text": ""}]