|
import os |
|
import sqlite3 |
|
from dotenv import load_dotenv |
|
from fastapi import FastAPI, HTTPException |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from pydantic import BaseModel |
|
from typing import List |
|
from contextlib import asynccontextmanager |
|
|
|
from langchain.schema import Document |
|
from langchain_together import ChatTogether, TogetherEmbeddings |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.runnables import RunnablePassthrough |
|
from langchain_core.documents import Document |
|
|
|
import chromadb |
|
from langchain_community.vectorstores import Chroma |
|
|
|
|
|
load_dotenv() |
|
|
|
TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY") |
|
if not TOGETHER_API_KEY: |
|
raise ValueError("TOGETHER_API_KEY environment variable not set. Please check your .env file.") |
|
|
|
VECTOR_DB_DIR = os.getenv("VECTOR_DB_DIR", "/tmp/vector_db_chroma") |
|
COLLECTION_NAME = "my_instrument_manual_chunks" |
|
|
|
LLM_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo" |
|
EMBEDDINGS_MODEL_NAME = "togethercomputer/m2-bert-80M-32k-retrieval" |
|
|
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
global rag_chain, retriever, prompt, llm |
|
|
|
print("--- Initializing RAG components ---") |
|
try: |
|
llm = ChatTogether( |
|
model=LLM_MODEL_NAME, |
|
temperature=0.3, |
|
api_key=TOGETHER_API_KEY |
|
) |
|
print(f"LLM {LLM_MODEL_NAME} initialized.") |
|
|
|
embeddings = TogetherEmbeddings( |
|
model=EMBEDDINGS_MODEL_NAME, |
|
api_key=TOGETHER_API_KEY |
|
) |
|
client = chromadb.PersistentClient(path=VECTOR_DB_DIR) |
|
vectorstore = Chroma( |
|
persist_directory=VECTOR_DB_DIR, |
|
collection_name=COLLECTION_NAME, |
|
embedding_function=embeddings |
|
) |
|
retriever = vectorstore.as_retriever(search_kwargs={"k": 5}) |
|
print("Retriever initialized.") |
|
|
|
answer_prompt = """ You are a professional HPLC instrument troubleshooting expert who specializes in helping junior researchers and students. |
|
Your task is to answer the user's troubleshooting questions in detail and clearly based on the HPLC instrument knowledge provided below. |
|
If there is no direct answer in the knowledge, please provide the most reasonable speculative suggestions based on your expert judgment, or ask further clarifying questions. |
|
Please ensure that your answers are logically clear, easy to understand, and directly address the user's questions.""" |
|
prompt = ChatPromptTemplate.from_messages([ |
|
("system", answer_prompt), |
|
("user", "context: {context}\n\nquestion: {question}"), |
|
]) |
|
|
|
def format_docs(docs: List[Document]) -> str: |
|
return "\n\n".join(doc.page_content for doc in docs) |
|
|
|
rag_chain = ( |
|
{"context": retriever | format_docs, "question": RunnablePassthrough()} |
|
| prompt |
|
| llm |
|
| StrOutputParser() |
|
) |
|
print("RAG chain ready.") |
|
|
|
except Exception as e: |
|
raise RuntimeError(f"Failed to initialize RAG chain: {e}") |
|
|
|
yield |
|
|
|
|
|
app = FastAPI( |
|
title="LabAid AI", |
|
description="API service for a Retrieval-Augmented Generation (RAG) AI assistant.", |
|
version="1.0.0", |
|
lifespan=lifespan |
|
) |
|
|
|
|
|
origins = [ |
|
"http://localhost", |
|
"http://localhost:3000", |
|
"http://127.0.0.1:8000", |
|
] |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=origins, |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
class QueryRequest(BaseModel): |
|
query: str |
|
|
|
class QueryResponse(BaseModel): |
|
answer: str |
|
source_documents: List[str] |
|
|
|
|
|
@app.post("/ask", response_model=QueryResponse) |
|
async def ask_rag(request: QueryRequest): |
|
if rag_chain is None or retriever is None: |
|
raise HTTPException(status_code=500, detail="RAG chain not initialized.") |
|
|
|
try: |
|
user_query = request.query |
|
print(f"Received query: {user_query}") |
|
|
|
retrieved_docs = retriever.invoke(user_query) |
|
formatted_context = "\n\n".join(doc.page_content for doc in retrieved_docs) |
|
|
|
answer = (prompt | llm | StrOutputParser()).invoke({ |
|
"context": formatted_context, |
|
"question": user_query |
|
}) |
|
|
|
sources = [doc.page_content for doc in retrieved_docs] |
|
return QueryResponse(answer=answer, source_documents=sources) |
|
|
|
except Exception as e: |
|
print(f"Error: {e}") |
|
raise HTTPException(status_code=500, detail=f"Failed to process query: {e}") |
|
|
|
|
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.responses import FileResponse |
|
|
|
|
|
app.mount("/static", StaticFiles(directory="build/static"), name="static") |
|
|
|
|
|
@app.get("/") |
|
async def serve_react_app(): |
|
return FileResponse("build/index.html") |
|
|
|
|