brendon-ai commited on
Commit
d75dc74
·
verified ·
1 Parent(s): de43d85

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -157
app.py CHANGED
@@ -2,196 +2,157 @@
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
  from typing import Optional
5
- import torch
6
  import uvicorn
7
- from transformers import pipeline
8
  import os
9
- from contextlib import asynccontextmanager # Import this!
10
- import sys # Import sys for sys.exit()
11
 
12
- # Optional: For gated models like Llama 3 from Meta, uncomment and configure HF_TOKEN
13
- # from huggingface_hub import login
14
 
15
- # --- Global variable to store the pipeline ---
16
- generator = None
17
- # Choose a model appropriate for free tier (e.g., 7B-8B parameters)
18
- # For DeepSeek, DeepSeek-V2-Lite-Base (7B) might be loadable, but DeepSeek-V3 is too big.
19
- MODEL_NAME = "brendon-ai/gemma3-dolly-finetuned"
 
20
 
21
- #"openai-community/gpt2" # Recommended for free tier
 
 
 
22
 
23
- # --- Lifespan Event Handler ---
24
- @asynccontextmanager
25
- async def lifespan(app: FastAPI):
26
- """
27
- Handles startup and shutdown events for the FastAPI application.
28
- Loads the model on startup and can optionally clean up on shutdown.
29
- """
30
- global generator
31
- try:
32
- # --- Optional: Login to Hugging Face Hub for gated models ---
33
- # If you are using a gated model (e.g., meta-llama/Llama-3-8B-Instruct),
34
- # uncomment the following lines and ensure HF_TOKEN is set as a Space Secret.
35
- # hf_token = os.getenv("HF_TOKEN")
36
- # if hf_token:
37
- # login(token=hf_token)
38
- # print("Logged into Hugging Face Hub.")
39
- # else:
40
- # print("HF_TOKEN not found. Make sure it's set as a Space Secret if using a gated model.")
41
 
42
- # --- Startup Code: Load the model ---
43
- if torch.cuda.is_available():
44
- print(f"CUDA is available! Using {torch.cuda.get_device_name(0)}")
45
- device = 0 # Use GPU
46
- # For larger models, use device_map="auto" and torch_dtype
47
- # device_map = "auto"
48
- # torch_dtype = torch.bfloat16 # or torch.float16 for GPUs that support it
49
- else:
50
- print("CUDA not available, using CPU. Inference will be very slow for this model size.")
51
- device = -1 # Use CPU
52
- # device_map = None
53
- # torch_dtype = torch.float32 # Default for CPU
54
 
55
- print(f"Attempting to load model '{MODEL_NAME}' on device: {'cuda' if device == 0 else 'cpu'}")
56
-
57
- # The pipeline automatically handles AutoModel and AutoTokenizer.
58
- # For better memory management with larger models, directly load with model_kwargs:
59
- generator = pipeline(
60
- 'text-generation',
61
- model=MODEL_NAME,
62
- device=device,
63
- # Pass your HF token to the model loading for gated models
64
- # token=os.getenv("HF_TOKEN"), # Uncomment if using a gated model
65
- # For 7B models on 16GB GPU, float16 is usually enough, but bfloat16 is better if supported
66
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
67
- # For more fine-grained control and auto device mapping for multiple GPUs:
68
- # model_kwargs={"device_map": "auto", "torch_dtype": torch.float16}
69
- )
70
- print("Model loaded successfully!")
71
 
72
- # 'yield' signifies that the startup code has completed, and the application
73
- # can now start processing requests.
74
- yield
 
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  except Exception as e:
77
- print(f"CRITICAL ERROR: Failed to load model during startup: {e}")
78
- # Exit with a non-zero code to indicate failure if model loading fails
79
- sys.exit(1)
80
-
81
- finally:
82
- # --- Shutdown Code (Optional): Clean up resources ---
83
- print("Application shutting down. Any cleanup can go here.")
84
-
85
-
86
- # --- Initialize FastAPI application with the lifespan handler ---
87
- app = FastAPI(lifespan=lifespan, # Use the lifespan context manager
88
- title="Text Generation API",
89
- description="A simple text generation API using Hugging Face transformers",
90
- version="1.0.0"
91
- )
92
-
93
- # Request model
94
- class TextGenerationRequest(BaseModel):
95
- prompt: str
96
- max_new_tokens: Optional[int] = 250 # Changed from max_length for better control
97
- num_return_sequences: Optional[int] = 1
98
- temperature: Optional[float] = 0.7 # Recommend lower temp for more coherent output
99
- do_sample: Optional[bool] = True
100
- top_p: Optional[float] = 0.9 # Added top_p for more control
101
-
102
- # Response model
103
- class TextGenerationResponse(BaseModel):
104
- generated_text: str
105
- prompt: str
106
- model_name: str
107
 
