File size: 2,797 Bytes
b45529f
c8eea54
 
 
 
 
 
 
041221d
5562446
041221d
5562446
041221d
5562446
 
 
 
 
 
 
 
 
 
 
 
041221d
 
5562446
 
041221d
 
c8eea54
b45529f
 
5562446
041221d
5562446
 
 
 
 
 
 
 
 
 
 
 
 
 
041221d
 
5562446
041221d
 
b45529f
041221d
b45529f
 
 
 
 
 
 
 
 
 
 
 
ad7a082
041221d
b45529f
c8eea54
041221d
5562446
041221d
5562446
041221d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gradio as gr
import torch
from diffusers.utils import export_to_video
from diffusers import AutoencoderKLWan, WanPipeline
from diffusers.schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
import os
from uuid import uuid4

# Check for available device (CUDA or CPU)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on {device}...")

# Load the model only once during startup
try:
    print("Loading model...")
    model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
    vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
    scheduler = UniPCMultistepScheduler(
        prediction_type='flow_prediction',
        use_flow_sigmas=True,
        num_train_timesteps=1000,
        flow_shift=5.0
    )
    pipe = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
    pipe.scheduler = scheduler
    pipe.to(device)  # Move model to GPU or CPU based on availability
    print("Model loaded successfully!")
except Exception as e:
    print(f"Error loading model: {e}")
    device = "cpu"  # Fallback to CPU if model loading fails on GPU
    pipe.to(device)

# Define the generation function
def generate_video(prompt, negative_prompt="", height=720, width=1280, num_frames=81, guidance_scale=5.0):
    try:
        print(f"Generating video with prompt: {prompt}")
        output = pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            height=height,
            width=width,
            num_frames=num_frames,
            guidance_scale=guidance_scale,
        ).frames[0]

        output_filename = f"{uuid4()}.mp4"
        output_path = os.path.join("outputs", output_filename)
        os.makedirs("outputs", exist_ok=True)
        export_to_video(output, output_path, fps=16)

        print(f"Video generated and saved to {output_path}")
        return output_path  # Gradio returns this as downloadable file/video
    except Exception as e:
        print(f"Error during video generation: {e}")
        return None

# Gradio Interface
iface = gr.Interface(
    fn=generate_video,
    inputs=[
        gr.Textbox(label="Prompt"),
        gr.Textbox(label="Negative Prompt", value=""),
        gr.Number(label="Height", value=720),
        gr.Number(label="Width", value=1280),
        gr.Number(label="Number of Frames", value=81),
        gr.Number(label="Guidance Scale", value=5.0)
    ],
    outputs=gr.File(label="Generated Video"),
    title="Wan2.1 Video Generator",
    description="Generate realistic videos from text prompts using the Wan2.1 T2V model.",
    live=True
)

# Launch Gradio app in API mode
try:
    iface.launch(share=True, server_name="0.0.0.0", server_port=7860)
except Exception as e:
    print(f"Error launching Gradio app: {e}")