import requests
import pandas as pd
from tqdm.auto import tqdm
from utils import *
import gradio as gr
from huggingface_hub import HfApi, hf_hub_download
from huggingface_hub.repocard import metadata_load


class DeepRL_Leaderboard:
    def __init__(self) -> None:
        self.leaderboard= {} 

    def add_leaderboard(self,id=None, title=None):
        if id is not None and title is not None:
            id = id.strip()
            title = title.strip()
            self.leaderboard.update({id:{'title':title,'data':get_data_per_env(id)}})
    def get_data(self):
        return self.leaderboard

    def get_ids(self):
        return list(self.leaderboard.keys())

          

# CSS file for the
with open('app.css','r') as f:
    BLOCK_CSS = f.read() 



LOADED_MODEL_IDS = {}

def get_data(rl_env):
    global LOADED_MODEL_IDS 
    data = []
    model_ids = get_model_ids(rl_env)
    LOADED_MODEL_IDS[rl_env]=model_ids

    for model_id in tqdm(model_ids):
        meta = get_metadata(model_id)
        if meta is None:
            continue
        row={}
        row["metadata"] = meta
       
        data.append(row)
    return pd.DataFrame.from_records(data)

def get_data_per_env(rl_env):
    dataframe = get_data(rl_env)
    return dataframe,dataframe.empty   



rl_leaderboard = DeepRL_Leaderboard()
rl_leaderboard.add_leaderboard('CarRacing-v0'," The Car Racing 🏎️ Leaderboard 🚀")
rl_leaderboard.add_leaderboard('MountainCar-v0',"The Mountain Car ⛰️ 🚗 Leaderboard 🚀")
rl_leaderboard.add_leaderboard('LunarLander-v2',"The Lunar Lander 🌕 Leaderboard 🚀")
rl_leaderboard.add_leaderboard('BipedalWalker-v3',"The BipedalWalker Leaderboard 🚀")
rl_leaderboard.add_leaderboard('Taxi-v3','The Taxi-v3🚖 Leaderboard 🚀')
rl_leaderboard.add_leaderboard('FrozenLake-v1-4x4-no_slippery','The FrozenLake-v1-4x4-no_slippery Leaderboard 🚀')
rl_leaderboard.add_leaderboard('FrozenLake-v1-8x8-no_slippery','The FrozenLake-v1-8x8-no_slippery Leaderboard 🚀')
rl_leaderboard.add_leaderboard('FrozenLake-v1-4x4','The FrozenLake-v1-4x4 Leaderboard 🚀')
rl_leaderboard.add_leaderboard('FrozenLake-v1-8x8','The FrozenLake-v1-8x8 Leaderboard 🚀')
rl_leaderboard.add_leaderboard('SpaceInvadersNoFrameskip-v4','The SpaceInvadersNoFrameskip-v4 Leaderboard 🚀')

RL_ENVS = rl_leaderboard.get_ids()
RL_DETAILS = rl_leaderboard.get_data()


def update_data(rl_env):
    global LOADED_MODEL_IDS
    data = []
    model_ids = [x for x in get_model_ids(rl_env) if x not in LOADED_MODEL_IDS[rl_env]]
    LOADED_MODEL_IDS[rl_env]+=model_ids

    for model_id in tqdm(model_ids):
        meta = get_metadata(model_id)
        if meta is None:
            continue
        row = {}
        row["metadata"] = meta
        data.append(row)
    return pd.DataFrame.from_records(data)



def update_data_per_env(rl_env):
    global RL_DETAILS

    old_dataframe,_ = RL_DETAILS[rl_env]['data']
    new_dataframe = update_data(rl_env)

    new_dataframe = new_dataframe.fillna("")

    dataframe = pd.concat([old_dataframe,new_dataframe])

    return dataframe,dataframe.empty   






def get_info_display(dataframe,env_name,name_leaderboard,is_empty):
    if not is_empty:
        markdown = """
        <div class='infoPoint'>
        <h1> {name_leaderboard} </h1>
        <br>
        <p> This is a leaderboard of <b>{len_dataframe}</b> agents, from <b>{num_unique_users}</b> unique users, playing {env_name} 👩‍🚀. </p>
        <br>
        <p> We use lower bound result to sort the models: mean_reward - std_reward. </p>
        <br>    
        <p> You can click on the model's name to be redirected to its model card which includes documentation. </p>
        <br>
        <p> You want to try your model? Read this <a href="https://github.com/huggingface/deep-rl-class/blob/Unit1/unit1/README.md" target="_blank">Unit 1</a> of Deep Reinforcement Learning Class.
        </p>
        </div>
        """.format(len_dataframe = len(dataframe),env_name = env_name,name_leaderboard = name_leaderboard,num_unique_users = len(set(dataframe['User'].values)))

    else:
        markdown = """
        <div class='infoPoint'>
        <h1> {name_leaderboard} </h1>
        <br>
        </div>                  
        """.format(name_leaderboard =  name_leaderboard)
    return markdown    

def reload_all_data():

    global RL_DETAILS,RL_ENVS

    for rl_env in RL_ENVS:
        RL_DETAILS[rl_env]['data'] = update_data_per_env(rl_env)

    html = """<div style="color: green">
                <p> ✅ Leaderboard updated! Click `Show Statistics` to see the current statistics.</p>
                </div>
               """    
    return html            


