from fastapi import FastAPI, HTTPException, UploadFile, File from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel from setfit import AbsaModel import logging from typing import List, Dict, Any import uvicorn import os # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize FastAPI app app = FastAPI(title="ABSA Web Application", description="Aspect-Based Sentiment Analysis using SetFit models") # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global variable to store the model absa_model = None class TextInput(BaseModel): text: str class ABSAResponse(BaseModel): text: str predictions: List[Dict[str, Any]] success: bool message: str async def load_model(): """Load the ABSA model on startup""" global absa_model try: logger.info("Loading ABSA models...") absa_model = AbsaModel.from_pretrained( "ronalhung/setfit-absa-restaurants-aspect", "ronalhung/setfit-absa-restaurants-polarity", ) logger.info("Models loaded successfully!") except Exception as e: logger.error(f"Failed to load models: {str(e)}") raise e @app.on_event("startup") async def startup_event(): """Load model when the application starts""" await load_model() @app.get("/", response_class=HTMLResponse) async def get_home(): """Serve the main HTML page""" html_content = """ ABSA - Aspect-Based Sentiment Analysis
""" return html_content @app.post("/analyze", response_model=ABSAResponse) async def analyze_text(input_data: TextInput): """Analyze text for aspects and sentiment""" global absa_model if absa_model is None: raise HTTPException(status_code=503, detail="Model not loaded yet. Please try again later.") try: text = input_data.text.strip() if not text: return ABSAResponse( text=text, predictions=[], success=False, message="Empty text provided" ) logger.info(f"Analyzing text: {text[:100]}...") # Run ABSA analysis predictions = absa_model(text) # Format predictions for response formatted_predictions = [] if predictions: for pred in predictions: formatted_predictions.append({ "span": pred.get("span", ""), "polarity": pred.get("polarity", "neutral") }) return ABSAResponse( text=text, predictions=formatted_predictions, success=True, message="Analysis completed successfully" ) except Exception as e: logger.error(f"Error during analysis: {str(e)}") return ABSAResponse( text=input_data.text, predictions=[], success=False, message=f"Analysis failed: {str(e)}" ) @app.get("/health") async def health_check(): """Health check endpoint""" return { "status": "healthy", "model_loaded": absa_model is not None, "message": "ABSA service is running" } if __name__ == "__main__": uvicorn.run( "app:app", host="0.0.0.0", port=8000, reload=True, log_level="info" )