import argparse
import glob
import json
import logging
import multiprocessing as mp
import os
import time
import uuid
from datetime import timedelta
from functools import lru_cache
from typing import List, Union

import boto3
import gradio as gr
import requests
from huggingface_hub import HfApi
from optimum.onnxruntime import ORTModelForSequenceClassification
from transformers import AutoTokenizer, pipeline

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

hf_token = os.getenv("HF_TOKEN")
hf_api = HfApi(token=hf_token)
num_processes = 2  # mp.cpu_count()

lakera_api_key = os.getenv("LAKERA_API_KEY")
azure_content_safety_endpoint = os.getenv("AZURE_CONTENT_SAFETY_ENDPOINT")
azure_content_safety_key = os.getenv("AZURE_CONTENT_SAFETY_KEY")
bedrock_runtime_client = boto3.client('bedrock-runtime', region_name="us-east-1")


@lru_cache(maxsize=2)
def init_prompt_injection_model(prompt_injection_ort_model: str, subfolder: str = "") -> pipeline:
    hf_model = ORTModelForSequenceClassification.from_pretrained(
        prompt_injection_ort_model,
        export=False,
        subfolder=subfolder,
        file_name="model.onnx",
        token=hf_token,
    )
    hf_tokenizer = AutoTokenizer.from_pretrained(
        prompt_injection_ort_model, subfolder=subfolder, token=hf_token
    )
    hf_tokenizer.model_input_names = ["input_ids", "attention_mask"]

    logger.info(f"Initialized classification ONNX model {prompt_injection_ort_model} on CPU")

    return pipeline(
        "text-classification",
        model=hf_model,
        tokenizer=hf_tokenizer,
        device="cpu",
        batch_size=1,
        truncation=True,
        max_length=512,
    )


def convert_elapsed_time(diff_time) -> float:
    return round(timedelta(seconds=diff_time).total_seconds(), 2)


deepset_classifier = init_prompt_injection_model(
    "protectai/deberta-v3-base-injection-onnx"
)  # ONNX version of deepset/deberta-v3-base-injection
protectai_v2_classifier = init_prompt_injection_model(
    "protectai/deberta-v3-base-prompt-injection-v2", "onnx"
)


def detect_hf(
    prompt: str,
    threshold: float = 0.5,
    classifier=protectai_v2_classifier,
    label: str = "INJECTION",
) -> (bool, bool):
    try:
        pi_result = classifier(prompt)
        injection_score = round(
            pi_result[0]["score"] if pi_result[0]["label"] == label else 1 - pi_result[0]["score"],
            2,
        )

        logger.info(f"Prompt injection result from the HF model: {pi_result}")

        return True, injection_score > threshold
    except Exception as err:
        logger.error(f"Failed to call HF model: {err}")
        return False, False


def detect_hf_protectai_v2(prompt: str) -> (bool, bool):
    return detect_hf(prompt, classifier=protectai_v2_classifier)


def detect_hf_deepset(prompt: str) -> (bool, bool):
    return detect_hf(prompt, classifier=deepset_classifier)


def detect_lakera(prompt: str) -> (bool, bool):
    try:
        response = requests.post(
            "https://api.lakera.ai/v1/prompt_injection",
            json={"input": prompt},
            headers={"Authorization": f"Bearer {lakera_api_key}"},
        )
        response_json = response.json()
        logger.info(f"Prompt injection result from Lakera: {response.json()}")

        return True, response_json["results"][0]["flagged"]
    except requests.RequestException as err:
        logger.error(f"Failed to call Lakera API: {err}")
        return False, False


def detect_azure(prompt: str) -> (bool, bool):
    try:
        response = requests.post(
            f"{azure_content_safety_endpoint}contentsafety/text:shieldPrompt?api-version=2024-02-15-preview",
            json={"userPrompt": prompt},
            headers={"Ocp-Apim-Subscription-Key": azure_content_safety_key},
        )
        response_json = response.json()
        logger.info(f"Prompt injection result from Azure: {response.json()}")

        if "userPromptAnalysis" not in response_json:
            return False, False

        return True, response_json["userPromptAnalysis"]["attackDetected"]
    except requests.RequestException as err:
        logger.error(f"Failed to call Azure API: {err}")
        return False, False


