Spaces:
Running
Running
import json | |
import pandas as pd | |
import gradio as gr | |
from gradio_leaderboard import Leaderboard, ColumnFilter, SelectColumns | |
from css_html_js import custom_css, trigger_plot | |
from parse import read_json, read_data | |
from utils import model_hyperlink, filter_RTLRepo, filter_bench, filter_bench_all, handle_special_cases | |
from typing import Union | |
from about import CITATION_BUTTON_LABEL, CITATION_BUTTON_TEXT | |
import numpy as np | |
import plotly.graph_objects as go | |
import plotly.express as px | |
from gradio.themes.utils import colors | |
def filter_leaderboard(benchmark, model_type, search_query, max_params): | |
subset = df.copy() | |
if benchmark != 'All': | |
subset = df[df['Benchmark'] == benchmark] | |
if model_type != 'All': | |
subset = subset[subset['Model Type'] == model_type] | |
if search_query: | |
subset = subset[subset['Model'].str.contains(search_query, case=False, na=False)] | |
max_params = float(max_params) | |
subset = subset[subset['Params'] <= max_params] | |
if benchmark == 'All': | |
return filter_bench_all(subset) | |
elif benchmark == 'RTL-Repo': | |
return filter_RTLRepo(subset) | |
else: | |
return filter_bench(subset) | |
def generate_scatter_plot(benchmark, metric): | |
benchmark, metric = handle_special_cases(benchmark, metric) | |
subset = df[df['Benchmark'] == benchmark] | |
if benchmark == "RTL-Repo": | |
subset = subset[subset['Metric'].str.contains('EM', case=False, na=False)] | |
detailed_scores = subset.groupby('Model', as_index=False)['Score'].mean() | |
detailed_scores.rename(columns={'Score': 'Exact Matching (EM)'}, inplace=True) | |
detailed_scores['Average β¬οΈ'] = detailed_scores['Exact Matching (EM)'] | |
else: | |
detailed_scores = subset.pivot_table(index='Model', columns='Metric', values='Score').reset_index() | |
detailed_scores['Average β¬οΈ'] = detailed_scores[['Syntax (STX)', 'Functionality (FNC)', 'Synthesis (SYN)', 'Power', 'Performance', 'Area']].mean(axis=1) | |
details = df[['Model', 'Params', 'Model Type']].drop_duplicates('Model') | |
scatter_data = pd.merge(detailed_scores, details, on='Model', how='left').dropna(subset=['Params', metric]) | |
scatter_data['x'] = scatter_data['Params'] | |
scatter_data['y'] = scatter_data[metric] | |
scatter_data['size'] = (scatter_data['x'] ** 0.3) * 40 | |
type_colors = {"General": "green", "Coding": "yellow", "RTL-Specific": "blue"} | |
scatter_data['color'] = scatter_data['Model Type'].map(type_colors).fillna('gray') | |
y_axis_limits = { | |
'Functionality (FNC)': [5, 90], 'Syntax (STX)': [20, 100], 'Synthesis (SYN)': [5, 90], | |
'Power': [0, 50], 'Performance': [0, 50], 'Area': [0, 50], 'Exact Matching (EM)': [0, 50], | |
'Average β¬οΈ': [0, 80] | |
} | |
y_range = y_axis_limits.get(metric, [0, 80]) | |
fig = px.scatter( | |
scatter_data, x='x', y='y', log_x=True, size='size', color='Model Type', text='Model', | |
hover_data={metric: ':.2f'}, title=f'Params vs. {metric} for {benchmark}', | |
labels={'x': '# Params (Log Scale)', 'y': metric}, template="plotly_white", | |
# color_discrete_map={"General": "#A8D5BA", "Coding": "#F7DC6F", "RTL-Specific": "#87CEFA"}, | |
height=600, width=1200 | |
) | |
fig.update_traces( | |
textposition='top center', textfont_size=10, | |
marker=dict(opacity=0.8, line=dict(width=0.5, color='black')) | |
) | |
fig.update_layout( | |
xaxis=dict( | |
showgrid=True, type='log', tickmode='array', | |
tickvals=[8, 14, 32, 72, 200, 700], | |
ticktext=['8', '14', '32', '72', '200', '700'] | |
), | |
showlegend=False, yaxis=dict(range=y_range), | |
margin=dict(l=50, r=50, t=50, b=50), plot_bgcolor='white' | |
) | |
return fig | |
js_func = """ | |
function refresh() { | |
const url = new URL(window.location); | |
if (url.searchParams.get('__theme') !== 'light') { | |
url.searchParams.set('__theme', 'light'); | |
window.location.href = url.href; | |
} | |
} | |
""" | |
with gr.Blocks(css=custom_css, js=js_func, theme=gr.themes.Default(primary_hue=colors.emerald)) as app: | |
df, benchmarks, metrics, default_metric = read_data() | |
# gr.Markdown("""# TuRTLe π’ Model Leaderboard""") | |
gr.HTML(""" | |
<p align="center" style="margin-bottom: -10px;"> | |
<img src='/gradio_api/file=logo.png' alt='TuRTLe Logo' width='220'/> <br/> | |
</p> | |
""") | |
gr.HTML(""" | |
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/css/all.min.css"> | |
<script defer src="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.0.0/js/all.min.js"></script> | |
<div style="text-align: center; margin-bottom: 15px;"> | |
<p style="margin-bottom: 15px;">Welcome to the TuRTLe Model Leaderboard! Use the filters below to explore different RTL benchmarks and models.</p> | |
<a href="https://github.com/HPAI-BSC" target="_blank" style="text-decoration: none; margin-right: 10px;"> | |
<button style="background: #333; color: white; padding: 10px 14px; border-radius: 8px; border: none; font-size: 16px; cursor: pointer;"> | |
GitHub Repo | |
</button> | |
</a> | |
<a href="https://arxiv.org/" target="_blank" style="text-decoration: none; margin-right: 10px;"> | |
<button style="background: #b31b1b; color: white; padding: 10px 14px; border-radius: 8px; border: none; font-size: 16px; cursor: pointer;"> | |
arXiv Preprint | |
</button> | |
</a> | |
<a href="https://github.com/HPAI-BSC" style="text-decoration: none;"> | |
<button style="background: #00674F; color: white; padding: 10px 14px; border-radius: 8px; border: none; font-size: 16px; cursor: pointer;"> | |
How to submit | |
</button> | |
</a> | |
<p style="margin-top: 15px;">If you have any inquiries or wish to collaborate: | |
<a href="mailto:hpai@bsc.es">hpai@bsc.es</a> | |
</p> | |
</div> | |
""") | |
with gr.Tabs(): | |
with gr.Tab("Leaderboard"): | |
with gr.Row(): | |
benchmark_radio = gr.Radio(choices=["All"] + benchmarks, label="Select Benchmark", value='VerilogEval S2R', scale=7) | |
model_type_radio = gr.Radio(choices=['All', 'General', 'Coding', 'RTL-Specific'], label="Select Model Type", value='All', scale=4) | |
with gr.Row(): | |
search_box = gr.Textbox(label="Search Model", placeholder="Type model name...") | |
params_slider = gr.Slider( | |
minimum=df['Params'].min(), | |
maximum=700, | |
value=700, | |
label="Max Params", | |
step=1 | |
) | |
leaderboard = gr.DataFrame( | |
value=filter_leaderboard('VerilogEval S2R', 'All', "", 700), | |
headers="first row", | |
show_row_numbers=True, | |
wrap=False, | |
datatype=["markdown", "html",], | |
interactive=True, | |
column_widths=["5%", "28%", "10%", "15%", "6.5%", "6.5%", "6.5%", "6.5%", "6.5%", "6.5%", "6.5%"],) | |
with gr.Tab("Interactive Bubble Plot"): | |
with gr.Row(): | |
bubble_benchmark = gr.Radio(choices=benchmarks, label="Select Benchmark", value='VerilogEval S2R') | |
bubble_metric = gr.Radio(choices=metrics, label="Select Metric", value=default_metric) | |
gr.Markdown("We show in π’ General Models, in π΅ Coding Models and in π΄ RTL-Specific Models. Detailed information is shown when hovering over each model in the plot.") | |
scatter_plot = gr.Plot(value=generate_scatter_plot('VerilogEval S2R', default_metric), label="Bubble Chart", elem_id="full-width-plot") | |
with gr.Tab("About Us"): | |
gr.HTML( | |
""" | |
<div style="max-width: 800px; margin: auto; padding: 20px; border: 1px solid #ccc; border-radius: 10px;"> | |
<h1 style="text-align: center; font-size: 28px; margin-top: -7px;">HPAI-BSC</h1> | |
<p style="font-size: 16px; text-align: start;"> | |
The <b>High-Performance Artificial Intelligence (HPAI)</b> group is part of the | |
<a href="https://bsc.es/" target="_blank">Barcelona Supercomputing Center (BSC)</a>. | |
This leaderboard is maintained by HPAI as part of our commitment to <b>open science</b>. | |
</p> | |
<ul style="font-size: 16px; margin-bottom: 20px; margin-top: 20px;"> | |
<li><a href="https://hpai.bsc.es/" target="_blank">Official Website</a></li> | |
<li><a href="https://github.com/HPAI-BSC/" target="_blank">GitHub Organization Page</a></li> | |
<li><a href="https://huggingface.co/HPAI-BSC/" target="_blank">Hugging Face Organization Page</a></li> | |
<li><a href="https://hpai.bsc.es/publications" target="_blank">Publications</a></li> | |
</ul> | |
<p style="font-size: 16px; margin-top: 15px;"> | |
Feel free to contact us: | |
</p> | |
<p style="font-size: 16px;">Email: <a href="mailto:hpai@bsc.es"><b>hpai@bsc.es</b></a></p> | |
</div> | |
""" | |
) | |
with gr.Row(): | |
with gr.Accordion("π Citation", open=False): | |
citation_button = gr.Textbox( | |
value=CITATION_BUTTON_TEXT, | |
label=CITATION_BUTTON_LABEL, | |
lines=20, | |
elem_id="citation-button", | |
show_copy_button=True, | |
) | |
# event handlers, ugly way but it works | |
benchmark_radio.change(fn=filter_leaderboard, inputs=[benchmark_radio, model_type_radio, search_box, params_slider], outputs=leaderboard) | |
model_type_radio.change(fn=filter_leaderboard, inputs=[benchmark_radio, model_type_radio, search_box, params_slider], outputs=leaderboard) | |
search_box.change(fn=filter_leaderboard, inputs=[benchmark_radio, model_type_radio, search_box, params_slider], outputs=leaderboard) | |
params_slider.change(fn=filter_leaderboard, inputs=[benchmark_radio, model_type_radio, search_box, params_slider], outputs=leaderboard) | |
# RTL-Repo Bubble plot handlres | |
def on_benchmark_change(benchmark, metric): | |
benchmark, metric = handle_special_cases(benchmark, metric) | |
fig = generate_scatter_plot(benchmark, metric) | |
return gr.update(value=metric), fig | |
def on_metric_change(benchmark, metric): | |
benchmark, metric = handle_special_cases(benchmark, metric) | |
fig = generate_scatter_plot(benchmark, metric) | |
return gr.update(value=benchmark), fig | |
bubble_benchmark.change( | |
fn=on_benchmark_change, | |
inputs=[bubble_benchmark, bubble_metric], | |
outputs=[bubble_metric, scatter_plot], | |
js=""" // this is to avoid resetting user scroll each time a plot is re-generated | |
(benchmark, metric) => { | |
let scrollY = window.scrollY; | |
const observer = new MutationObserver(() => { | |
window.scrollTo(0, scrollY); | |
observer.disconnect(); | |
}); | |
observer.observe(document.getElementById('full-width-plot'), { childList: true }); | |
return [benchmark, metric]; | |
} | |
""") | |
bubble_metric.change( | |
fn=on_metric_change, | |
inputs=[bubble_benchmark, bubble_metric], | |
outputs=[bubble_benchmark, scatter_plot], | |
js=""" // this is to avoid resetting user scroll each time a plot is re-generated | |
(benchmark, metric) => { | |
let scrollY = window.scrollY; | |
const observer = new MutationObserver(() => { | |
window.scrollTo(0, scrollY); | |
observer.disconnect(); | |
}); | |
observer.observe(document.getElementById('full-width-plot'), { childList: true }); | |
return [benchmark, metric]; | |
} | |
""") | |
app.launch(allowed_paths=["logo.png"]) | |