import pandas as pd
import numpy as np
from rouge_score import rouge_scorer
from joblib import Parallel, delayed
from selfrank.algos.greedy import SelfRankGreedy
from selfrank.algos.iterative import SelfRank
from selfrank.algos.baseline import MCARank
from selfrank.algos.triplet import equality, rouge, noisy_equality
import matplotlib.pyplot as plt
from itertools import zip_longest
from uuid import uuid4
import csv, os
from functools import partial
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def generate_data(max_acc, min_acc, nmodels, nanswers, nquestions) -> tuple[pd.DataFrame, list]:

    np.random.seed(42)
    # Spread model accuracies between min and max
    model_acc = np.linspace(max_acc, min_acc, nmodels)

    gt_and_model_ans = np.zeros(
        (nquestions, nmodels + 1), dtype=int
    )  # array to store ground truth and model ans

    # Create ground truth answers i.e. first column
    for i in range(nquestions):
        gt_and_model_ans[i][0] = np.random.randint(nanswers)

    for i in range(0, nmodels):
        no_of_entries_frm_gt = np.ceil(model_acc[i] / 100 * (nquestions)).astype(int)
        # print(no_of_entries_frm_gt)
        offsets_to_match = np.random.permutation(nquestions)[0:no_of_entries_frm_gt]
        # print(offsets_to_match)
        for j in range(nquestions):
            if j in offsets_to_match:
                gt_and_model_ans[j][i + 1] = gt_and_model_ans[j][0]
            else:
                lst_wo_gt = list(range(nanswers))
                lst_wo_gt.remove(gt_and_model_ans[j][0])
                gt_and_model_ans[j][i + 1] = lst_wo_gt[np.random.randint(nanswers - 1)]

    # print(gt_and_model_ans)
    filename = str(uuid4())

    fields = ["GT"]
    for i in range(nmodels):
        fields.append("M" + str(i + 1))

    # writing to csv file
    with open(filename, "w") as csvfile:
        # creating a csv writer object
        csvwriter = csv.writer(csvfile)

        # writing the fields
        csvwriter.writerow(fields)

        # writing the data rows
        csvwriter.writerows(gt_and_model_ans)

    df = pd.read_csv(filename)
    os.remove(filename)

    true_ranking = [f"M{i}" for i in range(1, nmodels + 1)]

    return df, true_ranking

def synth_executor(acc_range: tuple[float, float], nmodels, nanswers, nquestions, noise, method) -> tuple[str, dict]:


    min_acc, max_acc = acc_range
    logger.info(f"Synth experiment: min_acc:{min_acc}, max_acc:{max_acc}, nmodels: {nmodels}, nanswers: {nanswers}, nquestions: {nquestions}, noise:{noise}, method:{method}.")

    df, true_ranking = generate_data(max_acc, min_acc, nmodels, nanswers, nquestions)

    if noise == 0.:
        comp = equality
    else:
        comp = partial(noisy_equality, p=noise)
    
    df = df.drop(columns=["GT"])
    MODELS = df.columns.tolist()

    if method == "Full":
        ranker = SelfRank(MODELS, comp, true_ranking)
        ranker.fit(df)

        # outputs of interest
        out = {
            "true_ranking": true_ranking,
            "estimated_ranking": ranker.ranking,
            "rbo": ranker.measure(metric="rbo"),
            "map-1": ranker.measure(metric='mapk', k=1),
            "map-3": ranker.measure(metric='mapk', k=3),
            "map-5": ranker.measure(metric='mapk', k=5),
            "map-10": ranker.measure(metric='mapk', k=10)
        }

    elif method == "Greedy":
        ranker = SelfRankGreedy(MODELS, comp, true_ranking)
        ranker.fit(df)
        out = {
            "true_ranking": true_ranking,
            "estimated_ranking": ranker.ranking,
            "rbo": ranker.measure(metric="rbo"),
            "map-1": ranker.measure(metric='mapk', k=1),
            "map-3": ranker.measure(metric='mapk', k=3),
            "map-5": ranker.measure(metric='mapk', k=5),
            "map-10": ranker.measure(metric='mapk', k=10)
        }
    elif method == 'MCA':
        ranker = MCARank(MODELS, comp, true_ranking)
        ranker.fit(df, measure='noisy_equality', p=noise)
        out = {
            "true_ranking": true_ranking,
            "estimated_ranking": ranker.ranking,
            "rbo": ranker.measure(metric="rbo"),
            "map-1": ranker.measure(metric='mapk', k=1),
            "map-3": ranker.measure(metric='mapk', k=3),
            "map-5": ranker.measure(metric='mapk', k=5),
            "map-10": ranker.measure(metric='mapk', k=10)
        }
    else:
        raise ValueError(f"{method} not understood.")

    eval_metrics = (
            f"<h2 style='color: purple;'> Evaluation measures </h2>"
            f"Rank-Biased Overlap: {out['rbo']:0.3f}<br>"
            f"MAP-3              : {out['map-3']:0.3f}<br>"
            f"MAP-5              : {out['map-5']:0.3f}<br>"
            f"MAP-10             : {out['map-10']: 0.3f}."
        )

    out_plot = ranker.plot("synth")
    plt.close(out_plot)

    return "synth.png", eval_metrics