def detect_aws_bedrock(prompt: str) -> (bool, bool):
    response = bedrock_runtime_client.apply_guardrail(
        guardrailIdentifier="tx8t6psx14ho",
        guardrailVersion="1",
        source='INPUT',
        content=[
            {"text": {"text": prompt}}
        ])

    logger.info(f"Prompt injection result from AWS Bedrock Guardrails: {response}")
    if response["ResponseMetadata"]["HTTPStatusCode"] != 200:
        logger.error(f"Failed to call AWS Bedrock Guardrails API: {response}")
        return False, False

    return True, response['action'] != 'NONE'


detection_providers = {
    "ProtectAI v2 (HF model)": detect_hf_protectai_v2,
    "Deepset (HF model)": detect_hf_deepset,
    "Lakera Guard": detect_lakera,
    "Azure Content Safety": detect_azure,
    "AWS Bedrock Guardrails": detect_aws_bedrock,
}


def is_detected(provider: str, prompt: str) -> (str, bool, bool, float):
    if provider not in detection_providers:
        logger.warning(f"Provider {provider} is not supported")
        return False, 0.0

    start_time = time.monotonic()
    request_result, is_injection = detection_providers[provider](prompt)
    end_time = time.monotonic()

    return provider, request_result, is_injection, convert_elapsed_time(end_time - start_time)


def execute(prompt: str) -> List[Union[str, bool, float]]:
    results = []

    with mp.Pool(processes=num_processes) as pool:
        for result in pool.starmap(
            is_detected, [(provider, prompt) for provider in detection_providers.keys()]
        ):
            results.append(result)

    # Save image and result
    fileobj = json.dumps(
        {"prompt": prompt, "results": results}, indent=2, ensure_ascii=False
    ).encode("utf-8")
    result_path = f"/prompts/train/{str(uuid.uuid4())}.json"

    hf_api.upload_file(
        path_or_fileobj=fileobj,
        path_in_repo=result_path,
        repo_id="protectai/prompt-injection-benchmark",
        repo_type="dataset",
    )
    logger.info(f"Stored prompt: {prompt}")

    return results


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--port", type=int, default=7860)
    parser.add_argument("--url", type=str, default="0.0.0.0")
    args, left_argv = parser.parse_known_args()

    example_files = glob.glob(os.path.join(os.path.dirname(__file__), "examples", "*.txt"))
    examples = [open(file).read() for file in example_files]

    gr.Interface(
        fn=execute,
        inputs=[
            gr.Textbox(label="Prompt"),
        ],
        outputs=[
            gr.Dataframe(
                headers=[
                    "Provider",
                    "Is processed successfully?",
                    "Is prompt injection?",
                    "Latency (seconds)",
                ],
                datatype=["str", "bool", "bool", "number"],
                label="Results",
            ),
        ],
        title="Prompt Injection Solutions Benchmark",
        description="This interface aims to benchmark the known prompt injection detection providers. "
        "The results are <strong>stored in the private dataset</strong> for further analysis and improvements. This interface is for research purposes only."
        "<br /><br />"
        "HuggingFace (HF) models are hosted on Spaces while other providers are called as APIs.<br /><br />"
        '<a href="https://join.slack.com/t/laiyerai/shared_invite/zt-28jv3ci39-sVxXrLs3rQdaN3mIl9IT~w">Join our Slack community to discuss LLM Security</a><br />'
        '<a href="https://github.com/protectai/llm-guard">Secure your LLM interactions with LLM Guard</a>',
        examples=[
            [
                example,
                False,
            ]
            for example in examples
        ],
        cache_examples=True,
        allow_flagging="never",
        concurrency_limit=1,
    ).launch(server_name=args.url, server_port=args.port)