File size: 704 Bytes
e0be88b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
# Note that zamba does not have the `apply_rotary_pos_emb` function!
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from transformers.models.zamba.modeling_zamba import ZambaAttention
# When following ZambaAttention dependencies, the function `apply_rotary_pos_emb` is not present
# by default as it is absent from the class definition (and the file altogether).
# Note that this syntax should be able to add both `apply_rotary_pos_emb` as imported directly, but
# `rotate_half` as well as a dependency from the imported function!!
class TestAttention(ZambaAttention):
def __init__(self):
pass
def forward(self):
_ = apply_rotary_pos_emb(1, 1, 1, 1)
|