import sys
import cv2
import math
import copy
import torch
import itertools
import numpy as np
import matplotlib.pyplot as plt

from mpl_toolkits.axes_grid1 import make_axes_locatable
from scipy.optimize import linear_sum_assignment
from scipy.stats import linregress
from ellipse import LsqEllipse
from itertools import product
from functools import reduce

from utils.utils_field import _draw_field
from utils.utils_heatmap import generate_gaussian_array_vectorized


class KeypointsWCDB(object):
    def __init__(self, image, homography, size_out=(960,540)):

        self.keypoint_world_coords_2D = [[0., 0.], [52.5, 0.], [105., 0.], [0., 13.84], [16.5, 13.84], [88.5, 13.84],
                                         [105., 13.84], [0., 24.84], [5.5, 24.84], [99.5, 24.84], [105., 24.84],
                                         [0., 30.34], [0., 30.34], [105., 30.34], [105., 30.34], [0., 37.66],
                                         [0., 37.66], [105., 37.66], [105., 37.66], [0., 43.16], [5.5, 43.16],
                                         [99.5, 43.16], [105., 43.16], [0., 54.16], [16.5, 54.16], [88.5, 54.16],
                                         [105., 54.16], [0., 68.], [52.5, 68.], [105., 68.], [16.5, 26.68],
                                         [52.5, 24.85], [88.5, 26.68], [16.5, 41.31], [52.5, 43.15], [88.5, 41.31],
                                         [19.99, 32.29], [43.68, 31.53], [61.31, 31.53], [85., 32.29], [19.99, 35.7],
                                         [43.68, 36.46], [61.31, 36.46], [85., 35.7], [11., 34.], [16.5, 34.],
                                         [20.15, 34.], [46.03, 27.53], [58.97, 27.53], [43.35, 34.], [52.5, 34.],
                                         [61.5, 34.], [46.03, 40.47], [58.97, 40.47], [84.85, 34.], [88.5, 34.],
                                         [94., 34.]]  # 57

        self.keypoint_aux_world_coords_2D = [[5.5, 0], [16.5, 0], [88.5, 0], [99.5, 0], [5.5, 13.84], [99.5, 13.84],
                                             [16.5, 24.84], [88.5, 24.84], [16.5, 43.16], [88.5, 43.16], [5.5, 54.16],
                                             [99.5, 54.16], [5.5, 68], [16.5, 68], [88.5, 68], [99.5, 68]]

        self.lines_retrieval = [[24, 25], [5, 25], [4, 5], [26, 27], [6, 26], [12, 16], [16, 17], [12, 13], [15, 19],
                                [14, 15], [18, 19], [2, 29], [28, 29, 30], [1, 4, 8, 13, 17, 20, 24, 28],
                                [3, 7, 11, 14, 18, 23, 27, 30], [1, 2, 3], [20, 21], [9, 21], [8, 9], [22, 23],
                                [10, 22], [10, 11]]


        self.homography = homography
        self.image = image

        self.w, self.h = size_out
        self.size = (self.w, self.h)
        self.h_extra = self.h * 0.5
        self.w_extra = self.w * 0.5

        self.keypoints_final = {}

        self.num_channels = len(self.keypoint_world_coords_2D) + 1
        self.mask_array = np.ones(self.num_channels).astype(int)


    def get_tensor_w_mask(self):

        self.get_kp_from_homography()
        for kp in [12,15,16,19]:
            self.mask_array[kp-1] = 0
        heatmap_tensor = generate_gaussian_array_vectorized(self.num_channels, self.keypoints_final, self.size,
                                                            down_ratio=2, sigma=2)
        return heatmap_tensor, self.mask_array


    def kpmeters2yards(self, kp):
        wp = self.keypoint_world_coords_2D[kp - 1]
        wp_arr = np.array([wp[0] * 1.09361, wp[1] * 1.09361, 1.])
        return wp_arr



    def get_kp_from_homography(self):
        for kp in range(1, len(self.keypoint_world_coords_2D)+1):
            if kp not in [12, 15, 16, 19]:
                #wp_arr = self.kpmeters2yards(kp)
                wp = self.keypoint_world_coords_2D[kp-1]
                img_pt = np.linalg.inv(self.homography) @ np.array([wp[0], wp[1], 1.])
                img_pt /= img_pt[-1]
                img_pt[0] *= self.w / self.image.size[0]
                img_pt[1] *= self.h / self.image.size[1]

                self.keypoints_final[kp] = {'x': img_pt[0],
                                            'y': img_pt[1],
                                            'in_frame': True if 0 <= img_pt[0] <= self.w and 0 <= img_pt[1] <= self.w else False,
                                            'close_to_frame': True if -self.w_extra <= img_pt[0] <= self.w + self.w_extra and \
                                                                      -self.h_extra <= img_pt[1] <= self.h + self.h_extra else False}


    def get_lines_from_keypoints(self):
        if len(self.keypoints_final) == 0:
            self.get_kp_from_homography()

            ...

    def draw_keypoints(self, scale=1):

        if len(self.keypoints_final) == 0:
            self.get_kp_from_homography()

        fig, ax = plt.subplots(figsize=(scale*15, scale*7.5))
        ax.imshow(self.image)
        for kp in self.keypoints_final.keys():
            if kp <= 30:
                if self.keypoints_final[kp]['close_to_frame']:
                    x, y = self.keypoints_final[kp]['x'], self.keypoints_final[kp]['y']
                    ax.text(x, y, s=kp, zorder=11)
                    ax.scatter(x, y, c='r', s=scale*10, zorder=10)


            elif 30 < kp <= 36:
                if self.keypoints_final[kp]['close_to_frame']:
                    x, y = self.keypoints_final[kp]['x'], self.keypoints_final[kp]['y']
                    ax.text(x, y, s=kp, zorder=11)
                    ax.scatter(x, y, c='b', s=scale*10, zorder=10)


            elif 36 < kp <= 44:
                if self.keypoints_final[kp]['close_to_frame']:
                    x, y = self.keypoints_final[kp]['x'], self.keypoints_final[kp]['y']
                    ax.text(x, y, s=kp, zorder=11)
                    ax.scatter(x, y, c='pink', s=scale*10, zorder=10)


            elif 44 < kp <= 57:
                if self.keypoints_final[kp]['close_to_frame']:
                    x, y = self.keypoints_final[kp]['x'], self.keypoints_final[kp]['y']
                    ax.text(x, y, s=kp, zorder=11)
                    ax.scatter(x, y, c='green', s=scale*10, zorder=10)

        plt.show()