import torch | |
from transformers.models.llama.modeling_llama import LlamaModel | |
def rotate_half(x): | |
"""Rotates half the hidden dims of the input.""" | |
x1 = x[..., : x.shape[-1] // 4] | |
x2 = x[..., x.shape[-1] // 4 :] | |
return torch.cat((-x2, x1), dim=-1) | |
# example where we need some deps and some functions | |
class DummyModel(LlamaModel): | |
pass | |