Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import plotly.express as px | |
| import plotly.graph_objs as go | |
| from collections import defaultdict | |
| import json, math, gdown | |
| import numpy as np | |
| import pandas as pd | |
| from Config import * | |
| pd.options.display.float_format = '{:.2f}'.format | |
| battles = np.linspace(0, 100, 100) | |
| meta_topics = ['mmlu'] | |
| def generate_plot(meta_index, topic_index): | |
| """ | |
| Bar plot of a specific dataset | |
| """ | |
| # battles = np.linspace(0, 100, 100) | |
| meta_topic = meta_topics[meta_index] | |
| print(meta_topic) | |
| topic = TOPICS[meta_topic][topic_index] | |
| data = pd.read_csv(f"data/{meta_topic}/response_rec.csv", sep=",") | |
| topic_data = data.loc[data['sub_topic'] == topic].copy() | |
| # Compute human and llm accuracy | |
| topic_data['human_acc'] = topic_data['no_correct_human'] / topic_data['no_responses_human'].replace(0, np.nan) | |
| topic_data['llm_acc'] = topic_data['no_correct_llm'] / topic_data['no_responses_llm'].replace(0, np.nan) | |
| # Selecting only numeric columns for aggregation | |
| numeric_cols = ['no_responses_human', 'no_correct_human', 'no_responses_llm', 'no_correct_llm', 'oracle_acc', 'human_acc', 'llm_acc'] | |
| mean_data = topic_data.groupby('model_name')[numeric_cols].mean().reset_index() | |
| std_deviation = topic_data.groupby('model_name')[numeric_cols].std().reset_index() | |
| # Prepare the plot data | |
| plot_data = [] | |
| # Define a consistent color scheme with different opacities | |
| colors = ['#FFA07A', '#20B2AA', '#778899'] # Light Salmon, Light Sea Green, Light Slate Gray | |
| acc_types = ['oracle_acc', 'human_acc', 'llm_acc'] | |
| # Add bars with error bars for the averages | |
| for acc_type, color in zip(acc_types, colors): | |
| plot_data.append(go.Bar( | |
| x=mean_data['model_name'], | |
| y=mean_data[acc_type], | |
| error_y=dict( | |
| type='data', | |
| array=std_deviation[acc_type], | |
| visible=True | |
| ), | |
| name=acc_type.split('_')[0].capitalize(), | |
| marker=dict(color=color) | |
| )) | |
| # Layout | |
| layout = go.Layout( | |
| title=f"Accuracy for {meta_topic} ({topic})", | |
| xaxis=dict(title='Model Name'), | |
| yaxis=dict(title='Accuracy'), | |
| showlegend=True, | |
| legend=dict(title='Accuracy Type'), | |
| barmode='group' | |
| ) | |
| fig = go.Figure(data=plot_data, layout=layout) | |
| return fig | |
| # Gradio interface with grid layout | |
| with gr.Blocks() as interface: | |
| with gr.Row(): # Row 1 | |
| plot1 = gr.Plot(generate_plot(0, 0)) | |
| # plot1.update(inputs=[0, 0]) | |
| plot2 = gr.Plot(generate_plot(0, 0)) | |
| # plot2.update(inputs=[0, 1]) | |
| with gr.Row(): # Row 2 | |
| plot3 = gr.Plot(generate_plot(0, 0)) | |
| # plot3.update(inputs=[1, 0]) | |
| plot4 = gr.Plot(generate_plot(0, 0)) | |
| # plot4.update(inputs=[1, 1]) | |
| with gr.Row(): # Row 3 | |
| plot5 = gr.Plot(generate_plot(0, 0)) | |
| # plot5.update(inputs=[2, 0]) | |
| plot6 = gr.Plot(generate_plot(0, 0)) | |
| # plot6.update(inputs=[2, 1]) | |
| interface.launch() | |