SongGeneration / codeclm /utils /offload_profiler.py
root
add lowmem mode
98a0e3b
import torch
from torch.func import functional_call
import queue
import threading
from typing import Dict, List, Any
import omegaconf
from pydantic import BaseModel, validator
from typing import Optional
from functools import wraps
def _callable_once(func):
@wraps(func)
def wrapper(self, *args, **kwargs):
method_called_flag = f"_called_once_{func.__name__}"
if getattr(self, method_called_flag, False):
raise RuntimeError(f"{func.__name__} can only be called once.")
setattr(self, method_called_flag, True)
return func(self, *args, **kwargs)
return wrapper
class OffloadCleanCacheWrapperParam(BaseModel):
module: Any
method_name: str
diff_mem_gb_thre: float
class OffloadParam(BaseModel):
offload_module: Any
cpu_mem_gb: float
pre_copy_step: Optional[int] = None
clean_cache_after_forward: Optional[bool] = None
dtype: Optional[str] = None
offload_layer_dict: Dict[str, int] = {}
ignore_layer_list: List[str] = []
clean_cache_wrapper: Optional[OffloadCleanCacheWrapperParam] = None
debug: Optional[bool] = None
@validator('dtype')
def parse_dtype(cls, value):
if value is None:
return None
dtype_map = {
'torch.float16': torch.float16,
'torch.float32': torch.float32,
'torch.float64': torch.float64,
'torch.int64': torch.int64,
}
if value not in dtype_map:
raise ValueError(f"Unsupported dtype: {value}")
return dtype_map[value]
def init_param_dict(self):
param_dict = {}
param_dict['cpu_mem_gb'] = self.cpu_mem_gb
if self.pre_copy_step is not None:
param_dict['pre_copy_step'] = self.pre_copy_step
if self.clean_cache_after_forward is not None:
param_dict['clean_cache_after_forward'] = self.clean_cache_after_forward
if self.debug is not None:
param_dict['debug'] = self.debug
return param_dict
def offload_layer_param_dict(self):
param_dict = {}
param_dict['module'] = self.offload_module
param_dict['offload_layer_dict'] = self.offload_layer_dict
param_dict['ignore_layer_list'] = self.ignore_layer_list
param_dict['dtype'] = self.dtype
return param_dict
def clean_cache_param_dict(self):
param_dict = {}
if self.clean_cache_wrapper is not None:
param_dict['module'] = self.clean_cache_wrapper.module
param_dict['method_name'] = self.clean_cache_wrapper.method_name
param_dict['diff_mem_gb_thre'] = self.clean_cache_wrapper.diff_mem_gb_thre
return param_dict
@staticmethod
def recursive_print(model, indent=0):
for field_name, field_info in model.__fields__.items():
field_value = getattr(model, field_name)
print(" " * indent + f"{field_name}:")
if issubclass(type(field_value), BaseModel):
print(" " * (indent + 2) + f"--- Nested model: {field_value.__class__.__name__}")
OffloadParam.recursive_print(field_value, indent + 4)
else:
print(" " * (indent + 2) + f"class: {field_value.__class__.__name__}")
if isinstance(field_value, torch.nn.Module):
pass
else:
print(" " * (indent + 2) + f"value: {field_value}")
def show(self):
print("-"*20 + "[OffloadParam]" + "-"*20)
OffloadParam.recursive_print(self)
print("-"*40)
class OffloadParamParse:
def __init__(self):
pass
@staticmethod
def _get_model(root_model: torch.nn.Module, model_dir: str):
assert(model_dir.startswith("self")), f"model_dir {model_dir} must startswith `self`"
model = root_model
for layer in model_dir.split('.'):
if layer == "self":
continue
assert(hasattr(model, layer)), f"model not has layer [{layer}]!"
model = getattr(model, layer)
return model
@staticmethod
def parse_config(root_model: torch.nn.Module, cfg: omegaconf.DictConfig)->OffloadParam:
assert(hasattr(cfg, "offload_module") and hasattr(cfg, "cpu_mem_gb") and hasattr(cfg, "dtype"))
offload_module = OffloadParamParse._get_model(root_model, cfg.offload_module)
cpu_mem_gb = cfg.cpu_mem_gb
dtype = cfg.dtype
pre_copy_step = cfg.pre_copy_step \
if hasattr(cfg, "pre_copy_step") else None
clean_cache_after_forward = cfg.clean_cache_after_forward \
if hasattr(cfg, "clean_cache_after_forward") else None
offload_layer_dict = {k: v for k, v in cfg.offload_layer_dict.items()} \
if hasattr(cfg, "offload_layer_dict") else {}
ignore_layer_list = cfg.ignore_layer_list \
if hasattr(cfg, "ignore_layer_list") else []
debug = cfg.debug if hasattr(cfg, "debug") else None
clean_cache_wrapper = None
if hasattr(cfg, "clean_cache_wrapper"):
clean_cache_cfg = cfg.clean_cache_wrapper
cc_module = OffloadParamParse._get_model(root_model, clean_cache_cfg.module)
cc_method_name = clean_cache_cfg.method_name
diff_mem_gb_thre = clean_cache_cfg.diff_mem_gb_thre
clean_cache_wrapper = OffloadCleanCacheWrapperParam(
module=cc_module,
method_name=cc_method_name,
diff_mem_gb_thre=diff_mem_gb_thre)
return OffloadParam(
offload_module=offload_module,
cpu_mem_gb=cpu_mem_gb,
pre_copy_step=pre_copy_step,
clean_cache_after_forward=clean_cache_after_forward,
dtype=dtype,
offload_layer_dict=offload_layer_dict,
ignore_layer_list=ignore_layer_list,
clean_cache_wrapper=clean_cache_wrapper,
debug=debug
)
class LayerParamStruct:
def __init__(self):
self.count = 0
self.device_state = None
class OffloadProfiler:
def __init__(self, device_index=0, cpu_mem_gb=-1, pre_copy_step=1, clean_cache_after_forward=False, debug=False):
self.clean_cache_after_forward = clean_cache_after_forward
self.cpu_mem_gb = cpu_mem_gb
self.cpu_mem_b_count = 0
self.device_index = device_index
self.execution_order = []
self.execution_order_idx = {}
self.pin_memory = False
test_data = torch.rand(1,1, device='cpu')
pin_data = test_data.pin_memory()
self.pin_memory = pin_data.is_pinned()
print(f"pin:{self.pin_memory}")
self.copy_stream = torch.cuda.Stream()
self.copy_queue = queue.Queue()
self.layer_param:Dict[str, LayerParamStruct] = {}
self.model_map = {}
self.stop_flag = False
self.copy_condition = threading.Condition()
self.queue_condition = threading.Condition()
self.mem_line_b = 0
self.copy_thread = threading.Thread(target=self._copy_thread_fun)
self.copy_thread.daemon = True
self.copy_thread.start()
self.cur_copy_idx = 0
self.execute_over = False
self.pre_copy_step = pre_copy_step
self.tmp_state_list = []
self.tmp_state_idx = 0
for i in range(pre_copy_step + 2):
self.tmp_state_list.append(None)
self.debug = debug
def stop(self):
self.stop_flag = True
with self.queue_condition:
self.queue_condition.notify()
self.copy_thread.join()
del self.layer_param
del self.model_map
del self.copy_stream
def _copy_thread_fun(self):
while self.stop_flag == False:
layer_name = "--"
with self.queue_condition:
while self.copy_queue.qsize() == 0 and self.stop_flag == False:
self.queue_condition.wait()
if self.stop_flag == True:
break
layer_name = self.copy_queue.get()
with torch.cuda.stream(self.copy_stream):
if layer_name in self.model_map:
model = self.model_map[layer_name]
self.tmp_state_list[self.tmp_state_idx] = {
k: v.to(torch.device(f"cuda:{self.device_index}"), non_blocking=False)
for k, v in model.state_dict().items()
}
self.copy_stream.synchronize()
device_state = self.tmp_state_list[self.tmp_state_idx]
self.tmp_state_idx = (self.tmp_state_idx + 1) % len(self.tmp_state_list)
with self.copy_condition:
if layer_name in self.layer_param:
self.layer_param[layer_name].count += 1
else:
self.layer_param[layer_name] = LayerParamStruct()
self.layer_param[layer_name].count = 1
self.layer_param[layer_name].device_state = device_state
self.copy_condition.notify()
else:
print(f"get model error! {layer_name}")
print("copy thread stop..")
def _get_new_step_copy_begin_end(self, tag_name):
pre_copy_step = self.pre_copy_step
pre_copy_step = min(pre_copy_step, len(self.execution_order) // 2)
cur_exe_idx = self.execution_order_idx[tag_name]
copy_begin = self.cur_copy_idx
copy_end = cur_exe_idx + pre_copy_step + 1
if copy_end - copy_begin > len(self.execution_order):
copy_end %= len(self.execution_order)
if copy_end - copy_begin > pre_copy_step + 1 or copy_end - copy_begin < 0:
# jump
self.cur_copy_idx = cur_exe_idx
copy_begin, copy_end = self._get_new_step_copy_begin_end(tag_name=tag_name)
return copy_begin, copy_end
def make_forward_wrapper(self, module, tag_name, ignore_layer_list=[]):
original_forward = module.forward
layer_param_size = 0
for name, param in module.named_parameters():
layer_param_size += param.data.numel() * param.data.element_size() / 1024 / 1024 #MB
taget_cpu_mem_b = self.cpu_mem_gb * 1024 * 1024 * 1024
offload = False
for name, param in module.named_parameters():
p_name = f"{tag_name}.{name}" if tag_name else name
for i_layer in ignore_layer_list:
if p_name.startswith(i_layer):
if self.debug:
print(f"ignore layer param: {p_name}")
continue
if taget_cpu_mem_b >= 0 and self.cpu_mem_b_count >= taget_cpu_mem_b:
break
cpu_data = torch.empty_strided(size=param.data.size(),
stride=param.data.stride(),
dtype=param.data.dtype,
layout=param.data.layout,
device='cpu',
pin_memory=self.pin_memory)
cpu_data.copy_(param.data)
param.data = cpu_data
param_size = param.data.numel() * param.data.element_size()
self.cpu_mem_b_count += param_size
offload = True
if self.debug:
print(f"layer: {tag_name}, type: {module.__class__.__name__}, size(MB): {layer_param_size}, offload: {offload}, sum_offload_size(MB): {self.cpu_mem_b_count/1024/1024}")
if offload:
copy_condition = self.copy_condition
queue_condition = self.queue_condition
copy_queue = self.copy_queue
layer_param = self.layer_param
def forward_wrapper(*args, **kwargs):
module.forward = original_forward
execute_over = False if tag_name not in self.execution_order_idx else True
if execute_over == False:
self.model_map[tag_name] = module
self.execution_order.append(tag_name)
self.execution_order_idx[tag_name] = len(self.execution_order) - 1
copy_queue.put(tag_name)
with queue_condition:
queue_condition.notify()
else:
copy_begin, copy_end = self._get_new_step_copy_begin_end(tag_name=tag_name)
if copy_end > copy_begin:
for idx in range(copy_begin, copy_end):
idx = idx % len(self.execution_order)
copy_tag_name = self.execution_order[idx]
copy_queue.put(copy_tag_name)
with queue_condition:
queue_condition.notify()
self.cur_copy_idx = copy_end % len(self.execution_order)
run_state = None
with self.copy_condition:
while tag_name not in self.layer_param:
copy_condition.wait()
run_state = self.layer_param[tag_name].device_state
self.layer_param[tag_name].count -= 1
module.eval()
with torch.no_grad():
output = functional_call(module, run_state, args=args, kwargs=kwargs)
with self.copy_condition:
if self.layer_param[tag_name].count == 0:
del self.layer_param[tag_name]
diff_mem_b_thre = 1 * (1024 ** 3)
if self.clean_cache_after_forward:
reserved = torch.cuda.memory_reserved()
if reserved > self.mem_line_b:
torch.cuda.empty_cache()
cur_reserved = torch.cuda.memory_reserved()
diff_mem = reserved - cur_reserved
if diff_mem > diff_mem_b_thre:
self.mem_line_b = cur_reserved + (reserved - cur_reserved) / 2 + 10
else:
self.mem_line_b = reserved + 10
if self.debug:
print(f"child mem line update, clean cache:{reserved/1024/1024}, cur mem: {cur_reserved/1024/1024} new limit: {self.mem_line_b / 1024 / 1024}, child name: {tag_name}")
module.forward = forward_wrapper
return output
module.forward = forward_wrapper
torch.cuda.empty_cache()
return module
def reset_empty_cache_mem_line(self):
self.mem_line_b = 0
torch.cuda.empty_cache()
def clean_cache_wrapper(self, module, method_name='', diff_mem_gb_thre=1):
if not hasattr(module, method_name) or not callable(getattr(module, method_name)):
print(f"no this method {method_name}")
return module
original_fun = getattr(module, method_name)
diff_mem_b_thre = diff_mem_gb_thre * (1024 ** 3)
self.reset_empty_cache_mem_line()
def clean_wrapper(*args, **kwargs):
setattr(module, method_name, original_fun)
output = original_fun(*args, **kwargs)
reserved = torch.cuda.memory_reserved()
if reserved > self.mem_line_b:
torch.cuda.empty_cache()
cur_reserved = torch.cuda.memory_reserved()
diff_mem = reserved - cur_reserved
if diff_mem > diff_mem_b_thre:
self.mem_line_b = cur_reserved + (reserved - cur_reserved) / 2 + 10
else:
self.mem_line_b = reserved + 10
if self.debug:
print(f"mem line update, clean cache:{reserved/1024/1024}, cur mem: {cur_reserved/1024/1024} new limit: {self.mem_line_b / 1024 / 1024}")
setattr(module, method_name, clean_wrapper)
return output
setattr(module, method_name, clean_wrapper)
return module
@_callable_once
def offload_layer(self, module, offload_layer_dict={}, ignore_layer_list=[], dtype:torch.dtype = None):
return self._offload_layer(
module=module,
tag="",
offload_layer_dict=offload_layer_dict,
ignore_layer_list=ignore_layer_list,
dtype=dtype
)
def _offload_layer(self, module, tag="", offload_layer_dict={}, ignore_layer_list=[], dtype:torch.dtype = None):
"""
Offload specific layers of a PyTorch model to a specified depth.
A model can only be offloaded once.
Args:
module (torch.nn.Module):
The PyTorch model containing the layers to offload. This is the model that will be modified in place.
tag (str, optional):
A string identifier for the model.
Default is an empty string.
offload_layer_dict (dict, optional):
A dictionary where keys are layer names and values represent the depth at which the offloading should occur.
For example,
```offload_layer_dict = {'cfm_wrapper': 5, 'hubert': 4}``` means that the `cfm_wrapper` layer should
be offloaded at depth 5, and the `hubert` layer should be offloaded at depth 4.
Default is an empty dictionary.
ignore_layer_list (list, optional):
A list of layer names or parameter identifiers to be ignored during the offloading process.
Layers in this list will not be offloaded, even if they are present in the `offload_layer_dict`.
For example,
```ignore_layer_list = ['cfm_wrapper.estimator.h', 'cfm_wrapper.estimator.adaln_single']```
means that layers starting with `cfm_wrapper.estimator.h` or 'cfm_wrapper.estimator.adaln_single' will not be offload.
Default is an empty list.
dtype (torch.dtype, optional):
The data type (e.g., `torch.float16`, `torch.float32`) to which the offloaded layers should be converted.
If `None`, the data type of the layers will remain unchanged. Default is `None`.
Returns:
None
"""
for p in module._parameters.values():
if p is not None:
p.data = p.data.to(torch.device(f"cuda:{self.device_index}"))
if dtype is not None:
p.data = p.data.to(dtype)
for b in module._buffers.values():
if b is not None:
b.data = b.data.to(torch.device(f"cuda:{self.device_index}"))
if dtype is not None:
b.data = b.data.to(dtype)
for attr_name, attr in module.__dict__.items():
if isinstance(attr, torch.Tensor) and not attr_name.startswith('_'):
attr.data = attr.data.to(torch.device(f"cuda:{self.device_index}"))
if dtype is not None:
attr.data = attr.data.to(dtype)
for name, child in module.named_children():
current_tag = f"{tag}.{name}" if tag else name
child = child.to(torch.device(f"cuda:{self.device_index}"))
if dtype is not None:
child = child.to(dtype)
torch.cuda.empty_cache()
setattr(module, name, child)
pre_name = current_tag.split('.')[0]
if pre_name not in offload_layer_dict:
param_size = 0
for p in child.parameters():
param_size += p.data.numel() * p.data.element_size()
param_size = param_size / 1024 / 1024
if self.debug:
print(f"not offload layer {current_tag}, size: {param_size}MB")
continue
has_children = any(child.named_children())
layer_count = current_tag.count('.') + 1
layer_deep = offload_layer_dict[pre_name]
if layer_count >= layer_deep:
has_children = False
if has_children:
self._offload_layer(module=child,
tag=current_tag,
offload_layer_dict=offload_layer_dict,
ignore_layer_list=ignore_layer_list,
dtype=dtype)
continue
ignore = False
for i_layer in ignore_layer_list:
if current_tag.startswith(i_layer):
ignore = True
if self.debug:
print(f"ignore layer offload: {current_tag}")
break
if hasattr(child, "forward") and not ignore:
child = self.make_forward_wrapper(
child, current_tag, ignore_layer_list=ignore_layer_list
)
return module
def get_execution_order(self):
return self.execution_order