import os
import time
import traceback
from typing import Optional

from config_store import (
    get_process_config,
    get_inference_config,
    get_openvino_config,
    get_pytorch_config,
)

import gradio as gr
from huggingface_hub import whoami, create_repo
from huggingface_hub.errors import GatedRepoError
from gradio_huggingfacehub_search import HuggingfaceHubSearch
from optimum_benchmark.launchers.device_isolation_utils import *  # noqa
from optimum_benchmark.backends.openvino.utils import (
    TASKS_TO_OVMODELS,
    TASKS_TO_OVPIPELINES,
)
from optimum_benchmark.backends.transformers_utils import (
    TASKS_TO_AUTO_MODEL_CLASS_NAMES,
)
from optimum_benchmark.backends.diffusers_utils import (
    TASKS_TO_AUTO_PIPELINE_CLASS_NAMES,
)
from optimum_benchmark import (
    Benchmark,
    BenchmarkConfig,
    InferenceConfig,
    ProcessConfig,
    PyTorchConfig,
    OVConfig,
)
from optimum_benchmark.logging_utils import setup_logging
from optimum_benchmark.task_utils import infer_task_from_model_name_or_path


DEVICE = "cpu"
LAUNCHER = "process"
SCENARIO = "inference"
BACKENDS = ["pytorch", "openvino"]
TASKS = set(TASKS_TO_OVMODELS.keys() | TASKS_TO_OVPIPELINES) & set(
    TASKS_TO_AUTO_MODEL_CLASS_NAMES.keys() | TASKS_TO_AUTO_PIPELINE_CLASS_NAMES.keys()
)


def parse_configs(inputs):
    configs = {"process": {}, "inference": {}, "pytorch": {}, "openvino": {}}

    for key, value in inputs.items():
        if key.label == "model":
            model = value
        elif key.label == "task":
            task = value
        elif key.label == "openvino_model":
            openvino_label = value
        elif "." in key.label:
            backend, argument = key.label.split(".")
            configs[backend][argument] = value
        else:
            continue

    for key in configs.keys():
        for k, v in configs[key].items():
            if k in [
                "input_shapes",
                "reshape_kwargs",
                "generate_kwargs",
                "numactl_kwargs",
                "call_kwargs",
            ]:
                configs[key][k] = eval(v)

    configs["process"] = ProcessConfig(**configs.pop("process"))
    configs["inference"] = InferenceConfig(**configs.pop("inference"))
    configs["pytorch"] = PyTorchConfig(
        task=task,
        model=model,
        device=DEVICE,
        **{k: v for k, v in configs["pytorch"].items() if v},
    )

    configs["openvino"] = OVConfig(
        task=task,
        model=openvino_label or model,
        device=DEVICE,
        **{k: v for k, v in configs["openvino"].items() if v},
    )

    return configs


def run_benchmark(inputs, oauth_token: Optional[gr.OAuthToken]):
    if oauth_token is None:
        raise gr.Error("Please login to be able to run the benchmark.")

    timestamp = time.strftime("%Y-%m-%d-%H:%M:%S")
    user_name = whoami(oauth_token.token)["name"]
    repo_id = f"{user_name}/benchmarks"
    folder = f"{timestamp}"

    try:
        create_repo(
            repo_id=repo_id, repo_type="dataset", token=oauth_token.token, exist_ok=True
        )
        gr.Info(f"📩 Benchmarks will be saved under {repo_id} in the folder {folder}")
    except Exception:
        gr.Info(
            f"❌ Error while creating the repo {repo_id} where benchmarks are to be saved"
        )

    outputs = {backend: "Running..." for backend in BACKENDS}
    configs = parse_configs(inputs)
    yield tuple(outputs[b] for b in BACKENDS)

    for backend in BACKENDS:
        try:
            benchmark_name = f"{folder}/{backend}"
            benchmark_config = BenchmarkConfig(
                name=benchmark_name,
                backend=configs[backend],
                launcher=configs[LAUNCHER],
                scenario=configs[SCENARIO],
            )
            benchmark_report = Benchmark.launch(benchmark_config)

            benchmark_config.push_to_hub(
                repo_id=repo_id,
                subfolder=benchmark_name,
                token=oauth_token.token,
            )
            benchmark_report.push_to_hub(
                repo_id=repo_id,
                subfolder=benchmark_name,
                token=oauth_token.token,
            )

        except GatedRepoError:
            outputs[backend] = f"🔒 Model {configs[backend].model} is gated."
            yield tuple(outputs[b] for b in BACKENDS)
            gr.Info("🔒 Gated Repo Error while trying to access the model.")

        except Exception:
            outputs[backend] = f"\n```python-traceback\n{traceback.format_exc()}```\n"
            yield tuple(outputs[b] for b in BACKENDS)
            gr.Info(f"❌ Error while running benchmark for {backend} backend.")

        else:
            outputs[backend] = f"\n{benchmark_report.to_markdown_text()}\n"
            yield tuple(outputs[b] for b in BACKENDS)
            gr.Info(f"✅ Benchmark for {backend} backend ran successfully.")


