Spaces:
Paused
Paused
# Copyright 2020-2025 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import gc | |
import warnings | |
from collections.abc import Mapping | |
from contextlib import contextmanager | |
from typing import Optional, Union | |
import numpy as np | |
import torch | |
from transformers import is_torch_npu_available, is_torch_xpu_available | |
def flatten_dict(nested: dict, sep: str = "/") -> dict: | |
"""Flatten dictionary and concatenate nested keys with separator.""" | |
def recurse(nest: dict, prefix: str, into: dict) -> None: | |
for k, v in nest.items(): | |
if sep in k: | |
raise ValueError(f"separator '{sep}' not allowed to be in key '{k}'") | |
if isinstance(v, Mapping): | |
recurse(v, prefix + k + sep, into) | |
else: | |
into[prefix + k] = v | |
flat = {} | |
recurse(nested, "", flat) | |
return flat | |
def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor: | |
"""Compute mean of tensor with a masked values.""" | |
if axis is not None: | |
return (values * mask).sum(axis=axis) / mask.sum(axis=axis) | |
else: | |
return (values * mask).sum() / mask.sum() | |
def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor: | |
"""Compute variance of tensor with masked values.""" | |
mean = masked_mean(values, mask) | |
centered_values = values - mean | |
variance = masked_mean(centered_values**2, mask) | |
if unbiased: | |
mask_sum = mask.sum() | |
if mask_sum == 0: | |
raise ValueError( | |
"The sum of the mask is zero, which can happen when `mini_batch_size=1`;" | |
"try increase the `mini_batch_size` or `gradient_accumulation_steps`" | |
) | |
# note that if mask_sum == 1, then there is a division by zero issue | |
# to avoid it you just need to use a larger minibatch_size | |
bessel_correction = mask_sum / (mask_sum - 1) | |
variance = variance * bessel_correction | |
return variance | |
def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True) -> torch.Tensor: | |
"""Whiten values with masked values.""" | |
mean, var = masked_mean(values, mask), masked_var(values, mask) | |
whitened = (values - mean) * torch.rsqrt(var + 1e-8) | |
if not shift_mean: | |
whitened += mean | |
return whitened | |
class LengthSampler: | |
""" | |
Samples a length | |
""" | |
def __init__(self, min_value: int, max_value: int): | |
self.values = list(range(min_value, max_value)) | |
def __call__(self) -> int: | |
return np.random.choice(self.values) | |
class PPODecorators: | |
optimize_device_cache = False | |
def empty_device_cache(cls): | |
yield | |
if cls.optimize_device_cache: | |
if is_torch_xpu_available(): | |
gc.collect() | |
torch.xpu.empty_cache() | |
gc.collect() | |
elif is_torch_npu_available(): | |
gc.collect() | |
torch.npu.empty_cache() | |
gc.collect() | |
elif torch.cuda.is_available(): | |
gc.collect() | |
torch.cuda.empty_cache() | |
gc.collect() | |
def randn_tensor( | |
shape: Union[tuple, list], | |
generator: Optional[Union[list[torch.Generator], torch.Generator]] = None, | |
device: Optional[torch.device] = None, | |
dtype: Optional[torch.dtype] = None, | |
layout: Optional[torch.layout] = None, | |
) -> torch.Tensor: | |
"""A helper function to create random tensors on the desired `device` with the desired `dtype`. When | |
passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor | |
is always created on the CPU. | |
""" | |
# device on which tensor is created defaults to device | |
rand_device = device | |
batch_size = shape[0] | |
layout = layout or torch.strided | |
device = device or torch.device("cpu") | |
if generator is not None: | |
gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type | |
if gen_device_type != device.type and gen_device_type == "cpu": | |
rand_device = "cpu" | |
if device != "mps": | |
warnings.warn( | |
f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." | |
f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" | |
f" slighly speed up this function by passing a generator that was created on the {device} device.", | |
UserWarning, | |
) | |
elif gen_device_type != device.type and gen_device_type == "cuda": | |
raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") | |
# make sure generator list of length 1 is treated like a non-list | |
if isinstance(generator, list) and len(generator) == 1: | |
generator = generator[0] | |
if isinstance(generator, list): | |
shape = (1,) + shape[1:] | |
latents = [ | |
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) | |
for i in range(batch_size) | |
] | |
latents = torch.cat(latents, dim=0).to(device) | |
else: | |
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) | |
return latents | |