import os, io, base64, asyncio, torch, spaces from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import HTMLResponse, JSONResponse from diffusers import FluxPipeline from PIL import Image from concurrent.futures import ThreadPoolExecutor HF_TOKEN = os.getenv("HF_TOKEN") BASE_MODEL = "black-forest-labs/FLUX.1-schnell" _cached = {} # moderate concurrency so CPU doesnβt choke executor = ThreadPoolExecutor(max_workers=3) semaphore = asyncio.Semaphore(3) def load_pipeline(): if "flux" in _cached: return _cached["flux"] print("πΉ Loading FLUX.1-schnell (fast mode)") pipe = FluxPipeline.from_pretrained( BASE_MODEL, torch_dtype=torch.float16, use_auth_token=HF_TOKEN, ).to("cpu", dtype=torch.float16) pipe.enable_attention_slicing() pipe.enable_vae_tiling() _cached["flux"] = pipe return pipe def generate_image_sync(prompt: str, seed: int = 42): pipe = load_pipeline() gen = torch.Generator(device="cpu").manual_seed(int(seed)) # smaller size and steps for speed w, h = 768, 432 image = pipe( prompt=prompt, width=w, height=h, num_inference_steps=4, guidance_scale=3, generator=gen, ).images[0] # slight upscale back to 960Γ540 to keep output clear return image.resize((960, 540), Image.BICUBIC) async def generate_image_async(prompt, seed): async with semaphore: loop = asyncio.get_running_loop() return await loop.run_in_executor(executor, generate_image_sync, prompt, seed) app = FastAPI(title="FLUX Fast API", version="3.1") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/", response_class=HTMLResponse) def home(): return """