首先放上苏神的两个博客:

本人在进行代码解析之前,首先尽量去理解了这个理论基础,之后在解析代码的时候和论文、博客相互映照,才最终得到一个比较好的理解。其中,第二个是重中之重,需要从头到尾地理解透彻。第一个是阐述了RoPE的数学思路、设计原理。单看这个可能不太能理解RoPE,可以参照其他一些讲解RoPE的博客去理解,使用Google或者bing可以搜到一些程序员自己搭建的博客页面,而非csdn之类的,这些感觉质量都还不错。
此外,一定要理解RoPE的原生代码
总的来说,笔者在上面这些准备工作花不少时间,之后在理解yarn的代码上又花费了大概一天的时间,才终于产出了这篇解析源代码的博客。所获得的收获是很大的,不单单是理解了一种方法,更是对位置编码有了更深的理解。
当然,这些都还是一些浅显的东西,位置编码及其本质的探讨,苏神后来也有一些博客,有兴趣的可以去看一看。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# Inverse dim formula to find dim based on number of rotations
def _yarn_find_correction_dim(num_rotations, dim, base=10000, max_position_embeddings=2048):
return (dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi)))/(2 * math.log(base))
# 从函数名来看,应该是找到yarn应用时所需要的正确的dim,这个函数的需求是num_rotations,即θi在训练窗长内所占的圈数;返回结果是一个固定公式计算得出的数值。在yarn类中,最终会传入这个函数的是beta和alpha,通常是32和1。

# 函数应该是,基于hidden dim的维度,如0, 1, 2, ....511,计算这个维度需要的?
# 从dim乘以一个系数的角度考虑,那么这个函数是再确定应用[不变、线性插值、缩放]的dim范围!确实是应该有两个分界点的!

# Find dim range bounds based on rotations
def _yarn_find_correction_range(low_rot, high_rot, dim, base=10000, max_position_embeddings=2048):
low = math.floor(_yarn_find_correction_dim(
low_rot, dim, base, max_position_embeddings))
high = math.ceil(_yarn_find_correction_dim(
high_rot, dim, base, max_position_embeddings))
return max(low, 0), min(high, dim-1) # Clamp values just in case
# 上面的函数计算出来的数值,用在此处计算出了low和high,两个float返回值作为该函数的结果。low的值域应该是[0, min(high, dim-1)], high的值域应该是high和dim-1两者之间的更小值。这个值很有意思,因为在确定当前的θi是要走[不变、线性插值、缩放]三个条件中哪一个时判断的依据,就是。
# 由此,上第一个函数,应该是计算训练过程中转的圈数γi?其公式正好是θi * L_train / 2 * PI。拿这个公式带入最上面的函数,则:
# 分子:dim * math.log(max_position_embeddings/(num_rotations * 2 * math.pi))
# 分母:2 * math.log(base)
# 从苏神的博客中来看,唯一用到log的地方就是scale因子的确定上。

def _yarn_linear_ramp_mask(min, max, dim):
if min == max:
max += 0.001 # Prevent singularity

linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min)
ramp_func = torch.clamp(linear_func, 0, 1)
return ramp_func
# 上一个函数确定的范围,在此处得到了应用。linear_func先arange一个形如[0, 1, 2...., dim - 1]的tensor,shape为[1, dim],然后将其最大最小归一化。最后,将其截断在全部clamp(阶段)在min和max之间。

def _yarn_get_mscale(scale=1):
if scale <= 1:
return 1.0
return 0.1 * math.log(scale) + 1.0
# 根据sacle得到缩放系数sqrt(t),这个公式是固定的

以上是yarn需要用到的一些函数。逐个地顺序解析,大概知道了函数输入。输出是什么。至于为什么要这些输入输出,则需要到下面的主体class去看一下详细的操作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
# 如果让我基于yarn的理论写代码,那么,我应该会:原生rope的θi计算。然后对计算每个hidden position的缩放系数,然后相乘。然后apply rotray
class LlamaYaRNScaledRotaryEmbedding(torch.nn.Module):
def __init__(
self,
dim,
max_position_embeddings=2048,
base=10000,
scale=1,
original_max_position_embeddings=2048,
extrapolation_factor=1,
attn_factor=1,
beta_fast=32,
beta_slow=1,
finetuned=False,
device=None,
):
super().__init__()

self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.scale = scale
self.original_max_position_embeddings = original_max_position_embeddings
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow

