import logging
import sys
import numpy as np
sys.path.append("../")
# from tdc.multi_pred import GDA
import pandas as pd
from torch.utils.data import Dataset

LOGGER = logging.getLogger(__name__)

class GDA_Dataset(Dataset):
    """
    Candidate Dataset for:
        ALL gene-to-disease interactions
    """ 
    def __init__(self, data_examples):
        self.protein_seqs = data_examples[0]
        self.disease_dess = data_examples[1]
        self.scores = data_examples[2]

    def __getitem__(self, query_idx):

        protein_seq = self.protein_seqs[query_idx]
        disease_des = self.disease_dess[query_idx]
        score = self.scores[query_idx]

        return protein_seq, disease_des, score

    def __len__(self):
        return len(self.protein_seqs)


class TDC_Pretrain_Dataset(Dataset):
    """
        Dataset of TDC:
            ALL gene-disease associations
    """
    def __init__(self, data_dir="../../data/pretrain/", test=False):
        LOGGER.info("Initializing TDC Pretraining Dataset ! ...")
        
        data = GDA(name="DisGeNET")  # , path=data_dir
        data.neg_sample(frac = 1)
        data.binarize(threshold = 0, order = 'ascending')
        self.datasets = data.get_split()
        self.name = "DisGeNET"
        self.dataset_df = self.datasets['train']
        # self.dataset_df = pd.read_csv(f"{data_dir}/disgenet_gda.csv")
        self.dataset_df = self.dataset_df[
            ["Gene", "Disease", "Y"]
        ].dropna()  # Drop missing values.
        # print(self.dataset_df.head())
        print(
            f"{data_dir}TDC training dataset loaded, found associations: {len(self.dataset_df.index)}"
        )
        self.protein_seqs = self.dataset_df["Gene"].values
        self.disease_dess = self.dataset_df["Disease"].values
        self.scores = len(self.dataset_df["Y"].values) * [1]

    def __getitem__(self, query_idx):

        protein_seq = self.protein_seqs[query_idx]
        disease_des = self.disease_dess[query_idx]
        score = self.scores[query_idx]

        return protein_seq, disease_des, score

    def __len__(self):
        return len(self.protein_seqs)

class GDA_Pretrain_Dataset(Dataset):
    """
    Candidate Dataset for:
        ALL gene-disease associations
    """

    def __init__(self, data_dir="../../data/pretrain/", test=False, split="train", val_ratio=0.2):
        LOGGER.info("Initializing GDA Pretraining Dataset ! ...")
        self.dataset_df = pd.read_csv(f"{data_dir}/disgenet_gda.csv")
        self.dataset_df = self.dataset_df[["proteinSeq", "diseaseDes", "score"]].dropna()
        self.dataset_df = self.dataset_df.sample(frac=1, random_state=42).reset_index(drop=True)

        num_val_samples = int(len(self.dataset_df) * val_ratio)
        if split == "train":
            self.dataset_df = self.dataset_df[:-num_val_samples]
            print(f"{data_dir}disgenet_gda.csv loaded, found train associations: {len(self.dataset_df.index)}")
        elif split == "val":
            self.dataset_df = self.dataset_df[-num_val_samples:]
            print(f"{data_dir}disgenet_gda.csv loaded, found valid associations: {len(self.dataset_df.index)}")
            
        if test:
            self.protein_seqs = self.dataset_df["proteinSeq"].values[:128]
            self.disease_dess = self.dataset_df["diseaseDes"].values[:128]
            self.scores = 128 * [1]
        else:
            self.protein_seqs = self.dataset_df["proteinSeq"].values
            self.disease_dess = self.dataset_df["diseaseDes"].values
            self.scores = len(self.dataset_df["score"].values) * [1]

    def __getitem__(self, query_idx):

        protein_seq = self.protein_seqs[query_idx]
        disease_des = self.disease_dess[query_idx]
        score = self.scores[query_idx]

        return protein_seq, disease_des, score

    def __len__(self):
        return len(self.protein_seqs)
#         # 分离正负样本
#         positive_samples = self.dataset_df[self.dataset_df["score"] == 1]
#         negative_samples = self.dataset_df[self.dataset_df["score"] == 0]

#         # 打乱并划分正样本
#         positive_samples = positive_samples.sample(frac=1, random_state=42).reset_index(drop=True)
#         num_pos_val_samples = int(len(positive_samples) * val_ratio)

