|
|
|
|
|
|
|
|
|
|
|
from .compress import get_module_name |
|
from .constants import * |
|
from .helper import recursive_getattr |
|
from deepspeed.utils import logger |
|
|
|
|
|
class compression_scheduler(): |
|
''' |
|
Used to schedule different compression methods |
|
''' |
|
|
|
def __init__(self, model, compression_config): |
|
self.model = model |
|
self.compression_config = compression_config |
|
self.make_init() |
|
self.training_steps = 0 |
|
self.weight_quantization_enabled = False |
|
|
|
self.verbose = { |
|
WEIGHT_QUANTIZATION: False, |
|
ACTIVATION_QUANTIZATION: False, |
|
SPARSE_PRUNING: False, |
|
HEAD_PRUNING: False, |
|
ROW_PRUNING: False, |
|
CHANNEL_PRUNING: False |
|
} |
|
|
|
def make_init(self): |
|
self.different_compression_methods = {} |
|
for method, method_content in self.compression_config.items(): |
|
if LAYER_REDUCTION in method: |
|
continue |
|
self.different_compression_methods[method] = { |
|
TECHNIQUE_ENABLED: False, |
|
SHARED_PARAMETERS: None, |
|
DIFFERENT_GROUPS: [] |
|
} |
|
exist_module_name = set() |
|
shared_parameters = method_content[SHARED_PARAMETERS] |
|
self.different_compression_methods[method][TECHNIQUE_ENABLED] = shared_parameters[TECHNIQUE_ENABLED] |
|
self.different_compression_methods[method][SHARED_PARAMETERS] = shared_parameters |
|
|
|
for group_name, method_parameters in method_content[DIFFERENT_GROUPS].items(): |
|
module_name_list = [] |
|
for key_word in method_parameters[DIFFERENT_GROUPS_MODULE_SCOPE]: |
|
module_name, exist_module_name = get_module_name(group_name, |
|
self.model, |
|
key_word, |
|
exist_module_name, |
|
verbose=False) |
|
module_name_list.extend(module_name) |
|
if module_name_list: |
|
self.different_compression_methods[method][DIFFERENT_GROUPS].append( |
|
[group_name, module_name_list, |
|
method_parameters.copy().pop('params')]) |
|
|
|
def check_weight_quantization(self): |
|
|
|
wq = self.different_compression_methods[WEIGHT_QUANTIZATION] |
|
if not wq[TECHNIQUE_ENABLED]: |
|
return |
|
else: |
|
shared_parameters = wq[SHARED_PARAMETERS] |
|
if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]: |
|
for group_name, module_name_list, method_parameters in wq[DIFFERENT_GROUPS]: |
|
for module_name in module_name_list: |
|
module = recursive_getattr(self.model, module_name) |
|
module.weight_quantization_enabled = True |
|
|
|
if not self.verbose[WEIGHT_QUANTIZATION]: |
|
logger.info(f'Weight quantization is enabled at step {self.training_steps}') |
|
self.weight_quantization_enabled = True |
|
self.verbose[WEIGHT_QUANTIZATION] = True |
|
|
|
def check_activation_quantization(self): |
|
|
|
aq = self.different_compression_methods[ACTIVATION_QUANTIZATION] |
|
if not aq[TECHNIQUE_ENABLED]: |
|
return |
|
else: |
|
shared_parameters = aq[SHARED_PARAMETERS] |
|
if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]: |
|
for group_name, module_name_list, method_parameters in aq[DIFFERENT_GROUPS]: |
|
for module_name in module_name_list: |
|
module = recursive_getattr(self.model, module_name) |
|
module.activation_quantization_enabled = True |
|
if not self.verbose[ACTIVATION_QUANTIZATION]: |
|
logger.info(f'Activation quantization is enabled at step {self.training_steps}') |
|
self.verbose[ACTIVATION_QUANTIZATION] = True |
|
|
|
def check_sparse_pruning(self): |
|
|
|
sp = self.different_compression_methods[SPARSE_PRUNING] |
|
if not sp[TECHNIQUE_ENABLED]: |
|
return |
|
else: |
|
shared_parameters = sp[SHARED_PARAMETERS] |
|
if shared_parameters[TECHNIQUE_SCHEDULE_OFFSET] <= self.training_steps <= shared_parameters[ |
|
TECHNIQUE_SCHEDULE_OFFSET_END]: |
|
for group_name, module_name_list, method_parameters in sp[DIFFERENT_GROUPS]: |
|
for module_name in module_name_list: |
|
module = recursive_getattr(self.model, module_name) |
|
module.sparse_pruning_enabled = True |
|
if not self.verbose[SPARSE_PRUNING]: |
|
logger.info(f'Sparse pruning is enabled at step {self.training_steps}') |
|
self.verbose[SPARSE_PRUNING] = True |
|
|
|
def check_head_pruning(self): |
|
|
|
hp = self.different_compression_methods[HEAD_PRUNING] |
|
if not hp[TECHNIQUE_ENABLED]: |
|
return |
|
else: |
|
shared_parameters = hp[SHARED_PARAMETERS] |
|
if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]: |
|
for group_name, module_name_list, method_parameters in hp[DIFFERENT_GROUPS]: |
|
for module_name in module_name_list: |
|
module = recursive_getattr(self.model, module_name) |
|
module.head_pruning_enabled = True |
|
if not self.verbose[HEAD_PRUNING]: |
|
logger.info(f'Head pruning is enabled at step {self.training_steps}') |
|
self.verbose[HEAD_PRUNING] = True |
|
|
|
def check_row_pruning(self): |
|
|
|
rp = self.different_compression_methods[ROW_PRUNING] |
|
if not rp[TECHNIQUE_ENABLED]: |
|
return |
|
else: |
|
shared_parameters = rp[SHARED_PARAMETERS] |
|
if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]: |
|
for group_name, module_name_list, method_parameters in rp[DIFFERENT_GROUPS]: |
|
for module_name in module_name_list: |
|
module = recursive_getattr(self.model, module_name) |
|
module.row_pruning_enabled = True |
|
if not self.verbose[ROW_PRUNING]: |
|
logger.info(f'Row pruning is enabled at step {self.training_steps}') |
|
self.verbose[ROW_PRUNING] = True |
|
|
|
def check_channel_pruning(self): |
|
|
|
cp = self.different_compression_methods[CHANNEL_PRUNING] |
|
if not cp[TECHNIQUE_ENABLED]: |
|
return |
|
else: |
|
shared_parameters = cp[SHARED_PARAMETERS] |
|
if self.training_steps >= shared_parameters[TECHNIQUE_SCHEDULE_OFFSET]: |
|
for group_name, module_name_list, method_parameters in cp[DIFFERENT_GROUPS]: |
|
for module_name in module_name_list: |
|
module = recursive_getattr(self.model, module_name) |
|
module.channel_pruning_enabled = True |
|
if not self.verbose[CHANNEL_PRUNING]: |
|
logger.info(f'Channel pruning is enabled at step {self.training_steps}') |
|
self.verbose[CHANNEL_PRUNING] = True |
|
|
|
def check_all_modules(self): |
|
|
|
self.check_weight_quantization() |
|
self.check_activation_quantization() |
|
self.check_sparse_pruning() |
|
self.check_head_pruning() |
|
self.check_row_pruning() |
|
self.check_channel_pruning() |
|
|
|
def step(self, step_zero_check=False): |
|
if not step_zero_check: |
|
self.training_steps += 1 |
|
self.check_all_modules() |
|
|