feat: use latest modeling code
Browse files- configuration_stablelm.py +3 -1
- modeling_stablelm.py +109 -13
configuration_stablelm.py
CHANGED
|
@@ -45,7 +45,7 @@ class StableLmConfig(PretrainedConfig):
|
|
| 45 |
intermediate_size (`int`, *optional*, defaults to 6912):
|
| 46 |
Dimension of the MLP representations.
|
| 47 |
hidden_size (`int`, *optional*, defaults to 2560):
|
| 48 |
-
|
| 49 |
num_hidden_layers (`int`, *optional*, defaults to 32):
|
| 50 |
Number of hidden layers in the Transformer decoder.
|
| 51 |
num_attention_heads (`int`, *optional*, defaults to 32):
|
|
@@ -134,12 +134,14 @@ class StableLmConfig(PretrainedConfig):
|
|
| 134 |
):
|
| 135 |
self.vocab_size = vocab_size
|
| 136 |
self.max_position_embeddings = max_position_embeddings
|
|
|
|
| 137 |
self.hidden_size = hidden_size
|
| 138 |
self.intermediate_size = intermediate_size
|
| 139 |
self.num_hidden_layers = num_hidden_layers
|
| 140 |
self.num_attention_heads = num_attention_heads
|
| 141 |
self.num_key_value_heads = num_key_value_heads
|
| 142 |
self.hidden_act = hidden_act
|
|
|
|
| 143 |
self.initializer_range = initializer_range
|
| 144 |
self.layer_norm_eps = layer_norm_eps
|
| 145 |
self.use_cache = use_cache
|
|
|
|
| 45 |
intermediate_size (`int`, *optional*, defaults to 6912):
|
| 46 |
Dimension of the MLP representations.
|
| 47 |
hidden_size (`int`, *optional*, defaults to 2560):
|
| 48 |
+
Number of hidden layers in the Transformer decoder.
|
| 49 |
num_hidden_layers (`int`, *optional*, defaults to 32):
|
| 50 |
Number of hidden layers in the Transformer decoder.
|
| 51 |
num_attention_heads (`int`, *optional*, defaults to 32):
|
|
|
|
| 134 |
):
|
| 135 |
self.vocab_size = vocab_size
|
| 136 |
self.max_position_embeddings = max_position_embeddings
|
| 137 |
+
|
| 138 |
self.hidden_size = hidden_size
|
| 139 |
self.intermediate_size = intermediate_size
|
| 140 |
self.num_hidden_layers = num_hidden_layers
|
| 141 |
self.num_attention_heads = num_attention_heads
|
| 142 |
self.num_key_value_heads = num_key_value_heads
|
| 143 |
self.hidden_act = hidden_act
|
| 144 |
+
|
| 145 |
self.initializer_range = initializer_range
|
| 146 |
self.layer_norm_eps = layer_norm_eps
|
| 147 |
self.use_cache = use_cache
|
modeling_stablelm.py
CHANGED
|
@@ -103,7 +103,7 @@ class StableLmRotaryEmbedding(nn.Module):
|
|
| 103 |
)
|
| 104 |
|
| 105 |
|
| 106 |
-
# Copied from transformers.models.
|
| 107 |
class StableLmLinearScalingRotaryEmbedding(StableLmRotaryEmbedding):
|
| 108 |
"""StableLmRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
| 109 |
|
|
@@ -123,7 +123,7 @@ class StableLmLinearScalingRotaryEmbedding(StableLmRotaryEmbedding):
|
|
| 123 |
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 124 |
|
| 125 |
|
| 126 |
-
# Copied from transformers.models.
|
| 127 |
class StableLmDynamicNTKScalingRotaryEmbedding(StableLmRotaryEmbedding):
|
| 128 |
"""StableLmRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
| 129 |
|
|
@@ -374,6 +374,102 @@ class StableLmAttention(nn.Module):
|
|
| 374 |
return attn_output, attn_weights, past_key_value
|
| 375 |
|
| 376 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 377 |
class StableLmFlashAttention2(StableLmAttention):
|
| 378 |
"""
|
| 379 |
StableLM flash attention module. This module inherits from `StableLmAttention` as the weights of the module stays
|
|
@@ -574,6 +670,7 @@ class StableLmFlashAttention2(StableLmAttention):
|
|
| 574 |
|
| 575 |
ATTENTION_CLASSES = {
|
| 576 |
"eager": StableLmAttention,
|
|
|
|
| 577 |
"flash_attention_2": StableLmFlashAttention2,
|
| 578 |
}
|
| 579 |
|
|
@@ -669,7 +766,7 @@ STABLELM_START_DOCSTRING = r"""
|
|
| 669 |
|
| 670 |
|
| 671 |
@add_start_docstrings(
|
| 672 |
-
"The bare
|
| 673 |
STABLELM_START_DOCSTRING,
|
| 674 |
)
|
| 675 |
class StableLmPreTrainedModel(PreTrainedModel):
|
|
@@ -680,6 +777,7 @@ class StableLmPreTrainedModel(PreTrainedModel):
|
|
| 680 |
_skip_keys_device_placement = "past_key_values"
|
| 681 |
_supports_flash_attn_2 = True
|
| 682 |
_supports_cache_class = True
|
|
|
|
| 683 |
|
| 684 |
def _init_weights(self, module):
|
| 685 |
std = self.config.initializer_range
|
|
@@ -764,7 +862,7 @@ STABLELM_INPUTS_DOCSTRING = r"""
|
|
| 764 |
|
| 765 |
|
| 766 |
@add_start_docstrings(
|
| 767 |
-
"The bare
|
| 768 |
STABLELM_START_DOCSTRING,
|
| 769 |
)
|
| 770 |
class StableLmModel(StableLmPreTrainedModel):
|
|
@@ -858,6 +956,11 @@ class StableLmModel(StableLmPreTrainedModel):
|
|
| 858 |
if self._attn_implementation == "flash_attention_2":
|
| 859 |
# 2d mask is passed through the layers
|
| 860 |
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 861 |
else:
|
| 862 |
# 4d mask is passed through the layers
|
| 863 |
attention_mask = _prepare_4d_causal_attention_mask(
|
|
@@ -999,7 +1102,7 @@ class StableLmForCausalLM(StableLmPreTrainedModel):
|
|
| 999 |
>>> # Generate
|
| 1000 |
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 1001 |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 1002 |
-
'The weather is always wonderful in the San
|
| 1003 |
```"""
|
| 1004 |
|
| 1005 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
@@ -1048,7 +1151,6 @@ class StableLmForCausalLM(StableLmPreTrainedModel):
|
|
| 1048 |
attentions=outputs.attentions,
|
| 1049 |
)
|
| 1050 |
|
| 1051 |
-
# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM.prepare_inputs_for_generation
|
| 1052 |
def prepare_inputs_for_generation(
|
| 1053 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
| 1054 |
):
|
|
@@ -1089,12 +1191,6 @@ class StableLmForCausalLM(StableLmPreTrainedModel):
|
|
| 1089 |
if past_key_values:
|
| 1090 |
position_ids = position_ids[:, -input_ids.shape[1] :]
|
| 1091 |
|
| 1092 |
-
if past_key_value := getattr(self.model.layers[0].self_attn, "past_key_value", None):
|
| 1093 |
-
# generation with static cache
|
| 1094 |
-
seen_tokens = past_key_value.get_seq_length()
|
| 1095 |
-
input_ids = input_ids[:, seen_tokens:]
|
| 1096 |
-
position_ids = position_ids[:, seen_tokens:]
|
| 1097 |
-
|
| 1098 |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 1099 |
if inputs_embeds is not None and past_key_values is None:
|
| 1100 |
model_inputs = {"inputs_embeds": inputs_embeds}
|
|
@@ -1123,7 +1219,7 @@ class StableLmForCausalLM(StableLmPreTrainedModel):
|
|
| 1123 |
|
| 1124 |
@add_start_docstrings(
|
| 1125 |
"""
|
| 1126 |
-
The
|
| 1127 |
|
| 1128 |
[`StableLmForSequenceClassification`] uses the last token in order to do the classification, as other causal
|
| 1129 |
models (e.g. GPT-2) do.
|
|
|
|
| 103 |
)
|
| 104 |
|
| 105 |
|
| 106 |
+
# Copied from transformers.models.falcon.modeling_falcon.FalconLinearScalingRotaryEmbedding with Falcon->StableLm
|
| 107 |
class StableLmLinearScalingRotaryEmbedding(StableLmRotaryEmbedding):
|
| 108 |
"""StableLmRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
|
| 109 |
|
|
|
|
| 123 |
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
|
| 124 |
|
| 125 |
|
| 126 |
+
# Copied from transformers.models.falcon.modeling_falcon.FalconDynamicNTKScalingRotaryEmbedding with Falcon->StableLm
|
| 127 |
class StableLmDynamicNTKScalingRotaryEmbedding(StableLmRotaryEmbedding):
|
| 128 |
"""StableLmRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
|
| 129 |
|
|
|
|
| 374 |
return attn_output, attn_weights, past_key_value
|
| 375 |
|
| 376 |
|
| 377 |
+
class StableLmSdpaAttention(StableLmAttention):
|
| 378 |
+
def forward(
|
| 379 |
+
self,
|
| 380 |
+
hidden_states: torch.Tensor,
|
| 381 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 382 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 383 |
+
past_key_value: Optional[Cache] = None,
|
| 384 |
+
output_attentions: bool = False,
|
| 385 |
+
use_cache: bool = False,
|
| 386 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
| 387 |
+
if output_attentions:
|
| 388 |
+
# TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
|
| 389 |
+
logger.warning_once(
|
| 390 |
+
"StableLmModel is using StableLmSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
|
| 391 |
+
'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
|
| 392 |
+
)
|
| 393 |
+
return super().forward(
|
| 394 |
+
hidden_states=hidden_states,
|
| 395 |
+
attention_mask=attention_mask,
|
| 396 |
+
position_ids=position_ids,
|
| 397 |
+
past_key_value=past_key_value,
|
| 398 |
+
output_attentions=output_attentions,
|
| 399 |
+
use_cache=use_cache,
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
bsz, q_len, _ = hidden_states.size()
|
| 403 |
+
|
| 404 |
+
query_states = self.q_proj(hidden_states)
|
| 405 |
+
key_states = self.k_proj(hidden_states)
|
| 406 |
+
value_states = self.v_proj(hidden_states)
|
| 407 |
+
|
| 408 |
+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
|
| 409 |
+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 410 |
+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
|
| 411 |
+
|
| 412 |
+
kv_seq_len = key_states.shape[-2]
|
| 413 |
+
if past_key_value is not None:
|
| 414 |
+
if self.layer_idx is None:
|
| 415 |
+
raise ValueError(
|
| 416 |
+
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
| 417 |
+
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
| 418 |
+
"with a layer index."
|
| 419 |
+
)
|
| 420 |
+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
| 421 |
+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
| 422 |
+
|
| 423 |
+
# Partial rotary embedding
|
| 424 |
+
query_rot, query_pass = (
|
| 425 |
+
query_states[..., : self.rotary_emb.dim],
|
| 426 |
+
query_states[..., self.rotary_emb.dim :],
|
| 427 |
+
)
|
| 428 |
+
key_rot, key_pass = (
|
| 429 |
+
key_states[..., : self.rotary_emb.dim],
|
| 430 |
+
key_states[..., self.rotary_emb.dim :],
|
| 431 |
+
)
|
| 432 |
+
# [batch_size, seq_length, num_heads, head_dim // config.partial_rotary_factor]
|
| 433 |
+
query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
|
| 434 |
+
|
| 435 |
+
# [batch_size, seq_length, num_heads, head_dim]
|
| 436 |
+
query_states = torch.cat((query_rot, query_pass), dim=-1)
|
| 437 |
+
key_states = torch.cat((key_rot, key_pass), dim=-1)
|
| 438 |
+
|
| 439 |
+
if past_key_value is not None:
|
| 440 |
+
# Specific to RoPE models with partial rotation
|
| 441 |
+
cache_kwargs = {"sin": sin, "cos": cos, "partial_rotation_size": self.rotary_emb.dim}
|
| 442 |
+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 443 |
+
|
| 444 |
+
# Repeat k/v heads if n_kv_heads < n_heads
|
| 445 |
+
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
| 446 |
+
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
| 447 |
+
|
| 448 |
+
# SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
|
| 449 |
+
# Reference: https://github.com/pytorch/pytorch/issues/112577.
|
| 450 |
+
if query_states.device.type == "cuda" and attention_mask is not None:
|
| 451 |
+
query_states = query_states.contiguous()
|
| 452 |
+
key_states = key_states.contiguous()
|
| 453 |
+
value_states = value_states.contiguous()
|
| 454 |
+
|
| 455 |
+
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
| 456 |
+
query_states,
|
| 457 |
+
key_states,
|
| 458 |
+
value_states,
|
| 459 |
+
attn_mask=attention_mask,
|
| 460 |
+
dropout_p=self.attention_dropout.p if self.training else 0.0,
|
| 461 |
+
# The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
|
| 462 |
+
is_causal=self.is_causal and attention_mask is None and q_len > 1,
|
| 463 |
+
)
|
| 464 |
+
|
| 465 |
+
attn_output = attn_output.transpose(1, 2).contiguous()
|
| 466 |
+
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
|
| 467 |
+
|
| 468 |
+
attn_output = self.o_proj(attn_output)
|
| 469 |
+
|
| 470 |
+
return attn_output, None, past_key_value
|
| 471 |
+
|
| 472 |
+
|
| 473 |
class StableLmFlashAttention2(StableLmAttention):
|
| 474 |
"""
|
| 475 |
StableLM flash attention module. This module inherits from `StableLmAttention` as the weights of the module stays
|
|
|
|
| 670 |
|
| 671 |
ATTENTION_CLASSES = {
|
| 672 |
"eager": StableLmAttention,
|
| 673 |
+
"sdpa": StableLmSdpaAttention,
|
| 674 |
"flash_attention_2": StableLmFlashAttention2,
|
| 675 |
}
|
| 676 |
|
|
|
|
| 766 |
|
| 767 |
|
| 768 |
@add_start_docstrings(
|
| 769 |
+
"The bare StableLm Model outputting raw hidden-states without any specific head on top.",
|
| 770 |
STABLELM_START_DOCSTRING,
|
| 771 |
)
|
| 772 |
class StableLmPreTrainedModel(PreTrainedModel):
|
|
|
|
| 777 |
_skip_keys_device_placement = "past_key_values"
|
| 778 |
_supports_flash_attn_2 = True
|
| 779 |
_supports_cache_class = True
|
| 780 |
+
_supports_sdpa = True
|
| 781 |
|
| 782 |
def _init_weights(self, module):
|
| 783 |
std = self.config.initializer_range
|
|
|
|
| 862 |
|
| 863 |
|
| 864 |
@add_start_docstrings(
|
| 865 |
+
"The bare StableLm Model outputting raw hidden-states without any specific head on top.",
|
| 866 |
STABLELM_START_DOCSTRING,
|
| 867 |
)
|
| 868 |
class StableLmModel(StableLmPreTrainedModel):
|
|
|
|
| 956 |
if self._attn_implementation == "flash_attention_2":
|
| 957 |
# 2d mask is passed through the layers
|
| 958 |
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
| 959 |
+
# for output_attentions case used fallback to eager attention realization
|
| 960 |
+
elif self._attn_implementation == "sdpa" and not output_attentions:
|
| 961 |
+
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
| 962 |
+
attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
|
| 963 |
+
)
|
| 964 |
else:
|
| 965 |
# 4d mask is passed through the layers
|
| 966 |
attention_mask = _prepare_4d_causal_attention_mask(
|
|
|
|
| 1102 |
>>> # Generate
|
| 1103 |
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
|
| 1104 |
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
|
| 1105 |
+
'The weather is always wonderful in the summer in the city of San Diego. The city is located on the coast of the Pacific Ocean and is surrounded by'
|
| 1106 |
```"""
|
| 1107 |
|
| 1108 |
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
|
|
| 1151 |
attentions=outputs.attentions,
|
| 1152 |
)
|
| 1153 |
|
|
|
|
| 1154 |
def prepare_inputs_for_generation(
|
| 1155 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
| 1156 |
):
|
|
|
|
| 1191 |
if past_key_values:
|
| 1192 |
position_ids = position_ids[:, -input_ids.shape[1] :]
|
| 1193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1194 |
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
|
| 1195 |
if inputs_embeds is not None and past_key_values is None:
|
| 1196 |
model_inputs = {"inputs_embeds": inputs_embeds}
|
|
|
|
| 1219 |
|
| 1220 |
@add_start_docstrings(
|
| 1221 |
"""
|
| 1222 |
+
The StableLm transformer with a sequence classification head on top (linear layer).
|
| 1223 |
|
| 1224 |
[`StableLmForSequenceClassification`] uses the last token in order to do the classification, as other causal
|
| 1225 |
models (e.g. GPT-2) do.
|