zaki / app.py
Abdullah Zaki
Add plotly t
fd00e59
import gradio as gr
import pandas as pd
import numpy as np
import torch
from chronos import ChronosPipeline
import plotly.express as px
# Initialize Chronos-T5-Large for forecasting
# This model is loaded once at the start of the Gradio app for efficiency.
# The device_map automatically handles CPU/GPU allocation.
# torch_dtype=torch.bfloat16 is used for optimized performance if a compatible GPU is available.
chronos_pipeline = ChronosPipeline.from_pretrained(
"amazon/chronos-t5-large",
device_map="cuda" if torch.cuda.is_available() else "cpu",
torch_dtype=torch.bfloat16
)
def run_chronos_forecast(
csv_file: gr.File,
prediction_length: int = 30
) -> tuple[pd.DataFrame, px.line, str]:
"""
Runs time series forecasting using the Chronos-T5-Large model.
Args:
csv_file (gr.File): The uploaded CSV file containing historical data.
Must have 'date' and 'sentiment' columns.
prediction_length (int): The number of future periods (days) to forecast.
Returns:
tuple: A tuple containing:
- pd.DataFrame: A DataFrame of the forecast results (date, low, median, high).
- plotly.graph_objects.Figure: A Plotly figure visualizing the forecast.
- str: A status message (e.g., "Success" or an error message).
"""
if csv_file is None:
return pd.DataFrame(), None, "Error: Please upload a CSV file."
try:
# Read the uploaded CSV file into a pandas DataFrame
df = pd.read_csv(csv_file.name)
# Validate required columns
if "date" not in df.columns or "sentiment" not in df.columns:
return pd.DataFrame(), None, "Error: CSV must contain 'date' and 'sentiment' columns."
# Convert 'date' column to datetime objects
df['date'] = pd.to_datetime(df['date'])
# Convert 'sentiment' column to numeric, handling potential errors
df['sentiment'] = pd.to_numeric(df['sentiment'], errors='coerce')
# Drop rows where sentiment could not be converted (e.g., NaN values)
df.dropna(subset=['sentiment'], inplace=True)
if df.empty:
return pd.DataFrame(), None, "Error: No valid sentiment data found in the CSV."
# Sort data by date to ensure correct time series order
df = df.sort_values(by='date').reset_index(drop=True)
# Prepare time series data for Chronos
# Chronos expects a 1D tensor of the time series values
context = torch.tensor(df["sentiment"].values, dtype=torch.float32)
# Run forecast using Chronos-T5-Large pipeline
# The predict method returns a tensor of forecasts
forecast_tensor = chronos_pipeline.predict(context, prediction_length)
# Calculate quantiles (10%, 50% (median), 90%) for the forecast
# forecast_tensor[0] selects the first (and usually only) batch of predictions
low, median, high = np.quantile(forecast_tensor[0].numpy(), [0.1, 0.5, 0.9], axis=0)
# Generate future dates for the forecast results
# Start from the day after the last historical date
last_historical_date = df["date"].iloc[-1]
forecast_dates = pd.date_range(start=last_historical_date + pd.Timedelta(days=1),
periods=prediction_length,
freq="D")
# Create a DataFrame for the forecast results
forecast_df = pd.DataFrame({
"date": forecast_dates,
"low": low,
"median": median,
"high": high
})
# Create forecast plot using Plotly
fig = px.line(forecast_df, x="date", y=["median", "low", "high"], title="Sentiment Forecast")
fig.update_traces(line=dict(color="blue", width=3), selector=dict(name="median"))
fig.update_traces(line=dict(color="red", dash="dash"), selector=dict(name="low"))
fig.update_traces(line=dict(color="green", dash="dash"), selector=dict(name="high"))
fig.update_layout(hovermode="x unified", title_x=0.5) # Improve hover interactivity and center title
return forecast_df, fig, "Forecast generated successfully!"
except Exception as e:
# Catch any exceptions and return an error message to the user
return pd.DataFrame(), None, f"An error occurred: {str(e)}"
# Gradio interface definition
with gr.Blocks() as demo:
gr.Markdown("# Chronos Time Series Forecasting")
gr.Markdown("Upload a CSV file containing historical data with 'date' and 'sentiment' columns to get a sentiment forecast.")
with gr.Row():
csv_input = gr.File(label="Upload Historical Data (CSV)")
prediction_length_slider = gr.Slider(
1, 60, value=30, step=1, label="Prediction Length (days)"
)
run_button = gr.Button("Generate Forecast")
with gr.Tab("Forecast Plot"):
forecast_plot_output = gr.Plot(label="Sentiment Forecast Plot")
with gr.Tab("Forecast Data"):
forecast_json_output = gr.DataFrame(label="Raw Forecast Data") # Changed to DataFrame for better readability
status_message_output = gr.Textbox(label="Status", interactive=False)
# Define the click event handler for the run button
run_button.click(
fn=run_chronos_forecast,
inputs=[csv_input, prediction_length_slider],
outputs=[forecast_json_output, forecast_plot_output, status_message_output]
)
# Launch the Gradio application
demo.launch()