外观
多头注意力
约 1050 字大约 4 分钟
2025-06-13
# 导入PyTorch库
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
"""多头注意力机制模块 (Multi-Head Attention)
实现公式:
MultiHead(Q,K,V) = Concat(head_1,...,head_h)W^O
where head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)
Args:
d_in (int): 输入特征维度(原始嵌入维度)
d_out (int): 输出特征维度(需能被num_heads整除)
context_length (int): 序列最大长度(用于初始化因果掩码)
dropout (float): Dropout概率(0-1之间)
num_heads (int): 并行注意力头数
qkv_bias (bool): 是否为Q/K/V的线性变换添加可学习偏置
"""
super().__init__()
# 维度验证:确保多头拆分可行性
assert (d_out % num_heads == 0), "d_out必须能被num_heads整除"
self.d_out = d_out
self.num_heads = num_heads
self.head_dim = d_out // num_heads # 每个头的特征维度
# 定义Q/K/V的线性投影层(将输入映射到d_out维空间)
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) # 查询向量变换 Q = XW_q
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) # 键向量变换 K = XW_k
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) # 值向量变换 V = XW_v
# 最终输出投影层(整合多头结果)
self.out_proj = nn.Linear(d_out, d_out)
# 正则化模块
self.dropout = nn.Dropout(dropout)
# 注册因果掩码(防止看到未来信息)
# 使用上三角矩阵(含对角线),diagonal=1表示主对角线以上保留,以下置零
# 最终形状:(context_length, context_length)
self.register_buffer("mask",
torch.triu(torch.ones(context_length, context_length), diagonal=1))
def forward(self, x):
"""前向传播过程(实现缩放点积注意力机制)
计算流程:
1. 线性投影得到Q/K/V张量
2. 重塑张量形状分割为多个注意力头
3. 计算注意力分数矩阵
4. 应用因果掩码和Softmax归一化
5. 多头注意力加权求和
6. 拼接结果并线性投影
Args:
x (Tensor): 输入张量,形状(batch_size, num_tokens, d_in)
Returns:
Tensor: 多头注意力结果,形状(batch_size, num_tokens, d_out)
"""
b, num_tokens, d_in = x.shape # 解包输入维度
# Step 1: 线性投影得到Q/K/V (形状均为(b, num_tokens, d_out))
keys = self.W_key(x) # 键向量 K = XW_k
queries = self.W_query(x) # 查询向量 Q = XW_q
values = self.W_value(x) # 值向量 V = XW_v
# Step 2: 重塑张量形状,分割多头
# 新形状:(b, num_tokens, num_heads, head_dim)
keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
values = values.view(b, num_tokens, self.num_heads, self.head_dim)
# Step 3: 转置维度用于矩阵运算 -> (b, num_heads, num_tokens, head_dim)
keys = keys.transpose(1, 2) # 将num_heads维度提前
queries = queries.transpose(1, 2)
values = values.transpose(1, 2)
# Step 4: 计算缩放点积注意力分数
# (b, num_heads, num_tokens, num_tokens)
attn_scores = queries @ keys.transpose(2, 3) # QK^T
# Step 5: 应用因果掩码(将上三角区域设为负无穷)
mask_bool = self.mask.bool()[:num_tokens, :num_tokens] # 动态裁剪掩码
attn_scores.masked_fill_(mask_bool, -torch.inf) # 填充未来位置
# Step 6: 对注意力分数进行缩放和Softmax归一化
# 缩放因子为1/sqrt(d_k),防止点积过大导致梯度消失
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
# Step 7: 应用Dropout正则化
attn_weights = self.dropout(attn_weights)
# Step 8: 加权求和得到上下文向量
# (b, num_heads, num_tokens, head_dim)
context_vec = attn_weights @ values # 乘积维度:(n,n) @ (n,d) = (n,d)
# Step 9: 拼接多个头的结果
# 转置回 (b, num_tokens, num_heads, head_dim)
context_vec = context_vec.transpose(1, 2)
# 重塑为原始输出维度 (b, num_tokens, d_out)
context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
# Step 10: 最终线性投影(整合特征信息)
context_vec = self.out_proj(context_vec)
return context_vec
# 测试代码 -------------------------------------------------
# 测试案例说明:
# 1. 构造两个相同的样本(batch_size=2)
# 2. 每个样本包含6个token(序列长度=6)
# 3. 每个token使用3维嵌入向量(d_in=3)
# 验证目标:
# - 保持输入输出的批次大小和序列长度一致
# - 输出维度应符合d_out=2的要求
inputs = torch.tensor(
[[0.43, 0.15, 0.89], # 第1个词 "Your" 的嵌入向量
[0.55, 0.87, 0.66], # 第2个词 "journey" 的嵌入向量
[0.57, 0.85, 0.64], # 第3个词 "starts" 的嵌入向量
[0.22, 0.58, 0.33], # 第4个词 "with" 的嵌入向量
[0.77, 0.25, 0.10], # 第5个词 "one" 的嵌入向量
[0.05, 0.80, 0.55]] # 第6个词 "step" 的嵌入向量
)
batch = torch.stack((inputs, inputs), dim=0) # 形状(2, 6, 3)
# 初始化模块
torch.manual_seed(123) # 设置随机种子保证结果可复现
batch_size, context_length, d_in = batch.shape
d_out = 2 # 输出维度需要能被头数整除
mha = MultiHeadAttention(
d_in=d_in,
d_out=d_out,
context_length=context_length,
dropout=0.0, # 无Dropout
num_heads=2 # 2个注意力头
)
# 前向传播
context_vecs = mha(batch)
# 输出验证
print("上下文向量:\n", context_vecs)
print("输出形状:", context_vecs.shape) # 应为(2, 6, 2)
版权所有
版权归属:NateHHX