import cv2
import numpy as np
from onnxruntime import InferenceSession

import box_utils_numpy
from auto_rotate import align_face


def softmax(x):
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)


def crop_square(img, size, interpolation=cv2.INTER_AREA):
    h, w = img.shape[:2]
    min_size = np.amin([h, w])

    # Centralize and crop
    crop_img = img[
        int(h / 2 - min_size / 2) : int(h / 2 + min_size / 2),
        int(w / 2 - min_size / 2) : int(w / 2 + min_size / 2),
    ]
    resized = cv2.resize(crop_img, (size, size), interpolation=interpolation)

    return resized


class SpoofyNet:
    def __init__(self):
        self.face_model = InferenceSession("models/slim-facedetect.onnx")
        self.face_inputname = self.face_model.get_inputs()[0].name
        self.classifier = InferenceSession("models/spoof.onnx")

    def find_boxes(
        self,
        width,
        height,
        confidences,
        boxes,
        prob_threshold,
        iou_threshold=0.3,
        top_k=-1,
    ):
        boxes = boxes[0]
        confidences = confidences[0]
        picked_box_probs = []
        picked_labels = []
        for class_index in range(1, confidences.shape[1]):
            probs = confidences[:, class_index]
            mask = probs > prob_threshold
            probs = probs[mask]
            if probs.shape[0] == 0:
                continue
            subset_boxes = boxes[mask, :]
            box_probs = np.concatenate([subset_boxes, probs.reshape(-1, 1)], axis=1)
            box_probs = box_utils_numpy.hard_nms(
                box_probs,
                iou_threshold=iou_threshold,
                top_k=top_k,
            )
            picked_box_probs.append(box_probs)
            picked_labels.extend([class_index] * box_probs.shape[0])
        if not picked_box_probs:
            return np.array([]), np.array([]), np.array([])
        picked_box_probs = np.concatenate(picked_box_probs)
        picked_box_probs[:, 0] *= width
        picked_box_probs[:, 1] *= height
        picked_box_probs[:, 2] *= width
        picked_box_probs[:, 3] *= height
        return (
            picked_box_probs[:, :4].astype(np.int32),
            np.array(picked_labels),
            picked_box_probs[:, 4],
        )

    def tta(self, src):
        horizontal_rot = cv2.rotate(src, cv2.ROTATE_180)
        grayscale = cv2.cvtColor(src, cv2.COLOR_RGB2GRAY)
        grayscale = cv2.cvtColor(grayscale, cv2.COLOR_GRAY2RGB)
        return [src, horizontal_rot, grayscale]

    def find_spoof(self, img):
        ret = []
        threshold = 0.6
        image_mean = np.array([127, 127, 127])

        image = cv2.resize(img, (320, 240))
        image = (image - image_mean) / 128
        image = np.transpose(image, [2, 0, 1])
        image = np.expand_dims(image, axis=0)
        image = image.astype(np.float32)

        confidences, boxes = self.face_model.run(None, {self.face_inputname: image})
        boxes, _, _ = self.find_boxes(
            img.shape[1], img.shape[0], confidences, boxes, threshold
        )

        classify_mean, classify_std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
        for i in range(boxes.shape[0]):
            (startX, startY, endX, endY) = boxes[i, :]

            face = img[startY:endY, startX:endX]
            if face.size == 0:
                continue
            face = cv2.cvtColor(face, cv2.COLOR_BGR2RGB)

            # Preprocess
            face = align_face(face)
            face = crop_square(face, 256)

            probs_all = []
            for face in self.tta(face):
                # Normalize
                face = face / 255.0
                face = (face - classify_mean) / classify_std
                face = np.transpose(face, [2, 0, 1])
                face = np.expand_dims(face, axis=0)
                face = face.astype(np.float32)

                predicted = self.classifier.run(None, {"input": face})
                predicted_id = np.argmax(predicted)
                probs = softmax(predicted[0][0])
                probs_all.append(probs)

            final_probs = np.mean(probs_all, axis=0)
            predicted_id = np.argmax(final_probs)
            ret.append(
                {
                    "coords": (startX, startY, endX, endY),
                    "is_real": bool(predicted_id),
                    "probs": final_probs[predicted_id],
                }
            )
        return ret