usmansafdarktk commited on
Commit
14ad9da
·
1 Parent(s): b854771

Fix main.py: add os import, health endpoint, memory management, use instruction

Browse files
Files changed (1) hide show
  1. main.py +29 -35
main.py CHANGED
@@ -15,51 +15,51 @@ app = FastAPI(title="LaMini-LM API",
15
  description="API for text generation using LaMini-GPT-774M", version="1.0.0")
16
 
17
  # Define request model
18
-
19
-
20
  class TextGenerationRequest(BaseModel):
21
- prompt: str
22
  max_length: int = 100
23
  temperature: float = 1.0
24
  top_p: float = 0.9
25
 
26
-
27
  # Load model (cached after first load)
28
  try:
29
  logger.info("Loading LaMini-GPT-774M model...")
30
- # device=-1 for CPU
31
- generator = pipeline(
32
- 'text-generation', model='MBZUAI/LaMini-GPT-774M', device=-1)
33
  logger.info("Model loaded successfully.")
 
 
 
34
  except Exception as e:
35
  logger.error(f"Failed to load model: {str(e)}")
36
- raise Exception(f"Model loading failed: {str(e)}")
37
 
 
 
 
 
 
 
 
38
 
39
  @app.post("/generate")
40
  async def generate_text(request: TextGenerationRequest):
41
- """
42
- Generate text based on the input prompt using LaMini-GPT-774M.
43
- """
44
  try:
45
  # Validate inputs
46
- if not request.prompt.strip():
47
- raise HTTPException(
48
- status_code=400, detail="Prompt cannot be empty")
49
  if request.max_length < 10 or request.max_length > 500:
50
- raise HTTPException(
51
- status_code=400, detail="max_length must be between 10 and 500")
52
  if request.temperature <= 0 or request.temperature > 2:
53
- raise HTTPException(
54
- status_code=400, detail="temperature must be between 0 and 2")
55
  if request.top_p <= 0 or request.top_p > 1:
56
- raise HTTPException(
57
- status_code=400, detail="top_p must be between 0 and 1")
58
 
59
  # Generate text
60
- logger.info(f"Generating text for prompt: {request.prompt[:50]}...")
61
- wrapper = "Instruction: You are a helpful assistant. Please respond to the following prompt.\n\nPrompt: {}\n\nResponse:".format(
62
- request.prompt)
63
  outputs = generator(
64
  wrapper,
65
  max_length=request.max_length,
@@ -68,19 +68,13 @@ async def generate_text(request: TextGenerationRequest):
68
  num_return_sequences=1,
69
  do_sample=True
70
  )
71
- generated_text = outputs[0]['generated_text'].replace(
72
- wrapper, "").strip()
73
-
74
  return {"generated_text": generated_text}
75
  except Exception as e:
76
  logger.error(f"Error during text generation: {str(e)}")
77
- raise HTTPException(
78
- status_code=500, detail=f"Text generation failed: {str(e)}")
79
-
80
 
81
- @app.get("/")
82
- async def root():
83
- """
84
- Root endpoint with basic info.
85
- """
86
- return {"message": "Welcome to the LaMini-LM API. Use POST /generate to generate text."}
 
15
  description="API for text generation using LaMini-GPT-774M", version="1.0.0")
16
 
17
  # Define request model
 
 
18
  class TextGenerationRequest(BaseModel):
19
+ instruction: str # Changed from prompt for consistency
20
  max_length: int = 100
21
  temperature: float = 1.0
22
  top_p: float = 0.9
23
 
 
24
  # Load model (cached after first load)
25
  try:
26
  logger.info("Loading LaMini-GPT-774M model...")
27
+ generator = pipeline('text-generation', model='MBZUAI/LaMini-GPT-774M', device=-1)
 
 
28
  logger.info("Model loaded successfully.")
29
+ if torch.cuda.is_available():
30
+ torch.cuda.empty_cache()
31
+ gc.collect()
32
  except Exception as e:
33
  logger.error(f"Failed to load model: {str(e)}")
34
+ generator = None # Allow server to run for health check
35
 
36
+ @app.get("/health")
37
+ async def health_check():
38
+ return {"status": "healthy"}
39
+
40
+ @app.get("/")
41
+ async def root():
42
+ return {"message": "Welcome to the LaMini-LM API. Use POST /generate to generate text."}
43
 
44
  @app.post("/generate")
45
  async def generate_text(request: TextGenerationRequest):
46
+ if generator is None:
47
+ raise HTTPException(status_code=503, detail="Model not loaded. Check server logs.")
 
48
  try:
49
  # Validate inputs
50
+ if not request.instruction.strip():
51
+ raise HTTPException(status_code=400, detail="Instruction cannot be empty")
 
52
  if request.max_length < 10 or request.max_length > 500:
53
+ raise HTTPException(status_code=400, detail="max_length must be between 10 and 500")
 
54
  if request.temperature <= 0 or request.temperature > 2:
55
+ raise HTTPException(status_code=400, detail="temperature must be between 0 and 2")
 
56
  if request.top_p <= 0 or request.top_p > 1:
57
+ raise HTTPException(status_code=400, detail="top_p must be between 0 and 1")
 
58
 
59
  # Generate text
60
+ logger.info(f"Generating text for instruction: {request.instruction[:50]}...")
61
+ wrapper = "Instruction: You are a helpful assistant. Please respond to the following instruction.\n\nInstruction: {}\n\nResponse:".format(
62
+ request.instruction)
63
  outputs = generator(
64
  wrapper,
65
  max_length=request.max_length,
 
68
  num_return_sequences=1,
69
  do_sample=True
70
  )
71
+ generated_text = outputs[0]['generated_text'].replace(wrapper, "").strip()
 
 
72
  return {"generated_text": generated_text}
73
  except Exception as e:
74
  logger.error(f"Error during text generation: {str(e)}")
75
+ raise HTTPException(status_code=500, detail=f"Text generation failed: {str(e)}")
 
 
76
 
77
+ if __name__ == "__main__":
78
+ import uvicorn
79
+ port = int(os.environ.get("PORT", 7860))
80
+ uvicorn.run(app, host="0.0.0.0", port=port)