| from typing import Any | |
| import torch | |
| from .attention import ATTN_CLASS_REGISTRY | |
| from .blocks import MPTBlock | |
| from .ffn import FFN_CLASS_REGISTRY | |
| from .norm import NORM_CLASS_REGISTRY | |
| def pass_on_block_idx(parent: torch.nn.Module): | |
| if not hasattr(parent, 'block_idx') or not hasattr(parent, 'max_block_idx'): | |
| return | |
| for child in parent.children(): | |
| child.block_idx = parent.block_idx | |
| child.max_block_idx = parent.max_block_idx | |
| if child.children(): | |
| pass_on_block_idx(child) | |
| def get_act_ckpt_module(mod_name: str) -> Any: | |
| """Get the module type from the module name.""" | |
| if mod_name.lower() == 'mptblock': | |
| mod_type = MPTBlock | |
| elif mod_name in ATTN_CLASS_REGISTRY: | |
| mod_type = ATTN_CLASS_REGISTRY[mod_name] | |
| elif mod_name in FFN_CLASS_REGISTRY: | |
| mod_type = FFN_CLASS_REGISTRY[mod_name] | |
| elif mod_name in NORM_CLASS_REGISTRY: | |
| mod_type = NORM_CLASS_REGISTRY[mod_name] | |
| else: | |
| msg = ', '.join(list(ATTN_CLASS_REGISTRY.keys()) + list(FFN_CLASS_REGISTRY.keys()) + list(NORM_CLASS_REGISTRY.keys()) + ['MPTBlock']) | |
| raise ValueError(f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.') | |
| return mod_type | |
| def parse_ele_str(ele: str, max_block_idx: int) -> list: | |
| """Parse a string in target_blocks and return a list of block ids to add. | |
| Supported formats are: first-n, middle-m, last-k, range-i-j which correspond | |
| to the first n, the middle m, the last k, and the range [i, j). | |
| """ | |
| to_add = None | |
| if ele.startswith('first-'): | |
| assert ele[6:].isdigit(), f'Invalid target_blocks element {ele}' | |
| to_add = list(range(min(int(ele[6:]), max_block_idx + 1))) | |
| elif ele.startswith('last-'): | |
| assert ele[5:].isdigit(), f'Invalid target_blocks element {ele}' | |
| to_add = list(range(max(max_block_idx - int(ele[5:]) + 1, 0), max_block_idx + 1)) | |
| elif ele.startswith('middle-'): | |
| assert ele[7:].isdigit(), f'Invalid target_blocks element {ele}' | |
| num = int(ele[7:]) | |
| start = max(max_block_idx // 2 - num // 2, 0) | |
| end = min(start + num, max_block_idx + 1) | |
| to_add = list(range(start, end)) | |
| elif ele.startswith('range-'): | |
| r = ele[6:].split('-') | |
| assert len(r) == 2, f'Invalid target_blocks element {ele}' | |
| start, end = (int(r[0]), int(r[1])) | |
| start = max(start, 0) | |
| end = min(end, max_block_idx + 1) | |
| to_add = list(range(start, end)) | |
| else: | |
| raise ValueError(f'Invalid target_blocks element {ele}') | |
| return to_add | |
| def get_target_block_list(target_blocks: Any, max_block_idx: int) -> list: | |
| """Parse the user input and return a list of block ids.""" | |
| candidate_block_ids = [] | |
| if isinstance(target_blocks, int): | |
| candidate_block_ids = list(range(target_blocks)) | |
| elif isinstance(target_blocks, list): | |
| for ele in target_blocks: | |
| if isinstance(ele, int): | |
| candidate_block_ids.append(ele) | |
| elif isinstance(ele, str): | |
| to_add = parse_ele_str(ele, max_block_idx) | |
| candidate_block_ids.extend(to_add) | |
| else: | |
| raise ValueError(f'target_blocks must be a list of integers or "first-n", "middle-m", "last-k", or "range-i-j" where n, m, k, i, j are integers, but got {target_blocks}') | |
| elif isinstance(target_blocks, str): | |
| target_blocks = target_blocks.replace(' ', '') | |
| for ele in target_blocks.split(','): | |
| to_add = parse_ele_str(ele, max_block_idx) | |
| candidate_block_ids.extend(to_add) | |
| else: | |
| raise ValueError(f'target_blocks must be either a single intege, or a list of integers, or a comma separated string made of "first-n", "last-m", "middle-k", "range-i-j", or a list of mixed integers and before-mentioned strings, but got {type(target_blocks)}') | |
| candidate_block_ids = list(set(candidate_block_ids)) | |
| return candidate_block_ids | |
| def check_mapping_blocks_overlap(mapping: dict, max_block_idx: int) -> None: | |
| """Check if the block ids in the mapping overlap with each other.""" | |
| all_blocks = [None] * (max_block_idx + 1) | |
| for k, v in mapping.items(): | |
| if v == -1: | |
| v = list(range(max_block_idx + 1)) | |
| for vv in v: | |
| if vv < 0 or vv > max_block_idx: | |
| continue | |
| elif all_blocks[vv] is not None: | |
| raise ValueError(f'Block {vv} is assigned to both {k} and {all_blocks[vv]}. Each block can only have one granularity of activation checkpointing. Make sure the target_blocks in activation_checkpointing_target do not overlap. For more details, refer to the docs of activation_checkpointing_fn.') | |
| else: | |
| all_blocks[vv] = k | |
| def build_act_ckpt_mod_to_blocks(act_ckpt_target: Any, top_module: Any, max_block_idx: int) -> dict: | |
| act_ckpt_mod_to_blocks = {} | |
| if act_ckpt_target is None or act_ckpt_target == []: | |
| mod = top_module | |
| act_ckpt_mod_to_blocks[mod] = -1 | |
| elif isinstance(act_ckpt_target, str): | |
| mod = get_act_ckpt_module(act_ckpt_target) | |
| act_ckpt_mod_to_blocks[mod] = -1 | |
| elif isinstance(act_ckpt_target, list): | |
| for target in act_ckpt_target: | |
| mod = get_act_ckpt_module(target) | |
| act_ckpt_mod_to_blocks[mod] = -1 | |
| elif isinstance(act_ckpt_target, dict): | |
| for k, v in act_ckpt_target.items(): | |
| mod = get_act_ckpt_module(k) | |
| block_ids = get_target_block_list(v, max_block_idx) | |
| act_ckpt_mod_to_blocks[mod] = block_ids | |
| else: | |
| raise ValueError(f'activation_checkpointing_target must be either a single string or a list or a dict, but got {type(act_ckpt_target)}') | |
| return act_ckpt_mod_to_blocks |