Spaces:
Sleeping
Sleeping
# api/index.py | |
import os | |
import logging | |
import time | |
from datetime import datetime, timedelta | |
from datetime import date as datetime_date | |
from typing import List, Dict, Any, Optional, AsyncGenerator | |
import asyncio | |
from contextlib import asynccontextmanager | |
import yaml | |
import importlib.metadata | |
import pytz | |
from fastapi import FastAPI, HTTPException, BackgroundTasks, Depends, Security, Request, Response | |
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials | |
from pydantic import BaseModel, Field | |
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker | |
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column | |
from sqlalchemy import String, Integer, DateTime, select, delete, Float, Index | |
from sqlalchemy.types import Date as SQLAlchemyDate | |
from dotenv import load_dotenv, find_dotenv | |
from sqlalchemy.pool import NullPool | |
import requests | |
import pandas as pd | |
from io import StringIO | |
import ssl | |
import certifi | |
import aiohttp | |
import platform | |
import yfinance as yf | |
from fastapi.responses import JSONResponse | |
# --- MODELS --- | |
class Base(DeclarativeBase): | |
pass | |
class Ticker(Base): | |
__tablename__ = "tickers" | |
ticker: Mapped[str] = mapped_column(String(10), primary_key=True) | |
name: Mapped[str] = mapped_column(String(255), nullable=False) | |
sector: Mapped[Optional[str]] = mapped_column(String(128), nullable=True) | |
subindustry: Mapped[Optional[str]] = mapped_column(String(128), nullable=True) | |
is_sp500: Mapped[int] = mapped_column(Integer, default=0) | |
is_nasdaq100: Mapped[int] = mapped_column(Integer, default=0) | |
last_updated: Mapped[datetime] = mapped_column(DateTime) | |
class TickerData(Base): | |
__tablename__ = "ticker_data" | |
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) | |
ticker: Mapped[str] = mapped_column(String(10), nullable=False) | |
date: Mapped[datetime_date] = mapped_column(SQLAlchemyDate, nullable=False) | |
open: Mapped[float] = mapped_column(Float, nullable=False) | |
high: Mapped[float] = mapped_column(Float, nullable=False) | |
low: Mapped[float] = mapped_column(Float, nullable=False) | |
close: Mapped[float] = mapped_column(Float, nullable=False) | |
volume: Mapped[int] = mapped_column(Integer, nullable=False) | |
sma_fast: Mapped[Optional[float]] = mapped_column(Float, nullable=True) # SMA 10 | |
sma_med: Mapped[Optional[float]] = mapped_column(Float, nullable=True) # SMA 20 | |
sma_slow: Mapped[Optional[float]] = mapped_column(Float, nullable=True) # SMA 50 | |
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False) | |
__table_args__ = ( | |
Index('idx_ticker_date', 'ticker', 'date', unique=True), | |
Index('idx_ticker', 'ticker'), | |
Index('idx_date', 'date'), | |
) | |
# --- PYDANTIC MODELS --- | |
class TickerResponse(BaseModel): | |
ticker: str | |
name: str | |
sector: Optional[str] | |
subindustry: Optional[str] | |
is_sp500: bool | |
is_nasdaq100: bool | |
last_updated: datetime | |
class UpdateTickersRequest(BaseModel): | |
force_refresh: bool = Field(default=False, description="Force refresh even if data is recent") | |
class UpdateTickersResponse(BaseModel): | |
success: bool | |
message: str | |
total_tickers: int | |
sp500_count: int | |
nasdaq100_count: int | |
updated_at: datetime | |
class TaskStatus(BaseModel): | |
task_id: str | |
status: str # pending, running, completed, failed | |
message: Optional[str] = None | |
result: Optional[Dict[str, Any]] = None | |
created_at: datetime | |
class TickerDataResponse(BaseModel): | |
ticker: str | |
date: datetime_date | |
open: float | |
high: float | |
low: float | |
close: float | |
volume: int | |
sma_fast: Optional[float] = None # SMA 10 | |
sma_med: Optional[float] = None # SMA 20 | |
sma_slow: Optional[float] = None # SMA 50 | |
created_at: datetime | |
class DownloadDataRequest(BaseModel): | |
tickers: Optional[List[str]] = Field(default=None, description="Specific tickers to download. If not provided, downloads all available tickers") | |
force_refresh: bool = Field(default=False, description="Force refresh even if data exists") | |
force_indicators: bool = Field(default=False, description="Force calculation of technical indicators even if data is fresh") | |
class DownloadDataResponse(BaseModel): | |
success: bool | |
message: str | |
tickers_processed: int | |
records_created: int | |
records_updated: int | |
date_range: Dict[str, str] # start_date, end_date | |
updated_at: datetime | |
class FinancialDataRequest(BaseModel): | |
tickers: List[str] = Field(..., description="Stock ticker symbols (e.g., ['AAPL', 'MSFT', 'GOOGL'])") | |
period: str = Field(default="3mo", description="Data period: 1d,5d,1mo,3mo,6mo,1y,2y,5y,10y,ytd,max") | |
intraday: bool = Field(default=False, description="Enable intraday data (pre-market to post-market)") | |
interval: str = Field(default="1d", description="Data interval: 1m,2m,5m,15m,30m,60m,90m,1h,4h,1d,5d,1wk,1mo,3mo") | |
class TechnicalIndicatorData(BaseModel): | |
ticker: str | |
datetime: str # Changed from 'date' to 'datetime' for intraday support | |
open: float | |
high: float | |
low: float | |
close: float | |
volume: int | |
sma_fast: Optional[float] = None # SMA 10 | |
sma_med: Optional[float] = None # SMA 20 | |
sma_slow: Optional[float] = None # SMA 50 | |
class MarketStatus(BaseModel): | |
is_open: bool | |
market_state: str # REGULAR, PREPRE, PRE, POST, POSTPOST, CLOSED | |
timezone: str | |
class FinancialDataResponse(BaseModel): | |
success: bool | |
tickers: List[str] | |
period: str | |
interval: str | |
intraday: bool | |
total_data_points: int | |
date_range: Dict[str, str] # start_date, end_date | |
market_status: Optional[MarketStatus] = None | |
data: List[TechnicalIndicatorData] | |
calculated_at: datetime | |
# --- AUTHENTICATION --- | |
security = HTTPBearer() | |
async def verify_api_key(credentials: HTTPAuthorizationCredentials = Security(security)): | |
""" | |
Verify API key from Authorization header. | |
Expected format: Authorization: Bearer <api_key> | |
""" | |
api_key = os.getenv("API_KEY") | |
if not api_key: | |
raise HTTPException( | |
status_code=500, | |
detail="API key not configured on server" | |
) | |
if credentials.credentials != api_key: | |
raise HTTPException( | |
status_code=401, | |
detail="Invalid API key" | |
) | |
return credentials.credentials | |
# --- RATE LIMITING --- | |
class RateLimiter: | |
def __init__(self): | |
self.requests = {} # {ip_address: {endpoint: [(timestamp, count), ...]}} | |
self.limits = { | |
"/data/analyze": {"requests": 20, "window": 60}, # 20 requests per minute | |
"default": {"requests": 100, "window": 60} # 100 requests per minute default | |
} | |
def is_allowed(self, client_ip: str, endpoint: str) -> tuple[bool, dict]: | |
""" | |
Check if request is within rate limits. | |
Returns (is_allowed, rate_info) | |
""" | |
current_time = time.time() | |
# Get limits for this endpoint | |
limit_config = self.limits.get(endpoint, self.limits["default"]) | |
max_requests = limit_config["requests"] | |
window_seconds = limit_config["window"] | |
# Initialize tracking for this IP if needed | |
if client_ip not in self.requests: | |
self.requests[client_ip] = {} | |
if endpoint not in self.requests[client_ip]: | |
self.requests[client_ip][endpoint] = [] | |
# Clean old requests outside the window | |
cutoff_time = current_time - window_seconds | |
self.requests[client_ip][endpoint] = [ | |
(timestamp, count) for timestamp, count in self.requests[client_ip][endpoint] | |
if timestamp > cutoff_time | |
] | |
# Count current requests in window | |
current_count = sum(count for _, count in self.requests[client_ip][endpoint]) | |
# Check if limit exceeded | |
if current_count >= max_requests: | |
return False, { | |
"allowed": False, | |
"current_count": current_count, | |
"limit": max_requests, | |
"window_seconds": window_seconds, | |
"reset_time": max(timestamp for timestamp, _ in self.requests[client_ip][endpoint]) + window_seconds | |
} | |
# Allow request and record it | |
self.requests[client_ip][endpoint].append((current_time, 1)) | |
return True, { | |
"allowed": True, | |
"current_count": current_count + 1, | |
"limit": max_requests, | |
"window_seconds": window_seconds, | |
"remaining": max_requests - current_count - 1 | |
} | |
# Global rate limiter instance | |
rate_limiter = RateLimiter() | |
async def check_rate_limit(request: Request, endpoint: str = "/data/analyze"): | |
""" | |
Dependency to check rate limits for endpoints. | |
""" | |
# Get client IP (handle proxies) | |
client_ip = request.headers.get("x-forwarded-for", "").split(",")[0].strip() | |
if not client_ip: | |
client_ip = request.headers.get("x-real-ip", "") | |
if not client_ip: | |
client_ip = getattr(request.client, "host", "unknown") | |
is_allowed, rate_info = rate_limiter.is_allowed(client_ip, endpoint) | |
if not is_allowed: | |
reset_time = int(rate_info["reset_time"]) | |
logger = logging.getLogger(__name__) | |
logger.warning(f"rate_limit_exceeded client_ip={client_ip} endpoint={endpoint} count={rate_info['current_count']} limit={rate_info['limit']}") | |
raise HTTPException( | |
status_code=429, | |
detail={ | |
"error": "Rate limit exceeded", | |
"limit": rate_info["limit"], | |
"window_seconds": rate_info["window_seconds"], | |
"reset_time": reset_time, | |
"current_count": rate_info["current_count"] | |
}, | |
headers={ | |
"X-RateLimit-Limit": str(rate_info["limit"]), | |
"X-RateLimit-Remaining": "0", | |
"X-RateLimit-Reset": str(reset_time), | |
"Retry-After": str(int(rate_info["reset_time"] - time.time())) | |
} | |
) | |
return rate_info | |
async def add_security_headers(response: Response, rate_info: dict = None): | |
""" | |
Add security headers to response. | |
""" | |
response.headers["X-Content-Type-Options"] = "nosniff" | |
response.headers["X-Frame-Options"] = "DENY" | |
response.headers["X-XSS-Protection"] = "1; mode=block" | |
if rate_info: | |
response.headers["X-RateLimit-Limit"] = str(rate_info["limit"]) | |
response.headers["X-RateLimit-Remaining"] = str(rate_info.get("remaining", 0)) | |
return response | |
# --- CONFIGURATION --- | |
class Config: | |
def __init__(self): | |
load_dotenv(find_dotenv()) | |
self.config = self._load_yaml_config() | |
self._setup_logging() | |
def _load_yaml_config(self, config_path='config.yaml'): | |
try: | |
with open(config_path, 'r') as f: | |
return yaml.safe_load(f) | |
except FileNotFoundError: | |
logger = logging.getLogger(__name__) | |
logger.warning(f"config_file_not_found path={config_path} using_defaults=true") | |
return self._get_default_config() | |
def _get_default_config(self): | |
return { | |
'logging': {'level': 'INFO', 'log_file': 'logs/api.log'}, | |
'data_sources': { | |
'sp500': { | |
'url': 'https://en.wikipedia.org/wiki/List_of_S%26P_500_companies', | |
'ticker_column': 'Symbol', | |
'name_column': 'Security' | |
}, | |
'nasdaq100': { | |
'url': 'https://en.wikipedia.org/wiki/Nasdaq-100', | |
'ticker_column': 'Ticker', | |
'name_column': 'Company' | |
} | |
}, | |
'database': {'pool_size': 5, 'max_overflow': 10} | |
} | |
def _setup_logging(self): | |
log_config = self.config.get('logging', {}) | |
log_file = log_config.get('log_file', 'logs/api.log') | |
# Detecta si estamos en HF Spaces usando SPACE_ID | |
is_hf_spaces = os.getenv("SPACE_ID") is not None | |
handlers = [logging.StreamHandler()] | |
if not is_hf_spaces: | |
os.makedirs(os.path.dirname(log_file), exist_ok=True) | |
handlers.insert(0, logging.FileHandler(log_file)) | |
logging.basicConfig( | |
level=getattr(logging, log_config.get('level', 'INFO').upper()), | |
format='%(asctime)s [%(levelname)s] %(name)s: %(message)s', | |
handlers=handlers, | |
datefmt='%Y-%m-%d %H:%M:%S' | |
) | |
self.logger = logging.getLogger(__name__) | |
self.logger.info(f"logging_configured level={log_config.get('level', 'INFO')} hf_spaces={os.getenv('SPACE_ID') is not None}") | |
def database_url(self) -> str: | |
user = os.getenv("MYSQL_USER") | |
password = os.getenv("MYSQL_PASSWORD") | |
host = os.getenv("MYSQL_HOST") | |
port = os.getenv("MYSQL_PORT") | |
db = os.getenv("MYSQL_DB") | |
if not all([user, password]): | |
raise ValueError("MySQL credentials not found in environment variables") | |
return f"mysql+aiomysql://{user}:{password}@{host}:{port}/{db}" | |
# --- SERVICES --- | |
class TickerService: | |
def __init__(self, config: Config): | |
self.config = config | |
self.logger = logging.getLogger(__name__) | |
async def get_tickers_from_wikipedia( | |
self, url: str, ticker_column: str, name_column: str, | |
sector_column: Optional[str] = None, subindustry_column: Optional[str] = None | |
) -> List[tuple[str, str, Optional[str], Optional[str]]]: | |
"""Async version fetching ticker, name, sector, and subindustry from Wikipedia.""" | |
try: | |
ssl_context = ssl.create_default_context(cafile=certifi.where()) | |
connector = aiohttp.TCPConnector(ssl=ssl_context) | |
async with aiohttp.ClientSession(connector=connector) as session: | |
headers = {'User-Agent': 'Mozilla/5.0 (compatible; MarketDataAPI/1.0)'} | |
async with session.get(url, headers=headers) as response: | |
response.raise_for_status() | |
html_content = await response.text() | |
tables = pd.read_html(StringIO(html_content)) | |
columns_needed = [ticker_column, name_column] | |
if sector_column: | |
columns_needed.append(sector_column) | |
if subindustry_column: | |
columns_needed.append(subindustry_column) | |
df = next((table for table in tables if all(col in table.columns for col in columns_needed)), None) | |
if df is None: | |
self.logger.error(f"wikipedia_parsing_failed url={url} required_columns={columns_needed}") | |
return [] | |
entries = df[columns_needed].dropna(subset=[ticker_column]) | |
self.logger.info(f"wikipedia_data_fetched url={url} rows={len(entries)}") | |
results: List[tuple[str, str, Optional[str], Optional[str]]] = [] | |
for _, row in entries.iterrows(): | |
ticker = str(row[ticker_column]).strip() | |
name = str(row[name_column]).strip() | |
sector = str(row[sector_column]).strip() if sector_column and sector_column in row and pd.notna(row[sector_column]) else None | |
subindustry = str(row[subindustry_column]).strip() if subindustry_column and subindustry_column in row and pd.notna(row[subindustry_column]) else None | |
results.append((ticker, name, sector, subindustry)) | |
return results | |
except Exception as e: | |
self.logger.error(f"wikipedia_fetch_failed url={url} error={str(e)}") | |
return [] | |
async def get_sp500_tickers(self) -> List[tuple[str, str, Optional[str], Optional[str]]]: | |
cfg = self.config.config.get('data_sources', {}).get('sp500', {}) | |
return await self.get_tickers_from_wikipedia( | |
cfg.get('url'), | |
cfg.get('ticker_column'), | |
cfg.get('name_column'), | |
cfg.get('sector_column'), | |
cfg.get('subindustry_column') | |
) | |
async def get_nasdaq100_tickers(self) -> List[tuple[str, str, Optional[str], Optional[str]]]: | |
cfg = self.config.config.get('data_sources', {}).get('nasdaq100', {}) | |
return await self.get_tickers_from_wikipedia( | |
cfg.get('url'), | |
cfg.get('ticker_column'), | |
cfg.get('name_column'), | |
cfg.get('sector_column'), | |
cfg.get('subindustry_column') | |
) | |
async def update_tickers_in_db(self, session: AsyncSession, force_refresh: bool = False) -> Dict[str, Any]: | |
""" | |
Updates tickers table with latest data from Wikipedia sources, unless data is less than 1 day old (unless force_refresh). | |
""" | |
try: | |
# Check if tickers were updated in the last 24h | |
now = datetime.now(pytz.UTC) | |
result = await session.execute(select(Ticker.last_updated).order_by(Ticker.last_updated.desc()).limit(1)) | |
last = result.scalar() | |
if last and not force_refresh: | |
# Ensure 'last' is timezone aware | |
if last.tzinfo is None: | |
last = pytz.UTC.localize(last) | |
delta = now - last | |
if delta.total_seconds() < 86400: | |
self.logger.info(f"tickers_update_skipped last_update={last.isoformat()} reason=fresh_data") | |
from sqlalchemy import func | |
total_tickers = await session.scalar(select(func.count()).select_from(Ticker)) | |
sp500_count = await session.scalar(select(func.count()).select_from(Ticker).where(Ticker.is_sp500 == 1)) | |
nasdaq100_count = await session.scalar(select(func.count()).select_from(Ticker).where(Ticker.is_nasdaq100 == 1)) | |
return { | |
"total_tickers": total_tickers, | |
"sp500_count": sp500_count, | |
"nasdaq100_count": nasdaq100_count, | |
"updated_at": last, | |
"not_updated_reason": "Tickers not updated: last update was less than 1 day ago. Use force_refresh to override." | |
} | |
sp500_list = await self.get_sp500_tickers() | |
nasdaq_list = await self.get_nasdaq100_tickers() | |
combined = sp500_list + nasdaq_list | |
ticker_dict = {} | |
for t, n, s, sub in combined: | |
ticker_dict[t] = { | |
"name": n, | |
"sector": s, | |
"subindustry": sub, | |
"is_sp500": 1 if t in [x[0] for x in sp500_list] else 0, | |
"is_nasdaq100": 1 if t in [x[0] for x in nasdaq_list] else 0 | |
} | |
all_tickers = sorted(ticker_dict.keys()) | |
current_time = now | |
await session.execute(delete(Ticker)) | |
ticker_objects = [ | |
Ticker( | |
ticker=t, | |
name=ticker_dict[t]["name"], | |
sector=ticker_dict[t]["sector"], | |
subindustry=ticker_dict[t]["subindustry"], | |
is_sp500=ticker_dict[t]["is_sp500"], | |
is_nasdaq100=ticker_dict[t]["is_nasdaq100"], | |
last_updated=current_time | |
) | |
for t in all_tickers | |
] | |
session.add_all(ticker_objects) | |
await session.commit() | |
result = { | |
"total_tickers": len(all_tickers), | |
"sp500_count": len(sp500_list), | |
"nasdaq100_count": len(nasdaq_list), | |
"updated_at": current_time | |
} | |
self.logger.info(f"tickers_updated total={result['total_tickers']} sp500={result['sp500_count']} nasdaq100={result['nasdaq100_count']} timestamp={result['updated_at'].isoformat()}") | |
return result | |
except Exception as e: | |
await session.rollback() | |
self.logger.error(f"tickers_update_failed error={str(e)}") | |
raise | |
class YFinanceService: | |
def __init__(self, config: Config): | |
self.config = config | |
self.logger = logging.getLogger(__name__) | |
def get_market_status(self, ticker: str) -> MarketStatus: | |
""" | |
Get current market status using yfinance's most reliable endpoints. | |
Uses multiple methods for accuracy: info, calendar, and recent data. | |
""" | |
try: | |
ticker_obj = yf.Ticker(ticker) | |
# Method 1: Try to get current data with 1-minute interval | |
# This is the most reliable way to check if market is currently active | |
current_data = None | |
market_state = 'UNKNOWN' | |
timezone_name = 'America/New_York' | |
try: | |
# Get very recent data to check market activity | |
current_data = ticker_obj.history(period="1d", interval="1m", prepost=True) | |
if not current_data.empty: | |
last_timestamp = current_data.index[-1] | |
now = datetime.now(last_timestamp.tz) | |
time_diff = (now - last_timestamp).total_seconds() | |
# If last data is within 5 minutes, market is likely active | |
if time_diff <= 300: # 5 minutes | |
# Check if it's during regular hours, pre-market, or post-market | |
hour = last_timestamp.hour | |
if 9 <= hour < 16: # Regular hours (9:30 AM - 4:00 PM ET, roughly) | |
market_state = 'REGULAR' | |
elif 4 <= hour < 9: # Pre-market (4:00 AM - 9:30 AM ET) | |
market_state = 'PRE' | |
elif 16 <= hour <= 20: # Post-market (4:00 PM - 8:00 PM ET) | |
market_state = 'POST' | |
else: | |
market_state = 'CLOSED' | |
else: | |
market_state = 'CLOSED' | |
except Exception as hist_error: | |
self.logger.debug(f"history_method_failed ticker={ticker} error={str(hist_error)}") | |
# Method 2: Use ticker.info as backup/validation | |
try: | |
info = ticker_obj.info | |
info_market_state = info.get('marketState', 'UNKNOWN') | |
timezone_name = info.get('exchangeTimezoneName', 'America/New_York') | |
# If history method failed, use info method | |
if market_state == 'UNKNOWN' and info_market_state != 'UNKNOWN': | |
market_state = info_market_state | |
except Exception as info_error: | |
self.logger.debug(f"info_method_failed ticker={ticker} error={str(info_error)}") | |
# Determine if market is open | |
is_open = market_state in ['REGULAR', 'PRE', 'POST'] | |
self.logger.info(f"market_status_determined ticker={ticker} state={market_state} is_open={is_open} timezone={timezone_name}") | |
return MarketStatus( | |
is_open=is_open, | |
market_state=market_state, | |
timezone=timezone_name | |
) | |
except Exception as e: | |
self.logger.warning(f"market_status_check_failed ticker={ticker} error={str(e)}") | |
# Return conservative default status | |
return MarketStatus( | |
is_open=False, | |
market_state='UNKNOWN', | |
timezone='America/New_York' | |
) | |
def calculate_technical_indicators(self, df: pd.DataFrame) -> pd.DataFrame: | |
""" | |
Calculate technical indicators for a ticker's data. | |
Adds SMA columns: sma_fast (10), sma_med (20), sma_slow (50) | |
""" | |
if df.empty or 'Close' not in df.columns: | |
return df | |
start_time = time.perf_counter() | |
records_count = len(df) | |
# Sort by date to ensure proper calculation | |
df = df.sort_index() | |
# Calculate Simple Moving Averages | |
df['sma_fast'] = df['Close'].rolling(window=10, min_periods=10).mean() | |
df['sma_med'] = df['Close'].rolling(window=20, min_periods=20).mean() | |
df['sma_slow'] = df['Close'].rolling(window=50, min_periods=50).mean() | |
end_time = time.perf_counter() | |
duration = end_time - start_time | |
self.logger.info(f"technical_indicators_calculated records={records_count} duration_ms={duration*1000:.2f}") | |
return df | |
async def check_tickers_freshness(self, session: AsyncSession) -> bool: | |
""" | |
Check if tickers were updated within the last week (7 days). | |
Returns True if fresh, False if need update. | |
""" | |
try: | |
now = datetime.now(pytz.UTC) | |
result = await session.execute( | |
select(Ticker.last_updated).order_by(Ticker.last_updated.desc()).limit(1) | |
) | |
last_update = result.scalar() | |
if not last_update: | |
self.logger.info("ticker_freshness_check result=no_tickers_found") | |
return False | |
# Ensure timezone awareness | |
if last_update.tzinfo is None: | |
last_update = pytz.UTC.localize(last_update) | |
delta = now - last_update | |
is_fresh = delta.total_seconds() < (7 * 24 * 3600) # 7 days | |
self.logger.info(f"ticker_freshness_check last_update={last_update.isoformat()} is_fresh={is_fresh}") | |
return is_fresh | |
except Exception as e: | |
self.logger.error(f"ticker_freshness_check_failed error={str(e)}") | |
return False | |
async def check_ticker_data_freshness(self, session: AsyncSession) -> bool: | |
""" | |
Check if ticker data was updated within the last day (24 hours). | |
Returns True if fresh, False if need update. | |
""" | |
try: | |
now = datetime.now(pytz.UTC) | |
result = await session.execute( | |
select(TickerData.created_at).order_by(TickerData.created_at.desc()).limit(1) | |
) | |
last_update = result.scalar() | |
if not last_update: | |
self.logger.info("ticker_data_freshness_check result=no_data_found") | |
return False | |
# Ensure timezone awareness | |
if last_update.tzinfo is None: | |
last_update = pytz.UTC.localize(last_update) | |
delta = now - last_update | |
is_fresh = delta.total_seconds() < (24 * 3600) # 24 hours | |
self.logger.info(f"ticker_data_freshness_check last_update={last_update.isoformat()} is_fresh={is_fresh}") | |
return is_fresh | |
except Exception as e: | |
self.logger.error(f"ticker_data_freshness_check_failed error={str(e)}") | |
return False | |
async def clear_and_bulk_insert_ticker_data(self, session: AsyncSession, ticker_list: List[str]) -> Dict[str, Any]: | |
""" | |
Clear all ticker data and insert new data in bulk with chunking for better performance. | |
Uses bulk delete and bulk insert with chunks of 500 records. | |
""" | |
try: | |
# Start timing for total end-to-end process | |
total_start_time = time.perf_counter() | |
self.logger.info(f"bulk_refresh_started tickers_count={len(ticker_list)} operation=clear_and_insert") | |
# Start timing for data download | |
download_start_time = time.perf_counter() | |
# Download data for all tickers at once using period | |
data = yf.download(ticker_list, period='3mo', group_by='ticker', progress=True, auto_adjust=True) | |
download_end_time = time.perf_counter() | |
download_duration = download_end_time - download_start_time | |
self.logger.info(f"data_download_completed tickers_count={len(ticker_list)} duration_ms={download_duration*1000:.2f}") | |
if data.empty: | |
self.logger.warning("data_download_empty reason=no_data_for_any_tickers") | |
return { | |
"created": 0, | |
"updated": 0, | |
"date_range": {"start_date": "", "end_date": ""} | |
} | |
# Start timing for database operations | |
db_start_time = time.perf_counter() | |
# Clear all existing ticker data using TRUNCATE for speed | |
self.logger.info("database_clear_started operation=truncate_ticker_data") | |
clear_start = time.perf_counter() | |
from sqlalchemy import text | |
# Use TRUNCATE for faster clearing and avoid long-running DELETE | |
await session.execute(text("TRUNCATE TABLE ticker_data")) | |
await session.commit() # commit immediately to reset the connection | |
clear_end = time.perf_counter() | |
self.logger.info(f"database_truncate_completed duration_ms={(clear_end - clear_start)*1000:.2f}") | |
# Prepare data for bulk insert | |
current_time = datetime.now(pytz.UTC) | |
all_records = [] | |
# Get actual date range from the data | |
all_dates = data.index.tolist() | |
start_date = min(all_dates).date() if all_dates else datetime.now().date() | |
end_date = max(all_dates).date() if all_dates else datetime.now().date() | |
# Handle both single ticker and multi-ticker cases | |
if len(ticker_list) == 1: | |
# Single ticker case - data is not grouped | |
ticker = ticker_list[0] | |
# Calculate technical indicators | |
data_with_indicators = self.calculate_technical_indicators(data) | |
for date_idx, row in data_with_indicators.iterrows(): | |
if pd.isna(row['Close']): | |
continue | |
trade_date = date_idx.date() | |
record = { | |
'ticker': ticker, | |
'date': trade_date, | |
'open': float(row['Open']), | |
'high': float(row['High']), | |
'low': float(row['Low']), | |
'close': float(row['Close']), | |
'volume': int(row['Volume']), | |
'sma_fast': float(row['sma_fast']) if pd.notna(row['sma_fast']) else None, | |
'sma_med': float(row['sma_med']) if pd.notna(row['sma_med']) else None, | |
'sma_slow': float(row['sma_slow']) if pd.notna(row['sma_slow']) else None, | |
'created_at': current_time | |
} | |
all_records.append(record) | |
else: | |
# Multiple tickers case - data is grouped by ticker | |
for ticker in ticker_list: | |
if ticker not in data.columns.get_level_values(0): | |
self.logger.warning(f"ticker_data_missing ticker={ticker} reason=not_in_downloaded_data") | |
continue | |
ticker_data = data[ticker] | |
if ticker_data.empty: | |
continue | |
# Calculate technical indicators for this ticker | |
ticker_data_with_indicators = self.calculate_technical_indicators(ticker_data) | |
for date_idx, row in ticker_data_with_indicators.iterrows(): | |
if pd.isna(row['Close']): | |
continue | |
trade_date = date_idx.date() | |
record = { | |
'ticker': ticker, | |
'date': trade_date, | |
'open': float(row['Open']), | |
'high': float(row['High']), | |
'low': float(row['Low']), | |
'close': float(row['Close']), | |
'volume': int(row['Volume']), | |
'sma_fast': float(row['sma_fast']) if pd.notna(row['sma_fast']) else None, | |
'sma_med': float(row['sma_med']) if pd.notna(row['sma_med']) else None, | |
'sma_slow': float(row['sma_slow']) if pd.notna(row['sma_slow']) else None, | |
'created_at': current_time | |
} | |
all_records.append(record) | |
# Bulk insert in chunks of 1000 (optimized for MySQL performance) | |
chunk_size = 1000 | |
total_records = len(all_records) | |
inserted_count = 0 | |
self.logger.info(f"database_insert_started total_records={total_records} chunk_size={chunk_size}") | |
for i in range(0, total_records, chunk_size): | |
chunk = all_records[i:i + chunk_size] | |
chunk_start = time.perf_counter() | |
# Create TickerData objects for bulk insert | |
ticker_objects = [TickerData(**record) for record in chunk] | |
session.add_all(ticker_objects) | |
chunk_end = time.perf_counter() | |
inserted_count += len(chunk) | |
self.logger.info(f"database_chunk_inserted chunk={i//chunk_size + 1}/{(total_records + chunk_size - 1)//chunk_size} records={len(chunk)} duration_ms={(chunk_end - chunk_start)*1000:.2f}") | |
# Commit all changes | |
commit_start = time.perf_counter() | |
await session.commit() | |
commit_end = time.perf_counter() | |
self.logger.info(f"database_commit_completed duration_ms={(commit_end - commit_start)*1000:.2f}") | |
db_end_time = time.perf_counter() | |
db_duration = db_end_time - db_start_time | |
self.logger.info(f"database_operations_completed records_inserted={inserted_count} duration_ms={db_duration*1000:.2f}") | |
# Calculate total end-to-end duration | |
total_end_time = time.perf_counter() | |
total_duration = total_end_time - total_start_time | |
self.logger.info(f"bulk_refresh_completed total_duration_ms={total_duration*1000:.2f} download_ms={download_duration*1000:.2f} database_ms={db_duration*1000:.2f}") | |
self.logger.info(f"bulk_refresh_summary records_inserted={inserted_count} operation=completed") | |
return { | |
"created": inserted_count, | |
"updated": 0, | |
"date_range": { | |
"start_date": start_date.isoformat(), | |
"end_date": end_date.isoformat() | |
} | |
} | |
except Exception as e: | |
await session.rollback() | |
self.logger.error(f"bulk_refresh_failed error={str(e)}") | |
raise | |
async def download_all_tickers_data(self, session: AsyncSession, ticker_list: Optional[List[str]] = None, force_refresh: bool = False, force_indicators: bool = False) -> Dict[str, Any]: | |
""" | |
Download data for all or specified tickers for the last 3 months. | |
Uses smart strategy: checks data freshness, if > 24h, clears DB and bulk inserts new data. | |
Calculates technical indicators (SMA 10, 20, 50) for all data. | |
""" | |
try: | |
# Check ticker freshness and update if needed | |
if not await self.check_tickers_freshness(session): | |
self.logger.info("tickers_update_required reason=stale_data") | |
ticker_service = TickerService(self.config) | |
await ticker_service.update_tickers_in_db(session, force_refresh=True) | |
# Get tickers to process | |
if ticker_list: | |
# Validate provided tickers exist in database | |
result = await session.execute( | |
select(Ticker.ticker).where(Ticker.ticker.in_(ticker_list)) | |
) | |
valid_tickers = [row[0] for row in result.fetchall()] | |
invalid_tickers = set(ticker_list) - set(valid_tickers) | |
if invalid_tickers: | |
self.logger.warning(f"invalid_tickers_ignored count={len(invalid_tickers)} tickers={list(invalid_tickers)}") | |
tickers_to_process = valid_tickers | |
else: | |
# Get all tickers from database | |
result = await session.execute(select(Ticker.ticker)) | |
tickers_to_process = [row[0] for row in result.fetchall()] | |
if not tickers_to_process: | |
return { | |
"tickers_processed": 0, | |
"records_created": 0, | |
"records_updated": 0, | |
"date_range": {"start_date": "", "end_date": ""}, | |
"message": "No valid tickers found to process" | |
} | |
# Check if ticker data is fresh (less than 24h old) unless force_refresh | |
if not force_refresh and await self.check_ticker_data_freshness(session): | |
if not force_indicators: | |
self.logger.info("data_download_skipped reason=fresh_data age_limit=24h") | |
return { | |
"tickers_processed": len(tickers_to_process), | |
"records_created": 0, | |
"records_updated": 0, | |
"date_range": {"start_date": "", "end_date": ""}, | |
"message": f"Data is fresh, no update needed for {len(tickers_to_process)} tickers" | |
} | |
else: | |
# Data is fresh but force_indicators is True - only recalculate indicators | |
self.logger.info("indicators_recalculation_requested reason=force_indicators_flag data_age=fresh") | |
# TODO: Implement indicators-only recalculation | |
return { | |
"tickers_processed": len(tickers_to_process), | |
"records_created": 0, | |
"records_updated": 0, | |
"date_range": {"start_date": "", "end_date": ""}, | |
"message": f"Indicators recalculation for {len(tickers_to_process)} tickers (not implemented yet)" | |
} | |
# Data is stale - use bulk refresh strategy | |
self.logger.info("bulk_refresh_strategy_selected reason=stale_data age_limit=24h") | |
result = await self.clear_and_bulk_insert_ticker_data(session, tickers_to_process) | |
total_created = result["created"] | |
total_updated = result["updated"] | |
successful_tickers = len(tickers_to_process) | |
return { | |
"tickers_processed": successful_tickers, | |
"records_created": total_created, | |
"records_updated": total_updated, | |
"date_range": result["date_range"], | |
"message": f"Successfully processed {successful_tickers} tickers using bulk refresh" | |
} | |
except Exception as e: | |
self.logger.error(f"download_all_tickers_failed error={str(e)}") | |
raise | |
# --- DATABASE --- | |
class Database: | |
def __init__(self, config: Config): | |
self.config = config | |
# Filter out pool params not supported by NullPool | |
db_opts = self.config.config.get('database', {}).copy() | |
db_opts.pop('pool_size', None) | |
db_opts.pop('max_overflow', None) | |
self.engine = create_async_engine( | |
config.database_url, | |
pool_pre_ping=True, | |
poolclass=NullPool, | |
**db_opts | |
) | |
self.async_session = async_sessionmaker( | |
self.engine, | |
class_=AsyncSession, | |
expire_on_commit=False | |
) | |
async def create_tables(self): | |
async with self.engine.begin() as conn: | |
await conn.run_sync(Base.metadata.create_all) | |
# --- TASK MANAGER --- | |
from sqlalchemy import JSON as SQLAlchemyJSON | |
class Task(Base): | |
__tablename__ = "tasks" | |
task_id: Mapped[str] = mapped_column(String(64), primary_key=True) | |
status: Mapped[str] = mapped_column(String(32), nullable=False) | |
message: Mapped[Optional[str]] = mapped_column(String(255), nullable=True) | |
result: Mapped[Optional[dict]] = mapped_column(SQLAlchemyJSON, nullable=True) | |
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False) | |
class TaskManager: | |
def __init__(self, database: Database): | |
self.database = database | |
async def create_table_if_not_exists(self): | |
async with self.database.engine.begin() as conn: | |
await conn.run_sync(Base.metadata.create_all) | |
async def create_task(self, task_id: str) -> TaskStatus: | |
async with self.database.async_session() as session: | |
now = datetime.utcnow() | |
db_task = Task( | |
task_id=task_id, | |
status="pending", | |
message=None, | |
result=None, | |
created_at=now | |
) | |
session.add(db_task) | |
await session.commit() | |
return TaskStatus( | |
task_id=task_id, | |
status="pending", | |
message=None, | |
result=None, | |
created_at=now | |
) | |
async def update_task(self, task_id: str, status: str, message: str = None, result: Dict = None): | |
def serialize_datetimes(obj): | |
if isinstance(obj, dict): | |
return {k: serialize_datetimes(v) for k, v in obj.items()} | |
elif isinstance(obj, list): | |
return [serialize_datetimes(v) for v in obj] | |
elif isinstance(obj, datetime): | |
return obj.isoformat() | |
else: | |
return obj | |
async with self.database.async_session() as session: | |
db_task = await session.get(Task, task_id) | |
if db_task: | |
db_task.status = status | |
db_task.message = message | |
db_task.result = serialize_datetimes(result) if result is not None else None | |
await session.commit() | |
async def get_task(self, task_id: str) -> Optional[TaskStatus]: | |
async with self.database.async_session() as session: | |
db_task = await session.get(Task, task_id) | |
if db_task: | |
return TaskStatus( | |
task_id=db_task.task_id, | |
status=db_task.status, | |
message=db_task.message, | |
result=db_task.result, | |
created_at=db_task.created_at | |
) | |
return None | |
async def list_tasks(self) -> list[TaskStatus]: | |
async with self.database.async_session() as session: | |
result = await session.execute(select(Task)) | |
tasks = result.scalars().all() | |
return [ | |
TaskStatus( | |
task_id=t.task_id, | |
status=t.status, | |
message=t.message, | |
result=t.result, | |
created_at=t.created_at | |
) for t in tasks | |
] | |
async def delete_old_tasks(self, older_than_seconds: int = 3600) -> int: | |
cutoff = datetime.utcnow() - timedelta(seconds=older_than_seconds) | |
async with self.database.async_session() as session: | |
result = await session.execute(select(Task).where(Task.created_at < cutoff)) | |
old_tasks = result.scalars().all() | |
count = len(old_tasks) | |
for t in old_tasks: | |
await session.delete(t) | |
await session.commit() | |
return count | |
# --- APP SETUP --- | |
# Global instances | |
config = Config() | |
database = Database(config) | |
ticker_service = TickerService(config) | |
yfinance_service = YFinanceService(config) | |
task_manager = TaskManager(database) | |
# Dependency function | |
async def get_db_session() -> AsyncGenerator[AsyncSession, None]: | |
async with database.async_session() as session: | |
yield session | |
async def lifespan(app: FastAPI): | |
# Startup | |
await database.create_tables() | |
await task_manager.create_table_if_not_exists() | |
logger = logging.getLogger(__name__) | |
logger.info("database_lifecycle event=tables_created_verified") | |
yield | |
# Shutdown | |
await database.engine.dispose() | |
logger.info("database_lifecycle event=connections_closed") | |
# -------------------------- | |
# --- Create FastAPI app --- | |
app = FastAPI( | |
title="Stock Monitoring API", | |
description="API for managing S&P 500 and Nasdaq 100 ticker data", | |
version="0.2.0", | |
lifespan=lifespan, | |
) | |
# --- API ENDPOINTS --- | |
async def root_info(): | |
""" | |
Get API health status, current timestamp, versions, and DB/tables check. | |
**Logic**: | |
- Returns a JSON object with: | |
- **status**: Health status of the API | |
- **timestamp**: Current time in UTC timezone | |
- **versions**: Dictionary with Python and main library versions | |
- **database**: Connection status and existence of 'tickers' and 'tasks' tables | |
**Args**: None | |
**Example response:** | |
```json | |
{ | |
"status": "healthy", | |
"timestamp": "2025-07-19T19:38:26+02:00", | |
"versions": { ... }, | |
"database": { | |
"connected": true, | |
"tickers_table": true, | |
"tasks_table": true | |
} | |
} | |
``` | |
""" | |
now_utc = datetime.now(pytz.UTC) | |
versions = {} | |
versions["python"] = platform.python_version() | |
packages = ["uvicorn", "fastapi", "sqlalchemy", "pandas"] | |
for pkg in packages: | |
try: | |
versions[pkg] = importlib.metadata.version(pkg) | |
except Exception: | |
versions[pkg] = None | |
db_status = { | |
"connected": False, | |
"tickers_table": False, | |
"tasks_table": False | |
} | |
db_check_time = None | |
start = time.perf_counter() | |
try: | |
async with database.engine.connect() as conn: | |
db_status["connected"] = True | |
insp = await conn.run_sync(lambda c: c.dialect.get_table_names(c)) | |
db_status["tickers_table"] = "tickers" in insp | |
db_status["tasks_table"] = "tasks" in insp | |
except Exception as e: | |
db_status["connected"] = False | |
finally: | |
db_check_time = time.perf_counter() - start | |
return { | |
"status": "healthy" if db_status["connected"] and db_status["tickers_table"] and db_status["tasks_table"] else "degraded", | |
"timestamp": now_utc.isoformat(), | |
"versions": versions, | |
"database": db_status, | |
"db_check_seconds": round(db_check_time, 4) if db_check_time is not None else None | |
} | |
async def get_tickers( | |
is_sp500: Optional[bool] = None, | |
is_nasdaq: Optional[bool] = None, | |
limit: int = 1000, | |
session: AsyncSession = Depends(get_db_session) | |
): | |
""" | |
Get all tickers from database with optional filtering. | |
**Logic**: | |
- No parameters: Return all tickers | |
- is_sp500=true: Only S&P 500 tickers | |
- is_sp500=false: Only NON-S&P 500 tickers | |
- is_nasdaq=true: Only Nasdaq 100 tickers | |
- is_nasdaq=false: Only NON-Nasdaq 100 tickers | |
- Both parameters: Apply AND logic (intersection of conditions) | |
**Args (all optional)**: | |
- **is_sp500** (optional): Filter for S&P 500 membership (true/false/None) | |
- **is_nasdaq** (optional): Filter for Nasdaq 100 membership (true/false/None) | |
- **limit** (optional): Maximum number of results to return | |
**Examples:** | |
- `GET /tickers` - All tickers | |
- `GET /tickers?is_sp500=true` - Only S&P 500 | |
- `GET /tickers?is_nasdaq=true&is_sp500=false` - Only Nasdaq 100 but not S&P 500 | |
- `GET /tickers?is_sp500=true&is_nasdaq=false` - S&P 500 but not Nasdaq 100 | |
""" | |
try: | |
query = select(Ticker) | |
# Build conditions based on explicit flag values | |
conditions = [] | |
if is_sp500 is not None: | |
if is_sp500: | |
conditions.append(Ticker.is_sp500 == 1) | |
else: | |
conditions.append(Ticker.is_sp500 == 0) | |
if is_nasdaq is not None: | |
if is_nasdaq: | |
conditions.append(Ticker.is_nasdaq100 == 1) | |
else: | |
conditions.append(Ticker.is_nasdaq100 == 0) | |
# Apply filtering if we have conditions | |
if conditions: | |
from sqlalchemy import and_ | |
query = query.where(and_(*conditions)) | |
query = query.limit(limit).order_by(Ticker.ticker) | |
result = await session.execute(query) | |
tickers = result.scalars().all() | |
return [ | |
TickerResponse( | |
ticker=t.ticker, | |
name=t.name, | |
sector=t.sector, | |
subindustry=t.subindustry, | |
is_sp500=bool(t.is_sp500), | |
is_nasdaq100=bool(t.is_nasdaq100), | |
last_updated=t.last_updated | |
) | |
for t in tickers | |
] | |
except Exception as e: | |
logger = logging.getLogger(__name__) | |
logger.error(f"endpoint_error endpoint=get_tickers error={str(e)}") | |
raise HTTPException(status_code=500, detail="Failed to fetch tickers") | |
async def update_tickers( | |
request: UpdateTickersRequest, | |
background_tasks: BackgroundTasks, | |
session: AsyncSession = Depends(get_db_session), | |
api_key: str = Depends(verify_api_key) | |
): | |
""" | |
Update tickers from Wikipedia sources (S&P 500 and Nasdaq 100). | |
**Logic**: | |
- Fetches latest tickers from Wikipedia (S&P 500 and Nasdaq 100). | |
- Updates the database with the new tickers. | |
- Returns summary of update (counts, timestamp). | |
**Args**: | |
- **request**: UpdateTickersRequest (force_refresh: bool) | |
- **background_tasks**: FastAPI BackgroundTasks (unused) | |
- **session**: AsyncSession (DB session, injected) | |
**Example request:** | |
```json | |
{ "force_refresh": false } | |
{ "force_refresh": true } | |
``` | |
**Example response:** | |
```json | |
{ | |
"success": true, | |
"message": "Tickers updated successfully", | |
"total_tickers": 517, | |
"sp500_count": 500, | |
"nasdaq100_count": 100, | |
"updated_at": "2025-07-19T19:38:26+02:00" | |
} | |
``` | |
""" | |
try: | |
result = await ticker_service.update_tickers_in_db(session, force_refresh=request.force_refresh) | |
message = result.pop("not_updated_reason", None) | |
if message: | |
return UpdateTickersResponse( | |
success=True, | |
message=message, | |
**result | |
) | |
return UpdateTickersResponse( | |
success=True, | |
message="Tickers updated successfully", | |
**result | |
) | |
except Exception as e: | |
logger = logging.getLogger(__name__) | |
logger.error(f"endpoint_error endpoint=update_tickers error={str(e)}") | |
raise HTTPException(status_code=500, detail=f"Failed to update tickers: {str(e)}") | |
async def update_tickers_async( | |
request: UpdateTickersRequest, | |
background_tasks: BackgroundTasks, | |
api_key: str = Depends(verify_api_key) | |
): | |
""" | |
Start async ticker update task (background). | |
**Logic**: | |
- Launches a background task to update tickers from Wikipedia. | |
- Returns a task_id and status for tracking. | |
**Args**: | |
- **request**: UpdateTickersRequest (force_refresh: bool) | |
**Example request:** | |
```json | |
{ "force_refresh": false } | |
{ "force_refresh": true } | |
``` | |
**Example response:** | |
```json | |
{ | |
"task_id": "c1a2b3d4-5678-90ab-cdef-1234567890ab", | |
"status": "started" | |
} | |
``` | |
""" | |
import uuid | |
task_id = str(uuid.uuid4()) | |
await task_manager.create_task(task_id) | |
async def update_task(): | |
try: | |
await task_manager.update_task(task_id, "running", "Updating tickers...") | |
async with database.async_session() as session: | |
result = await ticker_service.update_tickers_in_db(session, force_refresh=request.force_refresh) | |
message = result.pop("not_updated_reason", None) | |
if message: | |
await task_manager.update_task(task_id, "completed", message, result) | |
else: | |
await task_manager.update_task(task_id, "completed", "Update successful", result) | |
except Exception as e: | |
await task_manager.update_task(task_id, "failed", str(e)) | |
background_tasks.add_task(update_task) | |
return {"task_id": task_id, "status": "started"} | |
async def list_all_tasks(api_key: str = Depends(verify_api_key)): | |
""" | |
List all background tasks and their status. | |
**Logic**: | |
- Returns a list of all tasks created via async update endpoint, with their status and result. | |
**Args**: None | |
**Example response:** | |
```json | |
[ | |
{ | |
"task_id": "c1a2b3d4-5678-90ab-cdef-1234567890ab", | |
"status": "completed", | |
"message": "Tickers updated successfully", | |
"result": { | |
"total_tickers": 517, | |
"sp500_count": 500, | |
"nasdaq100_count": 100, | |
"updated_at": "2025-07-19T19:38:26+02:00" | |
}, | |
"created_at": "2025-07-19T19:38:26+02:00" | |
}, | |
... | |
] | |
``` | |
""" | |
return await task_manager.list_tasks() | |
async def get_task_status(task_id: str, api_key: str = Depends(verify_api_key)): | |
""" | |
Get status and result of a background update task by task_id. | |
**Logic**: | |
- Returns the status and result of a background update task by task_id. | |
- If not found, returns 404. | |
**Args**: | |
- **task_id**: str (UUID of the task) | |
**Example response:** | |
```json | |
{ | |
"task_id": "c1a2b3d4-5678-90ab-cdef-1234567890ab", | |
"status": "completed", | |
"message": "Tickers updated successfully", | |
"result": { | |
"total_tickers": 517, | |
"sp500_count": 500, | |
"nasdaq100_count": 100, | |
"updated_at": "2025-07-19T19:38:26+02:00" | |
}, | |
"created_at": "2025-07-19T19:38:26+02:00" | |
} | |
``` | |
""" | |
task = await task_manager.get_task(task_id) | |
if not task: | |
raise HTTPException(status_code=404, detail="Task not found") | |
return task | |
# Endpoint to delete tasks older than 1 hour | |
async def delete_old_tasks(api_key: str = Depends(verify_api_key)): | |
""" | |
Delete tasks older than 1 hour (3600 seconds). | |
**Logic**: | |
- Deletes all tasks in the database older than 1 hour. | |
- Returns the number of deleted tasks. | |
**Args**: None | |
**Example response:** | |
```json | |
{ "deleted": 5 } | |
``` | |
""" | |
deleted_count = await task_manager.delete_old_tasks(older_than_seconds=3600) | |
return {"deleted": deleted_count} | |
async def download_all_tickers_data( | |
request: DownloadDataRequest = DownloadDataRequest(), | |
session: AsyncSession = Depends(get_db_session), | |
api_key: str = Depends(verify_api_key) | |
): | |
""" | |
Download daily ticker data for the last 3 months for ALL tickers in database. | |
**Logic**: | |
- Automatically downloads data for all tickers stored in the tickers table | |
- Checks if tickers were updated within the last week, updates if needed | |
- Only downloads if ticker data is older than 24 hours (unless force_refresh=true) | |
- Downloads daily data for the last 3 months for all available tickers | |
- Calculates technical indicators: SMA 10 (fast), SMA 20 (med), SMA 50 (slow) | |
- Uses bulk delete and insert strategy for optimal performance | |
- Returns summary with counts and date range | |
**Args**: | |
- **request**: DownloadDataRequest (force_refresh, force_indicators flags) | |
- **session**: AsyncSession (DB session, injected) | |
- **api_key**: str (API key for authentication, injected) | |
**Example request:** | |
```bash | |
curl -X POST "http://localhost:${PORT}/data/download-all" \ | |
-H "Authorization: Bearer your_api_key" \ | |
-H "Content-Type: application/json" \ | |
-d '{"force_refresh": false, "force_indicators": true}' | |
``` | |
**Example response:** | |
```json | |
{ | |
"success": true, | |
"message": "Successfully processed 503 tickers using bulk refresh", | |
"tickers_processed": 503, | |
"records_created": 12075, | |
"records_updated": 0, | |
"date_range": { | |
"start_date": "2025-06-30", | |
"end_date": "2025-07-30" | |
}, | |
"updated_at": "2025-07-30T14:15:26+00:00" | |
} | |
``` | |
""" | |
try: | |
# Use existing service without specifying ticker list (downloads all) | |
result = await yfinance_service.download_all_tickers_data( | |
session, | |
ticker_list=request.tickers, # None means download all tickers | |
force_refresh=request.force_refresh, | |
force_indicators=request.force_indicators | |
) | |
return DownloadDataResponse( | |
success=True, | |
message=result["message"], | |
tickers_processed=result["tickers_processed"], | |
records_created=result["records_created"], | |
records_updated=result["records_updated"], | |
date_range=result["date_range"], | |
updated_at=datetime.now(pytz.UTC) | |
) | |
except Exception as e: | |
logger = logging.getLogger(__name__) | |
logger.error(f"endpoint_error endpoint=download_all_tickers error={str(e)}") | |
raise HTTPException(status_code=500, detail=f"Failed to download all ticker data: {str(e)}") | |
async def get_ticker_data( | |
ticker: str, | |
days: int = 30, | |
session: AsyncSession = Depends(get_db_session) | |
): | |
""" | |
Get historical data for a specific ticker. | |
**Logic**: | |
- Returns historical data for the specified ticker | |
- Defaults to last 30 days if no days parameter provided | |
- Data is ordered by date descending (most recent first) | |
**Args**: | |
- **ticker**: str (Ticker symbol, e.g., "AAPL") | |
- **days**: int (Number of days to retrieve, default 30) | |
- **session**: AsyncSession (DB session, injected) | |
**Example response:** | |
```json | |
[ | |
{ | |
"ticker": "AAPL", | |
"date": "2025-07-30", | |
"open": 150.25, | |
"high": 152.80, | |
"low": 149.50, | |
"close": 151.75, | |
"volume": 45123000, | |
"created_at": "2025-07-30T13:45:26+00:00" | |
} | |
] | |
``` | |
""" | |
try: | |
# Calculate date range | |
end_date = datetime_date.today() | |
start_date = end_date - timedelta(days=days) | |
# Query ticker data | |
query = select(TickerData).where( | |
TickerData.ticker == ticker.upper(), | |
TickerData.date >= start_date, | |
TickerData.date <= end_date | |
).order_by(TickerData.date.desc()) | |
result = await session.execute(query) | |
ticker_data = result.scalars().all() | |
if not ticker_data: | |
raise HTTPException( | |
status_code=404, | |
detail=f"No data found for ticker {ticker.upper()} in the last {days} days" | |
) | |
return [ | |
TickerDataResponse( | |
ticker=data.ticker, | |
date=data.date, | |
open=data.open, | |
high=data.high, | |
low=data.low, | |
close=data.close, | |
volume=data.volume, | |
sma_fast=data.sma_fast, | |
sma_med=data.sma_med, | |
sma_slow=data.sma_slow, | |
created_at=data.created_at | |
) | |
for data in ticker_data | |
] | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger = logging.getLogger(__name__) | |
logger.error(f"endpoint_error endpoint=get_ticker_data ticker={ticker} error={str(e)}") | |
raise HTTPException(status_code=500, detail="Failed to fetch ticker data") | |
async def analyze_financial_data( | |
request: FinancialDataRequest, | |
http_request: Request, | |
api_key: str = Depends(verify_api_key), | |
rate_info: dict = Depends(check_rate_limit) | |
): | |
""" | |
Download financial data for multiple tickers and calculate technical indicators without database storage. | |
**Security Features**: | |
- **API Key Required**: Must provide valid API key in Authorization header | |
- **Rate Limited**: Maximum 20 requests per minute per IP address | |
- **Input Validation**: Comprehensive validation of ticker symbols and parameters | |
- **Request Logging**: All requests are logged with IP address and timing | |
**Logic**: | |
- Downloads real-time data from Yahoo Finance for the specified tickers and period | |
- Optimized for multiple tickers by downloading them in a single batch request | |
- Calculates technical indicators: SMA 10 (fast), SMA 20 (med), SMA 50 (slow) | |
- Returns the data with technical indicators without storing in database | |
- Useful for real-time analysis and testing without persisting data | |
**Authentication**: | |
- **Header**: `Authorization: Bearer <your_api_key>` | |
**Rate Limits**: | |
- **Limit**: 20 requests per minute per IP address | |
- **Headers**: Response includes rate limit headers (X-RateLimit-*) | |
**Args**: | |
- **request**: FinancialDataRequest (list of ticker symbols and period) | |
- **http_request**: Request object (auto-injected for IP tracking) | |
- **api_key**: API key for authentication (auto-injected) | |
- **rate_info**: Rate limiting info (auto-injected) | |
**Supported periods**: | |
- 1d, 5d, 1mo, 3mo, 6mo, 1y, 2y, 5y, 10y, ytd, max | |
**Supported intervals**: | |
- 1m, 2m, 5m, 15m, 30m, 60m, 90m, 1h, 4h, 1d, 5d, 1wk, 1mo, 3mo | |
**Intraday Features**: | |
- **Pre/Post Market**: Include extended hours data when `intraday=true` | |
- **Market Status**: Real-time market status checking | |
- **High Frequency**: Support for 5m to 4h intervals | |
- **Restrictions**: Intraday limited to 1d/5d/1mo periods with 5m-4h intervals | |
**Example requests:** | |
```bash | |
# Daily data (default) | |
curl -X POST "http://localhost:7860/data/analyze" \\ | |
-H "Authorization: Bearer your_api_key" \\ | |
-H "Content-Type: application/json" \\ | |
-d '{"tickers": ["AAPL", "MSFT"], "period": "3mo"}' | |
# Intraday 15-minute data with pre/post market | |
curl -X POST "http://localhost:7860/data/analyze" \\ | |
-H "Authorization: Bearer your_api_key" \\ | |
-H "Content-Type: application/json" \\ | |
-d '{"tickers": ["AAPL"], "period": "1d", "interval": "15m", "intraday": true}' | |
# Hourly data for current week | |
curl -X POST "http://localhost:7860/data/analyze" \\ | |
-H "Authorization: Bearer your_api_key" \\ | |
-H "Content-Type: application/json" \\ | |
-d '{"tickers": ["TSLA", "NVDA"], "period": "5d", "interval": "1h", "intraday": true}' | |
``` | |
**Example request body**: | |
```json | |
{ | |
"tickers": ["TSLA", "NVDA"], | |
"period": "5d", | |
"interval": "1h", | |
"intraday": true | |
} | |
``` | |
**Example response:** | |
```json | |
{ | |
"success": true, | |
"tickers": ["AAPL", "MSFT", "GOOGL"], | |
"period": "3mo", | |
"total_data_points": 195, | |
"date_range": { | |
"start_date": "2025-04-30", | |
"end_date": "2025-07-31" | |
}, | |
"data": [ | |
{ | |
"ticker": "AAPL", | |
"date": "2025-07-31", | |
"open": 150.25, | |
"high": 152.80, | |
"low": 149.50, | |
"close": 151.75, | |
"volume": 45123000, | |
"sma_fast": 150.85, | |
"sma_med": 149.92, | |
"sma_slow": 148.15 | |
} | |
], | |
"calculated_at": "2025-07-31T14:15:26+00:00" | |
} | |
``` | |
**Error Responses**: | |
- **401**: Invalid or missing API key | |
- **429**: Rate limit exceeded (includes Retry-After header) | |
- **400**: Invalid input parameters | |
- **404**: No data found for requested tickers | |
- **500**: Internal server error | |
""" | |
try: | |
logger = logging.getLogger(__name__) | |
start_time = time.perf_counter() | |
# Security logging - get client IP for audit trail | |
client_ip = http_request.headers.get("x-forwarded-for", "").split(",")[0].strip() | |
if not client_ip: | |
client_ip = http_request.headers.get("x-real-ip", "") | |
if not client_ip: | |
client_ip = getattr(http_request.client, "host", "unknown") | |
user_agent = http_request.headers.get("user-agent", "unknown") | |
# Enhanced input validation and security checks | |
if not request.tickers or len(request.tickers) == 0: | |
logger.warning(f"security_validation_failed client_ip={client_ip} reason=empty_tickers_list user_agent={user_agent}") | |
raise HTTPException( | |
status_code=400, | |
detail="At least one ticker symbol is required." | |
) | |
if len(request.tickers) > 50: | |
logger.warning(f"security_validation_failed client_ip={client_ip} reason=too_many_tickers count={len(request.tickers)} user_agent={user_agent}") | |
raise HTTPException( | |
status_code=400, | |
detail="Maximum 50 tickers allowed per request." | |
) | |
# Clean and validate ticker symbols with enhanced security | |
ticker_symbols = [] | |
for ticker in request.tickers: | |
ticker_clean = str(ticker).upper().strip() | |
# Security: Check for malicious patterns | |
if not ticker_clean or len(ticker_clean) > 10: | |
logger.warning(f"security_validation_failed client_ip={client_ip} reason=invalid_ticker_length ticker={ticker} user_agent={user_agent}") | |
raise HTTPException( | |
status_code=400, | |
detail=f"Invalid ticker symbol '{ticker}'. Must be 1-10 characters." | |
) | |
# Security: Only allow alphanumeric characters and common symbols | |
import re | |
if not re.match(r'^[A-Z0-9\.\-\^]+$', ticker_clean): | |
logger.warning(f"security_validation_failed client_ip={client_ip} reason=invalid_ticker_chars ticker={ticker} user_agent={user_agent}") | |
raise HTTPException( | |
status_code=400, | |
detail=f"Invalid ticker symbol '{ticker}'. Only alphanumeric characters, dots, hyphens, and carets allowed." | |
) | |
ticker_symbols.append(ticker_clean) | |
# Remove duplicates while preserving order | |
seen = set() | |
ticker_symbols = [x for x in ticker_symbols if not (x in seen or seen.add(x))] | |
# Validate period and interval with security logging | |
valid_periods = ['1d', '5d', '1mo', '3mo', '6mo', '1y', '2y', '5y', '10y', 'ytd', 'max'] | |
valid_intervals = ['1m', '2m', '5m', '15m', '30m', '60m', '90m', '1h', '4h', '1d', '5d', '1wk', '1mo', '3mo'] | |
if request.period not in valid_periods: | |
logger.warning(f"security_validation_failed client_ip={client_ip} reason=invalid_period period={request.period} user_agent={user_agent}") | |
raise HTTPException( | |
status_code=400, | |
detail=f"Invalid period. Must be one of: {', '.join(valid_periods)}" | |
) | |
if request.interval not in valid_intervals: | |
logger.warning(f"security_validation_failed client_ip={client_ip} reason=invalid_interval interval={request.interval} user_agent={user_agent}") | |
raise HTTPException( | |
status_code=400, | |
detail=f"Invalid interval. Must be one of: {', '.join(valid_intervals)}" | |
) | |
# Validate intraday configuration | |
if request.intraday: | |
# For intraday data, restrict to shorter periods and specific intervals | |
intraday_periods = ['1d', '5d', '1mo'] | |
intraday_intervals = ['5m', '15m', '30m', '60m', '90m', '1h', '4h'] | |
if request.period not in intraday_periods: | |
logger.warning(f"security_validation_failed client_ip={client_ip} reason=invalid_intraday_period period={request.period} user_agent={user_agent}") | |
raise HTTPException( | |
status_code=400, | |
detail=f"For intraday data, period must be one of: {', '.join(intraday_periods)}" | |
) | |
if request.interval not in intraday_intervals: | |
logger.warning(f"security_validation_failed client_ip={client_ip} reason=invalid_intraday_interval interval={request.interval} user_agent={user_agent}") | |
raise HTTPException( | |
status_code=400, | |
detail=f"For intraday data, interval must be one of: {', '.join(intraday_intervals)}" | |
) | |
# Security audit log for successful request start | |
logger.info(f"financial_data_analysis_started client_ip={client_ip} tickers={ticker_symbols} period={request.period} interval={request.interval} intraday={request.intraday} count={len(ticker_symbols)} api_key_valid=true user_agent={user_agent}") | |
logger.info(f"rate_limit_info client_ip={client_ip} current_count={rate_info['current_count']} limit={rate_info['limit']} remaining={rate_info.get('remaining', 0)}") | |
# Get market status for the first ticker (representative) | |
market_status = None | |
if request.intraday or request.interval in ['1m', '2m', '5m', '15m', '30m', '60m', '90m', '1h', '4h']: | |
yfinance_svc = YFinanceService(config) | |
market_status = yfinance_svc.get_market_status(ticker_symbols[0]) | |
logger.info(f"market_status_check ticker={ticker_symbols[0]} state={market_status.market_state} is_open={market_status.is_open}") | |
# Download data from Yahoo Finance - optimized for multiple tickers with interval support | |
download_start = time.perf_counter() | |
# Configure download parameters | |
download_params = { | |
'period': request.period, | |
'progress': False, | |
'auto_adjust': True | |
} | |
# Only group by ticker if we have multiple tickers | |
if len(ticker_symbols) > 1: | |
download_params['group_by'] = 'ticker' | |
# Add interval if different from default | |
if request.interval != '1d': | |
download_params['interval'] = request.interval | |
# For intraday data, include pre/post market data | |
if request.intraday: | |
download_params['prepost'] = True | |
logger.info(f"intraday_download_enabled prepost=true interval={request.interval}") | |
data = yf.download(ticker_symbols, **download_params) | |
download_end = time.perf_counter() | |
if data.empty: | |
logger.warning(f"no_data_found tickers={ticker_symbols} period={request.period}") | |
raise HTTPException( | |
status_code=404, | |
detail=f"No financial data found for tickers {ticker_symbols} with period {request.period}" | |
) | |
logger.info(f"data_downloaded tickers_count={len(ticker_symbols)} rows={len(data)} duration_ms={(download_end-download_start)*1000:.2f}") | |
# Calculate technical indicators and convert to response format | |
calc_start = time.perf_counter() | |
if 'yfinance_svc' not in locals(): | |
yfinance_svc = YFinanceService(config) | |
result_data = [] | |
all_dates = [] | |
# Handle both single ticker and multi-ticker cases | |
if len(ticker_symbols) == 1: | |
# Single ticker case - flatten multi-level columns if they exist | |
ticker = ticker_symbols[0] | |
# Check if we have multi-level columns and flatten them | |
if isinstance(data.columns, pd.MultiIndex): | |
# Flatten the multi-level columns by taking the first level (the actual column names) | |
data.columns = data.columns.get_level_values(0) | |
data_with_indicators = yfinance_svc.calculate_technical_indicators(data) | |
all_dates.extend(data_with_indicators.index.tolist()) | |
for date_idx, row in data_with_indicators.iterrows(): | |
try: | |
close_val = row['Close'] | |
if pd.isna(close_val): | |
continue | |
except (KeyError, ValueError): | |
continue | |
# Format datetime based on data type (intraday vs daily) | |
if request.intraday or request.interval in ['1m', '2m', '5m', '15m', '30m', '60m', '90m', '1h', '4h']: | |
datetime_str = date_idx.isoformat() | |
else: | |
datetime_str = date_idx.date().isoformat() | |
result_data.append(TechnicalIndicatorData( | |
ticker=ticker, | |
datetime=datetime_str, | |
open=float(row['Open']), | |
high=float(row['High']), | |
low=float(row['Low']), | |
close=float(row['Close']), | |
volume=int(row['Volume']), | |
sma_fast=float(row['sma_fast']) if pd.notna(row['sma_fast']) else None, | |
sma_med=float(row['sma_med']) if pd.notna(row['sma_med']) else None, | |
sma_slow=float(row['sma_slow']) if pd.notna(row['sma_slow']) else None | |
)) | |
else: | |
# Multiple tickers case - data is grouped by ticker | |
processed_tickers = [] | |
for ticker in ticker_symbols: | |
if ticker not in data.columns.get_level_values(0): | |
logger.warning(f"ticker_data_missing ticker={ticker} reason=not_in_downloaded_data") | |
continue | |
ticker_data = data[ticker] | |
if ticker_data.empty: | |
logger.warning(f"ticker_data_empty ticker={ticker}") | |
continue | |
# Calculate technical indicators for this ticker | |
ticker_data_with_indicators = yfinance_svc.calculate_technical_indicators(ticker_data) | |
all_dates.extend(ticker_data_with_indicators.index.tolist()) | |
processed_tickers.append(ticker) | |
for date_idx, row in ticker_data_with_indicators.iterrows(): | |
try: | |
close_val = row['Close'] | |
if pd.isna(close_val): | |
continue | |
except (KeyError, ValueError): | |
continue | |
# Format datetime based on data type (intraday vs daily) | |
if request.intraday or request.interval in ['1m', '2m', '5m', '15m', '30m', '60m', '90m', '1h', '4h']: | |
datetime_str = date_idx.isoformat() | |
else: | |
datetime_str = date_idx.date().isoformat() | |
result_data.append(TechnicalIndicatorData( | |
ticker=ticker, | |
datetime=datetime_str, | |
open=float(row['Open']), | |
high=float(row['High']), | |
low=float(row['Low']), | |
close=float(row['Close']), | |
volume=int(row['Volume']), | |
sma_fast=float(row['sma_fast']) if pd.notna(row['sma_fast']) else None, | |
sma_med=float(row['sma_med']) if pd.notna(row['sma_med']) else None, | |
sma_slow=float(row['sma_slow']) if pd.notna(row['sma_slow']) else None | |
)) | |
if not processed_tickers: | |
raise HTTPException( | |
status_code=404, | |
detail=f"No valid data found for any of the requested tickers: {ticker_symbols}" | |
) | |
calc_end = time.perf_counter() | |
logger.info(f"indicators_calculated tickers_count={len(ticker_symbols)} duration_ms={(calc_end-calc_start)*1000:.2f}") | |
# Calculate date range | |
start_date = min(all_dates).date() if all_dates else datetime.now().date() | |
end_date = max(all_dates).date() if all_dates else datetime.now().date() | |
# Sort by ticker and datetime (most recent first) | |
result_data.sort(key=lambda x: (x.ticker, x.datetime), reverse=True) | |
end_time = time.perf_counter() | |
total_duration = end_time - start_time | |
# Security audit log for successful completion | |
logger.info(f"financial_data_analysis_completed client_ip={client_ip} tickers={ticker_symbols} data_points={len(result_data)} total_duration_ms={total_duration*1000:.2f} status=success") | |
# Create response with security headers | |
from fastapi.responses import JSONResponse | |
# Create response data | |
response_data = { | |
"success": True, | |
"tickers": ticker_symbols, | |
"period": request.period, | |
"interval": request.interval, | |
"intraday": request.intraday, | |
"total_data_points": len(result_data), | |
"date_range": { | |
"start_date": start_date.isoformat(), | |
"end_date": end_date.isoformat() | |
}, | |
"market_status": { | |
"is_open": market_status.is_open, | |
"market_state": market_status.market_state, | |
"timezone": market_status.timezone | |
} if market_status else None, | |
"data": [ | |
{ | |
"ticker": item.ticker, | |
"datetime": item.datetime, | |
"open": item.open, | |
"high": item.high, | |
"low": item.low, | |
"close": item.close, | |
"volume": item.volume, | |
"sma_fast": item.sma_fast, | |
"sma_med": item.sma_med, | |
"sma_slow": item.sma_slow | |
} | |
for item in result_data | |
], | |
"calculated_at": datetime.now(pytz.UTC).isoformat() | |
} | |
# Return JSONResponse with security headers | |
return JSONResponse( | |
content=response_data, | |
headers={ | |
"X-RateLimit-Limit": str(rate_info["limit"]), | |
"X-RateLimit-Remaining": str(rate_info.get("remaining", 0)), | |
"X-Content-Type-Options": "nosniff", | |
"X-Frame-Options": "DENY", | |
"X-XSS-Protection": "1; mode=block" | |
} | |
) | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger = logging.getLogger(__name__) | |
# Security audit log for errors | |
client_ip = http_request.headers.get("x-forwarded-for", "").split(",")[0].strip() | |
if not client_ip: | |
client_ip = getattr(http_request.client, "host", "unknown") | |
logger.error(f"financial_data_analysis_failed client_ip={client_ip} tickers={request.tickers} error={str(e)} status=error") | |
raise HTTPException( | |
status_code=500, | |
detail=f"Failed to analyze financial data for tickers {request.tickers}: {str(e)}" | |
) | |
# Local execution configuration | |
if __name__ == "__main__": | |
import uvicorn | |
HOST = os.getenv("HOST", "0.0.0.0") | |
PORT = int(os.getenv("PORT", 7860)) | |
# Determina el valor de reload según si estamos en HF Spaces | |
RELOAD = os.getenv("SPACE_ID") is None | |
# Start the Uvicorn server | |
uvicorn.run("index:app", host=HOST, port=PORT, reload=RELOAD) | |