LabAid_v1 / api_service.py
J266501's picture
update
1214e80 verified
import os
import sqlite3 # May be used elsewhere
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List
from contextlib import asynccontextmanager
from langchain.schema import Document
from langchain_together import ChatTogether, TogetherEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.documents import Document
import chromadb
from langchain_community.vectorstores import Chroma # โœ… Updated import
# --- 1. Environment & Constants ---
load_dotenv()
TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY")
if not TOGETHER_API_KEY:
raise ValueError("TOGETHER_API_KEY environment variable not set. Please check your .env file.")
VECTOR_DB_DIR = os.getenv("VECTOR_DB_DIR", "/tmp/vector_db_chroma")
COLLECTION_NAME = "my_instrument_manual_chunks"
LLM_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"
EMBEDDINGS_MODEL_NAME = "togethercomputer/m2-bert-80M-32k-retrieval"
# --- 2. Lifespan Event Handler ---
@asynccontextmanager
async def lifespan(app: FastAPI):
global rag_chain, retriever, prompt, llm
print("--- Initializing RAG components ---")
try:
llm = ChatTogether(
model=LLM_MODEL_NAME,
temperature=0.3,
api_key=TOGETHER_API_KEY
)
print(f"LLM {LLM_MODEL_NAME} initialized.")
embeddings = TogetherEmbeddings(
model=EMBEDDINGS_MODEL_NAME,
api_key=TOGETHER_API_KEY
)
client = chromadb.PersistentClient(path=VECTOR_DB_DIR)
vectorstore = Chroma(
persist_directory=VECTOR_DB_DIR,
collection_name=COLLECTION_NAME,
embedding_function=embeddings
)
retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
print("Retriever initialized.")
answer_prompt = """ You are a professional HPLC instrument troubleshooting expert who specializes in helping junior researchers and students.
Your task is to answer the user's troubleshooting questions in detail and clearly based on the HPLC instrument knowledge provided below.
If there is no direct answer in the knowledge, please provide the most reasonable speculative suggestions based on your expert judgment, or ask further clarifying questions.
Please ensure that your answers are logically clear, easy to understand, and directly address the user's questions."""
prompt = ChatPromptTemplate.from_messages([
("system", answer_prompt),
("user", "context: {context}\n\nquestion: {question}"),
])
def format_docs(docs: List[Document]) -> str:
return "\n\n".join(doc.page_content for doc in docs)
rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
print("RAG chain ready.")
except Exception as e:
raise RuntimeError(f"Failed to initialize RAG chain: {e}")
yield # Keep app running
# --- 3. Initialize FastAPI App ---
app = FastAPI(
title="LabAid AI",
description="API service for a Retrieval-Augmented Generation (RAG) AI assistant.",
version="1.0.0",
lifespan=lifespan # โœ… Updated lifespan hook
)
# --- 4. CORS Middleware ---
origins = [
"http://localhost",
"http://localhost:3000",
"http://127.0.0.1:8000",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- 5. Request/Response Models ---
class QueryRequest(BaseModel):
query: str
class QueryResponse(BaseModel):
answer: str
source_documents: List[str]
# --- 6. RAG Query Endpoint ---
@app.post("/ask", response_model=QueryResponse)
async def ask_rag(request: QueryRequest):
if rag_chain is None or retriever is None:
raise HTTPException(status_code=500, detail="RAG chain not initialized.")
try:
user_query = request.query
print(f"Received query: {user_query}")
retrieved_docs = retriever.invoke(user_query)
formatted_context = "\n\n".join(doc.page_content for doc in retrieved_docs)
answer = (prompt | llm | StrOutputParser()).invoke({
"context": formatted_context,
"question": user_query
})
sources = [doc.page_content for doc in retrieved_docs]
return QueryResponse(answer=answer, source_documents=sources)
except Exception as e:
print(f"Error: {e}")
raise HTTPException(status_code=500, detail=f"Failed to process query: {e}")
# โœ… --- 7. React ้œๆ…‹ๆช”ๆกˆๆŽ›่ผ‰ ---
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
# ๆŽ›่ผ‰ build/static ่ณ‡ๆบ
app.mount("/static", StaticFiles(directory="build/static"), name="static")
# ๆ น็›ฎ้Œ„ๅ›žๅ‚ณ React ้ฆ–้ 
@app.get("/")
async def serve_react_app():
return FileResponse("build/index.html")