-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrope.py
96 lines (78 loc) · 3.89 KB
/
rope.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
import math
class RopeND:
def __init__(self, head_dim=64, nd=3, max_lens=[1024, 64, 64], nd_split=[2,1,1], bases=[1000, 1000, 1000], auto_base=True, cache_longer=1):
self.nd = nd
self.head_dim = head_dim
self.max_lens = max_lens
self.nd_split = nd_split
self.split_dims = [2*i*(head_dim//2//sum(nd_split)) for i in nd_split]
assert sum(self.split_dims) == head_dim
self.auto_base = auto_base
if auto_base:
# empirical, make cos(theta) = -1 when length is kL. base = kL/pi
# And L=1 the difference (1/base)**(1/32) ~ 0.7-0.8 ~ pi/4
# for traditional L = 4096, 8L/pi = 10.4k, base is set to 10k
self.bases = [(int(8*l/math.pi)// 100 + 1)*100 for l in self.max_lens]
print(f"Bases for rope: {self.bases}")
else:
self.bases = bases
self.cache_longer = cache_longer
def generated_cos_sin_mix2d(self, max_len, dim, device, base=1000):
inv_freq = 1.0 / (base ** \
(torch.linspace(start=0, end=self.head_dim, steps=dim//2, device=device).float() / self.head_dim ))
assert inv_freq.size(0) * 2 == dim, f"inv_freq.size(0) = {inv_freq.size(0)}, required dim = {dim}"
t = torch.arange(max_len*self.cache_longer, device=device).type_as(inv_freq)
freqs = torch.einsum("i,j->ij", t, inv_freq)
freqs = torch.cat([freqs, freqs], dim=1)
return freqs.cos().to(torch.float), freqs.sin().to(torch.float)
def generate_pos_embs_mix2d(self, position_ids, device=None):
if device is None:
device = position_ids.device
if position_ids.dim() == 1:
position_ids = position_ids.unsqueeze(0)
cos_emb_all, sin_emb_all = [], []
for i in range(self.nd):
dim_i = self.split_dims[i]
base_i = self.bases[i]
max_len_i = self.max_lens[i]
if not hasattr(self, f"cos_{i}"):
_cos, _sin = self.generated_cos_sin_mix2d(max_len=max_len_i, dim=dim_i, device=device, base=base_i)
setattr(self, f"cos_{i}", _cos)
setattr(self, f"sin_{i}", _sin)
cos_emb_all.append(getattr(self, f'cos_{i}')[position_ids[i, :], :])
sin_emb_all.append(getattr(self, f'sin_{i}')[position_ids[i, :], :])
cos_emb = torch.cat(cos_emb_all, dim=-1)
sin_emb = torch.cat(sin_emb_all, dim=-1)
return cos_emb, sin_emb
def __call__(self, q, k, position_ids):
'''q: N N_head L C
'''
cos_emb, sin_emb = self.generate_pos_embs_mix2d(position_ids, device=q.device)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(0).unsqueeze(0)
sin = sin.unsqueeze(0).unsqueeze(0)
dtype = q.dtype
q = q.to(torch.float)
k = k.to(torch.float)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
q_embed = q_embed.to(dtype)
k_embed = k_embed.to(dtype)
return q_embed, k_embed
q, k = apply_rotary_pos_emb(q, k, cos_emb, sin_emb)
return q, k