import networkx as nx
import numpy as np
from Bio.Phylo import to_networkx
from networkx.drawing.nx_agraph import graphviz_layout
import plotly.graph_objects as go
import plotly.express as px
from Bio.Phylo.TreeConstruction import DistanceTreeConstructor, DistanceCalculator, _DistanceMatrix

from tools import compute_ordered_matrix,compute_umap
from phylogeny import prepare_tree
from constants import UNKNOWN_COLOR, DEFAULT_COLOR, UNKNOWN_COLOR_RGB, DEFAULT_COLOR_RGB

# ------------------------------------------------------------------------------------------------
#
#                                     Sim Matrix Plotting
#
# ------------------------------------------------------------------------------------------------

def plot_sim_matrix_fig(ordered_sim_matrix,ordered_model_names,families,colors):
    fig = px.imshow(
        ordered_sim_matrix, 
        x=ordered_model_names, 
        y=ordered_model_names, 
        zmin=0, zmax=1, 
        color_continuous_scale='gray',
    )
    
    fig.update_layout(coloraxis_colorbar=dict(title='Similarity'),
        margin=dict(l=0, r=0, t=0, b=0),
        autosize=True,
    )

    fig.update_traces(
        colorbar=dict(
            thickness=20,
            len=0.75,
            xanchor="right", 
            x=1.02
        )
    )

    fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False,constrain='range')
    fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False,constrain='range')

    #Create rectangles for highlighted models
    rectX = go.layout.Shape(
            type="rect",
            xref="x", yref="y",
            x0=0, y0=0,
            x1=0, y1=0,
            line=dict(color="red", width=1),
            fillcolor="rgba(0,0,0,0)",
            name='rectX',
            opacity=0,
        )
    fig.add_shape(rectX)
    rectY = go.layout.Shape(
            type="rect",
            xref="x", yref="y",
            x0=0, y0=0,
            x1=0, y1=0,
            line=dict(color="red", width=1),
            fillcolor="rgba(0,0,0,0)",
            name='rectY',
            opacity=0,
        )
    fig.add_shape(rectY)

    return fig

def update_sim_matrix_fig(fig, ordered_model_names, model_search_x=None, model_search_y=None):
    if model_search_x in ordered_model_names:
        idx_x = ordered_model_names.index(model_search_x)
        fig.update_shapes(
            selector=dict(name='rectX'),
                x0=idx_x-0.5, y0=-0.5,
                x1=idx_x+0.5, y1=len(ordered_model_names)-0.5,
                opacity=0.7,
            )
    else:
        fig.update_shapes(
            selector=dict(name='rectX'),
                opacity=0
        )
    if model_search_y in ordered_model_names:
        idx_y = ordered_model_names.index(model_search_y)
        fig.update_shapes(
            selector=dict(name='rectY'),
                x0=-0.5, y0=idx_y-0.5,
                x1=len(ordered_model_names)-0.5, y1=idx_y+0.5,
                opacity=0.7,
            )
    else:
        fig.update_shapes(
            selector=dict(name='rectY'),
                opacity=0
        )
    return fig

# ------------------------------------------------------------------------------------------------
#
#                                     2D UMAP Plotting
#
# ------------------------------------------------------------------------------------------------

def alpha_scaling(val):
    base = 0.35
    return val**(1/(base+1/100))

def plot_umap_fig(dist_matrix, sim_matrix, model_names, families, colors,key='fig2',alpha_edges=None, alpha_names=None, alpha_markers=None):
    embedding = compute_umap(dist_matrix,d=2)

    fig = go.Figure()

    #-- EDGES
    # Calculate edge transparencies based on similarity
    edges = []
    for i in range(len(model_names)):
        for j in range(i+1, len(model_names)):  # Only process each pair once (i,j where i<j)
            val = alpha_scaling(sim_matrix[i][j])
            if sim_matrix[i][j] >= 0.6: #Considered as significant similarity
                edges.append((i, j, val, colors[families[i]]))
    
    # Add all edges at once
    for i, j, val, color in edges:
        fig.add_trace(
            go.Scatter(
                x=[embedding[i,0], embedding[j,0]],
                y=[embedding[i,1], embedding[j,1]],
                mode='lines',
                name='_edge',
                line=dict(color=color, width=val),
                opacity=alpha_edges,
                showlegend=False,
                hoverinfo='skip',
            )
        )

    #-- NODES
    marker_colors = [colors[f] for f in families]
    fig.add_trace(
        go.Scatter(
            x=embedding[:,0],
            y=embedding[:,1],
            text=model_names,
            mode='markers+text',
            textposition='top center',
            hoverinfo='text',
            hoveron='points+fills',
            showlegend=False,
            name='_node',
            marker=dict(
                color=marker_colors,
                size=8,
                line_width=2,
                opacity=alpha_markers,
            ),
            textfont=dict(
                color=f'rgba(0,0,0,{alpha_names})',
                size=8,
                family="Arial Black",
            )
        )
    )

    #-- LEGEND
    legends = []
    for f in set(families):
        legends.append(
            go.Scatter(
                x=[None],
                y=[None],
                mode='markers',
                marker=dict(
                    color=colors[f],
                    size=8,
                    line_width=2,
                    opacity=1
                ),
                name=f,
                
            )
        )
    fig.add_traces(legends)

    #Add highlighted node
    node = go.Scatter(
        x=[0],
        y=[0],
        mode='markers+text',
        textposition='top center',
        textfont=dict(color='red', size=16, family="Arial Black"),
        marker=dict(
            color='red',
            size=12,
            symbol='circle',
            line=dict(color='red', width=3)
        ),
        showlegend=False,
        name='node',
        opacity=0,
    )
    fig.add_trace(node)

    #Setup the layout
    fig.update_layout(
        margin=dict(l=0, r=0, t=0, b=0),
        autosize=True,
    )

    fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False,constrain='range')
    fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False,constrain='range')

    return fig

