Spaces:
Running
Running
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT license. | |
""" | |
Abstract base classes for TensorFlow model compression. | |
""" | |
import logging | |
import tensorflow as tf | |
assert tf.__version__.startswith('2'), 'NNI model compression only supports TensorFlow v2.x' | |
from . import default_layers | |
_logger = logging.getLogger(__name__) | |
class Compressor: | |
""" | |
Common base class for all compressors. | |
This class is designed for other base classes. | |
Algorithms should inherit ``Pruner`` or ``Quantizer`` instead. | |
Attributes | |
---------- | |
compressed_model : tf.keras.Model | |
Compressed user model. | |
wrappers : list of tf.keras.Model | |
A wrapper is an instrumented TF ``Layer``, in ``Model`` format. | |
Parameters | |
---------- | |
model : tf.keras.Model | |
The user model to be compressed. | |
config_list : list of JSON object | |
User configuration. The format is detailed in tutorial. | |
LayerWrapperClass : a class derive from Model | |
The class used to instrument layers. | |
""" | |
def __init__(self, model, config_list, LayerWrapperClass): | |
assert isinstance(model, tf.keras.Model) | |
self.validate_config(model, config_list) | |
self._original_model = model | |
self._config_list = config_list | |
self._wrapper_class = LayerWrapperClass | |
self._wrappers = {} # key: id(layer) , value: Wrapper(layer) | |
self.compressed_model = self._instrument(model) | |
self.wrappers = list(self._wrappers.values()) | |
if not self.wrappers: | |
_logger.warning('Nothing is configured to compress, please check your model and config list') | |
def set_wrappers_attribute(self, name, value): | |
""" | |
Call ``setattr`` on all wrappers. | |
""" | |
for wrapper in self.wrappers: | |
setattr(wrapper, name, value) | |
def validate_config(self, model, config_list): | |
""" | |
Compression algorithm should overload this function to validate configuration. | |
""" | |
pass | |
def _instrument(self, layer): | |
if isinstance(layer, tf.keras.Sequential): | |
return self._instrument_sequential(layer) | |
if isinstance(layer, tf.keras.Model): | |
return self._instrument_model(layer) | |
# a layer can be referenced in multiple attributes of a model, | |
# but should only be instrumented once | |
if id(layer) in self._wrappers: | |
return self._wrappers[id(layer)] | |
config = self._select_config(layer) | |
if config is not None: | |
wrapper = self._wrapper_class(layer, config, self) | |
self._wrappers[id(layer)] = wrapper | |
return wrapper | |
return layer | |
def _instrument_sequential(self, seq): | |
layers = list(seq.layers) # seq.layers is read-only property | |
need_rebuild = False | |
for i, layer in enumerate(layers): | |
new_layer = self._instrument(layer) | |
if new_layer is not layer: | |
layers[i] = new_layer | |
need_rebuild = True | |
return tf.keras.Sequential(layers) if need_rebuild else seq | |
def _instrument_model(self, model): | |
for key, value in list(model.__dict__.items()): # avoid "dictionary keys changed during iteration" | |
if isinstance(value, tf.keras.layers.Layer): | |
new_layer = self._instrument(value) | |
if new_layer is not value: | |
setattr(model, key, new_layer) | |
elif isinstance(value, list): | |
for i, item in enumerate(value): | |
if isinstance(item, tf.keras.layers.Layer): | |
value[i] = self._instrument(item) | |
return model | |
def _select_config(self, layer): | |
# Find the last matching config block for given layer. | |
# Returns None if the layer should not be compressed. | |
layer_type = type(layer).__name__ | |
last_match = None | |
for config in self._config_list: | |
if 'op_types' in config: | |
match = layer_type in config['op_types'] | |
match_default = 'default' in config['op_types'] and layer_type in default_layers.weighted_modules | |
if not match and not match_default: | |
continue | |
if 'op_names' in config and layer.name not in config['op_names']: | |
continue | |
last_match = config | |
if last_match is None or 'exclude' in last_match: | |
return None | |
return last_match | |
class Pruner(Compressor): | |
""" | |
Base class for pruning algorithms. | |
End users should use ``compress`` and callback APIs (WIP) to prune their models. | |
The underlying model is instrumented upon initialization of pruner object. | |
So if you want to pre-train the model, train it before creating pruner object. | |
The compressed model can only execute in eager mode. | |
Algorithm developers should override ``calc_masks`` method to specify pruning strategy. | |
Parameters | |
---------- | |
model : tf.keras.Model | |
The user model to prune. | |
config_list : list of JSON object | |
User configuration. The format is detailed in tutorial. | |
""" | |
def __init__(self, model, config_list): | |
super().__init__(model, config_list, PrunerLayerWrapper) | |
#self.callback = PrunerCallback(self) | |
def compress(self): | |
""" | |
Apply compression on a pre-trained model. | |
If you want to prune the model during training, use callback API (WIP) instead. | |
Returns | |
------- | |
tf.keras.Model | |
The compressed model. | |
""" | |
self._update_mask() | |
return self.compressed_model | |
def calc_masks(self, wrapper, **kwargs): | |
""" | |
Abstract method to be overridden by algorithm. End users should ignore it. | |
If the callback is set up, this method will be invoked at end of each training minibatch. | |
If not, it will only be called when end user invokes ``compress``. | |
Parameters | |
---------- | |
wrapper : PrunerLayerWrapper | |
The instrumented layer. | |
**kwargs | |
Reserved for forward compatibility. | |
Returns | |
------- | |
dict of (str, tf.Tensor), or None | |
The key is weight ``Variable``'s name. The value is a mask ``Tensor`` of weight's shape and dtype. | |
If a weight's key does not appear in the return value, that weight will not be pruned. | |
Returning ``None`` means the mask is not changed since last time. | |
Weight names are globally unique, e.g. `model/conv_1/kernel:0`. | |
""" | |
# TODO: maybe it should be able to calc on weight-granularity, beside from layer-granularity | |
raise NotImplementedError("Pruners must overload calc_masks()") | |
def _update_mask(self): | |
for wrapper_idx, wrapper in enumerate(self.wrappers): | |
masks = self.calc_masks(wrapper, wrapper_idx=wrapper_idx) | |
if masks is not None: | |
wrapper.masks = masks | |
class PrunerLayerWrapper(tf.keras.Model): | |
""" | |
Instrumented TF layer. | |
Wrappers will be passed to pruner's ``calc_masks`` API, | |
and the pruning algorithm should use wrapper's attributes to calculate masks. | |
Once instrumented, underlying layer's weights will get **modified** by masks before forward pass. | |
Attributes | |
---------- | |
layer_info : LayerInfo | |
All static information of the original layer. | |
layer : tf.keras.layers.Layer | |
The original layer. | |
config : JSON object | |
Selected configuration. The format is detailed in tutorial. | |
pruner : Pruner | |
Bound pruner object. | |
masks : dict of (str, tf.Tensor) | |
Current masks. The key is weight's name and the value is mask tensor. | |
On initialization, `masks` is an empty dict, which means no weight is pruned. | |
Afterwards, `masks` is the last return value of ``Pruner.calc_masks``. | |
See ``Pruner.calc_masks`` for details. | |
""" | |
def __init__(self, layer, config, pruner): | |
super().__init__() | |
self.layer = layer | |
self.config = config | |
self.pruner = pruner | |
self.masks = {} | |
_logger.info('Layer detected to compress: %s', self.layer.name) | |
def call(self, *inputs): | |
new_weights = [] | |
for weight in self.layer.weights: | |
mask = self.masks.get(weight.name) | |
if mask is not None: | |
new_weights.append(tf.math.multiply(weight, mask)) | |
else: | |
new_weights.append(weight) | |
if new_weights and not hasattr(new_weights[0], 'numpy'): | |
raise RuntimeError('NNI: Compressed model can only run in eager mode') | |
self.layer.set_weights([weight.numpy() for weight in new_weights]) | |
return self.layer(*inputs) | |
# TODO: designed to replace `patch_optimizer` | |
#class PrunerCallback(tf.keras.callbacks.Callback): | |
# def __init__(self, pruner): | |
# super().__init__() | |
# self._pruner = pruner | |
# | |
# def on_train_batch_end(self, batch, logs=None): | |
# self._pruner.update_mask() | |