Rivalcoder
Add application file
e15840d
raw
history blame
5.97 kB
import os
import warnings
import logging
# Set up cache directory for HuggingFace models
cache_dir = os.path.join(os.getcwd(), ".cache")
os.makedirs(cache_dir, exist_ok=True)
os.environ['HF_HOME'] = cache_dir
os.environ['TRANSFORMERS_CACHE'] = cache_dir
# Suppress TensorFlow warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
os.environ['TF_LOGGING_LEVEL'] = 'ERROR'
os.environ['TF_ENABLE_DEPRECATION_WARNINGS'] = '0'
# Suppress specific TensorFlow deprecation warnings
warnings.filterwarnings('ignore', category=DeprecationWarning, module='tensorflow')
logging.getLogger('tensorflow').setLevel(logging.ERROR)
from fastapi import FastAPI, Request, HTTPException, Depends, Header
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from parser import parse_pdf_from_url, parse_pdf_from_file
from embedder import build_faiss_index
from retriever import retrieve_chunks
from llm import query_gemini
import uvicorn
app = FastAPI(title="HackRx Insurance Policy Assistant", version="1.0.0")
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
async def root():
return {"message": "HackRx Insurance Policy Assistant API is running!"}
@app.get("/health")
async def health_check():
return {"status": "healthy", "message": "API is ready to process requests"}
class QueryRequest(BaseModel):
documents: str
questions: list[str]
class LocalQueryRequest(BaseModel):
document_path: str
questions: list[str]
def verify_token(authorization: str = Header(None)):
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid authorization header")
token = authorization.replace("Bearer ", "")
# For demo purposes, accept any token. In production, validate against a database
if not token:
raise HTTPException(status_code=401, detail="Invalid token")
return token
@app.post("/api/v1/hackrx/run")
async def run_query(request: QueryRequest, token: str = Depends(verify_token)):
try:
print(f"Processing {len(request.questions)} questions...")
text_chunks = parse_pdf_from_url(request.documents)
print(f"Extracted {len(text_chunks)} text chunks from PDF")
index, texts = build_faiss_index(text_chunks)
# Get relevant chunks for all questions at once
all_chunks = set()
for question in request.questions:
top_chunks = retrieve_chunks(index, texts, question)
all_chunks.update(top_chunks)
# Process all questions in a single LLM call
print(f"Processing all {len(request.questions)} questions in batch...")
response = query_gemini(request.questions, list(all_chunks))
# Extract answers from the JSON response
if isinstance(response, dict) and "answers" in response:
answers = response["answers"]
# Ensure we have the right number of answers
while len(answers) < len(request.questions):
answers.append("Not Found")
answers = answers[:len(request.questions)]
else:
# Fallback if response is not in expected format
answers = [response] if isinstance(response, str) else []
# Ensure we have the right number of answers
while len(answers) < len(request.questions):
answers.append("Not Found")
answers = answers[:len(request.questions)]
print(f"Generated {len(answers)} answers")
return { "answers": answers }
except Exception as e:
print(f"Error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@app.post("/api/v1/hackrx/local")
async def run_local_query(request: LocalQueryRequest):
try:
print(f"Processing local document: {request.document_path}")
print(f"Processing {len(request.questions)} questions...")
# Parse local PDF file
text_chunks = parse_pdf_from_file(request.document_path)
print(f"Extracted {len(text_chunks)} text chunks from local PDF")
index, texts = build_faiss_index(text_chunks)
# Get relevant chunks for all questions at once
all_chunks = set()
for question in request.questions:
top_chunks = retrieve_chunks(index, texts, question)
all_chunks.update(top_chunks)
# Process all questions in a single LLM call
print(f"Processing all {len(request.questions)} questions in batch...")
response = query_gemini(request.questions, list(all_chunks))
# Extract answers from the JSON response
if isinstance(response, dict) and "answers" in response:
answers = response["answers"]
# Ensure we have the right number of answers
while len(answers) < len(request.questions):
answers.append("Not Found")
answers = answers[:len(request.questions)]
else:
# Fallback if response is not in expected format
answers = [response] if isinstance(response, str) else []
# Ensure we have the right number of answers
while len(answers) < len(request.questions):
answers.append("Not Found")
answers = answers[:len(request.questions)]
print(f"Generated {len(answers)} answers")
return { "answers": answers }
except Exception as e:
print(f"Error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
uvicorn.run("app:app", host="0.0.0.0", port=port)