MMaDA / venv /lib /python3.11 /site-packages /fairscale /optim /layerwise_gradient_scaler.py
jamtur01's picture
Upload folder using huggingface_hub
9c6594c verified
import logging
from typing import List, Tuple
import torch
import torch.nn as nn
class LayerInfo:
"""
A class to record the layer attributes.
"""
def __init__(self, name: str, layer: nn.Module, scale: float = 1.0, scale_layer: bool = False) -> None:
"""
layer_name: name of the layer e.g. fc1, conv1, relu1
layer: type of the layer e.g. Linear, Conv2d, ReLU
scaling_factor: user configurable scaling factor for the layer, defaults to 1.0
found_inf_or_nan: a boolean indicating if any parameter of layer's gradient contains inf/nan
growth_tracker: tracks number of step since last time scale was increased
scale_layer: a boolean indicating if the layer should be scaled or not
"""
self.layer_name = name
self.layer = layer
self.scaling_factor = scale
self.found_inf_or_nan = False
self.growth_tracker = 0
self.scale_layer = scale_layer
class GradientHelper:
"""
A helper class to create instances of backward hooks. The hooks are registered in the
scale method of LayerwiseGradientScaler.
"""
def __init__(self, name: str, inputs_multiplier: float, outputs_multiplier: float):
self.layer_name = name
self.inputs_multiplier = inputs_multiplier
self.outputs_multiplier = outputs_multiplier
def scale_gradients(self, m: nn.Module, inputs: Tuple, outputs: Tuple) -> Tuple[torch.Tensor]:
"""
Backward hook that is attached to the layers to scale the gradients.
"""
scaled_up_grads = list()
for idx in range(len(inputs)):
if inputs[idx] is not None:
if self.inputs_multiplier != 1.0 or self.outputs_multiplier != 1.0:
logging.debug(
"layer = %s \t scale = %s \t scale_down = %s"
% (self.layer_name, self.inputs_multiplier, self.outputs_multiplier)
)
scaled_up_grads.append(inputs[idx].mul(self.inputs_multiplier * self.outputs_multiplier))
else:
logging.debug("next layer is None")
scaled_up_grads.append(inputs[idx])
return tuple(scaled_up_grads) # type: ignore
class LayerwiseGradientScaler:
"""
LayerwiseGradientScaler enables using distinct scaling factors for each layer
of the network.
Example:
# Create a convolutional network
class ConvNet(nn.Module):
def __init__(self):
...
def forward(self, x):
...
# Create an instance of the model
model = ConvNet()
optimizer = torch.optim.SGD(model.parameters())
# specify the layers to scale and their scaling factor
layer_scale_dict = {"conv1": 2**10, "conv2": 2**8, "fc1": 2**10, "fc2": 2**9}
scaler = LayerwiseGradientScaler(model, layer_scale_dict)
for epoch in num_epochs:
for inputs, targets in batch:
optimizer.zero_grad()
# scale the gradients
scaler.scale()
# enables mixed precision training
with autocast():
predictions = model(inputs)
loss = loss_function(predictions, targets)
loss.backward()
# unscale the gradients
loss.unscale()
# step is taken if there are no inf/nan in the gradients
# scaling factor for each layer are updated
loss.step(optimizer)
Args:
model : instance of a Model class, such as ConvNet above
layer_scale_dict (dict) : dictionary with key = layer_name and value = scaling_factor
growth_factor (float) : per layer scaling factor multiplier
backoff_factor (float) : per layer scaling factor multiplier when an inf/nan is found
growth_interval (int) : number of steps after which scale is multiplied by growth_factor
min_scaling_factor (float) : smallest scaling factor
max_scaling_factor (float) : largest scaling factor
"""
def __init__( # type: ignore
self,
model,
layer_scale_dict: dict,
growth_factor: float = 2.0,
backoff_factor: float = 0.5,
growth_interval: int = 10000,
min_scale: float = torch.finfo(torch.float32).tiny, # type: ignore
max_scale: float = torch.finfo(torch.float32).max, # type: ignore
) -> None:
self._model = model
self._layer_scale_dict: dict = layer_scale_dict
self._growth_factor: float = growth_factor
self._backoff_factor: float = backoff_factor
self._growth_interval: int = growth_interval
self._apply_layerwise_scaling: bool = True if len(layer_scale_dict.keys()) > 0 else False
self._min_scale = min_scale
self._max_scale = max_scale
self._handles: List = []
self.layer_info: List = []
if self._apply_layerwise_scaling:
assert self._growth_factor > 1.0, "The growth factor must be > 1.0."
assert self._backoff_factor < 1.0, "The backoff factor must be < 1.0."
self.layer_info = self._build_layer_info()
def _build_layer_info(self) -> List:
"""
Helper function to create a list of LayerInfo instances.
"""
layer_info_list = list()
for name, layer in self._model.named_modules():
if name != "":
if name not in self._layer_scale_dict.keys():
logging.debug("name = %s, layer = %s, scaling_factor = %s" % (name, layer, 1.0))
layer_info_list.append(LayerInfo(name, layer, 1.0))
else:
logging.debug(
"name = %s, layer = %s, scaling_factor = %s" % (name, layer, self._layer_scale_dict[name])
)
layer_info_list.append(LayerInfo(name, layer, self._layer_scale_dict[name], True))
return layer_info_list
def scale(self) -> None:
"""
For each layer calculates the scaling factor for preceding layer's grad inputs
and current layer's grad outputs. These values are used to register a full backward
hook. The handle returned from registering the backward hook is appended to a list
of handles. New hooks are created and registered at every step and a new list of
handles is created. The handles are flushed out in the unscale function.
"""
if not self._apply_layerwise_scaling:
return
for idx in range(len(self.layer_info)):
elt = self.layer_info[idx]
layer_name, layer = elt.layer_name, elt.layer
inputs_multiplier = 1.0
if idx > 0:
inputs_multiplier = self.layer_info[idx - 1].scaling_factor
outputs_multiplier = 1.0 / elt.scaling_factor
helper = GradientHelper(layer_name, inputs_multiplier, outputs_multiplier)
layer_handle = layer.register_full_backward_hook(helper.scale_gradients)
self._handles.append(layer_handle)
logging.debug("name = %s \t scale = %s" % (layer_name, elt.scaling_factor))
def _get_layers_with_finite_values(self) -> List[LayerInfo]:
layers_with_finite_values: List = []
for item in self.layer_info:
if not item.found_inf_or_nan:
layers_with_finite_values.append(item)
return layers_with_finite_values
def unscale(self) -> None:
"""
For each layer, check if any of the layer's parameters contain an inf/nan.
If there are no inf/nan in the gradient, then gradient of that layer is
unscaled by the reciprocal of the scaling factor for that layer.
Finally, all handles recorded while registering the hooks are deleted.
"""
if not self._apply_layerwise_scaling:
return
layers_with_finite_values = self._get_layers_with_finite_values()
for item in layers_with_finite_values:
for param_name, param in item.layer.named_parameters():
if hasattr(param, "grad") and param.grad is not None:
logging.debug("%s scaling down %s by %s" % (item.layer_name, param_name, 1.0 / item.scaling_factor))
param.grad.mul_(1.0 / item.scaling_factor)
while len(self._handles) > 0:
elt = self._handles.pop()
elt.remove()
def _check_for_inf_or_nan(self) -> None:
"""
For each layer, check if any of the parameters with a gradient attribute
contain an inf/nan. If any of the parameters' gradient contain an inf/nan,
then that layer's found_inf_or_nan attribute is set to True and all
remaining parameters for that layer are skipped.
"""
for elt in self.layer_info:
elt.found_inf_or_nan = False
for _, param in elt.layer.named_parameters():
if hasattr(param, "grad") and param.grad is not None:
if torch.isinf(param.grad).any().item() or torch.isnan(param.grad).any().item(): # type: ignore
elt.found_inf_or_nan = True
break # skip all remaining named parameters
def step(self, optimizer) -> None: # type: ignore
"""
If there are no inf/nan in the gradients' of all layers, then optimizer
takes a step, otherwise not. Update the scaling factor for each layer.
"""
# using layerwise gradient scaling
if self._apply_layerwise_scaling:
self._check_for_inf_or_nan()
inf_nan_found = any(elt.found_inf_or_nan for elt in self.layer_info)
if not inf_nan_found:
optimizer.step()
self._update_scale()
# not using layerwise gradient scaling
else:
optimizer.step()
def _update_scale(self) -> None:
"""
For each layer, if an inf/nan is found, then multiply the scaling factor
of that layer by the backoff factor and set the growth tracker of that
layer to 0. Else, increment the growth tracker of the layer. If growth
tracker equals the growth interval, then multiply the scaling factor of
the layer by the growth factor and reset the layer's growth tracker to 0.
Finally, clip the scaling factor to the range
[self.min_scaling_factor, self.max_scaling_factor]. The min/max scaling
factor values are user configurable.
"""
if not self._apply_layerwise_scaling:
return
for layer in self.layer_info:
if layer.found_inf_or_nan:
if layer.scale_layer:
layer.scaling_factor = max(
self._min_scale,
min(self._backoff_factor * layer.scaling_factor, self._max_scale),
)
layer.growth_tracker = 0
else:
layer.growth_tracker += 1
if layer.scale_layer and layer.growth_tracker == self._growth_interval:
layer.scaling_factor = max(
self._min_scale,
min(self._growth_factor * layer.scaling_factor, self._max_scale),
)
layer.growth_tracker = 0
def get_layer_info(self) -> List[LayerInfo]:
"""
Returns a list of LayerInfo instances of the model.
"""
return self.layer_info
def get_backward_hooks(self) -> List:
"""
Returns a list of tuples. Each tuple contains the layer name and the
hook attached to it.
"""
layer_name_and_hooks = list()
for name, layer in self._model.named_modules():
if name != "":
layer_name_and_hooks.append((name, layer._get_backward_hooks()))
return layer_name_and_hooks