Rivalcoder
Update The Model issues and Prompt
6bc8549
raw
history blame
10.5 kB
import os
import warnings
import logging
import time
from datetime import datetime
# Set up cache directory for HuggingFace models
cache_dir = os.path.join(os.getcwd(), ".cache")
os.makedirs(cache_dir, exist_ok=True)
os.environ['HF_HOME'] = cache_dir
os.environ['TRANSFORMERS_CACHE'] = cache_dir
# Suppress TensorFlow warnings
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'
os.environ['TF_LOGGING_LEVEL'] = 'ERROR'
os.environ['TF_ENABLE_DEPRECATION_WARNINGS'] = '0'
# Suppress specific TensorFlow deprecation warnings
warnings.filterwarnings('ignore', category=DeprecationWarning, module='tensorflow')
logging.getLogger('tensorflow').setLevel(logging.ERROR)
from fastapi import FastAPI, Request, HTTPException, Depends, Header
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from parser import parse_pdf_from_url, parse_pdf_from_file
from embedder import build_faiss_index, preload_model
from retriever import retrieve_chunks
from llm import query_gemini
import uvicorn
app = FastAPI(title="HackRx Insurance Policy Assistant", version="1.0.0")
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Preload the model at startup
@app.on_event("startup")
async def startup_event():
print("Starting up HackRx Insurance Policy Assistant...")
print("Preloading sentence transformer model...")
preload_model()
print("Model preloading completed. API is ready to serve requests!")
@app.get("/")
async def root():
return {"message": "HackRx Insurance Policy Assistant API is running!"}
@app.get("/health")
async def health_check():
return {"status": "healthy", "message": "API is ready to process requests"}
class QueryRequest(BaseModel):
documents: str
questions: list[str]
class LocalQueryRequest(BaseModel):
document_path: str
questions: list[str]
def verify_token(authorization: str = Header(None)):
if not authorization or not authorization.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid authorization header")
token = authorization.replace("Bearer ", "")
# For demo purposes, accept any token. In production, validate against a database
if not token:
raise HTTPException(status_code=401, detail="Invalid token")
return token
@app.post("/api/v1/hackrx/run")
async def run_query(request: QueryRequest, token: str = Depends(verify_token)):
start_time = time.time()
timing_data = {}
try:
print(f"\n=== INPUT JSON ===")
print(f"Documents: {request.documents}")
print(f"Questions: {request.questions}")
print(f"==================\n")
print(f"Processing {len(request.questions)} questions...")
# Time PDF parsing
pdf_start = time.time()
text_chunks = parse_pdf_from_url(request.documents)
pdf_time = time.time() - pdf_start
timing_data['pdf_parsing'] = round(pdf_time, 2)
print(f"Extracted {len(text_chunks)} text chunks from PDF")
# Time FAISS index building
index_start = time.time()
index, texts = build_faiss_index(text_chunks)
index_time = time.time() - index_start
timing_data['faiss_index_building'] = round(index_time, 2)
# Time chunk retrieval for all questions
retrieval_start = time.time()
all_chunks = set()
for i, question in enumerate(request.questions):
question_start = time.time()
top_chunks = retrieve_chunks(index, texts, question)
question_time = time.time() - question_start
all_chunks.update(top_chunks)
retrieval_time = time.time() - retrieval_start
timing_data['chunk_retrieval'] = round(retrieval_time, 2)
print(f"Retrieved {len(all_chunks)} unique chunks")
# Time LLM processing
llm_start = time.time()
print(f"Processing all {len(request.questions)} questions in batch...")
response = query_gemini(request.questions, list(all_chunks))
llm_time = time.time() - llm_start
timing_data['llm_processing'] = round(llm_time, 2)
# Time response processing
response_start = time.time()
# Extract answers from the JSON response
if isinstance(response, dict) and "answers" in response:
answers = response["answers"]
# Ensure we have the right number of answers
while len(answers) < len(request.questions):
answers.append("Not Found")
answers = answers[:len(request.questions)]
else:
# Fallback if response is not in expected format
answers = [response] if isinstance(response, str) else []
# Ensure we have the right number of answers
while len(answers) < len(request.questions):
answers.append("Not Found")
answers = answers[:len(request.questions)]
response_time = time.time() - response_start
timing_data['response_processing'] = round(response_time, 2)
print(f"Generated {len(answers)} answers")
# Calculate total time
total_time = time.time() - start_time
timing_data['total_time'] = round(total_time, 2)
print(f"\n=== TIMING BREAKDOWN ===")
print(f"PDF Parsing: {timing_data['pdf_parsing']}s")
print(f"FAISS Index Building: {timing_data['faiss_index_building']}s")
print(f"Chunk Retrieval: {timing_data['chunk_retrieval']}s")
print(f"LLM Processing: {timing_data['llm_processing']}s")
print(f"Response Processing: {timing_data['response_processing']}s")
print(f"TOTAL TIME: {timing_data['total_time']}s")
print(f"=======================\n")
result = {"answers": answers}
print(f"=== OUTPUT JSON ===")
print(f"{result}")
print(f"==================\n")
return result
except Exception as e:
total_time = time.time() - start_time
print(f"Error after {total_time:.2f} seconds: {str(e)}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
@app.post("/api/v1/hackrx/local")
async def run_local_query(request: LocalQueryRequest):
start_time = time.time()
timing_data = {}
try:
print(f"\n=== INPUT JSON ===")
print(f"Document Path: {request.document_path}")
print(f"Questions: {request.questions}")
print(f"==================\n")
print(f"Processing local document: {request.document_path}")
print(f"Processing {len(request.questions)} questions...")
# Time local PDF parsing
pdf_start = time.time()
text_chunks = parse_pdf_from_file(request.document_path)
pdf_time = time.time() - pdf_start
timing_data['pdf_parsing'] = round(pdf_time, 2)
print(f"Extracted {len(text_chunks)} text chunks from local PDF")
# Time FAISS index building
index_start = time.time()
index, texts = build_faiss_index(text_chunks)
index_time = time.time() - index_start
timing_data['faiss_index_building'] = round(index_time, 2)
# Time chunk retrieval for all questions
retrieval_start = time.time()
all_chunks = set()
for i, question in enumerate(request.questions):
question_start = time.time()
top_chunks = retrieve_chunks(index, texts, question)
question_time = time.time() - question_start
all_chunks.update(top_chunks)
retrieval_time = time.time() - retrieval_start
timing_data['chunk_retrieval'] = round(retrieval_time, 2)
print(f"Retrieved {len(all_chunks)} unique chunks")
# Time LLM processing
llm_start = time.time()
print(f"Processing all {len(request.questions)} questions in batch...")
response = query_gemini(request.questions, list(all_chunks))
llm_time = time.time() - llm_start
timing_data['llm_processing'] = round(llm_time, 2)
# Time response processing
response_start = time.time()
# Extract answers from the JSON response
if isinstance(response, dict) and "answers" in response:
answers = response["answers"]
# Ensure we have the right number of answers
while len(answers) < len(request.questions):
answers.append("Not Found")
answers = answers[:len(request.questions)]
else:
# Fallback if response is not in expected format
answers = [response] if isinstance(response, str) else []
# Ensure we have the right number of answers
while len(answers) < len(request.questions):
answers.append("Not Found")
answers = answers[:len(request.questions)]
response_time = time.time() - response_start
timing_data['response_processing'] = round(response_time, 2)
print(f"Generated {len(answers)} answers")
# Calculate total time
total_time = time.time() - start_time
timing_data['total_time'] = round(total_time, 2)
print(f"\n=== TIMING BREAKDOWN ===")
print(f"PDF Parsing: {timing_data['pdf_parsing']}s")
print(f"FAISS Index Building: {timing_data['faiss_index_building']}s")
print(f"Chunk Retrieval: {timing_data['chunk_retrieval']}s")
print(f"LLM Processing: {timing_data['llm_processing']}s")
print(f"Response Processing: {timing_data['response_processing']}s")
print(f"TOTAL TIME: {timing_data['total_time']}s")
print(f"=======================\n")
result = {"answers": answers}
print(f"=== OUTPUT JSON ===")
print(f"{result}")
print(f"==================\n")
return result
except Exception as e:
total_time = time.time() - start_time
print(f"Error after {total_time:.2f} seconds: {str(e)}")
raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}")
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860))
uvicorn.run("app:app", host="0.0.0.0", port=port)