|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import gc |
|
import warnings |
|
from collections.abc import Mapping |
|
from contextlib import contextmanager |
|
from typing import Optional, Union |
|
|
|
import numpy as np |
|
import torch |
|
from transformers import is_torch_npu_available, is_torch_xpu_available |
|
|
|
|
|
def flatten_dict(nested: dict, sep: str = "/") -> dict: |
|
"""Flatten dictionary and concatenate nested keys with separator.""" |
|
|
|
def recurse(nest: dict, prefix: str, into: dict) -> None: |
|
for k, v in nest.items(): |
|
if sep in k: |
|
raise ValueError(f"separator '{sep}' not allowed to be in key '{k}'") |
|
if isinstance(v, Mapping): |
|
recurse(v, prefix + k + sep, into) |
|
else: |
|
into[prefix + k] = v |
|
|
|
flat = {} |
|
recurse(nested, "", flat) |
|
return flat |
|
|
|
|
|
def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor: |
|
"""Compute mean of tensor with a masked values.""" |
|
if axis is not None: |
|
return (values * mask).sum(axis=axis) / mask.sum(axis=axis) |
|
else: |
|
return (values * mask).sum() / mask.sum() |
|
|
|
|
|
def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor: |
|
"""Compute variance of tensor with masked values.""" |
|
mean = masked_mean(values, mask) |
|
centered_values = values - mean |
|
variance = masked_mean(centered_values**2, mask) |
|
if unbiased: |
|
mask_sum = mask.sum() |
|
if mask_sum == 0: |
|
raise ValueError( |
|
"The sum of the mask is zero, which can happen when `mini_batch_size=1`;" |
|
"try increase the `mini_batch_size` or `gradient_accumulation_steps`" |
|
) |
|
|
|
|
|
bessel_correction = mask_sum / (mask_sum - 1) |
|
variance = variance * bessel_correction |
|
return variance |
|
|
|
|
|
def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True) -> torch.Tensor: |
|
"""Whiten values with masked values.""" |
|
mean, var = masked_mean(values, mask), masked_var(values, mask) |
|
whitened = (values - mean) * torch.rsqrt(var + 1e-8) |
|
if not shift_mean: |
|
whitened += mean |
|
return whitened |
|
|
|
|
|
class LengthSampler: |
|
""" |
|
Samples a length |
|
""" |
|
|
|
def __init__(self, min_value: int, max_value: int): |
|
self.values = list(range(min_value, max_value)) |
|
|
|
def __call__(self) -> int: |
|
return np.random.choice(self.values) |
|
|
|
|
|
class PPODecorators: |
|
optimize_device_cache = False |
|
|
|
@classmethod |
|
@contextmanager |
|
def empty_device_cache(cls): |
|
yield |
|
if cls.optimize_device_cache: |
|
if is_torch_xpu_available(): |
|
gc.collect() |
|
torch.xpu.empty_cache() |
|
gc.collect() |
|
elif is_torch_npu_available(): |
|
gc.collect() |
|
torch.npu.empty_cache() |
|
gc.collect() |
|
elif torch.cuda.is_available(): |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
gc.collect() |
|
|
|
|
|
def randn_tensor( |
|
shape: Union[tuple, list], |
|
generator: Optional[Union[list[torch.Generator], torch.Generator]] = None, |
|
device: Optional[torch.device] = None, |
|
dtype: Optional[torch.dtype] = None, |
|
layout: Optional[torch.layout] = None, |
|
) -> torch.Tensor: |
|
"""A helper function to create random tensors on the desired `device` with the desired `dtype`. When |
|
passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor |
|
is always created on the CPU. |
|
""" |
|
|
|
rand_device = device |
|
batch_size = shape[0] |
|
|
|
layout = layout or torch.strided |
|
device = device or torch.device("cpu") |
|
|
|
if generator is not None: |
|
gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type |
|
if gen_device_type != device.type and gen_device_type == "cpu": |
|
rand_device = "cpu" |
|
if device != "mps": |
|
warnings.warn( |
|
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." |
|
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" |
|
f" slighly speed up this function by passing a generator that was created on the {device} device.", |
|
UserWarning, |
|
) |
|
elif gen_device_type != device.type and gen_device_type == "cuda": |
|
raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") |
|
|
|
|
|
if isinstance(generator, list) and len(generator) == 1: |
|
generator = generator[0] |
|
|
|
if isinstance(generator, list): |
|
shape = (1,) + shape[1:] |
|
latents = [ |
|
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) |
|
for i in range(batch_size) |
|
] |
|
latents = torch.cat(latents, dim=0).to(device) |
|
else: |
|
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) |
|
|
|
return latents |
|
|