Spaces:
Running
Running
from fastapi import FastAPI, HTTPException, UploadFile, File | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import HTMLResponse | |
from fastapi.staticfiles import StaticFiles | |
from pydantic import BaseModel | |
from setfit import AbsaModel | |
import logging | |
from typing import List, Dict, Any | |
import uvicorn | |
import os | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Initialize FastAPI app | |
app = FastAPI(title="ABSA Web Application", description="Aspect-Based Sentiment Analysis using SetFit models") | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Global variable to store the model | |
absa_model = None | |
class TextInput(BaseModel): | |
text: str | |
class ABSAResponse(BaseModel): | |
text: str | |
predictions: List[Dict[str, Any]] | |
success: bool | |
message: str | |
async def load_model(): | |
"""Load the ABSA model on startup""" | |
global absa_model | |
try: | |
logger.info("Loading ABSA models...") | |
absa_model = AbsaModel.from_pretrained( | |
"ronalhung/setfit-absa-restaurants-aspect", | |
"ronalhung/setfit-absa-restaurants-polarity", | |
) | |
logger.info("Models loaded successfully!") | |
except Exception as e: | |
logger.error(f"Failed to load models: {str(e)}") | |
raise e | |
async def startup_event(): | |
"""Load model when the application starts""" | |
await load_model() | |
async def get_home(): | |
"""Serve the main HTML page""" | |
html_content = """ | |
<!DOCTYPE html> | |
<html lang="en"> | |
<head> | |
<meta charset="UTF-8"> | |
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | |
<title>ABSA - Aspect-Based Sentiment Analysis</title> | |
<script src="https://cdn.tailwindcss.com"></script> | |
<script src="https://unpkg.com/react@18/umd/react.development.js"></script> | |
<script src="https://unpkg.com/react-dom@18/umd/react-dom.development.js"></script> | |
<script src="https://unpkg.com/@babel/standalone/babel.min.js"></script> | |
</head> | |
<body class="bg-gray-50"> | |
<div id="root"></div> | |
<script type="text/babel"> | |
const { useState, useRef } = React; | |
const App = () => { | |
const [text, setText] = useState(''); | |
const [results, setResults] = useState(null); | |
const [loading, setLoading] = useState(false); | |
const [error, setError] = useState(''); | |
const fileInputRef = useRef(null); | |
const handleAnalyze = async () => { | |
if (!text.trim()) { | |
setError('Please enter some text to analyze'); | |
return; | |
} | |
setLoading(true); | |
setError(''); | |
try { | |
const response = await fetch('/analyze', { | |
method: 'POST', | |
headers: { | |
'Content-Type': 'application/json', | |
}, | |
body: JSON.stringify({ text: text.trim() }), | |
}); | |
const data = await response.json(); | |
if (data.success) { | |
setResults(data); | |
} else { | |
setError(data.message || 'Analysis failed'); | |
} | |
} catch (err) { | |
setError('Failed to analyze text. Please try again.'); | |
console.error('Error:', err); | |
} finally { | |
setLoading(false); | |
} | |
}; | |
const handleFileUpload = async (event) => { | |
const file = event.target.files[0]; | |
if (!file) return; | |
if (!file.name.endsWith('.txt')) { | |
setError('Please upload a .txt file'); | |
return; | |
} | |
try { | |
const text = await file.text(); | |
setText(text); | |
setError(''); | |
} catch (err) { | |
setError('Failed to read file. Please try again.'); | |
console.error('Error reading file:', err); | |
} | |
}; | |
const clearResults = () => { | |
setText(''); | |
setResults(null); | |
setError(''); | |
}; | |
const getSentimentColor = (polarity) => { | |
switch (polarity) { | |
case 'positive': return 'text-green-600 bg-green-100'; | |
case 'negative': return 'text-red-600 bg-red-100'; | |
case 'neutral': return 'text-gray-600 bg-gray-100'; | |
case 'conflict': return 'text-yellow-600 bg-yellow-100'; | |
default: return 'text-gray-600 bg-gray-100'; | |
} | |
}; | |
return ( | |
<div className="min-h-screen bg-gradient-to-br from-blue-50 to-indigo-100"> | |
<div className="container mx-auto px-4 py-8"> | |
<div className="max-w-4xl mx-auto"> | |
{/* Header */} | |
<div className="text-center mb-8"> | |
<h1 className="text-4xl font-bold text-gray-800 mb-4"> | |
Aspect-Based Sentiment Analysis | |
</h1> | |
<p className="text-lg text-gray-600"> | |
Analyze aspects and sentiments in restaurant reviews using SetFit models | |
</p> | |
</div> | |
{/* Input Section */} | |
<div className="bg-white rounded-lg shadow-lg p-6 mb-6"> | |
<h2 className="text-2xl font-semibold text-gray-800 mb-4">Input Text</h2> | |
{/* File Upload */} | |
<div className="mb-4"> | |
<label className="block text-sm font-medium text-gray-700 mb-2"> | |
Upload Text File (.txt) | |
</label> | |
<input | |
ref={fileInputRef} | |
type="file" | |
accept=".txt" | |
onChange={handleFileUpload} | |
className="block w-full text-sm text-gray-500 | |
file:mr-4 file:py-2 file:px-4 | |
file:rounded-md file:border-0 | |
file:text-sm file:font-semibold | |
file:bg-blue-50 file:text-blue-700 | |
hover:file:bg-blue-100 | |
cursor-pointer" | |
/> | |
</div> | |
{/* Text Area */} | |
<div className="mb-4"> | |
<label className="block text-sm font-medium text-gray-700 mb-2"> | |
Or type/paste your text here: | |
</label> | |
<textarea | |
value={text} | |
onChange={(e) => setText(e.target.value)} | |
placeholder="Enter restaurant review text for analysis..." | |
className="w-full h-32 p-3 border border-gray-300 rounded-md focus:ring-2 focus:ring-blue-500 focus:border-blue-500 resize-none" | |
/> | |
</div> | |
{/* Error Message */} | |
{error && ( | |
<div className="mb-4 p-3 bg-red-100 border border-red-400 text-red-700 rounded-md"> | |
{error} | |
</div> | |
)} | |
{/* Action Buttons */} | |
<div className="flex gap-3"> | |
<button | |
onClick={handleAnalyze} | |
disabled={loading || !text.trim()} | |
className="px-6 py-2 bg-blue-600 text-white rounded-md hover:bg-blue-700 | |
disabled:bg-gray-400 disabled:cursor-not-allowed | |
flex items-center gap-2 font-medium transition-colors" | |
> | |
{loading ? ( | |
<> | |
<div className="animate-spin rounded-full h-4 w-4 border-b-2 border-white"></div> | |
Analyzing... | |
</> | |
) : ( | |
'Analyze Text' | |
)} | |
</button> | |
<button | |
onClick={clearResults} | |
className="px-6 py-2 bg-gray-500 text-white rounded-md hover:bg-gray-600 | |
font-medium transition-colors" | |
> | |
Clear | |
</button> | |
</div> | |
</div> | |
{/* Results Section */} | |
{results && ( | |
<div className="bg-white rounded-lg shadow-lg p-6"> | |
<h2 className="text-2xl font-semibold text-gray-800 mb-4">Analysis Results</h2> | |
{/* Original Text */} | |
<div className="mb-6"> | |
<h3 className="text-lg font-medium text-gray-700 mb-2">Original Text:</h3> | |
<div className="p-3 bg-gray-50 rounded-md border"> | |
{results.text} | |
</div> | |
</div> | |
{/* Predictions */} | |
<div> | |
<h3 className="text-lg font-medium text-gray-700 mb-4"> | |
Detected Aspects & Sentiments: | |
</h3> | |
{results.predictions && results.predictions.length > 0 ? ( | |
<div className="space-y-3"> | |
{results.predictions.map((prediction, index) => ( | |
<div key={index} className="border border-gray-200 rounded-md p-4"> | |
<div className="flex items-center justify-between mb-2"> | |
<span className="text-sm font-medium text-gray-600"> | |
Aspect Span: | |
</span> | |
<span className="font-semibold text-gray-800"> | |
"{prediction.span}" | |
</span> | |
</div> | |
<div className="flex items-center justify-between"> | |
<span className="text-sm font-medium text-gray-600"> | |
Sentiment: | |
</span> | |
<span className={`px-3 py-1 rounded-full text-sm font-medium ${getSentimentColor(prediction.polarity)}`}> | |
{prediction.polarity} | |
</span> | |
</div> | |
</div> | |
))} | |
</div> | |
) : ( | |
<div className="text-gray-500 text-center py-4"> | |
No aspects detected in the text. | |
</div> | |
)} | |
</div> | |
</div> | |
)} | |
</div> | |
</div> | |
</div> | |
); | |
}; | |
ReactDOM.render(<App />, document.getElementById('root')); | |
</script> | |
</body> | |
</html> | |
""" | |
return html_content | |
async def analyze_text(input_data: TextInput): | |
"""Analyze text for aspects and sentiment""" | |
global absa_model | |
if absa_model is None: | |
raise HTTPException(status_code=503, detail="Model not loaded yet. Please try again later.") | |
try: | |
text = input_data.text.strip() | |
if not text: | |
return ABSAResponse( | |
text=text, | |
predictions=[], | |
success=False, | |
message="Empty text provided" | |
) | |
logger.info(f"Analyzing text: {text[:100]}...") | |
# Run ABSA analysis | |
predictions = absa_model(text) | |
# Format predictions for response | |
formatted_predictions = [] | |
if predictions: | |
for pred in predictions: | |
formatted_predictions.append({ | |
"span": pred.get("span", ""), | |
"polarity": pred.get("polarity", "neutral") | |
}) | |
return ABSAResponse( | |
text=text, | |
predictions=formatted_predictions, | |
success=True, | |
message="Analysis completed successfully" | |
) | |
except Exception as e: | |
logger.error(f"Error during analysis: {str(e)}") | |
return ABSAResponse( | |
text=input_data.text, | |
predictions=[], | |
success=False, | |
message=f"Analysis failed: {str(e)}" | |
) | |
async def health_check(): | |
"""Health check endpoint""" | |
return { | |
"status": "healthy", | |
"model_loaded": absa_model is not None, | |
"message": "ABSA service is running" | |
} | |
if __name__ == "__main__": | |
uvicorn.run( | |
"app:app", | |
host="0.0.0.0", | |
port=8000, | |
reload=True, | |
log_level="info" | |
) |