import gradio as gr
import torch
import os
import spaces
import uuid

from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
from diffusers.utils import export_to_video
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from PIL import Image

# 自定义CSS样式
custom_css = """
    .container {
        max-width: 1000px;
        margin: auto;
        padding: 20px;
    }
    
    .title {
        background: linear-gradient(90deg, #00ff87 0%, #60efff 100%);
        -webkit-background-clip: text;
        -webkit-text-fill-color: transparent;
        font-size: 2.5em;
        text-align: center;
        margin-bottom: 1em;
        font-weight: bold;
        text-shadow: 2px 2px 4px rgba(0,0,0,0.1);
    }
    
    .subtitle {
        color: #666;
        text-align: center;
        margin-bottom: 2em;
        font-size: 1.2em;
    }
    
    .warning {
        color: #ff4b4b;
        font-weight: bold;
        text-align: center;
        padding: 10px;
        margin: 10px 0;
        border-radius: 5px;
        background: rgba(255,75,75,0.1);
    }
    
    .info {
        color: #4b8bff;
        text-align: center;
        padding: 10px;
        margin: 10px 0;
        border-radius: 5px;
        background: rgba(75,139,255,0.1);
    }
    
    .gradio-container {
        background: linear-gradient(135deg, #1a1a1a 0%, #2a2a2a 100%);
    }
    
    .gr-button {
        background: linear-gradient(90deg, #00ff87 0%, #60efff 100%);
        border: none;
        color: black;
        font-weight: bold;
    }
    
    .gr-button:hover {
        background: linear-gradient(90deg, #60efff 0%, #00ff87 100%);
        transform: translateY(-2px);
        box-shadow: 0 5px 15px rgba(0,255,135,0.3);
        transition: all 0.3s ease;
    }
    
    .gr-input, .gr-dropdown {
        border: 2px solid rgba(96,239,255,0.2);
        border-radius: 8px;
        background: rgba(26,26,26,0.9);
        color: white;
    }
    
    .gr-input:focus, .gr-dropdown:focus {
        border-color: #00ff87;
        box-shadow: 0 0 10px rgba(0,255,135,0.3);
    }
    
    .gr-form {
        background: rgba(42,42,42,0.8);
        border-radius: 15px;
        padding: 20px;
        box-shadow: 0 8px 32px rgba(0,0,0,0.3);
    }
    
    .example-container {
        background: rgba(255,255,255,0.05);
        border-radius: 10px;
        padding: 15px;
        margin: 10px 0;
    }
"""

USERS = { 
    "admin": "svip",
    "svip": "svip8888"
}

# Constants
bases = {
    "卡通风格": "frankjoshua/toonyou_beta6",
    "写实风格": "emilianJR/epiCRealism", 
    "3D风格": "Lykon/DreamShaper",
    "动漫风格": "Yntec/mistoonAnime2"
}
step_loaded = None
base_loaded = "写实风格"
motion_loaded = None

# Ensure model and scheduler are initialized in GPU-enabled function
if not torch.cuda.is_available():
    raise NotImplementedError("未检测到GPU!")

device = "cuda"
dtype = torch.float16
pipe = AnimateDiffPipeline.from_pretrained(bases[base_loaded], torch_dtype=dtype).to(device)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")

# Safety checkers
from transformers import CLIPFeatureExtractor
feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")

def login(username, password):
    if username in USERS and USERS[username] == password:
        return True, "登录成功!"
    return False, "用户名或密码错误!"

@spaces.GPU(duration=30,queue=False)
def generate_image(prompt, base="写实风格", motion="", step=8, progress=gr.Progress()):
    global step_loaded
    global base_loaded
    global motion_loaded
    print(prompt, base, step)

    if step_loaded != step:
        repo = "ByteDance/AnimateDiff-Lightning"
        ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
        pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
        step_loaded = step

    if base_loaded != base:
        pipe.unet.load_state_dict(torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device), strict=False)
        base_loaded = base

    if motion_loaded != motion:
        pipe.unload_lora_weights()
        if motion != "":
            pipe.load_lora_weights(motion, adapter_name="motion")
            pipe.set_adapters(["motion"], [0.7])
        motion_loaded = motion

    progress((0, step))
    def progress_callback(i, t, z):
        progress((i+1, step))

    output = pipe(prompt=prompt, guidance_scale=1.2, num_inference_steps=step, callback=progress_callback, callback_steps=1)

    name = str(uuid.uuid4()).replace("-", "")
    path = f"/tmp/{name}.mp4"
    export_to_video(output.frames[0], path, fps=10)
    return path

