TTV / app.py
LTTEAM's picture
Update app.py
c3894e0 verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import sys
import random
import argparse
import urllib.request
import uuid
import torch
import gradio as gr
from omegaconf import OmegaConf
import av
import numpy as np
from PIL import Image
# --- 1. Cài flash-attn chỉ khi có GPU ---
if torch.cuda.is_available():
import subprocess
print("► GPU phát hiện. Đang cài flash-attn…")
subprocess.run(
'pip install -q flash-attn --no-build-isolation',
shell=True,
check=True
)
else:
print("⚠️ Không tìm thấy GPU, bỏ qua cài flash-attn.")
# --- 2. Tải code / trọng số từ Hugging Face Hub ---
from huggingface_hub import snapshot_download, hf_hub_download
snapshot_download(
repo_id="Wan-AI/Wan2.1-T2V-1.3B",
local_dir="wan_models/Wan2.1-T2V-1.3B",
resume_download=True,
repo_type="model"
)
hf_hub_download(
repo_id="gdhe17/Self-Forcing",
filename="checkpoints/self_forcing_dmd.pt",
local_dir="."
)
# --- 3. Imports core pipeline & VAE ---
from pipeline import CausalInferencePipeline
from demo_utils.constant import ZERO_VAE_CACHE
from demo_utils.vae_block3 import VAEDecoderWrapper
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder
# --- 4. Imports từ transformers (chỉ cần AutoTokenizer/AutoModel) ---
from transformers import AutoTokenizer, AutoModelForCausalLM
# --- 5. Xác định device & dtype chung ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype_common = torch.float16 if device.type == "cuda" else torch.float32
print(f"► Sử dụng device={device}, dtype={dtype_common}")
# --- 6. Chuẩn bị tokenizer + Qwen model để optimize prompt ---
MODEL_CKPT = "Qwen/Qwen3-8B"
tokenizer = AutoTokenizer.from_pretrained(MODEL_CKPT)
model_kwargs = {"torch_dtype": dtype_common}
if device.type == "cuda":
model_kwargs.update({
"attn_implementation": "flash_attention_2",
"device_map": "auto",
})
model = AutoModelForCausalLM.from_pretrained(MODEL_CKPT, **model_kwargs)
model.to(device)
def simple_generate(text: str) -> str:
"""Sinh văn bản bằng model.generate (thay vì pipeline)."""
inputs = tokenizer(text, return_tensors="pt").to(device)
outputs = model.generate(
**inputs,
max_new_tokens=256,
repetition_penalty=1.2,
do_sample=False,
pad_token_id=tokenizer.eos_token_id
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# --- 7. Prompt mẫu cinematic T2V ---
T2V_CINEMATIC_PROMPT = (
"You are a prompt engineer, aiming to rewrite user inputs into high-quality prompts "
"for better video generation without affecting the original meaning.\n"
"Task requirements:\n"
"1. For overly concise user inputs, infer and add details to make the video more complete...\n"
"2. Enhance main features (appearance, expression, posture, style...)\n"
"3. Output in English, preserving quotes/titles\n"
"4. Match user’s intent and style\n"
"5. Emphasize motion & camera movements\n"
"6. Add natural actions with simple verbs\n"
"7. Length ~80–100 words\n"
"I will now provide the prompt. Please rewrite accordingly without extra văn bản."
)
def enhance_prompt(prompt: str) -> str:
msgs = [
{"role": "system", "content": T2V_CINEMATIC_PROMPT},
{"role": "user", "content": prompt}
]
text_in = tokenizer.apply_chat_template(
msgs,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False
)
return simple_generate(text_in).strip()
# --- 8. Argument parsing ---
parser = argparse.ArgumentParser(
description="Demo Gradio tạo video Self-Forcing (streaming)"
)
parser.add_argument('--port', type=int, default=7860)
parser.add_argument('--host', type=str, default='0.0.0.0')
parser.add_argument('--share', action='store_true')
parser.add_argument('--checkpoint_path', type=str,
default='./checkpoints/self_forcing_dmd.pt')
parser.add_argument('--config_path', type=str,
default='./configs/self_forcing_dmd.yaml')
parser.add_argument('--trt', action='store_true',
help="Sử dụng TensorRT VAE decoder")
parser.add_argument('--fps', type=float, default=15.0)
args = parser.parse_args()
# --- 9. Load config Self-Forcing ---
try:
cfg_user = OmegaConf.load(args.config_path)
cfg_def = OmegaConf.load("configs/default_config.yaml")
config = OmegaConf.merge(cfg_def, cfg_user)
except FileNotFoundError as e:
print(f"[LỖI] Không tìm thấy config: {e}")
sys.exit(1)
# --- 10. Khởi tạo text_encoder & transformer ---
print("► Khởi tạo mô hình Self-Forcing…")
text_encoder = WanTextEncoder()\
.eval().to(device=device, dtype=dtype_common).requires_grad_(False)
transformer = WanDiffusionWrapper(is_causal=True)\
.eval().to(device=device, dtype=dtype_common).requires_grad_(False)
# Load checkpoint
try:
sd = torch.load(args.checkpoint_path, map_location="cpu")
key = 'generator_ema' if 'generator_ema' in sd else 'generator'
transformer.load_state_dict(sd[key])
except FileNotFoundError as e:
print(f"[LỖI] Không tìm thấy checkpoint: {e}")
sys.exit(1)
# Trạng thái VAE
APP_STATE = {"current_use_taehv": False, "current_vae_decoder": None}
def initialize_vae_decoder(use_taehv=False, use_trt=False):
"""Khởi tạo VAE decoder: Default / TAEHV / TensorRT."""
if use_trt:
from demo_utils.vae import VAETRTWrapper
vae = VAETRTWrapper()
APP_STATE["current_use_taehv"] = False
print("► Dùng TensorRT VAE")
elif use_taehv:
from demo_utils.taehv import TAEHV
ckpt = "checkpoints/taew2_1.pth"
if not os.path.exists(ckpt):
os.makedirs("checkpoints", exist_ok=True)
urllib.request.urlretrieve(
"https://github.com/madebyollin/taehv/raw/main/taew2_1.pth", ckpt
)
class TAEHVDiffuser(torch.nn.Module):
def __init__(self):
super().__init__()
self.taehv = TAEHV(checkpoint_path=ckpt).to(dtype_common)
def decode(self, latents, return_dict=None):
return self.taehv.decode_video(latents).mul_(2).sub_(1)
vae = TAEHVDiffuser()
APP_STATE["current_use_taehv"] = True
print("► Dùng TAEHV VAE")
else:
vae = VAEDecoderWrapper()
try:
sd_vae = torch.load(
'wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth',
map_location="cpu"
)
dd = {k: v for k, v in sd_vae.items()
if 'decoder.' in k or 'conv2' in k}
vae.load_state_dict(dd)
except FileNotFoundError:
print("⚠️ Không tìm thấy trọng số VAE mặc định.")
APP_STATE["current_use_taehv"] = False
print("► Dùng Default VAE")
vae.eval().to(device=device, dtype=dtype_common).requires_grad_(False)
APP_STATE["current_vae_decoder"] = vae
# Khởi decoder lần đầu
initialize_vae_decoder(use_taehv=False, use_trt=args.trt)
# --- 11. Xây dựng pipeline ---
pipeline = CausalInferencePipeline(
config=config,
device=device,
generator=transformer,
text_encoder=text_encoder,
vae=APP_STATE["current_vae_decoder"]
)
pipeline.to(device=device, dtype=dtype_common)
def frames_to_ts_file(frames, filepath, fps=15):
"""Chuyển list numpy frames -> file .ts để streaming."""
if not frames:
return filepath
h, w = frames[0].shape[:2]
container = av.open(filepath, mode='w', format='mpegts')
stream = container.add_stream('h264', rate=fps)
stream.width, stream.height = w, h
stream.pix_fmt = 'yuv420p'
stream.options = {
'preset':'ultrafast','tune':'zerolatency',
'crf':'23','profile':'baseline','level':'3.0'
}
try:
for fr in frames:
vf = av.VideoFrame.from_ndarray(fr, format='rgb24')\
.reformat(format=stream.pix_fmt)
for pkt in stream.encode(vf):
container.mux(pkt)
for pkt in stream.encode():
container.mux(pkt)
finally:
container.close()
return filepath
@torch.no_grad()
def video_generation_handler_streaming(prompt, seed, fps):
"""Sinh video streaming từng chunk .ts."""
if seed == -1:
seed = random.randint(0, 2**32 - 1)
print(f"▶️ Tạo video: '{prompt}', seed={seed}")
cond = text_encoder(text_prompts=[prompt])
for k, v in cond.items():
cond[k] = v.to(device=device, dtype=dtype_common)
rnd = torch.Generator(device=device).manual_seed(seed)
pipeline._initialize_kv_cache(1, dtype_common, device=device)
pipeline._initialize_crossattn_cache(1, dtype_common, device=device)
noise = torch.randn([1,21,16,60,104],
device=device, dtype=dtype_common, generator=rnd)
vae_cache, latents_cache = None, None
if not APP_STATE["current_use_taehv"] and not args.trt:
vae_cache = [c.to(device=device, dtype=dtype_common)
for c in ZERO_VAE_CACHE]
num_blocks = 7
start_frame = 0
total_frames = 0
os.makedirs("gradio_tmp", exist_ok=True)
for ib in range(num_blocks):
nfr = pipeline.num_frame_per_block
block_noise = noise[:, start_frame:start_frame+nfr]
# Denoise
for i_step, t in enumerate(pipeline.denoising_step_list):
tim = torch.full((1,nfr), t, dtype=torch.long, device=device)
_, pred = pipeline.generator(
noisy_image_or_video=block_noise,
conditional_dict=cond,
timestep=tim,
kv_cache=pipeline.kv_cache1,
crossattn_cache=pipeline.crossattn_cache,
current_start=start_frame * pipeline.frame_seq_length
)
if i_step < len(pipeline.denoising_step_list)-1:
nxt = pipeline.denoising_step_list[i_step+1]
block_noise = pipeline.scheduler.add_noise(
pred.flatten(0,1),
torch.randn_like(pred.flatten(0,1)),
torch.full((nfr,), nxt, device=device, dtype=torch.long)
).unflatten(0, pred.shape[:2])
# Decode VAE
if args.trt:
pixels, vae_cache = pipeline.vae.forward(pred.half(), *vae_cache)
elif APP_STATE["current_use_taehv"]:
if latents_cache is None:
latents_cache = pred
else:
pred = torch.cat([latents_cache, pred], dim=1)
latents_cache = pred[:, -3:]
pixels = pipeline.vae.decode(pred)
else:
pixels, vae_cache = pipeline.vae(pred.half(), *vae_cache)
# Skip frame đầu block nếu cần
if ib == 0 and not args.trt:
pixels = pixels[:, 3:]
elif APP_STATE["current_use_taehv"] and ib > 0:
pixels = pixels[:, 12:]
# Chuyển thành numpy frames
frames = []
for f in range(pixels.shape[1]):
img = pixels[0, f]
arr = ((img.clamp(-1,1).float()*127.5+127.5)
.to(torch.uint8).cpu().numpy())
arr = np.transpose(arr, (1,2,0))
frames.append(arr)
total_frames += 1
prog = (ib + (f+1)/pixels.shape[1]) / num_blocks * 100
yield None, gr.HTML(
f"<div style='padding:10px;border:1px solid #ddd;"
f"border-radius:8px;font-family:sans-serif'>"
f"<strong>Đang tạo… {prog:.1f}%</strong>"
f"</div>"
)
# Ghi chunk .ts
ts_file = f"block_{ib:02d}_{uuid.uuid4().hex[:8]}.ts"
ts_path = os.path.join("gradio_tmp", ts_file)
frames_to_ts_file(frames, ts_path, fps)
yield ts_path, gr.update()
start_frame += nfr
# Hoàn thành
yield None, gr.HTML(
f"<div style='padding:16px;border:1px solid #198754;"
f"background:#d1e7dd;border-radius:8px'>"
f"<h4>✅ Hoàn thành! Tạo được {total_frames} khung hình.</h4>"
f"</div>"
)
print("▶️ Streaming hoàn tất.")
# --- 12. Giao diện Gradio (tiếng Việt) ---
with gr.Blocks(title="Demo Self-Forcing Streaming") as demo:
gr.Markdown("# 🚀 Tạo video Self-Forcing theo thời gian thực")
gr.Markdown("Nhập prompt và nhấn 'Bắt đầu phát trực tiếp' để sinh video.")
with gr.Row():
with gr.Column(scale=2):
prompt = gr.Textbox(
label="✏️ Prompt", lines=4,
placeholder="Ví dụ: Một cô gái đội nón lá đi trên cánh đồng lúa..."
)
btn_enh = gr.Button("✨ Tối ưu Prompt", variant="secondary")
btn_start= gr.Button("🎬 Bắt đầu phát trực tiếp", size="lg")
gr.Markdown("### 🎯 Ví dụ")
gr.Examples(
examples=[
"A close-up shot of a ceramic teacup slowly pouring water into a glass mug.",
"A playful cat is seen playing an electronic guitar...",
"A dynamic over-the-shoulder perspective of a chef plating a dish..."
],
inputs=[prompt]
)
gr.Markdown("### ⚙️ Cài đặt")
seed = gr.Number(
label="Seed", value=-1, precision=0,
info="Nhập -1 để sinh ngẫu nhiên"
)
fps = gr.Slider(
label="FPS phát lại", minimum=1, maximum=30,
value=args.fps, step=1
)
with gr.Column(scale=3):
gr.Markdown("### 📺 Video phát trực tiếp")
vid_stream= gr.Video(streaming=True, loop=True,
autoplay=True, height=400)
status_el = gr.HTML(
"<div style='text-align:center;color:#666;"
"border:1px dashed #ddd;padding:20px;'>"
"🎬 Sẵn sàng…</div>"
)
btn_enh.click(fn=enhance_prompt, inputs=[prompt], outputs=[prompt])
btn_start.click(
fn=video_generation_handler_streaming,
inputs=[prompt, seed, fps],
outputs=[vid_stream, status_el]
)
# --- 13. Chạy app ---
if __name__ == "__main__":
# Dọn thư mục tạm
if os.path.exists("gradio_tmp"):
import shutil; shutil.rmtree("gradio_tmp")
os.makedirs("gradio_tmp", exist_ok=True)
print(f"► Khởi động server trên {args.host}:{args.port}, share={args.share}")
demo.queue().launch(
server_name=args.host,
server_port=args.port,
share=args.share
)