import gradio as gr import torch import os import uuid import logging import importlib.util import numpy as np from diffusers import AnimateDiffPipeline, EulerDiscreteScheduler from diffusers.utils import export_to_video from huggingface_hub import hf_hub_download from safetensors.torch import load_file # Thiết lập logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Kiểm tra backend def check_backend(): if importlib.util.find_spec("imageio") and importlib.util.find_spec("imageio_ffmpeg"): logger.info("imageio và imageio-ffmpeg đã được cài đặt. Sử dụng backend khuyến nghị.") else: logger.error("Yêu cầu imageio và imageio-ffmpeg để xuất video. Cài đặt bằng: pip install imageio imageio-ffmpeg") raise ImportError("Thiếu imageio hoặc imageio-ffmpeg. Cài đặt bằng: pip install imageio imageio-ffmpeg") check_backend() # Tạo thư mục lưu video trong không gian làm việc của Spaces output_dir = "/home/user/app/outputs" os.makedirs(output_dir, exist_ok=True) # Constants bases = { "Cartoon": "frankjoshua/toonyou_beta6", "Realistic": "emilianJR/epiCRealism", "3d": "Lykon/DreamShaper", "Anime": "Yntec/mistoonAnime2" } step_loaded = None base_loaded = "Realistic" motion_loaded = None # Thiết lập thiết bị CPU và kiểu dữ liệu device = "cpu" dtype = torch.float32 # Khởi tạo pipeline pipe = AnimateDiffPipeline.from_pretrained(bases[base_loaded], torch_dtype=dtype).to(device) pipe.scheduler = EulerDiscreteScheduler.from_config( pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear" ) pipe.safety_checker = None # Hàm tạo video def generate_image(prompt, base="Realistic", motion="", step=1, progress=gr.Progress()): global step_loaded, base_loaded, motion_loaded step = int(step) logger.info(f"Tạo video với prompt: {prompt}, base: {base}, steps: {step}") try: # Tải AnimateDiff Lightning checkpoint if step_loaded != step: repo = "ByteDance/AnimateDiff-Lightning" ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors" pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False) step_loaded = step # Tải mô hình cơ sở nếu thay đổi if base_loaded != base: pipe.unet.load_state_dict( torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device), strict=False ) base_loaded = base # Tải motion LoRA nếu có if motion_loaded != motion: try: pipe.unload_lora_weights() if motion: pipe.load_lora_weights(motion, adapter_name="motion") pipe.set_adapters(["motion"], [0.7]) motion_loaded = motion except Exception as e: logger.warning(f"Không thể tải motion LoRA: {e}") motion_loaded = "" progress((0, step)) def progress_callback(i, t, z): progress((i + 1, step)) # Suy luận with torch.no_grad(): output = pipe( prompt=prompt, guidance_scale=1.2, num_inference_steps=step, num_frames=32, callback=progress_callback, callback_steps=1, width=256, height=256 ) # Chuẩn hóa khung hình cho 8 giây frames = output.frames[0] fps = 24 target_frames = fps * 8 if len(frames) < target_frames: frames = np.tile(frames, (target_frames // len(frames) + 1, 1, 1, 1))[:target_frames] else: frames = frames[:target_frames] # Tạo video name = str(uuid.uuid4()).replace("-", "") video_path = os.path.join(output_dir, f"{name}.mp4") export_to_video(frames, video_path, fps=fps) if not os.path.exists(video_path): raise FileNotFoundError("❌ Video không được tạo") logger.info(f"✅ Video sẵn sàng tại {video_path}") # Trả về gr.File để Gradio tạo URL công khai return gr.File(video_path) except Exception as e: logger.error(f"❌ Lỗi khi tạo video: {e}") raise # Giao diện Gradio css = """ body {font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; background-color: #f4f4f9; color: #333;} h1 {color: #333; text-align: center; margin-bottom: 20px;} .gradio-container {max-width: 800px; margin: auto; padding: 20px; background: #fff; box-shadow: 0px 0px 20px rgba(0,0,0,0.1); border-radius: 10px;} .gr-input {margin-bottom: 15px;} .gr-button {width: 100%; background-color: #4CAF50; color: white; border: none; padding: 10px 20px; text-align: center; text-decoration: none; display: inline-block; font-size: 16px; border-radius: 5px; cursor: pointer; transition: background-color 0.3s;} .gr-button:hover {background-color: #45a049;} .gr-video {margin-top: 20px;} .gr-examples {margin-top: 30px;} .gr-examples .gr-example {display: inline-block; width: 100%; text-align: center; padding: 10px; background: #eaeaea; border-radius: 5px; margin-bottom: 10px;} .container {display: flex; flex-wrap: wrap;} .inputs, .output {padding: 20px;} .inputs {flex: 1; min-width: 300px;} .output {flex: 1; min-width: 300px;} @media (max-width: 768px) { .container {flex-direction: column-reverse;} } .svelte-1ybb3u7, .svelte-1clup3e {display: none !important;} """ with gr.Blocks(css=css) as demo: gr.HTML("