import tempfile
import time
import subprocess
import os
import json
from pathlib import Path
import concurrent.futures
from dotenv import load_dotenv
from datetime import datetime
import yaml
import argparse
from typing import Dict, Any
from tqdm import tqdm
from tools.lighteval.get_model_providers import get_model_providers

def run_lighteval(model_name: str, provider: str) -> dict:
    start_time = time.time()
    print(f"[{datetime.now().strftime('%H:%M:%S')}] Starting evaluation with {provider} provider for {model_name}")
    
    # Create temporary task file
    temp_file_path = tempfile.mktemp(suffix=".py")
    with open(temp_file_path, 'w') as temp_file:
        temp_file.write("""
from lighteval_task.lighteval_task import create_yourbench_task

# Create yourbench task
yourbench = create_yourbench_task("yourbench/yourbench_test", "single_shot_questions")

# Define TASKS_TABLE needed by lighteval
TASKS_TABLE = [yourbench]
""")

    # LightEval command
    cmd_args = [
        "lighteval",
        "endpoint",
        "inference-providers",
        f"model_name={model_name},provider={provider}",
        "custom|yourbench|0|0",
        "--custom-tasks",
        temp_file_path,
        "--max-samples", "3",
        "--output-dir", "data/lighteval_results",
        # "--save-details",
        "--no-push-to-hub"
    ]

    try:
        # Run the command with environment variables and timeout of 60 seconds
        subprocess.run(cmd_args, env=os.environ, timeout=60)
    except subprocess.TimeoutExpired:
        print(f"[{datetime.now().strftime('%H:%M:%S')}] Evaluation timed out for {model_name} after {time.time() - start_time:.2f}s")
        return {
            "model": model_name,
            "provider": provider,
            "accuracy": 0.0,
            "execution_time": 60.0,
            "status": "timeout"
        }

    # Calculate execution time
    execution_time = time.time() - start_time
    print(f"[{datetime.now().strftime('%H:%M:%S')}] Finished evaluation for {model_name} in {execution_time:.2f}s")

    # Clean up
    os.unlink(temp_file_path)

    try:
        # Get results from the output file
        results_dir = Path("data/lighteval_results/results") / model_name.replace("/", "/")
        results_file = next(results_dir.glob("results_*.json"))
        
        with open(results_file) as f:
            results = json.load(f)
            accuracy = results["results"]["all"]["accuracy"]

        return {
            "model": model_name,
            "provider": provider,
            "accuracy": accuracy,
            "execution_time": execution_time,
            "status": "success"
        }
    except Exception as e:
        print(f"[{datetime.now().strftime('%H:%M:%S')}] Failed to parse results for {model_name} after {execution_time:.2f}s: {str(e)}")
        return {
            "model": model_name,
            "provider": provider,
            "accuracy": 0.0,
            "execution_time": execution_time,
            "status": "parse_error"
        }

def main():
    # Start global timer
    script_start_time = time.time()
    
    # Load environment variables
    load_dotenv()

    # Models to evaluate
    models = [
        "Qwen/QwQ-32B",
        "Qwen/Qwen2.5-72B-Instruct",
        "deepseek-ai/DeepSeek-V3-0324",
        "deepseek-ai/DeepSeek-R1-Distill-Llama-70B",
    ]

    # Get providers for each model
    model_providers = get_model_providers(models)
    
    print(f"[{datetime.now().strftime('%H:%M:%S')}] Starting parallel evaluations")
    
    # Run evaluations in parallel using ProcessPoolExecutor
    with concurrent.futures.ProcessPoolExecutor() as executor:
        futures = [
            executor.submit(run_lighteval, model_name, providers[0]) 
            for model_name, providers in model_providers 
            if providers  # Only run if providers are available
        ]
        results = [future.result() for future in concurrent.futures.as_completed(futures)]

    # Calculate total script execution time
    total_time = time.time() - script_start_time
    print(f"[{datetime.now().strftime('%H:%M:%S')}] All evaluations completed in {total_time:.2f}s")

    # Print results in order
    print("\nResults:")
    print("-" * 80)
    for result in results:
        print(f"Model: {result['model']}")
        print(f"Provider: {result['provider']}")
        print(f"Accuracy: {result['accuracy']:.2f}")
        print(f"Execution time: {result['execution_time']:.2f}s")
        print("-" * 80)

if __name__ == "__main__":
    main()