Spaces:
videopix
/
Running

Gibhili / app_old.py
videopix's picture
Update app_old.py
07d8ca5 verified
raw
history blame
4.54 kB
import os, io, base64, asyncio, torch, spaces
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse, JSONResponse
from diffusers import FluxPipeline
from PIL import Image
from concurrent.futures import ThreadPoolExecutor
HF_TOKEN = os.getenv("HF_TOKEN")
BASE_MODEL = "black-forest-labs/FLUX.1-schnell"
_cached = {}
# moderate concurrency so CPU doesn’t choke
executor = ThreadPoolExecutor(max_workers=3)
semaphore = asyncio.Semaphore(3)
def load_pipeline():
if "flux" in _cached:
return _cached["flux"]
print("πŸ”Ή Loading FLUX.1-schnell (fast mode)")
pipe = FluxPipeline.from_pretrained(
BASE_MODEL,
torch_dtype=torch.float16,
use_auth_token=HF_TOKEN,
).to("cpu", dtype=torch.float16)
pipe.enable_attention_slicing()
pipe.enable_vae_tiling()
_cached["flux"] = pipe
return pipe
def generate_image_sync(prompt: str, seed: int = 42):
pipe = load_pipeline()
gen = torch.Generator(device="cpu").manual_seed(int(seed))
# smaller size and steps for speed
w, h = 768, 432
image = pipe(
prompt=prompt,
width=w,
height=h,
num_inference_steps=4,
guidance_scale=3,
generator=gen,
).images[0]
# slight upscale back to 960Γ—540 to keep output clear
return image.resize((960, 540), Image.BICUBIC)
async def generate_image_async(prompt, seed):
async with semaphore:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(executor, generate_image_sync, prompt, seed)
app = FastAPI(title="FLUX Fast API", version="3.1")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/", response_class=HTMLResponse)
def home():
return """
<html><head><title>FLUX Fast</title>
<style>body{font-family:Arial;text-align:center;padding:2rem}
input,button{margin:.5rem;padding:.6rem;width:300px;border-radius:6px;border:1px solid #ccc}
button{background:#444;color:#fff}button:hover{background:#333}
img{margin-top:1rem;max-width:90%;border-radius:12px}</style></head>
<body><h2>🎨 FLUX Fast Generator</h2>
<form id='f'><input id='prompt' placeholder='Describe image...' required><br>
<input id='seed' type='number' value='42'><br>
<button>Generate</button></form><div id='out'></div>
<script>
const form = document.getElementById("f");
const promptInput = document.getElementById("prompt");
const seedInput = document.getElementById("seed");
const resultDiv = document.getElementById("out");
form.addEventListener("submit", async (e) => {
e.preventDefault();
const prompt = promptInput.value.trim();
if (!prompt) {
resultDiv.innerHTML = "<p style='color:red'>❌ Please enter a prompt</p>";
return;
}
resultDiv.innerHTML = "<p>⏳ Generating...</p>";
const payload = {
prompt: prompt,
seed: parseInt(seedInput.value || 42)
};
const res = await fetch("/api/generate", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify(payload)
});
const json = await res.json();
if (json.status === "success") {
resultDiv.innerHTML = `<img src="data:image/png;base64,${json.image_base64}"/><p>βœ… Done!</p>`;
} else {
resultDiv.innerHTML = `<p style='color:red'>❌ ${json.message}</p>`;
}
});
</script>
</body></html>
"""
@app.post("/api/generate")
async def api_generate(request: Request):
try:
data = await request.json()
prompt = str(data.get("prompt", "")).strip()
seed = int(data.get("seed", 42))
if not prompt:
return JSONResponse({"status": "error", "message": "Prompt required"}, 400)
except Exception:
return JSONResponse({"status": "error", "message": "Invalid JSON"}, 400)
try:
image = await generate_image_async(prompt, seed)
buf = io.BytesIO()
image.save(buf, format="PNG")
img64 = base64.b64encode(buf.getvalue()).decode("utf-8")
return JSONResponse({"status": "success", "prompt": prompt, "image_base64": img64})
except Exception as e:
print(f"❌ Error: {e}")
return JSONResponse({"status": "error", "message": str(e)}, 500)
@spaces.GPU
def keep_alive(): return "ZeroGPU Ready"
if __name__ == "__main__":
import uvicorn
print("πŸš€ Launching Fast FLUX API")
keep_alive()
uvicorn.run(app, host="0.0.0.0", port=7860)