def benchmark_executor(data, mmlu_subject, evaluation, nmodels, nrows, method
    ) -> tuple[pd.DataFrame, plt.figure]:
        """Main execution flow for benchmarks"""

        logger.info(f"Benchmark experiment: benchmark:{data}, mmlu subject: {mmlu_subject}, evaluation:{evaluation}, nmodels:{nmodels}, nquestions: {nrows}, method: {method}.")
        seed = 40
        np.random.seed(seed)

        match data:
            case "MMLU":
                adf = pd.read_pickle(f"data/mmlu_subject_{mmlu_subject}.pkl")

            case "CNN/DM":
                adf = pd.read_pickle(f"data/cnndm.pkl")

            case "XSUM":
                adf = pd.read_pickle(f"data/xsum.pkl")

            case _:
                raise ValueError(f"'{data}' not understood.")

        MODELS = adf.model.unique()

        # Sample fewer models if so needed
        if nmodels != "All":
            if nmodels < len(MODELS):

                MODELS = np.random.choice(MODELS, nmodels, replace=False).tolist()
                adf = adf[adf.model.isin(MODELS)]

        match data:
            case "MMLU":
                keys = [
                    "id",
                    "trial_id",
                    "perturbation",
                ]  # MMLU has this extra parameter
            case "CNN/DM" | "XSUM":
                keys = ["id", "trial_id"]
            case _:
                pass

        df = adf.pivot_table(
            columns="model",
            index=keys,
            values="output",
            aggfunc="first",
        )

        # Filter by number of rows
        df.dropna(inplace=True)
        if nrows != "All":
            if nrows < df.shape[0]:
                df = df.sample(nrows, random_state=seed)

        # Compute true ranking
        adf = adf.set_index(keys).loc[df.index].reset_index()

        if evaluation == "Rouge":

            def __true_rouge(x, scorer):
                return scorer.score(x["reference"], x["output"])["rouge2"].fmeasure

            scorer = rouge_scorer.RougeScorer(["rouge2"], use_stemmer=True)
            adf["rouge"] = Parallel(n_jobs=-1, batch_size=128)(
                delayed(__true_rouge)(i, scorer) for _, i in adf.iterrows()
            )

            # Method 2 - look at "win rates" - for each question, see which model
            # wins (i.e. has the best ROUGE score)
            idx = adf.groupby(["id", "trial_id"])["rouge"].idxmax()
            win_rates = adf.loc[idx].model.value_counts()
            win_rate_rank = win_rates.index.tolist()

            # include models with nowins at the bottom
            no_wins = list(set(MODELS) - set(win_rate_rank))
            true_ranking = win_rate_rank + no_wins
            evaluator = rouge

        elif evaluation == "Equality":

            # Compute the true ranking (multiple choice - so use equality between
            # LLM response and reference-value)
            adf["C"] = (adf.output == adf.reference).astype(int)
            true_ranking = (
                adf.groupby("model")["C"]
                .apply(lambda x: sum(x) / len(x))
                .sort_values(ascending=False)
                .index.tolist()
            )
            evaluator = equality

        else:
            raise ValueError(f"'{evaluation}' not understood.")

        match method:
            case "Full":
                ranker = SelfRank(MODELS, evaluator, true_ranking)

            case "Greedy":
                ranker = SelfRankGreedy(MODELS, evaluator, true_ranking)

            case "MCA":
                raise NotImplementedError
            case _:
                raise ValueError(f"'{method}' not understood.")

        # generate outputs
        ranker.fit(df)
        ranks = ranker.ranking
        
        ranks = [
            j + i for i, j in zip_longest(ranks, ["🥇 ", "🥈 ", "🥉 "], fillvalue="")
        ]
        out_df = pd.DataFrame({"rank": range(1, len(true_ranking) + 1), "model": ranks})

        out_metrics = {
            "rbo": ranker.measure(metric="rbo"),
            "map-1": ranker.measure(metric="mapk", k=1),
            "map-3": ranker.measure(metric="mapk", k=3),
            "map-5": ranker.measure(metric="mapk", k=5),
            "map-10": ranker.measure(metric="mapk", k=10),
            "evaluations": evaluator.calls,
        }
        eval_metrics = (
            f"<h2 style='color: purple;'> Evaluation measures </h2>"
            f"Rank-Biased Overlap: {out_metrics['rbo']:0.3f}<br>"
            f"MAP-3              : {out_metrics['map-3']:0.3f}<br>"
            f"MAP-5              : {out_metrics['map-5']:0.3f}<br>"
            f"MAP-10             : {out_metrics['map-10']: 0.3f}."
        )

        out_plot = ranker.plot()
        plt.close(out_plot)

        return out_df, "output.png", eval_metrics