import argparse
import logging

import numpy as np
import torch
import trimesh

from cube3d.inference.utils import load_config, load_model_weights, parse_structured
from cube3d.model.autoencoder.one_d_autoencoder import OneDAutoEncoder

MESH_SCALE = 0.96


def rescale(vertices: np.ndarray, mesh_scale: float = MESH_SCALE) -> np.ndarray:
    """Rescale the vertices to a cube, e.g., [-1, -1, -1] to [1, 1, 1] when mesh_scale=1.0"""
    vertices = vertices
    bbmin = vertices.min(0)
    bbmax = vertices.max(0)
    center = (bbmin + bbmax) * 0.5
    scale = 2.0 * mesh_scale / (bbmax - bbmin).max()
    vertices = (vertices - center) * scale
    return vertices


def load_scaled_mesh(file_path: str) -> trimesh.Trimesh:
    """
    Load a mesh and scale it to a unit cube, and clean the mesh.
    Parameters:
        file_obj: str | IO
        file_type: str
    Returns:
        mesh: trimesh.Trimesh
    """
    mesh: trimesh.Trimesh = trimesh.load(file_path, force="mesh")
    mesh.remove_infinite_values()
    mesh.update_faces(mesh.nondegenerate_faces())
    mesh.update_faces(mesh.unique_faces())
    mesh.remove_unreferenced_vertices()
    if len(mesh.vertices) == 0 or len(mesh.faces) == 0:
        raise ValueError("Mesh has no vertices or faces after cleaning")
    mesh.vertices = rescale(mesh.vertices)
    return mesh


def load_and_process_mesh(file_path: str, n_samples: int = 8192):
    """
    Loads a 3D mesh from the specified file path, samples points from its surface,
    and processes the sampled points into a point cloud with normals.
    Args:
        file_path (str): The file path to the 3D mesh file.
        n_samples (int, optional): The number of points to sample from the mesh surface. Defaults to 8192.
    Returns:
        torch.Tensor: A tensor of shape (1, n_samples, 6) containing the processed point cloud.
                        Each point consists of its 3D position (x, y, z) and its normal vector (nx, ny, nz).
    """

    mesh = load_scaled_mesh(file_path)
    positions, face_indices = trimesh.sample.sample_surface(mesh, n_samples)
    normals = mesh.face_normals[face_indices]
    point_cloud = np.concatenate(
        [positions, normals], axis=1
    )  # Shape: (num_samples, 6)
    point_cloud = torch.from_numpy(point_cloud.reshape(1, -1, 6)).float()
    return point_cloud


@torch.inference_mode()
def run_shape_decode(
    shape_model: OneDAutoEncoder,
    output_ids: torch.Tensor,
    resolution_base: float = 8.0,
    chunk_size: int = 100_000,
):
    """
    Decodes the shape from the given output IDs and extracts the geometry.
    Args:
        shape_model (OneDAutoEncoder): The shape model.
        output_ids (torch.Tensor): The tensor containing the output IDs.
        resolution_base (float, optional): The base resolution for geometry extraction. Defaults to 8.43.
        chunk_size (int, optional): The chunk size for processing. Defaults to 100,000.
    Returns:
        tuple: A tuple containing the vertices and faces of the mesh.
    """
    shape_ids = (
        output_ids[:, : shape_model.cfg.num_encoder_latents, ...]
        .clamp_(0, shape_model.cfg.num_codes - 1)
        .view(-1, shape_model.cfg.num_encoder_latents)
    )
    latents = shape_model.decode_indices(shape_ids)
    mesh_v_f, _ = shape_model.extract_geometry(
        latents,
        resolution_base=resolution_base,
        chunk_size=chunk_size,
        use_warp=True,
    )
    return mesh_v_f


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="cube shape encode and decode example script"
    )
    parser.add_argument(
        "--mesh-path",
        type=str,
        required=True,
        help="Path to the input mesh file.",
    )
    parser.add_argument(
        "--config-path",
        type=str,
        default="cube3d/configs/open_model.yaml",
        help="Path to the configuration YAML file.",
    )
    parser.add_argument(
        "--shape-ckpt-path",
        type=str,
        required=True,
        help="Path to the shape encoder/decoder checkpoint file.",
    )
    parser.add_argument(
        "--recovered-mesh-path",
        type=str,
        default="recovered_mesh.obj",
        help="Path to save the recovered mesh file.",
    )
    args = parser.parse_args()
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    logging.info(f"Using device: {device}")

    cfg = load_config(args.config_path)

    shape_model = OneDAutoEncoder(
        parse_structured(OneDAutoEncoder.Config, cfg.shape_model)
    )
    load_model_weights(
        shape_model,
        args.shape_ckpt_path,
    )
    shape_model = shape_model.eval().to(device)
    point_cloud = load_and_process_mesh(args.mesh_path)
    output = shape_model.encode(point_cloud.to(device))
    indices = output[3]["indices"]
    print("Got the following shape indices:")
    print(indices)
    print("Indices shape: ", indices.shape)
    mesh_v_f = run_shape_decode(shape_model, indices)
    vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
    mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
    mesh.export(args.recovered_mesh_path)