usmansafdarktk commited on
Commit
1574d49
·
1 Parent(s): c963314

Add torch import to fix model loading error

Browse files
Files changed (1) hide show
  1. main.py +23 -22
main.py CHANGED
@@ -1,37 +1,38 @@
1
  import os
 
 
 
2
  from fastapi import FastAPI, HTTPException
3
  from pydantic import BaseModel
4
  from transformers import pipeline
5
- import logging
6
 
7
- # Set up logging
8
  logging.basicConfig(level=logging.INFO)
9
  logger = logging.getLogger(__name__)
10
 
11
- # Log cache directory
12
- logger.info(f"TRANSFORMERS_CACHE set to: {os.getenv('TRANSFORMERS_CACHE', '/.cache')}")
13
-
14
- app = FastAPI(title="LaMini-LM API",
15
- description="API for text generation using LaMini-GPT-774M", version="1.0.0")
16
 
17
- # Define request model
18
  class TextGenerationRequest(BaseModel):
19
- instruction: str # Changed from prompt for consistency
20
  max_length: int = 100
21
  temperature: float = 1.0
22
  top_p: float = 0.9
23
 
24
- # Load model (cached after first load)
25
- try:
26
- logger.info("Loading LaMini-GPT-774M model...")
27
- generator = pipeline('text-generation', model='MBZUAI/LaMini-GPT-774M', device=-1)
28
- logger.info("Model loaded successfully.")
29
- if torch.cuda.is_available():
30
- torch.cuda.empty_cache()
31
- gc.collect()
32
- except Exception as e:
33
- logger.error(f"Failed to load model: {str(e)}")
34
- generator = None # Allow server to run for health check
 
 
 
 
 
35
 
36
  @app.get("/health")
37
  async def health_check():
@@ -43,10 +44,11 @@ async def root():
43
 
44
  @app.post("/generate")
45
  async def generate_text(request: TextGenerationRequest):
 
 
46
  if generator is None:
47
  raise HTTPException(status_code=503, detail="Model not loaded. Check server logs.")
48
  try:
49
- # Validate inputs
50
  if not request.instruction.strip():
51
  raise HTTPException(status_code=400, detail="Instruction cannot be empty")
52
  if request.max_length < 10 or request.max_length > 500:
@@ -56,7 +58,6 @@ async def generate_text(request: TextGenerationRequest):
56
  if request.top_p <= 0 or request.top_p > 1:
57
  raise HTTPException(status_code=400, detail="top_p must be between 0 and 1")
58
 
59
- # Generate text
60
  logger.info(f"Generating text for instruction: {request.instruction[:50]}...")
61
  wrapper = "Instruction: You are a helpful assistant. Please respond to the following instruction.\n\nInstruction: {}\n\nResponse:".format(
62
  request.instruction)
 
1
  import os
2
+ import logging
3
+ import torch
4
+ import gc
5
  from fastapi import FastAPI, HTTPException
6
  from pydantic import BaseModel
7
  from transformers import pipeline
 
8
 
 
9
  logging.basicConfig(level=logging.INFO)
10
  logger = logging.getLogger(__name__)
11
 
12
+ app = FastAPI(title="LaMini-LM API", description="API for text generation using LaMini-GPT-774M", version="1.0.0")
 
 
 
 
13
 
 
14
  class TextGenerationRequest(BaseModel):
15
+ instruction: str
16
  max_length: int = 100
17
  temperature: float = 1.0
18
  top_p: float = 0.9
19
 
20
+ generator = None
21
+
22
+ def load_model():
23
+ global generator
24
+ if generator is None:
25
+ try:
26
+ logger.info("Loading LaMini-GPT-774M model...")
27
+ generator = pipeline('text-generation', model='MBZUAI/LaMini-GPT-774M', device=-1)
28
+ logger.info("Model loaded successfully.")
29
+ if torch.cuda.is_available():
30
+ torch.cuda.empty_cache()
31
+ gc.collect()
32
+ except Exception as e:
33
+ logger.error(f"Failed to load model: {str(e)}")
34
+ generator = None
35
+ raise Exception(f"Model loading failed: {str(e)}")
36
 
37
  @app.get("/health")
38
  async def health_check():
 
44
 
45
  @app.post("/generate")
46
  async def generate_text(request: TextGenerationRequest):
47
+ if generator is None:
48
+ load_model()
49
  if generator is None:
50
  raise HTTPException(status_code=503, detail="Model not loaded. Check server logs.")
51
  try:
 
52
  if not request.instruction.strip():
53
  raise HTTPException(status_code=400, detail="Instruction cannot be empty")
54
  if request.max_length < 10 or request.max_length > 500:
 
58
  if request.top_p <= 0 or request.top_p > 1:
59
  raise HTTPException(status_code=400, detail="top_p must be between 0 and 1")
60
 
 
61
  logger.info(f"Generating text for instruction: {request.instruction[:50]}...")
62
  wrapper = "Instruction: You are a helpful assistant. Please respond to the following instruction.\n\nInstruction: {}\n\nResponse:".format(
63
  request.instruction)