hoangkha1810 commited on
Commit
b9db3fe
·
verified ·
1 Parent(s): 7e1663e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -0
app.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import os
4
+ import uuid
5
+ from diffusers import AnimateDiffPipeline, EulerDiscreteScheduler
6
+ from diffusers.utils import export_to_video
7
+ from huggingface_hub import hf_hub_download
8
+ from safetensors.torch import load_file
9
+
10
+ # Constants
11
+ bases = {
12
+ "Cartoon": "frankjoshua/toonyou_beta6",
13
+ "Realistic": "emilianJR/epiCRealism",
14
+ "3d": "Lykon/DreamShaper",
15
+ "Anime": "Yntec/mistoonAnime2"
16
+ }
17
+ step_loaded = None
18
+ base_loaded = "Realistic"
19
+ motion_loaded = None
20
+
21
+ # Thiết lập thiết bị CPU và kiểu dữ liệu
22
+ device = "cpu"
23
+ dtype = torch.float32 # Sử dụng float32 thay vì float16 cho CPU
24
+
25
+ # Khởi tạo pipeline
26
+ pipe = AnimateDiffPipeline.from_pretrained(bases[base_loaded], torch_dtype=dtype).to(device)
27
+ pipe.scheduler = EulerDiscreteScheduler.from_config(
28
+ pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear"
29
+ )
30
+ # Tắt safety checker để tăng tốc
31
+ pipe.safety_checker = None
32
+
33
+ # Hàm tạo video
34
+ def generate_image(prompt, base="Realistic", motion="", step=1, progress=gr.Progress()):
35
+ global step_loaded, base_loaded, motion_loaded
36
+ step = int(step)
37
+ print(f"Generating video with prompt: {prompt}, base: {base}, steps: {step}")
38
+
39
+ # Tải checkpoint AnimateDiff-Lightning
40
+ if step_loaded != step:
41
+ repo = "ByteDance/AnimateDiff-Lightning"
42
+ ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
43
+ pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
44
+ step_loaded = step
45
+
46
+ # Tải mô hình cơ sở
47
+ if base_loaded != base:
48
+ pipe.unet.load_state_dict(
49
+ torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device),
50
+ strict=False
51
+ )
52
+ base_loaded = base
53
+
54
+ # Tải motion LoRA (tùy chọn)
55
+ if motion_loaded != motion:
56
+ pipe.unload_lora_weights()
57
+ if motion != "":
58
+ pipe.load_lora_weights(motion, adapter_name="motion")
59
+ pipe.set_adapters(["motion"], [0.7])
60
+ motion_loaded = motion
61
+
62
+ progress((0, step))
63
+ def progress_callback(i, t, z):
64
+ progress((i + 1, step))
65
+
66
+ # Tối ưu hóa suy luận
67
+ with torch.no_grad(): # Tắt gradient để tiết kiệm bộ nhớ
68
+ output = pipe(
69
+ prompt=prompt,
70
+ guidance_scale=1.2,
71
+ num_inference_steps=step,
72
+ callback=progress_callback,
73
+ callback_steps=1
74
+ )
75
+
76
+ # Xuất video
77
+ name = str(uuid.uuid4()).replace("-", "")
78
+ path = f"/tmp/{name}.mp4"
79
+ export_to_video(output.frames[0], path, fps=10)
80
+ return path
81
+
82
+ # Giao diện Gradio
83
+ css = """
84
+ body {font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; background-color: #f4f4f9; color: #333;}
85
+ h1 {color: #333; text-align: center; margin-bottom: 20px;}
86
+ .gradio-container {max-width: 800px; margin: auto; padding: 20px; background: #fff; box-shadow: 0px 0px 20px rgba(0,0,0,0.1); border-radius: 10px;}
87
+ .gr-input {margin-bottom: 15px;}
88
+ .gr-button {width: 100%; background-color: #4CAF50; color: white; border: none; padding: 10px 20px; text-align: center; text-decoration: none; display: inline-block; font-size: 16px; border-radius: 5px; cursor: pointer; transition: background-color 0.3s;}
89
+ .gr-button:hover {background-color: #45a049;}
90
+ .gr-video {margin-top: 20px;}
91
+ .gr-examples {margin-top: 30px;}
92
+ .gr-examples .gr-example {display: inline-block; width: 100%; text-align: center; padding: 10px; background: #eaeaea; border-radius: 5px; margin-bottom: 10px;}
93
+ .container {display: flex; flex-wrap: wrap;}
94
+ .inputs, .output {padding: 20px;}
95
+ .inputs {flex: 1; min-width: 300px;}
96
+ .output {flex: 1; min-width: 300px;}
97
+ @media (max-width: 768px) {
98
+ .container {flex-direction: column-reverse;}
99
+ }
100
+ .svelte-1ybb3u7, .svelte-1clup3e {display: none !important;}
101
+ """
102
+
103
+ with gr.Blocks(css=css) as demo:
104
+ gr.HTML("<h1>Instant⚡ Text to Video</h1>")
105
+ with gr.Row(elem_id="container"):
106
+ with gr.Column(elem_id="inputs"):
107
+ prompt = gr.Textbox(label="Prompt", placeholder="Enter text to generate video...", elem_id="gr-input")
108
+ select_base = gr.Dropdown(
109
+ label="Base model",
110
+ choices=["Cartoon", "Realistic", "3d", "Anime"],
111
+ value=base_loaded,
112
+ interactive=True,
113
+ elem_id="gr-input"
114
+ )
115
+ select_motion = gr.Dropdown(
116
+ label="Motion",
117
+ choices=[
118
+ ("Default", ""),
119
+ ("Zoom in", "guoyww/animatediff-motion-lora-zoom-in"),
120
+ ("Zoom out", "guoyww/animatediff-motion-lora-zoom-out"),
121
+ ("Tilt up", "guoyww/animatediff-motion-lora-tilt-up"),
122
+ ("Tilt down", "guoyww/animatediff-motion-lora-tilt-down"),
123
+ ("Pan left", "guoyww/animatediff-motion-lora-pan-left"),
124
+ ("Pan right", "guoyww/animatediff-motion-lora-pan-right"),
125
+ ("Roll left", "guoyww/animatediff-motion-lora-rolling-anticockwise"),
126
+ ("Roll right", "guoyww/animatediff-motion-lora-rolling-clockwise"),
127
+ ],
128
+ value="guoyww/animatediff-motion-lora-zoom-in",
129
+ interactive=True,
130
+ elem_id="gr-input"
131
+ )
132
+ select_step = gr.Dropdown(
133
+ label="Inference steps",
134
+ choices=[("1-Step", 1), ("2-Step", 2), ("4-Step", 4), ("8-Step", 8)],
135
+ value=1,
136
+ interactive=True,
137
+ elem_id="gr-input"
138
+ )
139
+ submit = gr.Button("Generate Video", variant="primary", elem_id="gr-button")
140
+ with gr.Column(elem_id="output"):
141
+ video = gr.Video(label="AnimateDiff-Lightning", autoplay=True, height=512, width=512, elem_id="gr-video")
142
+
143
+ prompt.submit(fn=generate_image, inputs=[prompt, select_base, select_motion, select_step], outputs=video)
144
+ submit.click(fn=generate_image, inputs=[prompt, select_base, select_motion, select_step], outputs=video)
145
+
146
+ gr.Examples(
147
+ examples=[
148
+ ["Focus: Eiffel Tower (Animate: Clouds moving)"],
149
+ ["Focus: Trees In forest (Animate: Lion running)"],
150
+ ["Focus: Astronaut in Space"],
151
+ ["Focus: Group of Birds in sky (Animate: Birds Moving) (Shot From distance)"],
152
+ ["Focus: Statue of liberty (Shot from Drone) (Animate: Drone coming toward statue)"],
153
+ ["Focus: Panda in Forest (Animate: Drinking Tea)"],
154
+ ["Focus: Kids Playing (Season: Winter)"],
155
+ ["Focus: Cars in Street (Season: Rain, Daytime) (Shot from Distance) (Movement: Cars running)"]
156
+ ],
157
+ fn=generate_image,
158
+ inputs=[prompt],
159
+ outputs=video,
160
+ cache_examples=False,
161
+ elem_id="gr-examples"
162
+ )
163
+
164
+ demo.launch()