File size: 359 Bytes
e0be88b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
import torch
from transformers.models.llama.modeling_llama import LlamaModel
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 4]
x2 = x[..., x.shape[-1] // 4 :]
return torch.cat((-x2, x1), dim=-1)
# example where we need some deps and some functions
class DummyModel(LlamaModel):
pass
|