108
  @app.get("/")
109
  async def root():
 
110
  return {
111
- "message": "Text Generation API",
112
- "status": "running",
113
  "endpoints": {
114
- "generate_post": "/generate", # Renamed for clarity
115
- "generate_get": "/generate_simple", # Renamed for clarity
116
- "health": "/health",
117
- "docs": "/docs"
118
- },
119
- "current_model": MODEL_NAME
120
  }
121
 
122
  @app.get("/health")
123
  async def health_check():
 
124
  return {
125
- "status": "healthy" if generator else "unhealthy",
126
- "model_loaded": generator is not None,
127
- "cuda_available": torch.cuda.is_available(),
128
- "model_name": MODEL_NAME
129
  }
130
 
131
- @app.post("/generate", response_model=TextGenerationResponse)
132
- async def generate_text_post(request: TextGenerationRequest):
133
- if generator is None:
134
- raise HTTPException(status_code=503, detail="Model not loaded yet. Service unavailable.")
135
-
 
136
  try:
137
- # Generate text
138
- result = generator(
139
- request.prompt,
140
- max_new_tokens=request.max_new_tokens, # Use max_new_tokens
141
- num_return_sequences=request.num_return_sequences,
142
- temperature=request.temperature,
143
- do_sample=request.do_sample,
144
- top_p=request.top_p, # Pass top_p
145
- pad_token_id=generator.tokenizer.eos_token_id,
146
- eos_token_id=generator.tokenizer.eos_token_id,
147
- # Add stop sequences relevant to your instruction-tuned model format
148
- # stop_sequences=["\nUser:", "\n###", "\n\n"]
149
- )
150
 
151
- generated_text = result[0]['generated_text']
152
 
153
- return TextGenerationResponse(
154
- generated_text=generated_text,
155
- prompt=request.prompt,
156
- model_name=MODEL_NAME
157
  )
158
-
159
  except Exception as e:
160
- print(f"Generation failed: {str(e)}")
161
- raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}. Check Space logs for details.")
162
 
163
- @app.get("/generate_simple") # Changed endpoint name to avoid conflict with POST
164
- async def generate_text_get(
165
- prompt: str,
166
- max_new_tokens: int = 250, # Changed from max_length
167
- temperature: float = 0.7
168
- ):
169
- """GET endpoint for simple text generation"""
170
- if generator is None:
171
- raise HTTPException(status_code=503, detail="Model not loaded yet. Service unavailable.")
172
-
173
  try:
174
- result = generator(
175
- prompt,
176
- max_new_tokens=max_new_tokens,
177
- num_return_sequences=1,
178
- temperature=temperature,
179
- do_sample=True,
180
- top_p=0.9, # Default top_p for simple GET
181
- pad_token_id=generator.tokenizer.eos_token_id,
182
- eos_token_id=generator.tokenizer.eos_token_id,
183
- )
 
 
 
184
 
185
  return {
186
- "generated_text": result[0]['generated_text'],
187
- "prompt": prompt,
188
- "model_name": MODEL_NAME
189
  }
190
-
191
  except Exception as e:
192
- print(f"Generation failed: {str(e)}")
193
- raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}. Check Space logs for details.")
 
 
 
 
 
 
 
 
194
 
195
  if __name__ == "__main__":
196
- port = int(os.environ.get("PORT", 7860)) # Hugging Face Spaces uses port 7860
197
- uvicorn.run(app, host="0.0.0.0", port=port)
 
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
  from typing import Optional
 
5
  import uvicorn
6
+ from src.RAGSample import setup_retriever, setup_rag_chain, RAGApplication
7
  import os
8
+ from dotenv import load_dotenv
 
9
 
10
+ # Load environment variables
11
+ load_dotenv()
12
 
13
+ # Create FastAPI app
14
+ app = FastAPI(
15
+ title="RAG API",
16
+ description="A REST API for Retrieval-Augmented Generation using local vector database",
17
+ version="1.0.0"
18
+ )
19
 
20
+ # Initialize RAG components (this will be done once when the server starts)
21
+ retriever = None
22
+ rag_chain = None
23
+ rag_application = None
24
 
25
+ # Pydantic model for request
26
+ class QuestionRequest(BaseModel):
27
+ question: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ # Pydantic model for response
30
+ class QuestionResponse(BaseModel):
31
+ question: str
32
+ answer: str
 
 
 
 
 
 
 
 
33
 
34
+ @app.on_event("startup")
35
+ async def startup_event():
36
+ """Initialize RAG components when the server starts."""
37
+ global retriever, rag_chain, rag_application
38
+ try:
39
+ print("Initializing RAG components...")
 
 
 
 
 
 
 
 
 
 
40
 
