import gradio as gr
import cv2
from PIL import Image
import torch
import numpy as np
import os

from transformers import AutoProcessor, CLIPVisionModel
from detection import detect_image, detect_video
from model import LinearClassifier


def load_model(detection_type):

    device = torch.device("cpu")

    processor = AutoProcessor.from_pretrained("clip-vit-large-patch14")
    clip_model = CLIPVisionModel.from_pretrained("clip-vit-large-patch14", output_attentions=True)
    
    model_path = f"pretrained_models/{detection_type}/clip_weights.pth"
    checkpoint = torch.load(model_path, map_location="cpu")
    input_dim = checkpoint["linear.weight"].shape[1]
    
    detection_model = LinearClassifier(input_dim)
    detection_model.load_state_dict(checkpoint)
    detection_model = detection_model.to(device)

    return processor, clip_model, detection_model

def process_image(image, detection_type):
    processor, clip_model, detection_model = load_model(detection_type)
    
    results = detect_image(image, processor, clip_model, detection_model)

    pred_score = 1 - results["pred_score"]
    attn_map = results["attn_map"]

    return pred_score, attn_map

def process_video(video, detection_type):
    processor, clip_model, detection_model = load_model(detection_type)

    cap = cv2.VideoCapture(video)
    frames = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        pil_image = Image.fromarray(frame)
        frames.append(pil_image)
    cap.release()

    results = detect_video(frames, processor, clip_model, detection_model)

    pred_score = results["pred_score"]
    attn_map = results["attn_map"]

    return pred_score, attn_map

def change_input(input_type):
    if input_type == "Image":
        return gr.update(value=None, visible=True), gr.update(value=None, visible=False)
    elif input_type == "Video":
        return gr.update(value=None, visible=False), gr.update(value=None, visible=True)
    else:
        return gr.update(value=None, visible=False), gr.update(value=None, visible=False)

def determine_model_type(image_path):
    if "facial" in image_path.lower():
        return "Facial"
    elif "general" in image_path.lower():
        return "General"
    else:
        return "Facial"  # 기본값


def process_input(input_type, model_type, image, video):
    detection_type = "facial" if model_type == "Facial" else "general"

    if input_type == "Image" and image is not None:
        return process_image(image, detection_type)
    elif input_type == "Video" and video is not None:
        return process_video(video, detection_type)
    else:
        return None, None


def process_example(image_path):
    model_type = determine_model_type(image_path)
    return Image.open(image_path), model_type

fake_examples, real_examples = [], []
for example in os.listdir("examples/fake"):
    fake_examples.append(os.path.join("examples/fake", example))
for example in os.listdir("examples/real"):
    real_examples.append(os.path.join("examples/real", example))

with gr.Blocks() as demo:
  
    gr.Markdown("## Deepfake Detection : Facial / General")
  
    input_type = gr.Radio(["Image", "Video"], label="Choose Input Type", value="Image")

    model_type = gr.Radio(["Facial", "General"], label="Choose Model Type", value="General")

    H, W = 300, 300
    image_input = gr.Image(type="pil", label="Upload Image", visible=True, height=H, width=W)
    video_input = gr.Video(label="Upload Video", visible=False, height=H, width=W)

    process_button = gr.Button("Run Model")

    pred_score_output = gr.Textbox(label="Prediction Score : 0 - REAL, 1 - FAKE")
    attn_map_output = gr.Image(type="pil", label="Attention Map", height=H, width=W)

    # Example Images 추가
    gr.Examples(
        examples=fake_examples,
        inputs=[image_input],
        outputs=[image_input, model_type],
        fn=process_example,
        cache_examples=False,
        examples_per_page=10,
        label="Fake Examples"
    )
    gr.Examples(
        examples=real_examples,
        inputs=[image_input],
        outputs=[image_input, model_type],
        fn=process_example,
        cache_examples=False,
        examples_per_page=10,
        label="Real Examples"
    )
  
    input_type.change(fn=change_input, inputs=[input_type], outputs=[image_input, video_input])
  
    process_button.click(
        fn=process_input, 
        inputs=[input_type, model_type, image_input, video_input], 
        outputs=[pred_score_output, attn_map_output]
    )

if __name__ == "__main__":
    demo.launch()