# utils/visualize_pca.py
import os
import tempfile
import logging
from functools import lru_cache
from typing import Tuple, Optional, Union, List

import torch
from optipfair.bias import visualize_pca, visualize_mean_differences, visualize_heatmap
from transformers import AutoModelForCausalLM, AutoTokenizer

import matplotlib
matplotlib.use('Agg')  # Use 'Agg' backend for non-GUI environments

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

@lru_cache(maxsize=None)
def load_model_tokenizer(model_name: str):
    """
    Loads the model and tokenizer on the CPU once and caches the result.
    """
    logger.info(f"Loading model and tokenizer for '{model_name}'")
    
    # Get HF token from environment for gated models
    hf_token = os.getenv("HF_TOKEN")
    
    # Device selection: MPS (Apple Silicon) > CUDA > CPU
    if torch.cuda.is_available():
        device = torch.device("cuda")
    elif torch.mps.is_available():
        device = torch.device("mps")
    else:
        device = torch.device("cpu")
    logger.info(f"Using device: {device}")

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        token=hf_token  # ← AÑADIR ESTA LÍNEA
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_name,
        token=hf_token  # ← AÑADIR ESTA LÍNEA
    )

    model = model.to(device)
    
    logger.info(f"Model loaded on device: {next(model.parameters()).device}")

    return model, tokenizer

def run_visualize_pca(
    model_name: str,
    prompt_pair: Tuple[str, str],
    layer_key: str,
    highlight_diff: bool = True,
    output_dir: Optional[str] = None,
    figure_format: str = "png",
    pair_index: int = 0,
) -> str:
    if output_dir is None:
        output_dir = tempfile.mkdtemp(prefix="optipfair_pca_")
    os.makedirs(output_dir, exist_ok=True)

    model, tokenizer = load_model_tokenizer(model_name)

    visualize_pca(
        model=model,
        tokenizer=tokenizer,
        prompt_pair=prompt_pair,
        layer_key=layer_key,
        highlight_diff=highlight_diff,
        output_dir=output_dir,
        figure_format=figure_format,
        pair_index=pair_index
    )

    layer_parts = layer_key.split("_")
    layer_type = "_".join(layer_parts[:-1])
    layer_num = layer_parts[-1]
    filename = build_visualization_filename(
        vis_type="pca",
        layer_type=layer_type,
        layer_num=layer_num,
        pair_index=pair_index,
        figure_format=figure_format
    )
    filepath = os.path.join(output_dir, filename)

    if not os.path.isfile(filepath):
        raise FileNotFoundError(f"Expected image file not found: {filepath}")

    logger.info(f"PCA image saved at {filepath}")
    return filepath

def run_visualize_mean_diff(
    model_name: str,
    prompt_pair: Tuple[str, str],
    layer_type: str,  # Changed from layer_key to layer_type
    figure_format: str = "png",
    output_dir: Optional[str] = None,
    pair_index: int = 0,
) -> str:
    if output_dir is None:
        output_dir = tempfile.mkdtemp(prefix="optipfair_mean_diff_")
    os.makedirs(output_dir, exist_ok=True)

    model, tokenizer = load_model_tokenizer(model_name)

    visualize_mean_differences(
        model=model,
        tokenizer=tokenizer,
        prompt_pair=prompt_pair,
        layer_type=layer_type,
        layers="all",  # By default, show all layers
        output_dir=output_dir,
        figure_format=figure_format,
        pair_index=pair_index
    )

    filename = build_visualization_filename(
        vis_type="mean_diff",
        layer_type=layer_type,
        pair_index=pair_index,
        figure_format=figure_format
    )
    filepath = os.path.join(output_dir, filename)
    if not os.path.isfile(filepath):
        raise FileNotFoundError(f"Expected image file not found: {filepath}")
    logger.info(f"Mean-diff image saved at {filepath}")
    return filepath

def run_visualize_heatmap(
    model_name: str,
    prompt_pair: Tuple[str, str],
    layer_key: str,
    figure_format: str = "png",
    output_dir: Optional[str] = None,
    pair_index: int = 0,
) -> str:
    if output_dir is None:
        output_dir = tempfile.mkdtemp(prefix="optipfair_heatmap_")
    os.makedirs(output_dir, exist_ok=True)

    model, tokenizer = load_model_tokenizer(model_name)

    visualize_heatmap(
        model=model,
        tokenizer=tokenizer,
        prompt_pair=prompt_pair,
        layer_key=layer_key,
        output_dir=output_dir,
        figure_format=figure_format,
        pair_index=pair_index
    )

    parts = layer_key.split("_")
    layer_type = "_".join(parts[:-1])
    layer_num = parts[-1]
    filename = build_visualization_filename(
        vis_type="heatmap",
        layer_type=layer_type,
        layer_num=layer_num,
        pair_index=pair_index,
        figure_format=figure_format
    )
    filepath = os.path.join(output_dir, filename)
    if not os.path.isfile(filepath):
        raise FileNotFoundError(f"Expected image file not found: {filepath}")
    logger.info(f"Heatmap image saved at {filepath}")
    return filepath

def build_visualization_filename(
    vis_type: str,
    layer_type: str,
    layer_num: str = None,
    layers: Union[str, List[int]] = None,
    pair_index: int = 0,
    figure_format: str = "png"
) -> str:
    """
    Builds the filename for any visualization.
    """
    if vis_type == "mean_diff":
        # The visualize_mean_differences function does not include the layer number in the filename
        return f"mean_diff_{layer_type}_pair{pair_index}.{figure_format}"
    elif vis_type in ("pca", "heatmap"):
        return f"{vis_type}_{layer_type}_{layer_num}_pair{pair_index}.{figure_format}"
    else:
        raise ValueError(f"Unknown visualization type: {vis_type}")