File size: 359 Bytes
e0be88b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch

from transformers.models.llama.modeling_llama import LlamaModel


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 4]
    x2 = x[..., x.shape[-1] // 4 :]
    return torch.cat((-x2, x1), dim=-1)


# example where we need some deps and some functions
class DummyModel(LlamaModel):
    pass