import gradio as gr
import numpy as np
import random
import spaces
import torch
import time
import os
from diffusers import DiffusionPipeline
from custom_pipeline import FLUXPipelineWithIntermediateOutputs

# 상수 정의
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048
DEFAULT_WIDTH = 1024
DEFAULT_HEIGHT = 1024
DEFAULT_INFERENCE_STEPS = 1
GPU_DURATION = 15  # GPU 할당 시간 축소

# 간단한 한글 감지 함수
def is_korean(text):
    """한글 포함 여부 확인"""
    return any('\u3131' <= char <= '\u3163' or '\uac00' <= char <= '\ud7a3' for char in text)

# 모델 설정
def setup_model():
    dtype = torch.float16
    pipe = FLUXPipelineWithIntermediateOutputs.from_pretrained(
        "black-forest-labs/FLUX.1-schnell", 
        torch_dtype=dtype
    ).to("cuda")
    return pipe

pipe = setup_model()

# 메뉴 레이블
labels = {
    "Generated Image": "Generated Image",
    "Prompt": "Prompt",
    "Enhance Image": "Enhance Image",
    "Advanced Options": "Advanced Options",
    "Seed": "Seed",
    "Randomize Seed": "Randomize Seed",
    "Width": "Width",
    "Height": "Height",
    "Inference Steps": "Inference Steps",
    "Inspiration Gallery": "Inspiration Gallery"
}

# 이미지 생성 함수
@spaces.GPU(duration=GPU_DURATION)
def generate_image(prompt, seed=None, width=DEFAULT_WIDTH, height=DEFAULT_HEIGHT, 
                  randomize_seed=True, num_inference_steps=DEFAULT_INFERENCE_STEPS):
    try:
        # 입력값 검증
        if not isinstance(seed, (int, type(None))):
            seed = None
            randomize_seed = True
        
        # 한글 프롬프트에 대한 경고 메시지 추가
        if is_korean(prompt):
            print("경고: 한글 프롬프트는 직접 처리됩니다. 번역기를 사용하지 않습니다.")
        
        if seed is None or randomize_seed:
            seed = random.randint(0, MAX_SEED)
        
        # 크기 유효성 검사
        width = min(max(256, width), MAX_IMAGE_SIZE)
        height = min(max(256, height), MAX_IMAGE_SIZE)
        
        generator = torch.Generator().manual_seed(seed)
        
        start_time = time.time()
        
        with torch.cuda.amp.autocast():
            for img in pipe.generate_images(
                prompt=prompt,
                guidance_scale=0,
                num_inference_steps=num_inference_steps,
                width=width,
                height=height,
                generator=generator
            ):
                latency = f"처리 시간: {(time.time()-start_time):.2f} 초"
                
                # CUDA 캐시 정리
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
                    
                yield img, seed, latency
                
    except Exception as e:
        print(f"이미지 생성 오류: {e}")
        yield None, seed, f"오류: {str(e)}"

# 예제 이미지 생성
def generate_example_image(prompt):
    try:
        return next(generate_image(prompt, randomize_seed=True))
    except Exception as e:
        print(f"예제 생성 오류: {e}")
        return None, None, f"오류: {str(e)}"

# Example prompts
examples = [
    "A steampunk owl wearing Victorian-era clothing and reading a mechanical book",
    "A floating island made of books with waterfalls of knowledge cascading down",
    "A bioluminescent forest where mushrooms glow like neon signs in a cyberpunk city",
    "An ancient temple being reclaimed by nature, with robots performing archaeology",
    "A cosmic coffee shop where baristas are constellations serving drinks made of stardust"
]


css = """
footer {
    visibility: hidden;
}
"""


def create_snow_effect():
    # CSS 스타일 정의
    snow_css = """
    @keyframes snowfall {
        0% {
            transform: translateY(-10vh) translateX(0);
            opacity: 1;
        }
        100% {
            transform: translateY(100vh) translateX(100px);
            opacity: 0.3;
        }
    }
    .snowflake {
        position: fixed;
        color: white;
        font-size: 1.5em;
        user-select: none;
        z-index: 1000;
        pointer-events: none;
        animation: snowfall linear infinite;
    }
    """

    # JavaScript 코드 정의
    snow_js = """
    function createSnowflake() {
        const snowflake = document.createElement('div');
        snowflake.innerHTML = '❄';
        snowflake.className = 'snowflake';
        snowflake.style.left = Math.random() * 100 + 'vw';
        snowflake.style.animationDuration = Math.random() * 3 + 2 + 's';
        snowflake.style.opacity = Math.random();
        document.body.appendChild(snowflake);
        
        setTimeout(() => {
            snowflake.remove();
        }, 5000);
    }
    setInterval(createSnowflake, 200);
    """

    # CSS와 JavaScript를 결합한 HTML 
    snow_html = f"""
    <style>
        {snow_css}
    </style>
    <script>
        {snow_js}
    </script>
    """
    
    return gr.HTML(snow_html)

