外观
单头注意力
约 3189 字大约 11 分钟
设计思路与核心概念
1. 单头注意力的背景与动机
单头注意力(Single-Head Attention)是完整自注意力机制的核心实现,相比简化版本引入了Query、Key、Value三个独立的表示空间,主要改进包括:
- 表示空间分离:通过独立的线性变换将输入映射到不同的语义空间
- 灵活的相似度计算:Query和Key在专门的空间中计算相似度,更加精确
- 可学习的特征提取:Value向量承载实际传递的信息,可以学习到最优的特征表示
- 缩放机制:引入缩放因子防止注意力分数过大导致的梯度问题
2. 核心设计思想
单头注意力的核心思想是将输入通过三个不同的线性变换分别映射到Query、Key、Value空间:
- Query(查询):表示"我想要什么信息",用于主动查找相关内容
- Key(键):表示"我能提供什么信息",用于被动匹配查询
- Value(值):表示"我实际传递的信息",是最终聚合的内容
- 注意力机制:通过Query-Key匹配决定Value的权重分配
3. 数学公式
完整的单头注意力公式
对于输入序列 X∈Rn×din,单头注意力的计算过程为:
线性变换生成QKV:
Q=XWQ,K=XWK,V=XWV
注意力分数计算:
S=QKT
缩放和归一化:
A=softmax(dkS)
上下文向量计算:
Attention(Q,K,V)=AV
其中:
- WQ,WK,WV∈Rdin×dout 是可学习的权重矩阵
- dk=dout 是Key向量的维度
- dk 是缩放因子,防止内积过大
执行流程
1. 整体执行流程图
2. 详细计算流程图
流程图节点详解
节点 | 操作描述 | 技术细节 | 矩阵维度变化 |
---|---|---|---|
输入词嵌入矩阵 | 接收序列的词向量表示 | 每个词用d_in维向量表示 | (seq_len, d_in) |
生成Query矩阵 | 通过W_Q线性变换 | self.W_query(x) | (seq_len, d_in) → (seq_len, d_out) |
生成Key矩阵 | 通过W_K线性变换 | self.W_key(x) | (seq_len, d_in) → (seq_len, d_out) |
生成Value矩阵 | 通过W_V线性变换 | self.W_value(x) | (seq_len, d_in) → (seq_len, d_out) |
计算注意力分数 | Query与Key转置相乘 | queries @ keys.T | (seq_len, d_out) @ (d_out, seq_len) → (seq_len, seq_len) |
缩放处理 | 除以sqrt(d_k)防止梯度消失 | scores / sqrt(d_out) | 数值缩放,维度不变 |
Softmax归一化 | 转换为概率分布 | torch.softmax(..., dim=-1) | 每行和为1,维度不变 |
加权求和 | 注意力权重与Value相乘 | attn_weights @ values | (seq_len, seq_len) @ (seq_len, d_out) → (seq_len, d_out) |
计算步骤详解
步骤 | 操作 | PyTorch代码 | 数学表达式 | 说明 |
---|---|---|---|---|
1 | 输入序列 | x | X∈Rn×din | n个词,每个词d_in维 |
2 | 生成Query | self.W_query(x) | Q=XWQ | 查询矩阵,形状(n, d_out) |
3 | 生成Key | self.W_key(x) | K=XWK | 键矩阵,形状(n, d_out) |
4 | 生成Value | self.W_value(x) | V=XWV | 值矩阵,形状(n, d_out) |
5 | 计算分数 | queries @ keys.T | S=QKT | 注意力分数矩阵(n, n) |
6 | 缩放处理 | attn_scores / keys.shape[-1]**0.5 | dkS | 防止梯度消失 |
7 | Softmax归一化 | torch.softmax(..., dim=-1) | A=softmax(dkS) | 转换为概率分布 |
8 | 加权求和 | attn_weights @ values | C=AV | 最终上下文向量(n, d_out) |
完整代码实现
单头注意力.py
# 导入PyTorch库及其神经网络模块
import torch
import torch.nn as nn
# 定义自注意力机制类(版本2)
class SelfAttention_v2(nn.Module):
def __init__(self, d_in, d_out, qkv_bias=False):
"""
初始化函数
d_in: 输入特征维度(词嵌入向量长度)
d_out: 输出特征维度(QKV向量空间维度)
qkv_bias: 是否启用偏置项
"""
super().__init__()
# 定义三个独立的线性变换层,用于生成查询向量(Query)、键向量(Key)和值向量(Value)
# W_query将输入从d_in维映射到d_out维:Q = XW_query^T + b
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
# 同理生成Key向量:K = XW_key^T + b
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
# 同理生成Value向量:V = XW_value^T + b
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
def forward(self, x):
"""
前向传播函数
x: 输入张量形状为(seq_len, d_in),seq_len为序列长度
"""
# 通过线性变换生成Key向量 (seq_len, d_out)
keys = self.W_key(x)
# 生成Query向量 (seq_len, d_out)
queries = self.W_query(x)
# 生成Value向量 (seq_len, d_out)
values = self.W_value(x)
# 计算注意力分数矩阵:QK^T (seq_len, seq_len)
# 每个元素表示query_i和key_j的点积相似度
attn_scores = queries @ keys.T # 矩阵乘法:(n,d) @ (d,m) => (n,m)
# 计算注意力权重(应用softmax归一化)
# 缩放因子:除以√d_k(key的维度)防止内积过大导致梯度消失
# softmax(QK^T/√d_k) => (seq_len, seq_len)
attn_weights = torch.softmax(
attn_scores / keys.shape[-1]**0.5, # 分母为√d_out
dim=-1 # 在最后一个维度进行softmax
)
# 计算上下文向量:注意力权重乘以Value向量
# (seq_len, seq_len) @ (seq_len, d_out) => (seq_len, d_out)
# 每个位置的输出向量是所有Value向量的加权平均
context_vec = attn_weights @ values
return context_vec
# 定义输入数据:6个词的嵌入向量(词表征)
inputs = torch.tensor(
[[0.43, 0.15, 0.89], # 第1个词 "Your" 的嵌入向量
[0.55, 0.87, 0.66], # 第2个词 "journey" 的嵌入向量
[0.57, 0.85, 0.64], # 第3个词 "starts" 的嵌入向量
[0.22, 0.58, 0.33], # 第4个词 "with" 的嵌入向量
[0.77, 0.25, 0.10], # 第5个词 "one" 的嵌入向量
[0.05, 0.80, 0.55]] # 第6个词 "step" 的嵌入向量
)
# 获取输入维度d_in(词向量长度=3)
d_in = inputs.shape[-1]
# 设置输出维度d_out(QKV空间维度=2)
d_out = 2
# 设置随机种子以保证结果可复现
torch.manual_seed(789)
# 实例化自注意力模型
sa_v2 = SelfAttention_v2(d_in, d_out)
# 前向传播计算并打印结果
print(sa_v2(inputs))
代码详细解析
1. 类设计架构
SelfAttention_v2类结构
class SelfAttention_v2(nn.Module):
def __init__(self, d_in, d_out, qkv_bias=False):
# 参数初始化
def forward(self, x):
# 前向传播逻辑
设计要点:
- 继承自
nn.Module
,符合PyTorch模块化设计 - 支持可配置的输入输出维度
- 可选的偏置项设置,增加模型灵活性
2. 核心参数分析
线性变换层设计
参数 | 类型 | 形状 | 作用 |
---|---|---|---|
W_query | nn.Linear | (d_in, d_out) | 将输入映射到查询空间 |
W_key | nn.Linear | (d_in, d_out) | 将输入映射到键空间 |
W_value | nn.Linear | (d_in, d_out) | 将输入映射到值空间 |
qkv_bias | bool | - | 是否使用偏置项 |
维度变换分析
# 输入: (seq_len, d_in) = (6, 3)
# 经过线性变换后:
# Q: (seq_len, d_out) = (6, 2)
# K: (seq_len, d_out) = (6, 2)
# V: (seq_len, d_out) = (6, 2)
3. 前向传播详解
步骤1:QKV生成
keys = self.W_key(x) # (6, 3) -> (6, 2)
queries = self.W_query(x) # (6, 3) -> (6, 2)
values = self.W_value(x) # (6, 3) -> (6, 2)
关键点:
- 独立变换:三个线性层参数独立,学习不同的表示
- 维度映射:从输入维度d_in映射到输出维度d_out
- 语义分离:Q关注查询,K关注匹配,V关注传递
步骤2:注意力分数计算
attn_scores = queries @ keys.T # (6, 2) @ (2, 6) -> (6, 6)
数学含义:
- 点积相似度:计算每个Query与所有Key的相似度
- 矩阵形状:结果为(seq_len, seq_len)的方阵
- 对称性:由于使用相同输入,矩阵具有一定对称性
步骤3:缩放和归一化
attn_weights = torch.softmax(
attn_scores / keys.shape[-1]**0.5, # 缩放因子 = √d_out
dim=-1
)
缩放机制分析:
- 缩放因子:dk1 其中 dk=dout
- 数值稳定性:防止点积过大导致softmax饱和
- 梯度优化:保持梯度在合理范围内
步骤4:上下文聚合
context_vec = attn_weights @ values # (6, 6) @ (6, 2) -> (6, 2)
聚合机制:
- 加权平均:每个位置的输出是所有Value的加权组合
- 信息传递:权重决定了信息传递的强度
- 维度保持:输出维度与Value维度相同
4. 与简化版本的对比
功能对比表
特性 | 简化版自注意力 | 单头注意力 |
---|---|---|
输入处理 | 直接使用原始输入 | 通过线性变换生成QKV |
表示空间 | 单一空间 | 三个独立空间 |
可学习参数 | 无 | 3个线性变换矩阵 |
缩放机制 | 无 | 有 dk1 |
灵活性 | 固定 | 可配置输入输出维度 |
表达能力 | 有限 | 更强 |
优势分析
- 表示能力增强:独立的QKV空间提供更丰富的表示
- 可学习性:参数可以通过训练优化
- 数值稳定性:缩放机制改善训练稳定性
- 维度灵活性:支持不同的输入输出维度
5. 实验配置分析
输入数据设计
inputs = torch.tensor([
[0.43, 0.15, 0.89], # "Your"
[0.55, 0.87, 0.66], # "journey"
[0.57, 0.85, 0.64], # "starts"
[0.22, 0.58, 0.33], # "with"
[0.77, 0.25, 0.10], # "one"
[0.05, 0.80, 0.55] # "step"
])
数据特点:
- 序列长度:6个词
- 嵌入维度:3维向量
- 数值范围:[0, 1]之间的浮点数
- 语义模拟:模拟真实的词嵌入表示
模型配置
d_in = 3 # 输入维度(词嵌入维度)
d_out = 2 # 输出维度(注意力头维度)
配置说明:
- 维度压缩:从3维压缩到2维,便于观察
- 计算效率:较小的维度降低计算复杂度
- 教学目的:便于理解和可视化
6. 输出结果解读
预期输出格式
# 输出形状: (6, 2)
tensor([
[x11, x12], # "Your"的上下文向量
[x21, x22], # "journey"的上下文向量
[x31, x32], # "starts"的上下文向量
[x41, x42], # "with"的上下文向量
[x51, x52], # "one"的上下文向量
[x61, x62] # "step"的上下文向量
])
结果特征:
- 维度变化:从(6,3)变为(6,2)
- 上下文信息:每个向量包含全局上下文
- 可学习性:通过训练可以优化表示质量
高级特性与扩展
1. 缩放机制的重要性
为什么需要缩放?
# 不缩放的问题演示
def attention_without_scaling(Q, K, V):
scores = Q @ K.T # 可能导致数值过大
weights = torch.softmax(scores, dim=-1) # softmax饱和
return weights @ V
# 带缩放的改进版本
def scaled_attention(Q, K, V):
d_k = K.size(-1)
scores = Q @ K.T / math.sqrt(d_k) # 缩放处理
weights = torch.softmax(scores, dim=-1)
return weights @ V
缩放效果分析
- 数值稳定性:防止点积过大
- 梯度优化:保持梯度在合理范围
- 训练稳定性:减少训练过程中的数值问题
2. 偏置项的作用
有偏置 vs 无偏置
# 无偏置版本
W_q_no_bias = nn.Linear(d_in, d_out, bias=False)
# 有偏置版本
W_q_with_bias = nn.Linear(d_in, d_out, bias=True)
偏置项的影响:
- 表达能力:增加模型的非线性表达能力
- 训练稳定性:有助于模型收敛
- 参数数量:增加少量参数但提升性能
3. 批处理支持
扩展到批处理版本
class BatchSelfAttention(nn.Module):
def __init__(self, d_in, d_out, qkv_bias=False):
super().__init__()
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
def forward(self, x):
# x: (batch_size, seq_len, d_in)
B, N, D = x.shape
Q = self.W_query(x) # (B, N, d_out)
K = self.W_key(x) # (B, N, d_out)
V = self.W_value(x) # (B, N, d_out)
# 批量矩阵乘法
scores = torch.bmm(Q, K.transpose(1, 2)) # (B, N, N)
scores = scores / math.sqrt(K.size(-1))
weights = torch.softmax(scores, dim=-1)
context = torch.bmm(weights, V) # (B, N, d_out)
return context
4. 可视化工具
注意力权重可视化
import matplotlib.pyplot as plt
import seaborn as sns
def visualize_attention_weights(model, inputs, words):
"""可视化注意力权重矩阵"""
with torch.no_grad():
# 获取中间结果
Q = model.W_query(inputs)
K = model.W_key(inputs)
scores = Q @ K.T
weights = torch.softmax(scores / K.shape[-1]**0.5, dim=-1)
plt.figure(figsize=(8, 6))
sns.heatmap(weights.numpy(),
xticklabels=words,
yticklabels=words,
annot=True,
cmap='Blues',
fmt='.3f')
plt.title('Single-Head Attention Weights')
plt.xlabel('Key (Attended to)')
plt.ylabel('Query (Attending from)')
plt.show()
# 使用示例
words = ["Your", "journey", "starts", "with", "one", "step"]
visualize_attention_weights(sa_v2, inputs, words)
应用场景与实践
1. 在Transformer中的地位
- 编码器:自注意力层的基础组件
- 解码器:自注意力和交叉注意力的实现基础
- 多头注意力:多个单头注意力的并行组合
2. 性能优化建议
- 维度选择:根据任务复杂度选择合适的d_out
- 初始化策略:使用Xavier或He初始化
- 正则化:可以添加dropout防止过拟合
3. 扩展方向
- 多头注意力:并行多个单头注意力
- 相对位置编码:增加位置信息
- 稀疏注意力:减少计算复杂度
更新日志
2025/8/18 00:31
查看所有更新日志
bd1d0
-迁移目录于b0f2a
-docs: 完善大模型学习文档 - 增加设计思路与执行流程于dfb81
-update于c2341
-update于f241e
-update于dc6b2
-update于
版权所有
版权归属:NateHHX