File size: 5,284 Bytes
36ea805
 
 
 
 
 
 
 
 
 
 
 
e66d877
36ea805
23a42f6
 
e66d877
 
e3516bd
e66d877
e3516bd
 
e66d877
 
36ea805
 
 
 
824f888
 
 
 
 
 
e3516bd
 
824f888
c4b2683
36ea805
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3516bd
36ea805
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e3516bd
36ea805
 
 
e3516bd
36ea805
 
e3516bd
 
36ea805
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23a42f6
 
 
e3516bd
45939dd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import os
import uuid
import logging
import torch
import numpy as np
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from diffusers import AnimateDiffPipeline, EulerDiscreteScheduler
from diffusers.utils import export_to_video
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from fastapi.responses import FileResponse
from fastapi.middleware.cors import CORSMiddleware

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Cho phép tất cả domain gọi API (hoặc liệt kê domain cụ thể ở đây)
    allow_credentials=True,
    allow_methods=["*"],  # Cho phép tất cả phương thức (POST, GET, etc.)
    allow_headers=["*"],  # Cho phép tất cả headers
)

# Thiết lập logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Thiết lập thư mục cache cho Hugging Face ngay đầu file
os.environ["HF_HOME"] = "/tmp/huggingface_cache"
os.environ["XDG_CACHE_HOME"] = "/tmp/huggingface_cache"
os.makedirs(os.environ["HF_HOME"], exist_ok=True)
logger.info(f"HF_HOME set to {os.environ['HF_HOME']}")

app = FastAPI()

# Tạo thư mục lưu video trong /tmp
output_dir = "/tmp/outputs"
os.makedirs(output_dir, exist_ok=True)

# Constants
bases = {
    "Cartoon": "frankjoshua/toonyou_beta6",
    "Realistic": "emilianJR/epiCRealism",
    "3d": "Lykon/DreamShaper",
    "Anime": "Yntec/mistoonAnime2"
}
step_loaded = None
base_loaded = "Realistic"
motion_loaded = None

# Thiết lập thiết bị CPU và kiểu dữ liệu
device = "cpu"
dtype = torch.float32

# Khởi tạo pipeline
pipe = AnimateDiffPipeline.from_pretrained(bases[base_loaded], torch_dtype=dtype).to(device)
pipe.scheduler = EulerDiscreteScheduler.from_config(
    pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear"
)
pipe.safety_checker = None

# Mô hình dữ liệu cho request
class VideoRequest(BaseModel):
    prompt: str
    base: str = "Realistic"
    motion: str = ""
    step: int = 1

# Endpoint tạo video
@app.post("/generate_video")
async def generate_video(request: VideoRequest):
    global step_loaded, base_loaded, motion_loaded
    prompt = request.prompt
    base = request.base
    motion = request.motion
    step = request.step

    logger.info(f"Tạo video với prompt: {prompt}, base: {base}, motion: {motion}, steps: {step}")

    try:
        # Kiểm tra base hợp lệ
        if base not in bases:
            raise HTTPException(status_code=400, detail="Base model không hợp lệ")

        # Tải AnimateDiff Lightning checkpoint
        if step_loaded != step:
            repo = "ByteDance/AnimateDiff-Lightning"
            ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
            pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
            step_loaded = step

        # Tải mô hình cơ sở nếu thay đổi
        if base_loaded != base:
            pipe.unet.load_state_dict(
                torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device),
                strict=False
            )
            base_loaded = base

        # Tải motion LoRA nếu có
        if motion_loaded != motion:
            try:
                pipe.unload_lora_weights()
                if motion:
                    pipe.load_lora_weights(motion, adapter_name="motion")
                    pipe.set_adapters(["motion"], [0.7])
                motion_loaded = motion
            except Exception as e:
                logger.warning(f"Không thể tải motion LoRA: {e}")
                motion_loaded = ""

        # Suy luận
        with torch.no_grad():
            output = pipe(
                prompt=prompt,
                guidance_scale=1.2,
                num_inference_steps=step,
                num_frames=32,
                width=256,
                height=256
            )

        # Chuẩn hóa khung hình cho 8 giây
        frames = output.frames[0]
        fps = 24
        target_frames = fps * 8
        if len(frames) < target_frames:
            frames = np.tile(frames, (target_frames // len(frames) + 1, 1, 1, 1))[:target_frames]
        else:
            frames = frames[:target_frames]

        # Tạo video
        name = str(uuid.uuid4()).replace("-", "")
        video_path = os.path.join(output_dir, f"{name}.mp4")
        export_to_video(frames, video_path, fps=fps)

        if not os.path.exists(video_path):
            raise FileNotFoundError("Video không được tạo")

        logger.info(f"Video sẵn sàng tại {video_path}")

        # Trả về file video
        return FileResponse(video_path, media_type="video/mp4", filename=f"{name}.mp4")

    except Exception as e:
        logger.error(f"Lỗi khi tạo video: {e}")
        raise HTTPException(status_code=500, detail=str(e))

# Endpoint kiểm tra trạng thái
@app.get("/")
async def root():
    return {"message": "FastAPI AnimateDiff-Lightning API on Hugging Face Spaces"}

if __name__ == "__main__":
    port = int(os.environ.get("PORT", 7860))  # Nếu PORT không có thì mặc định 7860
    uvicorn.run("main:app", host="0.0.0.0", port=port, reload=True)