Spaces:
Sleeping
Sleeping
from datasets import load_dataset | |
import pandas as pd | |
import duckdb | |
import matplotlib.pyplot as plt | |
import seaborn as sns # Import Seaborn | |
import plotly.express as px # Added for Plotly | |
import plotly.graph_objects as go # Added for Plotly error figure | |
import gradio as gr | |
import os | |
from huggingface_hub import login | |
from datetime import datetime, timedelta | |
import sys # Added for error logging | |
# Get token from environment variable | |
HF_TOKEN = os.getenv('HF_TOKEN') | |
if not HF_TOKEN: | |
raise ValueError("Please set the HF_TOKEN environment variable") | |
# Login to Hugging Face | |
login(token=HF_TOKEN) | |
# Apply Seaborn theme and context globally | |
sns.set_theme(style="whitegrid") | |
sns.set_context("notebook") | |
# Load dataset once at startup | |
try: | |
dataset = load_dataset("reach-vb/trending-repos", split="models") | |
df = dataset.to_pandas() | |
# Register the pandas DataFrame as a DuckDB table named 'models' | |
# This allows the SQL query to use 'FROM models' | |
duckdb.register('models', df) | |
except Exception as e: | |
print(f"Error loading dataset: {e}") | |
raise | |
def get_retention_data(start_date: str, end_date: str) -> pd.DataFrame: | |
try: | |
# The input start_date and end_date are already strings in YYYY-MM-DD format. | |
# We can pass them directly to DuckDB if the SQL column is DATE. | |
query = """ | |
WITH model_presence AS ( | |
SELECT | |
id AS model_id, | |
collected_at::DATE AS collection_day | |
FROM models | |
), | |
daily_model_counts AS ( | |
SELECT | |
collection_day, | |
COUNT(*) AS total_models_today | |
FROM model_presence | |
GROUP BY collection_day | |
), | |
retained_models AS ( | |
SELECT | |
a.collection_day, | |
COUNT(*) AS previously_existed_count | |
FROM model_presence a | |
JOIN model_presence b | |
ON a.model_id = b.model_id | |
AND a.collection_day = b.collection_day + INTERVAL '1 day' | |
GROUP BY a.collection_day | |
) | |
SELECT | |
d.collection_day, | |
d.total_models_today, | |
COALESCE(r.previously_existed_count, 0) AS carried_over_models, | |
CASE | |
WHEN d.total_models_today = 0 THEN NULL | |
ELSE ROUND(COALESCE(r.previously_existed_count, 0) * 100.0 / d.total_models_today, 2) | |
END AS percent_retained | |
FROM daily_model_counts d | |
LEFT JOIN retained_models r ON d.collection_day = r.collection_day | |
WHERE d.collection_day BETWEEN ? AND ? | |
ORDER BY d.collection_day | |
""" | |
# Pass the string dates directly to the query, using the 'params' keyword argument. | |
result = duckdb.query(query, params=[start_date, end_date]).to_df() | |
print("SQL Query Result:") # Log the result | |
print(result) # Log the result | |
return result | |
except Exception as e: | |
# Log the error to standard error | |
print(f"Error in get_retention_data: {e}", file=sys.stderr) | |
# Return empty DataFrame with error message | |
return pd.DataFrame({"Error": [str(e)]}) | |
def plot_retention_data(dataframe: pd.DataFrame): | |
print("DataFrame received by plot_retention_data (first 5 rows):") | |
print(dataframe.head()) | |
print("\nData types in plot_retention_data before any conversion:") | |
print(dataframe.dtypes) | |
# Check if the DataFrame itself is an error signal from the previous function | |
if "Error" in dataframe.columns and not dataframe.empty: | |
error_message = dataframe['Error'].iloc[0] | |
print(f"Error DataFrame received: {error_message}", file=sys.stderr) | |
fig = go.Figure() | |
fig.add_annotation( | |
text=f"Error from data generation: {error_message}", | |
xref="paper", yref="paper", | |
x=0.5, y=0.5, showarrow=False, | |
font=dict(size=16) | |
) | |
return fig | |
try: | |
# Ensure 'percent_retained' column exists | |
if 'percent_retained' not in dataframe.columns: | |
raise ValueError("'percent_retained' column is missing from the DataFrame.") | |
if 'collection_day' not in dataframe.columns: | |
raise ValueError("'collection_day' column is missing from the DataFrame.") | |
# Explicitly convert 'percent_retained' to numeric. | |
# Ensure 'percent_retained' is numeric and 'collection_day' is datetime for Plotly | |
dataframe['percent_retained'] = pd.to_numeric(dataframe['percent_retained'], errors='coerce') | |
dataframe['collection_day'] = pd.to_datetime(dataframe['collection_day']) | |
# Drop rows where 'percent_retained' could not be converted (became NaT) | |
dataframe.dropna(subset=['percent_retained', 'collection_day'], inplace=True) | |
print("\n'percent_retained' column after pd.to_numeric (first 5 values):") | |
print(dataframe['percent_retained'].head()) | |
print("'percent_retained' dtype after pd.to_numeric:", dataframe['percent_retained'].dtype) | |
print("\n'collection_day' column after pd.to_datetime (first 5 values):") | |
print(dataframe['collection_day'].head()) | |
print("'collection_day' dtype after pd.to_datetime:", dataframe['collection_day'].dtype) | |
if dataframe.empty: | |
fig = go.Figure() | |
fig.add_annotation( | |
text="No data available to plot after processing.", | |
xref="paper", yref="paper", | |
x=0.5, y=0.5, showarrow=False, | |
font=dict(size=16) | |
) | |
return fig | |
# Create Plotly bar chart | |
fig = px.bar( | |
dataframe, | |
x='collection_day', | |
y='percent_retained', | |
title='Previous Day Top 200 Trending Model Retention %', | |
labels={'collection_day': 'Date', 'percent_retained': 'Retention Rate (%)'}, | |
text='percent_retained' # Use the column directly for hover/text | |
) | |
# Format the text on bars | |
fig.update_traces( | |
texttemplate='%{text:.2f}%', | |
textposition='inside', | |
insidetextanchor='middle', # Anchor text to the middle of the bar | |
textfont_color='white', | |
textfont_size=10, # Adjusted size for better fit | |
hovertemplate='<b>Date</b>: %{x|%Y-%m-%d}<br>' + | |
'<b>Retention</b>: %{y:.2f}%<extra></extra>' # Custom hover | |
) | |
# Calculate and plot the average retention line | |
if not dataframe['percent_retained'].empty: | |
average_retention = dataframe['percent_retained'].mean() | |
fig.add_hline( | |
y=average_retention, | |
line_dash="dash", | |
line_color="red", | |
annotation_text=f"Average: {average_retention:.2f}%", | |
annotation_position="bottom right" | |
) | |
fig.update_xaxes(tickangle=45) | |
fig.update_layout( | |
title_x=0.5, # Center title | |
xaxis_title="Date", | |
yaxis_title="Retention Rate (%)", | |
plot_bgcolor='white', # Set plot background to white like seaborn whitegrid | |
bargap=0.2 # Gap between bars of different categories | |
) | |
return fig | |
except Exception as e: | |
print(f"Error during plot_retention_data: {e}", file=sys.stderr) | |
fig = go.Figure() | |
fig.add_annotation( | |
text=f"Plotting Error: {str(e)}", | |
xref="paper", yref="paper", | |
x=0.5, y=0.5, showarrow=False, | |
font=dict(size=16) | |
) | |
return fig | |
def interface_fn(start_date, end_date): | |
result = get_retention_data(start_date, end_date) | |
return plot_retention_data(result) | |
# Get min and max dates from the dataset | |
min_date = datetime.fromisoformat(df['collected_at'].min()).date() | |
max_date = datetime.fromisoformat(df['collected_at'].max()).date() | |
iface = gr.Interface( | |
fn=interface_fn, | |
inputs=[ | |
gr.Textbox(label="Start Date (YYYY-MM-DD)", value=min_date.strftime("%Y-%m-%d")), | |
gr.Textbox(label="End Date (YYYY-MM-DD)", value=max_date.strftime("%Y-%m-%d")) | |
], | |
outputs=gr.Plot(label="Model Retention Visualization"), | |
title="Model Retention Analysis", | |
description="Visualize model retention rates over time. Enter dates in YYYY-MM-DD format." | |
) | |
iface.launch() |