File size: 2,301 Bytes
eb04acb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import os
import sys
import time

import torch

# Đưa project vào sys.path
current_file_path = os.path.abspath(__file__)
project_roots = [
    os.path.dirname(current_file_path),
    os.path.dirname(os.path.dirname(current_file_path)),
    os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))
]
for project_root in project_roots:
    if project_root not in sys.path:
        sys.path.insert(0, project_root)

from cogvideox.api.api import (
    infer_forward_api,
    update_diffusion_transformer_api,
    update_edition_api
)
from cogvideox.ui.controller import flow_scheduler_dict
from cogvideox.ui.wan_fun_ui import ui, ui_eas, ui_modelscope

if __name__ == "__main__":
    # Chọn UI mode: "ui", "eas" hoặc "modelscope"
    ui_mode = "eas"

    # Các cấu hình bộ nhớ GPU
    GPU_memory_mode = "model_cpu_offload"
    weight_dtype = torch.bfloat16  # hoặc torch.float16
    config_path = "config/wan2.1/wan_civitai.yaml"

    # Cho modelscope mode
    model_name = "models/Diffusion_Transformer/Wan2.1-Fun-1.3B-InP"
    model_type = "Inpaint"
    savedir_sample = "samples"

    # Khởi tạo demo & controller tùy theo mode
    if ui_mode == "modelscope":
        demo, controller = ui_modelscope(
            model_name, model_type, savedir_sample,
            GPU_memory_mode, flow_scheduler_dict,
            weight_dtype, config_path
        )
    elif ui_mode == "eas":
        demo, controller = ui_eas(
            model_name, flow_scheduler_dict,
            savedir_sample, config_path
        )
    else:
        demo, controller = ui(
            GPU_memory_mode, flow_scheduler_dict,
            weight_dtype, config_path
        )

    # Launch Gradio trên Colab: use share=True để Gradio tự sinh URL public
    app, _, _ = demo.queue(status_update_rate=1).launch(
        share=True,               # <--- bật share, Gradio sẽ cung cấp link
        prevent_thread_lock=True  # giữ process không bị block
    )

    # Khởi API endpoints
    infer_forward_api(None, app, controller)
    update_diffusion_transformer_api(None, app, controller)
    update_edition_api(None, app, controller)

    # Giữ Python process chạy
    while True:
        time.sleep(5)