生成式模型解码策略

目前,生成式任务中模型的解码策略有五种,分别为:贪心搜索(Greedy Search)、束搜索(Beam Search)、top-k、top-p、随机采样。

贪心搜索

在t时刻中,选取当前概率最大的词。

这种采样方式只考虑当前时刻的最优解,忽略了当前低概率词后面的高概率词。

1
2
3
4
5
6
7
8
9
10
11
# 代码来自perplexity.ai
def greedy_search(model, device, tokenizer, start_text, length):
input_ids = tokenizer.encode(start_text, return_tensors='pt').to(device)
model.eval()
with torch.no_grad():
for i in range(length):
outputs = model(input_ids)
logits = outputs.logits[:, -1, :]
next_token = torch.argmax(logits, dim=-1)
input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1)
return tokenizer.decode(input_ids, skip_special_tokens=True)

束搜索

束搜索以句子为单位,返回可能性最大的输出序列列表。
每一步解码时,仅保留前K个可能的结果。例如在第一步解码时,我们选择前K个可能的y,分别代入第二步解码中,各取前k个候选词,即得到k^2个候选组合,最后保留概率乘积最大的前k个候选结果。
Transformers中的Beam Search高效实现
Beam Search快速理解及代码解析
Transformers仓库做语言生成的解码方法介绍
使用Transformers做限制集束搜索(Constrained Beam Search)的文本生成

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
#代码来自perplexity.ai
def beam_search(model, device, tokenizer, start_text, length, beam_size=5):
input_ids = tokenizer.encode(start_text, return_tensors='pt').to(device)
model.eval()
with torch.no_grad():
# 计算初始状态的logits和隐状态
outputs = model(input_ids)
logits = outputs.logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
topk_probs, topk_indices = torch.topk(probs, beam_size, dim=-1)
topk_log_probs = torch.log(topk_probs)
topk_inputs = torch.ones((1, beam_size), dtype=torch.long, device=device) * input_ids[0, -1]
topk_outputs = topk_inputs
topk_scores = topk_log_probs
topk_hidden = outputs.hidden

# 逐步生成文本
for i in range(1, length):
all_outputs = []
all_scores = []
all_hidden = []
for j in range(beam_size):
input_ids = torch.cat([input_ids, topk_inputs[:, j].unsqueeze(-1)], dim=-1)
outputs = model(input_ids, topk_hidden)
logits = outputs.logits[:, -1, :]
probs = F.softmax(logits, dim=-1)
scores = topk_scores[:, j].unsqueeze(-1) + torch.log(probs)
scores = scores.view(-1)
topk_scores, topk_indices = torch.topk(scores, beam_size)
topk_probs = torch.exp(topk_scores)
topk_inputs = topk_indices % probs.shape[-1]
topk_outputs = torch.cat([topk_outputs[:, j].unsqueeze(-1), topk_inputs.unsqueeze(-1)], dim=-1)
topk_hidden = outputs.hidden
all_outputs.append(topk_outputs)
all_scores.append(topk_scores)
all_hidden.append(topk_hidden)
topk_outputs = torch.cat(all_outputs, dim=1)
topk_scores = torch.cat(all_scores, dim=1)
topk_hidden = torch.cat(all_hidden, dim=1)
best_output = topk_outputs[:, 0]
return tokenizer.decode(best_output, skip_special_tokens=True)

top-k

top-k是对贪心搜索策略的优化,从其排名前k的token中进行采样。对于每个时刻t,首先筛选出k个备选token,重新计算这K个token的概率值,然后使用multinomial方法进行采样,采样时会优先取概率高的值。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
#代码来自perplexity.ai,不知道是否正确。
import torch
def top_k_sampling(logits, k):
"""
对logits进行top-k采样
:param logits: 模型输出的logits,​​形状为[batch_size, vocab_size]
:param k: 采样的候选项个数
:return: 采样结果,​​形状为[batch_size]
"""
batch_size, vocab_size = logits.shape
top_k_logits, top_k_indices = torch.topk(logits, k=k, dim=-1)
top_k_probs = torch.nn.functional.softmax(top_k_logits, dim=-1)
top_k_probs /= torch.sum(top_k_probs, dim=-1, keepdim=True)
top_k_samples = torch.multinomial(top_k_probs, num_samples=1)
samples = torch.gather(top_k_samples, dim=-1, index=top_k_indices)
samples = torch.squeeze(samples, dim=-1)
return samples

