|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
General namespace and dataclass related classes |
|
""" |
|
|
|
import argparse |
|
import copy |
|
import enum |
|
import functools |
|
import logging |
|
import os |
|
import warnings |
|
from collections.abc import Iterable |
|
from contextlib import contextmanager |
|
from dataclasses import dataclass, field |
|
from datetime import timedelta |
|
from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union, get_args |
|
|
|
import torch |
|
|
|
from .constants import ( |
|
BETA_TP_AVAILABLE_PYTORCH_VERSION, |
|
FSDP2_PYTORCH_VERSION, |
|
FSDP_AUTO_WRAP_POLICY, |
|
FSDP_BACKWARD_PREFETCH, |
|
FSDP_SHARDING_STRATEGY, |
|
MITA_PROFILING_AVAILABLE_PYTORCH_VERSION, |
|
XPU_PROFILING_AVAILABLE_PYTORCH_VERSION, |
|
) |
|
from .environment import parse_flag_from_env, str_to_bool |
|
from .imports import ( |
|
is_cuda_available, |
|
is_hpu_available, |
|
is_mlu_available, |
|
is_msamp_available, |
|
is_musa_available, |
|
is_npu_available, |
|
is_transformer_engine_available, |
|
is_xpu_available, |
|
) |
|
from .versions import compare_versions, is_torch_version |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
|
from torchao.float8 import Float8LinearConfig |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class KwargsHandler: |
|
""" |
|
Internal mixin that implements a `to_kwargs()` method for a dataclass. |
|
""" |
|
|
|
def to_dict(self): |
|
return copy.deepcopy(self.__dict__) |
|
|
|
def to_kwargs(self): |
|
""" |
|
Returns a dictionary containing the attributes with values different from the default of this class. |
|
""" |
|
|
|
from .environment import clear_environment |
|
|
|
with clear_environment(): |
|
default_dict = self.__class__().to_dict() |
|
this_dict = self.to_dict() |
|
return {k: v for k, v in this_dict.items() if default_dict[k] != v} |
|
|
|
|
|
class EnumWithContains(enum.EnumMeta): |
|
"A metaclass that adds the ability to check if `self` contains an item with the `in` operator" |
|
|
|
def __contains__(cls, item): |
|
try: |
|
cls(item) |
|
except ValueError: |
|
return False |
|
return True |
|
|
|
|
|
class BaseEnum(enum.Enum, metaclass=EnumWithContains): |
|
"An enum class that can get the value of an item with `str(Enum.key)`" |
|
|
|
def __str__(self): |
|
return self.value |
|
|
|
@classmethod |
|
def list(cls): |
|
"Method to list all the possible items in `cls`" |
|
return list(map(str, cls)) |
|
|
|
|
|
@dataclass |
|
class AutocastKwargs(KwargsHandler): |
|
""" |
|
Use this object in your [`Accelerator`] to customize how `torch.autocast` behaves. Please refer to the |
|
documentation of this [context manager](https://pytorch.org/docs/stable/amp.html#torch.autocast) for more |
|
information on each argument. |
|
|
|
Example: |
|
|
|
```python |
|
from accelerate import Accelerator |
|
from accelerate.utils import AutocastKwargs |
|
|
|
kwargs = AutocastKwargs(cache_enabled=True) |
|
accelerator = Accelerator(kwargs_handlers=[kwargs]) |
|
``` |
|
""" |
|
|
|
enabled: bool = True |
|
cache_enabled: bool = None |
|
|
|
|
|
class DDPCommunicationHookType(BaseEnum): |
|
""" |
|
Represents a type of communication hook used in DDP. |
|
|
|
Values: |
|
|
|
- **NO** -- no communication hook |
|
- **FP16** -- DDP communication hook to compress the gradients in FP16 |
|
- **BF16** -- DDP communication hook to compress the gradients in BF16 |
|
- **POWER_SGD** -- DDP communication hook to use PowerSGD |
|
- **BATCHED_POWER_SGD** -- DDP communication hook to use batched PowerSGD |
|
""" |
|
|
|
NO = "no" |
|
FP16 = "fp16" |
|
BF16 = "bf16" |
|
POWER_SGD = "power_sgd" |
|
BATCHED_POWER_SGD = "batched_power_sgd" |
|
|
|
|
|
@dataclass |
|
class DistributedDataParallelKwargs(KwargsHandler): |
|
""" |
|
Use this object in your [`Accelerator`] to customize how your model is wrapped in a |
|
`torch.nn.parallel.DistributedDataParallel`. Please refer to the documentation of this |
|
[wrapper](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) for more |
|
information on each argument. |
|
|
|
<Tip warning={true}> |
|
|
|
`gradient_as_bucket_view` is only available in PyTorch 1.7.0 and later versions. |
|
|
|
`static_graph` is only available in PyTorch 1.11.0 and later versions. |
|
|
|
</Tip> |
|
|
|
Example: |
|
|
|
```python |
|
from accelerate import Accelerator |
|
from accelerate.utils import DistributedDataParallelKwargs |
|
|
|
kwargs = DistributedDataParallelKwargs(find_unused_parameters=True) |
|
accelerator = Accelerator(kwargs_handlers=[kwargs]) |
|
``` |
|
""" |
|
|
|
dim: int = 0 |
|
broadcast_buffers: bool = True |
|
bucket_cap_mb: int = 25 |
|
find_unused_parameters: bool = False |
|
check_reduction: bool = False |
|
gradient_as_bucket_view: bool = False |
|
static_graph: bool = False |
|
|
|
comm_hook: DDPCommunicationHookType = DDPCommunicationHookType.NO |
|
comm_wrapper: Literal[ |
|
DDPCommunicationHookType.NO, DDPCommunicationHookType.FP16, DDPCommunicationHookType.BF16 |
|
] = DDPCommunicationHookType.NO |
|
comm_state_option: dict = field(default_factory=dict) |
|
|
|
def to_dict(self, ignore_keys=("comm_hook", "comm_wrapper", "comm_state_option")): |
|
return {k: v for k, v in super().to_dict().items() if k not in ignore_keys} |
|
|
|
def register_comm_hook(self, model): |
|
from torch.distributed.algorithms.ddp_comm_hooks import default_hooks, powerSGD_hook |
|
|
|
hook_map: dict[DDPCommunicationHookType, Callable] = { |
|
DDPCommunicationHookType.FP16: default_hooks.fp16_compress_hook, |
|
DDPCommunicationHookType.BF16: default_hooks.bf16_compress_hook, |
|
DDPCommunicationHookType.POWER_SGD: powerSGD_hook.powerSGD_hook, |
|
DDPCommunicationHookType.BATCHED_POWER_SGD: powerSGD_hook.batched_powerSGD_hook, |
|
} |
|
|
|
wrapper_map: dict[DDPCommunicationHookType, Callable] = { |
|
DDPCommunicationHookType.FP16: default_hooks.fp16_compress_wrapper, |
|
DDPCommunicationHookType.BF16: default_hooks.bf16_compress_wrapper, |
|
} |
|
|
|
hook: Optional[Callable] = hook_map.get(self.comm_hook) |
|
wrapper: Optional[Callable] = wrapper_map.get(self.comm_wrapper) |
|
|
|
if hook and wrapper: |
|
hook = wrapper(hook) |
|
|
|
if hook: |
|
state = ( |
|
powerSGD_hook.PowerSGDState(None, **self.comm_state_option) |
|
if self.comm_hook in (DDPCommunicationHookType.POWER_SGD, DDPCommunicationHookType.BATCHED_POWER_SGD) |
|
else None |
|
) |
|
model.register_comm_hook( |
|
state=state, |
|
hook=hook, |
|
) |
|
|
|
|
|
@dataclass |
|
class GradScalerKwargs(KwargsHandler): |
|
""" |
|
Use this object in your [`Accelerator`] to customize the behavior of mixed precision, specifically how the |
|
`torch.amp.GradScaler` or `torch.cuda.amp.GradScaler` used is created. Please refer to the documentation of this |
|
[scaler](https://pytorch.org/docs/stable/amp.html?highlight=gradscaler) for more information on each argument. |
|
|
|
<Tip warning={true}> |
|
|
|
`torch.cuda.amp.GradScaler` is only available in PyTorch 1.5.0 and later versions, and `torch.amp.GradScaler` is |
|
only available in PyTorch 2.4.0 and later versions. |
|
|
|
</Tip> |
|
|
|
Example: |
|
|
|
```python |
|
from accelerate import Accelerator |
|
from accelerate.utils import GradScalerKwargs |
|
|
|
kwargs = GradScalerKwargs(backoff_factor=0.25) |
|
accelerator = Accelerator(kwargs_handlers=[kwargs]) |
|
``` |
|
""" |
|
|
|
init_scale: float = 65536.0 |
|
growth_factor: float = 2.0 |
|
backoff_factor: float = 0.5 |
|
growth_interval: int = 2000 |
|
enabled: bool = True |
|
|
|
|
|
@dataclass |
|
class InitProcessGroupKwargs(KwargsHandler): |
|
""" |
|
Use this object in your [`Accelerator`] to customize the initialization of the distributed processes. Please refer |
|
to the documentation of this |
|
[method](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) for more |
|
information on each argument. |
|
|
|
Note: If `timeout` is set to `None`, the default will be based upon how `backend` is set. |
|
|
|
```python |
|
from datetime import timedelta |
|
from accelerate import Accelerator |
|
from accelerate.utils import InitProcessGroupKwargs |
|
|
|
kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=800)) |
|
accelerator = Accelerator(kwargs_handlers=[kwargs]) |
|
``` |
|
""" |
|
|
|
backend: Optional[str] = "nccl" |
|
init_method: Optional[str] = None |
|
timeout: Optional[timedelta] = None |
|
|
|
def __post_init__(self): |
|
if self.timeout is None: |
|
seconds = 1800 if self.backend != "nccl" else 600 |
|
self.timeout = timedelta(seconds=seconds) |
|
|
|
|
|
|
|
Backend = Literal["MSAMP", "TE"] |
|
OptLevel = Literal["O1", "O2"] |
|
FP8Format = Literal["HYBRID", "E4M3", "E5M2"] |
|
AmaxComputeAlgorithm = Literal["max", "most_recent"] |
|
|
|
|
|
|
|
@dataclass |
|
class AORecipeKwargs(KwargsHandler): |
|
""" |
|
Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision |
|
training with `torchao` FP8. |
|
|
|
Args: |
|
config (`torchao.float8.Float8LinearConfig`, *optional*, default to `None`): |
|
The configuration for the FP8 training. In general, the default config should be sufficient. |
|
module_filter_func (`Callable`, *optional*, default to `None`): |
|
Optional function that must take in a module and layer name, and returns a boolean indicating whether the |
|
module should be converted to FP8. Defaults to `accelerate.utils.ao.filter_linear_layers`. See it for an |
|
example. |
|
""" |
|
|
|
config: Optional["Float8LinearConfig"] = None |
|
module_filter_func: Optional[Callable] = None |
|
|
|
|
|
@dataclass |
|
class TERecipeKwargs(KwargsHandler): |
|
""" |
|
Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision |
|
training with `transformer-engine`. |
|
|
|
<Tip> |
|
|
|
For more information on the args, please refer to the API |
|
[documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html). |
|
|
|
</Tip> |
|
|
|
```python |
|
from accelerate import Accelerator |
|
from accelerate.utils import TERecipeKwargs |
|
|
|
kwargs = TERecipeKwargs(fp8_format="HYBRID") |
|
accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=[kwargs]) |
|
``` |
|
|
|
Args: |
|
use_autocast_during_eval (`bool`, *optional*, default to `False`): |
|
Whether to use FP8 autocast during eval mode. Generally better metrics are found when this is `False`. |
|
margin (`int`, *optional*, default to 0): |
|
The margin to use for the gradient scaling. |
|
interval (`int`, *optional*, default to 1): |
|
The interval to use for how often the scaling factor is recomputed. |
|
fp8_format (`str`, *optional*, default to "HYBRID"): |
|
The format to use for the FP8 recipe. Must be one of `HYBRID`, `E4M3` or `E5M2`. (Generally `HYBRID` for |
|
training, `E4M3` or `E5M2` for evaluation) |
|
amax_history_len (`int`, *optional*, default to 1024): |
|
The length of the history to use for the scaling factor computation |
|
amax_compute_algo (`str`, *optional*, default to "most_recent"): |
|
The algorithm to use for the scaling factor computation. Must be one of `max` or `most_recent`. |
|
override_linear_precision (`tuple` of three `bool`, *optional*, default to `(False, False, False)`): |
|
Whether or not to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision. |
|
""" |
|
|
|
use_autocast_during_eval: bool = None |
|
margin: int = None |
|
interval: int = None |
|
fp8_format: FP8Format = None |
|
amax_history_len: int = None |
|
amax_compute_algo: AmaxComputeAlgorithm = None |
|
override_linear_precision: tuple[bool, bool, bool] = None |
|
|
|
def __post_init__(self): |
|
env_prefix = "ACCELERATE_FP8_" |
|
if not is_transformer_engine_available(): |
|
raise ImportError("TransformerEngine is not available. Please install it or use a different backend.") |
|
if self.use_autocast_during_eval is None: |
|
self.use_autocast_during_eval = parse_flag_from_env(env_prefix + "USE_AUTOCAST_DURING_EVAL") |
|
if self.margin is None: |
|
self.margin = int(os.environ.get(env_prefix + "MARGIN", 0)) |
|
if self.interval is None: |
|
self.interval = int(os.environ.get(env_prefix + "INTERVAL", 1)) |
|
if self.fp8_format is None: |
|
self.fp8_format = os.environ.get(env_prefix + "FORMAT", "HYBRID") |
|
self.fp8_format = self.fp8_format.upper() |
|
if self.fp8_format not in get_args(FP8Format): |
|
raise ValueError(f"`fp8_format` must be one of {' or '.join(get_args(FP8Format))}.") |
|
if self.amax_compute_algo is None: |
|
self.amax_compute_algo = os.environ.get(env_prefix + "AMAX_COMPUTE_ALGO", "most_recent") |
|
self.amax_compute_algo = self.amax_compute_algo.lower() |
|
if self.amax_compute_algo not in get_args(AmaxComputeAlgorithm): |
|
raise ValueError(f"`amax_compute_algo` must be one of {' or '.join(get_args(AmaxComputeAlgorithm))}") |
|
if self.amax_history_len is None: |
|
self.amax_history_len = int(os.environ.get(env_prefix + "AMAX_HISTORY_LEN", 1024)) |
|
if self.override_linear_precision is None: |
|
fprop = parse_flag_from_env(env_prefix + "OVERRIDE_FPROP") |
|
dgrad = parse_flag_from_env(env_prefix + "OVERRIDE_DGRAD") |
|
wgrad = parse_flag_from_env(env_prefix + "OVERRIDE_WGRAD") |
|
self.override_linear_precision = (fprop, dgrad, wgrad) |
|
|
|
|
|
@dataclass |
|
class MSAMPRecipeKwargs(KwargsHandler): |
|
""" |
|
Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision |
|
training with `ms-amp`. |
|
""" |
|
|
|
opt_level: OptLevel = None |
|
|
|
def __post_init__(self): |
|
env_prefix = "ACCELERATE_FP8_" |
|
if self.opt_level is None: |
|
self.opt_level = os.environ.get(env_prefix + "OPT_LEVEL", "O2") |
|
if self.opt_level not in get_args(OptLevel): |
|
raise ValueError(f"`opt_level` must be one of {' or '.join(get_args(OptLevel))}") |
|
|
|
|
|
@dataclass |
|
class FP8RecipeKwargs(TERecipeKwargs, MSAMPRecipeKwargs): |
|
""" |
|
Deprecated. Please use one of the proper FP8 recipe kwargs classes such as `TERecipeKwargs` or `MSAMPRecipeKwargs` |
|
instead. |
|
""" |
|
|
|
backend: Backend = None |
|
|
|
def __post_init__(self): |
|
env_prefix = "ACCELERATE_FP8_" |
|
warnings.warn( |
|
"FP8RecipeKwargs is deprecated and will be removed in Accelerate v2.0.0. " |
|
"Please use one of the proper FP8 recipe kwargs classes such as TERecipeKwargs or MSAMPRecipeKwargs instead.", |
|
FutureWarning, |
|
) |
|
default_backend = "msamp" if is_msamp_available() else "te" |
|
if self.backend is None: |
|
self.backend = os.environ.get(env_prefix + "BACKEND", default_backend) |
|
self.backend = self.backend.upper() |
|
if self.backend not in get_args(Backend): |
|
raise ValueError("`backend` must be 'MSAMP' or 'TE' (TransformerEngine) to use `FP8RecipeKwargs`.") |
|
super().__post_init__() |
|
|
|
|
|
|
|
ProfilerActivity = Literal["cpu", "xpu", "mtia", "cuda", "hpu"] |
|
|
|
|
|
@dataclass |
|
class ProfileKwargs(KwargsHandler): |
|
""" |
|
Use this object in your [`Accelerator`] to customize the initialization of the profiler. Please refer to the |
|
documentation of this [context manager](https://pytorch.org/docs/stable/profiler.html#torch.profiler.profile) for |
|
more information on each argument. |
|
|
|
<Tip warning={true}> |
|
|
|
`torch.profiler` is only available in PyTorch 1.8.1 and later versions. |
|
|
|
</Tip> |
|
|
|
Example: |
|
|
|
```python |
|
from accelerate import Accelerator |
|
from accelerate.utils import ProfileKwargs |
|
|
|
kwargs = ProfileKwargs(activities=["cpu", "cuda"]) |
|
accelerator = Accelerator(kwargs_handlers=[kwargs]) |
|
``` |
|
|
|
Args: |
|
activities (`List[str]`, *optional*, default to `None`): |
|
The list of activity groups to use in profiling. Must be one of `"cpu"`, `"xpu"`, `"mtia"`, "hpu" or |
|
`"cuda"`. |
|
schedule_option (`Dict[str, int]`, *optional*, default to `None`): |
|
The schedule option to use for the profiler. Available keys are `wait`, `warmup`, `active`, `repeat` and |
|
`skip_first`. The profiler will skip the first `skip_first` steps, then wait for `wait` steps, then do the |
|
warmup for the next `warmup` steps, then do the active recording for the next `active` steps and then |
|
repeat the cycle starting with `wait` steps. The optional number of cycles is specified with the `repeat` |
|
parameter, the zero value means that the cycles will continue until the profiling is finished. |
|
on_trace_ready (`Callable`, *optional*, default to `None`): |
|
Callable that is called at each step when schedule returns `ProfilerAction.RECORD_AND_SAVE` during the |
|
profiling. |
|
record_shapes (`bool`, *optional*, default to `False`): |
|
Save information about operatorβs input shapes. |
|
profile_memory (`bool`, *optional*, default to `False`): |
|
Track tensor memory allocation/deallocation |
|
with_stack (`bool`, *optional*, default to `False`): |
|
Record source information (file and line number) for the ops. |
|
with_flops (`bool`, *optional*, default to `False`): |
|
Use formula to estimate the FLOPS of specific operators |
|
with_modules (`bool`, *optional*, default to `False`): |
|
Record module hierarchy (including function names) corresponding to the callstack of the op. |
|
output_trace_dir (`str`, *optional*, default to `None`): |
|
Exports the collected trace in Chrome JSON format. Chrome use 'chrome://tracing' view json file. Defaults |
|
to None, which means profiling does not store json files. |
|
""" |
|
|
|
activities: Optional[list[ProfilerActivity]] = None |
|
schedule_option: Optional[dict[str, int]] = None |
|
on_trace_ready: Optional[Callable] = None |
|
record_shapes: bool = False |
|
profile_memory: bool = False |
|
with_stack: bool = False |
|
with_flops: bool = False |
|
with_modules: bool = False |
|
output_trace_dir: Optional[str] = None |
|
|
|
def _get_profiler_activity(self, activity: ProfilerActivity) -> torch.profiler.ProfilerActivity: |
|
"""Get the profiler activity from the string. |
|
|
|
Args: |
|
activity (str): The profiler activity name. |
|
|
|
Returns: |
|
torch.profiler.ProfilerActivity: The profiler activity. |
|
""" |
|
|
|
profiler_activity_map: dict[str, torch.profiler.ProfilerActivity] = { |
|
"cpu": torch.profiler.ProfilerActivity.CPU, |
|
"cuda": torch.profiler.ProfilerActivity.CUDA, |
|
} |
|
|
|
if is_hpu_available(): |
|
profiler_activity_map["hpu"] = torch.profiler.ProfilerActivity.HPU |
|
|
|
if is_torch_version(">=", XPU_PROFILING_AVAILABLE_PYTORCH_VERSION): |
|
if torch.xpu.is_available(): |
|
profiler_activity_map["xpu"] = torch.profiler.ProfilerActivity.XPU |
|
|
|
if is_torch_version(">=", MITA_PROFILING_AVAILABLE_PYTORCH_VERSION): |
|
if torch.mtia.is_available(): |
|
profiler_activity_map["mtia"] = torch.profiler.ProfilerActivity.MTIA |
|
|
|
if activity not in profiler_activity_map: |
|
raise ValueError(f"Invalid profiler activity: {activity}. Must be one of {list(profiler_activity_map)}.") |
|
return profiler_activity_map[activity] |
|
|
|
def build(self) -> torch.profiler.profile: |
|
""" |
|
Build a profiler object with the current configuration. |
|
|
|
Returns: |
|
torch.profiler.profile: The profiler object. |
|
""" |
|
activities: Optional[list[ProfilerActivity]] = None |
|
if self.activities is not None: |
|
activities = [self._get_profiler_activity(activity) for activity in self.activities] |
|
schedule: Optional[torch.profiler.schedule] = None |
|
if self.schedule_option is not None: |
|
schedule = torch.profiler.schedule(**self.schedule_option) |
|
|
|
return torch.profiler.profile( |
|
activities=activities, |
|
schedule=schedule, |
|
on_trace_ready=self.on_trace_ready, |
|
record_shapes=self.record_shapes, |
|
profile_memory=self.profile_memory, |
|
with_stack=self.with_stack, |
|
with_flops=self.with_flops, |
|
with_modules=self.with_modules, |
|
) |
|
|
|
|
|
class DistributedType(str, enum.Enum): |
|
""" |
|
Represents a type of distributed environment. |
|
|
|
Values: |
|
|
|
- **NO** -- Not a distributed environment, just a single process. |
|
- **MULTI_CPU** -- Distributed on multiple CPU nodes. |
|
- **MULTI_GPU** -- Distributed on multiple GPUs. |
|
- **MULTI_MLU** -- Distributed on multiple MLUs. |
|
- **MULTI_SDAA** -- Distributed on multiple SDAAs. |
|
- **MULTI_MUSA** -- Distributed on multiple MUSAs. |
|
- **MULTI_NPU** -- Distributed on multiple NPUs. |
|
- **MULTI_XPU** -- Distributed on multiple XPUs. |
|
- **MULTI_HPU** -- Distributed on multiple HPUs. |
|
- **DEEPSPEED** -- Using DeepSpeed. |
|
- **XLA** -- Using TorchXLA. |
|
""" |
|
|
|
|
|
NO = "NO" |
|
MULTI_CPU = "MULTI_CPU" |
|
MULTI_GPU = "MULTI_GPU" |
|
MULTI_NPU = "MULTI_NPU" |
|
MULTI_MLU = "MULTI_MLU" |
|
MULTI_SDAA = "MULTI_SDAA" |
|
MULTI_MUSA = "MULTI_MUSA" |
|
MULTI_XPU = "MULTI_XPU" |
|
DEEPSPEED = "DEEPSPEED" |
|
FSDP = "FSDP" |
|
TP = "TP" |
|
XLA = "XLA" |
|
MEGATRON_LM = "MEGATRON_LM" |
|
MULTI_HPU = "MULTI_HPU" |
|
|
|
|
|
class SageMakerDistributedType(str, enum.Enum): |
|
""" |
|
Represents a type of distributed environment. |
|
|
|
Values: |
|
|
|
- **NO** -- Not a distributed environment, just a single process. |
|
- **DATA_PARALLEL** -- using sagemaker distributed data parallelism. |
|
- **MODEL_PARALLEL** -- using sagemaker distributed model parallelism. |
|
""" |
|
|
|
|
|
NO = "NO" |
|
DATA_PARALLEL = "DATA_PARALLEL" |
|
MODEL_PARALLEL = "MODEL_PARALLEL" |
|
|
|
|
|
class FP8BackendType(str, enum.Enum): |
|
""" |
|
Represents the backend used for FP8. |
|
|
|
Values: |
|
|
|
- **TE** -- using TransformerEngine. |
|
- **MSAMP** -- using msamp. |
|
""" |
|
|
|
|
|
TE = "TE" |
|
MSAMP = "MSAMP" |
|
|
|
|
|
class ComputeEnvironment(str, enum.Enum): |
|
""" |
|
Represents a type of the compute environment. |
|
|
|
Values: |
|
|
|
- **LOCAL_MACHINE** -- private/custom cluster hardware. |
|
- **AMAZON_SAGEMAKER** -- Amazon SageMaker as compute environment. |
|
""" |
|
|
|
|
|
LOCAL_MACHINE = "LOCAL_MACHINE" |
|
AMAZON_SAGEMAKER = "AMAZON_SAGEMAKER" |
|
|
|
|
|
class DynamoBackend(str, BaseEnum): |
|
""" |
|
Represents a dynamo backend (see https://pytorch.org/docs/stable/torch.compiler.html). |
|
|
|
Values: |
|
|
|
- **NO** -- Do not use torch dynamo. |
|
- **EAGER** -- Uses PyTorch to run the extracted GraphModule. This is quite useful in debugging TorchDynamo |
|
issues. |
|
- **AOT_EAGER** -- Uses AotAutograd with no compiler, i.e, just using PyTorch eager for the AotAutograd's |
|
extracted forward and backward graphs. This is useful for debugging, and unlikely to give speedups. |
|
- **INDUCTOR** -- Uses TorchInductor backend with AotAutograd and cudagraphs by leveraging codegened Triton |
|
kernels. [Read |
|
more](https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747) |
|
- **AOT_TS_NVFUSER** -- nvFuser with AotAutograd/TorchScript. [Read |
|
more](https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593) |
|
- **NVPRIMS_NVFUSER** -- nvFuser with PrimTorch. [Read |
|
more](https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593) |
|
- **CUDAGRAPHS** -- cudagraphs with AotAutograd. [Read more](https://github.com/pytorch/torchdynamo/pull/757) |
|
- **OFI** -- Uses Torchscript optimize_for_inference. Inference only. [Read |
|
more](https://pytorch.org/docs/stable/generated/torch.jit.optimize_for_inference.html) |
|
- **FX2TRT** -- Uses Nvidia TensorRT for inference optimizations. Inference only. [Read |
|
more](https://github.com/pytorch/TensorRT/blob/master/docsrc/tutorials/getting_started_with_fx_path.rst) |
|
- **ONNXRT** -- Uses ONNXRT for inference on CPU/GPU. Inference only. [Read more](https://onnxruntime.ai/) |
|
- **TENSORRT** -- Uses ONNXRT to run TensorRT for inference optimizations. [Read |
|
more](https://github.com/onnx/onnx-tensorrt) |
|
- **AOT_TORCHXLA_TRACE_ONCE** -- Uses Pytorch/XLA with TorchDynamo optimization, for training. [Read |
|
more](https://github.com/pytorch/xla/blob/r2.0/docs/dynamo.md) |
|
- **TORCHXLA_TRACE_ONCE** -- Uses Pytorch/XLA with TorchDynamo optimization, for inference. [Read |
|
more](https://github.com/pytorch/xla/blob/r2.0/docs/dynamo.md) |
|
- **IPEX** -- Uses IPEX for inference on CPU. Inference only. [Read |
|
more](https://github.com/intel/intel-extension-for-pytorch). |
|
- **TVM** -- Uses Apach TVM for inference optimizations. [Read more](https://tvm.apache.org/) |
|
- **HPU_BACKEND** -- Uses HPU backend for inference optimizations. |
|
|
|
""" |
|
|
|
|
|
NO = "NO" |
|
EAGER = "EAGER" |
|
AOT_EAGER = "AOT_EAGER" |
|
INDUCTOR = "INDUCTOR" |
|
AOT_TS_NVFUSER = "AOT_TS_NVFUSER" |
|
NVPRIMS_NVFUSER = "NVPRIMS_NVFUSER" |
|
CUDAGRAPHS = "CUDAGRAPHS" |
|
OFI = "OFI" |
|
FX2TRT = "FX2TRT" |
|
ONNXRT = "ONNXRT" |
|
TENSORRT = "TENSORRT" |
|
AOT_TORCHXLA_TRACE_ONCE = "AOT_TORCHXLA_TRACE_ONCE" |
|
TORCHXLA_TRACE_ONCE = "TORCHXLA_TRACE_ONCE" |
|
IPEX = "IPEX" |
|
TVM = "TVM" |
|
HPU_BACKEND = "HPU_BACKEND" |
|
|
|
|
|
class LoggerType(BaseEnum): |
|
"""Represents a type of supported experiment tracker |
|
|
|
Values: |
|
|
|
- **ALL** -- all available trackers in the environment that are supported |
|
- **TENSORBOARD** -- TensorBoard as an experiment tracker |
|
- **WANDB** -- wandb as an experiment tracker |
|
- **COMETML** -- comet_ml as an experiment tracker |
|
- **DVCLIVE** -- dvclive as an experiment tracker |
|
- **SWANLAB** -- swanlab as an experiment tracker |
|
""" |
|
|
|
ALL = "all" |
|
AIM = "aim" |
|
TENSORBOARD = "tensorboard" |
|
WANDB = "wandb" |
|
COMETML = "comet_ml" |
|
MLFLOW = "mlflow" |
|
CLEARML = "clearml" |
|
DVCLIVE = "dvclive" |
|
SWANLAB = "swanlab" |
|
|
|
|
|
class PrecisionType(str, BaseEnum): |
|
"""Represents a type of precision used on floating point values |
|
|
|
Values: |
|
|
|
- **NO** -- using full precision (FP32) |
|
- **FP16** -- using half precision |
|
- **BF16** -- using brain floating point precision |
|
""" |
|
|
|
NO = "no" |
|
FP8 = "fp8" |
|
FP16 = "fp16" |
|
BF16 = "bf16" |
|
|
|
|
|
class RNGType(BaseEnum): |
|
TORCH = "torch" |
|
CUDA = "cuda" |
|
MLU = "mlu" |
|
SDAA = "sdaa" |
|
MUSA = "musa" |
|
NPU = "npu" |
|
XLA = "xla" |
|
XPU = "xpu" |
|
HPU = "hpu" |
|
GENERATOR = "generator" |
|
|
|
|
|
class CustomDtype(enum.Enum): |
|
r""" |
|
An enum that contains multiple custom dtypes that can be used for `infer_auto_device_map`. |
|
""" |
|
|
|
FP8 = "fp8" |
|
INT4 = "int4" |
|
INT2 = "int2" |
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
class TensorInformation: |
|
shape: torch.Size |
|
dtype: torch.dtype |
|
|
|
|
|
@dataclass |
|
class DataLoaderConfiguration: |
|
""" |
|
Configuration for dataloader-related items when calling `accelerator.prepare`. |
|
|
|
Args: |
|
split_batches (`bool`, defaults to `False`): |
|
Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If |
|
`True`, the actual batch size used will be the same on any kind of distributed processes, but it must be a |
|
round multiple of `num_processes` you are using. If `False`, actual batch size used will be the one set in |
|
your script multiplied by the number of processes. |
|
dispatch_batches (`bool`, defaults to `None`): |
|
If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process |
|
and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose |
|
underlying dataset is an `IterableDataset`, `False` otherwise. |
|
even_batches (`bool`, defaults to `True`): |
|
If set to `True`, in cases where the total batch size across all processes does not exactly divide the |
|
dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among |
|
all workers. |
|
use_seedable_sampler (`bool`, defaults to `False`): |
|
Whether or not use a fully seedable random sampler ([`data_loader.SeedableRandomSampler`]). Ensures |
|
training results are fully reproducable using a different sampling technique. While seed-to-seed results |
|
may differ, on average the differences are neglible when using multiple different seeds to compare. Should |
|
also be ran with [`~utils.set_seed`] for the best results. |
|
data_seed (`int`, defaults to `None`): |
|
The seed to use for the underlying generator when using `use_seedable_sampler`. If `None`, the generator |
|
will use the current default seed from torch. |
|
non_blocking (`bool`, defaults to `False`): |
|
If set to `True`, the dataloader prepared by the Accelerator will utilize non-blocking host-to-device |
|
transfers, allowing for better overlap between dataloader communication and computation. Recommended that |
|
the prepared dataloader has `pin_memory` set to `True` to work properly. |
|
use_stateful_dataloader (`bool`, defaults to `False`): |
|
If set to `True`, the dataloader prepared by the Accelerator will be backed by |
|
[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). |
|
This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed. |
|
""" |
|
|
|
split_batches: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": "Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If" |
|
" `True` the actual batch size used will be the same on any kind of distributed processes, but it must be a" |
|
" round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set" |
|
" in your script multiplied by the number of processes." |
|
}, |
|
) |
|
dispatch_batches: bool = field( |
|
default=None, |
|
metadata={ |
|
"help": "If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process" |
|
" and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose" |
|
" underlying dataset is an `IterableDataset`, `False` otherwise." |
|
}, |
|
) |
|
even_batches: bool = field( |
|
default=True, |
|
metadata={ |
|
"help": "If set to `True`, in cases where the total batch size across all processes does not exactly divide the" |
|
" dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among" |
|
" all workers." |
|
}, |
|
) |
|
use_seedable_sampler: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": "Whether or not use a fully seedable random sampler ([`data_loader.SeedableRandomSampler`])." |
|
"Ensures training results are fully reproducable using a different sampling technique. " |
|
"While seed-to-seed results may differ, on average the differences are neglible when using" |
|
"multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results." |
|
}, |
|
) |
|
data_seed: int = field( |
|
default=None, |
|
metadata={ |
|
"help": "The seed to use for the underlying generator when using `use_seedable_sampler`. If `None`, the generator" |
|
" will use the current default seed from torch." |
|
}, |
|
) |
|
non_blocking: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": "If set to `True`, the dataloader prepared by the Accelerator will utilize non-blocking host-to-device" |
|
" transfers, allowing for better overlap between dataloader communication and computation. Recommended that the" |
|
" prepared dataloader has `pin_memory` set to `True` to work properly." |
|
}, |
|
) |
|
use_stateful_dataloader: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": "If set to `True`, the dataloader prepared by the Accelerator will be backed by " |
|
"[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed." |
|
}, |
|
) |
|
|
|
|
|
@dataclass |
|
class ProjectConfiguration: |
|
""" |
|
Configuration for the Accelerator object based on inner-project needs. |
|
|
|
Args: |
|
project_dir (`str`, defaults to `None`): |
|
A path to a directory for storing data. |
|
logging_dir (`str`, defaults to `None`): |
|
A path to a directory for storing logs of locally-compatible loggers. If None, defaults to `project_dir`. |
|
automatic_checkpoint_naming (`bool`, defaults to `False`): |
|
Whether saved states should be automatically iteratively named. |
|
total_limit (`int`, defaults to `None`): |
|
The maximum number of total saved states to keep. |
|
iteration (`int`, defaults to `0`): |
|
The current save iteration. |
|
save_on_each_node (`bool`, defaults to `False`): |
|
When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on |
|
the main one. |
|
""" |
|
|
|
project_dir: str = field(default=None, metadata={"help": "A path to a directory for storing data."}) |
|
logging_dir: str = field( |
|
default=None, |
|
metadata={ |
|
"help": "A path to a directory for storing logs of locally-compatible loggers. If None, defaults to `project_dir`." |
|
}, |
|
) |
|
automatic_checkpoint_naming: bool = field( |
|
default=False, |
|
metadata={"help": "Whether saved states should be automatically iteratively named."}, |
|
) |
|
|
|
total_limit: int = field( |
|
default=None, |
|
metadata={"help": "The maximum number of total saved states to keep."}, |
|
) |
|
|
|
iteration: int = field( |
|
default=0, |
|
metadata={"help": "The current save iteration."}, |
|
) |
|
|
|
save_on_each_node: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": ( |
|
"When doing multi-node distributed training, whether to save models and checkpoints on each node, or" |
|
" only on the main one" |
|
) |
|
}, |
|
) |
|
|
|
def set_directories(self, project_dir: str = None): |
|
"Sets `self.project_dir` and `self.logging_dir` to the appropriate values." |
|
self.project_dir = project_dir |
|
if self.logging_dir is None: |
|
self.logging_dir = project_dir |
|
|
|
def __post_init__(self): |
|
self.set_directories(self.project_dir) |
|
|
|
|
|
@dataclass |
|
class GradientAccumulationPlugin(KwargsHandler): |
|
""" |
|
A plugin to configure gradient accumulation behavior. You can only pass one of `gradient_accumulation_plugin` or |
|
`gradient_accumulation_steps` to [`Accelerator`]. Passing both raises an error. |
|
|
|
Parameters: |
|
num_steps (`int`): |
|
The number of steps to accumulate gradients for. |
|
adjust_scheduler (`bool`, *optional*, defaults to `True`): |
|
Whether to adjust the scheduler steps to account for the number of steps being accumulated. Should be |
|
`True` if the used scheduler was not adjusted for gradient accumulation. |
|
sync_with_dataloader (`bool`, *optional*, defaults to `True`): |
|
Whether to synchronize setting the gradients when at the end of the dataloader. |
|
sync_each_batch (`bool`, *optional*): |
|
Whether to synchronize setting the gradients at each data batch. Seting to `True` may reduce memory |
|
requirements when using gradient accumulation with distributed training, at expense of speed. |
|
|
|
Example: |
|
|
|
```python |
|
from accelerate.utils import GradientAccumulationPlugin |
|
|
|
gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=2) |
|
accelerator = Accelerator(gradient_accumulation_plugin=gradient_accumulation_plugin) |
|
``` |
|
""" |
|
|
|
num_steps: int = field(default=None, metadata={"help": "The number of steps to accumulate gradients for."}) |
|
adjust_scheduler: bool = field( |
|
default=True, |
|
metadata={ |
|
"help": "Whether to adjust the scheduler steps to account for the number of steps being accumulated. Should be `True` if the used scheduler was not adjusted for gradient accumulation." |
|
}, |
|
) |
|
sync_with_dataloader: bool = field( |
|
default=True, |
|
metadata={ |
|
"help": "Whether to synchronize setting the gradients when at the end of the dataloader. Should only be set to `False` if you know what you're doing." |
|
}, |
|
) |
|
sync_each_batch: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": "Whether to synchronize setting the gradients at each data batch. Setting to `True` may reduce memory requirements when using gradient accumulation with distributed training, at expense of speed." |
|
}, |
|
) |
|
|
|
|
|
@dataclass |
|
class TorchDynamoPlugin(KwargsHandler): |
|
""" |
|
This plugin is used to compile a model with PyTorch 2.0 |
|
|
|
Args: |
|
backend (`DynamoBackend`, defaults to `None`): |
|
A valid Dynamo backend. See https://pytorch.org/docs/stable/torch.compiler.html for more details. |
|
mode (`str`, defaults to `None`): |
|
Possible options are 'default', 'reduce-overhead' or 'max-autotune'. |
|
fullgraph (`bool`, defaults to `None`): |
|
Whether it is ok to break model into several subgraphs. |
|
dynamic (`bool`, defaults to `None`): |
|
Whether to use dynamic shape for tracing. |
|
options (`Any`, defaults to `None`): |
|
A dictionary of options to pass to the backend. |
|
disable (`bool`, defaults to `False`): |
|
Turn torch.compile() into a no-op for testing |
|
use_regional_compilation (`bool`, defaults to `None`): |
|
Use it to reduce the cold start compilation time of torch.compile() by targeting repeated blocks of the |
|
same class and compiling them sequentially to hit the compiler's cache. For example, in `GPT2LMHeadModel`, |
|
the repeated block/class is `GPT2Block`, and can be accessed as `model.transformer.h[0]`. The rest of the |
|
model (e.g model.lm_head) is compiled separately. |
|
""" |
|
|
|
backend: DynamoBackend = field( |
|
default=None, |
|
metadata={"help": f"Possible options are {[b.value.lower() for b in DynamoBackend]}"}, |
|
) |
|
mode: str = field( |
|
default=None, metadata={"help": "Possible options are 'default', 'reduce-overhead' or 'max-autotune'"} |
|
) |
|
fullgraph: bool = field(default=None, metadata={"help": "Whether it is ok to break model into several subgraphs"}) |
|
dynamic: bool = field(default=None, metadata={"help": "Whether to use dynamic shape for tracing"}) |
|
options: Any = field(default=None, metadata={"help": "A dictionary of options to pass to the backend."}) |
|
disable: bool = field(default=False, metadata={"help": "Turn torch.compile() into a no-op for testing"}) |
|
|
|
use_regional_compilation: bool = field( |
|
default=None, |
|
metadata={ |
|
"help": ( |
|
|
|
"Use it to reduce the cold start compilation time of torch.compile() by targeting repeated " |
|
"blocks of the same class and compiling them sequentially to hit the compiler's cache. For " |
|
"example, in `GPT2LMHeadModel`, the repeated block/class is `GPT2Block`, and can be accessed " |
|
"as `model.transformer.h[0]`. The rest of the model (e.g model.lm_head) is compiled separately." |
|
) |
|
}, |
|
) |
|
|
|
def __post_init__(self): |
|
prefix = "ACCELERATE_DYNAMO_" |
|
if self.backend is None: |
|
self.backend = os.environ.get(prefix + "BACKEND", "no") |
|
self.backend = DynamoBackend(self.backend.upper()) |
|
|
|
if self.mode is None: |
|
self.mode = os.environ.get(prefix + "MODE", "default") |
|
if self.fullgraph is None: |
|
self.fullgraph = str_to_bool(os.environ.get(prefix + "USE_FULLGRAPH", "False")) == 1 |
|
if self.use_regional_compilation is None: |
|
self.use_regional_compilation = ( |
|
str_to_bool(os.environ.get(prefix + "USE_REGIONAL_COMPILATION", "False")) == 1 |
|
) |
|
|
|
if self.dynamic is None and os.environ.get(prefix + "USE_DYNAMIC", None) is not None: |
|
self.dynamic = str_to_bool(os.environ.get(prefix + "USE_DYNAMIC", "False")) == 1 |
|
|
|
def to_dict(self): |
|
dynamo_config = copy.deepcopy(self.__dict__) |
|
dynamo_config["backend"] = dynamo_config["backend"].value.lower() |
|
return dynamo_config |
|
|
|
def to_kwargs(self): |
|
kwargs = super().to_kwargs() |
|
kwargs.pop("use_regional_compilation", None) |
|
return kwargs |
|
|
|
|
|
@dataclass |
|
class DeepSpeedPlugin: |
|
""" |
|
This plugin is used to integrate DeepSpeed. |
|
|
|
Args: |
|
hf_ds_config (`Any`, defaults to `None`): |
|
Path to DeepSpeed config file or dict or an object of class `accelerate.utils.deepspeed.HfDeepSpeedConfig`. |
|
gradient_accumulation_steps (`int`, defaults to `None`): |
|
Number of steps to accumulate gradients before updating optimizer states. If not set, will use the value |
|
from the `Accelerator` directly. |
|
gradient_clipping (`float`, defaults to `None`): |
|
Enable gradient clipping with value. |
|
zero_stage (`int`, defaults to `None`): |
|
Possible options are 0, 1, 2, 3. Default will be taken from environment variable. |
|
is_train_batch_min (`bool`, defaults to `True`): |
|
If both train & eval dataloaders are specified, this will decide the `train_batch_size`. |
|
offload_optimizer_device (`str`, defaults to `None`): |
|
Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3. |
|
offload_param_device (`str`, defaults to `None`): |
|
Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3. |
|
offload_optimizer_nvme_path (`str`, defaults to `None`): |
|
Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3. |
|
offload_param_nvme_path (`str`, defaults to `None`): |
|
Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3. |
|
zero3_init_flag (`bool`, defaults to `None`): |
|
Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3. |
|
zero3_save_16bit_model (`bool`, defaults to `None`): |
|
Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3. |
|
transformer_moe_cls_names (`str`, defaults to `None`): |
|
Comma-separated list of Transformers MoE layer class names (case-sensitive). For example, |
|
`MixtralSparseMoeBlock`, `Qwen2MoeSparseMoeBlock`, `JetMoEAttention`, `JetMoEBlock`, etc. |
|
enable_msamp (`bool`, defaults to `None`): |
|
Flag to indicate whether to enable MS-AMP backend for FP8 training. |
|
msasmp_opt_level (`Optional[Literal["O1", "O2"]]`, defaults to `None`): |
|
Optimization level for MS-AMP (defaults to 'O1'). Only applicable if `enable_msamp` is True. Should be one |
|
of ['O1' or 'O2']. |
|
""" |
|
|
|
hf_ds_config: Any = field( |
|
default=None, |
|
metadata={ |
|
"help": "path to DeepSpeed config file or dict or an object of class `accelerate.utils.deepspeed.HfDeepSpeedConfig`." |
|
}, |
|
) |
|
gradient_accumulation_steps: int = field( |
|
default=None, |
|
metadata={ |
|
"help": "Number of steps to accumulate gradients before updating optimizer states. If not set, will use the value from the `Accelerator` directly." |
|
}, |
|
) |
|
gradient_clipping: float = field(default=None, metadata={"help": "Enable gradient clipping with value"}) |
|
zero_stage: int = field( |
|
default=None, |
|
metadata={"help": "Possible options are 0,1,2,3; Default will be taken from environment variable"}, |
|
) |
|
is_train_batch_min: bool = field( |
|
default=True, |
|
metadata={"help": "If both train & eval dataloaders are specified, this will decide the train_batch_size"}, |
|
) |
|
offload_optimizer_device: str = field( |
|
default=None, |
|
metadata={"help": "Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3."}, |
|
) |
|
offload_param_device: str = field( |
|
default=None, |
|
metadata={"help": "Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3."}, |
|
) |
|
offload_optimizer_nvme_path: str = field( |
|
default=None, |
|
metadata={"help": "Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3."}, |
|
) |
|
offload_param_nvme_path: str = field( |
|
default=None, |
|
metadata={"help": "Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3."}, |
|
) |
|
zero3_init_flag: bool = field( |
|
default=None, |
|
metadata={ |
|
"help": "Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models." |
|
"Only applicable with ZeRO Stage-3." |
|
}, |
|
) |
|
zero3_save_16bit_model: bool = field( |
|
default=None, |
|
metadata={"help": "Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3."}, |
|
) |
|
transformer_moe_cls_names: str = field( |
|
default=None, |
|
metadata={ |
|
"help": "comma-separated list of transformers MoE layer class names (case-sensitive), e.g : " |
|
" `MixtralSparseMoeBlock`, `Qwen2MoeSparseMoeBlock`, `JetMoEAttention,JetMoEBlock` ..." |
|
}, |
|
) |
|
enable_msamp: bool = field( |
|
default=None, |
|
metadata={"help": "Flag to indicate whether to enable MS-AMP backend for FP8 training."}, |
|
) |
|
msamp_opt_level: Optional[Literal["O1", "O2"]] = field( |
|
default=None, |
|
metadata={ |
|
"help": "Optimization level for MS-AMP (defaults to 'O1'). Only applicable if `enable_msamp` is True. Should be one of ['O1' or 'O2']." |
|
}, |
|
) |
|
|
|
def __post_init__(self): |
|
from .deepspeed import HfDeepSpeedConfig |
|
|
|
if self.gradient_accumulation_steps is None: |
|
gas = os.environ.get("ACCELERATE_GRADIENT_ACCUMULATION_STEPS", "auto") |
|
self.gradient_accumulation_steps = int(gas) if gas.isdigit() else gas |
|
|
|
if self.gradient_clipping is None: |
|
gradient_clipping = os.environ.get("ACCELERATE_GRADIENT_CLIPPING", "auto") |
|
self.gradient_clipping = gradient_clipping if gradient_clipping == "auto" else float(gradient_clipping) |
|
|
|
if self.zero_stage is None: |
|
self.zero_stage = int(os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE", 2)) |
|
|
|
if self.offload_optimizer_device is None: |
|
self.offload_optimizer_device = os.environ.get("ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_DEVICE", "none") |
|
|
|
if self.offload_param_device is None: |
|
self.offload_param_device = os.environ.get("ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_DEVICE", "none") |
|
|
|
if self.offload_optimizer_nvme_path is None: |
|
self.offload_optimizer_nvme_path = os.environ.get( |
|
"ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_NVME_PATH", "none" |
|
) |
|
|
|
if self.offload_param_nvme_path is None: |
|
self.offload_param_nvme_path = os.environ.get("ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_NVME_PATH", "none") |
|
|
|
if self.zero3_save_16bit_model is None: |
|
self.zero3_save_16bit_model = ( |
|
os.environ.get("ACCELERATE_DEEPSPEED_ZERO3_SAVE_16BIT_MODEL", "false") == "true" |
|
) |
|
if self.enable_msamp is None: |
|
self.enable_msamp = os.environ.get("ACCELERATE_FP8_BACKEND", None) == "MSAMP" |
|
|
|
if self.msamp_opt_level is None: |
|
self.msamp_opt_level = os.environ.get("ACCELERATE_FP8_OPT_LEVEL", "O1") |
|
|
|
if self.hf_ds_config is None: |
|
self.hf_ds_config = os.environ.get("ACCELERATE_DEEPSPEED_CONFIG_FILE", "none") |
|
if ( |
|
isinstance(self.hf_ds_config, dict) |
|
or (isinstance(self.hf_ds_config, str) and self.hf_ds_config != "none") |
|
or isinstance(self.hf_ds_config, HfDeepSpeedConfig) |
|
): |
|
if not isinstance(self.hf_ds_config, HfDeepSpeedConfig): |
|
self.hf_ds_config = HfDeepSpeedConfig(self.hf_ds_config) |
|
if "gradient_accumulation_steps" not in self.hf_ds_config.config: |
|
self.hf_ds_config.config["gradient_accumulation_steps"] = 1 |
|
if "zero_optimization" not in self.hf_ds_config.config: |
|
raise ValueError("Please specify the ZeRO optimization config in the DeepSpeed config.") |
|
|
|
self._deepspeed_config_checks() |
|
plugin_to_config_mapping = { |
|
"gradient_accumulation_steps": "gradient_accumulation_steps", |
|
"gradient_clipping": "gradient_clipping", |
|
"zero_stage": "zero_optimization.stage", |
|
"offload_optimizer_device": "zero_optimization.offload_optimizer.device", |
|
"offload_param_device": "zero_optimization.offload_param.device", |
|
"offload_param_nvme_path": "zero_optimization.offload_param.nvme_path", |
|
"offload_optimizer_nvme_path": "zero_optimization.offload_optimizer.nvme_path", |
|
"zero3_save_16bit_model": "zero_optimization.stage3_gather_16bit_weights_on_model_save", |
|
} |
|
kwargs = {v: getattr(self, k) for k, v in plugin_to_config_mapping.items() if getattr(self, k) is not None} |
|
for key in kwargs.keys(): |
|
self.fill_match(key, **kwargs, must_match=False) |
|
self.hf_ds_config.set_stage_and_offload() |
|
|
|
|
|
|
|
for key, value in plugin_to_config_mapping.items(): |
|
config_value = self.hf_ds_config.get_value(value) |
|
if config_value is not None and config_value != "auto": |
|
setattr(self, key, config_value) |
|
else: |
|
config = { |
|
"train_batch_size": "auto", |
|
"train_micro_batch_size_per_gpu": "auto", |
|
"gradient_accumulation_steps": self.gradient_accumulation_steps, |
|
"zero_optimization": { |
|
"stage": self.zero_stage, |
|
"offload_optimizer": { |
|
"device": self.offload_optimizer_device, |
|
"nvme_path": self.offload_optimizer_nvme_path |
|
if self.offload_optimizer_device == "nvme" |
|
else None, |
|
}, |
|
"offload_param": { |
|
"device": self.offload_param_device, |
|
"nvme_path": self.offload_param_nvme_path if self.offload_param_device == "nvme" else None, |
|
}, |
|
"stage3_gather_16bit_weights_on_model_save": self.zero3_save_16bit_model, |
|
}, |
|
} |
|
if self.gradient_clipping: |
|
config["gradient_clipping"] = self.gradient_clipping |
|
self.hf_ds_config = HfDeepSpeedConfig(config) |
|
|
|
self.deepspeed_config = self.hf_ds_config.config |
|
self.deepspeed_config["steps_per_print"] = float("inf") |
|
if self.zero3_init_flag is None: |
|
self.zero3_init_flag = ( |
|
str_to_bool(os.environ.get("ACCELERATE_DEEPSPEED_ZERO3_INIT", str(self.hf_ds_config.is_zero3()))) == 1 |
|
) |
|
if self.zero3_init_flag and not self.hf_ds_config.is_zero3(): |
|
warnings.warn("DeepSpeed Zero3 Init flag is only applicable for ZeRO Stage 3. Setting it to False.") |
|
self.zero3_init_flag = False |
|
|
|
|
|
self._set_selected(False) |
|
|
|
|
|
if self.enable_msamp and "msamp" not in self.deepspeed_config: |
|
if self.zero_stage == 3: |
|
raise NotImplementedError( |
|
"MS-AMP is not supported for ZeRO Stage 3. Please use ZeRO Stage 0, 1, or 2 instead." |
|
) |
|
if self.msamp_opt_level not in ["O1", "O2"]: |
|
raise ValueError("Invalid optimization level for MS-AMP. Please use one of ['O1' or'O2'].") |
|
self.deepspeed_config["msamp"] = {"enabled": True, "opt_level": self.msamp_opt_level} |
|
|
|
def fill_match(self, ds_key_long, mismatches=None, must_match=True, **kwargs): |
|
mismatches = [] if mismatches is None else mismatches |
|
config, ds_key = self.hf_ds_config.find_config_node(ds_key_long) |
|
if config is None: |
|
return |
|
|
|
if config.get(ds_key) == "auto": |
|
if ds_key_long in kwargs: |
|
config[ds_key] = kwargs[ds_key_long] |
|
return |
|
else: |
|
raise ValueError( |
|
f"`{ds_key_long}` not found in kwargs. " |
|
f"Please specify `{ds_key_long}` without `auto` (set to correct value) in the DeepSpeed config file or " |
|
"pass it in kwargs." |
|
) |
|
|
|
if not must_match: |
|
return |
|
|
|
ds_val = config.get(ds_key) |
|
if ds_val is not None and ds_key_long in kwargs: |
|
if ds_val != kwargs[ds_key_long]: |
|
mismatches.append(f"- ds {ds_key_long}={ds_val} vs arg {ds_key_long}={kwargs[ds_key_long]}") |
|
|
|
def is_auto(self, ds_key_long): |
|
val = self.hf_ds_config.get_value(ds_key_long) |
|
if val is None: |
|
return False |
|
else: |
|
return val == "auto" |
|
|
|
def get_value(self, ds_key_long, default=None): |
|
return self.hf_ds_config.get_value(ds_key_long, default) |
|
|
|
def deepspeed_config_process(self, prefix="", mismatches=None, config=None, must_match=True, **kwargs): |
|
"""Process the DeepSpeed config with the values from the kwargs.""" |
|
mismatches = [] if mismatches is None else mismatches |
|
if config is None: |
|
config = self.deepspeed_config |
|
for key, value in config.items(): |
|
if isinstance(value, dict): |
|
self.deepspeed_config_process( |
|
prefix=prefix + key + ".", mismatches=mismatches, config=value, must_match=must_match, **kwargs |
|
) |
|
else: |
|
self.fill_match(prefix + key, mismatches, must_match=must_match, **kwargs) |
|
if len(mismatches) > 0 and prefix == "": |
|
mismatches_msg = "\n".join(mismatches) |
|
raise ValueError( |
|
"Please correct the following DeepSpeed config values that mismatch kwargs " |
|
f" values:\n{mismatches_msg}\nThe easiest method is to set these DeepSpeed config values to 'auto'." |
|
) |
|
|
|
def set_mixed_precision(self, mixed_precision): |
|
ds_config = self.deepspeed_config |
|
kwargs = { |
|
"fp16.enabled": mixed_precision == "fp16", |
|
|
|
"bf16.enabled": mixed_precision in ("bf16", "fp8"), |
|
} |
|
if mixed_precision == "fp16": |
|
if "fp16" not in ds_config: |
|
ds_config["fp16"] = {"enabled": True, "auto_cast": True} |
|
elif mixed_precision in ("bf16", "fp8"): |
|
if "bf16" not in ds_config: |
|
ds_config["bf16"] = {"enabled": True} |
|
|
|
if mixed_precision == "fp8" and self.enable_msamp: |
|
if "msamp" not in ds_config: |
|
ds_config["msamp"] = {"enabled": True, "opt_level": self.msamp_opt_level} |
|
|
|
if mixed_precision != "no": |
|
diff_dtype = "bf16" if mixed_precision == "fp16" else "fp16" |
|
if str(ds_config.get(diff_dtype, {}).get("enabled", "False")).lower() == "true": |
|
raise ValueError( |
|
f"`--mixed_precision` arg cannot be set to `{mixed_precision}` when `{diff_dtype}` is set in the DeepSpeed config file." |
|
) |
|
for dtype in ["fp16", "bf16"]: |
|
if dtype not in ds_config: |
|
ds_config[dtype] = {"enabled": False} |
|
self.fill_match("fp16.enabled", must_match=False, **kwargs) |
|
self.fill_match("bf16.enabled", must_match=False, **kwargs) |
|
|
|
def set_deepspeed_weakref(self): |
|
from .imports import is_transformers_available |
|
|
|
ds_config = copy.deepcopy(self.deepspeed_config) |
|
if self.zero3_init_flag: |
|
if not is_transformers_available(): |
|
raise Exception( |
|
"When `zero3_init_flag` is set, it requires Transformers to be installed. " |
|
"Please run `pip install transformers`." |
|
) |
|
if "gradient_accumulation_steps" not in ds_config or ds_config["gradient_accumulation_steps"] == "auto": |
|
ds_config["gradient_accumulation_steps"] = 1 |
|
if "train_micro_batch_size_per_gpu" not in ds_config or ds_config["train_micro_batch_size_per_gpu"] == "auto": |
|
ds_config["train_micro_batch_size_per_gpu"] = 1 |
|
if ds_config.get("train_batch_size", None) == "auto": |
|
del ds_config["train_batch_size"] |
|
|
|
if compare_versions("transformers", "<", "4.46"): |
|
from transformers.deepspeed import HfDeepSpeedConfig, unset_hf_deepspeed_config |
|
else: |
|
from transformers.integrations import HfDeepSpeedConfig, unset_hf_deepspeed_config |
|
|
|
unset_hf_deepspeed_config() |
|
self.dschf = HfDeepSpeedConfig(ds_config) |
|
|
|
def is_zero3_init_enabled(self): |
|
return self.zero3_init_flag |
|
|
|
@contextmanager |
|
def zero3_init_context_manager(self, enable=False): |
|
old = self.zero3_init_flag |
|
if old == enable: |
|
yield |
|
else: |
|
self.zero3_init_flag = enable |
|
self.dschf = None |
|
self.set_deepspeed_weakref() |
|
yield |
|
self.zero3_init_flag = old |
|
self.dschf = None |
|
self.set_deepspeed_weakref() |
|
|
|
def _deepspeed_config_checks(self): |
|
env_variable_names_to_ignore = [ |
|
"ACCELERATE_GRADIENT_ACCUMULATION_STEPS", |
|
"ACCELERATE_GRADIENT_CLIPPING", |
|
"ACCELERATE_DEEPSPEED_ZERO_STAGE", |
|
"ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_DEVICE", |
|
"ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_DEVICE", |
|
"ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_NVME_PATH", |
|
"ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_NVME_PATH", |
|
"ACCELERATE_DEEPSPEED_ZERO3_SAVE_16BIT_MODEL", |
|
"ACCELERATE_MIXED_PRECISION", |
|
] |
|
env_variable_names_to_ignore = [ |
|
name.replace("ACCELERATE_", "").replace("DEEPSPEED_", "").lower() for name in env_variable_names_to_ignore |
|
] |
|
|
|
deepspeed_fields_from_accelerate_config = os.environ.get("ACCELERATE_CONFIG_DS_FIELDS", "").split(",") |
|
|
|
if any(name in env_variable_names_to_ignore for name in deepspeed_fields_from_accelerate_config): |
|
raise ValueError( |
|
f"When using `deepspeed_config_file`, the following accelerate config variables will be ignored: {env_variable_names_to_ignore}.\n" |
|
"Please specify them appropriately in the DeepSpeed config file.\n" |
|
"If you are using an accelerate config file, remove others config variables mentioned in the above specified list.\n" |
|
"The easiest method is to create a new config following the questionnaire via `accelerate config`.\n" |
|
"It will only ask for the necessary config variables when using `deepspeed_config_file`." |
|
) |
|
|
|
def set_moe_leaf_modules(self, model): |
|
if self.transformer_moe_cls_names is None: |
|
self.transformer_moe_cls_names = os.environ.get("ACCELERATE_DEEPSPEED_MOE_LAYER_CLS_NAMES", None) |
|
if self.transformer_moe_cls_names is not None: |
|
if compare_versions("deepspeed", "<", "0.14.0"): |
|
raise ImportError("DeepSpeed version must be >= 0.14.0 to use MOE support. Please update DeepSpeed.") |
|
from deepspeed.utils import set_z3_leaf_modules |
|
|
|
class_names = self.transformer_moe_cls_names.split(",") |
|
transformer_moe_cls = [] |
|
for layer_class in class_names: |
|
transformer_cls = get_module_class_from_name(model, layer_class) |
|
if transformer_cls is None: |
|
raise Exception( |
|
f"Could not find a transformer layer class called '{layer_class}' to wrap in the model." |
|
) |
|
else: |
|
transformer_moe_cls.append(transformer_cls) |
|
set_z3_leaf_modules(model, transformer_moe_cls) |
|
|
|
def select(self, _from_accelerator_state: bool = False): |
|
""" |
|
Sets the HfDeepSpeedWeakref to use the current deepspeed plugin configuration |
|
""" |
|
if not _from_accelerator_state: |
|
raise ValueError( |
|
"A `DeepSpeedPlugin` object must be enabled manually by calling `AcceleratorState().enable_deepspeed_plugin(plugin_key)`." |
|
) |
|
self.set_deepspeed_weakref() |
|
self._set_selected(True) |
|
|
|
def _unselect(self): |
|
self._set_selected(False) |
|
|
|
def _set_selected(self, value: bool): |
|
""" |
|
Private setter for the 'enabled' attribute. |
|
""" |
|
self._selected = value |
|
|
|
@property |
|
def selected(self): |
|
return self._selected |
|
|
|
@selected.setter |
|
def selected(self, value): |
|
raise NotImplementedError( |
|
"'enabled' can only be set through calling 'AcceleratorState().enable_deepspeed_plugin(key)'." |
|
) |
|
|
|
|
|
@dataclass |
|
class FullyShardedDataParallelPlugin: |
|
""" |
|
This plugin is used to enable fully sharded data parallelism. |
|
|
|
Args: |
|
fsdp_version (`int`, defaults to `1`): |
|
The version of FSDP to use. Defaults to 1. If set to 2, launcher expects the config to be converted to |
|
FSDP2 format. |
|
sharding_strategy (`Union[str, torch.distributed.fsdp.ShardingStrategy]`, defaults to `'FULL_SHARD'`): |
|
Sharding strategy to use. Should be either a `str` or an instance of |
|
`torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`. Is deprecated in favor of |
|
`reshard_after_forward`. |
|
reshard_after_forward (`Union[str, torch.distributed.fsdp.ShardingStrategy, bool]`, defaults to `'FULL_SHARD'` for `fsdp_version=1` and `True` for `fsdp_version=2`): |
|
Sharding strategy to use. Should be a bool if `fsdp_version` is set to 2 else a `str` or an instance of |
|
`torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`. |
|
backward_prefetch (`Union[str, torch.distributed.fsdp.BackwardPrefetch]`, defaults to `'NO_PREFETCH'`): |
|
Backward prefetch strategy to use. Should be either a `str` or an instance of |
|
`torch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch`. |
|
mixed_precision_policy (`Optional[Union[dict, torch.distributed.fsdp.MixedPrecision, torch.distributed.fsdp.MixedPrecisionPolicy]]`, defaults to `None`): |
|
A config to enable mixed precision training with FullyShardedDataParallel. If passing in a `dict`, it |
|
should have the following keys: `param_dtype`, `reduce_dtype`, and `buffer_dtype`, can be an instance of |
|
`torch.distributed.fsdp.MixedPrecisionPolicy` if `fsdp_version` is set to 2. |
|
auto_wrap_policy (`Optional(Union[Callable, Literal["transformer_based_wrap", "size_based_wrap", "no_wrap"]]), defaults to `NO_WRAP`): |
|
A callable or string specifying a policy to recursively wrap layers with FSDP. If a string, it must be one |
|
of `transformer_based_wrap`, `size_based_wrap`, or `no_wrap`. See |
|
`torch.distributed.fsdp.wrap.size_based_wrap_policy` for a direction on what it should look like. |
|
cpu_offload (`Union[bool, torch.distributed.fsdp.CPUOffload, torch.distributed.fsdp.CPUOffloadPolicy]`, defaults to `False`): |
|
Whether to offload parameters to CPU. Should be either a `bool` or an instance of |
|
`torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload` or |
|
`torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffloadPolicy` if `fsdp_version` is set to 2. |
|
ignored_modules (`Optional[Iterable[torch.nn.Module]]`, defaults to `None`): |
|
A list of modules to ignore when wrapping with FSDP. |
|
state_dict_type (`Union[str, torch.distributed.fsdp.StateDictType]`, defaults to `'FULL_STATE_DICT'`): |
|
State dict type to use. If a string, it must be one of `full_state_dict`, `local_state_dict`, or |
|
`sharded_state_dict`. |
|
state_dict_config (`Optional[Union[torch.distributed.fsdp.FullStateDictConfig, torch.distributed.fsdp.ShardedStateDictConfig]`, defaults to `None`): |
|
State dict config to use. Is determined based on the `state_dict_type` if not passed in. |
|
optim_state_dict_config (`Optional[Union[torch.distributed.fsdp.FullOptimStateDictConfig, torch.distributed.fsdp.ShardedOptimStateDictConfig]`, defaults to `None`): |
|
Optim state dict config to use. Is determined based on the `state_dict_type` if not passed in. |
|
limit_all_gathers (`bool`, defaults to `True`): |
|
Whether to have FSDP explicitly synchronizes the CPU thread to prevent too many in-flight all-gathers. This |
|
bool only affects the sharded strategies that schedule all-gathers. Enabling this can help lower the number |
|
of CUDA malloc retries. |
|
use_orig_params (`bool`, defaults to `False`): |
|
Whether to use the original parameters for the optimizer. |
|
param_init_fn (`Optional[Callable[[torch.nn.Module], None]`, defaults to `None`): |
|
A `Callable[torch.nn.Module] -> None` that specifies how modules that are currently on the meta device |
|
should be initialized onto an actual device. Only applicable when `sync_module_states` is `True`. By |
|
default is a `lambda` which calls `to_empty` on the module. |
|
sync_module_states (`bool`, defaults to `False`): |
|
Whether each individually wrapped FSDP unit should broadcast module parameters from rank 0 to ensure they |
|
are the same across all ranks after initialization. Defaults to `False` unless `cpu_ram_efficient_loading` |
|
is `True`, then will be forcibly enabled. |
|
forward_prefetch (`bool`, defaults to `False`): |
|
Whether to have FSDP explicitly prefetches the next upcoming all-gather while executing in the forward |
|
pass. only use with Static graphs. |
|
activation_checkpointing (`bool`, defaults to `False`): |
|
A technique to reduce memory usage by clearing activations of certain layers and recomputing them during a |
|
backward pass. Effectively, this trades extra computation time for reduced memory usage. |
|
cpu_ram_efficient_loading (`bool`, defaults to `None`): |
|
If True, only the first process loads the pretrained model checkoint while all other processes have empty |
|
weights. Only applicable for Transformers. When using this, `sync_module_states` needs to be `True`. |
|
transformer_cls_names_to_wrap (`Optional[List[str]]`, defaults to `None`): |
|
A list of transformer layer class names to wrap. Only applicable when `auto_wrap_policy` is |
|
`transformer_based_wrap`. |
|
min_num_params (`Optional[int]`, defaults to `None`): |
|
The minimum number of parameters a module must have to be wrapped. Only applicable when `auto_wrap_policy` |
|
is `size_based_wrap`. |
|
""" |
|
|
|
fsdp_version: int = field( |
|
default=None, |
|
metadata={ |
|
"help": "The version of FSDP to use. Defaults to 1. If set to 2, launcher expects the config to be converted to FSDP2 format." |
|
}, |
|
) |
|
|
|
sharding_strategy: Union[str, "torch.distributed.fsdp.ShardingStrategy"] = field( |
|
default=None, |
|
metadata={ |
|
"help": "Sharding strategy to use. Should be either a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`. Defaults to 'FULL_SHARD'. Is deprecated in favor of `reshard_after_forward` " |
|
}, |
|
) |
|
|
|
reshard_after_forward: Union[str, "torch.distributed.fsdp.ShardingStrategy", bool] = field( |
|
default=None, |
|
metadata={ |
|
"help": "Sharding strategy to use. Should be a bool if `fsdp_version` is set to 2 else a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`. Defaults to 'FULL_SHARD'" |
|
}, |
|
) |
|
backward_prefetch: Optional[Union[str, "torch.distributed.fsdp.BackwardPrefetch"]] = field( |
|
default=None, |
|
metadata={ |
|
"help": "Backward prefetch strategy to use. Should be either a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch`. Defaults to 'NO_PREFETCH'. This becomes obsolete in FSDP2." |
|
}, |
|
) |
|
mixed_precision_policy: Optional[ |
|
Union[dict, "torch.distributed.fsdp.MixedPrecision", "torch.distributed.fsdp.MixedPrecisionPolicy"] |
|
] = field( |
|
default=None, |
|
metadata={ |
|
"help": "A config to enable mixed precision training with FullyShardedDataParallel. " |
|
"If passing in a `dict`, it should have the following keys: `param_dtype`, `reduce_dtype`, and `buffer_dtype`." |
|
"Can also be an instance of `torch.distributed.fsdp.MixedPrecisionPolicy` if `fsdp_version` is set to 2." |
|
}, |
|
) |
|
auto_wrap_policy: Optional[Union[Callable, Literal["transformer_based_wrap", "size_based_wrap", "no_wrap"]]] = ( |
|
field( |
|
default=None, |
|
metadata={ |
|
"help": "A callable or string specifying a policy to recursively wrap layers with FSDP. If a string, it must be one of `transformer_based_wrap`, `size_based_wrap`, or `no_wrap`. " |
|
"Defaults to `NO_WRAP`. See `torch.distributed.fsdp.wrap.size_based_wrap_policy` for a direction on what it should look like" |
|
}, |
|
) |
|
) |
|
cpu_offload: Union[bool, "torch.distributed.fsdp.CPUOffload", "torch.distributed.fsdp.CPUOffloadPolicy"] = field( |
|
default=None, |
|
metadata={ |
|
"help": "Whether to offload parameters to CPU. Should be either a `bool` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload` or `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffloadPolicy` if `fsdp_version` is set to 2. Defaults to `False`" |
|
}, |
|
) |
|
ignored_modules: Optional[Iterable[torch.nn.Module]] = field( |
|
default=None, |
|
metadata={"help": "A list of modules to ignore when wrapping with FSDP."}, |
|
) |
|
|
|
state_dict_type: Union[str, "torch.distributed.fsdp.StateDictType"] = field( |
|
default=None, |
|
metadata={ |
|
"help": "State dict type to use. If a string, it must be one of `full_state_dict`, `local_state_dict`, or `sharded_state_dict`. Defaults to `FULL_STATE_DICT`" |
|
}, |
|
) |
|
state_dict_config: Optional[ |
|
Union[ |
|
"torch.distributed.fsdp.FullStateDictConfig", |
|
"torch.distributed.fsdp.ShardedStateDictConfig", |
|
] |
|
] = field( |
|
default=None, |
|
metadata={"help": "State dict config to use. Is determined based on the `state_dict_type` if not passed in."}, |
|
) |
|
optim_state_dict_config: Optional[ |
|
Union["torch.distributed.fsdp.FullOptimStateDictConfig", "torch.distributed.fsdp.ShardedOptimStateDictConfig"] |
|
] = field( |
|
default=None, |
|
metadata={ |
|
"help": "Optim state dict config to use. Is determined based on the `state_dict_type` if not passed in." |
|
}, |
|
) |
|
limit_all_gathers: bool = field( |
|
default=True, |
|
metadata={ |
|
"help": "Whether to have FSDP explicitly synchronizes the CPU thread to prevent " |
|
"too many in-flight all-gathers. This bool only affects the sharded strategies that schedule all-gathers. " |
|
"Enabling this can help lower the number of CUDA malloc retries." |
|
}, |
|
) |
|
use_orig_params: Optional[bool] = field( |
|
default=None, |
|
metadata={ |
|
"help": "Whether to use the original parameters for the optimizer. Defaults to `False`. This becomes obsolete in FSDP2." |
|
}, |
|
) |
|
param_init_fn: Optional[Callable[[torch.nn.Module], None]] = field( |
|
default=None, |
|
metadata={ |
|
"help": "A Callable[torch.nn.Module] -> None that specifies how modules " |
|
"that are currently on the meta device should be initialized onto an actual device. " |
|
"Only applicable when `sync_module_states` is `True`. By default is a `lambda` which calls `to_empty` on the module." |
|
}, |
|
) |
|
sync_module_states: Optional[bool] = field( |
|
default=None, |
|
metadata={ |
|
"help": "Whether each individually wrapped FSDP unit should broadcast module parameters from rank 0 " |
|
"to ensure they are the same across all ranks after initialization. Defaults to `False` unless " |
|
"`cpu_ram_efficient_loading` is `True`, then will be forcibly enabled. This becomes obsolete in FSDP2." |
|
}, |
|
) |
|
forward_prefetch: bool = field( |
|
default=None, |
|
metadata={ |
|
"help": "Whether to have FSDP explicitly prefetches the next upcoming " |
|
"all-gather while executing in the forward pass. only use with Static graphs. Defaults to `False`" |
|
}, |
|
) |
|
activation_checkpointing: bool = field( |
|
default=None, |
|
metadata={ |
|
"help": "A technique to reduce memory usage by clearing activations of " |
|
"certain layers and recomputing them during a backward pass. Effectively, this trades extra computation time " |
|
"for reduced memory usage. Defaults to `False`" |
|
}, |
|
) |
|
cpu_ram_efficient_loading: bool = field( |
|
default=None, |
|
metadata={ |
|
"help": "If True, only the first process loads the pretrained model checkoint while all other processes have empty weights. " |
|
"Only applicable for π€ Transformers. When using this, `sync_module_states` needs to be `True`. Defaults to `False`." |
|
}, |
|
) |
|
transformer_cls_names_to_wrap: Optional[list[str]] = field( |
|
default=None, |
|
metadata={ |
|
"help": "A list of transformer layer class names to wrap. Only applicable when `auto_wrap_policy` is `transformer_based_wrap`." |
|
}, |
|
) |
|
min_num_params: Optional[int] = field( |
|
default=None, |
|
metadata={ |
|
"help": "The minimum number of parameters a module must have to be wrapped. Only applicable when `auto_wrap_policy` is `size_based_wrap`." |
|
}, |
|
) |
|
|
|
def __post_init__(self): |
|
from torch.distributed.fsdp import ( |
|
BackwardPrefetch, |
|
ShardingStrategy, |
|
) |
|
|
|
_fsdp2_warnings = set() |
|
|
|
env_prefix = "FSDP_" |
|
|
|
if self.fsdp_version is None: |
|
self.fsdp_version = int(os.environ.get(env_prefix + "VERSION", "1")) |
|
|
|
if self.fsdp_version == 2: |
|
if not is_torch_version(">=", FSDP2_PYTORCH_VERSION): |
|
raise ImportError(f"FSDP2 requires PyTorch >= {FSDP2_PYTORCH_VERSION}") |
|
|
|
if self.sharding_strategy is not None: |
|
|
|
|
|
_fsdp2_warnings.add( |
|
"sharding_strategy is deprecated in favor of reshard_after_forward. " |
|
"This will be removed in a future version of Accelerate." |
|
) |
|
if self.fsdp_version == 1: |
|
if self.sharding_strategy is None: |
|
self.sharding_strategy = os.environ.get(env_prefix + "SHARDING_STRATEGY", "FULL_SHARD") |
|
if isinstance(self.sharding_strategy, str): |
|
if self.sharding_strategy.upper() in FSDP_SHARDING_STRATEGY: |
|
self.sharding_strategy = FSDP_SHARDING_STRATEGY.index(self.sharding_strategy.upper()) + 1 |
|
if isinstance(self.sharding_strategy, int) or self.sharding_strategy.isdigit(): |
|
self.sharding_strategy = ShardingStrategy(int(self.sharding_strategy)) |
|
else: |
|
self.sharding_strategy = ShardingStrategy[self.sharding_strategy.upper()] |
|
|
|
|
|
if self.reshard_after_forward is None and self.sharding_strategy is None: |
|
reshard_after_forward = os.environ.get( |
|
env_prefix + "RESHARD_AFTER_FORWARD", "true" if self.fsdp_version == 2 else "FULL_SHARD" |
|
) |
|
if self.fsdp_version == 2: |
|
self.reshard_after_forward = str_to_bool(reshard_after_forward.lower(), to_bool=True) |
|
else: |
|
self.reshard_after_forward = reshard_after_forward |
|
if isinstance(self.reshard_after_forward, str): |
|
if self.fsdp_version == 2: |
|
self.reshard_after_forward = str_to_bool(self.reshard_after_forward.lower(), to_bool=True) |
|
else: |
|
|
|
if self.reshard_after_forward.upper() in FSDP_SHARDING_STRATEGY: |
|
self.reshard_after_forward = FSDP_SHARDING_STRATEGY.index(self.reshard_after_forward.upper()) + 1 |
|
if isinstance(self.reshard_after_forward, int) or self.reshard_after_forward.isdigit(): |
|
self.reshard_after_forward = ShardingStrategy(int(self.reshard_after_forward)) |
|
else: |
|
self.reshard_after_forward = ShardingStrategy[self.reshard_after_forward.upper()] |
|
|
|
if self.fsdp_version == 2 and not isinstance(self.reshard_after_forward, bool): |
|
raise ValueError( |
|
f"reshard_after_forward set to {self.reshard_after_forward}. This is not supported with FSDP2, please set to a `bool`" |
|
) |
|
if self.fsdp_version == 1 and isinstance(self.reshard_after_forward, bool): |
|
raise ValueError( |
|
f"reshard_after_forward set to {self.reshard_after_forward}. This is not supported with FSDP1, please set to a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`" |
|
) |
|
|
|
if self.cpu_offload is None: |
|
self.cpu_offload = str_to_bool(os.environ.get(env_prefix + "OFFLOAD_PARAMS", "False")) == 1 |
|
|
|
self.set_cpu_offload() |
|
self.validate_cpu_offload() |
|
|
|
if self.backward_prefetch is None: |
|
self.backward_prefetch = os.environ.get(env_prefix + "BACKWARD_PREFETCH", None) |
|
if isinstance(self.backward_prefetch, str) and self.backward_prefetch.upper() == "NO_PREFETCH": |
|
self.backward_prefetch = None |
|
if self.backward_prefetch is not None and not isinstance(self.backward_prefetch, BackwardPrefetch): |
|
if isinstance(self.backward_prefetch, str) and self.backward_prefetch.upper() in FSDP_BACKWARD_PREFETCH: |
|
self.backward_prefetch = FSDP_BACKWARD_PREFETCH.index(self.backward_prefetch.upper()) + 1 |
|
if isinstance(self.backward_prefetch, int) or self.backward_prefetch.isdigit(): |
|
self.backward_prefetch = BackwardPrefetch(int(self.backward_prefetch)) |
|
else: |
|
self.backward_prefetch = BackwardPrefetch[self.backward_prefetch.upper()] |
|
if self.fsdp_version == 2 and self.backward_prefetch is not None: |
|
_fsdp2_warnings.add("backward_prefetch is not supported in FSDP2. Setting backward prefetch to None.") |
|
self.backward_prefetch = None |
|
|
|
self.set_state_dict_type() |
|
|
|
if self.auto_wrap_policy is None: |
|
self.auto_wrap_policy = os.environ.get(env_prefix + "AUTO_WRAP_POLICY", "NO_WRAP") |
|
if isinstance(self.auto_wrap_policy, str): |
|
if self.auto_wrap_policy.upper() not in FSDP_AUTO_WRAP_POLICY: |
|
raise ValueError( |
|
f"Invalid auto wrap policy: {self.auto_wrap_policy}. Must be one of {FSDP_AUTO_WRAP_POLICY}" |
|
) |
|
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy |
|
|
|
if self.auto_wrap_policy.upper() == "TRANSFORMER_BASED_WRAP": |
|
self.auto_wrap_policy = transformer_auto_wrap_policy |
|
if self.transformer_cls_names_to_wrap is None: |
|
self.transformer_cls_names_to_wrap = os.environ.get(env_prefix + "TRANSFORMER_CLS_TO_WRAP", None) |
|
if isinstance(self.transformer_cls_names_to_wrap, str): |
|
self.transformer_cls_names_to_wrap = self.transformer_cls_names_to_wrap.split(",") |
|
elif self.auto_wrap_policy.upper() == "SIZE_BASED_WRAP": |
|
self.auto_wrap_policy = size_based_auto_wrap_policy |
|
if self.min_num_params is None: |
|
self.min_num_params = int(os.environ.get(env_prefix + "MIN_NUM_PARAMS", 0)) |
|
elif not isinstance(self.min_num_params, int): |
|
raise ValueError( |
|
f"`min_num_params` must be an integer. Got {self.min_num_params} of type {type(self.min_num_params)}" |
|
) |
|
elif self.auto_wrap_policy.upper() == "NO_WRAP": |
|
self.auto_wrap_policy = None |
|
|
|
if self.use_orig_params is None and self.fsdp_version == 1: |
|
self.use_orig_params = str_to_bool(os.environ.get(env_prefix + "USE_ORIG_PARAMS", "False")) == 1 |
|
if self.fsdp_version == 2 and self.use_orig_params is not None: |
|
_fsdp2_warnings.add("use_orig_params is obsolete in FSDP2, as FSDP2 always uses the original parameters.") |
|
self.use_orig_params = None |
|
|
|
if self.sync_module_states is None and self.fsdp_version == 1: |
|
self.sync_module_states = str_to_bool(os.environ.get(env_prefix + "SYNC_MODULE_STATES", "False")) == 1 |
|
if self.fsdp_version == 2 and self.sync_module_states is not None: |
|
_fsdp2_warnings.add( |
|
"sync_module_states is obsolete in FSDP2, as it is not needed anymore." |
|
"Setting sync_module_states to None." |
|
) |
|
self.sync_module_states = None |
|
|
|
if self.forward_prefetch is None and self.fsdp_version == 1: |
|
self.forward_prefetch = str_to_bool(os.environ.get(env_prefix + "FORWARD_PREFETCH", "False")) == 1 |
|
if self.fsdp_version == 2 and self.forward_prefetch is not None: |
|
raise ValueError("forward_prefetch is not yet implemented in FSDP2, set to None or use `fsdp_version=1`") |
|
|
|
if self.activation_checkpointing is None: |
|
self.activation_checkpointing = ( |
|
str_to_bool(os.environ.get(env_prefix + "ACTIVATION_CHECKPOINTING", "False")) == 1 |
|
) |
|
|
|
if self.cpu_ram_efficient_loading is None: |
|
self.cpu_ram_efficient_loading = ( |
|
str_to_bool(os.environ.get(env_prefix + "CPU_RAM_EFFICIENT_LOADING", "False")) == 1 |
|
) |
|
|
|
if self.fsdp_version == 1 and self.cpu_ram_efficient_loading and not self.sync_module_states: |
|
warnings.warn( |
|
"sync_module_states cannot be False since efficient cpu ram loading enabled. " |
|
"Setting sync_module_states to True." |
|
) |
|
self.sync_module_states = True |
|
|
|
if self.cpu_ram_efficient_loading != bool( |
|
str_to_bool(os.environ.get(env_prefix + "CPU_RAM_EFFICIENT_LOADING", "False")) |
|
): |
|
env_var = env_prefix + "CPU_RAM_EFFICIENT_LOADING" |
|
warnings.warn( |
|
f"The `cpu_ram_efficient_loading` flag for `FullyShardedDataParallelPlugin` does not match the environment variable {env_var}. " |
|
"Setting environment variable to match `cpu_ram_efficient_loading`." |
|
) |
|
os.environ[env_var] = str(self.cpu_ram_efficient_loading) |
|
|
|
if isinstance(self.mixed_precision_policy, dict): |
|
self.set_mixed_precision(self.mixed_precision_policy) |
|
if self.mixed_precision_policy is not None: |
|
self.validate_mixed_precision_policy() |
|
|
|
if self.sync_module_states: |
|
if is_npu_available(): |
|
device = torch.npu.current_device() |
|
elif is_mlu_available(): |
|
device = torch.mlu.current_device() |
|
elif is_musa_available(): |
|
device = torch.musa.current_device() |
|
elif is_cuda_available(): |
|
device = torch.cuda.current_device() |
|
elif is_xpu_available(): |
|
device = torch.xpu.current_device() |
|
elif is_hpu_available(): |
|
device = torch.hpu.current_device() |
|
else: |
|
raise RuntimeError( |
|
"There are currently no available devices found, must be one of 'XPU', 'CUDA', 'MLU', 'NPU', 'MUSA', or 'HPU'." |
|
) |
|
|
|
|
|
self.param_init_fn = lambda x: x.to_empty(device=device, recurse=False) |
|
|
|
|
|
if _fsdp2_warnings: |
|
logger.warning("Multiple deprecation warnings due to FSDP2 conversion:\n".join(_fsdp2_warnings)) |
|
|
|
def set_state_dict_type(self, state_dict_type=None): |
|
""" |
|
Set the state dict config based on the `StateDictType`. |
|
""" |
|
from torch.distributed.fsdp.fully_sharded_data_parallel import ( |
|
FullOptimStateDictConfig, |
|
FullStateDictConfig, |
|
ShardedOptimStateDictConfig, |
|
ShardedStateDictConfig, |
|
StateDictType, |
|
) |
|
|
|
|
|
|
|
if state_dict_type is not None: |
|
self.state_dict_type = state_dict_type |
|
|
|
if self.state_dict_type is None: |
|
self.state_dict_type = os.environ.get( |
|
"FSDP_STATE_DICT_TYPE", "FULL_STATE_DICT" if self.fsdp_version == 1 else "SHARDED_STATE_DICT" |
|
) |
|
if isinstance(self.state_dict_type, str): |
|
if self.state_dict_type.isdigit(): |
|
self.state_dict_type = StateDictType(int(self.state_dict_type)) |
|
else: |
|
self.state_dict_type = StateDictType[self.state_dict_type.upper()] |
|
|
|
if self.state_dict_type == StateDictType.FULL_STATE_DICT: |
|
if self.state_dict_config is None: |
|
self.state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) |
|
if self.optim_state_dict_config is None: |
|
self.optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True) |
|
elif self.state_dict_type == StateDictType.SHARDED_STATE_DICT: |
|
if self.state_dict_config is None: |
|
self.state_dict_config = ShardedStateDictConfig(offload_to_cpu=True) |
|
if self.optim_state_dict_config is None: |
|
self.optim_state_dict_config = ShardedOptimStateDictConfig(offload_to_cpu=True) |
|
|
|
if self.fsdp_version == 2 and self.state_dict_type == StateDictType.LOCAL_STATE_DICT: |
|
raise ValueError( |
|
"FSDP2 does not support LOCAL_STATE_DICT. " |
|
"Please set `fsdp_state_dict_type` to `SHARDED_STATE_DICT` or `FULL_STATE_DICT`." |
|
) |
|
|
|
def set_auto_wrap_policy(self, model): |
|
""" |
|
Given `model`, creates an `auto_wrap_policy` baesd on the passed in policy and if we can use the |
|
`transformer_cls_to_wrap` |
|
""" |
|
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy |
|
|
|
|
|
no_split_modules = getattr(model, "_no_split_modules", None) |
|
default_transformer_cls_names_to_wrap = list(no_split_modules) if no_split_modules is not None else [] |
|
if self.auto_wrap_policy == transformer_auto_wrap_policy: |
|
if self.transformer_cls_names_to_wrap is None: |
|
self.transformer_cls_names_to_wrap = default_transformer_cls_names_to_wrap |
|
transformer_cls_to_wrap = set() |
|
for layer_class in self.transformer_cls_names_to_wrap: |
|
transformer_cls = get_module_class_from_name(model, layer_class) |
|
if transformer_cls is None: |
|
raise ValueError(f"Could not find the transformer layer class {layer_class} in the model.") |
|
transformer_cls_to_wrap.add(transformer_cls) |
|
|
|
self.auto_wrap_policy = functools.partial( |
|
self.auto_wrap_policy, transformer_layer_cls=transformer_cls_to_wrap |
|
) |
|
|
|
elif self.auto_wrap_policy == size_based_auto_wrap_policy: |
|
|
|
if self.min_num_params > 0: |
|
self.auto_wrap_policy = functools.partial(self.auto_wrap_policy, min_num_params=self.min_num_params) |
|
else: |
|
self.auto_wrap_policy = None |
|
|
|
def set_mixed_precision(self, mixed_precision, buffer_autocast=False, override=False): |
|
"Sets the mixed precision policy for FSDP" |
|
mixed_precision_mapping = { |
|
"fp8": torch.bfloat16, |
|
"fp16": torch.float16, |
|
"bf16": torch.bfloat16, |
|
"fp32": torch.float32, |
|
} |
|
dtype = mixed_precision |
|
if isinstance(mixed_precision, str): |
|
dtype = mixed_precision_mapping.get(mixed_precision, None) |
|
if dtype is None: |
|
raise ValueError( |
|
f"Invalid mixed precision: {mixed_precision}. Must be one of {list(mixed_precision_mapping.keys())}" |
|
) |
|
elif isinstance(mixed_precision, torch.dtype) and mixed_precision not in mixed_precision_mapping.values(): |
|
raise ValueError( |
|
f"Invalid mixed precision: {mixed_precision}. Must be one of {list(mixed_precision_mapping.values())}" |
|
) |
|
|
|
buffer_type = torch.float32 if buffer_autocast else dtype |
|
|
|
if self.fsdp_version == 1: |
|
from torch.distributed.fsdp import MixedPrecision |
|
elif self.fsdp_version == 2: |
|
from torch.distributed.fsdp import MixedPrecisionPolicy as MixedPrecision |
|
|
|
if override or self.mixed_precision_policy is None: |
|
dtype_args = {"param_dtype": dtype, "reduce_dtype": dtype} |
|
if self.fsdp_version == 1: |
|
dtype_args["buffer_dtype"] = buffer_type |
|
else: |
|
dtype_args["output_dtype"] = dtype |
|
|
|
self.mixed_precision_policy = MixedPrecision(**dtype_args) |
|
elif isinstance(self.mixed_precision_policy, dict): |
|
|
|
valid_keys = ["param_dtype", "reduce_dtype"] + ( |
|
["buffer_dtype"] if self.fsdp_version == 1 else ["output_dtype"] |
|
) |
|
missing_keys = [k for k in valid_keys if k not in self.mixed_precision_policy] |
|
invalid_values = [ |
|
k for k, v in self.mixed_precision_policy.items() if v not in mixed_precision_mapping.values() |
|
] |
|
if missing_keys or invalid_values: |
|
raise ValueError( |
|
f"Invalid mixed precision policy: {self.mixed_precision_policy}. " |
|
f"Must be a `dict` with keys {valid_keys}." |
|
f"Values must be one of {list(mixed_precision_mapping.values())}" |
|
) |
|
self.mixed_precision_policy = MixedPrecision(**self.mixed_precision_policy) |
|
|
|
def validate_mixed_precision_policy(self): |
|
""" |
|
Validates the mixed precision policy, abstracted away to not bring in the imports if not needed. |
|
""" |
|
if self.fsdp_version == 2: |
|
from torch.distributed.fsdp import MixedPrecisionPolicy as MixedPrecision |
|
else: |
|
from torch.distributed.fsdp import MixedPrecision |
|
|
|
if not isinstance(self.mixed_precision_policy, MixedPrecision): |
|
required_type = ( |
|
"`torch.distributed.fsdp.MixedPrecisionPolicy`" |
|
if self.fsdp_version == 2 |
|
else "`torch.distributed.fsdp.MixedPrecision`" |
|
) |
|
raise ValueError(f"mixed_precision_policy must be an instance of {required_type}.") |
|
|
|
def set_cpu_offload(self): |
|
if self.fsdp_version == 2: |
|
from torch.distributed.fsdp import CPUOffloadPolicy, OffloadPolicy |
|
else: |
|
from torch.distributed.fsdp import CPUOffload |
|
|
|
if isinstance(self.cpu_offload, bool): |
|
if self.fsdp_version == 2: |
|
if not self.cpu_offload: |
|
self.cpu_offload = OffloadPolicy() |
|
else: |
|
self.cpu_offload = CPUOffloadPolicy() |
|
else: |
|
self.cpu_offload = CPUOffload(offload_params=self.cpu_offload) |
|
|
|
def validate_cpu_offload(self): |
|
if self.fsdp_version == 2: |
|
from torch.distributed.fsdp import OffloadPolicy |
|
else: |
|
from torch.distributed.fsdp import CPUOffload |
|
|
|
if self.fsdp_version == 2 and not isinstance(self.cpu_offload, OffloadPolicy): |
|
raise ValueError( |
|
f"`cpu_offload` must be an instance of `torch.distributed.fsdp.OffloadPolicy` in FSDP2, got {self.cpu_offload}" |
|
) |
|
if self.fsdp_version == 1 and not isinstance(self.cpu_offload, CPUOffload): |
|
raise ValueError( |
|
f"`cpu_offload` must be an instance of `torch.distributed.fsdp.CPUOffload` in FSDP1, got {self.cpu_offload}" |
|
) |
|
|
|
|
|
@dataclass |
|
class TorchTensorParallelPlugin: |
|
""" |
|
This plugin is used to enable tensor parallelism using PyTorch >= 2.0. |
|
""" |
|
|
|
tp_size: int = field( |
|
default=1, |
|
metadata={"help": "tensor parallel size will be used in the device mesh preparation"}, |
|
) |
|
|
|
|
|
torch_device_mesh: Optional["torch.distributed.DeviceMesh"] = field(default=None) |
|
|
|
def __post_init__(self): |
|
if not isinstance(self.tp_size, int): |
|
raise ValueError(f"`tp_size` set to {self.tp_size}, please set to an `int`.") |
|
|
|
if self.tp_size <= 1: |
|
raise ValueError("`tp_size` must be greater than 1.") |
|
|
|
if is_torch_version("<", BETA_TP_AVAILABLE_PYTORCH_VERSION): |
|
raise ValueError( |
|
f"Minimum PyTorch version {BETA_TP_AVAILABLE_PYTORCH_VERSION} needed to use tensor parallel." |
|
) |
|
from torch.distributed.device_mesh import init_device_mesh |
|
|
|
|
|
if is_hpu_available(init_hccl=True): |
|
device = "hpu" |
|
elif is_xpu_available(): |
|
device = "xpu" |
|
else: |
|
device = "cuda" |
|
|
|
mesh_dim_name = "tp" |
|
|
|
|
|
|
|
self.torch_device_mesh = init_device_mesh(device, (self.tp_size,), mesh_dim_names=(mesh_dim_name,)) |
|
|
|
|
|
@dataclass |
|
class MegatronLMPlugin: |
|
""" |
|
Plugin for Megatron-LM to enable tensor, pipeline, sequence and data parallelism. Also to enable selective |
|
activation recomputation and optimized fused kernels. |
|
|
|
Args: |
|
tp_degree (`int`, defaults to `None`): |
|
Tensor parallelism degree. |
|
pp_degree (`int`, defaults to `None`): |
|
Pipeline parallelism degree. |
|
num_micro_batches (`int`, defaults to `None`): |
|
Number of micro-batches. |
|
gradient_clipping (`float`, defaults to `None`): |
|
Gradient clipping value based on global L2 Norm (0 to disable). |
|
sequence_parallelism (`bool`, defaults to `None`): |
|
Enable sequence parallelism. |
|
recompute_activations (`bool`, defaults to `None`): |
|
Enable selective activation recomputation. |
|
use_distributed_optimizr (`bool`, defaults to `None`): |
|
Enable distributed optimizer. |
|
pipeline_model_parallel_split_rank (`int`, defaults to `None`): |
|
Rank where encoder and decoder should be split. |
|
num_layers_per_virtual_pipeline_stage (`int`, defaults to `None`): |
|
Number of layers per virtual pipeline stage. |
|
is_train_batch_min (`str`, defaults to `True`): |
|
If both tran & eval dataloaders are specified, this will decide the `micro_batch_size`. |
|
train_iters (`int`, defaults to `None`): |
|
Total number of samples to train over all training runs. Note that either train-iters or train-samples |
|
should be provided when using `MegatronLMDummyScheduler`. |
|
train_samples (`int`, defaults to `None`): |
|
Total number of samples to train over all training runs. Note that either train-iters or train-samples |
|
should be provided when using `MegatronLMDummyScheduler`. |
|
weight_decay_incr_style (`str`, defaults to `'constant'`): |
|
Weight decay increment function. choices=["constant", "linear", "cosine"]. |
|
start_weight_decay (`float`, defaults to `None`): |
|
Initial weight decay coefficient for L2 regularization. |
|
end_weight_decay (`float`, defaults to `None`): |
|
End of run weight decay coefficient for L2 regularization. |
|
lr_decay_style (`str`, defaults to `'linear'`): |
|
Learning rate decay function. choices=['constant', 'linear', 'cosine']. |
|
lr_decay_iters (`int`, defaults to `None`): |
|
Number of iterations for learning rate decay. If None defaults to `train_iters`. |
|
lr_decay_samples (`int`, defaults to `None`): |
|
Number of samples for learning rate decay. If None defaults to `train_samples`. |
|
lr_warmup_iters (`int`, defaults to `None`): |
|
Number of iterations to linearly warmup learning rate over. |
|
lr_warmup_samples (`int`, defaults to `None`): |
|
Number of samples to linearly warmup learning rate over. |
|
lr_warmup_fraction (`float`, defaults to `None`): |
|
Fraction of lr-warmup-(iters/samples) to linearly warmup learning rate over. |
|
min_lr (`float`, defaults to `0`): |
|
Minumum value for learning rate. The scheduler clip values below this threshold. |
|
consumed_samples (`List`, defaults to `None`): |
|
Number of samples consumed in the same order as the dataloaders to `accelerator.prepare` call. |
|
no_wd_decay_cond (`Optional`, defaults to `None`): |
|
Condition to disable weight decay. |
|
scale_lr_cond (`Optional`, defaults to `None`): |
|
Condition to scale learning rate. |
|
lr_mult (`float`, defaults to `1.0`): |
|
Learning rate multiplier. |
|
megatron_dataset_flag (`bool`, defaults to `False`): |
|
Whether the format of dataset follows Megatron-LM Indexed/Cached/MemoryMapped format. |
|
seq_length (`int`, defaults to `None`): |
|
Maximum sequence length to process. |
|
encoder_seq_length (`int`, defaults to `None`): |
|
Maximum sequence length to process for the encoder. |
|
decoder_seq_length (`int`, defaults to `None`): |
|
Maximum sequence length to process for the decoder. |
|
tensorboard_dir (`str`, defaults to `None`): |
|
Path to save tensorboard logs. |
|
set_all_logging_options (`bool`, defaults to `False`): |
|
Whether to set all logging options. |
|
eval_iters (`int`, defaults to `100`): |
|
Number of iterations to run for evaluation validation/test for. |
|
eval_interval (`int`, defaults to `1000`): |
|
Interval between running evaluation on validation set. |
|
return_logits (`bool`, defaults to `False`): |
|
Whether to return logits from the model. |
|
custom_train_step_class (`Optional`, defaults to `None`): |
|
Custom train step class. |
|
custom_train_step_kwargs (`Optional`, defaults to `None`): |
|
Custom train step kwargs. |
|
custom_model_provider_function (`Optional`, defaults to `None`): |
|
Custom model provider function. |
|
custom_prepare_model_function (`Optional`, defaults to `None`): |
|
Custom prepare model function. |
|
custom_megatron_datasets_provider_function (`Optional`, defaults to `None`): |
|
Custom megatron train_valid_test datasets provider function. |
|
custom_get_batch_function (`Optional`, defaults to `None`): |
|
Custom get batch function. |
|
custom_loss_function (`Optional`, defaults to `None`): |
|
Custom loss function. |
|
other_megatron_args (`Optional`, defaults to `None`): |
|
Other Megatron-LM arguments. Please refer Megatron-LM. |
|
""" |
|
|
|
tp_degree: int = field(default=None, metadata={"help": "tensor parallelism degree."}) |
|
pp_degree: int = field(default=None, metadata={"help": "pipeline parallelism degree."}) |
|
num_micro_batches: int = field(default=None, metadata={"help": "number of micro-batches."}) |
|
gradient_clipping: float = field( |
|
default=None, metadata={"help": "gradient clipping value based on global L2 Norm (0 to disable)"} |
|
) |
|
sequence_parallelism: bool = field( |
|
default=None, |
|
metadata={"help": "enable sequence parallelism"}, |
|
) |
|
recompute_activations: bool = field( |
|
default=None, |
|
metadata={"help": "enable selective activation recomputation"}, |
|
) |
|
use_distributed_optimizer: bool = field( |
|
default=None, |
|
metadata={"help": "enable distributed optimizer"}, |
|
) |
|
pipeline_model_parallel_split_rank: int = field( |
|
default=None, metadata={"help": "Rank where encoder and decoder should be split."} |
|
) |
|
num_layers_per_virtual_pipeline_stage: int = field( |
|
default=None, metadata={"help": "Number of layers per virtual pipeline stage."} |
|
) |
|
is_train_batch_min: str = field( |
|
default=True, |
|
metadata={"help": "If both train & eval dataloaders are specified, this will decide the micro_batch_size"}, |
|
) |
|
train_iters: int = field( |
|
default=None, |
|
metadata={ |
|
"help": "Total number of iterations to train over all training runs. " |
|
"Note that either train-iters or train-samples should be provided when using `MegatronLMDummyScheduler`" |
|
}, |
|
) |
|
train_samples: int = field( |
|
default=None, |
|
metadata={ |
|
"help": "Total number of samples to train over all training runs. " |
|
"Note that either train-iters or train-samples should be provided when using `MegatronLMDummyScheduler`" |
|
}, |
|
) |
|
weight_decay_incr_style: str = field( |
|
default="constant", |
|
metadata={"help": 'Weight decay increment function. choices=["constant", "linear", "cosine"]. '}, |
|
) |
|
start_weight_decay: float = field( |
|
default=None, |
|
metadata={"help": "Initial weight decay coefficient for L2 regularization."}, |
|
) |
|
end_weight_decay: float = field( |
|
default=None, |
|
metadata={"help": "End of run weight decay coefficient for L2 regularization."}, |
|
) |
|
lr_decay_style: str = field( |
|
default="linear", |
|
metadata={"help": "Learning rate decay function. choices=['constant', 'linear', 'cosine']."}, |
|
) |
|
lr_decay_iters: int = field( |
|
default=None, |
|
metadata={"help": "Number of iterations for learning rate decay. If None defaults to `train_iters`."}, |
|
) |
|
lr_decay_samples: int = field( |
|
default=None, |
|
metadata={"help": "Number of samples for learning rate decay. If None defaults to `train_samples`."}, |
|
) |
|
lr_warmup_iters: int = field( |
|
default=None, |
|
metadata={"help": "number of iterations to linearly warmup learning rate over."}, |
|
) |
|
lr_warmup_samples: int = field( |
|
default=None, |
|
metadata={"help": "number of samples to linearly warmup learning rate over."}, |
|
) |
|
lr_warmup_fraction: float = field( |
|
default=None, |
|
metadata={"help": "fraction of lr-warmup-(iters/samples) to linearly warmup learning rate over."}, |
|
) |
|
min_lr: float = field( |
|
default=0, |
|
metadata={"help": "Minumum value for learning rate. The scheduler clip values below this threshold."}, |
|
) |
|
consumed_samples: list[int] = field( |
|
default=None, |
|
metadata={ |
|
"help": "Number of samples consumed in the same order as the dataloaders to `accelerator.prepare` call." |
|
}, |
|
) |
|
no_wd_decay_cond: Optional[Callable] = field(default=None, metadata={"help": "Condition to disable weight decay."}) |
|
scale_lr_cond: Optional[Callable] = field(default=None, metadata={"help": "Condition to scale learning rate."}) |
|
lr_mult: float = field(default=1.0, metadata={"help": "Learning rate multiplier."}) |
|
megatron_dataset_flag: bool = field( |
|
default=False, |
|
metadata={"help": "Whether the format of dataset follows Megatron-LM Indexed/Cached/MemoryMapped format."}, |
|
) |
|
seq_length: int = field( |
|
default=None, |
|
metadata={"help": "Maximum sequence length to process."}, |
|
) |
|
encoder_seq_length: int = field( |
|
default=None, |
|
metadata={"help": "Maximum sequence length to process for the encoder."}, |
|
) |
|
decoder_seq_length: int = field( |
|
default=None, |
|
metadata={"help": "Maximum sequence length to process for the decoder."}, |
|
) |
|
tensorboard_dir: str = field( |
|
default=None, |
|
metadata={"help": "Path to save tensorboard logs."}, |
|
) |
|
set_all_logging_options: bool = field( |
|
default=False, |
|
metadata={"help": "Whether to set all logging options."}, |
|
) |
|
eval_iters: int = field( |
|
default=100, metadata={"help": "Number of iterations to run for evaluation validation/test for."} |
|
) |
|
eval_interval: int = field( |
|
default=1000, metadata={"help": "Interval between running evaluation on validation set."} |
|
) |
|
return_logits: bool = field( |
|
default=False, |
|
metadata={"help": "Whether to return logits from the model."}, |
|
) |
|
|
|
|
|
custom_train_step_class: Optional[Any] = field( |
|
default=None, |
|
metadata={"help": "Custom train step class."}, |
|
) |
|
custom_train_step_kwargs: Optional[dict[str, Any]] = field( |
|
default=None, |
|
metadata={"help": "Custom train step kwargs."}, |
|
) |
|
|
|
|
|
custom_model_provider_function: Optional[Callable] = field( |
|
default=None, |
|
metadata={"help": "Custom model provider function."}, |
|
) |
|
custom_prepare_model_function: Optional[Callable] = field( |
|
default=None, |
|
metadata={"help": "Custom prepare model function."}, |
|
) |
|
custom_megatron_datasets_provider_function: Optional[Callable] = field( |
|
default=None, |
|
metadata={"help": "Custom megatron train_valid_test datasets provider function."}, |
|
) |
|
custom_get_batch_function: Optional[Callable] = field( |
|
default=None, |
|
metadata={"help": "Custom get batch function."}, |
|
) |
|
custom_loss_function: Optional[Callable] = field( |
|
default=None, |
|
metadata={"help": "Custom loss function."}, |
|
) |
|
|
|
|
|
|
|
other_megatron_args: Optional[dict[str, Any]] = field( |
|
default=None, |
|
metadata={"help": "Other Megatron-LM arguments. Please refer Megatron-LM"}, |
|
) |
|
|
|
def __post_init__(self): |
|
prefix = "MEGATRON_LM_" |
|
if self.tp_degree is None: |
|
self.tp_degree = int(os.environ.get(prefix + "TP_DEGREE", 1)) |
|
if self.pp_degree is None: |
|
self.pp_degree = int(os.environ.get(prefix + "PP_DEGREE", 1)) |
|
if self.num_micro_batches is None: |
|
self.num_micro_batches = int(os.environ.get(prefix + "NUM_MICRO_BATCHES", 1)) |
|
if self.gradient_clipping is None: |
|
self.gradient_clipping = float(os.environ.get(prefix + "GRADIENT_CLIPPING", 1.0)) |
|
if self.recompute_activations is None: |
|
self.recompute_activations = str_to_bool(os.environ.get(prefix + "RECOMPUTE_ACTIVATIONS", "False")) == 1 |
|
if self.use_distributed_optimizer is None: |
|
self.use_distributed_optimizer = ( |
|
str_to_bool(os.environ.get(prefix + "USE_DISTRIBUTED_OPTIMIZER", "False")) == 1 |
|
) |
|
if self.sequence_parallelism is None: |
|
self.sequence_parallelism = str_to_bool(os.environ.get(prefix + "SEQUENCE_PARALLELISM", "False")) == 1 |
|
|
|
if self.pp_degree > 1 or self.use_distributed_optimizer: |
|
self.DDP_impl = "local" |
|
else: |
|
self.DDP_impl = "torch" |
|
|
|
if self.consumed_samples is not None: |
|
if len(self.consumed_samples) == 1: |
|
self.consumed_samples.extend([0, 0]) |
|
elif len(self.consumed_samples) == 2: |
|
self.consumed_samples.append(0) |
|
|
|
self.megatron_lm_default_args = { |
|
"tensor_model_parallel_size": self.tp_degree, |
|
"pipeline_model_parallel_size": self.pp_degree, |
|
"pipeline_model_parallel_split_rank": self.pipeline_model_parallel_split_rank, |
|
"num_layers_per_virtual_pipeline_stage": self.num_layers_per_virtual_pipeline_stage, |
|
"DDP_impl": self.DDP_impl, |
|
"use_distributed_optimizer": self.use_distributed_optimizer, |
|
"sequence_parallel": self.sequence_parallelism, |
|
"clip_grad": self.gradient_clipping, |
|
"num_micro_batches": self.num_micro_batches, |
|
"consumed_samples": self.consumed_samples, |
|
"no_wd_decay_cond": self.no_wd_decay_cond, |
|
"scale_lr_cond": self.scale_lr_cond, |
|
"lr_mult": self.lr_mult, |
|
"megatron_dataset_flag": self.megatron_dataset_flag, |
|
"eval_iters": self.eval_iters, |
|
"eval_interval": self.eval_interval, |
|
} |
|
if self.recompute_activations: |
|
self.megatron_lm_default_args["recompute_granularity"] = "selective" |
|
if self.tensorboard_dir is not None: |
|
self.megatron_lm_default_args["tensorboard_dir"] = self.tensorboard_dir |
|
if self.set_all_logging_options: |
|
self.set_tensorboard_logging_options() |
|
if self.other_megatron_args is not None: |
|
self.megatron_lm_default_args.update(self.other_megatron_args) |
|
|
|
def set_network_size_args(self, model, batch_data=None): |
|
model_config_type = model.config.model_type.lower() |
|
for model_type in MODEL_CONFIGS_TO_MEGATRON_PARSERS.keys(): |
|
if model_type in model_config_type: |
|
MODEL_CONFIGS_TO_MEGATRON_PARSERS[model_type](self, model, batch_data) |
|
return |
|
raise ValueError( |
|
f"Accelerate Megatron-LM integration not supports {model_config_type} model. " |
|
"You can add your own model config parser." |
|
) |
|
|
|
def set_mixed_precision(self, mixed_precision): |
|
if mixed_precision == "fp16": |
|
self.megatron_lm_default_args["fp16"] = True |
|
elif mixed_precision == "bf16": |
|
self.megatron_lm_default_args["bf16"] = True |
|
self.DDP_impl = "local" |
|
self.megatron_lm_default_args["DDP_impl"] = self.DDP_impl |
|
|
|
def set_training_args(self, micro_batch_size, dp_degree): |
|
self.data_parallel_size = dp_degree |
|
self.micro_batch_size = micro_batch_size |
|
self.global_batch_size = dp_degree * micro_batch_size * self.num_micro_batches |
|
self.megatron_lm_default_args["data_parallel_size"] = self.data_parallel_size |
|
self.megatron_lm_default_args["micro_batch_size"] = self.micro_batch_size |
|
self.megatron_lm_default_args["global_batch_size"] = self.global_batch_size |
|
|
|
def set_optimizer_type(self, optimizer): |
|
optimizer_name = optimizer.__class__.__name__.lower() |
|
if "adam" in optimizer_name: |
|
self.megatron_lm_default_args["optimizer"] = "adam" |
|
self.megatron_lm_default_args["adam_beta1"] = optimizer.defaults["betas"][0] |
|
self.megatron_lm_default_args["adam_beta2"] = optimizer.defaults["betas"][1] |
|
self.megatron_lm_default_args["adam_eps"] = optimizer.defaults["eps"] |
|
elif "sgd" in optimizer_name: |
|
self.megatron_lm_default_args["optimizer"] = "sgd" |
|
self.megatron_lm_default_args["sgd_momentum"] = optimizer.defaults["momentum"] |
|
else: |
|
raise ValueError(f"Optimizer {optimizer_name} is not supported by Megatron-LM") |
|
|
|
self.megatron_lm_default_args["lr"] = optimizer.defaults["lr"] |
|
self.megatron_lm_default_args["weight_decay"] = optimizer.defaults["weight_decay"] |
|
|
|
def set_scheduler_args(self, scheduler): |
|
if self.train_iters is None: |
|
self.train_iters = scheduler.total_num_steps // self.megatron_lm_default_args["data_parallel_size"] |
|
if self.train_samples is not None: |
|
self.train_samples = None |
|
warnings.warn( |
|
"Ignoring `train_samples` as `train_iters` based on scheduler is being used for training." |
|
) |
|
if self.lr_warmup_iters is None: |
|
self.lr_warmup_iters = scheduler.warmup_num_steps // self.megatron_lm_default_args["data_parallel_size"] |
|
if self.lr_warmup_samples is not None: |
|
warnings.warn( |
|
"Ignoring `lr_warmup_samples` as `lr_warmup_iters` based on scheduler is being used for training." |
|
) |
|
self.lr_warmup_samples = 0 |
|
|
|
self.megatron_lm_default_args["train_iters"] = self.train_iters |
|
self.megatron_lm_default_args["lr_warmup_iters"] = self.lr_warmup_iters |
|
self.megatron_lm_default_args["train_samples"] = self.train_samples |
|
self.megatron_lm_default_args["lr_warmup_samples"] = self.lr_warmup_samples |
|
self.megatron_lm_default_args["lr_decay_iters"] = self.lr_decay_iters |
|
self.megatron_lm_default_args["lr_decay_samples"] = self.lr_decay_samples |
|
self.megatron_lm_default_args["lr_warmup_fraction"] = self.lr_warmup_fraction |
|
self.megatron_lm_default_args["lr_decay_style"] = self.lr_decay_style |
|
self.megatron_lm_default_args["weight_decay_incr_style"] = self.weight_decay_incr_style |
|
self.megatron_lm_default_args["start_weight_decay"] = self.start_weight_decay |
|
self.megatron_lm_default_args["end_weight_decay"] = self.end_weight_decay |
|
self.megatron_lm_default_args["min_lr"] = self.min_lr |
|
|
|
def set_tensorboard_logging_options(self): |
|
from megatron.training.arguments import _add_logging_args |
|
|
|
parser = argparse.ArgumentParser() |
|
parser = _add_logging_args(parser) |
|
logging_args = parser.parse_known_args() |
|
self.dataset_args = vars(logging_args[0]) |
|
for key, value in self.dataset_args.items(): |
|
if key.startswith("log_"): |
|
self.megatron_lm_default_args[key] = True |
|
elif key.startswith("no_log_"): |
|
self.megatron_lm_default_args[key.replace("no_", "")] = True |
|
|
|
|
|
MODEL_CONFIGS_TO_MEGATRON_PARSERS = {} |
|
|
|
|
|
def add_model_config_to_megatron_parser(model_type: str): |
|
def add_model_config_parser_helper(func): |
|
@functools.wraps(func) |
|
def wrapper(*args, **kwargs): |
|
return func(*args, **kwargs) |
|
|
|
MODEL_CONFIGS_TO_MEGATRON_PARSERS[model_type] = func |
|
return wrapper |
|
|
|
return add_model_config_parser_helper |
|
|
|
|
|
@add_model_config_to_megatron_parser("megatron-bert") |
|
def parse_bert_config(megatron_lm_plugin, model, batch_data): |
|
model_type_name = "bert" |
|
num_layers = model.config.num_hidden_layers |
|
hidden_size = model.config.hidden_size |
|
num_attention_heads = model.config.num_attention_heads |
|
max_position_embeddings = model.config.max_position_embeddings |
|
num_labels = model.config.num_labels |
|
orig_vocab_size = model.config.vocab_size |
|
pretraining_flag = False |
|
if "maskedlm" in model.__class__.__name__.lower(): |
|
pretraining_flag = True |
|
if megatron_lm_plugin.seq_length is not None: |
|
if megatron_lm_plugin.encoder_seq_length is not None: |
|
warnings.warn("Both `seq_length` and `encoder_seq_length` are set. Using `encoder_seq_length`.") |
|
megatron_lm_plugin.seq_length = megatron_lm_plugin.encoder_seq_length |
|
elif megatron_lm_plugin.encoder_seq_length is not None: |
|
megatron_lm_plugin.seq_length = megatron_lm_plugin.encoder_seq_length |
|
elif batch_data is not None: |
|
megatron_lm_plugin.seq_length = batch_data["input_ids"].shape[1] |
|
else: |
|
megatron_lm_plugin.seq_length = max_position_embeddings |
|
megatron_lm_plugin.megatron_lm_default_args["seq_length"] = megatron_lm_plugin.seq_length |
|
megatron_lm_plugin.megatron_lm_default_args["model_type_name"] = model_type_name |
|
megatron_lm_plugin.megatron_lm_default_args["num_layers"] = num_layers |
|
megatron_lm_plugin.megatron_lm_default_args["hidden_size"] = hidden_size |
|
megatron_lm_plugin.megatron_lm_default_args["num_attention_heads"] = num_attention_heads |
|
megatron_lm_plugin.megatron_lm_default_args["max_position_embeddings"] = max_position_embeddings |
|
megatron_lm_plugin.megatron_lm_default_args["pretraining_flag"] = pretraining_flag |
|
megatron_lm_plugin.megatron_lm_default_args["orig_vocab_size"] = orig_vocab_size |
|
megatron_lm_plugin.megatron_lm_default_args["model_return_dict"] = model.config.return_dict |
|
megatron_lm_plugin.megatron_lm_default_args["num_labels"] = num_labels |
|
|
|
|
|
@add_model_config_to_megatron_parser("gpt2") |
|
def parse_gpt2_config(megatron_lm_plugin, model, batch_data): |
|
model_type_name = "gpt" |
|
num_layers = model.config.n_layer |
|
hidden_size = model.config.n_embd |
|
num_attention_heads = model.config.n_head |
|
max_position_embeddings = model.config.n_positions |
|
orig_vocab_size = model.config.vocab_size |
|
pretraining_flag = True |
|
if megatron_lm_plugin.seq_length is not None: |
|
if megatron_lm_plugin.decoder_seq_length is not None: |
|
warnings.warn("Both `seq_length` and `decoder_seq_length` are set. Using `decoder_seq_length`.") |
|
megatron_lm_plugin.seq_length = megatron_lm_plugin.decoder_seq_length |
|
elif megatron_lm_plugin.decoder_seq_length is not None: |
|
megatron_lm_plugin.seq_length = megatron_lm_plugin.decoder_seq_length |
|
elif batch_data is not None: |
|
megatron_lm_plugin.seq_length = batch_data["input_ids"].shape[1] |
|
else: |
|
megatron_lm_plugin.seq_length = max_position_embeddings |
|
megatron_lm_plugin.megatron_lm_default_args["seq_length"] = megatron_lm_plugin.seq_length |
|
megatron_lm_plugin.megatron_lm_default_args["return_logits"] = megatron_lm_plugin.return_logits |
|
megatron_lm_plugin.megatron_lm_default_args["tokenizer_type"] = "GPT2BPETokenizer" |
|
megatron_lm_plugin.megatron_lm_default_args["model_type_name"] = model_type_name |
|
megatron_lm_plugin.megatron_lm_default_args["num_layers"] = num_layers |
|
megatron_lm_plugin.megatron_lm_default_args["hidden_size"] = hidden_size |
|
megatron_lm_plugin.megatron_lm_default_args["num_attention_heads"] = num_attention_heads |
|
megatron_lm_plugin.megatron_lm_default_args["max_position_embeddings"] = max_position_embeddings |
|
megatron_lm_plugin.megatron_lm_default_args["pretraining_flag"] = pretraining_flag |
|
megatron_lm_plugin.megatron_lm_default_args["orig_vocab_size"] = orig_vocab_size |
|
megatron_lm_plugin.megatron_lm_default_args["model_return_dict"] = model.config.return_dict |
|
|
|
|
|
@add_model_config_to_megatron_parser("t5") |
|
def parse_t5_config(megatron_lm_plugin, model, batch_data): |
|
model_type_name = "t5" |
|
num_layers = model.config.num_layers |
|
hidden_size = model.config.d_model |
|
num_attention_heads = model.config.num_heads |
|
max_position_embeddings = model.config.n_positions if hasattr(model.config, "n_positions") else 1024 |
|
orig_vocab_size = model.config.vocab_size |
|
pretraining_flag = True |
|
if megatron_lm_plugin.encoder_seq_length is None: |
|
if batch_data is not None: |
|
megatron_lm_plugin.encoder_seq_length = batch_data["input_ids"].shape[1] |
|
else: |
|
megatron_lm_plugin.encoder_seq_length = max_position_embeddings |
|
if megatron_lm_plugin.decoder_seq_length is None: |
|
if batch_data is not None: |
|
megatron_lm_plugin.decoder_seq_length = batch_data["labels"].shape[1] |
|
else: |
|
megatron_lm_plugin.decoder_seq_length = max_position_embeddings |
|
megatron_lm_plugin.megatron_lm_default_args["encoder_seq_length"] = megatron_lm_plugin.encoder_seq_length |
|
megatron_lm_plugin.megatron_lm_default_args["decoder_seq_length"] = megatron_lm_plugin.decoder_seq_length |
|
megatron_lm_plugin.megatron_lm_default_args["model_type_name"] = model_type_name |
|
megatron_lm_plugin.megatron_lm_default_args["num_layers"] = num_layers |
|
megatron_lm_plugin.megatron_lm_default_args["hidden_size"] = hidden_size |
|
megatron_lm_plugin.megatron_lm_default_args["num_attention_heads"] = num_attention_heads |
|
megatron_lm_plugin.megatron_lm_default_args["max_position_embeddings"] = max_position_embeddings |
|
megatron_lm_plugin.megatron_lm_default_args["pretraining_flag"] = pretraining_flag |
|
megatron_lm_plugin.megatron_lm_default_args["orig_vocab_size"] = orig_vocab_size |
|
megatron_lm_plugin.megatron_lm_default_args["model_return_dict"] = model.config.return_dict |
|
|
|
|
|
@add_model_config_to_megatron_parser("llama") |
|
def parse_llama_config(megatron_lm_plugin, model, batch_data): |
|
model_type_name = "gpt" |
|
num_layers = model.config.num_hidden_layers |
|
pretraining_flag = True |
|
hidden_size = model.config.hidden_size |
|
num_attention_heads = model.config.num_attention_heads |
|
orig_vocab_size = model.config.vocab_size |
|
|
|
max_position_embeddings = model.config.max_position_embeddings |
|
seq_length = getattr(model.config, "max_sequence_length", None) |
|
if megatron_lm_plugin.seq_length is None: |
|
if seq_length is not None: |
|
megatron_lm_plugin.seq_length = seq_length |
|
elif megatron_lm_plugin.decoder_seq_length is not None: |
|
megatron_lm_plugin.seq_length = megatron_lm_plugin.decoder_seq_length |
|
elif batch_data is not None: |
|
megatron_lm_plugin.seq_length = batch_data["input_ids"].shape[1] |
|
else: |
|
megatron_lm_plugin.seq_length = max_position_embeddings |
|
|
|
megatron_lm_plugin.megatron_lm_default_args["return_logits"] = megatron_lm_plugin.return_logits |
|
megatron_lm_plugin.megatron_lm_default_args["tokenizer_type"] = "Llama2Tokenizer" |
|
megatron_lm_plugin.megatron_lm_default_args["model_type_name"] = model_type_name |
|
megatron_lm_plugin.megatron_lm_default_args["num_layers"] = num_layers |
|
megatron_lm_plugin.megatron_lm_default_args["pretraining_flag"] = pretraining_flag |
|
megatron_lm_plugin.megatron_lm_default_args["hidden_size"] = hidden_size |
|
megatron_lm_plugin.megatron_lm_default_args["num_attention_heads"] = num_attention_heads |
|
megatron_lm_plugin.megatron_lm_default_args["orig_vocab_size"] = orig_vocab_size |
|
megatron_lm_plugin.megatron_lm_default_args["max_position_embeddings"] = max_position_embeddings |
|
megatron_lm_plugin.megatron_lm_default_args["seq_length"] = megatron_lm_plugin.seq_length |
|
megatron_lm_plugin.megatron_lm_default_args["model_return_dict"] = model.config.return_dict |
|
|
|
|
|
@dataclass |
|
class BnbQuantizationConfig: |
|
""" |
|
A plugin to enable BitsAndBytes 4bit and 8bit quantization |
|
|
|
Args: |
|
load_in_8bit (`bool`, defaults to `False`): |
|
Enable 8bit quantization. |
|
llm_int8_threshold (`float`, defaults to `6.0`): |
|
Value of the outliner threshold. Only relevant when `load_in_8bit=True`. |
|
load_in_4_bit (`bool`, defaults to `False`): |
|
Enable 4bit quantization. |
|
bnb_4bit_quant_type (`str`, defaults to `fp4`): |
|
Set the quantization data type in the `bnb.nn.Linear4Bit` layers. Options are {'fp4','np4'}. |
|
bnb_4bit_use_double_quant (`bool`, defaults to `False`): |
|
Enable nested quantization where the quantization constants from the first quantization are quantized |
|
again. |
|
bnb_4bit_compute_dtype (`bool`, defaults to `fp16`): |
|
This sets the computational type which might be different than the input time. For example, inputs might be |
|
fp32, but computation can be set to bf16 for speedups. Options are {'fp32','fp16','bf16'}. |
|
torch_dtype (`torch.dtype`, defaults to `None`): |
|
This sets the dtype of the remaining non quantized layers. `bitsandbytes` library suggests to set the value |
|
to `torch.float16` for 8 bit model and use the same dtype as the compute dtype for 4 bit model. |
|
skip_modules (`List[str]`, defaults to `None`): |
|
An explicit list of the modules that we don't quantize. The dtype of these modules will be `torch_dtype`. |
|
keep_in_fp32_modules (`List`, defaults to `None`): |
|
An explicit list of the modules that we don't quantize. We keep them in `torch.float32`. |
|
""" |
|
|
|
load_in_8bit: bool = field(default=False, metadata={"help": "enable 8bit quantization."}) |
|
|
|
llm_int8_threshold: float = field( |
|
default=6.0, metadata={"help": "value of the outliner threshold. only relevant when load_in_8bit=True"} |
|
) |
|
|
|
load_in_4bit: bool = field(default=False, metadata={"help": "enable 4bit quantization."}) |
|
|
|
bnb_4bit_quant_type: str = field( |
|
default="fp4", |
|
metadata={ |
|
"help": "set the quantization data type in the `bnb.nn.Linear4Bit` layers. Options are {'fp4','nf4'}." |
|
}, |
|
) |
|
|
|
bnb_4bit_use_double_quant: bool = field( |
|
default=False, |
|
metadata={ |
|
"help": "enable nested quantization where the quantization constants from the first quantization are quantized again." |
|
}, |
|
) |
|
|
|
bnb_4bit_compute_dtype: str = field( |
|
default="fp16", |
|
metadata={ |
|
"help": "This sets the computational type which might be different than the input time. For example, inputs might be " |
|
"fp32, but computation can be set to bf16 for speedups. Options are {'fp32','fp16','bf16'}." |
|
}, |
|
) |
|
|
|
torch_dtype: torch.dtype = field( |
|
default=None, |
|
metadata={ |
|
"help": "this sets the dtype of the remaining non quantized layers. `bitsandbytes` library suggests to set the value" |
|
"to `torch.float16` for 8 bit model and use the same dtype as the compute dtype for 4 bit model " |
|
}, |
|
) |
|
|
|
skip_modules: list[str] = field( |
|
default=None, |
|
metadata={ |
|
"help": "an explicit list of the modules that we don't quantize. The dtype of these modules will be `torch_dtype`." |
|
}, |
|
) |
|
|
|
keep_in_fp32_modules: list[str] = field( |
|
default=None, |
|
metadata={"help": "an explicit list of the modules that we don't quantize. We keep them in `torch.float32`."}, |
|
) |
|
|
|
def __post_init__(self): |
|
""" |
|
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. |
|
""" |
|
if not isinstance(self.load_in_8bit, bool): |
|
raise ValueError("load_in_8bit must be a boolean") |
|
|
|
if not isinstance(self.load_in_4bit, bool): |
|
raise ValueError("load_in_4bit must be a boolean") |
|
|
|
if self.load_in_4bit and self.load_in_8bit: |
|
raise ValueError("load_in_4bit and load_in_8bit can't be both True") |
|
|
|
if not self.load_in_4bit and not self.load_in_8bit: |
|
raise ValueError("load_in_4bit and load_in_8bit can't be both False") |
|
|
|
if not isinstance(self.llm_int8_threshold, (int, float)): |
|
raise ValueError("llm_int8_threshold must be a float or an int") |
|
|
|
if not isinstance(self.bnb_4bit_quant_type, str): |
|
raise ValueError("bnb_4bit_quant_type must be a string") |
|
elif self.bnb_4bit_quant_type not in ["fp4", "nf4"]: |
|
raise ValueError(f"bnb_4bit_quant_type must be in ['fp4','nf4'] but found {self.bnb_4bit_quant_type}") |
|
|
|
if not isinstance(self.bnb_4bit_use_double_quant, bool): |
|
raise ValueError("bnb_4bit_use_double_quant must be a boolean") |
|
|
|
if isinstance(self.bnb_4bit_compute_dtype, str): |
|
if self.bnb_4bit_compute_dtype == "fp32": |
|
self.bnb_4bit_compute_dtype = torch.float32 |
|
elif self.bnb_4bit_compute_dtype == "fp16": |
|
self.bnb_4bit_compute_dtype = torch.float16 |
|
elif self.bnb_4bit_compute_dtype == "bf16": |
|
self.bnb_4bit_compute_dtype = torch.bfloat16 |
|
else: |
|
raise ValueError( |
|
f"bnb_4bit_compute_dtype must be in ['fp32','fp16','bf16'] but found {self.bnb_4bit_compute_dtype}" |
|
) |
|
elif not isinstance(self.bnb_4bit_compute_dtype, torch.dtype): |
|
raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype") |
|
|
|
if self.skip_modules is not None and not isinstance(self.skip_modules, list): |
|
raise ValueError("skip_modules must be a list of strings") |
|
|
|
if self.keep_in_fp32_modules is not None and not isinstance(self.keep_in_fp32_modules, list): |
|
raise ValueError("keep_in_fp_32_modules must be a list of strings") |
|
|
|
if self.load_in_4bit: |
|
self.target_dtype = CustomDtype.INT4 |
|
|
|
if self.load_in_8bit: |
|
self.target_dtype = torch.int8 |
|
|
|
if self.load_in_4bit and self.llm_int8_threshold != 6.0: |
|
warnings.warn("llm_int8_threshold can only be used for model loaded in 8bit") |
|
|
|
if isinstance(self.torch_dtype, str): |
|
if self.torch_dtype == "fp32": |
|
self.torch_dtype = torch.float32 |
|
elif self.torch_dtype == "fp16": |
|
self.torch_dtype = torch.float16 |
|
elif self.torch_dtype == "bf16": |
|
self.torch_dtype = torch.bfloat16 |
|
else: |
|
raise ValueError(f"torch_dtype must be in ['fp32','fp16','bf16'] but found {self.torch_dtype}") |
|
if self.load_in_8bit and self.torch_dtype is None: |
|
self.torch_dtype = torch.float16 |
|
|
|
if self.load_in_4bit and self.torch_dtype is None: |
|
self.torch_dtype = self.bnb_4bit_compute_dtype |
|
|
|
if not isinstance(self.torch_dtype, torch.dtype): |
|
raise ValueError("torch_dtype must be a torch.dtype") |
|
|
|
|
|
def get_module_class_from_name(module, name): |
|
""" |
|
Gets a class from a module by its name. |
|
|
|
Args: |
|
module (`torch.nn.Module`): The module to get the class from. |
|
name (`str`): The name of the class. |
|
""" |
|
modules_children = list(module.children()) |
|
if module.__class__.__name__ == name: |
|
return module.__class__ |
|
elif len(modules_children) == 0: |
|
return |
|
else: |
|
for child_module in modules_children: |
|
module_class = get_module_class_from_name(child_module, name) |
|
if module_class is not None: |
|
return module_class |
|
|