def update_umap_fig(fig, dist_matrix, model_names, families, colors, model_search_x=None, alpha_names=None, alpha_markers=None, alpha_edges=None, key='fig2'):
    #Update nodes
    fig.update_traces(
        selector=dict(name='_node'),
        textfont=dict(
            color=f'rgba(0,0,0,{alpha_names})',
        ),
        marker=dict(
            opacity=alpha_markers
        ),
    )

    #Update edges
    fig.update_traces(
        selector=dict(mode='lines'),
        line=dict(width=1),
        opacity=alpha_edges
    )

    #Update highlighted node
    if model_search_x in model_names:
        searched_idx = model_names.index(model_search_x)
        embedding = compute_umap(dist_matrix,d=2) #Cached computation
        fig.update_traces(
            selector=dict(name='node'),
            x=[embedding[searched_idx,0]],
            y=[embedding[searched_idx,1]],
            text=[model_search_x],
            marker=dict(
                color=colors[families[searched_idx]],
            ),
            hovertext=model_search_x,
            opacity=1
        )
    else:
        fig.update_traces(
            selector=dict(name='node'),
            x=[0],
            y=[0],
            text=[''],
            opacity=0
        )
    return fig

# ------------------------------------------------------------------------------------------------
#
#                                     Phylogenetic Tree Plotting
#
# ------------------------------------------------------------------------------------------------

def draw_graphviz(tree, label_func=str, prog='twopi', args='',
                 node_size=15, edge_width=0.0, alpha_edges=None, alpha_names=None,alpha_markers=None, **kwargs):
    #Display a tree or clade as a graph using Plotly, with layout from the graphviz engine.

    global UNKNOWN_COLOR, DEFAULT_COLOR
    # Convert the Bio.Phylo tree to a NetworkX graph
    G = to_networkx(tree)
    
    # Relabel nodes using integers while keeping original labels
    Gi = nx.convert_node_labels_to_integers(G, label_attribute='label')
    
    # Apply the Graphviz layout
    pos = graphviz_layout(Gi, prog=prog, args=args)
    
    # Prepare node labels for display
    def get_label_mapping(G, selection):
        for node, data in G.nodes(data=True):
            if (selection is None) or (node in selection):
                try:
                    label = label_func(data.get('label', node))
                    if label not in (None, node.__class__.__name__):
                        yield (node, label)
                except (LookupError, AttributeError, ValueError):
                    pass
    
    # Extract labels
    labels = dict(get_label_mapping(Gi, None))
    nodelist = list(labels.keys())
    
    # Collect node colors and create edge traces
    edge_traces = []
    node_traces_by_family = {}
    node_colors = {}
    node_families = {}
    
    # Track if we find the searched model and its position
    searched_model_node = None
    searched_model_pos = None

    default_color = (0,0,0)
    
    # Get colors and families for all nodes
    for node in Gi.nodes():
        node_data = Gi.nodes[node].get('label')
        if hasattr(node_data, 'color'):
            node_colors[node] = node_data.color.to_rgb() if not(node_data.color is None) else default_color
        else:
            node_colors[node] = default_color
        node_colors[node] = f'rgb({node_colors[node][0]},{node_colors[node][1]},{node_colors[node][2]})'
            
        if hasattr(node_data, 'family'):
            node_families[node] = node_data.family
        else:
            node_families[node] = None
            
    # Create edge traces
    for edge in Gi.edges():
        x0, y0 = pos[edge[0]]
        x1, y1 = pos[edge[1]]
        
        # Use the child node's color for the edge if available
        edge_color = node_colors[edge[1]]
        if list(edge_color) == list(UNKNOWN_COLOR_RGB): # Use the parent node's color for edge's color except if it's an unknown nodes
            edge_color = tuple(DEFAULT_COLOR_RGB)
        #edge_color = f'rgb({edge_color[0]},{edge_color[1]},{edge_color[2]})'
        edge_trace = go.Scatter(
            x=[x0, x1, None],
            y=[y0, y1, None],
            line=dict(width=edge_width, color=edge_color),
            hoverinfo='none',
            mode='lines',
            showlegend=False,
            name='_edge',
            opacity=alpha_edges,
        )
        edge_traces.append(edge_trace)

    # Create node traces
    node_traces = []
    for node in nodelist:
        x,y = pos[node]
        text = labels.get(node, None)
        color = node_colors.get(node, None)
        node_trace = go.Scatter(
            x=[x],
            y=[y],
            text=text,
            mode='markers+text',
            textposition='top center',
            hoverinfo='text',
            showlegend=False,
            name='_node',
            marker=dict(
                color=color,
                size=node_size,
                line_width=2,
                opacity=alpha_markers,
            ),
            textfont=dict(
                color=f'rgba(0,0,0,{alpha_names})',
                size=8,
                family="Arial Black",
            )
        )
        node_traces.append(node_trace)

    # Get color dict
    colors = {}
    families = []
    for node in node_families.keys():
        family = node_families[node]
        if family is not None:
            families.append(family)
            colors[family] = node_colors.get(node, DEFAULT_COLOR)
        else:
            colors[family] = DEFAULT_COLOR

    families = set(families)

    #Custom legend
    legends = []
    for f in families:
        legends.append(
            go.Scatter(
                x=[None],
                y=[None],
                mode='markers',
                marker=dict(
                    color=colors[f],
                    size=8,
                    line_width=2,
                    opacity=1
                ),
                name=f,
                
            )
        )
    
    # Create the figure
    fig = go.Figure(
        data=edge_traces + node_traces,
        layout=go.Layout(
            showlegend=True,
            hovermode='closest',
            margin=dict(b=1, l=1, r=1, t=1),
            xaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
            yaxis=dict(showgrid=False, zeroline=False, showticklabels=False),
            legend=dict(
                yanchor="top",
                y=0.99,
                xanchor="right",
                x=0.99
            )
        )
    )

    fig.add_traces(legends)

    return fig