torch.multinomial(input, num_samples, replacement=False)
torch.multinominal方法可以根据给定权重对数组进行多次采样,返回采样后的元素下标。

  • 参数说明
    input: 必须是torch.Tensor类型,即概率分布。可以是一维或者二维,不必手动归一化。
    num_samples:采样的次数。如果input是二维的,则表示每行的采样次数。
    replacement: 采样是否放回。默认为False,即不放回采样,此时num_samples必须小于input中的非零元素。在无放回采样中,input中值为0的元素只有所有其他元素被抽到后,才会被抽到。换句话说,有放回情况下,概率为0的元素永远不会被采样到。
  • 按概率采样
    1
    2
    3
    weights = torch.Tensor([0.2,0.2,0.3,0.3])
    torch.multinomial(weights,2)
    >>> tensor([3,2])
  • 按频率采样
    multinomial()函数的input可以是大于1的数,在函数内部会再次进行归一化。例如在处理文本对word进行采样时,直接传入词典中每个词的词频就好了,不需要搜东归一化。
    1
    2
    3
    weights = torch.Tensor([3, 2, 7, 8])
    torch.multinomial(weights, 2)
    >>> tensor([3, 1])
  • 有放回采样
    采样结果可重复出现,需要将replacement=True。
    1
    2
    3
    weights = torch.Tensor([0, 0.3, 0.7])
    torch.multinomial(weights, 10, replacement=True)
    >>> tensor([2, 1, 2, 2, 2, 1, 2, 1, 2, 1])
  • 多行同时进行采样
    1
    2
    3
    4
    5
    6
    7
    8
    传入的input可以是2维矩阵,此时会分别对每一行按各自的权重进行采样:
    weights = torch.Tensor([
    [0, 0.3, 0.7],
    [0.3, 0.7, 0]
    ])
    torch.multinomial(weights, 2)
    >>> tensor([[2, 1],
    [1, 0]])

top-p

由于top-k采样方式中很难觉得k值应该取多少,于是出现动态设置token候选列表大小的策略,即核采样(Nucleus Sampling)。在每个时间步中,解码词的概率分布中,头部几个词的出现概率已经占据了绝大部分概率空间,这部分词被称为nucleus。
具体做法是,在每个时间步中,对当前token的概率分布进行排序,并给定一个阈值P,在候选词中概率最高的词开始累积求和,使得他们出现的概率和大于等于p,即可得到一新的候选词集合V_p,然后再对V_p做一次re-scaling,再进行采样。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# 代码来自perplexity.ai 未必正确。
import torch
def top_p_sampling(logits, p=0.9):
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)

# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > p
# Shift the indices to the right to keep the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0

# Scatter sorted tensors to original indexing
indices_to_remove = sorted_indices[sorted_indices_to_remove]
logits[indices_to_remove] = float('-inf')
return torch.multinomial(torch.softmax(logits, dim=-1), 1)

# Example usage
logits = torch.randn(1, 10)
sampled_token = top_p_sampling(logits, p=0.8)

temperature温度采样

在概率模型中,logits扮演着能量的角色,可以通过将logits除以温度来实现温度采样,然后将其输入Softmax并获得采样概率,就是直接re-scale原有的概率分布,温度越低(<1)会使模型倾向于高频token,而高于1的温度,则会缩小高频词和低频词之间的差距。

1
2
3
4
5
6
7
8
9
10
11
def sample(model, device, tokenizer, start_text, length, temperature=1.0):
input_ids = tokenizer.encode(start_text, return_tensors='pt').to(device)
model.eval()
with torch.no_grad():
for i in range(length):
outputs = model(input_ids)
logits = outputs.logits[:, -1, :] / temperature
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
input_ids = torch.cat([input_ids, next_token], dim=-1)
return tokenizer.decode(input_ids, skip_special_tokens=True)