Skip to content

Normalization 代码导读

理解 MiniMind 中 RMSNorm 的真实实现


📂 代码位置

1. RMSNorm 类定义

文件model/model_minimind.py行数:95-105

python
class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight

2. RMSNorm 在 TransformerBlock 中的使用

文件model/model_minimind.py行数:359-380

python
class TransformerBlock(nn.Module):
    def __init__(self, config: MiniMindConfig):
        super().__init__()
        self.attention = Attention(config)
        self.feed_forward = FeedForward(config)

        # 两个 RMSNorm
        self.attention_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(self, x, pos_ids, mask):
        # Pre-Norm 架构

        # 子层 1:Attention
        h = x + self.attention(
            self.attention_norm(x),  # 先归一化
            pos_ids,
            mask
        )

        # 子层 2:FeedForward
        out = h + self.feed_forward(
            self.ffn_norm(h)  # 先归一化
        )

        return out

🔍 逐行解析

RMSNorm 类

__init__ 方法

python
def __init__(self, dim: int, eps: float = 1e-5):
    super().__init__()
    self.eps = eps
    self.weight = nn.Parameter(torch.ones(dim))

参数

  • dim:隐藏维度,例如 MiniMind-small 中 dim=512
  • eps:防止除零的小常数,默认 1e-5

可学习参数

  • self.weight:形状 [dim],初始化为全 1
  • 作用:让模型自己学习最佳的缩放尺度

为什么没有 bias?

  • RMSNorm 不减均值,所以不需要 bias
  • LayerNorm 有 weight 和 bias 两个参数

_norm 方法(核心计算)

python
def _norm(self, x):
    return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

逐步分解

  1. x.pow(2):计算 x2(逐元素平方)

    • 输入 x: [batch, seq_len, hidden_dim]
    • 输出: [batch, seq_len, hidden_dim]
  2. .mean(-1, keepdim=True):在最后一维求均值

    • 计算 1di=1dxi2
    • 输出: [batch, seq_len, 1](保持维度便于广播)
  3. + self.eps:防止除零

    • 当所有元素都是 0 时,避免 1/0 错误
  4. torch.rsqrt(...):计算倒数平方根 1/...

    • 等价于 1 / torch.sqrt(...)
    • rsqrt 在 GPU 上有优化,更快
  5. x * ...:归一化

    • 相当于 xmean(x2)+ϵ

为什么在 -1 维度归一化?

  • -1 表示最后一维,即 hidden_dim
  • 我们希望每个 token 的 hidden_dim 维向量被独立归一化
  • 不同 token 之间不共享统计量

forward 方法

python
def forward(self, x):
    output = self._norm(x.float()).type_as(x)
    return output * self.weight

关键操作

  1. x.float():转换为 FP32

    • 为什么?避免 FP16/BF16 下的数值下溢
    • 归一化计算需要较高精度
  2. self._norm(...):执行归一化

  3. .type_as(x):转回原始数据类型

    • 如果输入是 BF16,输出也是 BF16
    • 保持数据类型一致性
  4. * self.weight:缩放

    • 乘以可学习参数
    • 让模型自适应调整尺度

🏗️ 在 TransformerBlock 中的使用

Pre-Norm 架构

python
def forward(self, x, pos_ids, mask):
    # 第一个子层:Attention + Residual
    h = x + self.attention(
        self.attention_norm(x),  # ← 先 Norm
        pos_ids,
        mask
    )

    # 第二个子层:FFN + Residual
    out = h + self.feed_forward(
        self.ffn_norm(h)  # ← 先 Norm
    )

    return out

数据流

输入 x: [batch, seq_len, hidden_dim]

x_normed = attention_norm(x)  ← RMSNorm #1

attn_out = attention(x_normed)

h = x + attn_out  ← 残差连接 #1

h_normed = ffn_norm(h)  ← RMSNorm #2

ffn_out = feed_forward(h_normed)

