Spaces:
Running
Running
import os | |
import uuid | |
import logging | |
import torch | |
import numpy as np | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from diffusers import AnimateDiffPipeline, EulerDiscreteScheduler | |
from diffusers.utils import export_to_video | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_file | |
from fastapi.responses import FileResponse | |
from fastapi.middleware.cors import CORSMiddleware | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Cho phép tất cả domain gọi API (hoặc liệt kê domain cụ thể ở đây) | |
allow_credentials=True, | |
allow_methods=["*"], # Cho phép tất cả phương thức (POST, GET, etc.) | |
allow_headers=["*"], # Cho phép tất cả headers | |
) | |
# Thiết lập logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Thiết lập thư mục cache cho Hugging Face ngay đầu file | |
os.environ["HF_HOME"] = "/tmp/huggingface_cache" | |
os.environ["XDG_CACHE_HOME"] = "/tmp/huggingface_cache" | |
os.makedirs(os.environ["HF_HOME"], exist_ok=True) | |
logger.info(f"HF_HOME set to {os.environ['HF_HOME']}") | |
app = FastAPI() | |
# Tạo thư mục lưu video trong /tmp | |
output_dir = "/tmp/outputs" | |
os.makedirs(output_dir, exist_ok=True) | |
# Constants | |
bases = { | |
"Cartoon": "frankjoshua/toonyou_beta6", | |
"Realistic": "emilianJR/epiCRealism", | |
"3d": "Lykon/DreamShaper", | |
"Anime": "Yntec/mistoonAnime2" | |
} | |
step_loaded = None | |
base_loaded = "Realistic" | |
motion_loaded = None | |
# Thiết lập thiết bị CPU và kiểu dữ liệu | |
device = "cpu" | |
dtype = torch.float32 | |
# Khởi tạo pipeline | |
pipe = AnimateDiffPipeline.from_pretrained(bases[base_loaded], torch_dtype=dtype).to(device) | |
pipe.scheduler = EulerDiscreteScheduler.from_config( | |
pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear" | |
) | |
pipe.safety_checker = None | |
# Mô hình dữ liệu cho request | |
class VideoRequest(BaseModel): | |
prompt: str | |
base: str = "Realistic" | |
motion: str = "" | |
step: int = 1 | |
# Endpoint tạo video | |
async def generate_video(request: VideoRequest): | |
global step_loaded, base_loaded, motion_loaded | |
prompt = request.prompt | |
base = request.base | |
motion = request.motion | |
step = request.step | |
logger.info(f"Tạo video với prompt: {prompt}, base: {base}, motion: {motion}, steps: {step}") | |
try: | |
# Kiểm tra base hợp lệ | |
if base not in bases: | |
raise HTTPException(status_code=400, detail="Base model không hợp lệ") | |
# Tải AnimateDiff Lightning checkpoint | |
if step_loaded != step: | |
repo = "ByteDance/AnimateDiff-Lightning" | |
ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors" | |
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False) | |
step_loaded = step | |
# Tải mô hình cơ sở nếu thay đổi | |
if base_loaded != base: | |
pipe.unet.load_state_dict( | |
torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device), | |
strict=False | |
) | |
base_loaded = base | |
# Tải motion LoRA nếu có | |
if motion_loaded != motion: | |
try: | |
pipe.unload_lora_weights() | |
if motion: | |
pipe.load_lora_weights(motion, adapter_name="motion") | |
pipe.set_adapters(["motion"], [0.7]) | |
motion_loaded = motion | |
except Exception as e: | |
logger.warning(f"Không thể tải motion LoRA: {e}") | |
motion_loaded = "" | |
# Suy luận | |
with torch.no_grad(): | |
output = pipe( | |
prompt=prompt, | |
guidance_scale=1.2, | |
num_inference_steps=step, | |
num_frames=32, | |
width=256, | |
height=256 | |
) | |
# Chuẩn hóa khung hình cho 8 giây | |
frames = output.frames[0] | |
fps = 24 | |
target_frames = fps * 8 | |
if len(frames) < target_frames: | |
frames = np.tile(frames, (target_frames // len(frames) + 1, 1, 1, 1))[:target_frames] | |
else: | |
frames = frames[:target_frames] | |
# Tạo video | |
name = str(uuid.uuid4()).replace("-", "") | |
video_path = os.path.join(output_dir, f"{name}.mp4") | |
export_to_video(frames, video_path, fps=fps) | |
if not os.path.exists(video_path): | |
raise FileNotFoundError("Video không được tạo") | |
logger.info(f"Video sẵn sàng tại {video_path}") | |
# Trả về file video | |
return FileResponse(video_path, media_type="video/mp4", filename=f"{name}.mp4") | |
except Exception as e: | |
logger.error(f"Lỗi khi tạo video: {e}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
# Endpoint kiểm tra trạng thái | |
async def root(): | |
return {"message": "FastAPI AnimateDiff-Lightning API on Hugging Face Spaces"} | |
if __name__ == "__main__": | |
port = int(os.environ.get("PORT", 7860)) # Nếu PORT không có thì mặc định 7860 | |
uvicorn.run("main:app", host="0.0.0.0", port=port, reload=True) |