File size: 5,266 Bytes
effe6e2 7efd1a1 effe6e2 cc2875a effe6e2 1214e80 effe6e2 ec37931 effe6e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import os
import sqlite3 # May be used elsewhere
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List
from contextlib import asynccontextmanager
from langchain.schema import Document
from langchain_together import ChatTogether, TogetherEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.documents import Document
import chromadb
from langchain_community.vectorstores import Chroma # โ
Updated import
# --- 1. Environment & Constants ---
load_dotenv()
TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY")
if not TOGETHER_API_KEY:
raise ValueError("TOGETHER_API_KEY environment variable not set. Please check your .env file.")
VECTOR_DB_DIR = os.getenv("VECTOR_DB_DIR", "/tmp/vector_db_chroma")
COLLECTION_NAME = "my_instrument_manual_chunks"
LLM_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"
EMBEDDINGS_MODEL_NAME = "togethercomputer/m2-bert-80M-32k-retrieval"
# --- 2. Lifespan Event Handler ---
@asynccontextmanager
async def lifespan(app: FastAPI):
global rag_chain, retriever, prompt, llm
print("--- Initializing RAG components ---")
try:
llm = ChatTogether(
model=LLM_MODEL_NAME,
temperature=0.3,
api_key=TOGETHER_API_KEY
)
print(f"LLM {LLM_MODEL_NAME} initialized.")
embeddings = TogetherEmbeddings(
model=EMBEDDINGS_MODEL_NAME,
api_key=TOGETHER_API_KEY
)
client = chromadb.PersistentClient(path=VECTOR_DB_DIR)
vectorstore = Chroma(
persist_directory=VECTOR_DB_DIR,
collection_name=COLLECTION_NAME,
embedding_function=embeddings
)
retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
print("Retriever initialized.")
answer_prompt = """ You are a professional HPLC instrument troubleshooting expert who specializes in helping junior researchers and students.
Your task is to answer the user's troubleshooting questions in detail and clearly based on the HPLC instrument knowledge provided below.
If there is no direct answer in the knowledge, please provide the most reasonable speculative suggestions based on your expert judgment, or ask further clarifying questions.
Please ensure that your answers are logically clear, easy to understand, and directly address the user's questions."""
prompt = ChatPromptTemplate.from_messages([
("system", answer_prompt),
("user", "context: {context}\n\nquestion: {question}"),
])
def format_docs(docs: List[Document]) -> str:
return "\n\n".join(doc.page_content for doc in docs)
rag_chain = (
{"context": retriever | format_docs, "question": RunnablePassthrough()}
| prompt
| llm
| StrOutputParser()
)
print("RAG chain ready.")
except Exception as e:
raise RuntimeError(f"Failed to initialize RAG chain: {e}")
yield # Keep app running
# --- 3. Initialize FastAPI App ---
app = FastAPI(
title="LabAid AI",
description="API service for a Retrieval-Augmented Generation (RAG) AI assistant.",
version="1.0.0",
lifespan=lifespan # โ
Updated lifespan hook
)
# --- 4. CORS Middleware ---
origins = [
"http://localhost",
"http://localhost:3000",
"http://127.0.0.1:8000",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- 5. Request/Response Models ---
class QueryRequest(BaseModel):
query: str
class QueryResponse(BaseModel):
answer: str
source_documents: List[str]
# --- 6. RAG Query Endpoint ---
@app.post("/ask", response_model=QueryResponse)
async def ask_rag(request: QueryRequest):
if rag_chain is None or retriever is None:
raise HTTPException(status_code=500, detail="RAG chain not initialized.")
try:
user_query = request.query
print(f"Received query: {user_query}")
retrieved_docs = retriever.invoke(user_query)
formatted_context = "\n\n".join(doc.page_content for doc in retrieved_docs)
answer = (prompt | llm | StrOutputParser()).invoke({
"context": formatted_context,
"question": user_query
})
sources = [doc.page_content for doc in retrieved_docs]
return QueryResponse(answer=answer, source_documents=sources)
except Exception as e:
print(f"Error: {e}")
raise HTTPException(status_code=500, detail=f"Failed to process query: {e}")
# โ
--- 7. React ้ๆ
ๆชๆกๆ่ผ ---
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
# ๆ่ผ build/static ่ณๆบ
app.mount("/static", StaticFiles(directory="build/static"), name="static")
# ๆ น็ฎ้ๅๅณ React ้ฆ้
@app.get("/")
async def serve_react_app():
return FileResponse("build/index.html")
|