# Gradio Interface
with gr.Blocks(css=custom_css) as demo:
    # 创建两个界面容器
    with gr.Group(visible=True) as login_container:
        gr.HTML("""
            <div class="container">
                <h1 class="title">🌟 OfficeChatAI 视频生成系统</h1>
                <p class="subtitle">欢迎使用AI视频生成系统,让创意转化为现实</p>
            </div>
        """)
        with gr.Group(elem_classes="gr-form"):
            username = gr.Textbox(label="用户名", placeholder="请输入VIP用户名")
            password = gr.Textbox(label="密码", type="password", placeholder="请输入密码")
            login_button = gr.Button("登 录", variant="primary")
            login_msg = gr.Textbox(label="登录状态", interactive=False)
        
    # 主界面
    with gr.Group(visible=False) as main_container:
        gr.HTML("""
            <div class="container">
                <h1 class="title">🎬 OfficeChatAI 视频生成工作室</h1>
                <p class="subtitle">专业的AI视频生成平台 | VIP尊享服务</p>
                <div class="warning">提示:首次生成视频需要较长时间,后续生成速度会显著提升</div>
                <div class="info">为获得最佳效果,建议使用英文提示词,参考示例格式</div>
            </div>
        """)
        with gr.Group(elem_classes="gr-form"):
            with gr.Row():
                prompt = gr.Textbox(
                    label='创作提示词',
                    placeholder='请输入您想要生成的视频场景描述...',
                    elem_classes="gr-input"
                )
            with gr.Row():
                select_base = gr.Dropdown(
                    label='选择基础模型',
                    choices=[
                        "卡通风格", 
                        "写实风格",
                        "3D风格",
                        "动漫风格",
                    ],
                    value=base_loaded,
                    interactive=True,
                    elem_classes="gr-dropdown"
                )
                select_motion = gr.Dropdown(
                    label='动作特效',
                    choices=[
                        ("默认效果", ""),
                        ("镜头拉近", "guoyww/animatediff-motion-lora-zoom-in"),
                        ("镜头拉远", "guoyww/animatediff-motion-lora-zoom-out"),
                        ("向上倾斜", "guoyww/animatediff-motion-lora-tilt-up"),
                        ("向下倾斜", "guoyww/animatediff-motion-lora-tilt-down"),
                        ("向左平移", "guoyww/animatediff-motion-lora-pan-left"),
                        ("向右平移", "guoyww/animatediff-motion-lora-pan-right"),
                        ("逆时针旋转", "guoyww/animatediff-motion-lora-rolling-anticlockwise"),
                        ("顺时针旋转", "guoyww/animatediff-motion-lora-rolling-clockwise"),
                    ],
                    value="guoyww/animatediff-motion-lora-zoom-in",
                    interactive=True,
                    elem_classes="gr-dropdown"
                )
                select_step = gr.Dropdown(
                    label='生成质量',
                    choices=[
                        ('快速模式(1步)', 1), 
                        ('平衡模式(2步)', 2),
                        ('高质量(4步)', 4),
                        ('超高清(8步)', 8),
                    ],
                    value=4,
                    interactive=True,
                    elem_classes="gr-dropdown"
                )
            submit = gr.Button(
                value="✨ 开始生成",
                scale=1,
                variant="primary",
                elem_classes=["gr-button"]
            )
            
        video = gr.Video(
            label='创作结果',
            autoplay=True,
            height=512,
            width=512,
            elem_id="video_output",
            elem_classes="output-video"
        )

        with gr.Group(elem_classes="example-container"):
            gr.HTML("<h3 class='subtitle'>🎯 创作灵感</h3>")
            gr.Examples(
                examples=[
                ["A majestic Eiffel Tower with moving clouds in the background"], 
                ["A lion running through a dense forest"],
                ["An astronaut floating in space with stars twinkling"],
                ["A flock of birds flying in formation against a blue sky"],
                ["Statue of Liberty viewed from a approaching drone"],
                ["A cute panda drinking tea in a bamboo forest"],
                ["Children playing in the snow"],
                ["Cars driving on a rainy city street"]
            ], 
                fn=generate_image,
                inputs=[prompt],
                outputs=[video],
                cache_examples="lazy",
            )

        # 生成按钮事件
        submit.click(
            fn=generate_image,
            inputs=[prompt, select_base, select_motion, select_step],
            outputs=[video]
        )

    # 登录逻辑
    def handle_login(username, password):
        success, message = login(username, password)
        if success:
            return message, gr.update(visible=False), gr.update(visible=True)
        return message, gr.update(visible=True), gr.update(visible=False)

    login_button.click(
        fn=handle_login,
        inputs=[username, password],
        outputs=[login_msg, login_container, main_container]
    )

demo.queue().launch()