self.yarn(device)
# 查看下面的yarn函数解析。此函数运行后可以得到两个参数:缩放系数根号t,以及θnew

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
# 在扩展模型窗长的时候,scale变化,max_position_embeddings也应该乘以相应倍数?
t = torch.arange(
self.max_seq_len_cached,
device=self.inv_freq.device,
dtype=self.inv_freq.dtype,
)

# 下面的没啥好说了,和RoPE一样不变,因为yarn只是改变θ
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
dtype = torch.get_default_dtype()

self.register_buffer(
"cos_cached",
(emb.cos() * self.mscale)[None, None, :, :].to(dtype),
persistent=False,
)
self.register_buffer(
"sin_cached",
(emb.sin() * self.mscale)[None, None, :, :].to(dtype),
persistent=False,
)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len

t = torch.arange(
self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype
)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)

self.register_buffer(
"cos_cached",
(emb.cos() * self.mscale)[None, None, :, :].to(x.dtype),
persistent=False,
)
self.register_buffer(
"sin_cached",
(emb.sin() * self.mscale)[None, None, :, :].to(x.dtype),
persistent=False,
)
# return (
# self.cos_cached[:seq_len].to(dtype=x.dtype),
# self.sin_cached[:seq_len].to(dtype=x.dtype),
# )
return (
self.cos_cached[:, :, :, ...].to(dtype=x.dtype),
self.sin_cached[:, :, :, ...].to(dtype=x.dtype),
)

def yarn(self, device):
pos_freqs = self.base ** (
torch.arange(0, self.dim, 2).float().to(device) / self.dim
)
inv_freq_extrapolation = 1.0 / pos_freqs # 外推
inv_freq_interpolation = 1.0 / (self.scale * pos_freqs) # 内推/内插

low, high = _yarn_find_correction_range(
self.beta_fast,
self.beta_slow,
self.dim,
self.base,
self.original_max_position_embeddings,
)
inv_freq_mask = (
1 - _yarn_linear_ramp_mask(low, high, self.dim // 2).float().to(device) # 获得的是一个递减的list[1, 1, 0.9, 0.8 ..... 0], 这个其实是系数,决定你当前这个位置的position是否外推/内插/不变
) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
# 此处的extrapolation_factor,可以看到默认值,是1,这个是一个总体的缩放系数,一般不需要更改
inv_freq = (
inv_freq_interpolation * (1 - inv_freq_mask)
+ inv_freq_extrapolation * inv_freq_mask
)
# 此处的逻辑,是将任何位置的rope都看成是内插 + 外推的线性组合。看论文中的公式,这一步很好理解。如果是看苏神的博客的话。那么外推inv_freq_extrapolation就是θi,内插inv_freq_interpolation则是θi * L_train / L_test。此处用scale来直接代表,scale其实就是位置编码总体上扩展的大小。如果是8.0,则是将原来的4k窗长扩展到32k,当然后来还需要进行相应的ft以适应这种变化。
# 总的来说,这一步得到了inv_freq,也就是所有的Θnew被确定下来了
#
#
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.mscale = float(
_yarn_get_mscale(self.scale) * self.attn_factor
) # Get n-d magnitude scaling corrected for interpolation
# 此处attn总体的缩放系数和之前的extrapolation_factor一样,也默认为1,且通常不需要更改。
# 为什么要有这两类系数,也许是做实验时候进行网格搜索之类的,结果发现不变更好;也有可能是其他方法对attention进行缩放的话,如L2norm还是啥约束attn的大小,从而使得参数更容易稀疏化,就可以在这一步进行
# 此外,yarn还有revise版本,其没有分开计算外推和内插:
# inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
# inv_freq = inv_freq / ((1-inv_freq_mask)*self.scale + inv_freq_mask)
# 以具体数值可能更好理解:[1, 1, 0.9,...0, 0] -> [0, 0, 0.1, ...., 1, 1] -> [0, 0, 0.8, ....8, 8] -> [1, 1, 1.7, .....8, 8]
# 具体的原因好像在于底层算子未按照预期运行?具体issue讨论在yarn下:https://github.com/jquesnelle/yarn/issues/24。此外,就算正确运行,这个revise似乎也有概率效果更好,但不保证奥

class LlamaDynamicYaRNScaledRotaryEmbedding(torch.nn.Module):
def __init__(
self,
dim,
max_position_embeddings=2048,
base=10000,
original_max_position_embeddings=2048,
extrapolation_factor=1,
attn_factor=1,
beta_fast=32,
beta_slow=1,
finetuned=False,
device=None,
):
super().__init__()

self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.original_max_position_embeddings = original_max_position_embeddings
self.extrapolation_factor = extrapolation_factor
self.attn_factor = attn_factor
self.beta_fast = beta_fast
self.beta_slow = beta_slow

# 只是在这里加了判断,如果处于fintune阶段,则重新以yarn计算rope,否则原生rope
# 这个类的本质,是将yarn集成到RoPE中,支持rope原生训练以及微调时方便地使用yarn
if finetuned:
self.yarn(
self.max_position_embeddings / self.original_max_position_embeddings,
device,
)
else:
inv_freq = 1.0 / (
base ** (torch.arange(0, dim, 2).float().to(device) / dim)
)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.mscale = 1

# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(
self.max_seq_len_cached, device=self.inv_freq.device, dtype=torch.float32
)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
dtype = torch.get_default_dtype()

