BE_T2V / main.py
hoangkha1810's picture
Update main.py
e3516bd verified
import os
import uuid
import logging
import torch
import numpy as np
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from diffusers import AnimateDiffPipeline, EulerDiscreteScheduler
from diffusers.utils import export_to_video
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Cho phép tất cả domain gọi API (hoặc liệt kê domain cụ thể ở đây)
allow_credentials=True,
allow_methods=["*"], # Cho phép tất cả phương thức (POST, GET, etc.)
allow_headers=["*"], # Cho phép tất cả headers
)
# Thiết lập logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Thiết lập thư mục cache cho Hugging Face ngay đầu file
os.environ["HF_HOME"] = "/tmp/huggingface_cache"
os.environ["XDG_CACHE_HOME"] = "/tmp/huggingface_cache"
os.makedirs(os.environ["HF_HOME"], exist_ok=True)
logger.info(f"HF_HOME set to {os.environ['HF_HOME']}")
app = FastAPI()
# Tạo thư mục lưu video trong /tmp
output_dir = "/tmp/outputs"
os.makedirs(output_dir, exist_ok=True)
# Constants
bases = {
"Cartoon": "frankjoshua/toonyou_beta6",
"Realistic": "emilianJR/epiCRealism",
"3d": "Lykon/DreamShaper",
"Anime": "Yntec/mistoonAnime2"
}
step_loaded = None
base_loaded = "Realistic"
motion_loaded = None
# Thiết lập thiết bị CPU và kiểu dữ liệu
device = "cpu"
dtype = torch.float32
# Khởi tạo pipeline
pipe = AnimateDiffPipeline.from_pretrained(bases[base_loaded], torch_dtype=dtype).to(device)
pipe.scheduler = EulerDiscreteScheduler.from_config(
pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear"
)
pipe.safety_checker = None
# Mô hình dữ liệu cho request
class VideoRequest(BaseModel):
prompt: str
base: str = "Realistic"
motion: str = ""
step: int = 1
# Endpoint tạo video
@app.post("/generate_video")
async def generate_video(request: VideoRequest):
global step_loaded, base_loaded, motion_loaded
prompt = request.prompt
base = request.base
motion = request.motion
step = request.step
logger.info(f"Tạo video với prompt: {prompt}, base: {base}, motion: {motion}, steps: {step}")
try:
# Kiểm tra base hợp lệ
if base not in bases:
raise HTTPException(status_code=400, detail="Base model không hợp lệ")
# Tải AnimateDiff Lightning checkpoint
if step_loaded != step:
repo = "ByteDance/AnimateDiff-Lightning"
ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
step_loaded = step
# Tải mô hình cơ sở nếu thay đổi
if base_loaded != base:
pipe.unet.load_state_dict(
torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device),
strict=False
)
base_loaded = base
# Tải motion LoRA nếu có
if motion_loaded != motion:
try:
pipe.unload_lora_weights()
if motion:
pipe.load_lora_weights(motion, adapter_name="motion")
pipe.set_adapters(["motion"], [0.7])
motion_loaded = motion
except Exception as e:
logger.warning(f"Không thể tải motion LoRA: {e}")
motion_loaded = ""
# Suy luận
with torch.no_grad():
output = pipe(
prompt=prompt,
guidance_scale=1.2,
num_inference_steps=step,
num_frames=32,
width=256,
height=256
)
# Chuẩn hóa khung hình cho 8 giây
frames = output.frames[0]
fps = 24
target_frames = fps * 8
if len(frames) < target_frames:
frames = np.tile(frames, (target_frames // len(frames) + 1, 1, 1, 1))[:target_frames]
else:
frames = frames[:target_frames]
# Tạo video
name = str(uuid.uuid4()).replace("-", "")
video_path = os.path.join(output_dir, f"{name}.mp4")
export_to_video(frames, video_path, fps=fps)
if not os.path.exists(video_path):
raise FileNotFoundError("Video không được tạo")
logger.info(f"Video sẵn sàng tại {video_path}")
# Trả về file video
return FileResponse(video_path, media_type="video/mp4", filename=f"{name}.mp4")
except Exception as e:
logger.error(f"Lỗi khi tạo video: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Endpoint kiểm tra trạng thái
@app.get("/")
async def root():
return {"message": "FastAPI AnimateDiff-Lightning API on Hugging Face Spaces"}
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860)) # Nếu PORT không có thì mặc định 7860
uvicorn.run("main:app", host="0.0.0.0", port=port, reload=True)