外观
多头注意力
约 3944 字大约 13 分钟
设计思路与核心概念
1. 多头注意力的背景与动机
多头注意力(Multi-Head Attention)是Transformer架构的核心创新,由Vaswani等人在《Attention Is All You Need》中提出,主要解决以下问题:
- 表示多样性:单个注意力头可能只关注某一种语义关系,多头可以捕获不同类型的依赖
- 并行计算:多个头可以并行处理,提高计算效率
- 信息整合:不同头关注不同的表示子空间,最终整合形成丰富的表示
- 模型容量:在相同参数量下,多头机制提供更强的表达能力
2. 核心设计思想
多头注意力的核心思想是将注意力计算分解为多个并行的"头",每个头关注不同的表示子空间:
- 子空间分解:将d_model维度分解为h个d_k维的子空间
- 并行注意力:每个头独立计算注意力,关注不同的语义关系
- 信息聚合:通过线性投影整合多个头的输出
- 参数共享:所有头共享相同的输入,但使用不同的权重矩阵
3. 数学公式
多头注意力的完整公式
对于输入序列 X∈Rn×dmodel,多头注意力的计算过程为:
多头投影:
Qi=XWiQ,Ki=XWiK,Vi=XWiV
其中 WiQ,WiK,WiV∈Rdmodel×dk,dk=dmodel/h
单头注意力计算:
headi=Attention(Qi,Ki,Vi)=softmax(dkQiKiT)Vi
多头拼接:
MultiHead(Q,K,V)=Concat(head1,...,headh)WO
最终输出:
Output=MultiHead(Q,K,V)
其中:
- h 是注意力头的数量
- dk=dv=dmodel/h 是每个头的维度
- WO∈Rdmodel×dmodel 是输出投影矩阵
执行流程
1. 整体执行流程图
2. 详细计算流程图
计算步骤详解
步骤 | 操作 | PyTorch代码 | 数学表达式 | 说明 |
---|---|---|---|---|
1 | 输入序列 | x | X∈Rb×n×din | 批次×序列长度×输入维度 |
2 | 生成QKV | self.W_query(x) | Q=XWQ | 线性投影到d_out维度 |
3 | 重塑多头 | q.view(b, n, h, d_k) | 分解为h个头 | 每头维度d_k = d_out/h |
4 | 转置维度 | q.transpose(1, 2) | (b,h,n,dk) | 便于并行计算 |
5 | 计算分数 | q @ k.transpose(2, 3) | QKT | 注意力分数矩阵 |
6 | 应用掩码 | scores.masked_fill_(mask, -inf) | 屏蔽未来位置 | 因果注意力约束 |
7 | 缩放归一化 | softmax(scores / sqrt(d_k)) | softmax(dkQKT) | 注意力权重 |
8 | 加权求和 | attn_weights @ values | AV | 上下文向量 |
9 | 拼接结果 | context.view(b, n, d_out) | Concat(heads) | 多头拼接 |
10 | 输出投影 | self.out_proj(context) | OutputWO | 最终线性变换 |
完整代码实现
多头注意力.py
# 导入PyTorch库
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
"""多头注意力机制模块 (Multi-Head Attention)
实现公式:
MultiHead(Q,K,V) = Concat(head_1,...,head_h)W^O
where head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)
Args:
d_in (int): 输入特征维度(原始嵌入维度)
d_out (int): 输出特征维度(需能被num_heads整除)
context_length (int): 序列最大长度(用于初始化因果掩码)
dropout (float): Dropout概率(0-1之间)
num_heads (int): 并行注意力头数
qkv_bias (bool): 是否为Q/K/V的线性变换添加可学习偏置
"""
super().__init__()
# 维度验证:确保多头拆分可行性
assert (d_out % num_heads == 0), "d_out必须能被num_heads整除"
self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads # 每个头的特征维度
# 定义Q/K/V的线性投影层(将输入映射到d_out维空间)
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) # 查询向量变换 Q = XW_q
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) # 键向量变换 K = XW_k
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) # 值向量变换 V = XW_v
# 最终输出投影层(整合多头结果)
self.out_proj = nn.Linear(d_out, d_out)
# 正则化模块
self.dropout = nn.Dropout(dropout)
# 注册因果掩码(防止看到未来信息)
# 使用上三角矩阵(含对角线),diagonal=1表示主对角线以上保留,以下置零
# 最终形状:(context_length, context_length)
self.register_buffer("mask",
torch.triu(torch.ones(context_length, context_length), diagonal=1))
def forward(self, x):
"""前向传播过程(实现缩放点积注意力机制)
计算流程:
1. 线性投影得到Q/K/V张量
2. 重塑张量形状分割为多个注意力头
3. 计算注意力分数矩阵
4. 应用因果掩码和Softmax归一化
5. 多头注意力加权求和
6. 拼接结果并线性投影
Args:
x (Tensor): 输入张量,形状(batch_size, num_tokens, d_in)
Returns:
Tensor: 多头注意力结果,形状(batch_size, num_tokens, d_out)
"""
b, num_tokens, d_in = x.shape # 解包输入维度
# Step 1: 线性投影得到Q/K/V (形状均为(b, num_tokens, d_out))
keys = self.W_key(x) # 键向量 K = XW_k
queries = self.W_query(x) # 查询向量 Q = XW_q
values = self.W_value(x) # 值向量 V = XW_v
# Step 2: 重塑张量形状,分割多头
# 新形状:(b, num_tokens, num_heads, head_dim)
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
# Step 3: 转置维度用于矩阵运算 -> (b, num_heads, num_tokens, head_dim)
keys = keys.transpose(1, 2) # 将num_heads维度提前
queries = queries.transpose(1, 2)
values = values.transpose(1, 2)
# Step 4: 计算缩放点积注意力分数
# (b, num_heads, num_tokens, num_tokens)
attn_scores = queries @ keys.transpose(2, 3) # QK^T
# Step 5: 应用因果掩码(将上三角区域设为负无穷)
mask_bool = self.mask.bool()[:num_tokens, :num_tokens] # 动态裁剪掩码
attn_scores.masked_fill_(mask_bool, -torch.inf) # 填充未来位置
# Step 6: 对注意力分数进行缩放和Softmax归一化
# 缩放因子为1/sqrt(d_k),防止点积过大导致梯度消失
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
# Step 7: 应用Dropout正则化
attn_weights = self.dropout(attn_weights)
# Step 8: 加权求和得到上下文向量
# (b, num_heads, num_tokens, head_dim)
context_vec = attn_weights @ values # 乘积维度:(n,n) @ (n,d) = (n,d)
# Step 9: 拼接多个头的结果
# 转置回 (b, num_tokens, num_heads, head_dim)
context_vec = context_vec.transpose(1, 2)
# 重塑为原始输出维度 (b, num_tokens, d_out)
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
# Step 10: 最终线性投影(整合特征信息)
context_vec = self.out_proj(context_vec)
return context_vec
# 测试代码 -------------------------------------------------
# 测试案例说明:
# 1. 构造两个相同的样本(batch_size=2)
# 2. 每个样本包含6个token(序列长度=6)
# 3. 每个token使用3维嵌入向量(d_in=3)
# 验证目标:
# - 保持输入输出的批次大小和序列长度一致
# - 输出维度应符合d_out=2的要求
inputs = torch.tensor(
[[0.43, 0.15, 0.89], # 第1个词 "Your" 的嵌入向量
[0.55, 0.87, 0.66], # 第2个词 "journey" 的嵌入向量
[0.57, 0.85, 0.64], # 第3个词 "starts" 的嵌入向量
[0.22, 0.58, 0.33], # 第4个词 "with" 的嵌入向量
[0.77, 0.25, 0.10], # 第5个词 "one" 的嵌入向量
[0.05, 0.80, 0.55]] # 第6个词 "step" 的嵌入向量
)
batch = torch.stack((inputs, inputs), dim=0) # 形状(2, 6, 3)
# 初始化模块
torch.manual_seed(123) # 设置随机种子保证结果可复现
batch_size, context_length, d_in = batch.shape
d_out = 2 # 输出维度需要能被头数整除
mha = MultiHeadAttention(
d_in=d_in,
d_out=d_out,
context_length=context_length,
dropout=0.0, # 无Dropout
num_heads=2 # 2个注意力头
)
# 前向传播
context_vecs = mha(batch)
# 输出验证
print("上下文向量:\n", context_vecs)
print("输出形状:", context_vecs.shape) # 应为(2, 6, 2)
代码详细解析
1. 类设计架构
MultiHeadAttention类结构
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
# 参数初始化和多头配置
def forward(self, x):
# 多头注意力前向传播
设计要点:
- 继承自
nn.Module
,符合PyTorch模块化设计 - 支持可配置的头数和维度
- 集成因果掩码和dropout正则化
- 包含输出投影层整合多头信息
2. 核心参数分析
初始化参数详解
参数 | 类型 | 约束条件 | 作用 |
---|---|---|---|
d_in | int | > 0 | 输入特征维度 |
d_out | int | 能被num_heads整除 | 输出特征维度 |
context_length | int | > 0 | 最大序列长度 |
dropout | float | [0, 1] | Dropout比率 |
num_heads | int | d_out的因子 | 注意力头数 |
qkv_bias | bool | - | 是否使用偏置 |
关键组件设计
# 维度验证
assert (d_out % num_heads == 0), "d_out必须能被num_heads整除"
# 计算每个头的维度
self.head_dim = d_out // num_heads
# QKV投影层(共享输入,独立权重)
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
# 输出投影层(整合多头信息)
self.out_proj = nn.Linear(d_out, d_out)
3. 多头机制详解
维度分解原理
# 原始维度: (batch_size, seq_len, d_out)
# 重塑为多头: (batch_size, seq_len, num_heads, head_dim)
# 转置为: (batch_size, num_heads, seq_len, head_dim)
多头分解示例(d_out=8, num_heads=2):
原始QKV: [q1, q2, q3, q4, q5, q6, q7, q8]
头1: [q1, q2, q3, q4]
头2: [q5, q6, q7, q8]
并行计算优势
# 单头计算(串行)
for i in range(num_heads):
head_i = attention(Q_i, K_i, V_i)
# 多头计算(并行)
all_heads = attention(Q_all, K_all, V_all) # 同时计算所有头
4. 前向传播详解
步骤1:QKV生成和重塑
# 线性投影
keys = self.W_key(x) # (b, n, d_out)
queries = self.W_query(x) # (b, n, d_out)
values = self.W_value(x) # (b, n, d_out)
# 重塑为多头形式
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
维度变换过程:
(b, n, d_out)
→(b, n, h, d_k)
通过view重塑(b, n, h, d_k)
→(b, h, n, d_k)
通过transpose转置
步骤2:注意力计算
# 计算注意力分数
attn_scores = queries @ keys.transpose(2, 3) # (b, h, n, n)
# 应用因果掩码
mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
attn_scores.masked_fill_(mask_bool, -torch.inf)
# 缩放和归一化
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
并行处理特点:
- 所有头同时计算注意力分数
- 共享相同的掩码和缩放因子
- 每个头独立进行softmax归一化
步骤3:结果整合
# 加权求和
context_vec = attn_weights @ values # (b, h, n, d_k)
# 转置回原始维度顺序
context_vec = context_vec.transpose(1, 2) # (b, n, h, d_k)
# 拼接多头结果
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
# 输出投影
context_vec = self.out_proj(context_vec)
5. 与单头注意力的对比
功能对比表
特性 | 单头注意力 | 多头注意力 |
---|---|---|
表示能力 | 单一语义关系 | 多种语义关系 |
计算复杂度 | O(n²d) | O(n²d) |
参数数量 | 3×d_in×d_out | 3×d_in×d_out + d_out² |
并行度 | 低 | 高 |
信息整合 | 无 | 输出投影层 |
表示多样性 | 有限 | 丰富 |
注意力模式对比
# 单头注意力(关注单一模式)
head_1: 主要关注语法关系
权重分布: 集中在相邻词汇
# 多头注意力(关注多种模式)
head_1: 关注语法关系
head_2: 关注语义关系
head_3: 关注长距离依赖
head_4: 关注位置信息
6. 实验配置分析
输入数据设计
inputs = torch.tensor([
[0.43, 0.15, 0.89], # "Your"
[0.55, 0.87, 0.66], # "journey"
[0.57, 0.85, 0.64], # "starts"
[0.22, 0.58, 0.33], # "with"
[0.77, 0.25, 0.10], # "one"
[0.05, 0.80, 0.55] # "step"
])
batch = torch.stack((inputs, inputs), dim=0) # (2, 6, 3)
模型配置
d_in = 3 # 输入维度
d_out = 2 # 输出维度
num_heads = 2 # 注意力头数
head_dim = 1 # 每头维度 = d_out / num_heads
配置合理性验证:
- ✅
d_out % num_heads == 0
:2 % 2 = 0 - ✅
head_dim = d_out // num_heads
:1 = 2 // 2 - ✅ 批处理支持:batch_size = 2
7. 输出结果解读
预期输出特征
- 形状:
(2, 6, 2)
- 批次×序列长度×输出维度 - 多头整合: 包含2个头的综合信息
- 因果性: 保持时序约束
- 一致性: 相同输入产生相同输出
多头信息融合
# 每个位置的输出向量包含:
# - 头1的注意力信息(关注某种语义关系)
# - 头2的注意力信息(关注另一种语义关系)
# - 输出投影层的整合结果
高级特性与扩展
1. 多头可视化
注意力头可视化
def visualize_multi_head_attention(model, inputs, words):
"""可视化多头注意力权重"""
model.eval()
with torch.no_grad():
x = inputs.unsqueeze(0)
b, num_tokens, d_in = x.shape
# 获取QKV
Q = model.W_query(x).view(b, num_tokens, model.num_heads, model.head_dim)
K = model.W_key(x).view(b, num_tokens, model.num_heads, model.head_dim)
Q = Q.transpose(1, 2) # (b, h, n, d_k)
K = K.transpose(1, 2)
# 计算注意力权重
attn_scores = Q @ K.transpose(2, 3)
attn_scores.masked_fill_(
model.mask.bool()[:num_tokens, :num_tokens],
float('-inf')
)
attn_weights = torch.softmax(attn_scores / (K.shape[-1] ** 0.5), dim=-1)
# 绘制每个头的注意力权重
fig, axes = plt.subplots(1, model.num_heads, figsize=(15, 6))
for head in range(model.num_heads):
sns.heatmap(attn_weights[0, head].numpy(),
xticklabels=words,
yticklabels=words,
annot=True,
cmap='Blues',
fmt='.3f',
ax=axes[head])
axes[head].set_title(f'Head {head + 1}')
plt.tight_layout()
plt.show()
头重要性分析
def analyze_head_importance(model, inputs):
"""分析不同头的重要性"""
model.eval()
with torch.no_grad():
# 获取每个头的输出
x = inputs.unsqueeze(0)
b, num_tokens, d_in = x.shape
# 前向传播到多头计算
Q = model.W_query(x).view(b, num_tokens, model.num_heads, model.head_dim).transpose(1, 2)
K = model.W_key(x).view(b, num_tokens, model.num_heads, model.head_dim).transpose(1, 2)
V = model.W_value(x).view(b, num_tokens, model.num_heads, model.head_dim).transpose(1, 2)
attn_scores = Q @ K.transpose(2, 3)
attn_scores.masked_fill_(model.mask.bool()[:num_tokens, :num_tokens], float('-inf'))
attn_weights = torch.softmax(attn_scores / (K.shape[-1] ** 0.5), dim=-1)
# 计算每个头的注意力分布熵
head_entropy = []
for head in range(model.num_heads):
entropy = -torch.sum(attn_weights[0, head] * torch.log(attn_weights[0, head] + 1e-9), dim=-1)
head_entropy.append(entropy.mean().item())
return head_entropy
2. 性能优化
内存优化版本
class MemoryEfficientMultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
super().__init__()
assert d_out % num_heads == 0
self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads
# 使用单个线性层生成QKV
self.qkv_proj = nn.Linear(d_in, 3 * d_out, bias=qkv_bias)
self.out_proj = nn.Linear(d_out, d_out)
self.dropout = nn.Dropout(dropout)
self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))
def forward(self, x):
b, num_tokens, d_in = x.shape
# 一次性生成QKV
qkv = self.qkv_proj(x) # (b, n, 3*d_out)
qkv = qkv.view(b, num_tokens, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4) # (3, b, h, n, d_k)
queries, keys, values = qkv[0], qkv[1], qkv[2]
# 注意力计算
attn_scores = queries @ keys.transpose(-2, -1)
attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], float('-inf'))
attn_weights = torch.softmax(attn_scores / (self.head_dim ** 0.5), dim=-1)
attn_weights = self.dropout(attn_weights)
context = attn_weights @ values
context = context.transpose(1, 2).contiguous().view(b, num_tokens, self.d_out)
return self.out_proj(context)
Flash Attention集成
try:
from flash_attn import flash_attn_func
class FlashMultiHeadAttention(nn.Module):
def forward(self, x):
# 使用Flash Attention优化
b, num_tokens, d_in = x.shape
Q = self.W_query(x).view(b, num_tokens, self.num_heads, self.head_dim)
K = self.W_key(x).view(b, num_tokens, self.num_heads, self.head_dim)
V = self.W_value(x).view(b, num_tokens, self.num_heads, self.head_dim)
# Flash Attention计算
context = flash_attn_func(Q, K, V, causal=True)
context = context.view(b, num_tokens, self.d_out)
return self.out_proj(context)
except ImportError:
print("Flash Attention not available")
3. 扩展应用
交叉注意力版本
class CrossMultiHeadAttention(nn.Module):
"""交叉注意力(用于编码器-解码器架构)"""
def __init__(self, d_in, d_out, num_heads, dropout=0.0):
super().__init__()
assert d_out % num_heads == 0
self.num_heads = num_heads
self.head_dim = d_out // num_heads
self.W_query = nn.Linear(d_in, d_out)
self.W_key = nn.Linear(d_in, d_out)
self.W_value = nn.Linear(d_in, d_out)
self.out_proj = nn.Linear(d_out, d_out)
self.dropout = nn.Dropout(dropout)
def forward(self, query, key_value):
"""
query: 来自解码器的查询 (b, n_q, d_in)
key_value: 来自编码器的键值 (b, n_kv, d_in)
"""
b, n_q, _ = query.shape
n_kv = key_value.shape[1]
Q = self.W_query(query).view(b, n_q, self.num_heads, self.head_dim).transpose(1, 2)
K = self.W_key(key_value).view(b, n_kv, self.num_heads, self.head_dim).transpose(1, 2)
V = self.W_value(key_value).view(b, n_kv, self.num_heads, self.head_dim).transpose(1, 2)
attn_scores = Q @ K.transpose(-2, -1)
attn_weights = torch.softmax(attn_scores / (self.head_dim ** 0.5), dim=-1)
attn_weights = self.dropout(attn_weights)
context = attn_weights @ V
context = context.transpose(1, 2).contiguous().view(b, n_q, -1)
return self.out_proj(context)
应用场景与实践
1. 在Transformer中的应用
- 编码器: 自注意力层捕获输入序列的内部关系
- 解码器: 自注意力 + 交叉注意力实现生成
- 预训练模型: BERT、GPT等的核心组件
2. 超参数调优
- 头数选择: 通常为8、12、16等,需要平衡表示能力和计算成本
- 维度分配: 确保d_model能被num_heads整除
- dropout设置: 通常在0.1-0.3之间
3. 性能考虑
- 内存使用: 与头数和序列长度的平方成正比
- 计算复杂度: O(n²d),其中n是序列长度
- 并行化: 多头可以充分利用GPU并行计算能力
更新日志
2025/8/18 00:31
查看所有更新日志
bd1d0
-迁移目录于b0f2a
-docs: 完善大模型学习文档 - 增加设计思路与执行流程于dfb81
-update于2f7cb
-update于
版权所有
版权归属:NateHHX