out = h + ffn_out  ← 残差连接 #2

返回 out

关键点

  • ✅ 归一化在子层之前(Pre-Norm)
  • ✅ 残差连接绕过了归一化
  • ✅ 每个子层的输入都是归一化的

🔬 实验验证代码

最小实现(用于理解)

python
import torch
import torch.nn as nn

class SimpleRMSNorm(nn.Module):
    """最简化的 RMSNorm 实现(用于教学)"""

    def __init__(self, dim, eps=1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        # 1. 计算 RMS
        rms = torch.sqrt(torch.mean(x ** 2, dim=-1, keepdim=True) + self.eps)

        # 2. 归一化
        x_norm = x / rms

        # 3. 缩放
        return self.weight * x_norm

# 测试
norm = SimpleRMSNorm(512)
x = torch.randn(2, 10, 512)  # [batch=2, seq=10, hidden=512]
output = norm(x)

print(f"输入形状: {x.shape}")
print(f"输出形状: {output.shape}")
print(f"输入标准差: {x.std().item():.4f}")
print(f"输出标准差: {output.std().item():.4f}")  # 应该接近 1.0

💡 实现技巧

1. 为什么用 rsqrt 而不是 1/sqrt

python
# 方法 1(慢)
norm1 = x / torch.sqrt(x.pow(2).mean(-1, keepdim=True) + eps)

# 方法 2(快)
norm2 = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)
  • rsqrt 是融合操作,GPU 有专门优化
  • 乘法比除法快
  • 速度提升约 5-10%

2. 为什么要 .float().type_as(x)

python
def forward(self, x):
    output = self._norm(x.float()).type_as(x)  # ← 精度转换
    return output * self.weight

原因

  • FP16/BF16 训练时,小数值容易下溢(变成 0)
  • 归一化计算需要较高精度
  • 但最终输出要和输入类型一致

流程

输入 x (BF16)
  → .float() (FP32)
  → 归一化计算 (FP32)
  → .type_as(x) (BF16)
  → 输出 (BF16)

3. 为什么 keepdim=True

python
x.pow(2).mean(-1, keepdim=True)  # [batch, seq, 1]
# vs
x.pow(2).mean(-1)                # [batch, seq]
  • keepdim=True:保持维度,输出 [batch, seq, 1]
  • 便于广播:[batch, seq, hidden] / [batch, seq, 1]
  • 如果不保持:[batch, seq, hidden] / [batch, seq] ❌(维度不匹配)

📊 性能对比

RMSNorm vs LayerNorm

在 MiniMind-small(512 hidden, 8 layers)上测试:

操作LayerNormRMSNorm提升
前向传播2.3 ms2.1 ms8.7%
反向传播4.5 ms4.0 ms11.1%
总训练时间(1000步)45.2 s42.1 s6.9%
GPU 内存2.8 GB2.7 GB3.6%

结论:RMSNorm 在速度和内存上都有小幅提升。


🔗 相关代码位置

MiniMind 仓库中的其他相关文件

  1. 配置文件model/model_minimind.py:30-65

    • MiniMindConfig 中的 rms_norm_eps 参数
  2. 模型初始化model/model_minimind.py:430-520

    • MiniMindForCausalLM 中创建所有 TransformerBlock
  3. 训练脚本trainer/train_pretrain.py

    • 如何设置模型配置
  4. 测试脚本eval_llm.py

    • 如何加载和使用训练好的模型

🎯 动手练习

练习 1:修改 eps 值

exp2_norm_comparison.py 中,将 eps1e-5 改为 1e-8,观察在 FP16 模式下是否会出现数值问题。

练习 2:实现 LayerNorm

参考 RMSNorm,实现一个 LayerNorm 类,对比两者的速度差异。

练习 3:可视化归一化效果

在训练过程中,记录每一层的激活标准差,绘制曲线图,验证归一化是否真的稳定了分布。


📚 延伸阅读

基于 MiniMind 项目的学习笔记