brendon-ai commited on
Commit
dcbac7e
·
verified ·
1 Parent(s): 5ccc276

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -88
app.py CHANGED
@@ -1,91 +1,22 @@
1
- # app.py
2
- from fastapi import FastAPI, HTTPException
3
- from pydantic import BaseModel
4
  import torch
5
  from transformers import pipeline
6
- import os
7
- from contextlib import asynccontextmanager # Import this!
8
 
9
- # --- Global variable for the model ---
10
- # It's important to declare this globally so it can be accessed within
11
- # the lifespan function and the API endpoint functions.
12
- generator = None
13
-
14
- # --- Lifespan Event Handler ---
15
- @asynccontextmanager
16
- async def lifespan(app: FastAPI):
17
- """
18
- Handles startup and shutdown events for the FastAPI application.
19
- Loads the model on startup and can optionally clean up on shutdown.
20
- """
21
- global generator # Declare intent to modify the global 'generator' variable
22
- try:
23
- # --- Startup Code: Load the model ---
24
- # This code runs BEFORE the application starts receiving requests.
25
- if torch.cuda.is_available():
26
- print(f"CUDA is available! Using {torch.cuda.get_device_name(0)}")
27
- device = 0 # Use GPU
28
- else:
29
- print("CUDA not available, using CPU.")
30
- device = -1 # Use CPU
31
-
32
- print(f"Loading model 'distilgpt2' on device: {'cuda' if device == 0 else 'cpu'}")
33
- generator = pipeline('text-generation', model='distilgpt2', device=device)
34
- print("Model loaded successfully!")
35
-
36
- # 'yield' signifies that the startup code has completed, and the application
37
- # can now start processing requests.
38
- yield
39
-
40
- except Exception as e:
41
- print(f"Error loading model during startup: {e}")
42
- # In a real application, you might want to exit here if the model is crucial
43
- # sys.exit(1) or raise an exception to prevent the app from starting unhealthy.
44
-
45
- finally:
46
- # --- Shutdown Code (Optional): Clean up resources ---
47
- # This code runs AFTER the application has finished handling requests and is shutting down.
48
- # For a simple model loaded like this, there might not be explicit cleanup needed.
49
- # If you had database connections, external client sessions, etc., you'd close them here.
50
- print("Application shutting down. Any cleanup can go here.")
51
-
52
-
53
- # --- Initialize FastAPI application with the lifespan handler ---
54
- app = FastAPI(lifespan=lifespan) # Pass the lifespan function to the FastAPI app
55
-
56
-
57
- # --- Define Request Body Schema ---
58
- class PromptRequest(BaseModel):
59
- prompt: str
60
- max_length: int = 50 # Default value, can be overridden by user
61
- num_return_sequences: int = 1 # Default value
62
-
63
- # --- Define API Endpoint ---
64
- @app.post("/generate")
65
- async def generate_text(request: PromptRequest):
66
- """
67
- Generates text based on a given prompt using the loaded LLM.
68
- """
69
- if generator is None:
70
- # This indicates a failure during startup, or the app started unhealthy
71
- raise HTTPException(status_code=503, detail="Model not loaded. Service unavailable.")
72
-
73
- try:
74
- result = generator(
75
- request.prompt,
76
- max_length=request.max_length,
77
- num_return_sequences=request.num_return_sequences
78
- )
79
- return {"generated_text": result[0]['generated_text']}
80
- except Exception as e:
81
- # Log the full exception for debugging in production
82
- print(f"Error during text generation: {e}")
83
- raise HTTPException(status_code=500, detail=f"Error during text generation: {e}")
84
-
85
- # --- Basic Health Check Endpoint ---
86
- @app.get("/")
87
- async def read_root():
88
- """
89
- A simple health check endpoint to confirm the API is running.
90
- """
91
- return {"message": "LLM Inference API is running!"}
 
 
 
 
1
  import torch
2
  from transformers import pipeline
 
 
3
 
4
+ # Check for GPU
5
+ if torch.cuda.is_available():
6
+ print(f"CUDA is available! Using {torch.cuda.get_device_name(0)}")
7
+ device = 0 # Use GPU
8
+ else:
9
+ print("CUDA not available, using CPU.")
10
+ device = -1 # Use CPU
11
+
12
+ # Load a text generation pipeline
13
+ # For a free tier/small GPU, consider a smaller model like 'distilgpt2' or 'gpt2'
14
+ # For larger GPUs, you can try models like 'meta-llama/Llama-2-7b-hf' (requires auth)
15
+ # or 'mistralai/Mistral-7B-Instruct-v0.2'
16
+ generator = pipeline('text-generation', model='distilgpt2', device=device) # or specify a larger model
17
+
18
+ # Generate text
19
+ prompt = "The quick brown fox jumps over the lazy dog because"
20
+ result = generator(prompt, max_length=50, num_return_sequences=1)
21
+
22
+ print(result[0]['generated_text'])