目前,生成式任务中模型的解码策略有五种,分别为:贪心搜索(Greedy Search)、束搜索(Beam Search)、top-k、top-p、随机采样。
贪心搜索
在t时刻中,选取当前概率最大的词。

这种采样方式只考虑当前时刻的最优解,忽略了当前低概率词后面的高概率词。
1 | # 代码来自perplexity.ai |
束搜索
束搜索以句子为单位,返回可能性最大的输出序列列表。
每一步解码时,仅保留前K个可能的结果。例如在第一步解码时,我们选择前K个可能的y,分别代入第二步解码中,各取前k个候选词,即得到k^2个候选组合,最后保留概率乘积最大的前k个候选结果。
Transformers中的Beam Search高效实现
Beam Search快速理解及代码解析
Transformers仓库做语言生成的解码方法介绍
使用Transformers做限制集束搜索(Constrained Beam Search)的文本生成
1 | #代码来自perplexity.ai |
top-k
top-k是对贪心搜索策略的优化,从其排名前k的token中进行采样。对于每个时刻t,首先筛选出k个备选token,重新计算这K个token的概率值,然后使用multinomial方法进行采样,采样时会优先取概率高的值。
1 | #代码来自perplexity.ai,不知道是否正确。 |
torch.multinomial(input, num_samples, replacement=False)
torch.multinominal方法可以根据给定权重对数组进行多次采样,返回采样后的元素下标。
- 参数说明
input: 必须是torch.Tensor类型,即概率分布。可以是一维或者二维,不必手动归一化。
num_samples:采样的次数。如果input是二维的,则表示每行的采样次数。
replacement: 采样是否放回。默认为False,即不放回采样,此时num_samples必须小于input中的非零元素。在无放回采样中,input中值为0的元素只有所有其他元素被抽到后,才会被抽到。换句话说,有放回情况下,概率为0的元素永远不会被采样到。 - 按概率采样
1
2
3weights = torch.Tensor([0.2,0.2,0.3,0.3])
torch.multinomial(weights,2)
3,2]) tensor([ - 按频率采样
multinomial()函数的input可以是大于1的数,在函数内部会再次进行归一化。例如在处理文本对word进行采样时,直接传入词典中每个词的词频就好了,不需要搜东归一化。1
2
3weights = torch.Tensor([3, 2, 7, 8])
torch.multinomial(weights, 2)
3, 1]) tensor([ - 有放回采样
采样结果可重复出现,需要将replacement=True。1
2
3weights = torch.Tensor([0, 0.3, 0.7])
torch.multinomial(weights, 10, replacement=True)
2, 1, 2, 2, 2, 1, 2, 1, 2, 1]) tensor([ - 多行同时进行采样
1
2
3
4
5
6
7
8传入的input可以是2维矩阵,此时会分别对每一行按各自的权重进行采样:
weights = torch.Tensor([
[0, 0.3, 0.7],
[0.3, 0.7, 0]
])
torch.multinomial(weights, 2)
2, 1], tensor([[
[1, 0]])
top-p
由于top-k采样方式中很难觉得k值应该取多少,于是出现动态设置token候选列表大小的策略,即核采样(Nucleus Sampling)。在每个时间步中,解码词的概率分布中,头部几个词的出现概率已经占据了绝大部分概率空间,这部分词被称为nucleus。
具体做法是,在每个时间步中,对当前token的概率分布进行排序,并给定一个阈值P,在候选词中概率最高的词开始累积求和,使得他们出现的概率和大于等于p,即可得到一新的候选词集合V_p,然后再对V_p做一次re-scaling,再进行采样。
1 | # 代码来自perplexity.ai 未必正确。 |
temperature温度采样
在概率模型中,logits扮演着能量的角色,可以通过将logits除以温度来实现温度采样,然后将其输入Softmax并获得采样概率,就是直接re-scale原有的概率分布,温度越低(<1)会使模型倾向于高频token,而高于1的温度,则会缩小高频词和低频词之间的差距。
1 | def sample(model, device, tokenizer, start_text, length, temperature=1.0): |