def reload_leaderboard(rl_env):
    global RL_DETAILS
 
    data_dataframe,is_empty = RL_DETAILS[rl_env]['data'] 

    markdown = get_info_display(data_dataframe,rl_env,RL_DETAILS[rl_env]['title'],is_empty)            
    
    return markdown     
            
def get_units_stat():
    # gets the number of models per unit
    units={'Unit 1':[],'Unit 2':[],'Unit 3':[]}
    for rl_env in RL_ENVS:
        rl_env_metadata,is_empty = RL_DETAILS[rl_env]['data']
        if is_empty is False:
            # All good! Carry on
            metadata_list = rl_env_metadata['metadata'].values
            units['Unit 1'].extend([m for m in metadata_list if 'stable-baselines3' in m['tags']])
            units['Unit 2'].extend([m for m in metadata_list if 'custom-implementation' in m['tags']])
            units['Unit 3'].extend([m for m in metadata_list if 'stable-baselines3' in m['tags'] and 'SpaceInvadersNoFrameskip-v4'.lower() in [tag.lower for tag in m['tags']]])

    # get count
    for k in units.keys():
        units[k] = len(units[k])

    return plot_bar(value = list(units.values),name = list(units.keys()),x_name = "Units",y_name = "Number of model submissions",title="Number of model submissions per unit")

    
      
def get_models_stat():
    # gets the number of models per unit
    units={}
    for rl_env in RL_ENVS:
        rl_env_metadata,is_empty = RL_DETAILS[rl_env]['data']
        if is_empty is False:
            # All good! Carry on
            metadata_list = rl_env_metadata['metadata'].values
            units[rl_env] = [m for m in metadata_list]

    # get count
    for k in units.keys():
        units[k] = len(units[k])

    return plot_bar(value = list(units.values),name = list(units.keys()),x_name = "RL Environment",y_name = "Number of model submissions",title="Number of model submissions per RL environment")

def get_user_stat():
    # gets the number of models per unit
    users={}
    for rl_env in RL_ENVS:
        rl_env_metadata,is_empty = RL_DETAILS[rl_env]['data']
        if is_empty is False:
            # All good! Carry on
            metadata_list = rl_env_metadata['metadata'].values
            users[rl_env] = [m['model_id'].split('/')[0] for m in metadata_list]

    # get count
    for k in users.keys():
        users[k] = len(set(users[k]))

    return plot_bar(value = list(users.values),name = list(users.keys()),x_name = "RL Environment",y_name = "Number of user submissions",title="Number of user submissions per RL environment")

def get_stat():
    # gets the number of models per unit
    units={'Unit 1':[],'Unit 2':[],'Unit 3':[]}
    users={}
    models={}
    for rl_env in RL_ENVS:
        rl_env_metadata,is_empty = RL_DETAILS[rl_env]['data']
        if is_empty is False:
            # All good! Carry on
            metadata_list = rl_env_metadata['metadata'].values
            units['Unit 1'].extend([m for m in metadata_list if 'stable-baselines3' in m['tags']])
            units['Unit 2'].extend([m for m in metadata_list if 'custom-implementation' in m['tags']])
            units['Unit 3'].extend([m for m in metadata_list if 'stable-baselines3' in m['tags'] and 'spaceinvadersNoFrameskip-v4'.lower() in [tag.lower() for tag in m['tags']]])

            users[rl_env] = [m['model_id'].split('/')[0] for m in metadata_list]
            models[rl_env] = [m for m in metadata_list]

    # get count
    for k in units.keys():
        units[k] = len(units[k])
    for k in users.keys():
        users[k] = len(set(users[k]))  
    for k in models.keys():
        models[k] = len(models[k])      

    units_plot = plot_bar(value = list(units.values()),name = list(units.keys()),x_name = "Units",y_name = "Number of model submissions",title="Number of model submissions per unit")
    user_plot  =  plot_barh(value = list(users.values()),name = list(users.keys()),x_name = "RL Environment",y_name = "Number of unique user submissions",title="Number of unique user submissions per RL environment")
    model_plot =  plot_barh(value = list(models.values()),name = list(models.keys()),x_name = "RL Environment",y_name = "Number of model submissions",title="Number of model submissions per RL environment")
    return units_plot,user_plot,model_plot

                 



block = gr.Blocks(css=BLOCK_CSS)
with block:
    notification = gr.HTML("""<div style="color: green">
                <p> ⌛ Updating leaderboard... </p>
                </div>
               """)
    block.load(reload_all_data,[],[notification])
    
    with gr.Tabs():
        with gr.TabItem("Dashboard") as rl_tab:
            # Stats of user submission per units
            # 2. # model submissions per environment
            # 3. # unique users per environment
            # get_units_stat() 
            #data_html,data_dataframe,is_empty = RL_DETAILS[rl_env]['data'] 
            #markdown = get_info_display(data_dataframe,rl_env,RL_DETAILS[rl_env]['title'],is_empty)            
            #env_state =gr.Variable(default_value=rl_env)  
            #output_markdown = gr.HTML(markdown)
            reload = gr.Button('Show Statistics')

            units_plot = gr.Plot(type="matplotlib")
            model_plot = gr.Plot(type="matplotlib")
            user_plot = gr.Plot(type="matplotlib")
            #plot_gender = gr.Plot(type="matplotlib")

            #output_html = gr.HTML(data_html)

            reload.click(get_stat,[],[units_plot,user_plot,model_plot])
            #rl_tab.select(reload_leaderboard,inputs=[env_state],outputs=[output_markdown,output_html])

block.launch()