批量归一化(Batch Normalization,简称 BN)是深度学习中最重要的技术突破之一。由 Sergey Ioffe 和 Christian Szegedy 在 2015 年提出,BN 通过在每个 mini-batch 上对层输入进行标准化,使得深层神经网络的训练更加稳定和快速。如今,BN 已是几乎所有现代卷积网络的标准组件,在 ResNet、Inception、EfficientNet 等架构中发挥着关键作用。
训练深层神经网络时,每一层的输入分布会随着前面层参数的更新而不断变化。这种变化迫使每一层网络持续适应新的输入分布,导致训练效率低下。例如,在一个 50 层的网络中,第 40 层的输入分布可能在训练过程中发生剧烈变化,使得该层需要不断调整其参数来适应。
原始论文将上述现象定义为内部协变量偏移(Internal Covariate Shift, ICS):在深层网络训练过程中,由于网络中参数变化而引起内部结点数据分布发生变化的过程。
具体来说,假设第 层的输出为 ,其中 。随着梯度下降更新 和 , 的分布会不断变化,进而导致下一层 需要适应这些变化。这种连锁反应在深层网络中会被逐层放大。
数值示例:考虑一个简单两层网络,第一层权重 从 0.5 更新到 0.6(变化 20%)。如果输入特征 的平均值为 10,则输出 变为 ,变化了 20%。而在一个 10 层网络中,这种微小变化经过逐层放大后,可能导致最终输出变化超过 500%。
| 问题 | 描述 | 影响 |
|---|---|---|
| 学习速度降低 | 上层网络需要不断调整以适应输入分布变化 | 收敛速度变慢,需要更多训练轮次 |
| 梯度饱和 | 使用 sigmoid/tanh 时,输入分布变化可能使激活函数进入饱和区 | 梯度趋近于零,参数几乎不更新 |
| 对初始化敏感 | 参数初始化的微小差异在深层网络中会被放大 | 需要精心设计初始化策略 |
| 学习率受限 | 过大的学习率会导致分布剧烈变化,网络无法收敛 | 只能使用较小的学习率 |
以 sigmoid 激活函数为例:sigmoid 在 时梯度接近 0。如果某一层的输出 由于前面层的更新从 0 附近偏移到 ,则该层神经元将进入饱和区,梯度几乎为 0,参数更新停滞。
BN 的核心思想是:对每一层的输入进行标准化,使其均值为 0、方差为 1,然后再通过可学习的仿射变换恢复其表达能力。
对于输入 mini-batch ,包含 个样本,BN 的计算步骤如下:
Step 1:计算 mini-batch 的均值
Step 2:计算 mini-batch 的方差
Step 3:标准化
其中 是一个很小的常数(通常取 ),防止除零错误。
Step 4:缩放和平移(可学习参数)
其中 和 是可学习的参数,维度与 相同。
数值示例:假设 mini-batch 包含 4 个样本,
可以看到,经过 BN 后,输入被重新分布到激活函数的非饱和区域,同时通过 和 保留了必要的表达能力。
和 的存在至关重要:如果标准化后的分布不适合当前任务,网络可以通过学习 和 恢复原始的分布。特别地,当 且 时,BN 可以完全还原原始分布,即 。
BN 是可微的,因此可以正常参与梯度反向传播。 和 的梯度计算如下:
其中 是损失函数。
训练时,BN 使用当前 mini-batch 的统计量(均值和方差)进行标准化。这引入了 mini-batch 间的随机性,起到了正则化作用。
推理时,我们不再使用 mini-batch 统计量,而是使用训练过程中累积的全局统计量(running mean 和 running variance):
通常使用滑动平均(moving average)更新:
其中 通常取 0.9(动量系数)。
推理时的 BN 变换为固定的线性变换:
这意味着推理时 BN 完全退化为一个简单的线性变换,没有随机性,保证了推理结果的确定性。
在全连接层中,BN 作用于每个神经元的激活值:对 batch 中所有样本的同一神经元输出进行标准化。
z = Wx + b # 线性变换
z_norm = BN(z) # 批量归一化
a = activation(z_norm) # 激活函数
注意:BN 层会包含 参数,因此线性变换中的偏置 可以省略(或者设为 0),因为其效果会被 吸收。
在卷积层中,BN 的标准化维度略有不同。对于形状为 的特征图( 为 batch size, 为通道数, 和 为空间维度),BN 对每个通道分别计算均值和方差:
| 架构 | BN 放置位置 | 特点 |
|---|---|---|
| 原始 BN 论文 | 激活函数之前:BN(Conv(x)) → ReLU |
标准化后再激活 |
| 实践惯例 | 激活函数之后:Conv(x) → ReLU → BN |
兼容性更好 |
| ResNet | Conv → BN → ReLU |
瓶颈结构中的标准配置 |
| Pre-activation ResNet | BN → ReLU → Conv |
改善梯度流动 |
原始论文认为 BN 通过减少内部协变量偏移来加速训练。但这一解释在后来受到了挑战。
Santurkar 等人(2018)在论文"How Does Batch Normalization Help Optimization?"中证明:
内部协变量偏移并非 BN 有效的主要原因:实验表明,即使引入了人为的协变量偏移,BN 网络仍能正常训练。相反,没有 BN 的网络即使 ICS 很小,训练仍然困难。
BN 使优化景观更加平滑:这是 BN 有效性的核心机制。具体来说,BN 改善了损失函数的 Lipschitz 性质(Lipschitzness)和 -平滑性,使得:
平滑效果数值示例:假设损失函数 在没有 BN 时是 ,梯度为 。在 处的梯度为 0.003,在 处为 0.003,梯度的 Lipschitz 常数为 60。加上 BN 后,函数变为 ,梯度为 ,梯度的 Lipschitz 常数降为 2。这意味着梯度变化更加温和,优化更加稳定。
| 效果 | 说明 | 影响程度 |
|---|---|---|
| 允许更大学习率 | 平滑的损失景观允许使用 5-10 倍更大的学习率 | 显著 |
| 降低对初始化的敏感度 | 参数的初始值对训练影响减小 | 显著 |
| 正则化效果 | mini-batch 统计量的随机性类似于 Dropout | 中等(约等于 Dropout 0.5) |
| 缓解梯度消失 | 将输入保持在激活函数的非饱和区域 | 显著 |
| 加速收敛 | 通常减少 10-14 倍的训练步数 | 显著 |
BN 的正则化效果来自 mini-batch 统计量的随机性。每个 mini-batch 的均值和方差都是总体统计量的有偏估计,这种估计的随机性相当于在训练过程中引入了噪声。研究表明,当 batch size 为 32 时,BN 的正则化效果约相当于 Dropout rate 为 0.3-0.5。因此,在使用 BN 的网络中,通常可以降低或完全去除 Dropout。
BN 对 batch size 敏感。过小的 batch size 会导致统计量估计不准确,影响性能。
| Batch Size | BN 性能 | 说明 |
|---|---|---|
| 优秀 | 统计量估计准确 | |
| 16 | 良好 | 稍有波动,但可接受 |
| 8 | 一般 | 方差估计不稳定 |
| 4 | 较差 | 需要配合其他技巧 |
| 2 | 很差 | 不推荐使用 BN |
| 1 | 不可用 | 无法计算方差(分母为 0) |
直观理解:batch size = 32 时,均值和方差的估计误差约为 ;batch size = 4 时,误差约为 。过大的统计噪声会干扰训练。
使用 BN 后,可以显著提高学习率。一般建议:
BN 本身具有正则化效果,因此:
| 技巧 | 建议 | 原因 |
|---|---|---|
| 学习率 | 增大 3-10 倍 | BN 使优化景观更平滑 |
| 初始化 | 使用更大的标准差 | BN 会自动调整分布 |
| Dropout | 降低或移除 | BN 自带正则化 |
| L2 正则化 | 适当降低 weight decay | 与 BN 正则化叠加 |
| Batch Size | 保持 | 过小会导致统计量不可靠 |
| 学习率预热 | 前几个 epoch 逐步增大学习率 | 避免早期训练不稳定 |
BN 不是唯一的归一化技术。随着应用场景的多样化,研究者提出了多种替代方案。
不同的归一化方法在归一化的维度上有所不同。设输入张量为 :
| 方法 | 归一化维度 | 提出时间 | 代表作 |
|---|---|---|---|
| BatchNorm (BN) | 2015 | Ioffe & Szegedy | |
| LayerNorm (LN) | 2016 | Ba et al. | |
| InstanceNorm (IN) | 2016 | Ulyanov et al. | |
| GroupNorm (GN) | 2018 | Wu & He | |
| Batch Renorm | (改进) | 2017 | Ioffe |
假设 batch size , 个通道,:
| 方法 | 适用场景 | 优势 | 劣势 |
|---|---|---|---|
| BN | CV(图像分类、检测) | 收敛快,效果好 | 对 batch size 敏感 |
| LN | NLP(Transformer、RNN) | 不依赖 batch | 在 CV 任务中效果不如 BN |
| IN | 风格迁移、图像生成 | 适合单样本处理 | 丢弃了样本间信息 |
| GN | 小 batch CV 任务 | batch size 鲁棒 | 需要选择分组数 |
下表展示了不同 batch size 下 ResNet-50 在 ImageNet 验证集上的错误率(%)(Wu & He, 2018):
| 方法 | batch=32 | batch=16 | batch=8 | batch=4 | batch=2 |
|---|---|---|---|---|---|
| BN | 23.6 | 23.7 | 24.8 | 27.3 | 34.7 |
| GN | 24.1 | 24.2 | 24.0 | 24.2 | 24.1 |
| LN | — | — | 25.3 | — | — |
| IN | — | — | 28.4 | — | — |
关键发现:
Transformer 使用 LayerNorm 而非 BN 的原因:
import torch
import torch.nn as nn
class BatchNormManual(nn.Module):
def __init__(self, num_features, eps=1e-5, momentum=0.9):
super().__init__()
self.num_features = num_features
self.eps = eps
self.momentum = momentum
# 可学习参数
self.gamma = nn.Parameter(torch.ones(num_features))
self.beta = nn.Parameter(torch.zeros(num_features))
# 全局统计量(推理时使用)
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.training = True
def forward(self, x):
# x: (batch, num_features) 或 (batch, channels, height, width)
if x.dim() == 2:
# 全连接层
if self.training:
mean = x.mean(dim=0)
var = x.var(dim=0, unbiased=False)
# 更新全局统计量
self.running_mean = self.momentum * self.running_mean + \
(1 - self.momentum) * mean
self.running_var = self.momentum * self.running_var + \
(1 - self.momentum) * var
else:
mean = self.running_mean
var = self.running_var
x_norm = (x - mean) / torch.sqrt(var + self.eps)
return self.gamma * x_norm + self.beta
elif x.dim() == 4:
# 卷积层:每个通道独立归一化
b, c, h, w = x.shape
if self.training:
mean = x.mean(dim=(0, 2, 3))
var = x.var(dim=(0, 2, 3), unbiased=False)
self.running_mean = self.momentum * self.running_mean + \
(1 - self.momentum) * mean
self.running_var = self.momentum * self.running_var + \
(1 - self.momentum) * var
else:
mean = self.running_mean
var = self.running_var
x_norm = (x - mean.view(1, c, 1, 1)) / \
torch.sqrt(var.view(1, c, 1, 1) + self.eps)
return self.gamma.view(1, c, 1, 1) * x_norm + \
self.beta.view(1, c, 1, 1)
import torch.nn as nn
# 全连接层:num_features 是特征数
bn_fc = nn.BatchNorm1d(num_features=128)
# 卷积层:num_features 是通道数
bn_conv = nn.BatchNorm2d(num_features=64)
class Bottleneck(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3,
stride=stride, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv3 = nn.Conv2d(out_channels, out_channels * 4, kernel_size=1)
self.bn3 = nn.BatchNorm2d(out_channels * 4)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
out += identity # 残差连接
out = self.relu(out)
return out
当 batch size 必须很小时(如视频理解、3D 医学图像),可以考虑:
在分布式训练中,每个 GPU 处理不同的数据子集。标准的 BN 在每个 GPU 上独立计算统计量,这相当于使用了更小的 batch size。解决方案:
nn.SyncBatchNorm 替代 nn.BatchNorm2d# 将模型中的 BN 替换为 SyncBN
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
def freeze_bn(model):
for module in model.modules():
if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)):
module.eval() # 固定 running stats
module.requires_grad_(False) # 不更新 gamma, beta
BN 在训练和推理时行为不同。务必正确切换模型模式:
model.train() # 训练模式:使用 batch 统计量
# ... 训练代码 ...
model.eval() # 推理模式:使用全局 running stats
# ... 推理代码 ...
忘记切换 model.eval() 是使用 BN 时最常见的错误,会导致推理结果不稳定甚至错误。
| 年份 | 里程碑 | 贡献者 |
|---|---|---|
| 2015 | BN 论文发表,提出 ICS 解释 | Ioffe & Szegedy |
| 2016 | LayerNorm 提出,用于 RNN | Ba et al. |
| 2017 | Batch Renorm 改进小 batch 场景 | Ioffe |
| 2018 | "How Does BN Help?" 推翻 ICS 解释 | Santurkar et al. |
| 2018 | GroupNorm 提出,解决小 batch 问题 | Wu & He |
| 2020 | Normalization-Free 网络尝试 | Brock et al. |
| 2022 | 归一化替代研究趋于成熟 | 多个方向 |
2017 年 NeurIPS 大会上,Ali Rahimi 在接受"Test of Time Award"时,将现代深度学习实践比作炼金术,以 BN 的内部协变量偏移解释作为典型案例,指出许多流行的技术缺乏严格的理论基础。这场演讲引发了深度学习社区对"工程实践 vs 科学理解"之间差距的广泛讨论。
Santurkar 等人(2018)随后证明,BN 的成功与内部协变量偏移关系不大,其真正作用是平滑优化景观。这一发现说明,深度学习中的许多"直觉解释"可能并不准确,技术的有效性有时在发现其真正原因之前就已经被实践经验证实。
尽管如此,BN 在实践中仍是不可或缺的技术。即使我们对它的理解在不断深化,它在图像分类器中的有效性已被数万次引用和无数实验证明。