外观
GELU激活函数
约 329 字大约 1 分钟
2025-06-20
import torch
import torch.nn as nn
class GELU(nn.Module):
"""高斯误差线性单元(GELU)激活函数实现(基于Hendrycks等2016年原始论文的近似公式)"""
def __init__(self):
super().__init__()
# 无需可训练参数,与ReLU相比具有更平滑的梯度特性
def forward(self, x):
"""前向传播实现(使用tanh近似公式提高计算效率)
Args:
x (Tensor): 输入张量
Return:
Tensor: 激活后的输出
"""
return 0.5 * x * (1 + torch.tanh(
torch.sqrt(torch.tensor(2.0 / torch.pi)) * # √(2/π) 来自高斯分布特性
(x + 0.044715 * torch.pow(x, 3)) # 三次项系数优化训练稳定性
))
# 激活函数可视化对比
import matplotlib.pyplot as plt
gelu = GELU() # 实例化GELU模块
relu = nn.ReLU() # 实例化标准ReLU模块作为对比
x = torch.linspace(-3, 3, 100) # 生成[-3,3]范围内100个均匀分布点
y_gelu = gelu(x) # 计算GELU输出
y_relu = relu(x) # 计算ReLU输出
plt.figure(figsize=(8, 3)) # 设置画布尺寸(宽度8英寸,高度3英寸)
# 双图对比可视化(左GELU,右ReLU)
for i, (y, label) in enumerate(zip([y_gelu, y_relu], ["GELU", "ReLU"]), 1):
plt.subplot(1, 2, i) # 创建子图位置索引
plt.plot(x.numpy(), y.numpy()) # 张量转numpy数组绘图
plt.title(f"{label} activation function")
plt.xlabel("x") # x轴标签
plt.ylabel(f"{label}(x)") # y轴标签
plt.grid(True) # 显示网格线
plt.tight_layout() # 自动调整子图间距
plt.show() # 显示图像
版权所有
版权归属:NateHHX