Spaces:
Sleeping
Sleeping
# api/index.py | |
import os | |
import logging | |
import time | |
from datetime import datetime, timedelta | |
from datetime import date as datetime_date | |
from typing import List, Dict, Any, Optional, AsyncGenerator | |
import asyncio | |
from contextlib import asynccontextmanager | |
import yaml | |
import importlib.metadata | |
import pytz | |
from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends, Security | |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
from pydantic import BaseModel, Field | |
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker | |
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column | |
from sqlalchemy import String, Integer, DateTime, select, delete, Float, Index | |
from sqlalchemy.types import Date as SQLAlchemyDate | |
from dotenv import load_dotenv, find_dotenv | |
from sqlalchemy.pool import NullPool | |
import requests | |
import pandas as pd | |
from io import StringIO | |
import ssl | |
import certifi | |
import aiohttp | |
import platform | |
import yfinance as yf | |
# --- Favicon/Static imports --- | |
import os | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.responses import FileResponse | |
from fastapi.responses import JSONResponse | |
# --- MODELS --- | |
class Base(DeclarativeBase): | |
pass | |
class Ticker(Base): | |
__tablename__ = "tickers" | |
ticker: Mapped[str] = mapped_column(String(10), primary_key=True) | |
name: Mapped[str] = mapped_column(String(255), nullable=False) | |
sector: Mapped[Optional[str]] = mapped_column(String(128), nullable=True) | |
subindustry: Mapped[Optional[str]] = mapped_column(String(128), nullable=True) | |
is_sp500: Mapped[int] = mapped_column(Integer, default=0) | |
is_nasdaq100: Mapped[int] = mapped_column(Integer, default=0) | |
last_updated: Mapped[datetime] = mapped_column(DateTime) | |
class TickerData(Base): | |
__tablename__ = "ticker_data" | |
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) | |
ticker: Mapped[str] = mapped_column(String(10), nullable=False) | |
date: Mapped[datetime_date] = mapped_column(SQLAlchemyDate, nullable=False) | |
open: Mapped[float] = mapped_column(Float, nullable=False) | |
high: Mapped[float] = mapped_column(Float, nullable=False) | |
low: Mapped[float] = mapped_column(Float, nullable=False) | |
close: Mapped[float] = mapped_column(Float, nullable=False) | |
volume: Mapped[int] = mapped_column(Integer, nullable=False) | |
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False) | |
__table_args__ = ( | |
Index('idx_ticker_date', 'ticker', 'date', unique=True), | |
Index('idx_ticker', 'ticker'), | |
Index('idx_date', 'date'), | |
) | |
# --- PYDANTIC MODELS --- | |
class TickerResponse(BaseModel): | |
ticker: str | |
name: str | |
sector: Optional[str] | |
subindustry: Optional[str] | |
is_sp500: bool | |
is_nasdaq100: bool | |
last_updated: datetime | |
class UpdateTickersRequest(BaseModel): | |
force_refresh: bool = Field(default=False, description="Force refresh even if data is recent") | |
class UpdateTickersResponse(BaseModel): | |
success: bool | |
message: str | |
total_tickers: int | |
sp500_count: int | |
nasdaq100_count: int | |
updated_at: datetime | |
class TaskStatus(BaseModel): | |
task_id: str | |
status: str # pending, running, completed, failed | |
message: Optional[str] = None | |
result: Optional[Dict[str, Any]] = None | |
created_at: datetime | |
class TickerDataResponse(BaseModel): | |
ticker: str | |
date: datetime_date | |
open: float | |
high: float | |
low: float | |
close: float | |
volume: int | |
created_at: datetime | |
class DownloadDataRequest(BaseModel): | |
tickers: Optional[List[str]] = Field(default=None, description="Specific tickers to download. If not provided, downloads all available tickers") | |
force_refresh: bool = Field(default=False, description="Force refresh even if data exists") | |
class DownloadDataResponse(BaseModel): | |
success: bool | |
message: str | |
tickers_processed: int | |
records_created: int | |
records_updated: int | |
date_range: Dict[str, str] # start_date, end_date | |
updated_at: datetime | |
# --- AUTHENTICATION --- | |
security = HTTPBearer() | |
async def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security)): | |
""" | |
Verify API key from Authorization header. | |
Expected format: Authorization: Bearer <api_key> | |
""" | |
api_key = os.getenv("API_KEY") | |
if not api_key: | |
raise HTTPException( | |
status_code=500, | |
detail="API key not configured on server" | |
) | |
if credentials.credentials != api_key: | |
raise HTTPException( | |
status_code=401, | |
detail="Invalid API key" | |
) | |
return credentials.credentials | |
# --- CONFIGURATION --- | |
class Config: | |
def __init__(self): | |
load_dotenv(find_dotenv()) | |
self.config = self._load_yaml_config() | |
self._setup_logging() | |
def _load_yaml_config(self, config_path='config.yaml'): | |
try: | |
with open(config_path, 'r') as f: | |
return yaml.safe_load(f) | |
except FileNotFoundError: | |
logging.warning(f"Config file '{config_path}' not found. Using defaults.") | |
return self._get_default_config() | |
def _get_default_config(self): | |
return { | |
'logging': {'level': 'INFO', 'log_file': 'data_cache/api.log'}, | |
'data_sources': { | |
'sp500': { | |
'url': 'https://en.wikipedia.org/wiki/List_of_S%26P_500_companies', | |
'ticker_column': 'Symbol', | |
'name_column': 'Security' | |
}, | |
'nasdaq100': { | |
'url': 'https://en.wikipedia.org/wiki/Nasdaq-100', | |
'ticker_column': 'Ticker', | |
'name_column': 'Company' | |
} | |
}, | |
'database': {'pool_size': 5, 'max_overflow': 10} | |
} | |
def _setup_logging(self): | |
log_config = self.config.get('logging', {}) | |
log_file = log_config.get('log_file', 'data_cache/api.log') | |
prod_mode = os.getenv("PROD", "False") == "True" | |
handlers = [logging.StreamHandler()] | |
if not prod_mode: | |
os.makedirs(os.path.dirname(log_file), exist_ok=True) | |
handlers.insert(0, logging.FileHandler(log_file)) | |
logging.basicConfig( | |
level=getattr(logging, log_config.get('level', 'INFO').upper()), | |
format='%(asctime)s [%(levelname)s] %(name)s: %(message)s', | |
handlers=handlers, | |
datefmt='%Y-%m-%d %H:%M:%S' | |
) | |
def database_url(self) -> str: | |
user = os.getenv("MYSQL_USER") | |
password = os.getenv("MYSQL_PASSWORD") | |
host = os.getenv("MYSQL_HOST") | |
port = os.getenv("MYSQL_PORT") | |
db = os.getenv("MYSQL_DB") | |
if not all([user, password]): | |
raise ValueError("MySQL credentials not found in environment variables") | |
return f"mysql+aiomysql://{user}:{password}@{host}:{port}/{db}" | |
# --- SERVICES --- | |
class TickerService: | |
def __init__(self, config: Config): | |
self.config = config | |
self.logger = logging.getLogger(__name__) | |
async def get_tickers_from_wikipedia( | |
self, url: str, ticker_column: str, name_column: str, | |
sector_column: Optional[str] = None, subindustry_column: Optional[str] = None | |
) -> List[tuple[str, str, Optional[str], Optional[str]]]: | |
"""Async version fetching ticker, name, sector, and subindustry from Wikipedia.""" | |
try: | |
ssl_context = ssl.create_default_context(cafile=certifi.where()) | |
connector = aiohttp.TCPConnector(ssl=ssl_context) | |
async with aiohttp.ClientSession(connector=connector) as session: | |
headers = {'User-Agent': 'Mozilla/5.0 (compatible; MarketDataAPI/1.0)'} | |
async with session.get(url, headers=headers) as response: | |
response.raise_for_status() | |
html_content = await response.text() | |
tables = pd.read_html(StringIO(html_content)) | |
columns_needed = [ticker_column, name_column] | |
if sector_column: | |
columns_needed.append(sector_column) | |
if subindustry_column: | |
columns_needed.append(subindustry_column) | |
df = next((table for table in tables if all(col in table.columns for col in columns_needed)), None) | |
if df is None: | |
self.logger.error(f"Could not find columns {columns_needed} on {url}") | |
return [] | |
entries = df[columns_needed].dropna(subset=[ticker_column]) | |
self.logger.info(f"Fetched {len(entries)} rows from {url}") | |
results: List[tuple[str, str, Optional[str], Optional[str]]] = [] | |
for _, row in entries.iterrows(): | |
ticker = str(row[ticker_column]).strip() | |
name = str(row[name_column]).strip() | |
sector = str(row[sector_column]).strip() if sector_column and sector_column in row and pd.notna(row[sector_column]) else None | |
subindustry = str(row[subindustry_column]).strip() if subindustry_column and subindustry_column in row and pd.notna(row[subindustry_column]) else None | |
results.append((ticker, name, sector, subindustry)) | |
return results | |
except Exception as e: | |
self.logger.error(f"Failed to fetch tickers and names from {url}: {e}") | |
return [] | |
async def get_sp500_tickers(self) -> List[tuple[str, str, Optional[str], Optional[str]]]: | |
cfg = self.config.config.get('data_sources', {}).get('sp500', {}) | |
return await self.get_tickers_from_wikipedia( | |
cfg.get('url'), | |
cfg.get('ticker_column'), | |
cfg.get('name_column'), | |
cfg.get('sector_column'), | |
cfg.get('subindustry_column') | |
) | |
async def get_nasdaq100_tickers(self) -> List[tuple[str, str, Optional[str], Optional[str]]]: | |
cfg = self.config.config.get('data_sources', {}).get('nasdaq100', {}) | |
return await self.get_tickers_from_wikipedia( | |
cfg.get('url'), | |
cfg.get('ticker_column'), | |
cfg.get('name_column'), | |
cfg.get('sector_column'), | |
cfg.get('subindustry_column') | |
) | |
async def update_tickers_in_db(self, session: AsyncSession, force_refresh: bool = False) -> Dict[str, Any]: | |
""" | |
Updates tickers table with latest data from Wikipedia sources, unless data is less than 1 day old (unless force_refresh). | |
""" | |
try: | |
# Check if tickers were updated in the last 24h | |
now = datetime.now(pytz.UTC) | |
result = await session.execute(select(Ticker.last_updated).order_by(Ticker.last_updated.desc()).limit(1)) | |
last = result.scalar() | |
if last and not force_refresh: | |
# Ensure 'last' is timezone aware | |
if last.tzinfo is None: | |
last = pytz.UTC.localize(last) | |
delta = now - last | |
if delta.total_seconds() < 86400: | |
self.logger.info(f"Tickers not updated: last update {last.isoformat()} < 1 day ago.") | |
from sqlalchemy import func | |
total_tickers = await session.scalar(select(func.count()).select_from(Ticker)) | |
sp500_count = await session.scalar(select(func.count()).select_from(Ticker).where(Ticker.is_sp500 == 1)) | |
nasdaq100_count = await session.scalar(select(func.count()).select_from(Ticker).where(Ticker.is_nasdaq100 == 1)) | |
return { | |
"total_tickers": total_tickers, | |
"sp500_count": sp500_count, | |
"nasdaq100_count": nasdaq100_count, | |
"updated_at": last, | |
"not_updated_reason": "Tickers not updated: last update was less than 1 day ago. Use force_refresh to override." | |
} | |
sp500_list = await self.get_sp500_tickers() | |
nasdaq_list = await self.get_nasdaq100_tickers() | |
combined = sp500_list + nasdaq_list | |
ticker_dict = {} | |
for t, n, s, sub in combined: | |
ticker_dict[t] = { | |
"name": n, | |
"sector": s, | |
"subindustry": sub, | |
"is_sp500": 1 if t in [x[0] for x in sp500_list] else 0, | |
"is_nasdaq100": 1 if t in [x[0] for x in nasdaq_list] else 0 | |
} | |
all_tickers = sorted(ticker_dict.keys()) | |
current_time = now | |
await session.execute(delete(Ticker)) | |
ticker_objects = [ | |
Ticker( | |
ticker=t, | |
name=ticker_dict[t]["name"], | |
sector=ticker_dict[t]["sector"], | |
subindustry=ticker_dict[t]["subindustry"], | |
is_sp500=ticker_dict[t]["is_sp500"], | |
is_nasdaq100=ticker_dict[t]["is_nasdaq100"], | |
last_updated=current_time | |
) | |
for t in all_tickers | |
] | |
session.add_all(ticker_objects) | |
await session.commit() | |
result = { | |
"total_tickers": len(all_tickers), | |
"sp500_count": len(sp500_list), | |
"nasdaq100_count": len(nasdaq_list), | |
"updated_at": current_time | |
} | |
self.logger.info( | |
"Tickers table updated: total=%d, sp500=%d, nasdaq100=%d at %s", | |
result["total_tickers"], | |
result["sp500_count"], | |
result["nasdaq100_count"], | |
result["updated_at"].isoformat() | |
) | |
return result | |
except Exception as e: | |
await session.rollback() | |
self.logger.error(f"Failed to update tickers: {e}") | |
raise | |
class YFinanceService: | |
def __init__(self, config: Config): | |
self.config = config | |
self.logger = logging.getLogger(__name__) | |
async def check_tickers_freshness(self, session: AsyncSession) -> bool: | |
""" | |
Check if tickers were updated within the last week (7 days). | |
Returns True if fresh, False if need update. | |
""" | |
try: | |
now = datetime.now(pytz.UTC) | |
result = await session.execute( | |
select(Ticker.last_updated).order_by(Ticker.last_updated.desc()).limit(1) | |
) | |
last_update = result.scalar() | |
if not last_update: | |
self.logger.info("No tickers found in database") | |
return False | |
# Ensure timezone awareness | |
if last_update.tzinfo is None: | |
last_update = pytz.UTC.localize(last_update) | |
delta = now - last_update | |
is_fresh = delta.total_seconds() < (7 * 24 * 3600) # 7 days | |
self.logger.info(f"Tickers last updated: {last_update.isoformat()}, Fresh: {is_fresh}") | |
return is_fresh | |
except Exception as e: | |
self.logger.error(f"Error checking ticker freshness: {e}") | |
return False | |
async def check_ticker_data_freshness(self, session: AsyncSession) -> bool: | |
""" | |
Check if ticker data was updated within the last day (24 hours). | |
Returns True if fresh, False if need update. | |
""" | |
try: | |
now = datetime.now(pytz.UTC) | |
result = await session.execute( | |
select(TickerData.created_at).order_by(TickerData.created_at.desc()).limit(1) | |
) | |
last_update = result.scalar() | |
if not last_update: | |
self.logger.info("No ticker data found in database") | |
return False | |
# Ensure timezone awareness | |
if last_update.tzinfo is None: | |
last_update = pytz.UTC.localize(last_update) | |
delta = now - last_update | |
is_fresh = delta.total_seconds() < (24 * 3600) # 24 hours | |
self.logger.info(f"Ticker data last updated: {last_update.isoformat()}, Fresh: {is_fresh}") | |
return is_fresh | |
except Exception as e: | |
self.logger.error(f"Error checking ticker data freshness: {e}") | |
return False | |
async def clear_and_bulk_insert_ticker_data(self, session: AsyncSession, ticker_list: List[str]) -> Dict[str, Any]: | |
""" | |
Clear all ticker data and insert new data in bulk with chunking for better performance. | |
Uses bulk delete and bulk insert with chunks of 500 records. | |
""" | |
try: | |
# Start timing for total end-to-end process | |
total_start_time = time.perf_counter() | |
self.logger.info(f"Starting bulk data refresh for {len(ticker_list)} tickers (clear and insert)") | |
# Start timing for data download | |
download_start_time = time.perf_counter() | |
# Download data for all tickers at once using period | |
data = yf.download(ticker_list, period='1mo', group_by='ticker', progress=True, auto_adjust=True) | |
download_end_time = time.perf_counter() | |
download_duration = download_end_time - download_start_time | |
self.logger.info(f"DEBUG: Data download completed in {download_duration:.2f} seconds for {len(ticker_list)} tickers") | |
if data.empty: | |
self.logger.warning("No data found for any tickers") | |
return { | |
"created": 0, | |
"updated": 0, | |
"date_range": {"start_date": "", "end_date": ""} | |
} | |
# Start timing for database operations | |
db_start_time = time.perf_counter() | |
# Clear all existing ticker data | |
self.logger.info("Clearing all existing ticker data...") | |
clear_start = time.perf_counter() | |
await session.execute(delete(TickerData)) | |
clear_end = time.perf_counter() | |
self.logger.info(f"DEBUG: Data cleared in {clear_end - clear_start:.2f} seconds") | |
# Prepare data for bulk insert | |
current_time = datetime.now(pytz.UTC) | |
all_records = [] | |
# Get actual date range from the data | |
all_dates = data.index.tolist() | |
start_date = min(all_dates).date() if all_dates else datetime.now().date() | |
end_date = max(all_dates).date() if all_dates else datetime.now().date() | |
# Handle both single ticker and multi-ticker cases | |
if len(ticker_list) == 1: | |
# Single ticker case - data is not grouped | |
ticker = ticker_list[0] | |
for date_idx, row in data.iterrows(): | |
if pd.isna(row['Close']): | |
continue | |
trade_date = date_idx.date() | |
record = { | |
'ticker': ticker, | |
'date': trade_date, | |
'open': float(row['Open']), | |
'high': float(row['High']), | |
'low': float(row['Low']), | |
'close': float(row['Close']), | |
'volume': int(row['Volume']), | |
'created_at': current_time | |
} | |
all_records.append(record) | |
else: | |
# Multiple tickers case - data is grouped by ticker | |
for ticker in ticker_list: | |
if ticker not in data.columns.get_level_values(0): | |
self.logger.warning(f"No data found for ticker {ticker}") | |
continue | |
ticker_data = data[ticker] | |
if ticker_data.empty: | |
continue | |
for date_idx, row in ticker_data.iterrows(): | |
if pd.isna(row['Close']): | |
continue | |
trade_date = date_idx.date() | |
record = { | |
'ticker': ticker, | |
'date': trade_date, | |
'open': float(row['Open']), | |
'high': float(row['High']), | |
'low': float(row['Low']), | |
'close': float(row['Close']), | |
'volume': int(row['Volume']), | |
'created_at': current_time | |
} | |
all_records.append(record) | |
# Bulk insert in chunks of 1000 (optimized for MySQL performance) | |
chunk_size = 1000 | |
total_records = len(all_records) | |
inserted_count = 0 | |
self.logger.info(f"Inserting {total_records} records in chunks of {chunk_size}") | |
for i in range(0, total_records, chunk_size): | |
chunk = all_records[i:i + chunk_size] | |
chunk_start = time.perf_counter() | |
# Create TickerData objects for bulk insert | |
ticker_objects = [TickerData(**record) for record in chunk] | |
session.add_all(ticker_objects) | |
chunk_end = time.perf_counter() | |
inserted_count += len(chunk) | |
self.logger.info(f"DEBUG: Inserted chunk {i//chunk_size + 1}/{(total_records + chunk_size - 1)//chunk_size} ({len(chunk)} records) in {chunk_end - chunk_start:.2f} seconds") | |
# Commit all changes | |
commit_start = time.perf_counter() | |
await session.commit() | |
commit_end = time.perf_counter() | |
self.logger.info(f"DEBUG: Database commit completed in {commit_end - commit_start:.2f} seconds") | |
db_end_time = time.perf_counter() | |
db_duration = db_end_time - db_start_time | |
self.logger.info(f"DEBUG: Database operations completed in {db_duration:.2f} seconds for {inserted_count} records") | |
# Calculate total end-to-end duration | |
total_end_time = time.perf_counter() | |
total_duration = total_end_time - total_start_time | |
self.logger.info(f"DEBUG: Total bulk refresh completed in {total_duration:.2f} seconds (download: {download_duration:.2f}s, database: {db_duration:.2f}s)") | |
self.logger.info(f"Bulk refresh: inserted {inserted_count} records") | |
return { | |
"created": inserted_count, | |
"updated": 0, | |
"date_range": { | |
"start_date": start_date.isoformat(), | |
"end_date": end_date.isoformat() | |
} | |
} | |
except Exception as e: | |
await session.rollback() | |
self.logger.error(f"Error in bulk refresh: {e}") | |
raise | |
async def download_all_tickers_data(self, session: AsyncSession, ticker_list: Optional[List[str]] = None) -> Dict[str, Any]: | |
""" | |
Download data for all or specified tickers for the last month. | |
Uses smart strategy: checks data freshness, if > 24h, clears DB and bulk inserts new data. | |
""" | |
try: | |
# Check ticker freshness and update if needed | |
if not await self.check_tickers_freshness(session): | |
self.logger.info("Tickers are stale, updating...") | |
ticker_service = TickerService(self.config) | |
await ticker_service.update_tickers_in_db(session, force_refresh=True) | |
# Get tickers to process | |
if ticker_list: | |
# Validate provided tickers exist in database | |
result = await session.execute( | |
select(Ticker.ticker).where(Ticker.ticker.in_(ticker_list)) | |
) | |
valid_tickers = [row[0] for row in result.fetchall()] | |
invalid_tickers = set(ticker_list) - set(valid_tickers) | |
if invalid_tickers: | |
self.logger.warning(f"Invalid tickers ignored: {invalid_tickers}") | |
tickers_to_process = valid_tickers | |
else: | |
# Get all tickers from database | |
result = await session.execute(select(Ticker.ticker)) | |
tickers_to_process = [row[0] for row in result.fetchall()] | |
if not tickers_to_process: | |
return { | |
"tickers_processed": 0, | |
"records_created": 0, | |
"records_updated": 0, | |
"date_range": {"start_date": "", "end_date": ""}, | |
"message": "No valid tickers found to process" | |
} | |
# Check if ticker data is fresh (less than 24h old) | |
if await self.check_ticker_data_freshness(session): | |
self.logger.info("Ticker data is fresh (less than 24h old), skipping update") | |
return { | |
"tickers_processed": len(tickers_to_process), | |
"records_created": 0, | |
"records_updated": 0, | |
"date_range": {"start_date": "", "end_date": ""}, | |
"message": f"Data is fresh, no update needed for {len(tickers_to_process)} tickers" | |
} | |
# Data is stale - use bulk refresh strategy | |
self.logger.info("Data is stale (>24h old), using bulk refresh strategy") | |
result = await self.clear_and_bulk_insert_ticker_data(session, tickers_to_process) | |
total_created = result["created"] | |
total_updated = result["updated"] | |
successful_tickers = len(tickers_to_process) | |
return { | |
"tickers_processed": successful_tickers, | |
"records_created": total_created, | |
"records_updated": total_updated, | |
"date_range": result["date_range"], | |
"message": f"Successfully processed {successful_tickers} tickers using bulk refresh" | |
} | |
except Exception as e: | |
self.logger.error(f"Error in download_all_tickers_data: {e}") | |
raise | |
# --- DATABASE --- | |
class Database: | |
def __init__(self, config: Config): | |
self.config = config | |
# Filter out pool params not supported by NullPool | |
db_opts = self.config.config.get('database', {}).copy() | |
db_opts.pop('pool_size', None) | |
db_opts.pop('max_overflow', None) | |
self.engine = create_async_engine( | |
config.database_url, | |
pool_pre_ping=True, | |
poolclass=NullPool, | |
**db_opts | |
) | |
self.async_session = async_sessionmaker( | |
self.engine, | |
class_=AsyncSession, | |
expire_on_commit=False | |
) | |
async def create_tables(self): | |
async with self.engine.begin() as conn: | |
await conn.run_sync(Base.metadata.create_all) | |
# --- TASK MANAGER --- | |
from sqlalchemy import JSON as SQLAlchemyJSON | |
class Task(Base): | |
__tablename__ = "tasks" | |
task_id: Mapped[str] = mapped_column(String(64), primary_key=True) | |
status: Mapped[str] = mapped_column(String(32), nullable=False) | |
message: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) | |
result: Mapped[Optional[dict]] = mapped_column(SQLAlchemyJSON, nullable=True) | |
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False) | |
class TaskManager: | |
def __init__(self, database: Database): | |
self.database = database | |
async def create_table_if_not_exists(self): | |
async with self.database.engine.begin() as conn: | |
await conn.run_sync(Base.metadata.create_all) | |
async def create_task(self, task_id: str) -> TaskStatus: | |
async with self.database.async_session() as session: | |
now = datetime.utcnow() | |
db_task = Task( | |
task_id=task_id, | |
status="pending", | |
message=None, | |
result=None, | |
created_at=now | |
) | |
session.add(db_task) | |
await session.commit() | |
return TaskStatus( | |
task_id=task_id, | |
status="pending", | |
message=None, | |
result=None, | |
created_at=now | |
) | |
async def update_task(self, task_id: str, status: str, message: str = None, result: Dict = None): | |
def serialize_datetimes(obj): | |
if isinstance(obj, dict): | |
return {k: serialize_datetimes(v) for k, v in obj.items()} | |
elif isinstance(obj, list): | |
return [serialize_datetimes(v) for v in obj] | |
elif isinstance(obj, datetime): | |
return obj.isoformat() | |
else: | |
return obj | |
async with self.database.async_session() as session: | |
db_task = await session.get(Task, task_id) | |
if db_task: | |
db_task.status = status | |
db_task.message = message | |
db_task.result = serialize_datetimes(result) if result is not None else None | |
await session.commit() | |
async def get_task(self, task_id: str) -> Optional[TaskStatus]: | |
async with self.database.async_session() as session: | |
db_task = await session.get(Task, task_id) | |
if db_task: | |
return TaskStatus( | |
task_id=db_task.task_id, | |
status=db_task.status, | |
message=db_task.message, | |
result=db_task.result, | |
created_at=db_task.created_at | |
) | |
return None | |
async def list_tasks(self) -> list[TaskStatus]: | |
async with self.database.async_session() as session: | |
result = await session.execute(select(Task)) | |
tasks = result.scalars().all() | |
return [ | |
TaskStatus( | |
task_id=t.task_id, | |
status=t.status, | |
message=t.message, | |
result=t.result, | |
created_at=t.created_at | |
) for t in tasks | |
] | |
async def delete_old_tasks(self, older_than_seconds: int = 3600) -> int: | |
cutoff = datetime.utcnow() - timedelta(seconds=older_than_seconds) | |
async with self.database.async_session() as session: | |
result = await session.execute(select(Task).where(Task.created_at < cutoff)) | |
old_tasks = result.scalars().all() | |
count = len(old_tasks) | |
for t in old_tasks: | |
await session.delete(t) | |
await session.commit() | |
return count | |
# --- APP SETUP --- | |
# Global instances | |
config = Config() | |
database = Database(config) | |
ticker_service = TickerService(config) | |
yfinance_service = YFinanceService(config) | |
task_manager = TaskManager(database) | |
# Dependency function | |
async def get_db_session() -> AsyncGenerator[AsyncSession, None]: | |
async with database.async_session() as session: | |
yield session | |
async def lifespan(app: FastAPI): | |
# Startup | |
await database.create_tables() | |
await task_manager.create_table_if_not_exists() | |
logging.info("Database tables created/verified") | |
yield | |
# Shutdown | |
await database.engine.dispose() | |
logging.info("Database connections closed") | |
# Create FastAPI app | |
app = FastAPI( | |
title="Stock Monitoring API", | |
description="API for managing S&P 500 and Nasdaq 100 ticker data", | |
version="0.1.0", | |
lifespan=lifespan, | |
swagger_ui_parameters={"faviconUrl": "/static/favicon.ico"} | |
) | |
# Serve static files (make sure a 'static' folder exists at project root with favicon.ico inside) | |
app.mount("/static", StaticFiles(directory=os.path.join(os.path.dirname(__file__), "static")), name="static") | |
# Favicon endpoint | |
async def favicon(): | |
return FileResponse(os.path.join(os.path.dirname(__file__), "static", "favicon.ico")) | |
# --- API ENDPOINTS --- | |
async def root_info(): | |
""" | |
Get API health status, current timestamp, versions, and DB/tables check. | |
**Logic**: | |
- Returns a JSON object with: | |
- **status**: Health status of the API | |
- **timestamp**: Current time in UTC timezone | |
- **versions**: Dictionary with Python and main library versions | |
- **database**: Connection status and existence of 'tickers' and 'tasks' tables | |
**Args**: None | |
**Example response:** | |
```json | |
{ | |
"status": "healthy", | |
"timestamp": "2025-07-19T19:38:26+02:00", | |
"versions": { ... }, | |
"database": { | |
"connected": true, | |
"tickers_table": true, | |
"tasks_table": true | |
} | |
} | |
``` | |
""" | |
now_utc = datetime.now(pytz.UTC) | |
versions = {} | |
versions["python"] = platform.python_version() | |
packages = ["uvicorn", "fastapi", "sqlalchemy", "pandas"] | |
for pkg in packages: | |
try: | |
versions[pkg] = importlib.metadata.version(pkg) | |
except Exception: | |
versions[pkg] = None | |
db_status = { | |
"connected": False, | |
"tickers_table": False, | |
"tasks_table": False | |
} | |
db_check_time = None | |
start = time.perf_counter() | |
try: | |
async with database.engine.connect() as conn: | |
db_status["connected"] = True | |
insp = await conn.run_sync(lambda c: c.dialect.get_table_names(c)) | |
db_status["tickers_table"] = "tickers" in insp | |
db_status["tasks_table"] = "tasks" in insp | |
except Exception as e: | |
db_status["connected"] = False | |
finally: | |
db_check_time = time.perf_counter() - start | |
return { | |
"status": "healthy" if db_status["connected"] and db_status["tickers_table"] and db_status["tasks_table"] else "degraded", | |
"timestamp": now_utc.isoformat(), | |
"versions": versions, | |
"database": db_status, | |
"db_check_seconds": round(db_check_time, 4) if db_check_time is not None else None | |
} | |
async def get_tickers( | |
is_sp500: Optional[bool] = None, | |
is_nasdaq: Optional[bool] = None, | |
limit: int = 1000, | |
session: AsyncSession = Depends(get_db_session) | |
): | |
""" | |
Get all tickers from database with optional filtering. | |
**Logic**: | |
- No parameters: Return all tickers | |
- is_sp500=true: Only S&P 500 tickers | |
- is_sp500=false: Only NON-S&P 500 tickers | |
- is_nasdaq=true: Only Nasdaq 100 tickers | |
- is_nasdaq=false: Only NON-Nasdaq 100 tickers | |
- Both parameters: Apply AND logic (intersection of conditions) | |
**Args (all optional)**: | |
- **is_sp500** (optional): Filter for S&P 500 membership (true/false/None) | |
- **is_nasdaq** (optional): Filter for Nasdaq 100 membership (true/false/None) | |
- **limit** (optional): Maximum number of results to return | |
**Examples:** | |
- `GET /tickers` - All tickers | |
- `GET /tickers?is_sp500=true` - Only S&P 500 | |
- `GET /tickers?is_nasdaq=true&is_sp500=false` - Only Nasdaq 100 but not S&P 500 | |
- `GET /tickers?is_sp500=true&is_nasdaq=false` - S&P 500 but not Nasdaq 100 | |
""" | |
try: | |
query = select(Ticker) | |
# Build conditions based on explicit flag values | |
conditions = [] | |
if is_sp500 is not None: | |
if is_sp500: | |
conditions.append(Ticker.is_sp500 == 1) | |
else: | |
conditions.append(Ticker.is_sp500 == 0) | |
if is_nasdaq is not None: | |
if is_nasdaq: | |
conditions.append(Ticker.is_nasdaq100 == 1) | |
else: | |
conditions.append(Ticker.is_nasdaq100 == 0) | |
# Apply filtering if we have conditions | |
if conditions: | |
from sqlalchemy import and_ | |
query = query.where(and_(*conditions)) | |
query = query.limit(limit).order_by(Ticker.ticker) | |
result = await session.execute(query) | |
tickers = result.scalars().all() | |
return [ | |
TickerResponse( | |
ticker=t.ticker, | |
name=t.name, | |
sector=t.sector, | |
subindustry=t.subindustry, | |
is_sp500=bool(t.is_sp500), | |
is_nasdaq100=bool(t.is_nasdaq100), | |
last_updated=t.last_updated | |
) | |
for t in tickers | |
] | |
except Exception as e: | |
logging.error(f"Error fetching tickers: {e}") | |
raise HTTPException(status_code=500, detail="Failed to fetch tickers") | |
async def update_tickers( | |
request: UpdateTickersRequest, | |
background_tasks: BackgroundTasks, | |
session: AsyncSession = Depends(get_db_session), | |
api_key: str = Depends(verify_api_key) | |
): | |
""" | |
Update tickers from Wikipedia sources (S&P 500 and Nasdaq 100). | |
**Logic**: | |
- Fetches latest tickers from Wikipedia (S&P 500 and Nasdaq 100). | |
- Updates the database with the new tickers. | |
- Returns summary of update (counts, timestamp). | |
**Args**: | |
- **request**: UpdateTickersRequest (force_refresh: bool) | |
- **background_tasks**: FastAPI BackgroundTasks (unused) | |
- **session**: AsyncSession (DB session, injected) | |
**Example request:** | |
```json | |
{ "force_refresh": false } | |
{ "force_refresh": true } | |
``` | |
**Example response:** | |
```json | |
{ | |
"success": true, | |
"message": "Tickers updated successfully", | |
"total_tickers": 517, | |
"sp500_count": 500, | |
"nasdaq100_count": 100, | |
"updated_at": "2025-07-19T19:38:26+02:00" | |
} | |
``` | |
""" | |
try: | |
result = await ticker_service.update_tickers_in_db(session, force_refresh=request.force_refresh) | |
message = result.pop("not_updated_reason", None) | |
if message: | |
return UpdateTickersResponse( | |
success=True, | |
message=message, | |
**result | |
) | |
return UpdateTickersResponse( | |
success=True, | |
message="Tickers updated successfully", | |
**result | |
) | |
except Exception as e: | |
logging.error(f"Error updating tickers: {e}") | |
raise HTTPException(status_code=500, detail=f"Failed to update tickers: {str(e)}") | |
async def update_tickers_async( | |
request: UpdateTickersRequest, | |
background_tasks: BackgroundTasks, | |
api_key: str = Depends(verify_api_key) | |
): | |
""" | |
Start async ticker update task (background). | |
**Logic**: | |
- Launches a background task to update tickers from Wikipedia. | |
- Returns a task_id and status for tracking. | |
**Args**: | |
- **request**: UpdateTickersRequest (force_refresh: bool) | |
**Example request:** | |
```json | |
{ "force_refresh": false } | |
{ "force_refresh": true } | |
``` | |
**Example response:** | |
```json | |
{ | |
"task_id": "c1a2b3d4-5678-90ab-cdef-1234567890ab", | |
"status": "started" | |
} | |
``` | |
""" | |
import uuid | |
task_id = str(uuid.uuid4()) | |
await task_manager.create_task(task_id) | |
async def update_task(): | |
try: | |
await task_manager.update_task(task_id, "running", "Updating tickers...") | |
async with database.async_session() as session: | |
result = await ticker_service.update_tickers_in_db(session, force_refresh=request.force_refresh) | |
message = result.pop("not_updated_reason", None) | |
if message: | |
await task_manager.update_task(task_id, "completed", message, result) | |
else: | |
await task_manager.update_task(task_id, "completed", "Update successful", result) | |
except Exception as e: | |
await task_manager.update_task(task_id, "failed", str(e)) | |
background_tasks.add_task(update_task) | |
return {"task_id": task_id, "status": "started"} | |
async def list_all_tasks(api_key: str = Depends(verify_api_key)): | |
""" | |
List all background tasks and their status. | |
**Logic**: | |
- Returns a list of all tasks created via async update endpoint, with their status and result. | |
**Args**: None | |
**Example response:** | |
```json | |
[ | |
{ | |
"task_id": "c1a2b3d4-5678-90ab-cdef-1234567890ab", | |
"status": "completed", | |
"message": "Tickers updated successfully", | |
"result": { | |
"total_tickers": 517, | |
"sp500_count": 500, | |
"nasdaq100_count": 100, | |
"updated_at": "2025-07-19T19:38:26+02:00" | |
}, | |
"created_at": "2025-07-19T19:38:26+02:00" | |
}, | |
... | |
] | |
``` | |
""" | |
return await task_manager.list_tasks() | |
async def get_task_status(task_id: str, api_key: str = Depends(verify_api_key)): | |
""" | |
Get status and result of a background update task by task_id. | |
**Logic**: | |
- Returns the status and result of a background update task by task_id. | |
- If not found, returns 404. | |
**Args**: | |
- **task_id**: str (UUID of the task) | |
**Example response:** | |
```json | |
{ | |
"task_id": "c1a2b3d4-5678-90ab-cdef-1234567890ab", | |
"status": "completed", | |
"message": "Tickers updated successfully", | |
"result": { | |
"total_tickers": 517, | |
"sp500_count": 500, | |
"nasdaq100_count": 100, | |
"updated_at": "2025-07-19T19:38:26+02:00" | |
}, | |
"created_at": "2025-07-19T19:38:26+02:00" | |
} | |
``` | |
""" | |
task = await task_manager.get_task(task_id) | |
if not task: | |
raise HTTPException(status_code=404, detail="Task not found") | |
return task | |
# Endpoint to delete tasks older than 1 hour | |
async def delete_old_tasks(api_key: str = Depends(verify_api_key)): | |
""" | |
Delete tasks older than 1 hour (3600 seconds). | |
**Logic**: | |
- Deletes all tasks in the database older than 1 hour. | |
- Returns the number of deleted tasks. | |
**Args**: None | |
**Example response:** | |
```json | |
{ "deleted": 5 } | |
``` | |
""" | |
deleted_count = await task_manager.delete_old_tasks(older_than_seconds=3600) | |
return {"deleted": deleted_count} | |
async def download_all_tickers_data( | |
session: AsyncSession = Depends(get_db_session), | |
api_key: str = Depends(verify_api_key) | |
): | |
""" | |
Download daily ticker data for the last month for ALL tickers in database. | |
**Logic**: | |
- Automatically downloads data for all tickers stored in the tickers table | |
- Checks if tickers were updated within the last week, updates if needed | |
- Only downloads if ticker data is older than 24 hours | |
- Downloads daily data for the last 30 days for all available tickers | |
- Uses bulk delete and insert strategy for optimal performance | |
- Returns summary with counts and date range | |
**Args**: | |
- **session**: AsyncSession (DB session, injected) | |
- **api_key**: str (API key for authentication, injected) | |
**Example request:** | |
```bash | |
curl -X POST "http://localhost:${PORT}/data/download-all" \ | |
-H "Authorization: Bearer your_api_key" | |
``` | |
**Example response:** | |
```json | |
{ | |
"success": true, | |
"message": "Successfully processed 503 tickers using bulk refresh", | |
"tickers_processed": 503, | |
"records_created": 12075, | |
"records_updated": 0, | |
"date_range": { | |
"start_date": "2025-06-30", | |
"end_date": "2025-07-30" | |
}, | |
"updated_at": "2025-07-30T14:15:26+00:00" | |
} | |
``` | |
""" | |
try: | |
# Use existing service without specifying ticker list (downloads all) | |
result = await yfinance_service.download_all_tickers_data( | |
session, | |
ticker_list=None # None means download all tickers | |
) | |
return DownloadDataResponse( | |
success=True, | |
message=result["message"], | |
tickers_processed=result["tickers_processed"], | |
records_created=result["records_created"], | |
records_updated=result["records_updated"], | |
date_range=result["date_range"], | |
updated_at=datetime.now(pytz.UTC) | |
) | |
except Exception as e: | |
logging.error(f"Error downloading all ticker data: {e}") | |
raise HTTPException(status_code=500, detail=f"Failed to download all ticker data: {str(e)}") | |
async def get_ticker_data( | |
ticker: str, | |
days: int = 30, | |
session: AsyncSession = Depends(get_db_session) | |
): | |
""" | |
Get historical data for a specific ticker. | |
**Logic**: | |
- Returns historical data for the specified ticker | |
- Defaults to last 30 days if no days parameter provided | |
- Data is ordered by date descending (most recent first) | |
**Args**: | |
- **ticker**: str (Ticker symbol, e.g., "AAPL") | |
- **days**: int (Number of days to retrieve, default 30) | |
- **session**: AsyncSession (DB session, injected) | |
**Example response:** | |
```json | |
[ | |
{ | |
"ticker": "AAPL", | |
"date": "2025-07-30", | |
"open": 150.25, | |
"high": 152.80, | |
"low": 149.50, | |
"close": 151.75, | |
"volume": 45123000, | |
"created_at": "2025-07-30T13:45:26+00:00" | |
} | |
] | |
``` | |
""" | |
try: | |
# Calculate date range | |
end_date = datetime_date.today() | |
start_date = end_date - timedelta(days=days) | |
# Query ticker data | |
query = select(TickerData).where( | |
TickerData.ticker == ticker.upper(), | |
TickerData.date >= start_date, | |
TickerData.date <= end_date | |
).order_by(TickerData.date.desc()) | |
result = await session.execute(query) | |
ticker_data = result.scalars().all() | |
if not ticker_data: | |
raise HTTPException( | |
status_code=404, | |
detail=f"No data found for ticker {ticker.upper()} in the last {days} days" | |
) | |
return [ | |
TickerDataResponse( | |
ticker=data.ticker, | |
date=data.date, | |
open=data.open, | |
high=data.high, | |
low=data.low, | |
close=data.close, | |
volume=data.volume, | |
created_at=data.created_at | |
) | |
for data in ticker_data | |
] | |
except HTTPException: | |
raise | |
except Exception as e: | |
logging.error(f"Error fetching ticker data for {ticker}: {e}") | |
raise HTTPException(status_code=500, detail="Failed to fetch ticker data") | |
# Local execution configuration | |
if __name__ == "__main__": | |
import uvicorn | |
HOST = os.getenv("HOST", "0.0.0.0") | |
PORT = int(os.getenv("PORT", 8000)) | |
# Determina el valor de reload según la variable de entorno PROD | |
RELOAD = os.getenv("PROD", "False") != "True" | |
# Start the Uvicorn server | |
uvicorn.run("index:app", host=HOST, port=PORT, reload=RELOAD) | |