import torch
from torch import Tensor
from typing import Optional
from optimum.quanto import QBytesTensor


def compute_scale_for_dtype(tensor, dtype):
    """
    Compute appropriate scale for the given tensor and target dtype.
    
    Args:
        tensor: Input tensor to be quantized
        dtype: Target dtype for quantization
    Returns:
        Appropriate scale factor for the quantization
    """
    if dtype == torch.int8:
        abs_max = torch.max(torch.abs(tensor))
        return abs_max / 127.0 if abs_max > 0 else 1.0
    elif dtype == torch.uint8:
        max_val = torch.max(tensor)
        min_val = torch.min(tensor)
        range_val = max_val - min_val
        return range_val / 255.0 if range_val > 0 else 1.0
    elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
        # For float8, we typically want to preserve the magnitude of the values
        # while fitting within the representable range of the format
        abs_max = torch.max(torch.abs(tensor))
        if dtype == torch.float8_e4m3fn:
            # e4m3fn has range [-448, 448] with no infinities
            max_representable = 448.0
        else:  # torch.float8_e5m2
            # e5m2 has range [-57344, 57344] with infinities
            max_representable = 57344.0
        
        return abs_max / max_representable if abs_max > 0 else 1.0
    else:
        raise ValueError(f"Unsupported dtype for quantization: {dtype}")

def quantize_tensor(tensor, dtype):
    """
    Quantize a floating-point tensor to the target dtype with appropriate scaling.
    
    Args:
        tensor: Input tensor (float)
        dtype: Target dtype for quantization
    Returns:
        quantized_data: Quantized tensor
        scale: Scale factor used
    """
    scale = compute_scale_for_dtype(tensor, dtype)
    
    if dtype == torch.int8:
        quantized_data = torch.clamp(torch.round(tensor / scale), -128, 127).to(dtype)
    elif dtype == torch.uint8:
        quantized_data = torch.clamp(torch.round(tensor / scale), 0, 255).to(dtype)
    elif dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
        # For float8, we scale and then cast directly to the target type
        # The casting operation will handle the appropriate rounding
        scaled_tensor = tensor / scale
        quantized_data = scaled_tensor.to(dtype)
    else:
        raise ValueError(f"Unsupported dtype for quantization: {dtype}")
        
    return quantized_data, scale


def update_parameter(target, result_float):
    """
    Updates a parameter tensor, handling both regular torch.Tensor and QBytesTensor cases
    with proper rescaling for quantized tensors.
    
    Args:
        target: The parameter to update (either torch.Tensor or QBytesTensor)
        result_float: The new values to assign (torch.Tensor)
    """
    if isinstance(target, QBytesTensor):
        # Get the target dtype from the existing quantized tensor
        target_dtype = target._data.dtype
        
        # Handle device placement
        device = target._data.device
        result_float = result_float.to(device)
        
        # Compute new quantized values and scale
        quantized_data, new_scale = quantize_tensor(result_float, target_dtype)
        
        # Update the internal tensors with newly computed values
        target._data.copy_(quantized_data)
        target._scale.copy_(new_scale)
    else:
        # Regular tensor update
        target.copy_(result_float)


def get_format_params(dtype: torch.dtype) -> tuple[int, int]:
    """
    Returns (mantissa_bits, total_bits) for each format.
    mantissa_bits excludes the implicit leading 1.
    """
    if dtype == torch.float32:
        return 23, 32
    elif dtype == torch.bfloat16:
        return 7, 16
    elif dtype == torch.float16:
        return 10, 16
    elif dtype == torch.float8_e4m3fn:
        return 3, 8
    elif dtype == torch.float8_e5m2:
        return 2, 8
    elif dtype == torch.int8:
        return 0, 8  # Int8 doesn't have mantissa bits
    else:
        raise ValueError(f"Unsupported dtype: {dtype}")


