hoangkha1810 commited on
Commit
36ea805
·
verified ·
1 Parent(s): 86aa991

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +138 -0
main.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ```python
2
+ import os
3
+ import uuid
4
+ import logging
5
+ import torch
6
+ import numpy as np
7
+ from fastapi import FastAPI, HTTPException
8
+ from pydantic import BaseModel
9
+ from diffusers import AnimateDiffPipeline, EulerDiscreteScheduler
10
+ from diffusers.utils import export_to_video
11
+ from huggingface_hub import hf_hub_download
12
+ from safetensors.torch import load_file
13
+ from fastapi.responses import FileResponse
14
+
15
+ # Thiết lập logging
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
+
19
+ app = FastAPI()
20
+
21
+ # Tạo thư mục lưu video
22
+ output_dir = "/app/outputs"
23
+ os.makedirs(output_dir, exist_ok=True)
24
+
25
+ # Constants
26
+ bases = {
27
+ "Cartoon": "frankjoshua/toonyou_beta6",
28
+ "Realistic": "emilianJR/epiCRealism",
29
+ "3d": "Lykon/DreamShaper",
30
+ "Anime": "Yntec/mistoonAnime2"
31
+ }
32
+ step_loaded = None
33
+ base_loaded = "Realistic"
34
+ motion_loaded = None
35
+
36
+ # Thiết lập thiết bị CPU và kiểu dữ liệu
37
+ device = "cpu"
38
+ dtype = torch.float32
39
+
40
+ # Khởi tạo pipeline
41
+ pipe = AnimateDiffPipeline.from_pretrained(bases[base_loaded], torch_dtype=dtype).to(device)
42
+ pipe.scheduler = EulerDiscreteScheduler.from_config(
43
+ pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear"
44
+ )
45
+ pipe.safety_checker = None
46
+
47
+ # Mô hình dữ liệu cho request
48
+ class VideoRequest(BaseModel):
49
+ prompt: str
50
+ base: str = "Realistic"
51
+ motion: str = ""
52
+ step: int = 1
53
+
54
+ # Endpoint tạo video
55
+ @app.post("/generate_video")
56
+ async def generate_video(request: VideoRequest):
57
+ global step_loaded, base_loaded, motion_loaded
58
+ prompt = request.prompt
59
+ base = request.base
60
+ motion = request.motion
61
+ step = request.step
62
+
63
+ logger.info(f"Tạo video với prompt: {prompt}, base: {base}, motion: {motion}, steps: {step}")
64
+
65
+ try:
66
+ # Kiểm tra base hợp lệ
67
+ if base not in bases:
68
+ raise HTTPException(status_code=400, detail="Base model không hợp lệ")
69
+
70
+ # Tải AnimateDiff Lightning checkpoint
71
+ if step_loaded != step:
72
+ repo = "ByteDance/AnimateDiff-Lightning"
73
+ ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
74
+ pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
75
+ step_loaded = step
76
+
77
+ # Tải mô hình cơ sở nếu thay đổi
78
+ if base_loaded != base:
79
+ pipe.unet.load_state_dict(
80
+ torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device),
81
+ strict=False
82
+ )
83
+ base_loaded = base
84
+
85
+ # Tải motion LoRA nếu có
86
+ if motion_loaded != motion:
87
+ try:
88
+ pipe.unload_lora_weights()
89
+ if motion:
90
+ pipe.load_lora_weights(motion, adapter_name="motion")
91
+ pipe.set_adapters(["motion"], [0.7])
92
+ motion_loaded = motion
93
+ except Exception as e:
94
+ logger.warning(f"Không thể tải motion LoRA: {e}")
95
+ motion_loaded = ""
96
+
97
+ # Suy luận
98
+ with torch.no_grad():
99
+ output = pipe(
100
+ prompt=prompt,
101
+ guidance_scale=1.2,
102
+ num_inference_steps=step,
103
+ num_frames=32,
104
+ width=256,
105
+ height=256
106
+ )
107
+
108
+ # Chuẩn hóa khung hình cho 8 giây
109
+ frames = output.frames[0]
110
+ fps = 24
111
+ target_frames = fps * 8
112
+ if len(frames) < target_frames:
113
+ frames = np.tile(frames, (target_frames // len(frames) + 1, 1, 1, 1))[:target_frames]
114
+ else:
115
+ frames = frames[:target_frames]
116
+
117
+ # Tạo video
118
+ name = str(uuid.uuid4()).replace("-", "")
119
+ video_path = os.path.join(output_dir, f"{name}.mp4")
120
+ export_to_video(frames, video_path, fps=fps)
121
+
122
+ if not os.path.exists(video_path):
123
+ raise FileNotFoundError("Video không được tạo")
124
+
125
+ logger.info(f"Video sẵn sàng tại {video_path}")
126
+
127
+ # Trả về file video
128
+ return FileResponse(video_path, media_type="video/mp4", filename=f"{name}.mp4")
129
+
130
+ except Exception as e:
131
+ logger.error(f"Lỗi khi tạo video: {e}")
132
+ raise HTTPException(status_code=500, detail=str(e))
133
+
134
+ # Endpoint kiểm tra trạng thái
135
+ @app.get("/")
136
+ async def root():
137
+ return {"message": "FastAPI AnimateDiff-Lightning API on Hugging Face Spaces"}
138
+ ```