#         # 打乱并划分负样本
#         negative_samples = negative_samples.sample(frac=1, random_state=42).reset_index(drop=True)
#         num_neg_val_samples = int(len(negative_samples) * val_ratio)

        # if split == "train":
        #     self.dataset_df = pd.concat([positive_samples[:-num_pos_val_samples], negative_samples[:-num_neg_val_samples]])
        #     print(f"{data_dir}disgenet_gda.csv loaded, found associations: {len(self.dataset_df.index)}")
        # elif split == "val":
        #     self.dataset_df = pd.concat([positive_samples[-num_pos_val_samples:], negative_samples[-num_neg_val_samples:]])
        #     print(f"{data_dir}disgenet_gda.csv loaded, found associations: {len(self.dataset_df.index)}")
        # Shuffle and split data

# class GDA_Pretrain_Dataset(Dataset):
#     """
#     Candidate Dataset for:
#         ALL gene-disease associations
#     """

#     def __init__(self, data_dir="../../data/pretrain/", test=False):
#         LOGGER.info("Initializing GDA Pretraining Dataset ! ...")
        # updated = pd.read_csv(f"{data_dir}/disgenet_updated.csv")     
        
        # data = GDA(name="DisGeNET")
        # data = data.get_data()
        # data = data[['Gene_ID','Disease_ID']].dropna()
        # self.dataset_df = pd.read_csv(f"{data_dir}/disgenet_gda.csv")
        
        # num_unique_diseaseId = self.dataset_df['diseaseId'].nunique()
        # num_unique_geneId = self.dataset_df['geneId'].nunique()

        # print(f"Number of unique 'diseaseId': {num_unique_diseaseId}")
        # print(f"Number of unique 'geneId': {num_unique_geneId}")
        
#         num_of_c0002395 = self.dataset_df[self.dataset_df['diseaseId'] == 'C0002395'].shape[0]
        # print(f"Alzheimer Number in 2020:{num_of_c0002395}")
        
        # Convert 'Gene_ID' and 'Disease_ID' to str before merge
        # data['Gene_ID'] = data['Gene_ID'].astype(str)
        # data['Disease_ID'] = data['Disease_ID'].astype(str)

        # Similarly for 'geneId' and 'diseaseId', if they're not already of type 'str'
        # self.dataset_df['geneId'] = self.dataset_df['geneId'].astype(str)
        # self.dataset_df['diseaseId'] = self.dataset_df['diseaseId'].astype(str)

#         # 合并两个DataFrame并找出不同的行
#         merged = df.merge(self.dataset_df, how='outer', indicator=True)
#         differences = merged[merged['_merge'] != 'both']

#         differences.to_csv('/nfs/dpa_pretrain/data/pretrain/differences.csv', index=False)

        
#         Check for overlap between TDC dataset and DisGeNET dataset
#         merged_df = pd.merge(data, self.dataset_df, how='inner', left_on=['Gene_ID','Disease_ID'], right_on=['geneId','diseaseId'])
        
#         num_matched_pairs = merged_df.shape[0]

#         print(f"Number of matched pairs TDC: {num_matched_pairs}")
        
#         merged_dis = pd.merge(data, updated, how='inner', left_on=['Gene','Disease'], right_on=['proteinSeq','diseaseDes'])
        
#         num_matched = merged_dis.shape[0]

#         print(f"Number of matched pairs DisGeNET_test: {num_matched}")
        
        # self.dataset_df = self.dataset_df[
        #     ["proteinSeq", "diseaseDes", "score"]
        # ].dropna()  # Drop missing values.
        # print(self.dataset_df.head())  "proteinSeq", "diseaseDes", "score"
        
        # print(
        #     f"{data_dir}disgenet_gda.csv loaded, found associations: {len(self.dataset_df.index)}"
        # )
#         df1 = pd.read_csv(f"{data_dir}/disgenet_gda.csv")
#         df1 = df1[
#             ["proteinSeq", "diseaseDes", "score"]
#         ].dropna()

#         # 合并两个DataFrame并找出不同的行
#         merged = df1.merge(self.dataset_df, how='outer', indicator=True)
#         differences = merged[merged['_merge'] != 'both']

#         # 将结果保存到新的文件中
#         differences.to_csv('/nfs/dpa_pretrain/data/pretrain/differences.csv', index=False)