def copy_stochastic(
    target: torch.Tensor,
    source: torch.Tensor,
    eps: Optional[float] = None
) -> None:
    """
    Performs stochastic rounding from source tensor to target tensor.

    Args:
        target: Destination tensor (determines the target format)
        source: Source tensor (typically float32)
        eps: Optional minimum value for stochastic rounding (for numerical stability)
    """
    with torch.no_grad():
        # If target is float32, just copy directly
        if target.dtype == torch.float32:
            target.copy_(source)
            return

        # Special handling for int8
        if target.dtype == torch.int8:
            # Scale the source values to utilize the full int8 range
            scaled = source * 127.0  # Scale to [-127, 127]

            # Add random noise for stochastic rounding
            noise = torch.rand_like(scaled) - 0.5
            rounded = torch.round(scaled + noise)

            # Clamp to int8 range
            clamped = torch.clamp(rounded, -127, 127)
            target.copy_(clamped.to(torch.int8))
            return

        mantissa_bits, _ = get_format_params(target.dtype)

        # Convert source to int32 view
        source_int = source.view(dtype=torch.int32)

        # Calculate number of bits to round
        bits_to_round = 23 - mantissa_bits  # 23 is float32 mantissa bits

        # Create random integers for stochastic rounding
        rand = torch.randint_like(
            source,
            dtype=torch.int32,
            low=0,
            high=(1 << bits_to_round),
        )

        # Add random values to the bits that will be rounded off
        result = source_int.clone()
        result.add_(rand)

        # Mask to keep only the bits we want
        # Create mask with 1s in positions we want to keep
        mask = (-1) << bits_to_round
        result.bitwise_and_(mask)

        # Handle minimum value threshold if specified
        if eps is not None:
            eps_int = torch.tensor(
                eps, dtype=torch.float32).view(dtype=torch.int32)
            zero_mask = (result.abs() < eps_int)
            result[zero_mask] = torch.sign(source_int[zero_mask]) * eps_int

        # Convert back to float32 view
        result_float = result.view(dtype=torch.float32)

        # Special handling for float8 formats
        if target.dtype == torch.float8_e4m3fn:
            result_float.clamp_(-448.0, 448.0)
        elif target.dtype == torch.float8_e5m2:
            result_float.clamp_(-57344.0, 57344.0)

        # Copy the result to the target tensor
        update_parameter(target, result_float)
        # target.copy_(result_float)
        del result, rand, source_int


class Auto8bitTensor:
    def __init__(self, data: Tensor, *args, **kwargs):
        if isinstance(data, dict):  # Add constructor from state dict
            self._load_from_state_dict(data)
        else:
            abs_max = data.abs().max().item()
            scale = abs_max / 127.0 if abs_max > 0 else 1.0

            self.quantized = (data / scale).round().clamp(-127, 127).to(torch.int8)
            self.scale = scale
            self.orig_dtype = data.dtype

    def dequantize(self) -> Tensor:
        return self.quantized.to(dtype=torch.float32) * self.scale

    def to(self, *args, **kwargs):
        # Handle the dtype argument whether it's positional or keyword
        dtype = None
        if args and isinstance(args[0], torch.dtype):
            dtype = args[0]
            args = args[1:]
        elif 'dtype' in kwargs:
            dtype = kwargs['dtype']
            del kwargs['dtype']

        if dtype is not None:
            # First dequantize then convert to requested dtype
            return self.dequantize().to(dtype=dtype, *args, **kwargs)

        # If no dtype specified, just pass through to parent
        return self.dequantize().to(*args, **kwargs)

    def state_dict(self):
        """Returns a dictionary containing the current state of the tensor."""
        return {
            'quantized': self.quantized,
            'scale': self.scale,
            'orig_dtype': self.orig_dtype
        }

    def _load_from_state_dict(self, state_dict):
        """Loads the tensor state from a state dictionary."""
        self.quantized = state_dict['quantized']
        self.scale = state_dict['scale']
        self.orig_dtype = state_dict['orig_dtype']

    def __str__(self):
        return f"Auto8bitTensor({self.dequantize()})"


def stochastic_grad_accummulation(param):
    if hasattr(param, "_accum_grad"):
        grad_fp32 = param._accum_grad.clone().to(torch.float32)
        grad_fp32.add_(param.grad.to(torch.float32))
        copy_stochastic(param._accum_grad, grad_fp32)
        del grad_fp32
        del param.grad
    else:
        param._accum_grad = param.grad.clone()
        del param.grad