def get_color(index):
    """Get a color from plotly's qualitative color palette."""
    colors = px.colors.qualitative.Plotly
    return colors[index % len(colors)]

def plot_tree(sim_matrix, models, families,colors, alpha_names=None, alpha_markers=None, alpha_edges=None):
    """
    Plot a phylogenetic tree based on a similarity matrix.
    
    Parameters:
    - sim_matrix: similarity matrix between models
    - models: list of model names
    - families: list of family names for each model
    
    Returns:
    - fig: Plotly figure object with the phylogenetic tree
    """
    # Create color mapping for families

    # Prepare the distance matrix
    dist_matrix = -np.log(np.maximum(sim_matrix, 1e-10))  # Avoid log(0)

    # Prepare the data for Bio.Phylo
    low_triangle_kl_mean = [[dist_matrix[i][j] for j in range(i+1)] for i in range(len(dist_matrix))]
    df = _DistanceMatrix(names=models, matrix=low_triangle_kl_mean)

    # Setup Bio.Phylo
    calculator = DistanceCalculator('identity')
    constructor = DistanceTreeConstructor(calculator, 'nj')

    # Build the tree
    NJTree = constructor.nj(df)
    NJTree.ladderize(reverse=False)

    # Color the tree
    prepare_tree(NJTree, models, families, colors)

    # Generate the plotly figure
    fig = draw_graphviz(NJTree, node_size=15, edge_width=6,alpha_names=alpha_names, alpha_markers=alpha_markers, alpha_edges=alpha_edges)

    return fig

def update_tree_fig(fig, model_names, model_search=None,alpha_names=None, alpha_markers=None, alpha_edges=None):
    #Update nodes
    fig.update_traces(
        selector=dict(name='_node'),
        marker=dict(
            opacity=alpha_markers,
        ),
        textfont=dict(
            color=f'rgba(0,0,0,{alpha_names})',
        )
    )
    
    # Update edges
    fig.update_traces(
        selector=dict(name='_edge'),
        opacity=alpha_edges,
    )

    for d in fig.data:
        if d.name in ['_node','node']:
            if d.text == 'mistralai/Mistral-7B-Instruct-v0.1':
                print(d)

    # Update highlighted node
    fig.update_traces(
            selector=dict(name='node'),
            marker=dict(
                size=15,  # Bigger than normal nodes
                line=None  # Red border
            ),
            textfont=dict(
                color=f'rgba(0,0,0,{alpha_names})', size=16, family="Arial Black",
            ),
            name='_node'
        )
    if model_search in model_names:
        fig.update_traces(
            selector=dict(name='_node',text=model_search),
            marker=dict(
                size=22,  # Bigger than normal nodes
                line=dict(color='red', width=4)  # Red border
            ),
            textfont=dict(
                color='red', size=16, family="Arial Black",
            ),
            name='node'
        )
        for d in fig.data:
            if d.name in ['_node','node']:
                if d.text == 'mistralai/Mistral-7B-Instruct-v0.1':
                    print(d)
    else:
        pass

    return fig