import gradio as gr
import os
import tempfile
from PIL import Image
import numpy as np

css = """
p {
    font-size: 120%;
}

li {
    font-size: 110%;
}

html, body {
    overflow: scroll;
}

video {
    max-height: 400px;
}

.container {
    height: initial;
}


.image-container {
    width: 200px; 
    max-height: auto;
    margin: auto;
}

img {
    max-height: 400px;
}
"""

# Optional CSS stuff for the header example image:
#example {
#    width: 80%;
#    height: 60%
#}

#example img {
#    width: 80%;
#    height: 80%
#}


a = os.path.join(os.path.dirname(__file__), "files/barkley_balloon.mp4")
b = os.path.join(os.path.dirname(__file__), "files/eiffel_tower.mp4")
c = os.path.join(os.path.dirname(__file__), "files/bird.bmp")
d = os.path.join(os.path.dirname(__file__), "files/groot.jpeg")
w1 = os.path.join(os.path.dirname(__file__), "files/AI_generated.png")
w2 = os.path.join(os.path.dirname(__file__), "files/hf-logo.png")
w3 = os.path.join(os.path.dirname(__file__), "files/forest_qr_watermarking.png")
w4 = os.path.join(os.path.dirname(__file__), "files/cheetah1.jpg")
w5 = os.path.join(os.path.dirname(__file__), "files/frog.jpg")
w6 = os.path.join(os.path.dirname(__file__), "files/Human_generated.png")
w7 = os.path.join(os.path.dirname(__file__), "files/hf-logo_transpng.png")


def process_watermark(watermark_path, opacity, size, for_video=False):
    """Process watermark image with opacity and size adjustments"""
    if watermark_path is None:
        return None
    
    # Load watermark image
    if isinstance(watermark_path, str):
        watermark = Image.open(watermark_path)
    else:
        watermark = Image.fromarray(watermark_path) if isinstance(watermark_path, np.ndarray) else watermark_path
    
    # Convert to RGBA if not already
    if watermark.mode != 'RGBA':
        watermark = watermark.convert('RGBA')
    
    # Resize watermark based on size parameter
    original_size = watermark.size
    new_size = (int(original_size[0] * size), int(original_size[1] * size))
    watermark = watermark.resize(new_size, Image.Resampling.LANCZOS)
    
    # Applying Opacity
    # Get the alpha channel and multiply by opacity
    r, g, b, a = watermark.split()
    a = a.point(lambda x: int(x * opacity))
    watermark = Image.merge('RGBA', (r, g, b, a))
    
    # Return PIL Image for images, file path for videos
    if for_video:
        temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
        watermark.save(temp_file.name, 'PNG')
        temp_file.close()
        return temp_file.name
    else:
        return watermark
        
def generate_image(original_image, watermark, opacity, size):
    if original_image is None:
        return None
    processed_watermark = process_watermark(watermark, opacity, size, for_video=False)
    # Convert original_image to PIL Image if it's a numpy array
    if isinstance(original_image, np.ndarray):
        original_image = Image.fromarray(original_image)
    return gr.Image(original_image, watermark=processed_watermark)

def generate_video(original_video, watermark, opacity, size):
    if original_video is None:
        return None
    processed_watermark = process_watermark(watermark, opacity, size, for_video=True)
    return gr.Video(original_video, watermark=processed_watermark)



