Spaces:
Running
Running
File size: 5,284 Bytes
36ea805 e66d877 36ea805 23a42f6 e66d877 e3516bd e66d877 e3516bd e66d877 36ea805 824f888 e3516bd 824f888 c4b2683 36ea805 e3516bd 36ea805 e3516bd 36ea805 e3516bd 36ea805 e3516bd 36ea805 23a42f6 e3516bd 45939dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import os
import uuid
import logging
import torch
import numpy as np
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from diffusers import AnimateDiffPipeline, EulerDiscreteScheduler
from diffusers.utils import export_to_video
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Cho phép tất cả domain gọi API (hoặc liệt kê domain cụ thể ở đây)
allow_credentials=True,
allow_methods=["*"], # Cho phép tất cả phương thức (POST, GET, etc.)
allow_headers=["*"], # Cho phép tất cả headers
)
# Thiết lập logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Thiết lập thư mục cache cho Hugging Face ngay đầu file
os.environ["HF_HOME"] = "/tmp/huggingface_cache"
os.environ["XDG_CACHE_HOME"] = "/tmp/huggingface_cache"
os.makedirs(os.environ["HF_HOME"], exist_ok=True)
logger.info(f"HF_HOME set to {os.environ['HF_HOME']}")
app = FastAPI()
# Tạo thư mục lưu video trong /tmp
output_dir = "/tmp/outputs"
os.makedirs(output_dir, exist_ok=True)
# Constants
bases = {
"Cartoon": "frankjoshua/toonyou_beta6",
"Realistic": "emilianJR/epiCRealism",
"3d": "Lykon/DreamShaper",
"Anime": "Yntec/mistoonAnime2"
}
step_loaded = None
base_loaded = "Realistic"
motion_loaded = None
# Thiết lập thiết bị CPU và kiểu dữ liệu
device = "cpu"
dtype = torch.float32
# Khởi tạo pipeline
pipe = AnimateDiffPipeline.from_pretrained(bases[base_loaded], torch_dtype=dtype).to(device)
pipe.scheduler = EulerDiscreteScheduler.from_config(
pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear"
)
pipe.safety_checker = None
# Mô hình dữ liệu cho request
class VideoRequest(BaseModel):
prompt: str
base: str = "Realistic"
motion: str = ""
step: int = 1
# Endpoint tạo video
@app.post("/generate_video")
async def generate_video(request: VideoRequest):
global step_loaded, base_loaded, motion_loaded
prompt = request.prompt
base = request.base
motion = request.motion
step = request.step
logger.info(f"Tạo video với prompt: {prompt}, base: {base}, motion: {motion}, steps: {step}")
try:
# Kiểm tra base hợp lệ
if base not in bases:
raise HTTPException(status_code=400, detail="Base model không hợp lệ")
# Tải AnimateDiff Lightning checkpoint
if step_loaded != step:
repo = "ByteDance/AnimateDiff-Lightning"
ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
step_loaded = step
# Tải mô hình cơ sở nếu thay đổi
if base_loaded != base:
pipe.unet.load_state_dict(
torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device),
strict=False
)
base_loaded = base
# Tải motion LoRA nếu có
if motion_loaded != motion:
try:
pipe.unload_lora_weights()
if motion:
pipe.load_lora_weights(motion, adapter_name="motion")
pipe.set_adapters(["motion"], [0.7])
motion_loaded = motion
except Exception as e:
logger.warning(f"Không thể tải motion LoRA: {e}")
motion_loaded = ""
# Suy luận
with torch.no_grad():
output = pipe(
prompt=prompt,
guidance_scale=1.2,
num_inference_steps=step,
num_frames=32,
width=256,
height=256
)
# Chuẩn hóa khung hình cho 8 giây
frames = output.frames[0]
fps = 24
target_frames = fps * 8
if len(frames) < target_frames:
frames = np.tile(frames, (target_frames // len(frames) + 1, 1, 1, 1))[:target_frames]
else:
frames = frames[:target_frames]
# Tạo video
name = str(uuid.uuid4()).replace("-", "")
video_path = os.path.join(output_dir, f"{name}.mp4")
export_to_video(frames, video_path, fps=fps)
if not os.path.exists(video_path):
raise FileNotFoundError("Video không được tạo")
logger.info(f"Video sẵn sàng tại {video_path}")
# Trả về file video
return FileResponse(video_path, media_type="video/mp4", filename=f"{name}.mp4")
except Exception as e:
logger.error(f"Lỗi khi tạo video: {e}")
raise HTTPException(status_code=500, detail=str(e))
# Endpoint kiểm tra trạng thái
@app.get("/")
async def root():
return {"message": "FastAPI AnimateDiff-Lightning API on Hugging Face Spaces"}
if __name__ == "__main__":
port = int(os.environ.get("PORT", 7860)) # Nếu PORT không có thì mặc định 7860
uvicorn.run("main:app", host="0.0.0.0", port=port, reload=True) |