# Gradio UI 구성
with gr.Blocks(theme="soft", css=css) as demo:

    gr.HTML(
        """
        <div class='container' style='display:flex; justify-content:center; gap:12px;'>
            <a href="https://huggingface.co/spaces/openfree/Best-AI" target="_blank">
                <img src="https://img.shields.io/static/v1?label=OpenFree&message=BEST%20AI%20Services&color=%230000ff&labelColor=%23000080&logo=huggingface&logoColor=%23ffa500&style=for-the-badge" alt="OpenFree badge">
            </a>
    
            <a href="https://discord.gg/openfreeai" target="_blank">
                <img src="https://img.shields.io/static/v1?label=Discord&message=Openfree%20AI&color=%230000ff&labelColor=%23800080&logo=discord&logoColor=white&style=for-the-badge" alt="Discord badge">
            </a>
        </div>
            """
    )

    
    create_snow_effect()    
    
    with gr.Column(elem_id="app-container"):
        with gr.Row():
            with gr.Column(scale=3):
                result = gr.Image(label=labels["Generated Image"], 
                                show_label=False, 
                                interactive=False)
            with gr.Column(scale=1):
                prompt = gr.Text(
                    label=labels["Prompt"],
                    placeholder="생성하고 싶은 이미지를 설명해주세요...",
                    lines=3,
                    show_label=False,
                    container=False,
                )
                enhanceBtn = gr.Button(f"🚀 {labels['Enhance Image']}")

                with gr.Column(labels["Advanced Options"]):
                    with gr.Row():
                        latency = gr.Text(show_label=False)
                    with gr.Row():
                        seed = gr.Number(
                            label=labels["Seed"], 
                            value=42, 
                            precision=0,
                            minimum=0,
                            maximum=MAX_SEED
                        )
                        randomize_seed = gr.Checkbox(
                            label=labels["Randomize Seed"], 
                            value=True
                        )
                    with gr.Row():
                        width = gr.Slider(
                            label=labels["Width"], 
                            minimum=256, 
                            maximum=MAX_IMAGE_SIZE, 
                            step=32, 
                            value=DEFAULT_WIDTH
                        )
                        height = gr.Slider(
                            label=labels["Height"], 
                            minimum=256, 
                            maximum=MAX_IMAGE_SIZE, 
                            step=32, 
                            value=DEFAULT_HEIGHT
                        )
                        num_inference_steps = gr.Slider(
                            label=labels["Inference Steps"], 
                            minimum=1, 
                            maximum=4, 
                            step=1, 
                            value=DEFAULT_INFERENCE_STEPS
                        )

        with gr.Row():
            gr.Markdown(f"### 🌟 {labels['Inspiration Gallery']}")
        with gr.Row():
            gr.Examples(
                examples=examples,
                fn=generate_example_image,
                inputs=[prompt],
                outputs=[result, seed],
                cache_examples=False
            )

    # 이벤트 처리
    def validated_generate(*args):
        try:
            return next(generate_image(*args))
        except Exception as e:
            print(f"검증 생성 오류: {e}")
            return None, args[1], f"오류: {str(e)}"

    enhanceBtn.click(
        fn=generate_image,
        inputs=[prompt, seed, width, height],
        outputs=[result, seed, latency],
        show_progress="hidden",
        show_api=False,
        queue=False
    )

    gr.on(
        triggers=[prompt.input, width.input, height.input, num_inference_steps.input],
        fn=validated_generate,
        inputs=[prompt, seed, width, height, randomize_seed, num_inference_steps],
        outputs=[result, seed, latency],
        show_progress="hidden",
        show_api=False,
        trigger_mode="always_last",
        queue=False
    )

if __name__ == "__main__":
    demo.launch(show_api=False, share=True, server_name="0.0.0.0", mcp_server=True)