with gr.Blocks(css=css) as demo:
    with gr.Row():
        with gr.Column():
            gr.Markdown("# 🤗 Watermarking with Gradio: Example")
            gr.Markdown("Watermarks can be **visible** or **invisible**.")
            gr.Markdown("""They can provide information directly, or provide a link for more information.  
                           - Visible watermarks are useful to disclose when content is AI-generated.  
                           - Invisible watermarks can mark content as authentic.
                           - ...And vice versa! There are many possibilities for what watermarks can provide.
                           - Watermarks can also provide information about content created by people.""")
            gr.Markdown("They are a useful tool for **AI provenance**.")
            gr.Markdown("**Expert level:** One particularly useful form of watermarking provides a link with more information, such as with a QR code, [which you can further customize to suit the style of your imagery](https://huggingface.co/spaces/huggingface-projects/QR-code-AI-art-generator).")
            gr.Markdown("""For more information on watermarking -- what watermarking is, why it's important, and the tools available on Hugging Face -- please check out [our blogpost on AI watermarking](https://huggingface.co/blog/watermarking).""")
            gr.Markdown()
            gr.Markdown("## Try it out below!")
        with gr.Column():
            with gr.Column():
                gr.Image('files/watermark_example.png', visible=False)
            with gr.Column():
                gr.Image('files/watermark_example.png', show_label=False, show_download_button=False, elem_id='example', container=False, interactive=False)
                gr.Markdown('**Image Watermark Code:**')
                gr.Code('import gradio as gr\n\nwatermarked_image = gr.Image(original_image_file, watermark=watermark_file)', lines=3)
                gr.Markdown('**Video Watermark Code:**')
                gr.Code('import gradio as gr\n\nwatermarked_video = gr.Video(original_video_file, watermark=watermark_file)', lines=3)
            with gr.Column():
                gr.Image('files/watermark_example.png', visible=False)
    with gr.Tab("Image Watermarking"):
        with gr.Column():
            gr.Markdown("**Inputs**: Image and watermark file")
        with gr.Column():
            gr.Markdown("**Output**: Watermarked image")

        with gr.Row():
            with gr.Column():
                input_image = gr.Image(label="Original Image")
                watermark_image = gr.Image(type='filepath', image_mode=None, label="Watermark Image")
                with gr.Accordion("Watermark settings", open=False):
                    opacity_slider = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, step=0.1, label="Watermark Opacity")
                    size_slider = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Watermark Size")
                generate_btn = gr.Button("Generate Watermarked Image")
            with gr.Column():
                output_image = gr.Image(label="Watermarked Image")
        
        # Examples
        gr.Examples(
            examples=[
                [d, w7, 1.0, 1.0],
                [w4, w5, 0.7, 0.8],
                [c, w6, 0.9, 1.2]
            ],
            inputs=[input_image, watermark_image, opacity_slider, size_slider],
            outputs=output_image,
            fn=generate_image,
            cache_examples=False
        )
        
        generate_btn.click(
            fn=generate_image,
            inputs=[input_image, watermark_image, opacity_slider, size_slider],
            outputs=output_image
        )        
        
    with gr.Tab("Video Watermarking"):
        with gr.Column():
            gr.Markdown("**Inputs**: Video and watermark file")
        with gr.Column():
            gr.Markdown("**Output**: Watermarked video")

        with gr.Row():
            with gr.Column():
                input_video = gr.Video(label="Original Video")
                watermark_video = gr.Image(type='filepath', image_mode=None, label="Watermark Image")
                with gr.Accordion("Watermark settings", open=False):
                    opacity_slider_video = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, step=0.1, label="Watermark Opacity")
                    size_slider_video = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Watermark Size")
                generate_btn_video = gr.Button("Generate Watermarked Video")
            with gr.Column():
                output_video = gr.Video(label="Watermarked Video")
        
        # Examples
        gr.Examples(
            examples=[
                [a, w1, 1.0, 1.0],
                [b, w2, 0.8, 0.9],
                [a, w3, 0.6, 1.5],
                [b, w4, 0.7, 0.8]
            ],
            inputs=[input_video, watermark_video, opacity_slider_video, size_slider_video],
            outputs=output_video,
            fn=generate_video
        )
        
        generate_btn_video.click(
            fn=generate_video,
            inputs=[input_video, watermark_video, opacity_slider_video, size_slider_video],
            outputs=output_video
        )
    
if __name__ == "__main__":
    demo.launch()