File size: 5,266 Bytes
effe6e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7efd1a1
effe6e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc2875a
effe6e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1214e80
effe6e2
 
 
 
ec37931
 
 
 
 
 
 
 
 
 
 
 
effe6e2
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import os
import sqlite3  # May be used elsewhere
from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List
from contextlib import asynccontextmanager

from langchain.schema import Document
from langchain_together import ChatTogether, TogetherEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_core.documents import Document

import chromadb
from langchain_community.vectorstores import Chroma  # โœ… Updated import

# --- 1. Environment & Constants ---
load_dotenv()

TOGETHER_API_KEY = os.getenv("TOGETHER_API_KEY")
if not TOGETHER_API_KEY:
    raise ValueError("TOGETHER_API_KEY environment variable not set. Please check your .env file.")

VECTOR_DB_DIR = os.getenv("VECTOR_DB_DIR", "/tmp/vector_db_chroma")
COLLECTION_NAME = "my_instrument_manual_chunks"

LLM_MODEL_NAME = "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo"
EMBEDDINGS_MODEL_NAME = "togethercomputer/m2-bert-80M-32k-retrieval"

# --- 2. Lifespan Event Handler ---
@asynccontextmanager
async def lifespan(app: FastAPI):
    global rag_chain, retriever, prompt, llm

    print("--- Initializing RAG components ---")
    try:
        llm = ChatTogether(
            model=LLM_MODEL_NAME,
            temperature=0.3,
            api_key=TOGETHER_API_KEY
        )
        print(f"LLM {LLM_MODEL_NAME} initialized.")

        embeddings = TogetherEmbeddings(
            model=EMBEDDINGS_MODEL_NAME,
            api_key=TOGETHER_API_KEY
        )
        client = chromadb.PersistentClient(path=VECTOR_DB_DIR)
        vectorstore = Chroma(
            persist_directory=VECTOR_DB_DIR,
            collection_name=COLLECTION_NAME,
            embedding_function=embeddings
        )
        retriever = vectorstore.as_retriever(search_kwargs={"k": 5})
        print("Retriever initialized.")

        answer_prompt = """ You are a professional HPLC instrument troubleshooting expert who specializes in helping junior researchers and students.
                            Your task is to answer the user's troubleshooting questions in detail and clearly based on the HPLC instrument knowledge provided below.
                            If there is no direct answer in the knowledge, please provide the most reasonable speculative suggestions based on your expert judgment, or ask further clarifying questions.
                            Please ensure that your answers are logically clear, easy to understand, and directly address the user's questions."""
        prompt = ChatPromptTemplate.from_messages([
            ("system", answer_prompt),
            ("user", "context: {context}\n\nquestion: {question}"),
        ])

        def format_docs(docs: List[Document]) -> str:
            return "\n\n".join(doc.page_content for doc in docs)

        rag_chain = (
            {"context": retriever | format_docs, "question": RunnablePassthrough()}
            | prompt
            | llm
            | StrOutputParser()
        )
        print("RAG chain ready.")

    except Exception as e:
        raise RuntimeError(f"Failed to initialize RAG chain: {e}")

    yield  # Keep app running

# --- 3. Initialize FastAPI App ---
app = FastAPI(
    title="LabAid AI",
    description="API service for a Retrieval-Augmented Generation (RAG) AI assistant.",
    version="1.0.0",
    lifespan=lifespan  # โœ… Updated lifespan hook
)

# --- 4. CORS Middleware ---
origins = [
    "http://localhost",
    "http://localhost:3000",
    "http://127.0.0.1:8000",
]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# --- 5. Request/Response Models ---
class QueryRequest(BaseModel):
    query: str

class QueryResponse(BaseModel):
    answer: str
    source_documents: List[str]

# --- 6. RAG Query Endpoint ---
@app.post("/ask", response_model=QueryResponse)
async def ask_rag(request: QueryRequest):
    if rag_chain is None or retriever is None:
        raise HTTPException(status_code=500, detail="RAG chain not initialized.")

    try:
        user_query = request.query
        print(f"Received query: {user_query}")

        retrieved_docs = retriever.invoke(user_query)
        formatted_context = "\n\n".join(doc.page_content for doc in retrieved_docs)

        answer = (prompt | llm | StrOutputParser()).invoke({
            "context": formatted_context,
            "question": user_query
        })

        sources = [doc.page_content for doc in retrieved_docs]
        return QueryResponse(answer=answer, source_documents=sources)

    except Exception as e:
        print(f"Error: {e}")
        raise HTTPException(status_code=500, detail=f"Failed to process query: {e}")
    
# โœ… --- 7. React ้œๆ…‹ๆช”ๆกˆๆŽ›่ผ‰ ---
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse

# ๆŽ›่ผ‰ build/static ่ณ‡ๆบ
app.mount("/static", StaticFiles(directory="build/static"), name="static")

# ๆ น็›ฎ้Œ„ๅ›žๅ‚ณ React ้ฆ–้ 
@app.get("/")
async def serve_react_app():
    return FileResponse("build/index.html")