videopix commited on
Commit
8ddc590
·
verified ·
1 Parent(s): 7967b4b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -76
app.py CHANGED
@@ -15,66 +15,57 @@ from diffusers import DiffusionPipeline
15
 
16
 
17
  # -------------------------------------------------------------
18
- # HuggingFace Token (optional)
19
  # -------------------------------------------------------------
20
- HF_TOKEN = os.getenv("HF_TOKEN") # <-- added
21
 
22
 
23
  # -------------------------------------------------------------
24
- # Model / device setup
25
  # -------------------------------------------------------------
26
  MODEL_REPO = "stabilityai/sdxl-turbo"
27
 
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
30
 
31
- print(f"Loading {MODEL_REPO} on {device} with dtype={dtype}...")
32
 
33
- # Load with token if present
 
34
  pipe = DiffusionPipeline.from_pretrained(
35
  MODEL_REPO,
36
  torch_dtype=dtype,
37
  use_safetensors=True,
38
- token=HF_TOKEN if HF_TOKEN else None, # <-- added
39
  )
40
 
41
  pipe.to(device)
42
 
43
- # Optional CPU optimization
44
  if device == "cpu":
45
  try:
46
  pipe.enable_model_cpu_offload()
47
- except Exception:
48
  pass
49
 
50
  print("Model ready.")
51
 
52
 
53
  # -------------------------------------------------------------
54
- # Image generation core
55
  # -------------------------------------------------------------
56
- def generate_image(
57
- prompt: str,
58
- negative_prompt: str,
59
- seed: int,
60
- width: int,
61
- height: int,
62
- num_inference_steps: int,
63
- guidance_scale: float,
64
- ):
65
  generator = torch.Generator(device=device).manual_seed(seed)
66
 
67
- out = pipe(
68
  prompt=prompt,
69
  negative_prompt=negative_prompt if negative_prompt else None,
70
- guidance_scale=guidance_scale,
71
- num_inference_steps=num_inference_steps,
72
  width=width,
73
  height=height,
74
  generator=generator,
75
  )
76
 
77
- return out.images[0]
78
 
79
 
80
  # -------------------------------------------------------------
