File size: 8,354 Bytes
a834908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213e464
a834908
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
from datasets import load_dataset
import pandas as pd
import duckdb
import matplotlib.pyplot as plt
import seaborn as sns # Import Seaborn
import plotly.express as px # Added for Plotly
import plotly.graph_objects as go # Added for Plotly error figure
import gradio as gr
import os
from huggingface_hub import login
from datetime import datetime, timedelta
import sys # Added for error logging

# Get token from environment variable
HF_TOKEN = os.getenv('HF_TOKEN')
if not HF_TOKEN:
    raise ValueError("Please set the HF_TOKEN environment variable")

# Login to Hugging Face
login(token=HF_TOKEN)

# Apply Seaborn theme and context globally
sns.set_theme(style="whitegrid")
sns.set_context("notebook")

# Load dataset once at startup
try:
    dataset = load_dataset("reach-vb/trending-repos", split="models")
    df = dataset.to_pandas()
    # Register the pandas DataFrame as a DuckDB table named 'models'
    # This allows the SQL query to use 'FROM models'
    duckdb.register('models', df) 
except Exception as e:
    print(f"Error loading dataset: {e}")
    raise

def get_retention_data(start_date: str, end_date: str) -> pd.DataFrame:
    try:
        # The input start_date and end_date are already strings in YYYY-MM-DD format.
        # We can pass them directly to DuckDB if the SQL column is DATE.
        
        query = """
        WITH model_presence AS (
          SELECT 
            id AS model_id,
            collected_at::DATE AS collection_day
          FROM models
        ),
        daily_model_counts AS (
          SELECT 
            collection_day,
            COUNT(*) AS total_models_today
          FROM model_presence
          GROUP BY collection_day
        ),
        retained_models AS (
          SELECT 
            a.collection_day,
            COUNT(*) AS previously_existed_count
          FROM model_presence a
          JOIN model_presence b 
            ON a.model_id = b.model_id 
            AND a.collection_day = b.collection_day + INTERVAL '1 day'
          GROUP BY a.collection_day
        )
        SELECT 
          d.collection_day,
          d.total_models_today,
          COALESCE(r.previously_existed_count, 0) AS carried_over_models,
          CASE 
            WHEN d.total_models_today = 0 THEN NULL
            ELSE ROUND(COALESCE(r.previously_existed_count, 0) * 100.0 / d.total_models_today, 2)
          END AS percent_retained
        FROM daily_model_counts d
        LEFT JOIN retained_models r ON d.collection_day = r.collection_day
        WHERE d.collection_day BETWEEN ? AND ?
        ORDER BY d.collection_day
        """
        # Pass the string dates directly to the query, using the 'params' keyword argument.
        result = duckdb.query(query, params=[start_date, end_date]).to_df()
        print("SQL Query Result:") # Log the result
        print(result) # Log the result
        return result
    except Exception as e:
        # Log the error to standard error
        print(f"Error in get_retention_data: {e}", file=sys.stderr)
        # Return empty DataFrame with error message
        return pd.DataFrame({"Error": [str(e)]})

