# -*- coding: utf-8 -*-
#
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# Using this computer program means that you agree to the terms
# in the LICENSE file included with this software distribution.
# Any use not explicitly granted by the LICENSE is prohibited.
#
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# For comments or questions, please email us at pixie@tue.mpg.de
# For commercial licensing contact, please contact ps-license@tuebingen.mpg.de

import os

import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from skimage.io import imread

from .models.encoders import MLP, HRNEncoder, ResnetEncoder
from .models.moderators import TempSoftmaxFusion
from .models.SMPLX import SMPLX
from .utils import rotation_converter as converter
from .utils import tensor_cropper, util
from .utils.config import cfg


class PIXIE(object):
    def __init__(self, config=None, device="cuda:0"):
        if config is None:
            self.cfg = cfg
        else:
            self.cfg = config

        self.device = device
        # parameters setting
        self.param_list_dict = {}
        for lst in self.cfg.params.keys():
            param_list = cfg.params.get(lst)
            self.param_list_dict[lst] = {i: cfg.model.get("n_" + i) for i in param_list}

        # Build the models
        self._create_model()
        # Set up the cropping modules used to generate face/hand crops from the body predictions
        self._setup_cropper()

    def forward(self, data):

        # encode + decode
        param_dict = self.encode(
            {"body": {"image": data}},
            threthold=True,
            keep_local=True,
            copy_and_paste=False,
        )
        opdict = self.decode(param_dict["body"], param_type="body")

        return opdict

    def _setup_cropper(self):
        self.Cropper = {}
        for crop_part in ["head", "hand"]:
            data_cfg = self.cfg.dataset[crop_part]
            scale_size = (data_cfg.scale_min + data_cfg.scale_max) * 0.5
            self.Cropper[crop_part] = tensor_cropper.Cropper(
                crop_size=data_cfg.image_size,
                scale=[scale_size, scale_size],
                trans_scale=0,
            )

    def _create_model(self):
        self.model_dict = {}
        # Build all image encoders
        # Hand encoder only works for right hand, for left hand, flip inputs and flip the results back
        self.Encoder = {}
        for key in self.cfg.network.encoder.keys():
            if self.cfg.network.encoder.get(key).type == "resnet50":
                self.Encoder[key] = ResnetEncoder().to(self.device)
            elif self.cfg.network.encoder.get(key).type == "hrnet":
                self.Encoder[key] = HRNEncoder().to(self.device)
            self.model_dict[f"Encoder_{key}"] = self.Encoder[key].state_dict()

        # Build the parameter regressors
        self.Regressor = {}
        for key in self.cfg.network.regressor.keys():
            n_output = sum(self.param_list_dict[f"{key}_list"].values())
            channels = ([2048] + self.cfg.network.regressor.get(key).channels + [n_output])
            if self.cfg.network.regressor.get(key).type == "mlp":
                self.Regressor[key] = MLP(channels=channels).to(self.device)
            self.model_dict[f"Regressor_{key}"] = self.Regressor[key].state_dict()

        # Build the extractors
        # to extract separate head/left hand/right hand feature from body feature
        self.Extractor = {}
        for key in self.cfg.network.extractor.keys():
            channels = [2048] + self.cfg.network.extractor.get(key).channels + [2048]
            if self.cfg.network.extractor.get(key).type == "mlp":
                self.Extractor[key] = MLP(channels=channels).to(self.device)
            self.model_dict[f"Extractor_{key}"] = self.Extractor[key].state_dict()

        # Build the moderators
        self.Moderator = {}
        for key in self.cfg.network.moderator.keys():
            share_part = key.split("_")[0]
            detach_inputs = self.cfg.network.moderator.get(key).detach_inputs
            detach_feature = self.cfg.network.moderator.get(key).detach_feature
            channels = [2048 * 2] + self.cfg.network.moderator.get(key).channels + [2]
            self.Moderator[key] = TempSoftmaxFusion(
                detach_inputs=detach_inputs,
                detach_feature=detach_feature,
                channels=channels,
            ).to(self.device)
            self.model_dict[f"Moderator_{key}"] = self.Moderator[key].state_dict()

        # Build the SMPL-X body model, which we also use to represent faces and
        # hands, using the relevant parts only
        self.smplx = SMPLX(self.cfg.model).to(self.device)
        self.part_indices = self.smplx.part_indices

        # -- resume model
        model_path = self.cfg.pretrained_modelpath
        if os.path.exists(model_path):
            checkpoint = torch.load(model_path)
            for key in self.model_dict.keys():
                util.copy_state_dict(self.model_dict[key], checkpoint[key])
        else:
            print(f"pixie trained model path: {model_path} does not exist!")
            exit()
        # eval mode
        for module in [self.Encoder, self.Regressor, self.Moderator, self.Extractor]:
            for net in module.values():
                net.eval()

    def decompose_code(self, code, num_dict):
        """Convert a flattened parameter vector to a dictionary of parameters"""
        code_dict = {}
        start = 0
        for key in num_dict:
            end = start + int(num_dict[key])
            code_dict[key] = code[:, start:end]
            start = end
        return code_dict

    def part_from_body(self, image, part_key, points_dict, crop_joints=None):
        """crop part(head/left_hand/right_hand) out from body data, joints also change accordingly"""
        assert part_key in ["head", "left_hand", "right_hand"]
        assert "smplx_kpt" in points_dict.keys()
        if part_key == "head":
            # use face 68 kpts for cropping head image
            indices_key = "face"
        elif part_key == "left_hand":
            indices_key = "left_hand"
        elif part_key == "right_hand":
            indices_key = "right_hand"

        # get points for cropping
        part_indices = self.part_indices[indices_key]
        if crop_joints is not None:
            points_for_crop = crop_joints[:, part_indices]
        else:
            points_for_crop = points_dict["smplx_kpt"][:, part_indices]

        # crop
        cropper_key = "hand" if "hand" in part_key else part_key
        points_scale = image.shape[-2:]
        cropped_image, tform = self.Cropper[cropper_key].crop(image, points_for_crop, points_scale)
        # transform points(must be normalized to [-1.1]) accordingly
        cropped_points_dict = {}
        for points_key in points_dict.keys():
            points = points_dict[points_key]
            cropped_points = self.Cropper[cropper_key].transform_points(
                points, tform, points_scale, normalize=True
            )
            cropped_points_dict[points_key] = cropped_points
        return cropped_image, cropped_points_dict

    @torch.no_grad()
    def encode(
        self,
        data,
        threthold=True,
        keep_local=True,
        copy_and_paste=False,
        body_only=False,
    ):
        """Encode images to smplx parameters
        Args:
            data: dict
                key: image_type (body/head/hand)
                value:
                    image: [bz, 3, 224, 224], range [0,1]
                    image_hd(needed if key==body): a high res version of image, only for cropping parts from body image
                    head_image: optinal, well-cropped head from body image
                    left_hand_image: optinal, well-cropped left hand from body image
                    right_hand_image: optinal, well-cropped right hand from body image
        Returns:
            param_dict: dict
                key: image_type (body/head/hand)
                value: param_dict
        """
        for key in data.keys():
            assert key in ["body", "head", "hand"]

        feature = {}
        param_dict = {}

        # Encode features
        for key in data.keys():
            part = key
            # encode feature
            feature[key] = {}
            feature[key][part] = self.Encoder[part](data[key]["image"])

            # for head/hand image
            if key == "head" or key == "hand":
                # predict head/hand-only parameters from part feature
                part_dict = self.decompose_code(
                    self.Regressor[part](feature[key][part]),
                    self.param_list_dict[f"{part}_list"],
                )
                # if input is part data, skip feature fusion: share feature is the same as part feature
                # then predict share parameters
                feature[key][f"{key}_share"] = feature[key][key]
                share_dict = self.decompose_code(
                    self.Regressor[f"{part}_share"](feature[key][f"{part}_share"]),
                    self.param_list_dict[f"{part}_share_list"],
                )
                # compose parameters
                param_dict[key] = {**share_dict, **part_dict}

            # for body image
            if key == "body":
                fusion_weight = {}
                f_body = feature["body"]["body"]
                # extract part feature
                for part_name in ["head", "left_hand", "right_hand"]:
                    feature["body"][f"{part_name}_share"] = self.Extractor[f"{part_name}_share"](
                        f_body
                    )

                # -- check if part crops are given, if not, crop parts by coarse body estimation
                if (
                    "head_image" not in data[key].keys() or
                    "left_hand_image" not in data[key].keys() or
                    "right_hand_image" not in data[key].keys()
                ):
                    # - run without fusion to get coarse estimation, for cropping parts
                    # body only
                    body_dict = self.decompose_code(
                        self.Regressor[part](feature[key][part]),
                        self.param_list_dict[part + "_list"],
                    )
                    # head share
                    head_share_dict = self.decompose_code(
                        self.Regressor["head" + "_share"](feature[key]["head" + "_share"]),
                        self.param_list_dict["head" + "_share_list"],
                    )
                    # right hand share
                    right_hand_share_dict = self.decompose_code(
                        self.Regressor["hand" + "_share"](feature[key]["right_hand" + "_share"]),
                        self.param_list_dict["hand" + "_share_list"],
                    )
                    # left hand share
                    left_hand_share_dict = self.decompose_code(
                        self.Regressor["hand" + "_share"](feature[key]["left_hand" + "_share"]),
                        self.param_list_dict["hand" + "_share_list"],
                    )
                    # change the dict name from right to left
                    left_hand_share_dict["left_hand_pose"] = left_hand_share_dict.pop(
                        "right_hand_pose"
                    )
                    left_hand_share_dict["left_wrist_pose"] = left_hand_share_dict.pop(
                        "right_wrist_pose"
                    )
                    param_dict[key] = {
                        **body_dict,
                        **head_share_dict,
                        **left_hand_share_dict,
                        **right_hand_share_dict,
                    }
                    if body_only:
                        param_dict["moderator_weight"] = None
                        return param_dict
                    prediction_body_only = self.decode(param_dict[key], param_type="body")
                    # crop
                    for part_name in ["head", "left_hand", "right_hand"]:
                        part = part_name.split("_")[-1]
                        points_dict = {
                            "smplx_kpt": prediction_body_only["smplx_kpt"],
                            "trans_verts": prediction_body_only["transformed_vertices"],
                        }
                        image_hd = torchvision.transforms.Resize(1024)(data["body"]["image"])
                        cropped_image, cropped_joints_dict = self.part_from_body(
                            image_hd, part_name, points_dict
                        )
                        data[key][part_name + "_image"] = cropped_image

                # -- encode features from part crops, then fuse feature using the weight from moderator
                for part_name in ["head", "left_hand", "right_hand"]:
                    part = part_name.split("_")[-1]
                    cropped_image = data[key][part_name + "_image"]
                    # if left hand, flip it as if it is right hand
                    if part_name == "left_hand":
                        cropped_image = torch.flip(cropped_image, dims=(-1, ))
                    # run part regressor
                    f_part = self.Encoder[part](cropped_image)
                    part_dict = self.decompose_code(
                        self.Regressor[part](f_part),
                        self.param_list_dict[f"{part}_list"],
                    )
                    part_share_dict = self.decompose_code(
                        self.Regressor[f"{part}_share"](f_part),
                        self.param_list_dict[f"{part}_share_list"],
                    )
                    param_dict["body_" + part_name] = {**part_dict, **part_share_dict}

                    # moderator to assign weight, then integrate features
                    f_body_out, f_part_out, f_weight = self.Moderator[f"{part}_share"](
                        feature["body"][f"{part_name}_share"], f_part, work=True
                    )
                    if copy_and_paste:
                        # copy and paste strategy always trusts the results from part
                        feature["body"][f"{part_name}_share"] = f_part
                    elif threthold and part == "hand":
                        # for hand, if part weight > 0.7 (very confident, then fully trust part)
                        part_w = f_weight[:, [1]]
                        part_w[part_w > 0.7] = 1.0
                        f_body_out = (
                            feature["body"][f"{part_name}_share"] * (1.0 - part_w) + f_part * part_w
                        )
                        feature["body"][f"{part_name}_share"] = f_body_out
                    else:
                        feature["body"][f"{part_name}_share"] = f_body_out
                    fusion_weight[part_name] = f_weight
                # save weights from moderator, that can be further used for optimization/running specific tasks on parts
                param_dict["moderator_weight"] = fusion_weight

                # -- predict parameters from fused body feature
                # head share
                head_share_dict = self.decompose_code(
                    self.Regressor["head" + "_share"](feature[key]["head" + "_share"]),
                    self.param_list_dict["head" + "_share_list"],
                )
                # right hand share
                right_hand_share_dict = self.decompose_code(
                    self.Regressor["hand" + "_share"](feature[key]["right_hand" + "_share"]),
                    self.param_list_dict["hand" + "_share_list"],
                )
                # left hand share
                left_hand_share_dict = self.decompose_code(
                    self.Regressor["hand" + "_share"](feature[key]["left_hand" + "_share"]),
                    self.param_list_dict["hand" + "_share_list"],
                )
                # change the dict name from right to left
                left_hand_share_dict["left_hand_pose"] = left_hand_share_dict.pop("right_hand_pose")
                left_hand_share_dict["left_wrist_pose"] = left_hand_share_dict.pop(
                    "right_wrist_pose"
                )
                param_dict["body"] = {
                    **body_dict,
                    **head_share_dict,
                    **left_hand_share_dict,
                    **right_hand_share_dict,
                }
                # copy tex param from head param dict to body param dict
                param_dict["body"]["tex"] = param_dict["body_head"]["tex"]
                param_dict["body"]["light"] = param_dict["body_head"]["light"]

                if keep_local:
                    # for local change that will not affect whole body and produce unnatral pose, trust part
                    param_dict[key]["exp"] = param_dict["body_head"]["exp"]
                    param_dict[key]["right_hand_pose"] = param_dict["body_right_hand"][
                        "right_hand_pose"]
                    param_dict[key]["left_hand_pose"] = param_dict["body_left_hand"][
                        "right_hand_pose"]

        return param_dict

    def convert_pose(self, param_dict, param_type):
        """Convert pose parameters to rotation matrix
        Args:
            param_dict: smplx parameters
            param_type: should be one of body/head/hand
        Returns:
            param_dict: smplx parameters
        """
        assert param_type in ["body", "head", "hand"]

        # convert pose representations: the output from network are continous repre or axis angle,
        # while the input pose for smplx need to be rotation matrix
        for key in param_dict:
            if "pose" in key and "jaw" not in key:
                param_dict[key] = converter.batch_cont2matrix(param_dict[key])
        if param_type == "body" or param_type == "head":
            param_dict["jaw_pose"] = converter.batch_euler2matrix(param_dict["jaw_pose"]
                                                                 )[:, None, :, :]

        # complement params if it's not in given param dict
        if param_type == "head":
            batch_size = param_dict["shape"].shape[0]
            param_dict["abs_head_pose"] = param_dict["head_pose"].clone()
            param_dict["global_pose"] = param_dict["head_pose"]
            param_dict["partbody_pose"] = self.smplx.body_pose.unsqueeze(0).expand(
                batch_size, -1, -1, -1
            )[:, :self.param_list_dict["body_list"]["partbody_pose"]]
            param_dict["neck_pose"] = self.smplx.neck_pose.unsqueeze(0).expand(
                batch_size, -1, -1, -1
            )
            param_dict["left_wrist_pose"] = self.smplx.neck_pose.unsqueeze(0).expand(
                batch_size, -1, -1, -1
            )
            param_dict["left_hand_pose"] = self.smplx.left_hand_pose.unsqueeze(0).expand(
                batch_size, -1, -1, -1
            )
            param_dict["right_wrist_pose"] = self.smplx.neck_pose.unsqueeze(0).expand(
                batch_size, -1, -1, -1
            )
            param_dict["right_hand_pose"] = self.smplx.right_hand_pose.unsqueeze(0).expand(
                batch_size, -1, -1, -1
            )
        elif param_type == "hand":
            batch_size = param_dict["right_hand_pose"].shape[0]
            param_dict["abs_right_wrist_pose"] = param_dict["right_wrist_pose"].clone()
            dtype = param_dict["right_hand_pose"].dtype
            device = param_dict["right_hand_pose"].device
            x_180_pose = (torch.eye(3, dtype=dtype, device=device).unsqueeze(0).repeat(1, 1, 1))
            x_180_pose[0, 2, 2] = -1.0
            x_180_pose[0, 1, 1] = -1.0
            param_dict["global_pose"] = x_180_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
            param_dict["shape"] = self.smplx.shape_params.expand(batch_size, -1)
            param_dict["exp"] = self.smplx.expression_params.expand(batch_size, -1)
            param_dict["head_pose"] = self.smplx.head_pose.unsqueeze(0).expand(
                batch_size, -1, -1, -1
            )
            param_dict["neck_pose"] = self.smplx.neck_pose.unsqueeze(0).expand(
                batch_size, -1, -1, -1
            )
            param_dict["jaw_pose"] = self.smplx.jaw_pose.unsqueeze(0).expand(batch_size, -1, -1, -1)
            param_dict["partbody_pose"] = self.smplx.body_pose.unsqueeze(0).expand(
                batch_size, -1, -1, -1
            )[:, :self.param_list_dict["body_list"]["partbody_pose"]]
            param_dict["left_wrist_pose"] = self.smplx.neck_pose.unsqueeze(0).expand(
                batch_size, -1, -1, -1
            )
            param_dict["left_hand_pose"] = self.smplx.left_hand_pose.unsqueeze(0).expand(
                batch_size, -1, -1, -1
            )
        elif param_type == "body":
            # the predcition from the head and hand share regressor is always absolute pose
            batch_size = param_dict["shape"].shape[0]
            param_dict["abs_head_pose"] = param_dict["head_pose"].clone()
            param_dict["abs_right_wrist_pose"] = param_dict["right_wrist_pose"].clone()
            param_dict["abs_left_wrist_pose"] = param_dict["left_wrist_pose"].clone()
            # the body-hand share regressor is working for right hand
            # so we assume body network get the flipped feature for the left hand. then get the parameters
            # then we need to flip it back to left, which matches the input left hand
            param_dict["left_wrist_pose"] = util.flip_pose(param_dict["left_wrist_pose"])
            param_dict["left_hand_pose"] = util.flip_pose(param_dict["left_hand_pose"])
        else:
            exit()

        return param_dict

    def decode(self, param_dict, param_type):
        """Decode model parameters to smplx vertices & joints & texture
        Args:
            param_dict: smplx parameters
            param_type: should be one of body/head/hand
        Returns:
            predictions: smplx predictions
        """
        if "jaw_pose" in param_dict.keys() and len(param_dict["jaw_pose"].shape) == 2:
            self.convert_pose(param_dict, param_type)
        elif param_dict["right_wrist_pose"].shape[-1] == 6:
            self.convert_pose(param_dict, param_type)

        # concatenate body pose
        partbody_pose = param_dict["partbody_pose"]
        param_dict["body_pose"] = torch.cat(
            [
                partbody_pose[:, :11],
                param_dict["neck_pose"],
                partbody_pose[:, 11:11 + 2],
                param_dict["head_pose"],
                partbody_pose[:, 13:13 + 4],
                param_dict["left_wrist_pose"],
                param_dict["right_wrist_pose"],
            ],
            dim=1,
        )

        # change absolute head&hand pose to relative pose according to rest body pose
        if param_type == "head" or param_type == "body":
            param_dict["body_pose"] = self.smplx.pose_abs2rel(
                param_dict["global_pose"], param_dict["body_pose"], abs_joint="head"
            )
        if param_type == "hand" or param_type == "body":
            param_dict["body_pose"] = self.smplx.pose_abs2rel(
                param_dict["global_pose"],
                param_dict["body_pose"],
                abs_joint="left_wrist",
            )
            param_dict["body_pose"] = self.smplx.pose_abs2rel(
                param_dict["global_pose"],
                param_dict["body_pose"],
                abs_joint="right_wrist",
            )

        if self.cfg.model.check_pose:
            # check if pose is natural (relative rotation), if not, set relative to 0 (especially for head pose)
            # xyz: pitch(positive for looking down), yaw(positive for looking left), roll(rolling chin to left)
            for pose_ind in [14]:    # head [15-1, 20-1, 21-1]:
                curr_pose = param_dict["body_pose"][:, pose_ind]
                euler_pose = converter._compute_euler_from_matrix(curr_pose)
                for i, max_angle in enumerate([20, 70, 10]):
                    euler_pose_curr = euler_pose[:, i]
                    euler_pose_curr[euler_pose_curr != torch.clamp(
                        euler_pose_curr,
                        min=-max_angle * np.pi / 180,
                        max=max_angle * np.pi / 180,
                    )] = 0.0
                param_dict["body_pose"][:, pose_ind] = converter.batch_euler2matrix(euler_pose)

        # SMPLX
        verts, landmarks, joints = self.smplx(
            shape_params=param_dict["shape"],
            expression_params=param_dict["exp"],
            global_pose=param_dict["global_pose"],
            body_pose=param_dict["body_pose"],
            jaw_pose=param_dict["jaw_pose"],
            left_hand_pose=param_dict["left_hand_pose"],
            right_hand_pose=param_dict["right_hand_pose"],
        )
        smplx_kpt3d = joints.clone()

        # projection
        cam = param_dict[param_type + "_cam"]
        trans_verts = util.batch_orth_proj(verts, cam)
        predicted_landmarks = util.batch_orth_proj(landmarks, cam)[:, :, :2]
        predicted_joints = util.batch_orth_proj(joints, cam)[:, :, :2]

        prediction = {
            "vertices": verts,
            "transformed_vertices": trans_verts,
            "face_kpt": predicted_landmarks,
            "smplx_kpt": predicted_joints,
            "smplx_kpt3d": smplx_kpt3d,
            "joints": joints,
            "cam": param_dict[param_type + "_cam"],
        }

        # change the order of face keypoints, to be the same as "standard" 68 keypoints
        prediction["face_kpt"] = torch.cat([
            prediction["face_kpt"][:, -17:], prediction["face_kpt"][:, :-17]
        ],
                                           dim=1)

        prediction.update(param_dict)

        return prediction

    def decode_Tpose(self, param_dict):
        """return body mesh in T pose, support body and head param dict only"""
        verts, _, _ = self.smplx(
            shape_params=param_dict["shape"],
            expression_params=param_dict["exp"],
            jaw_pose=param_dict["jaw_pose"],
        )
        return verts