# Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ A collection of utilities for ensuring that training can always occur. Heavily influenced by the [toma](https://github.com/BlackHC/toma) library. """ import functools import gc import importlib import inspect import warnings import torch from packaging import version from .imports import ( is_cuda_available, is_hpu_available, is_ipex_available, is_mlu_available, is_mps_available, is_musa_available, is_npu_available, is_sdaa_available, is_xpu_available, ) from .versions import compare_versions def clear_device_cache(garbage_collection=False): """ Clears the device cache by calling `torch.{backend}.empty_cache`. Can also run `gc.collect()`, but do note that this is a *considerable* slowdown and should be used sparingly. """ if garbage_collection: gc.collect() if is_xpu_available(): torch.xpu.empty_cache() elif is_mlu_available(): torch.mlu.empty_cache() elif is_sdaa_available(): torch.sdaa.empty_cache() elif is_musa_available(): torch.musa.empty_cache() elif is_npu_available(): torch.npu.empty_cache() elif is_mps_available(min_version="2.0"): torch.mps.empty_cache() elif is_cuda_available(): torch.cuda.empty_cache() elif is_hpu_available(): # torch.hpu.empty_cache() # not available on hpu as it reserves all device memory for the current process pass def release_memory(*objects): """ Releases memory from `objects` by setting them to `None` and calls `gc.collect()` and `torch.cuda.empty_cache()`. Returned objects should be reassigned to the same variables. Args: objects (`Iterable`): An iterable of objects Returns: A list of `None` objects to replace `objects` Example: ```python >>> import torch >>> from accelerate.utils import release_memory >>> a = torch.ones(1000, 1000).cuda() >>> b = torch.ones(1000, 1000).cuda() >>> a, b = release_memory(a, b) ``` """ if not isinstance(objects, list): objects = list(objects) for i in range(len(objects)): objects[i] = None clear_device_cache(garbage_collection=True) return objects def should_reduce_batch_size(exception: Exception) -> bool: """ Checks if `exception` relates to CUDA out-of-memory, XPU out-of-memory, CUDNN not supported, or CPU out-of-memory Args: exception (`Exception`): An exception """ _statements = [ " out of memory.", # OOM for CUDA, HIP, XPU "cuDNN error: CUDNN_STATUS_NOT_SUPPORTED.", # CUDNN SNAFU "DefaultCPUAllocator: can't allocate memory", # CPU OOM "FATAL ERROR :: MODULE:PT_DEVMEM Allocation failed", # HPU OOM ] if isinstance(exception, RuntimeError) and len(exception.args) == 1: return any(err in exception.args[0] for err in _statements) return False def find_executable_batch_size( function: callable = None, starting_batch_size: int = 128, reduce_batch_size_fn: callable = None ): """ A basic decorator that will try to execute `function`. If it fails from exceptions related to out-of-memory or CUDNN, the batch size is cut in half and passed to `function` `function` must take in a `batch_size` parameter as its first argument. Args: function (`callable`, *optional*): A function to wrap starting_batch_size (`int`, *optional*): The batch size to try and fit into memory Example: ```python >>> from accelerate.utils import find_executable_batch_size >>> @find_executable_batch_size(starting_batch_size=128) ... def train(batch_size, model, optimizer): ... ... >>> train(model, optimizer) ``` """ if function is None: return functools.partial(find_executable_batch_size, starting_batch_size=starting_batch_size) batch_size = starting_batch_size if reduce_batch_size_fn is None: def reduce_batch_size_fn(): nonlocal batch_size batch_size = batch_size // 2 return batch_size def decorator(*args, **kwargs): nonlocal batch_size clear_device_cache(garbage_collection=True) params = list(inspect.signature(function).parameters.keys()) # Guard against user error if len(params) < (len(args) + 1): arg_str = ", ".join([f"{arg}={value}" for arg, value in zip(params[1:], args[1:])]) raise TypeError( f"Batch size was passed into `{function.__name__}` as the first argument when called." f"Remove this as the decorator already does so: `{function.__name__}({arg_str})`" ) while True: if batch_size == 0: raise RuntimeError("No executable batch size found, reached zero.") try: return function(batch_size, *args, **kwargs) except Exception as e: if should_reduce_batch_size(e): clear_device_cache(garbage_collection=True) batch_size = reduce_batch_size_fn() else: raise return decorator def get_xpu_available_memory(device_index: int): if version.parse(torch.__version__).release >= version.parse("2.6").release: # torch.xpu.mem_get_info API is available starting from PyTorch 2.6 # It further requires PyTorch built with the SYCL runtime which supports API # to query available device memory. If not available, exception will be # raised. Version of SYCL runtime used to build PyTorch is being reported # with print(torch.version.xpu) and corresponds to the version of Intel DPC++ # SYCL compiler. First version to support required feature is 20250001. try: return torch.xpu.mem_get_info(device_index)[0] except Exception: pass elif is_ipex_available(): ipex_version = version.parse(importlib.metadata.version("intel_extension_for_pytorch")) if compare_versions(ipex_version, ">=", "2.5"): from intel_extension_for_pytorch.xpu import mem_get_info return mem_get_info(device_index)[0] warnings.warn( "The XPU `mem_get_info` API is available in IPEX version >=2.5 or PyTorch >=2.6. The current returned available memory is incorrect. Please consider upgrading your IPEX or PyTorch version." ) return torch.xpu.max_memory_allocated(device_index)