41
+ # Check if Kaggle credentials are provided via environment variables
42
+ kaggle_username = os.getenv("KAGGLE_USERNAME")
43
+ kaggle_key = os.getenv("KAGGLE_KEY")
44
+ kaggle_dataset = os.getenv("KAGGLE_DATASET")
45
 
46
+ # If no environment variables, try to load from kaggle.json
47
+ if not (kaggle_username and kaggle_key):
48
+ try:
49
+ from src.kaggle_loader import KaggleDataLoader
50
+ # Test if we can create a loader (this will auto-load from kaggle.json)
51
+ test_loader = KaggleDataLoader()
52
+ if test_loader.kaggle_username and test_loader.kaggle_key:
53
+ kaggle_username = test_loader.kaggle_username
54
+ kaggle_key = test_loader.kaggle_key
55
+ print(f"Loaded Kaggle credentials from kaggle.json: {kaggle_username}")
56
+ except Exception as e:
57
+ print(f"Could not load Kaggle credentials from kaggle.json: {e}")
58
+
59
+ if kaggle_username and kaggle_key and kaggle_dataset:
60
+ print(f"Loading Kaggle dataset: {kaggle_dataset}")
61
+ retriever = setup_retriever(
62
+ use_kaggle_data=True,
63
+ kaggle_dataset=kaggle_dataset,
64
+ kaggle_username=kaggle_username,
65
+ kaggle_key=kaggle_key
66
+ )
67
+ else:
68
+ print("Loading mental health FAQ data from local file...")
69
+ # Load mental health FAQ data from local file (default behavior)
70
+ retriever = setup_retriever()
71
+
72
+ rag_chain = setup_rag_chain()
73
+ rag_application = RAGApplication(retriever, rag_chain)
74
+ print("RAG components initialized successfully!")
75
  except Exception as e:
76
+ print(f"Error initializing RAG components: {e}")
77
+ raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  @app.get("/")
80
  async def root():
81
+ """Root endpoint with API information."""
82
  return {
83
+ "message": "RAG API is running",
 
84
  "endpoints": {
85
+ "ask_question": "/ask",
86
+ "health_check": "/health",
87
+ "load_kaggle_dataset": "/load-kaggle-dataset"
88
+ }
 
 
89
  }
90
 
91
  @app.get("/health")
92
  async def health_check():
93
+ """Health check endpoint."""
94
  return {
95
+ "status": "healthy",
96
+ "rag_initialized": rag_application is not None
 
 
97
  }
98
 
99
+ @app.post("/ask", response_model=QuestionResponse)
100
+ async def ask_question(request: QuestionRequest):
101
+ """Ask a question and get an answer using RAG."""
102
+ if rag_application is None:
103
+ raise HTTPException(status_code=500, detail="RAG application not initialized")
104
+
105
  try:
106
+ print(f"Processing question: {request.question}")
107
+
108
+ # Debug: Check what retriever we're using
109
+ retriever_type = type(rag_application.retriever).__name__
110
+ print(f"DEBUG: Using retriever type: {retriever_type}")
 
 
 
 
 
 
 
 
111
 
112
+ answer = rag_application.run(request.question)
113
 
114
+ return QuestionResponse(
115
+ question=request.question,
116
+ answer=answer
 
117
  )
 
118
  except Exception as e:
119
+ print(f"Error processing question: {e}")
120
+ raise HTTPException(status_code=500, detail=f"Error processing question: {str(e)}")
121
 
122
+ @app.post("/load-kaggle-dataset")
123
+ async def load_kaggle_dataset(dataset_name: str):
124
+ """Load a Kaggle dataset for RAG."""
 
 
 
 
 
 
 
125
  try:
126
+ from src.kaggle_loader import KaggleDataLoader
127
+
128
+ # Create loader without parameters - it will auto-load from kaggle.json
129
+ loader = KaggleDataLoader()
130
+
131
+ # Download the dataset
132
+ dataset_path = loader.download_dataset(dataset_name)
133
+
134
+ # Reload the retriever with the new dataset
135
+ global rag_application
136
+ retriever = setup_retriever(use_kaggle_data=True, kaggle_dataset=dataset_name)
137
+ rag_chain = setup_rag_chain()
138
+ rag_application = RAGApplication(retriever, rag_chain)
139
 
140
  return {
141
+ "status": "success",
142
+ "message": f"Dataset {dataset_name} loaded successfully",
143
+ "dataset_path": dataset_path
144
  }
 
145
  except Exception as e:
146
+ return {"status": "error", "message": str(e)}
147
+
148
+ @app.get("/models")
149
+ async def get_models():
150
+ """Get information about available models."""
151
+ return {
152
+ "llm_model": "dolphin-llama3:8b",
153
+ "embedding_model": "TF-IDF embeddings",
154
+ "vector_database": "ChromaDB (local)"
155
+ }
156
 
157
  if __name__ == "__main__":
158
+ uvicorn.run(app, host="0.0.0.0", port=8000)