Spaces:
Running
Running
| import os, io, base64, asyncio, torch, spaces | |
| from fastapi import FastAPI, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import HTMLResponse, JSONResponse | |
| from diffusers import FluxPipeline | |
| from PIL import Image | |
| from concurrent.futures import ThreadPoolExecutor | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| BASE_MODEL = "black-forest-labs/FLUX.1-schnell" | |
| _cached = {} | |
| # moderate concurrency so CPU doesnβt choke | |
| executor = ThreadPoolExecutor(max_workers=3) | |
| semaphore = asyncio.Semaphore(3) | |
| def load_pipeline(): | |
| if "flux" in _cached: | |
| return _cached["flux"] | |
| print("πΉ Loading FLUX.1-schnell (fast mode)") | |
| pipe = FluxPipeline.from_pretrained( | |
| BASE_MODEL, | |
| torch_dtype=torch.float16, | |
| use_auth_token=HF_TOKEN, | |
| ).to("cpu", dtype=torch.float16) | |
| pipe.enable_attention_slicing() | |
| pipe.enable_vae_tiling() | |
| _cached["flux"] = pipe | |
| return pipe | |
| def generate_image_sync(prompt: str, seed: int = 42): | |
| pipe = load_pipeline() | |
| gen = torch.Generator(device="cpu").manual_seed(int(seed)) | |
| # smaller size and steps for speed | |
| w, h = 768, 432 | |
| image = pipe( | |
| prompt=prompt, | |
| width=w, | |
| height=h, | |
| num_inference_steps=4, | |
| guidance_scale=3, | |
| generator=gen, | |
| ).images[0] | |
| # slight upscale back to 960Γ540 to keep output clear | |
| return image.resize((960, 540), Image.BICUBIC) | |
| async def generate_image_async(prompt, seed): | |
| async with semaphore: | |
| loop = asyncio.get_running_loop() | |
| return await loop.run_in_executor(executor, generate_image_sync, prompt, seed) | |
| app = FastAPI(title="FLUX Fast API", version="3.1") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def home(): | |
| return """ | |
| <html><head><title>FLUX Fast</title> | |
| <style>body{font-family:Arial;text-align:center;padding:2rem} | |
| input,button{margin:.5rem;padding:.6rem;width:300px;border-radius:6px;border:1px solid #ccc} | |
| button{background:#444;color:#fff}button:hover{background:#333} | |
| img{margin-top:1rem;max-width:90%;border-radius:12px}</style></head> | |
| <body><h2>π¨ FLUX Fast Generator</h2> | |
| <form id='f'><input id='prompt' placeholder='Describe image...' required><br> | |
| <input id='seed' type='number' value='42'><br> | |
| <button>Generate</button></form><div id='out'></div> | |
| <script> | |
| const form = document.getElementById("f"); | |
| const promptInput = document.getElementById("prompt"); | |
| const seedInput = document.getElementById("seed"); | |
| const resultDiv = document.getElementById("out"); | |
| form.addEventListener("submit", async (e) => { | |
| e.preventDefault(); | |
| const prompt = promptInput.value.trim(); | |
| if (!prompt) { | |
| resultDiv.innerHTML = "<p style='color:red'>β Please enter a prompt</p>"; | |
| return; | |
| } | |
| resultDiv.innerHTML = "<p>β³ Generating...</p>"; | |
| const payload = { | |
| prompt: prompt, | |
| seed: parseInt(seedInput.value || 42) | |
| }; | |
| const res = await fetch("/api/generate", { | |
| method: "POST", | |
| headers: { "Content-Type": "application/json" }, | |
| body: JSON.stringify(payload) | |
| }); | |
| const json = await res.json(); | |
| if (json.status === "success") { | |
| resultDiv.innerHTML = `<img src="data:image/png;base64,${json.image_base64}"/><p>β Done!</p>`; | |
| } else { | |
| resultDiv.innerHTML = `<p style='color:red'>β ${json.message}</p>`; | |
| } | |
| }); | |
| </script> | |
| </body></html> | |
| """ | |
| async def api_generate(request: Request): | |
| try: | |
| data = await request.json() | |
| prompt = str(data.get("prompt", "")).strip() | |
| seed = int(data.get("seed", 42)) | |
| if not prompt: | |
| return JSONResponse({"status": "error", "message": "Prompt required"}, 400) | |
| except Exception: | |
| return JSONResponse({"status": "error", "message": "Invalid JSON"}, 400) | |
| try: | |
| image = await generate_image_async(prompt, seed) | |
| buf = io.BytesIO() | |
| image.save(buf, format="PNG") | |
| img64 = base64.b64encode(buf.getvalue()).decode("utf-8") | |
| return JSONResponse({"status": "success", "prompt": prompt, "image_base64": img64}) | |
| except Exception as e: | |
| print(f"β Error: {e}") | |
| return JSONResponse({"status": "error", "message": str(e)}, 500) | |
| def keep_alive(): return "ZeroGPU Ready" | |
| if __name__ == "__main__": | |
| import uvicorn | |
| print("π Launching Fast FLUX API") | |
| keep_alive() | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |