|
import gradio as gr |
|
import pandas as pd |
|
import numpy as np |
|
import torch |
|
from chronos import ChronosPipeline |
|
import plotly.express as px |
|
|
|
|
|
|
|
|
|
|
|
chronos_pipeline = ChronosPipeline.from_pretrained( |
|
"amazon/chronos-t5-large", |
|
device_map="cuda" if torch.cuda.is_available() else "cpu", |
|
torch_dtype=torch.bfloat16 |
|
) |
|
|
|
def run_chronos_forecast( |
|
csv_file: gr.File, |
|
prediction_length: int = 30 |
|
) -> tuple[pd.DataFrame, px.line, str]: |
|
""" |
|
Runs time series forecasting using the Chronos-T5-Large model. |
|
|
|
Args: |
|
csv_file (gr.File): The uploaded CSV file containing historical data. |
|
Must have 'date' and 'sentiment' columns. |
|
prediction_length (int): The number of future periods (days) to forecast. |
|
|
|
Returns: |
|
tuple: A tuple containing: |
|
- pd.DataFrame: A DataFrame of the forecast results (date, low, median, high). |
|
- plotly.graph_objects.Figure: A Plotly figure visualizing the forecast. |
|
- str: A status message (e.g., "Success" or an error message). |
|
""" |
|
if csv_file is None: |
|
return pd.DataFrame(), None, "Error: Please upload a CSV file." |
|
|
|
try: |
|
|
|
df = pd.read_csv(csv_file.name) |
|
|
|
|
|
if "date" not in df.columns or "sentiment" not in df.columns: |
|
return pd.DataFrame(), None, "Error: CSV must contain 'date' and 'sentiment' columns." |
|
|
|
|
|
df['date'] = pd.to_datetime(df['date']) |
|
|
|
df['sentiment'] = pd.to_numeric(df['sentiment'], errors='coerce') |
|
|
|
df.dropna(subset=['sentiment'], inplace=True) |
|
|
|
if df.empty: |
|
return pd.DataFrame(), None, "Error: No valid sentiment data found in the CSV." |
|
|
|
|
|
df = df.sort_values(by='date').reset_index(drop=True) |
|
|
|
|
|
|
|
context = torch.tensor(df["sentiment"].values, dtype=torch.float32) |
|
|
|
|
|
|
|
forecast_tensor = chronos_pipeline.predict(context, prediction_length) |
|
|
|
|
|
|
|
low, median, high = np.quantile(forecast_tensor[0].numpy(), [0.1, 0.5, 0.9], axis=0) |
|
|
|
|
|
|
|
last_historical_date = df["date"].iloc[-1] |
|
forecast_dates = pd.date_range(start=last_historical_date + pd.Timedelta(days=1), |
|
periods=prediction_length, |
|
freq="D") |
|
|
|
|
|
forecast_df = pd.DataFrame({ |
|
"date": forecast_dates, |
|
"low": low, |
|
"median": median, |
|
"high": high |
|
}) |
|
|
|
|
|
fig = px.line(forecast_df, x="date", y=["median", "low", "high"], title="Sentiment Forecast") |
|
fig.update_traces(line=dict(color="blue", width=3), selector=dict(name="median")) |
|
fig.update_traces(line=dict(color="red", dash="dash"), selector=dict(name="low")) |
|
fig.update_traces(line=dict(color="green", dash="dash"), selector=dict(name="high")) |
|
fig.update_layout(hovermode="x unified", title_x=0.5) |
|
|
|
return forecast_df, fig, "Forecast generated successfully!" |
|
|
|
except Exception as e: |
|
|
|
return pd.DataFrame(), None, f"An error occurred: {str(e)}" |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Chronos Time Series Forecasting") |
|
gr.Markdown("Upload a CSV file containing historical data with 'date' and 'sentiment' columns to get a sentiment forecast.") |
|
|
|
with gr.Row(): |
|
csv_input = gr.File(label="Upload Historical Data (CSV)") |
|
prediction_length_slider = gr.Slider( |
|
1, 60, value=30, step=1, label="Prediction Length (days)" |
|
) |
|
|
|
run_button = gr.Button("Generate Forecast") |
|
|
|
with gr.Tab("Forecast Plot"): |
|
forecast_plot_output = gr.Plot(label="Sentiment Forecast Plot") |
|
with gr.Tab("Forecast Data"): |
|
forecast_json_output = gr.DataFrame(label="Raw Forecast Data") |
|
|
|
status_message_output = gr.Textbox(label="Status", interactive=False) |
|
|
|
|
|
run_button.click( |
|
fn=run_chronos_forecast, |
|
inputs=[csv_input, prediction_length_slider], |
|
outputs=[forecast_json_output, forecast_plot_output, status_message_output] |
|
) |
|
|
|
|
|
demo.launch() |
|
|