import os import sqlite3 # May be used elsewhere from dotenv import load_dotenv from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing import List from contextlib import asynccontextmanager from langchain.schema import Document from langchain_together import ChatTogether, TogetherEmbeddings from langchain_core.prompts import ChatPromptTemplate from langchain_core.output_parsers import StrOutputParser from langchain_core.runnables import RunnablePassthrough from langchain_core.documents import Document import chromadb from langchain_community.vectorstores import Chroma # ✅ Updated import # --- 1. Environment & Constants --- load_dotenv() TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY") if not TOGETHER_API_KEY: raise ValueError("TOGETHER_API_KEY environment variable not set. Please check your .env file.") VECTOR_DB_DIR = os.getenv("VECTOR_DB_DIR", "/tmp/vector_db_chroma") COLLECTION_NAME = "my_instrument_manual_chunks" LLM_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo" EMBEDDINGS_MODEL_NAME = "togethercomputer/m2-bert-80M-32k-retrieval" # --- 2. Lifespan Event Handler --- @asynccontextmanager async def lifespan(app: FastAPI): global rag_chain, retriever, prompt, llm print("--- Initializing RAG components ---") try: llm = ChatTogether( model=LLM_MODEL_NAME, temperature=0.3, api_key=TOGETHER_API_KEY ) print(f"LLM {LLM_MODEL_NAME} initialized.") embeddings = TogetherEmbeddings( model=EMBEDDINGS_MODEL_NAME, api_key=TOGETHER_API_KEY ) client = chromadb.PersistentClient(path=VECTOR_DB_DIR) vectorstore = Chroma( persist_directory=VECTOR_DB_DIR, collection_name=COLLECTION_NAME, embedding_function=embeddings ) retriever = vectorstore.as_retriever(search_kwargs={"k": 5}) print("Retriever initialized.") answer_prompt = """ You are a professional HPLC instrument troubleshooting expert who specializes in helping junior researchers and students. Your task is to answer the user's troubleshooting questions in detail and clearly based on the HPLC instrument knowledge provided below. If there is no direct answer in the knowledge, please provide the most reasonable speculative suggestions based on your expert judgment, or ask further clarifying questions. Please ensure that your answers are logically clear, easy to understand, and directly address the user's questions.""" prompt = ChatPromptTemplate.from_messages([ ("system", answer_prompt), ("user", "context: {context}\n\nquestion: {question}"), ]) def format_docs(docs: List[Document]) -> str: return "\n\n".join(doc.page_content for doc in docs) rag_chain = ( {"context": retriever | format_docs, "question": RunnablePassthrough()} | prompt | llm | StrOutputParser() ) print("RAG chain ready.") except Exception as e: raise RuntimeError(f"Failed to initialize RAG chain: {e}") yield # Keep app running # --- 3. Initialize FastAPI App --- app = FastAPI( title="LabAid AI", description="API service for a Retrieval-Augmented Generation (RAG) AI assistant.", version="1.0.0", lifespan=lifespan # ✅ Updated lifespan hook ) # --- 4. CORS Middleware --- origins = [ "http://localhost", "http://localhost:3000", "http://127.0.0.1:8000", ] app.add_middleware( CORSMiddleware, allow_origins=origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # --- 5. Request/Response Models --- class QueryRequest(BaseModel): query: str class QueryResponse(BaseModel): answer: str source_documents: List[str] # --- 6. RAG Query Endpoint --- @app.post("/ask", response_model=QueryResponse) async def ask_rag(request: QueryRequest): if rag_chain is None or retriever is None: raise HTTPException(status_code=500, detail="RAG chain not initialized.") try: user_query = request.query print(f"Received query: {user_query}") retrieved_docs = retriever.invoke(user_query) formatted_context = "\n\n".join(doc.page_content for doc in retrieved_docs) answer = (prompt | llm | StrOutputParser()).invoke({ "context": formatted_context, "question": user_query }) sources = [doc.page_content for doc in retrieved_docs] return QueryResponse(answer=answer, source_documents=sources) except Exception as e: print(f"Error: {e}") raise HTTPException(status_code=500, detail=f"Failed to process query: {e}") # ✅ --- 7. React 靜態檔案掛載 --- from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse # 掛載 build/static 資源 app.mount("/static", StaticFiles(directory="build/static"), name="static") # 根目錄回傳 React 首頁 @app.get("/") async def serve_react_app(): return FileResponse("build/index.html")