#         if test:
#             self.protein_seqs = self.dataset_df["proteinSeq"].values[:128]
#             self.disease_dess = self.dataset_df["diseaseDes"].values[:128]
#             self.scores = 128 * [1]
#         else:
#             self.protein_seqs = self.dataset_df["proteinSeq"].values
#             self.disease_dess = self.dataset_df["diseaseDes"].values
#             self.scores = len(self.dataset_df["score"].values) * [1]

#     def __getitem__(self, query_idx):

#         protein_seq = self.protein_seqs[query_idx]
#         disease_des = self.disease_dess[query_idx]
#         score = self.scores[query_idx]

#         return protein_seq, disease_des, score

#     def __len__(self):
#         return len(self.protein_seqs)


class PPI_Pretrain_Dataset(Dataset):
    """
    Candidate Dataset for:
        ALL protein-to-protein interactions
    """

    def __init__(self, data_dir="../../data/pretrain/", test=False):
        LOGGER.info("Initializing metric learning data set! ...")
        self.dataset_df = pd.read_csv(f"{data_dir}/string_ppi_900_2m.csv")
        self.dataset_df = self.dataset_df[["item_seq_a", "item_seq_b", "score"]]
        self.dataset_df = self.dataset_df.dropna()
        if test:
            self.dataset_df = self.dataset_df.sample(100)
        print(
            f"{data_dir}/string_ppi_900_2m.csv loaded, found interactions: {len(self.dataset_df.index)}"
        )
        self.protein_seq1 = self.dataset_df["item_seq_a"].values
        self.protein_seq2 = self.dataset_df["item_seq_b"].values
        self.scores = len(self.dataset_df["score"].values) * [1]

    def __getitem__(self, query_idx):

        protein_seq1 = self.protein_seq1[query_idx]
        protein_seq2 = self.protein_seq2[query_idx]
        score = self.scores[query_idx]

        return protein_seq1, protein_seq2, score

    def __len__(self):
        return len(self.protein_seq1)


class PPI_Dataset(Dataset):
    """
    Candidate Dataset for:
        ALL protein-to-protein interactions
    """

    def __init__(self, protein_seq1, protein_seq2, score):
        self.protein_seq1 = protein_seq1
        self.protein_seq2 = protein_seq2
        self.scores = score

    def __getitem__(self, query_idx):

        protein_seq1 = self.protein_seq1[query_idx]
        protein_seq2 = self.protein_seq2[query_idx]
        score = self.scores[query_idx]

        return protein_seq1, protein_seq2, score

    def __len__(self):
        return len(self.protein_seq1)


class DDA_Dataset(Dataset):
    """
    Candidate Dataset for:
        ALL disease-to-disease associations
    """

    def __init__(self, diseaseDes1, diseaseDes2, label):
        self.diseaseDes1 = diseaseDes1
        self.diseaseDes2 = diseaseDes2
        self.label = label

    def __getitem__(self, query_idx):

        diseaseDes1 = self.diseaseDes1[query_idx]
        diseaseDes2 = self.diseaseDes2[query_idx]
        label = self.label[query_idx]

        return diseaseDes1, diseaseDes2, label

    def __len__(self):
        return len(self.diseaseDes1)


class DDA_Pretrain_Dataset(Dataset):
    """
    Candidate Dataset for:
        ALL protein-to-protein interactions
    """

    def __init__(self, data_dir="../../data/pretrain/", test=False):
        LOGGER.info("Initializing metric learning data set! ...")
        self.dataset_df = pd.read_csv(f"{data_dir}disgenet_dda.csv")
        self.dataset_df = self.dataset_df.dropna()  # Drop missing values.
        if test:
            self.dataset_df = self.dataset_df.sample(100)
        print(
            f"{data_dir}disgenet_dda.csv loaded, found associations: {len(self.dataset_df.index)}"
        )
        self.disease_des1 = self.dataset_df["diseaseDes1"].values
        self.disease_des2 = self.dataset_df["diseaseDes2"].values
        self.scores = len(self.dataset_df["jaccard_variant"].values) * [1]

    def __getitem__(self, query_idx):

        disease_des1 = self.disease_des1[query_idx]
        disease_des2 = self.disease_des2[query_idx]
        score = self.scores[query_idx]

        return disease_des1, disease_des2, score

    def __len__(self):
        return len(self.disease_des1)