self.register_buffer(
"cos_cached", (emb.cos() * self.mscale).to(dtype), persistent=False
)
self.register_buffer(
"sin_cached", (emb.sin() * self.mscale).to(dtype), persistent=False
)

def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len

self.yarn(seq_len / self.max_position_embeddings, x.device)

t = torch.arange(
self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype
)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)

self.register_buffer(
"cos_cached", (emb.cos() * self.mscale).to(x.dtype), persistent=False
)
self.register_buffer(
"sin_cached", (emb.sin() * self.mscale).to(x.dtype), persistent=False
)
return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)

def yarn(self, scale, device):
pos_freqs = self.base ** (
torch.arange(0, self.dim, 2).float().to(device) / self.dim
)
inv_freq_extrapolation = 1.0 / pos_freqs
inv_freq_interpolation = 1.0 / (scale * pos_freqs)

low, high = _yarn_find_correction_range(
self.beta_fast,
self.beta_slow,
self.dim,
self.base,
self.original_max_position_embeddings,
)
inv_freq_mask = (
1 - _yarn_linear_ramp_mask(low, high, self.dim // 2).float().to(device)
) * self.extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
inv_freq = (
inv_freq_interpolation * (1 - inv_freq_mask)
+ inv_freq_extrapolation * inv_freq_mask
)

self.register_buffer("inv_freq", inv_freq, persistent=False)
self.mscale = float(
_yarn_get_mscale(scale) * self.attn_factor
) # Get n-d magnitude scaling corrected for interpolation

以上是yarn positional embedding的类代码
由此,yarn代码解析完毕。
什么嘛!做完之后还是挺通畅简单的嘛!

PoSE代码解析

pose代码如下,这是核心代码,实际上不需要在modeling里面修改,只需要在tokenize之后、传入model之前进行操作

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
def train_preprocess_function_randomized(examples, tokenizer, scaled_max_position_embeddings, model_max_position_embeddings):

inputs = examples["text"]
# 获取输入的query/src/text
model_inputs = tokenizer(inputs, padding=False, truncation=True, max_length=model_max_position_embeddings)
# tokenize,返回token list,也有可能是多个句子的token list
position_ids = [torch.arange(len(ids), dtype=torch.long) for ids in model_inputs["input_ids"]]
# 对于多个句子, ids即一个句子tokenize后的id list,position_ids返回的是句子长度的[1, 2, 3, .. , seq_len],

for pos_ids in position_ids:
len_pos_ids = len(pos_ids)
# 获取的seq_len?
tot_pos_list = list(range(scaled_max_position_embeddings))
# 扩展后的scaled_max_position_embeddings, 代码里显示scaled_max_position_embeddings=int(training_args.model_max_position_embeddings * training_args.rope_scaling_factor)
new_pos_list = random.sample(tot_pos_list, len_pos_ids)
# 从扩展后的pose id里面,随机筛选出当前seq_len的个数
new_pos_list.sort()
# 排序后,应该是一个打乱的非连续的pos_id,本身token 位置对应的1, 2, 3...这类顺序pose不存在了
pos_ids[:] = torch.tensor(new_pos_list, dtype=torch.long)
# tensor化
model_inputs["position_ids"] = position_ids
# 更新
model_inputs["labels"] = model_inputs["input_ids"]
# token的label不变,即embedding不变, 但这一步有必要吗?

return model_inputs
# 上面这个纯随机在实际中效果不好,因而不用;默认使用的是下面的有选择截断拼接

def train_preprocess_function_pose(examples, tokenizer, scaled_max_position_embeddings, model_max_position_embeddings):

inputs = examples["text"]
raw_model_inputs = tokenizer(inputs, padding=False, truncation=True, max_length=model_max_position_embeddings*5)
# max_length变成五倍,意思是可以接受更长文本的输入。至于为什么是5,可能的一个原因是,单张80g的卡最多同时放下10000个token。pose再高单张卡放不下了。
input_ids = []
position_ids = []

for ids in raw_model_inputs["input_ids"]:
# ids 即为一个句子tokenize之后的id list
len_chunk = min(len(ids), model_max_position_embeddings)
len_input = len(ids)
# 下面将ids切成两片
lt1 = 0
rt1 = random.randint(1, (len_chunk+1)//2)
# 第一个切片,起始点是0, 终止点最多在:如果句子长,则在原来模型的窗长的一半内,通常是2048 // 2; 如果句子短,则是句子长度一半以内
rt2 = random.randint(lt1+len_chunk, len_input)
lt2 = rt2 - (len_chunk - (rt1-lt1))
# 第二个切片, 长度是len_chunk - 第一个切片的长度。终止点一定是超过了len_chunk的,起始点未必
chunked_ids = ids[lt1:rt1] + ids[lt2:rt2]
# 拼接得到新的chunk id,也就是token id的list
input_ids.append(chunked_ids)
# 把新得到的这一段切片加到input ids里
pos_ids = torch.arange(len(chunked_ids), dtype=torch.long)
# tensor化
len_pos_ids = len(pos_ids)
# lt = random.randint(0, scaled_max_position_embeddings-len_pos_ids)
lt = 0 # this revision makes the coverage possiblity more uniform for large relative positions
rt = random.randint(lt, scaled_max_position_embeddings-len_pos_ids)
#
pos_ids[:rt1-lt1] += lt
# 第一个切片,不变 + 0
pos_ids[rt1-lt1:] += rt
# 第二个切片后面加上一个随机整数
position_ids.append(pos_ids)
# 处理完毕
# 这一步的操作,是将过长的文本进行切分与拼接,这个切分与拼接包含token id以及对应的rope id,之后会返回新的拼接后的结果。
# 也就是说,中间的那一部分被扔掉了。
# 本质上,这个方法是:当输入文本没超过之前的窗长,则文本不变,pose由本来的[1, 2, ....seq_len]切成两部分,第一部分不变(代码中+0),第二部分加一个扩展后窗长(如0 - 2048*5)内的随机数,这样会让文本的位置编码变成两部分的拼接。虽然比不上直接用全数据,但有时候没有长数据可以用这种方法来将之前的文本长度扩展,实际上还是妥协的方法
# 当文本超过了之前模型的窗长, 则将其限制在之前模型能处理的窗长内,具体为随机选取两段拼接。中间的部分也扔掉。至于位置编码,则和原来的位置也关系不大,第二个切片的位置编码随机加了整数,和超过窗长的文本本来的位置也对不上
# 回顾上面的lt1、rt1、lt2、rt2,如果len_input没有超过最大窗长(如2048),实际上inpput ids是没变的,如果超过,则选取一部分。关于位置编码的部分,第二部分其实都是[50,51 52.. 100]这种list每个元素加了一个随机数。和原来的没关系
# 总感觉在超出窗长文本位置编码的处理上有更好的方法?可不可以直接保留原来的位置呢?
model_inputs = {"input_ids": input_ids, "position_ids": position_ids, "labels": input_ids}

return model_inputs

def train_preprocess_function_pi(examples, tokenizer, scaled_max_position_embeddings, model_max_position_embeddings):

inputs = examples["text"]
model_inputs = tokenizer(inputs, padding=False, truncation=True, max_length=scaled_max_position_embeddings)
position_ids = [torch.arange(len(ids), dtype=torch.long) for ids in model_inputs["input_ids"]]
model_inputs["position_ids"] = position_ids
model_inputs["labels"] = model_inputs["input_ids"]

return model_inputs

def test_preprocess_function(examples, tokenizer, inference_length):

inputs = examples["text"]
model_inputs = tokenizer(inputs, padding=False, truncation=True, max_length=inference_length)
position_ids = [torch.arange(len(ids), dtype=torch.long) for ids in model_inputs["input_ids"]]
model_inputs["position_ids"] = position_ids
model_inputs["labels"] = model_inputs["input_ids"]

return model_inputs