import logging
import math
from typing import Dict, List, Optional, Tuple

import PIL
import PIL.Image
import torch
from diffusers import DiffusionPipeline

from rewards import clip_img_transform
from rewards.base_reward import BaseRewardLoss


class LatentNoiseTrainer:
    """Trainer for optimizing latents with reward losses."""

    def __init__(
        self,
        reward_losses: List[BaseRewardLoss],
        model: DiffusionPipeline,
        n_iters: int,
        n_inference_steps: int,
        seed: int,
        no_optim: bool = False,
        regularize: bool = True,
        regularization_weight: float = 0.01,
        grad_clip: float = 0.1,
        log_metrics: bool = True,
        save_all_images: bool = False,
        imageselect: bool = False,
        device: torch.device = torch.device("cuda"),
    ):
        self.reward_losses = reward_losses
        self.model = model
        self.n_iters = n_iters
        self.n_inference_steps = n_inference_steps
        self.seed = seed
        self.no_optim = no_optim
        self.regularize = regularize
        self.regularization_weight = regularization_weight
        self.grad_clip = grad_clip
        self.log_metrics = log_metrics
        self.save_all_images = save_all_images
        self.imageselect = imageselect
        self.device = device
        self.preprocess_fn = clip_img_transform(224)

    def train(
        self,
        latents: torch.Tensor,
        prompt: str,
        optimizer: torch.optim.Optimizer,
        save_dir: Optional[str] = None,
        multi_apply_fn=None,
        progress_callback=None,
    ) -> Tuple[PIL.Image.Image, Dict[str, float], Dict[str, float]]:
        logging.info(f"Optimizing latents for prompt '{prompt}'.")
        best_loss = torch.inf
        best_image = None
        initial_image = None
        initial_rewards = None
        best_rewards = None
        best_latents = None
        latent_dim = math.prod(latents.shape[1:])
        for iteration in range(self.n_iters):
            to_log = ""
            rewards = {}
            optimizer.zero_grad()
            generator = torch.Generator("cuda").manual_seed(self.seed)
            if self.imageselect:
                new_latents = torch.randn_like(
                    latents, device=self.device, dtype=latents.dtype
                )
                image = self.model.apply(
                    new_latents,
                    prompt,
                    generator=generator,
                    num_inference_steps=self.n_inference_steps,
                )
            else:
                image = self.model.apply(
                    latents=latents,
                    prompt=prompt,
                    generator=generator,
                    num_inference_steps=self.n_inference_steps,
                )
            if initial_image is None and multi_apply_fn is not None:
                multi_step_image = multi_apply_fn(latents.detach(), prompt)
                image_numpy = (
                    multi_step_image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
                )
                initial_image = DiffusionPipeline.numpy_to_pil(image_numpy)[0]
            if self.no_optim:
                best_image = image
                break

            total_loss = 0
            preprocessed_image = self.preprocess_fn(image)
            for reward_loss in self.reward_losses:
                loss = reward_loss(preprocessed_image, prompt)
                to_log += f"{reward_loss.name}: {loss.item():.4f}, "
                total_loss += loss * reward_loss.weighting
                rewards[reward_loss.name] = loss.item()
            rewards["total"] = total_loss.item()
            to_log += f"Total: {total_loss.item():.4f}"
            total_reward_loss = total_loss.item()
            if self.regularize:
                # compute in fp32 to avoid overflow
                latent_norm = torch.linalg.vector_norm(latents).to(torch.float32)
                log_norm = torch.log(latent_norm)
                regularization = self.regularization_weight * (
                    0.5 * latent_norm**2 - (latent_dim - 1) * log_norm
                )
                to_log += f", Latent norm: {latent_norm.item()}"
                rewards["norm"] = latent_norm.item()
                total_loss += regularization.to(total_loss.dtype)
            if self.log_metrics:
                logging.info(f"Iteration {iteration}: {to_log}")
            if total_reward_loss < best_loss:
                best_loss = total_reward_loss
                best_image = image
                best_rewards = rewards
                best_latents = latents.detach().cpu()
            if iteration != self.n_iters - 1 and not self.imageselect:
                total_loss.backward()
                torch.nn.utils.clip_grad_norm_(latents, self.grad_clip)
                optimizer.step()
            if self.save_all_images:
                image_numpy = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
                image_pil = DiffusionPipeline.numpy_to_pil(image_numpy)[0]
                image_pil.save(f"{save_dir}/{iteration}.png")
            if initial_rewards is None:
                initial_rewards = rewards
            if progress_callback:
                progress_callback(iteration + 1)
        image_numpy = best_image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
        best_image_pil = DiffusionPipeline.numpy_to_pil(image_numpy)[0]
        if multi_apply_fn is not None:
            multi_step_image = multi_apply_fn(best_latents.to("cuda"), prompt)
            image_numpy = (
                multi_step_image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
            )
            best_image_pil = DiffusionPipeline.numpy_to_pil(image_numpy)[0]
        return initial_image, best_image_pil, initial_rewards, best_rewards