@@ -103,7 +94,7 @@ async def run_generate(prompt, negative_prompt, seed, width, height, steps, guid
103
  # -------------------------------------------------------------
104
  # FastAPI App
105
  # -------------------------------------------------------------
106
- app = FastAPI(title="SDXL Turbo Text2Image", version="1.0")
107
 
108
  app.add_middleware(
109
  CORSMiddleware,
@@ -115,7 +106,7 @@ app.add_middleware(
115
 
116
 
117
  # -------------------------------------------------------------
118
- # Simple Web UI
119
  # -------------------------------------------------------------
120
  @app.get("/", response_class=HTMLResponse)
121
  def home():
@@ -123,63 +114,53 @@ def home():
123
  <!doctype html>
124
  <html>
125
  <head>
126
- <meta charset="utf-8" />
127
  <title>SDXL Turbo CPU Generator</title>
128
  <style>
129
  body { font-family: Arial; max-width: 900px; margin: 30px auto; }
130
- textarea { width: 100%; padding: 10px; border-radius: 6px; border: 1px solid #ccc; margin-bottom: 10px; }
131
- button { padding: 12px 18px; background:black; color:white; border:none; cursor:pointer; margin-top:10px; }
132
- img { margin-top:20px; max-width:100%; border-radius:10px; }
133
  #status { margin-top:10px; }
134
  </style>
135
  </head>
136
  <body>
137
- <h1>SDXL Turbo Text to Image</h1>
138
 
139
- <textarea id="prompt" rows="3" placeholder="Astronaut in a jungle, 8k, cold colors"></textarea>
140
 
141
- <textarea id="neg" rows="2" placeholder="Negative prompt (optional)"></textarea>
 
142
 
143
- <button id="btn" onclick="gen()">Generate</button>
144
 
145
  <div id="status"></div>
146
  <img id="result"/>
147
 
148
  <script>
149
- async function gen() {
150
- const btn = document.getElementById("btn");
 
151
  const status = document.getElementById("status");
152
  const img = document.getElementById("result");
153
 
154
- const prompt = document.getElementById("prompt").value;
155
- const neg = document.getElementById("neg").value;
156
-
157
- if (!prompt.trim()) {
158
- status.textContent = "Please enter a prompt.";
159
- return;
160
- }
161
-
162
- btn.disabled = true;
163
- status.textContent = "Generating...";
164
- img.src = "";
165
 
 
166
  const res = await fetch("/api/generate", {
167
  method: "POST",
168
- headers: { "Content-Type": "application/json" },
169
- body: JSON.stringify({ prompt, negative_prompt: neg })
170
  });
171
 
172
- const j = await res.json();
173
 
174
- if (j.status !== "success") {
175
- status.textContent = "Error: " + j.message;
176
- btn.disabled = false;
177
  return;
178
  }
179
 
180
- img.src = "data:image/png;base64," + j.image_base64;
181
- status.textContent = "Done. Seed: " + j.seed;
182
- btn.disabled = false;
183
  }
184
  </script>
185
 
@@ -189,16 +170,15 @@ def home():
189
 
190
 
191
  # -------------------------------------------------------------
192
- # API Endpoint
193
  # -------------------------------------------------------------
194
  @app.post("/api/generate")
195
  async def api_generate(request: Request):
196
-
197
  try:
198
- data = await request.json()
199
- prompt = data.get("prompt", "").strip()
200
- negative_prompt = data.get("negative_prompt", "").strip()
201
- except Exception:
202
  return JSONResponse({"status": "error", "message": "Invalid JSON"}, 400)
203
 
204
  if not prompt:
@@ -207,28 +187,23 @@ async def api_generate(request: Request):
207
  width = 768
208
  height = 432
209
  steps = 2
210
- guidance = 0.0 # SDXL Turbo is trained for cfg=0
211
-
212
  seed = random.randint(0, 2**31 - 1)
213
 
214
  try:
215
- img = await run_generate(
216
- prompt, negative_prompt, seed, width, height, steps, guidance
217
- )
218
 
219
  buf = io.BytesIO()
220
  img.save(buf, format="PNG")
221
- encoded = base64.b64encode(buf.getvalue()).decode()
222
-
223
- return JSONResponse(
224
- {
225
- "status": "success",
226
- "image_base64": encoded,
227
- "seed": seed,
228
- "width": width,
229
- "height": height,
230
- }
231
- )
232
 
233
  except Exception as e:
234
  return JSONResponse({"status": "error", "message": str(e)}, 500)
 
15
 
16
 
17
  # -------------------------------------------------------------
18
+ # HuggingFace Token (auto)
19
  # -------------------------------------------------------------
20
+ HF_TOKEN = os.getenv("HF_TOKEN")
21
 
22
 
23
  # -------------------------------------------------------------
24
+ # Model Settings
25
  # -------------------------------------------------------------
26
  MODEL_REPO = "stabilityai/sdxl-turbo"
27
 
28
  device = "cuda" if torch.cuda.is_available() else "cpu"
29
  dtype = torch.float16 if torch.cuda.is_available() else torch.float32
30
 
 
31
 
32
+ print(f"Loading {MODEL_REPO} on {device}...")
33
+
34
  pipe = DiffusionPipeline.from_pretrained(
35
  MODEL_REPO,
36
  torch_dtype=dtype,
37
  use_safetensors=True,
38
+ token=HF_TOKEN if HF_TOKEN else None, # auto use token
39
  )
40
 
41
  pipe.to(device)
42
 
 
43
  if device == "cpu":
44
  try:
45
  pipe.enable_model_cpu_offload()
46
+ except:
47
  pass
48
 
49
  print("Model ready.")
50
 
51
 
52
  # -------------------------------------------------------------
53
+ # Core Generation Function
54
  # -------------------------------------------------------------
55
+ def generate_image(prompt, negative_prompt, seed, width, height, steps, guidance):
 
 
 
 
 
 
 
 
56
  generator = torch.Generator(device=device).manual_seed(seed)
57
 
58
+ result = pipe(
59
  prompt=prompt,
60
  negative_prompt=negative_prompt if negative_prompt else None,
61
+ guidance_scale=guidance,
62
+ num_inference_steps=steps,
63
  width=width,
64
  height=height,
65
  generator=generator,
66
  )
67
 
68
+ return result.images[0]
69
 
70
 
71
  # -------------------------------------------------------------
 
94
  # -------------------------------------------------------------
95
  # FastAPI App
96
  # -------------------------------------------------------------
97
+ app = FastAPI(title="SDXL Turbo Generator", version="2.0")
98
 
99
  app.add_middleware(
100
  CORSMiddleware,
 
106
 
107
 
108
  # -------------------------------------------------------------
109
+ # POST-only HTML UI
110
  # -------------------------------------------------------------
111
  @app.get("/", response_class=HTMLResponse)
112
  def home():
 
114
  <!doctype html>
115
  <html>
116
  <head>
117
+ <meta charset="utf-8"/>
118
  <title>SDXL Turbo CPU Generator</title>
119
  <style>
120
  body { font-family: Arial; max-width: 900px; margin: 30px auto; }
121
+ textarea { width: 100%; padding: 10px; margin-bottom: 10px; }
122
+ button { padding: 12px; background:black; color:white; border:none; cursor:pointer; }
123
+ img { max-width:100%; margin-top:20px; border-radius:10px; }
124
  #status { margin-top:10px; }
125
  </style>
126
  </head>
127
  <body>
 
128
 
129
+ <h1>SDXL Turbo (POST-only UI)</h1>
130
 
131
+ <textarea id="prompt" placeholder="Enter prompt"></textarea>
132
+ <textarea id="negative_prompt" placeholder="Negative prompt (optional)"></textarea>
133
 
134
+ <button onclick="send()">Generate</button>
135
 
136
  <div id="status"></div>
137
  <img id="result"/>
138
 
139
  <script>
140
+ async function send() {
141
+ const prompt = document.getElementById("prompt").value;
142
+ const negative_prompt = document.getElementById("negative_prompt").value;
143
  const status = document.getElementById("status");
144
  const img = document.getElementById("result");
145
 
146
+ status.innerText = "Generating...";
 
 
 
 
 
 
 
 
 
 
147
 
148
+ // POST request only
149
  const res = await fetch("/api/generate", {
150
  method: "POST",
151
+ headers: {"Content-Type": "application/json"},
152
+ body: JSON.stringify({ prompt, negative_prompt })
153
  });
154
 
155
+ const data = await res.json();
156
 
157
+ if (data.status !== "success") {
158
+ status.innerText = "Error: " + data.message;
 
159
  return;
160
  }
161
 
162
+ img.src = "data:image/png;base64," + data.image_base64;
163
+ status.innerText = "Done (seed " + data.seed + ")";
 
164
  }
165
  </script>
166
 
 
170
 
171
 
172
  # -------------------------------------------------------------
173
+ # API Endpoint (POST only)
174
  # -------------------------------------------------------------
175
  @app.post("/api/generate")
176
  async def api_generate(request: Request):
 
177
  try:
178
+ body = await request.json()
179
+ prompt = body.get("prompt", "").strip()
180
+ negative_prompt = body.get("negative_prompt", "").strip()
181
+ except:
182
  return JSONResponse({"status": "error", "message": "Invalid JSON"}, 400)
183
 
184
  if not prompt:
 
187
  width = 768
188
  height = 432
189
  steps = 2
190
+ guidance = 0.0
 
191
  seed = random.randint(0, 2**31 - 1)
192
 
193
  try:
194
+ img = await run_generate(prompt, negative_prompt, seed, width, height, steps, guidance)
 
 
195
 
196
  buf = io.BytesIO()
197
  img.save(buf, format="PNG")
198
+ b64 = base64.b64encode(buf.getvalue()).decode()
199
+
200
+ return JSONResponse({
201
+ "status": "success",
202
+ "image_base64": b64,
203
+ "seed": seed,
204
+ "width": width,
205
+ "height": height
206
+ })
 
 
207
 
208
  except Exception as e:
209
  return JSONResponse({"status": "error", "message": str(e)}, 500)