usmansafdarktk commited on
Commit
f418bc1
·
1 Parent(s): 4cb09ae

Add simple Tailwind UI and serve via FastAPI static files

Browse files
Files changed (2) hide show
  1. main.py +21 -5
  2. static/index.html +7 -3
main.py CHANGED
@@ -4,6 +4,7 @@ import torch
4
  import gc
5
  from fastapi import FastAPI, HTTPException
6
  from fastapi.staticfiles import StaticFiles
 
7
  from pydantic import BaseModel
8
  from transformers import pipeline
9
 
@@ -12,8 +13,17 @@ logger = logging.getLogger(__name__)
12
 
13
  app = FastAPI(title="LaMini-LM API", description="API for text generation using LaMini-GPT-774M", version="1.0.0")
14
 
15
- # Mount static files
16
- app.mount("/", StaticFiles(directory="static", html=True), name="static")
 
 
 
 
 
 
 
 
 
17
 
18
  class TextGenerationRequest(BaseModel):
19
  instruction: str
@@ -28,7 +38,7 @@ def load_model():
28
  if generator is None:
29
  try:
30
  logger.info("Loading LaMini-GPT-774M model...")
31
- generator = pipeline('text-generation', model='MBZUAI/LaMini-GPT-774M', device=-1)
32
  logger.info("Model loaded successfully.")
33
  if torch.cuda.is_available():
34
  torch.cuda.empty_cache()
@@ -36,14 +46,19 @@ def load_model():
36
  except Exception as e:
37
  logger.error(f"Failed to load model: {str(e)}")
38
  generator = None
39
- raise Exception(f"Model loading failed: {str(e)}")
40
 
41
  @app.get("/health")
42
  async def health_check():
43
  return {"status": "healthy"}
44
 
 
 
 
 
45
  @app.post("/generate")
46
  async def generate_text(request: TextGenerationRequest):
 
47
  if generator is None:
48
  load_model()
49
  if generator is None:
@@ -67,7 +82,8 @@ async def generate_text(request: TextGenerationRequest):
67
  temperature=request.temperature,
68
  top_p=request.top_p,
69
  num_return_sequences=1,
70
- do_sample=True
 
71
  )
72
  generated_text = outputs[0]['generated_text'].replace(wrapper, "").strip()
73
  return {"generated_text": generated_text}
 
4
  import gc
5
  from fastapi import FastAPI, HTTPException
6
  from fastapi.staticfiles import StaticFiles
7
+ from fastapi.middleware.cors import CORSMiddleware
8
  from pydantic import BaseModel
9
  from transformers import pipeline
10
 
 
13
 
14
  app = FastAPI(title="LaMini-LM API", description="API for text generation using LaMini-GPT-774M", version="1.0.0")
15
 
16
+ # Add CORS middleware to allow UI requests
17
+ app.add_middleware(
18
+ CORSMiddleware,
19
+ allow_origins=["*"], # Adjust for production to specific origins
20
+ allow_credentials=True,
21
+ allow_methods=["*"],
22
+ allow_headers=["*"],
23
+ )
24
+
25
+ # Mount static files at /ui
26
+ app.mount("/ui", StaticFiles(directory="static", html=True), name="static")
27
 
28
  class TextGenerationRequest(BaseModel):
29
  instruction: str
 
38
  if generator is None:
39
  try:
40
  logger.info("Loading LaMini-GPT-774M model...")
41
+ generator = pipeline('text-generation', model='MBZUAI/LaMini-GPT-774M', device=-1, trust_remote_code=True)
42
  logger.info("Model loaded successfully.")
43
  if torch.cuda.is_available():
44
  torch.cuda.empty_cache()
 
46
  except Exception as e:
47
  logger.error(f"Failed to load model: {str(e)}")
48
  generator = None
49
+ raise HTTPException(status_code=503, detail=f"Model loading failed: {str(e)}")
50
 
51
  @app.get("/health")
52
  async def health_check():
53
  return {"status": "healthy"}
54
 
55
+ @app.get("/")
56
+ async def root():
57
+ return {"message": "Welcome to the LaMini-LM API. Use POST /generate to generate text or visit /ui for the web interface."}
58
+
59
  @app.post("/generate")
60
  async def generate_text(request: TextGenerationRequest):
61
+ logger.info(f"Received request: {request.dict()}")
62
  if generator is None:
63
  load_model()
64
  if generator is None:
 
82
  temperature=request.temperature,
83
  top_p=request.top_p,
84
  num_return_sequences=1,
85
+ do_sample=True,
86
+ truncation=True
87
  )
88
  generated_text = outputs[0]['generated_text'].replace(wrapper, "").strip()
89
  return {"generated_text": generated_text}
static/index.html CHANGED
@@ -126,9 +126,13 @@
126
  }
127
 
128
  try {
129
- const response = await fetch('/generate', {
130
  method: 'POST',
131
- headers: { 'Content-Type': 'application/json' },
 
 
 
 
132
  body: JSON.stringify({
133
  instruction,
134
  max_length: maxLength,
@@ -141,7 +145,7 @@
141
  resultDiv.classList.remove('hidden');
142
  generatedText.textContent = data.generated_text;
143
  } else {
144
- showError(data.detail?.[0]?.msg || 'Failed to generate text.');
145
  }
146
  } catch (err) {
147
  showError('Error connecting to the API. Please try again.');
 
126
  }
127
 
128
  try {
129
+ const response = await fetch('https://usmansafder-lamini-lm-api.hf.space/generate', {
130
  method: 'POST',
131
+ headers: {
132
+ 'Content-Type': 'application/json',
133
+ // Add Authorization header if Space is private
134
+ // 'Authorization': 'Bearer <your_hf_token>'
135
+ },
136
  body: JSON.stringify({
137
  instruction,
138
  max_length: maxLength,
 
145
  resultDiv.classList.remove('hidden');
146
  generatedText.textContent = data.generated_text;
147
  } else {
148
+ showError(data.detail?.[0]?.msg || data.detail || 'Failed to generate text.');
149
  }
150
  } catch (err) {
151
  showError('Error connecting to the API. Please try again.');