Doaneer's picture
Upload 5065 files
e0be88b verified
raw
history blame contribute delete
704 Bytes
# Note that zamba does not have the `apply_rotary_pos_emb` function!
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from transformers.models.zamba.modeling_zamba import ZambaAttention
# When following ZambaAttention dependencies, the function `apply_rotary_pos_emb` is not present
# by default as it is absent from the class definition (and the file altogether).
# Note that this syntax should be able to add both `apply_rotary_pos_emb` as imported directly, but
# `rotate_half` as well as a dependency from the imported function!!
class TestAttention(ZambaAttention):
def __init__(self):
pass
def forward(self):
_ = apply_rotary_pos_emb(1, 1, 1, 1)