外观
控制随机性的解码策略
约 1560 字大约 5 分钟
代码思路与设计
核心思想
本文档实现了三种控制大语言模型文本生成随机性的解码策略:
- 温度缩放(Temperature Scaling):通过调节温度参数控制概率分布的平滑度
- Top-k采样:只从概率最高的k个候选词中采样
- 综合解码函数:结合温度缩放和Top-k采样的完整生成函数
设计目标
- 提供可控的文本生成质量
- 平衡生成文本的创造性和连贯性
- 支持不同应用场景的需求(创意写作 vs 技术文档)
解码策略流程图
1. 温度缩放实现
import torch
import tiktoken
import matplotlib.pyplot as plt
# 模型设置
model.to("cpu")
model.eval()
# 基础文本生成
tokenizer = tiktoken.get_encoding("gpt2")
token_ids = generate_text_simple(
model=model,
idx=text_to_token_ids("Every effort moves you", tokenizer),
max_new_tokens=25,
context_size=GPT_CONFIG_124M["context_length"]
)
print("Output text:\n", token_ids_to_text(token_ids, tokenizer))
# 示例词汇表和logits
vocab = {
"closer": 0,
"every": 1,
"effort": 2,
"forward": 3,
"inches": 4,
"moves": 5,
"pizza": 6,
"toward": 7,
"you": 8,
}
inverse_vocab = {v: k for k, v in vocab.items()}
# 模拟下一个token的logits
next_token_logits = torch.tensor(
[4.51, 0.89, -1.90, 6.75, 1.63, -1.62, -1.89, 6.28, 1.79]
)
# 标准softmax概率
probas = torch.softmax(next_token_logits, dim=0)
next_token_id = torch.argmax(probas).item()
print(f"贪心选择: {inverse_vocab[next_token_id]}")
# 随机采样
torch.manual_seed(123)
next_token_id = torch.multinomial(probas, num_samples=1).item()
print(f"随机采样: {inverse_vocab[next_token_id]}")
# 采样频率统计函数
def print_sampled_tokens(probas):
torch.manual_seed(123)
sample = [torch.multinomial(probas, num_samples=1).item()
for i in range(1_000)]
sampled_ids = torch.bincount(torch.tensor(sample))
for i, freq in enumerate(sampled_ids):
print(f"{freq} x {inverse_vocab[i]}")
print("\n采样频率统计:")
print_sampled_tokens(probas)
# 温度缩放函数
def softmax_with_temperature(logits, temperature):
scaled_logits = logits / temperature
return torch.softmax(scaled_logits, dim=0)
# 不同温度下的概率分布
temperatures = [1, 0.1, 5]
scaled_probas = [softmax_with_temperature(next_token_logits, T)
for T in temperatures]
# 可视化温度效果
x = torch.arange(len(vocab))
bar_width = 0.15
fig, ax = plt.subplots(figsize=(5, 3))
for i, T in enumerate(temperatures):
rects = ax.bar(x + i * bar_width, scaled_probas[i],
bar_width, label=f'Temperature = {T}')
ax.set_ylabel('Probability')
ax.set_xticks(x)
ax.set_xticklabels(vocab.keys(), rotation=90)
ax.legend()
plt.tight_layout()
plt.show()
2. Top-k采样实现
# Top-k采样示例
top_k = 3
top_logits, top_pos = torch.topk(next_token_logits, top_k)
print("Top logits:", top_logits)
print("Top positions:", top_pos)
# 将非Top-k的logits设为负无穷
new_logits = torch.where(
condition=next_token_logits < top_logits[-1],
input=torch.tensor(float('-inf')),
other=next_token_logits
)
print("过滤后的logits:", new_logits)
# 计算Top-k概率分布
topk_probas = torch.softmax(new_logits, dim=0)
print("Top-k概率分布:", topk_probas)
3. 综合解码函数
def generate(model, idx, max_new_tokens, context_size,
temperature=0.0, top_k=None, eos_id=None):
"""
综合解码函数,支持温度缩放和Top-k采样
Args:
model: 语言模型
idx: 输入token序列
max_new_tokens: 最大生成token数
context_size: 上下文窗口大小
temperature: 温度参数,0表示贪心解码
top_k: Top-k采样参数
eos_id: 结束token ID
"""
for _ in range(max_new_tokens):
# 截取上下文窗口
idx_cond = idx[:, -context_size:]
with torch.no_grad():
# 模型前向传播
logits = model(idx_cond)
logits = logits[:, -1, :] # 获取最后一个位置的logits
# 应用Top-k过滤
if top_k is not None:
top_logits, _ = torch.topk(logits, top_k)
min_val = top_logits[:, -1]
logits = torch.where(
logits < min_val,
torch.tensor(float('-inf')).to(logits.device),
logits
)
# 应用温度缩放和采样
if temperature > 0.0:
logits = logits / temperature
probs = torch.softmax(logits, dim=-1)
idx_next = torch.multinomial(probs, num_samples=1)
else:
# 贪心解码
idx_next = torch.argmax(logits, dim=-1, keepdim=True)
# 检查结束条件
if idx_next == eos_id:
break
# 拼接新token
idx = torch.cat((idx, idx_next), dim=1)
return idx
# 使用综合解码函数生成文本
torch.manual_seed(123)
token_ids = generate(
model=model,
idx=text_to_token_ids("Every effort moves you", tokenizer),
max_new_tokens=15,
context_size=GPT_CONFIG_124M["context_length"],
top_k=25,
temperature=1.4
)
print("生成文本:\n", token_ids_to_text(token_ids, tokenizer))
代码执行结果
1. 温度缩放效果
贪心选择: forward
随机采样: forward
采样频率统计:
0 x closer
0 x every
0 x effort
518 x forward
0 x inches
0 x moves
0 x pizza
482 x toward
0 x you
分析:
- 贪心选择总是选择概率最高的"forward"
- 在1000次随机采样中,"forward"出现518次,"toward"出现482次
- 其他低概率词汇几乎不被选中
2. 不同温度参数的影响
温度值 | 效果 | 适用场景 |
---|---|---|
T = 0.1 | 概率分布更加尖锐,接近贪心选择 | 技术文档、事实性文本 |
T = 1.0 | 标准概率分布 | 平衡的文本生成 |
T = 5.0 | 概率分布更加平滑,增加随机性 | 创意写作、诗歌生成 |
3. Top-k采样结果
Top logits: tensor([6.75, 6.28, 4.51])
Top positions: tensor([3, 7, 0])
过滤后的logits: tensor([4.51, -inf, -inf, 6.75, -inf, -inf, -inf, 6.28, -inf])
Top-k概率分布: tensor([0.0474, 0.0000, 0.0000, 0.5106, 0.0000, 0.0000, 0.0000, 0.4420, 0.0000])
分析:
- Top-3采样只保留了"closer"、"forward"、"toward"三个候选词
- 其他词汇的概率被设为0
- 有效减少了低质量候选词的干扰
4. 综合解码策略生成文本
生成文本:
Every effort moves you toward a better understanding of the complex dynamics that drive success in
参数设置:
top_k=25
:从前25个候选词中选择temperature=1.4
:适度增加随机性max_new_tokens=15
:生成15个新token
效果评估:
- 文本连贯性好,语法正确
- 保持了一定的创造性和多样性
- 避免了过度重复和低质量输出
技术总结
核心优势
- 可控性强:通过调节温度和Top-k参数精确控制生成质量
- 适应性好:不同参数组合适用于不同应用场景
- 计算高效:相比复杂的解码算法,计算开销较小
参数选择建议
- 创意写作:
temperature=1.2-2.0, top_k=40-100
- 技术文档:
temperature=0.3-0.7, top_k=10-25
- 对话系统:
temperature=0.8-1.2, top_k=25-50
- 代码生成:
temperature=0.1-0.5, top_k=5-15
实际应用价值
这些解码策略在实际的大语言模型部署中广泛应用,是控制模型输出质量的重要工具。通过合理的参数调节,可以在保持文本连贯性的同时,实现所需的创造性水平。
更新日志
2025/8/19 16:09
查看所有更新日志
245db
-feat: AI实验室文档结构优化与代码整理 v1.0.24于b52e1
-feat: 新增无标签数据上进行预训练章节于
版权所有
版权归属:NateHHX