brendon-ai commited on
Commit
51a0302
·
verified ·
1 Parent(s): dcbac7e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -19
app.py CHANGED
@@ -1,22 +1,133 @@
1
  import torch
2
  from transformers import pipeline
 
 
 
 
 
3
 
4
- # Check for GPU
5
- if torch.cuda.is_available():
6
- print(f"CUDA is available! Using {torch.cuda.get_device_name(0)}")
7
- device = 0 # Use GPU
8
- else:
9
- print("CUDA not available, using CPU.")
10
- device = -1 # Use CPU
11
-
12
- # Load a text generation pipeline
13
- # For a free tier/small GPU, consider a smaller model like 'distilgpt2' or 'gpt2'
14
- # For larger GPUs, you can try models like 'meta-llama/Llama-2-7b-hf' (requires auth)
15
- # or 'mistralai/Mistral-7B-Instruct-v0.2'
16
- generator = pipeline('text-generation', model='distilgpt2', device=device) # or specify a larger model
17
-
18
- # Generate text
19
- prompt = "The quick brown fox jumps over the lazy dog because"
20
- result = generator(prompt, max_length=50, num_return_sequences=1)
21
-
22
- print(result[0]['generated_text'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  from transformers import pipeline
3
+ from fastapi import FastAPI, HTTPException
4
+ from pydantic import BaseModel
5
+ from typing import Optional
6
+ import uvicorn
7
+ import os
8
 
9
+ # Initialize FastAPI app
10
+ app = FastAPI(
11
+ title="Text Generation API",
12
+ description="A simple text generation API using Hugging Face transformers",
13
+ version="1.0.0"
14
+ )
15
+
16
+ # Request model
17
+ class TextGenerationRequest(BaseModel):
18
+ prompt: str
19
+ max_length: Optional[int] = 50
20
+ num_return_sequences: Optional[int] = 1
21
+ temperature: Optional[float] = 1.0
22
+ do_sample: Optional[bool] = True
23
+
24
+ # Response model
25
+ class TextGenerationResponse(BaseModel):
26
+ generated_text: str
27
+ prompt: str
28
+
29
+ # Global variable to store the pipeline
30
+ generator = None
31
+
32
+ @app.on_event("startup")
33
+ async def load_model():
34
+ global generator
35
+
36
+ # Check for GPU
37
+ if torch.cuda.is_available():
38
+ print(f"CUDA is available! Using {torch.cuda.get_device_name(0)}")
39
+ device = 0 # Use GPU
40
+ else:
41
+ print("CUDA not available, using CPU.")
42
+ device = -1 # Use CPU
43
+
44
+ # Load the text generation pipeline
45
+ try:
46
+ generator = pipeline(
47
+ 'text-generation',
48
+ model='distilgpt2',
49
+ device=device,
50
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
51
+ )
52
+ print("Model loaded successfully!")
53
+ except Exception as e:
54
+ print(f"Error loading model: {e}")
55
+ raise e
56
+
57
+ @app.get("/")
58
+ async def root():
59
+ return {
60
+ "message": "Text Generation API",
61
+ "status": "running",
62
+ "endpoints": {
63
+ "generate": "/generate",
64
+ "health": "/health",
65
+ "docs": "/docs"
66
+ }
67
+ }
68
+
69
+ @app.get("/health")
70
+ async def health_check():
71
+ return {
72
+ "status": "healthy",
73
+ "model_loaded": generator is not None,
74
+ "cuda_available": torch.cuda.is_available()
75
+ }
76
+
77
+ @app.post("/generate", response_model=TextGenerationResponse)
78
+ async def generate_text(request: TextGenerationRequest):
79
+ if generator is None:
80
+ raise HTTPException(status_code=503, detail="Model not loaded yet")
81
+
82
+ try:
83
+ # Generate text
84
+ result = generator(
85
+ request.prompt,
86
+ max_length=min(request.max_length, 200), # Limit max length for safety
87
+ num_return_sequences=request.num_return_sequences,
88
+ temperature=request.temperature,
89
+ do_sample=request.do_sample,
90
+ pad_token_id=generator.tokenizer.eos_token_id
91
+ )
92
+
93
+ generated_text = result[0]['generated_text']
94
+
95
+ return TextGenerationResponse(
96
+ generated_text=generated_text,
97
+ prompt=request.prompt
98
+ )
99
+
100
+ except Exception as e:
101
+ raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
102
+
103
+ @app.get("/generate")
104
+ async def generate_text_get(
105
+ prompt: str,
106
+ max_length: int = 50,
107
+ temperature: float = 1.0
108
+ ):
109
+ """GET endpoint for simple text generation"""
110
+ if generator is None:
111
+ raise HTTPException(status_code=503, detail="Model not loaded yet")
112
+
113
+ try:
114
+ result = generator(
115
+ prompt,
116
+ max_length=min(max_length, 200),
117
+ num_return_sequences=1,
118
+ temperature=temperature,
119
+ do_sample=True,
120
+ pad_token_id=generator.tokenizer.eos_token_id
121
+ )
122
+
123
+ return {
124
+ "generated_text": result[0]['generated_text'],
125
+ "prompt": prompt
126
+ }
127
+
128
+ except Exception as e:
129
+ raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}")
130
+
131
+ if __name__ == "__main__":
132
+ port = int(os.environ.get("PORT", 7860)) # Hugging Face Spaces uses port 7860
133
+ uvicorn.run(app, host="0.0.0.0", port=port)