brendon-ai commited on
Commit
5734a73
·
verified ·
1 Parent(s): 6806851

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -19
app.py CHANGED
@@ -1,22 +1,74 @@
 
 
 
1
  import torch
2
  from transformers import pipeline
 
3
 
4
- # Check for GPU
5
- if torch.cuda.is_available():
6
- print(f"CUDA is available! Using {torch.cuda.get_device_name(0)}")
7
- device = 0 # Use GPU
8
- else:
9
- print("CUDA not available, using CPU.")
10
- device = -1 # Use CPU
11
-
12
- # Load a text generation pipeline
13
- # For a free tier/small GPU, consider a smaller model like 'distilgpt2' or 'gpt2'
14
- # For larger GPUs, you can try models like 'meta-llama/Llama-2-7b-hf' (requires auth)
15
- # or 'mistralai/Mistral-7B-Instruct-v0.2'
16
- generator = pipeline('text-generation', model='distilgpt2', device=device) # or specify a larger model
17
-
18
- # Generate text
19
- prompt = "The quick brown fox jumps over the lazy dog because"
20
- result = generator(prompt, max_length=50, num_return_sequences=1)
21
-
22
- print(result[0]['generated_text'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ from fastapi import FastAPI, HTTPException
3
+ from pydantic import BaseModel
4
  import torch
5
  from transformers import pipeline
6
+ import os
7
 
8
+ app = FastAPI()
9
+
10
+ # --- Model Loading (Global Scope to load once) ---
11
+ # This part will be executed only once when the FastAPI application starts up.
12
+ # This saves memory and time compared to loading the model on every request.
13
+
14
+ generator = None # Initialize generator to None
15
+
16
+ @app.on_event("startup")
17
+ async def load_model():
18
+ """
19
+ Load the model when the FastAPI application starts.
20
+ """
21
+ global generator
22
+ try:
23
+ # Check for GPU
24
+ if torch.cuda.is_available():
25
+ print(f"CUDA is available! Using {torch.cuda.get_device_name(0)}")
26
+ device = 0 # Use GPU
27
+ else:
28
+ print("CUDA not available, using CPU.")
29
+ device = -1 # Use CPU
30
+
31
+ # Load a text generation pipeline
32
+ # For a free tier/small GPU, consider a smaller model like 'distilgpt2' or 'gpt2'
33
+ # For larger GPUs, you can try models like 'meta-llama/Llama-2-7b-hf' (requires auth)
34
+ # or 'mistralai/Mistral-7B-Instruct-v0.2'
35
+ print(f"Loading model 'distilgpt2' on device: {'cuda' if device == 0 else 'cpu'}")
36
+ generator = pipeline('text-generation', model='distilgpt2', device=device)
37
+ print("Model loaded successfully!")
38
+ except Exception as e:
39
+ print(f"Error loading model: {e}")
40
+ # You might want to raise an exception or log this more robustly in production
41
+ # For a simple app, we'll let it fail and then handle requests later.
42
+
43
+ # --- Define Request Body Schema ---
44
+ class PromptRequest(BaseModel):
45
+ prompt: str
46
+ max_length: int = 50 # Default value, can be overridden by user
47
+ num_return_sequences: int = 1 # Default value
48
+
49
+ # --- Define API Endpoint ---
50
+ @app.post("/generate")
51
+ async def generate_text(request: PromptRequest):
52
+ """
53
+ Generates text based on a given prompt using the loaded LLM.
54
+ """
55
+ if generator is None:
56
+ raise HTTPException(status_code=503, detail="Model not loaded. Please try again later.")
57
+
58
+ try:
59
+ result = generator(
60
+ request.prompt,
61
+ max_length=request.max_length,
62
+ num_return_sequences=request.num_return_sequences
63
+ )
64
+ return {"generated_text": result[0]['generated_text']}
65
+ except Exception as e:
66
+ raise HTTPException(status_code=500, detail=f"Error during text generation: {e}")
67
+
68
+ # --- Basic Health Check Endpoint (Optional but Recommended) ---
69
+ @app.get("/")
70
+ async def read_root():
71
+ """
72
+ A simple health check endpoint to confirm the API is running.
73
+ """
74
+ return {"message": "LLM Inference API is running!"}