File size: 6,256 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
from .base import *
from .features.meta_tensor import MetaTensorContainer
from .features.hybrid_engine import HybridEngineContainer
from deepspeed.model_implementations.transformers.ds_bloom import DeepSpeedBloomInference
from ..policy import TransformerPolicy
from ..policy import transformer_param_names
from ..policy import maybe_copy
from ..policy import maybe_get_lora
supported_models = {None}
class DS_BloomContainer(MetaTensorContainer, HybridEngineContainer, BaseTransformerContainer):
def __init__(self, **kwargs):
# Check transformers version, error if > 4.43.4 (breaks at 4.44.0)
from importlib.metadata import version
v_transformers = version('transformers')
vers = v_transformers.split('.')
major = int(vers[0])
minor = int(vers[1])
if major > 4 or (major == 4 and minor > 43):
import sys
sys.exit(
f"Transformers version {v_transformers} exceeds version 4.43.4! After transformers version 4.43.4, BLOOM inference with DeepSpeed is no longer supported."
)
super().__init__(**kwargs)
# All model specific things should be defined here instead of the base class.
self.bigscience_bloom = True
self.triangular_masking = False
def create_module(self, config=None):
_config = config if config is not None else self.ds_model_config
self.module = DeepSpeedBloomInference(_config, mp_group=self.mp_group)
self.module.config.scale_attention = self.scale_attention
self.module.config.invert_mask = False
return self.module
def attention_qkv_mp(self, mp_replace, reversed_dim=False):
self.module.attention.attn_qkvw = mp_replace.copy(self.module.attention.attn_qkvw, self.qkvw)
self.module.attention.attn_qkvb = mp_replace.copy(self.module.attention.attn_qkvb, self.qkvb)
def get_lora_matched_pair(self):
"""
Necessary to implement for `HybridEngineContainer`
"""
fc1_lora, fc2_lora, qkv_lora, out_lora = self.get_lora_params()
ret = [(fc1_lora, self._h4h_w), (fc2_lora, self._4hh_w), (qkv_lora, self.qkvw), (out_lora, self.dense_w)]
return ret
def set_lora_params(self):
"""
Necessary to implement for `HybridEngineContainer`
"""
self.lora_params = [
maybe_get_lora(p) for p in [
self.policy.client_module.mlp.dense_h_to_4h, self.policy.client_module.mlp.dense_4h_to_h, self.policy.
client_module.self_attention.query_key_value, self.policy.client_module.self_attention.dense
]
]
def load_params(self, module, sd, weight_quantizer, mp_replace, prefix):
param_names = (
'self_attention.query_key_value.weight', \
'self_attention.query_key_value.bias', \
'self_attention.dense.weight', \
'self_attention.dense.bias', \
'mlp.dense_h_to_4h.weight', \
'mlp.dense_h_to_4h.bias', \
'mlp.dense_4h_to_h.weight', \
'mlp.dense_4h_to_h.bias', \
'post_attention_layernorm.weight', \
'post_attention_layernorm.bias', \
'input_layernorm.weight', \
'input_layernorm.bias'
)
for i in range(0, 2):
maybe_copy(module.attention,
sd,
weight_quantizer,
mp_replace,
transformer_param_names[i],
prefix + param_names[i],
qkv=True,
megatron_v2=self.policy.is_megatron_v2,
split_qkv=self.policy.split_qkv)
for i in range(2, 4):
maybe_copy(module.attention, sd, weight_quantizer, mp_replace, transformer_param_names[i],
prefix + param_names[i])
for i in range(4, 10):
maybe_copy(module.mlp, sd, weight_quantizer, mp_replace, transformer_param_names[i],
prefix + param_names[i])
for i in range(10, 12):
maybe_copy(module, sd, weight_quantizer, mp_replace, transformer_param_names[i], prefix + param_names[i])
class BLOOMLayerPolicy(TransformerPolicy):
_orig_layer_class = None
def __init__(self, client_module, inference=True, use_load_prefix=True, split_qkv=False):
super().__init__(inference, linear_layer=True, use_load_prefix=use_load_prefix, split_qkv=split_qkv)
self.client_module = client_module
try:
import transformers
BLOOMLayerPolicy._orig_layer_class = transformers.models.bloom.modeling_bloom.BloomBlock
global supported_models
supported_models.update({transformers.models.bloom.modeling_bloom.BloomModel})
except Exception as e:
print(f"WARNING! Setting BLOOMLayerPolicy._orig_layer_class to None due to Exception: {e}")
BLOOMLayerPolicy._orig_layer_class = None
def get_hidden_heads(self):
return self.client_module.self_attention.hidden_size, \
self.client_module.self_attention.num_heads, \
self.client_module.input_layernorm.eps, \
DEFAULT_INTERMEDIATE_SIZE
def attention(self, enable_training=False):
return self.client_module.self_attention.query_key_value.weight, \
self.client_module.self_attention.query_key_value.bias, \
self.client_module.self_attention.dense.weight, \
self.client_module.self_attention.dense.bias,
def mlp(self, enable_training=False):
return self.client_module.mlp.dense_h_to_4h.weight, \
self.client_module.mlp.dense_h_to_4h.bias, \
self.client_module.mlp.dense_4h_to_h.weight, \
self.client_module.mlp.dense_4h_to_h.bias
def layernorm(self):
return self.client_module.post_attention_layernorm.weight, \
self.client_module.post_attention_layernorm.bias, \
self.client_module.input_layernorm.weight, \
self.client_module.input_layernorm.bias
|