def update_task(model_id):
    try:
        inferred_task = infer_task_from_model_name_or_path(model_id)

    except GatedRepoError:
        raise gr.Error(
            f"Model {model_id} is gated, please use optimum-benchmark locally to benchmark it."
        )

    except Exception:
        raise gr.Error(
            f"Error while inferring task for {model_id}, please select a task manually."
        )

    if inferred_task not in TASKS:
        raise gr.Error(
            f"Task {inferred_task} is not supported by OpenVINO, please select a task manually."
        )

    return inferred_task


with gr.Blocks() as demo:
    # add login button
    gr.LoginButton()

    # add image
    gr.HTML(
        """<img src="https://huggingface.co/spaces/optimum/optimum-benchmark-ui/resolve/main/huggy_bench.png" style="display: block; margin-left: auto; margin-right: auto; width: 30%;">"""
        "<h1 style='text-align: center'>🤗 Optimum-Benchmark Interface 🏋️</h1>"
        "<p style='text-align: center'>"
        "This Space uses <a href='https://github.com/huggingface/optimum-benchmark.git'>Optimum-Benchmark</a> to automatically benchmark a model from the Hub on different backends."
        "<br>The results (config and report) will be pushed under your namespace in a benchmark repository on the Hub."
        "</p>"
    )

    with gr.Column(variant="panel"):
        model = HuggingfaceHubSearch(
            placeholder="Search for a PyTorch model",
            sumbit_on_select=True,
            search_type="model",
            label="model",
        )

        openvino_model = HuggingfaceHubSearch(
            placeholder="Search for an OpenVINO model (optional)",
            label="openvino_model",
            sumbit_on_select=True,
            search_type="model",
        )

        with gr.Row():
            task = gr.Dropdown(
                info="Task to run the benchmark on.",
                elem_id="task-dropdown",
                choices=TASKS,
                label="task",
            )

    with gr.Column(variant="panel"):
        with gr.Accordion(label="Process Config", open=False, visible=True):
            process_config = get_process_config()
        with gr.Accordion(label="Inference Config", open=False, visible=True):
            inference_config = get_inference_config()

    with gr.Row() as backend_configs:
        with gr.Accordion(label="PyTorch Config", open=False, visible=True):
            pytorch_config = get_pytorch_config()
        with gr.Accordion(label="OpenVINO Config", open=False, visible=True):
            openvino_config = get_openvino_config()

    with gr.Row():
        button = gr.Button(value="Run Benchmark", variant="primary")

    with gr.Row():
        with gr.Accordion(label="PyTorch Report", open=True, visible=True):
            pytorch_report = gr.Markdown()
        with gr.Accordion(label="OpenVINO Report", open=True, visible=True):
            openvino_report = gr.Markdown()

    model.submit(inputs=model, outputs=task, fn=update_task)

    button.click(
        fn=run_benchmark,
        inputs={
            task,
            model,
            openvino_model,
            # backends,
            *process_config.values(),
            *inference_config.values(),
            *pytorch_config.values(),
            *openvino_config.values(),
        },
        outputs={
            pytorch_report,
            openvino_report,
        },
        concurrency_limit=1,
    )


if __name__ == "__main__":
    os.environ["LOG_TO_FILE"] = "0"
    os.environ["LOG_LEVEL"] = "INFO"
    setup_logging(level="INFO", prefix="MAIN-PROCESS")

    demo.queue(max_size=10).launch()