import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import json
import csv

def create_performance_plot(csv_path='benchmark_results.csv', metadata_path='metadata.json'):
    # Define whitelist of interesting models (partial matches)
    WHITELIST = [
        'Meta-Llama-3.1-70B-Instruct'
    ]

    # Read the benchmark results with error handling for inconsistent rows
    valid_rows = []
    expected_fields = 14  # Number of expected fields in each row

    with open(csv_path, 'r') as f:
        reader = csv.reader(f)
        header = next(reader)  # Get header row
        # Strip whitespace from header names
        header = [h.strip() for h in header]
        for row in reader:
            if len(row) == expected_fields:  # Only keep rows with correct number of fields
                # Strip whitespace from values
                valid_rows.append([val.strip() for val in row])

    # Create DataFrame from valid rows
    df = pd.DataFrame(valid_rows, columns=header)

    # Read model sizes from metadata
    with open(metadata_path, 'r') as f:
        metadata = json.load(f)

    # Process the data
    # Keep only successful runs (where Benchmark Score is not FAILED)
    df = df[df['Benchmark Score'] != 'FAILED']
    df = df[df['Benchmark Score'].notna()]
    # Convert score to numeric, handling invalid values
    df['Benchmark Score'] = pd.to_numeric(df['Benchmark Score'], errors='coerce')
    df = df[df['Benchmark Score'].notna()]  # Remove rows where conversion failed

    # Convert Num Questions Parseable to numeric and calculate adjusted score
    df['Num Questions Parseable'] = pd.to_numeric(df['Num Questions Parseable'], errors='coerce')
    df['Benchmark Score'] = df['Benchmark Score'] * (df['Num Questions Parseable'] / 171)

    # For each model, keep only the latest run
    df['Run ID'] = df['Run ID'].fillna('')
    df['timestamp'] = pd.to_datetime(df['Benchmark Completed'])
    df = df.sort_values('timestamp')
    df = df.drop_duplicates(subset=['Model Path'], keep='last')

    # Get model sizes
    def get_model_size(model_path):
        # Try exact match first
        if model_path in metadata:
            return metadata[model_path]
        # Try with max_length suffix
        if f"{model_path},max_length=4096" in metadata:
            return metadata[f"{model_path},max_length=4096"]
        return None

    # Print models without size before filtering
    print("\nModels without size assigned:")
    models_without_size = df[df['Model Path'].apply(get_model_size).isna()]
    for model in models_without_size['Model Path']:
        print(f"- {model}")

    df['Model Size'] = df['Model Path'].apply(get_model_size)
    df = df[df['Model Size'].notna()]

    # Remove extreme outliers (scores that are clearly errors)
    q1 = df['Benchmark Score'].quantile(0.25)
    q3 = df['Benchmark Score'].quantile(0.75)
    iqr = q3 - q1
    df = df[
        (df['Benchmark Score'] >= q1 - 1.5 * iqr) & 
        (df['Benchmark Score'] <= q3 + 1.5 * iqr)
    ]

    # Find models on Pareto frontier
    sizes = sorted(df['Model Size'].unique())
    frontier_points = []
    max_score = float('-inf')
    frontier_models = set()

    for size in sizes:
        # Get scores for models of this size or smaller
        subset = df[df['Model Size'] <= size]
        if len(subset) > 0:
            max_score_idx = subset['Benchmark Score'].idxmax()
            current_max = subset.loc[max_score_idx, 'Benchmark Score']
            if current_max > max_score:
                max_score = current_max
                frontier_points.append((size, max_score))
                frontier_models.add(subset.loc[max_score_idx, 'Model Path'])

    # Filter models - keep those on Pareto frontier or matching whitelist
    df['Keep'] = False
    for idx, row in df.iterrows():
        if row['Model Path'] in frontier_models:
            df.loc[idx, 'Keep'] = True
        else:
            for pattern in WHITELIST:
                if pattern in row['Model Path']:
                    df.loc[idx, 'Keep'] = True
                    break

    df = df[df['Keep']]

    # Create the plot
    fig = plt.figure(figsize=(12, 8))

    # Create scatter plot
    plt.scatter(df['Model Size'], 
               df['Benchmark Score'],
               alpha=0.6)

    # Add labels for points
    for idx, row in df.iterrows():
        # Get model name - either last part of path or full name for special cases
        model_name = row['Model Path'].split('/')[-1]
        if any(pattern in row['Model Path'] for pattern in ['gpt-3', 'gpt-4']):
            model_name = row['Model Path']
            
        plt.annotate(model_name,
                    (row['Model Size'], row['Benchmark Score']),
                    xytext=(5, 5), textcoords='offset points',
                    fontsize=8,
                    bbox=dict(facecolor='white', alpha=0.7, edgecolor='none', pad=0.5))

    # Plot the Pareto frontier line
    if frontier_points:
        frontier_x, frontier_y = zip(*frontier_points)
        plt.plot(frontier_x, frontier_y, 'r--', label='Pareto frontier')

    # Add vertical line for consumer GPU budget
    plt.axvline(x=12, color='gray', linestyle=':', label='Consumer-budget GPU limit', ymin=-0.15, clip_on=False)
    plt.text(12, -0.15, 'Consumer-budget\nGPU (24GB) limit\nin half precision', 
             horizontalalignment='center', verticalalignment='top',
             transform=plt.gca().get_xaxis_transform())

    # Customize the plot
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.xlabel('Model Size (billions of parameters)')
    plt.ylabel('Benchmark Score')
    plt.title('Model Performance vs Size (Pareto Frontier)')

    # Add legend
    plt.legend()

    # Adjust layout to prevent label cutoff
    plt.tight_layout()
    
    return fig

if __name__ == "__main__":
    # When run as a script, save the plot to a file
    fig = create_performance_plot()
    fig.savefig('model_performance.png', dpi=300, bbox_inches='tight')