外观
层归一化
约 1952 字大约 7 分钟
设计思路与核心概念
1. 层归一化的背景与动机
层归一化(Layer Normalization)是由 Ba et al. 在2016年提出的一种归一化技术,主要解决以下问题:
- 批归一化的局限性:批归一化依赖于批次大小,在RNN和小批次场景下效果不佳
- 训练稳定性:通过归一化激活值分布,加速训练收敛并提高模型稳定性
- 梯度流优化:缓解深层网络中的梯度消失和梯度爆炸问题
2. 核心设计思想
层归一化的核心思想是对每个样本的特征维度进行归一化,而不是像批归一化那样对批次维度进行归一化:
- 归一化维度:沿着特征维度(最后一维)计算均值和方差
- 独立性:每个样本独立进行归一化,不依赖其他样本
- 可学习参数:引入缩放参数γ和平移参数β,保持模型表达能力
3. 数学公式
对于输入向量 x∈Rd,层归一化的计算公式为:
LayerNorm(x)=γ⊙σ2+ϵx−μ+β
其中:
- μ=d1∑i=1dxi (特征维度均值)
- σ2=d1∑i=1d(xi−μ)2 (特征维度方差)
- γ,β∈Rd (可学习的缩放和平移参数)
- ϵ (数值稳定项,防止除零)
执行流程
1. 整体执行流程图
2. 详细计算流程图
计算步骤详解
步骤 | 操作 | PyTorch代码 | 说明 |
---|---|---|---|
1 | 输入张量 | x | 形状: (batch_size, emb_dim) |
2 | 计算均值 | mean = x.mean(dim=-1, keepdim=True) | 沿特征维度计算均值 |
3 | 计算方差 | var = x.var(dim=-1, keepdim=True, unbiased=False) | 沿特征维度计算方差 |
4 | 标准化 | norm_x = (x - mean) / torch.sqrt(var + eps) | 零均值单位方差标准化 |
5 | 缩放变换 | scaled = gamma * norm_x | 应用可学习缩放参数 |
6 | 平移变换 | output = scaled + beta | 应用可学习平移参数 |
3. 详细执行步骤
参数初始化阶段
- 初始化缩放参数 γ 为全1向量
- 初始化平移参数 β 为全0向量
- 设置数值稳定项 ε = 1e-5
前向传播阶段
- 计算输入张量沿特征维度的均值
- 计算输入张量沿特征维度的方差
- 执行标准化操作
- 应用可学习的仿射变换
验证阶段
- 检查归一化后的均值是否接近0
- 检查归一化后的方差是否接近1
完整代码实现
层归一化.py
import torch
import torch.nn as nn
# 设置随机种子保证实验可重复性
torch.manual_seed(123)
# 生成示例输入数据:batch_size=2, sequence_length=5(模拟两个样本的特征向量)
batch_example = torch.randn(2, 5)
# 关闭科学计数法显示,便于观察数值分布
torch.set_printoptions(sci_mode=False)
class LayerNorm(nn.Module):
"""层归一化实现(参考论文《Layer Normalization》核心思想)"""
def __init__(self, emb_dim):
"""
初始化层归一化模块
:param emb_dim: 特征维度大小(对应输入张量的最后一维)
"""
super().__init__()
self.eps = 1e-5 # 数值稳定项,防止除以零(类似BN中的epsilon)
self.scale = nn.Parameter(torch.ones(emb_dim)) # 可学习的缩放参数γ
self.shift = nn.Parameter(torch.zeros(emb_dim)) # 可学习的平移参数β
def forward(self, x):
"""
前向传播过程(核心计算公式与原始论文一致)
:param x: 输入张量,形状为(batch_size, embedding_dim)
"""
# 计算沿特征维度的均值(保持维度用于广播)
mean = x.mean(dim=-1, keepdim=True)
# 计算沿特征维度的方差(使用有偏估计与PyTorch官方实现保持一致)
var = x.var(dim=-1, keepdim=True, unbiased=False)
# 标准化计算:(x - μ)/√(σ² + ε)
norm_x = (x - mean) / torch.sqrt(var + self.eps)
# 应用仿射变换:γ * x_normalized + β
return self.scale * norm_x + self.shift
# 实例化层归一化模块(特征维度=5)
ln = LayerNorm(emb_dim=5)
# 前向传播计算
out_ln = ln(batch_example)
# 验证归一化效果
mean = out_ln.mean(dim=-1, keepdim=True) # 沿特征维度计算均值
var = out_ln.var(dim=-1, unbiased=False, keepdim=True) # 沿特征维度计算方差
print("归一化后均值:\n", mean) # 理论值应接近0
print("归一化后方差:\n", var) # 理论值应接近1
代码详细解析
1. 类设计说明
LayerNorm类结构
class LayerNorm(nn.Module):
def __init__(self, emb_dim):
# 参数初始化
def forward(self, x):
# 前向传播逻辑
设计要点:
- 继承自
nn.Module
,符合PyTorch模块化设计 emb_dim
参数指定特征维度大小- 使用
nn.Parameter
包装可学习参数,确保参数能被优化器更新
2. 关键参数说明
参数名 | 类型 | 作用 | 初始值 |
---|---|---|---|
eps | float | 数值稳定项,防止除零错误 | 1e-5 |
scale | Parameter | 可学习的缩放参数γ | ones(emb_dim) |
shift | Parameter | 可学习的平移参数β | zeros(emb_dim) |
3. 前向传播详解
步骤1:统计量计算
mean = x.mean(dim=-1, keepdim=True) # 沿最后一维计算均值
var = x.var(dim=-1, keepdim=True, unbiased=False) # 计算方差
关键点:
dim=-1
:沿特征维度(最后一维)计算keepdim=True
:保持维度用于后续广播操作unbiased=False
:使用有偏估计,与PyTorch官方实现一致
步骤2:标准化操作
norm_x = (x - mean) / torch.sqrt(var + self.eps)
数学含义:将输入转换为均值为0、方差为1的标准正态分布
步骤3:仿射变换
return self.scale * norm_x + self.shift
作用:通过可学习参数恢复模型的表达能力
4. 实验验证
输入数据特征
- 形状:(2, 5) - 2个样本,每个样本5个特征
- 数据类型:随机正态分布
- 随机种子:123(保证实验可重复性)
验证指标
- 均值检查:归一化后每个样本的特征均值应接近0
- 方差检查:归一化后每个样本的特征方差应接近1
运行结果分析
预期输出
归一化后均值:
tensor([[-0.0000],
[ 0.0000]])
归一化后方差:
tensor([[1.0000],
[1.0000]])
结果解释
- 均值接近0:说明标准化成功消除了特征的偏移
- 方差接近1:说明特征被缩放到标准范围
- 每行独立:每个样本独立进行归一化,体现了层归一化的核心特点
与其他归一化方法的对比
归一化方法 | 归一化维度 | 依赖性 | 适用场景 |
---|---|---|---|
批归一化(BN) | 批次维度 | 依赖批次大小 | CNN、大批次训练 |
层归一化(LN) | 特征维度 | 样本独立 | RNN、Transformer、小批次 |
实例归一化(IN) | 空间维度 | 样本独立 | 风格迁移、生成模型 |
组归一化(GN) | 特征组 | 样本独立 | 小批次、目标检测 |
在Transformer中的应用
层归一化在Transformer架构中扮演关键角色:
- 位置:通常放在多头注意力和前馈网络之后
- 作用:稳定训练过程,加速收敛
- 变体:Pre-LN(层归一化前置)vs Post-LN(层归一化后置)
# Transformer中的典型用法
class TransformerBlock(nn.Module):
def __init__(self, d_model):
self.attention = MultiHeadAttention(d_model)
self.ffn = FeedForward(d_model)
self.ln1 = LayerNorm(d_model) # 注意力后的层归一化
self.ln2 = LayerNorm(d_model) # 前馈网络后的层归一化
def forward(self, x):
# Post-LN结构
x = x + self.ln1(self.attention(x))
x = x + self.ln2(self.ffn(x))
return x
更新日志
2025/8/18 00:31
查看所有更新日志
bd1d0
-迁移目录于bcc6d
-docs: 完善大模型架构文档 - 增加设计思路与执行流程于dfb81
-update于1a489
-update于
版权所有
版权归属:NateHHX