|
|
|
|
|
|
|
import os |
|
import sys |
|
import random |
|
import argparse |
|
import urllib.request |
|
import uuid |
|
|
|
import torch |
|
import gradio as gr |
|
from omegaconf import OmegaConf |
|
import av |
|
import numpy as np |
|
from PIL import Image |
|
|
|
|
|
if torch.cuda.is_available(): |
|
import subprocess |
|
print("► GPU phát hiện. Đang cài flash-attn…") |
|
subprocess.run( |
|
'pip install -q flash-attn --no-build-isolation', |
|
shell=True, |
|
check=True |
|
) |
|
else: |
|
print("⚠️ Không tìm thấy GPU, bỏ qua cài flash-attn.") |
|
|
|
|
|
from huggingface_hub import snapshot_download, hf_hub_download |
|
|
|
snapshot_download( |
|
repo_id="Wan-AI/Wan2.1-T2V-1.3B", |
|
local_dir="wan_models/Wan2.1-T2V-1.3B", |
|
resume_download=True, |
|
repo_type="model" |
|
) |
|
|
|
hf_hub_download( |
|
repo_id="gdhe17/Self-Forcing", |
|
filename="checkpoints/self_forcing_dmd.pt", |
|
local_dir="." |
|
) |
|
|
|
|
|
from pipeline import CausalInferencePipeline |
|
from demo_utils.constant import ZERO_VAE_CACHE |
|
from demo_utils.vae_block3 import VAEDecoderWrapper |
|
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder |
|
|
|
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
dtype_common = torch.float16 if device.type == "cuda" else torch.float32 |
|
print(f"► Sử dụng device={device}, dtype={dtype_common}") |
|
|
|
|
|
MODEL_CKPT = "Qwen/Qwen3-8B" |
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_CKPT) |
|
|
|
model_kwargs = {"torch_dtype": dtype_common} |
|
if device.type == "cuda": |
|
model_kwargs.update({ |
|
"attn_implementation": "flash_attention_2", |
|
"device_map": "auto", |
|
}) |
|
|
|
model = AutoModelForCausalLM.from_pretrained(MODEL_CKPT, **model_kwargs) |
|
model.to(device) |
|
|
|
def simple_generate(text: str) -> str: |
|
"""Sinh văn bản bằng model.generate (thay vì pipeline).""" |
|
inputs = tokenizer(text, return_tensors="pt").to(device) |
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=256, |
|
repetition_penalty=1.2, |
|
do_sample=False, |
|
pad_token_id=tokenizer.eos_token_id |
|
) |
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
T2V_CINEMATIC_PROMPT = ( |
|
"You are a prompt engineer, aiming to rewrite user inputs into high-quality prompts " |
|
"for better video generation without affecting the original meaning.\n" |
|
"Task requirements:\n" |
|
"1. For overly concise user inputs, infer and add details to make the video more complete...\n" |
|
"2. Enhance main features (appearance, expression, posture, style...)\n" |
|
"3. Output in English, preserving quotes/titles\n" |
|
"4. Match user’s intent and style\n" |
|
"5. Emphasize motion & camera movements\n" |
|
"6. Add natural actions with simple verbs\n" |
|
"7. Length ~80–100 words\n" |
|
"I will now provide the prompt. Please rewrite accordingly without extra văn bản." |
|
) |
|
|
|
def enhance_prompt(prompt: str) -> str: |
|
msgs = [ |
|
{"role": "system", "content": T2V_CINEMATIC_PROMPT}, |
|
{"role": "user", "content": prompt} |
|
] |
|
text_in = tokenizer.apply_chat_template( |
|
msgs, |
|
tokenize=False, |
|
add_generation_prompt=True, |
|
enable_thinking=False |
|
) |
|
return simple_generate(text_in).strip() |
|
|
|
|
|
parser = argparse.ArgumentParser( |
|
description="Demo Gradio tạo video Self-Forcing (streaming)" |
|
) |
|
parser.add_argument('--port', type=int, default=7860) |
|
parser.add_argument('--host', type=str, default='0.0.0.0') |
|
parser.add_argument('--share', action='store_true') |
|
parser.add_argument('--checkpoint_path', type=str, |
|
default='./checkpoints/self_forcing_dmd.pt') |
|
parser.add_argument('--config_path', type=str, |
|
default='./configs/self_forcing_dmd.yaml') |
|
parser.add_argument('--trt', action='store_true', |
|
help="Sử dụng TensorRT VAE decoder") |
|
parser.add_argument('--fps', type=float, default=15.0) |
|
args = parser.parse_args() |
|
|
|
|
|
try: |
|
cfg_user = OmegaConf.load(args.config_path) |
|
cfg_def = OmegaConf.load("configs/default_config.yaml") |
|
config = OmegaConf.merge(cfg_def, cfg_user) |
|
except FileNotFoundError as e: |
|
print(f"[LỖI] Không tìm thấy config: {e}") |
|
sys.exit(1) |
|
|
|
|
|
print("► Khởi tạo mô hình Self-Forcing…") |
|
text_encoder = WanTextEncoder()\ |
|
.eval().to(device=device, dtype=dtype_common).requires_grad_(False) |
|
transformer = WanDiffusionWrapper(is_causal=True)\ |
|
.eval().to(device=device, dtype=dtype_common).requires_grad_(False) |
|
|
|
|
|
try: |
|
sd = torch.load(args.checkpoint_path, map_location="cpu") |
|
key = 'generator_ema' if 'generator_ema' in sd else 'generator' |
|
transformer.load_state_dict(sd[key]) |
|
except FileNotFoundError as e: |
|
print(f"[LỖI] Không tìm thấy checkpoint: {e}") |
|
sys.exit(1) |
|
|
|
|
|
APP_STATE = {"current_use_taehv": False, "current_vae_decoder": None} |
|
|
|
def initialize_vae_decoder(use_taehv=False, use_trt=False): |
|
"""Khởi tạo VAE decoder: Default / TAEHV / TensorRT.""" |
|
if use_trt: |
|
from demo_utils.vae import VAETRTWrapper |
|
vae = VAETRTWrapper() |
|
APP_STATE["current_use_taehv"] = False |
|
print("► Dùng TensorRT VAE") |
|
elif use_taehv: |
|
from demo_utils.taehv import TAEHV |
|
ckpt = "checkpoints/taew2_1.pth" |
|
if not os.path.exists(ckpt): |
|
os.makedirs("checkpoints", exist_ok=True) |
|
urllib.request.urlretrieve( |
|
"https://github.com/madebyollin/taehv/raw/main/taew2_1.pth", ckpt |
|
) |
|
class TAEHVDiffuser(torch.nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.taehv = TAEHV(checkpoint_path=ckpt).to(dtype_common) |
|
def decode(self, latents, return_dict=None): |
|
return self.taehv.decode_video(latents).mul_(2).sub_(1) |
|
vae = TAEHVDiffuser() |
|
APP_STATE["current_use_taehv"] = True |
|
print("► Dùng TAEHV VAE") |
|
else: |
|
vae = VAEDecoderWrapper() |
|
try: |
|
sd_vae = torch.load( |
|
'wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth', |
|
map_location="cpu" |
|
) |
|
dd = {k: v for k, v in sd_vae.items() |
|
if 'decoder.' in k or 'conv2' in k} |
|
vae.load_state_dict(dd) |
|
except FileNotFoundError: |
|
print("⚠️ Không tìm thấy trọng số VAE mặc định.") |
|
APP_STATE["current_use_taehv"] = False |
|
print("► Dùng Default VAE") |
|
vae.eval().to(device=device, dtype=dtype_common).requires_grad_(False) |
|
APP_STATE["current_vae_decoder"] = vae |
|
|
|
|
|
initialize_vae_decoder(use_taehv=False, use_trt=args.trt) |
|
|
|
|
|
pipeline = CausalInferencePipeline( |
|
config=config, |
|
device=device, |
|
generator=transformer, |
|
text_encoder=text_encoder, |
|
vae=APP_STATE["current_vae_decoder"] |
|
) |
|
pipeline.to(device=device, dtype=dtype_common) |
|
|
|
def frames_to_ts_file(frames, filepath, fps=15): |
|
"""Chuyển list numpy frames -> file .ts để streaming.""" |
|
if not frames: |
|
return filepath |
|
h, w = frames[0].shape[:2] |
|
container = av.open(filepath, mode='w', format='mpegts') |
|
stream = container.add_stream('h264', rate=fps) |
|
stream.width, stream.height = w, h |
|
stream.pix_fmt = 'yuv420p' |
|
stream.options = { |
|
'preset':'ultrafast','tune':'zerolatency', |
|
'crf':'23','profile':'baseline','level':'3.0' |
|
} |
|
try: |
|
for fr in frames: |
|
vf = av.VideoFrame.from_ndarray(fr, format='rgb24')\ |
|
.reformat(format=stream.pix_fmt) |
|
for pkt in stream.encode(vf): |
|
container.mux(pkt) |
|
for pkt in stream.encode(): |
|
container.mux(pkt) |
|
finally: |
|
container.close() |
|
return filepath |
|
|
|
@torch.no_grad() |
|
def video_generation_handler_streaming(prompt, seed, fps): |
|
"""Sinh video streaming từng chunk .ts.""" |
|
if seed == -1: |
|
seed = random.randint(0, 2**32 - 1) |
|
print(f"▶️ Tạo video: '{prompt}', seed={seed}") |
|
|
|
cond = text_encoder(text_prompts=[prompt]) |
|
for k, v in cond.items(): |
|
cond[k] = v.to(device=device, dtype=dtype_common) |
|
|
|
rnd = torch.Generator(device=device).manual_seed(seed) |
|
pipeline._initialize_kv_cache(1, dtype_common, device=device) |
|
pipeline._initialize_crossattn_cache(1, dtype_common, device=device) |
|
noise = torch.randn([1,21,16,60,104], |
|
device=device, dtype=dtype_common, generator=rnd) |
|
|
|
vae_cache, latents_cache = None, None |
|
if not APP_STATE["current_use_taehv"] and not args.trt: |
|
vae_cache = [c.to(device=device, dtype=dtype_common) |
|
for c in ZERO_VAE_CACHE] |
|
|
|
num_blocks = 7 |
|
start_frame = 0 |
|
total_frames = 0 |
|
os.makedirs("gradio_tmp", exist_ok=True) |
|
|
|
for ib in range(num_blocks): |
|
nfr = pipeline.num_frame_per_block |
|
block_noise = noise[:, start_frame:start_frame+nfr] |
|
|
|
|
|
for i_step, t in enumerate(pipeline.denoising_step_list): |
|
tim = torch.full((1,nfr), t, dtype=torch.long, device=device) |
|
_, pred = pipeline.generator( |
|
noisy_image_or_video=block_noise, |
|
conditional_dict=cond, |
|
timestep=tim, |
|
kv_cache=pipeline.kv_cache1, |
|
crossattn_cache=pipeline.crossattn_cache, |
|
current_start=start_frame * pipeline.frame_seq_length |
|
) |
|
if i_step < len(pipeline.denoising_step_list)-1: |
|
nxt = pipeline.denoising_step_list[i_step+1] |
|
block_noise = pipeline.scheduler.add_noise( |
|
pred.flatten(0,1), |
|
torch.randn_like(pred.flatten(0,1)), |
|
torch.full((nfr,), nxt, device=device, dtype=torch.long) |
|
).unflatten(0, pred.shape[:2]) |
|
|
|
|
|
if args.trt: |
|
pixels, vae_cache = pipeline.vae.forward(pred.half(), *vae_cache) |
|
elif APP_STATE["current_use_taehv"]: |
|
if latents_cache is None: |
|
latents_cache = pred |
|
else: |
|
pred = torch.cat([latents_cache, pred], dim=1) |
|
latents_cache = pred[:, -3:] |
|
pixels = pipeline.vae.decode(pred) |
|
else: |
|
pixels, vae_cache = pipeline.vae(pred.half(), *vae_cache) |
|
|
|
|
|
if ib == 0 and not args.trt: |
|
pixels = pixels[:, 3:] |
|
elif APP_STATE["current_use_taehv"] and ib > 0: |
|
pixels = pixels[:, 12:] |
|
|
|
|
|
frames = [] |
|
for f in range(pixels.shape[1]): |
|
img = pixels[0, f] |
|
arr = ((img.clamp(-1,1).float()*127.5+127.5) |
|
.to(torch.uint8).cpu().numpy()) |
|
arr = np.transpose(arr, (1,2,0)) |
|
frames.append(arr) |
|
total_frames += 1 |
|
|
|
prog = (ib + (f+1)/pixels.shape[1]) / num_blocks * 100 |
|
yield None, gr.HTML( |
|
f"<div style='padding:10px;border:1px solid #ddd;" |
|
f"border-radius:8px;font-family:sans-serif'>" |
|
f"<strong>Đang tạo… {prog:.1f}%</strong>" |
|
f"</div>" |
|
) |
|
|
|
|
|
ts_file = f"block_{ib:02d}_{uuid.uuid4().hex[:8]}.ts" |
|
ts_path = os.path.join("gradio_tmp", ts_file) |
|
frames_to_ts_file(frames, ts_path, fps) |
|
yield ts_path, gr.update() |
|
start_frame += nfr |
|
|
|
|
|
yield None, gr.HTML( |
|
f"<div style='padding:16px;border:1px solid #198754;" |
|
f"background:#d1e7dd;border-radius:8px'>" |
|
f"<h4>✅ Hoàn thành! Tạo được {total_frames} khung hình.</h4>" |
|
f"</div>" |
|
) |
|
print("▶️ Streaming hoàn tất.") |
|
|
|
|
|
with gr.Blocks(title="Demo Self-Forcing Streaming") as demo: |
|
gr.Markdown("# 🚀 Tạo video Self-Forcing theo thời gian thực") |
|
gr.Markdown("Nhập prompt và nhấn 'Bắt đầu phát trực tiếp' để sinh video.") |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=2): |
|
prompt = gr.Textbox( |
|
label="✏️ Prompt", lines=4, |
|
placeholder="Ví dụ: Một cô gái đội nón lá đi trên cánh đồng lúa..." |
|
) |
|
btn_enh = gr.Button("✨ Tối ưu Prompt", variant="secondary") |
|
btn_start= gr.Button("🎬 Bắt đầu phát trực tiếp", size="lg") |
|
|
|
gr.Markdown("### 🎯 Ví dụ") |
|
gr.Examples( |
|
examples=[ |
|
"A close-up shot of a ceramic teacup slowly pouring water into a glass mug.", |
|
"A playful cat is seen playing an electronic guitar...", |
|
"A dynamic over-the-shoulder perspective of a chef plating a dish..." |
|
], |
|
inputs=[prompt] |
|
) |
|
|
|
gr.Markdown("### ⚙️ Cài đặt") |
|
seed = gr.Number( |
|
label="Seed", value=-1, precision=0, |
|
info="Nhập -1 để sinh ngẫu nhiên" |
|
) |
|
fps = gr.Slider( |
|
label="FPS phát lại", minimum=1, maximum=30, |
|
value=args.fps, step=1 |
|
) |
|
|
|
with gr.Column(scale=3): |
|
gr.Markdown("### 📺 Video phát trực tiếp") |
|
vid_stream= gr.Video(streaming=True, loop=True, |
|
autoplay=True, height=400) |
|
status_el = gr.HTML( |
|
"<div style='text-align:center;color:#666;" |
|
"border:1px dashed #ddd;padding:20px;'>" |
|
"🎬 Sẵn sàng…</div>" |
|
) |
|
|
|
btn_enh.click(fn=enhance_prompt, inputs=[prompt], outputs=[prompt]) |
|
btn_start.click( |
|
fn=video_generation_handler_streaming, |
|
inputs=[prompt, seed, fps], |
|
outputs=[vid_stream, status_el] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
if os.path.exists("gradio_tmp"): |
|
import shutil; shutil.rmtree("gradio_tmp") |
|
os.makedirs("gradio_tmp", exist_ok=True) |
|
|
|
print(f"► Khởi động server trên {args.host}:{args.port}, share={args.share}") |
|
demo.queue().launch( |
|
server_name=args.host, |
|
server_port=args.port, |
|
share=args.share |
|
) |
|
|