def plot_retention_data(dataframe: pd.DataFrame):
    print("DataFrame received by plot_retention_data (first 5 rows):")
    print(dataframe.head())
    print("\nData types in plot_retention_data before any conversion:")
    print(dataframe.dtypes)

    # Check if the DataFrame itself is an error signal from the previous function
    if "Error" in dataframe.columns and not dataframe.empty:
        error_message = dataframe['Error'].iloc[0]
        print(f"Error DataFrame received: {error_message}", file=sys.stderr)
        fig = go.Figure()
        fig.add_annotation(
            text=f"Error from data generation: {error_message}",
            xref="paper", yref="paper",
            x=0.5, y=0.5, showarrow=False,
            font=dict(size=16)
        )
        return fig

    try:
        # Ensure 'percent_retained' column exists
        if 'percent_retained' not in dataframe.columns:
            raise ValueError("'percent_retained' column is missing from the DataFrame.")
        
        if 'collection_day' not in dataframe.columns:
            raise ValueError("'collection_day' column is missing from the DataFrame.")

        # Explicitly convert 'percent_retained' to numeric.
        # Ensure 'percent_retained' is numeric and 'collection_day' is datetime for Plotly
        dataframe['percent_retained'] = pd.to_numeric(dataframe['percent_retained'], errors='coerce')
        dataframe['collection_day'] = pd.to_datetime(dataframe['collection_day'])

        # Drop rows where 'percent_retained' could not be converted (became NaT)
        dataframe.dropna(subset=['percent_retained', 'collection_day'], inplace=True)
        
        print("\n'percent_retained' column after pd.to_numeric (first 5 values):")
        print(dataframe['percent_retained'].head())
        print("'percent_retained' dtype after pd.to_numeric:", dataframe['percent_retained'].dtype)
        print("\n'collection_day' column after pd.to_datetime (first 5 values):")
        print(dataframe['collection_day'].head())
        print("'collection_day' dtype after pd.to_datetime:", dataframe['collection_day'].dtype)
        
        if dataframe.empty:
            fig = go.Figure()
            fig.add_annotation(
                text="No data available to plot after processing.",
                xref="paper", yref="paper",
                x=0.5, y=0.5, showarrow=False,
                font=dict(size=16)
            )
            return fig


        # Create Plotly bar chart
        fig = px.bar(
            dataframe,
            x='collection_day',
            y='percent_retained',
            title='Previous Day Top 200 Trending Model Retention %',
            labels={'collection_day': 'Date', 'percent_retained': 'Retention Rate (%)'},
            text='percent_retained'  # Use the column directly for hover/text
        )

        # Format the text on bars
        fig.update_traces(
            texttemplate='%{text:.2f}%', 
            textposition='inside', 
            insidetextanchor='middle', # Anchor text to the middle of the bar
            textfont_color='white',
            textfont_size=10, # Adjusted size for better fit
            hovertemplate='<b>Date</b>: %{x|%Y-%m-%d}<br>' +
                          '<b>Retention</b>: %{y:.2f}%<extra></extra>' # Custom hover
        )
        
        # Calculate and plot the average retention line
        if not dataframe['percent_retained'].empty:
            average_retention = dataframe['percent_retained'].mean()
            fig.add_hline(
                y=average_retention,
                line_dash="dash",
                line_color="red",
                annotation_text=f"Average: {average_retention:.2f}%",
                annotation_position="bottom right"
            )

        fig.update_xaxes(tickangle=45)
        fig.update_layout(
            title_x=0.5, # Center title
            xaxis_title="Date",
            yaxis_title="Retention Rate (%)",
            plot_bgcolor='white', # Set plot background to white like seaborn whitegrid
            bargap=0.2 # Gap between bars of different categories
        )
        return fig
    except Exception as e:
        print(f"Error during plot_retention_data: {e}", file=sys.stderr)
        fig = go.Figure()
        fig.add_annotation(
            text=f"Plotting Error: {str(e)}",
            xref="paper", yref="paper",
            x=0.5, y=0.5, showarrow=False,
            font=dict(size=16)
        )
        return fig

def interface_fn(start_date, end_date):
    result = get_retention_data(start_date, end_date)
    return plot_retention_data(result)

# Get min and max dates from the dataset
min_date = datetime.fromisoformat(df['collected_at'].min()).date()
max_date = datetime.fromisoformat(df['collected_at'].max()).date()

iface = gr.Interface(
    fn=interface_fn,
    inputs=[
        gr.Textbox(label="Start Date (YYYY-MM-DD)", value=min_date.strftime("%Y-%m-%d")),
        gr.Textbox(label="End Date (YYYY-MM-DD)", value=max_date.strftime("%Y-%m-%d"))
    ],
    outputs=gr.Plot(label="Model Retention Visualization"),
    title="Model Retention Analysis",
    description="Visualize model retention rates over time. Enter dates in YYYY-MM-DD format."
)

iface.launch()