Spaces:
Runtime error
Runtime error
File size: 2,724 Bytes
a0e2cb7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import torch
import os,sys
from transformers.utils import is_flash_attn_2_available
from transformers.models.llama import LlamaModel, LlamaConfig
from transformers.models.bart.modeling_bart import BartEncoder, BartDecoder, BartConfig
import warnings
# from transformers.models.musicgen.modeling_musicgen import MusicgenModel, MusicgenDecoder, MusicgenDecoderConfig # 用的就是BartDecoder,但是没有cross-attn
try:
assert is_flash_attn_2_available()
assert torch.cuda.get_device_capability(torch.device("cuda")) >= (8, 0)
assert os.environ.get("DISABLE_FLASH_ATTN",'0') != "1"
_enable_flash_attention = True
except:
_enable_flash_attention = False
if not _enable_flash_attention:
warnings.warn("Not support flash-attn!")
def get_backend(name, dim, num_heads, num_layers, hidden_scale, init_std=0.02, rope_theta=10000,):
# SA (causal) - FF
if name == 'llama':
model_cfg = LlamaConfig(
hidden_size=dim,
intermediate_size=dim * hidden_scale,
num_attention_heads=num_heads,
num_hidden_layers=num_layers,
num_key_value_heads=num_heads,
vocab_size=dim,
use_cache=False,
max_position_embeddings=4096,
hidden_act="silu",
initializer_range=init_std,
rope_theta=rope_theta,
_attn_implementation="flash_attention_2" if _enable_flash_attention else "eager",
)
model = LlamaModel(model_cfg)
# SA -FF
elif name == 'bart_enc':
model_cfg = BartConfig(
d_model=dim,
max_position_embeddings=4096,
dropout=0.,
use_cache=False,
_attn_implementation="flash_attention_2" if _enable_flash_attention else "eager",
activation_function='gelu',
# for BartEncoder
encoder_layers=num_layers,
encoder_attention_heads=num_heads,
init_std=init_std,
encoder_ffn_dim=dim * hidden_scale,
)
model = BartEncoder(model_cfg)
# SA - CA - FF
elif name == 'bart_dec':
model_cfg = BartConfig(
d_model=dim,
max_position_embeddings=4096,
dropout=0.,
use_cache=False,
_attn_implementation="flash_attention_2" if _enable_flash_attention else "eager",
activation_function='gelu',
# for BartDecoder
decoder_layers=num_layers,
decoder_attention_heads=num_heads,
decoder_ffn_dim=dim * hidden_scale,
)
model = BartDecoder(model_cfg)
else:
raise NotImplementedError
delattr(model, "embed_tokens")
return model |