外观
因果注意力
约 643 字大约 2 分钟
2025-06-13
# 导入PyTorch库
import torch
import torch.nn as nn
class CausalAttention(nn.Module):
"""实现因果自注意力机制的模块"""
def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
"""
参数说明:
d_in: 输入特征维度
d_out: 输出特征维度
context_length: 上下文长度(最大序列长度)
dropout: dropout比率
qkv_bias: 是否在Q/K/V线性变换中添加偏置项
"""
super().__init__()
self.d_out = d_out
# 定义Q/K/V的线性变换层
self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias) # 查询(Query)变换
self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias) # 键(Key)变换
self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias) # 值(Value)变换
self.dropout = nn.Dropout(dropout) # dropout层
# 注册因果掩码(不可训练的缓冲区)
# 创建上三角矩阵(主对角线以上为1,其余为0),用于屏蔽未来信息
self.register_buffer(
'mask',
torch.triu(torch.ones(context_length, context_length), diagonal=1)
)
def forward(self, x):
"""
前向传播过程
输入形状: (batch_size, num_tokens, d_in)
输出形状: (batch_size, num_tokens, d_out)
"""
b, num_tokens, d_in = x.shape # 获取输入张量的形状信息
# 生成键(K)、查询(Q)、值(V)
keys = self.W_key(x) # 形状: (b, num_tokens, d_out)
queries = self.W_query(x) # 形状: (b, num_tokens, d_out)
values = self.W_value(x) # 形状: (b, num_tokens, d_out)
# 计算注意力分数(缩放点积注意力)
attn_scores = queries @ keys.transpose(1, 2) # (b, num_tokens, num_tokens)
# 应用因果掩码(屏蔽未来位置的注意力)
attn_scores.masked_fill_(
self.mask.bool()[:num_tokens, :num_tokens], # 动态调整掩码尺寸
-torch.inf # 用负无穷填充被屏蔽的位置
)
# 计算注意力权重(应用softmax和dropout)
attn_weights = torch.softmax(
attn_scores / keys.shape[-1] ** 0.5, # 缩放因子sqrt(d_k)
dim=-1
)
attn_weights = self.dropout(attn_weights)
# 计算上下文向量(加权和)
context_vec = attn_weights @ values # (b, num_tokens, d_out)
return context_vec
# ===================== 测试代码 =====================
# 创建测试输入数据(模拟6个词的嵌入向量)
# 每个词用3维向量表示,包含两个相同的样本(batch_size=2)
inputs = torch.tensor(
[[0.43, 0.15, 0.89], # 第1个词 "Your" 的嵌入向量
[0.55, 0.87, 0.66], # 第2个词 "journey" 的嵌入向量
[0.57, 0.85, 0.64], # 第3个词 "starts" 的嵌入向量
[0.22, 0.58, 0.33], # 第4个词 "with" 的嵌入向量
[0.77, 0.25, 0.10], # 第5个词 "one" 的嵌入向量
[0.05, 0.80, 0.55]] # 第6个词 "step" 的嵌入向量
)
# 创建批次数据(将输入数据堆叠两次,模拟batch_size=2的情况)
batch = torch.stack((inputs, inputs), dim=0)
print("输入数据形状:", batch.shape) # 预期输出: torch.Size([2, 6, 3])
# 初始化模型参数
d_in = inputs.shape[-1] # 输入维度 = 3
d_out = 2 # 输出维度 = 2
context_length = batch.shape[1] # 上下文长度 = 6
# 固定随机种子(保证可重复性)
torch.manual_seed(123)
# 初始化因果自注意力模块(关闭dropout)
ca = CausalAttention(d_in, d_out, context_length, 0.0)
# 前向传播
context_vecs = ca(batch)
# 输出结果形状
print("上下文向量形状:", context_vecs) # 预期输出: torch.Size([2, 6, 2])
版权所有
版权归属:NateHHX