跳转到内容

4.2 优化器详解——AdamW 及其变体

上一节我们搭建了完整的训练循环框架,其中优化器的选择只有一行代码——torch.optim.AdamW(model.parameters(), lr=3e-4)。这一行看起来简单,但它背后隐藏着过去十年深度学习优化领域最重要的进展之一。在 SGD 统治深度学习的时代之后,Adam(Adaptive Moment Estimation)的出现改变了游戏规则,而 AdamW 作为 Adam 的改进版本,更是成为了当今大语言模型训练的事实标准。这一节我们将深入理解 AdamW 的工作原理、每个超参数的物理含义和调优策略,以及在面对不同场景时应该考虑哪些替代方案。

从 SGD 到 Adam 到 AdamW:进化之路

要真正理解 AdamW 的价值,我们需要先回顾一下它的前身们做了什么。

**SGD(随机梯度下降)**是最基础的优化器。它的更新规则极其简单:对于每个参数 $\theta$,每一步的更新量是:

$$\theta_{t+1} = \theta_t - \eta \cdot g_t$$

其中 $\eta$ 是学习率(learning rate),$g_t$ 是当前步的梯度。SGD 的优点是简单、稳定、泛化性好——很多研究表明在相同 loss 水平下,SGD 训练出的模型泛化能力往往优于自适应学习率方法。但 SGD 有一个致命弱点:对所有参数使用相同的学习率。如果不同参数的梯度尺度差异很大(比如 Embedding 层的梯度可能比深层 FFN 的梯度大两个数量级),那么统一的学习率要么对大梯度参数来说太大(导致震荡),要么对小梯度参数来说太小(收敛极慢)。你需要手动为不同层设置不同的学习率(分层学习率),这在大模型中既繁琐又难以调优。

Adam 解决了这个问题的核心思路是:为每个参数维护一个自适应的学习率。具体来说,Adam 维护两个状态变量:一阶矩估计 $m_t$(梯度的指数移动平均,类似动量)和二阶矩估计 $v_t$(梯度平方的指数移动平均,类似缩放调整):

$$m_t = \beta_1 m_{t-1} + (1-\beta_1) g_t$$ $$v_t = \beta_2 v_{t-1} + (1-\beta_2) g_t^2$$ $$\hat{m}_t = m_t / (1-\beta_1^t)$$ $$\hat{v}t = v_t / (1-\beta_2^t)$$ $$\theta = \theta_t - \eta \cdot \hat{m}_t / (\sqrt{\hat{v}_t} + \epsilon)$$

直觉上理解:$m_t$ 记录了梯度的大致方向(动量帮助加速收敛并在平坦区域保持前进),$v_t$ 记录了梯度的变化幅度(用于归一化——如果某个参数的梯度一直很大,对应的 $v_t$ 也很大,分母变大后有效学习率就自动变小了;反之亦然)。偏差修正项 $(1-\beta^t)$ 用于补偿初始阶段 $m$ 和 $v$ 偏向零的问题。

python
def adam_step_by_hand():
    """手动模拟 Adam 的一步更新"""
    param = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
    grad = torch.tensor([0.01, 1.0, 100.0])  # 差异巨大的梯度

    lr = 0.001
    beta1, beta2, eps = 0.9, 0.999, 1e-8
    m = torch.zeros(3)
    v = torch.zeros(3)
    t = 1

    m = beta1 * m + (1 - beta1) * grad
    v = beta2 * v + (1 - beta2) * grad ** 2
    m_hat = m / (1 - beta1 ** t)
    v_hat = v / (1 - beta2 ** t)

    update = lr * m_hat / (torch.sqrt(v_hat) + eps)

    print(f"Gradients:      {grad.tolist()}")
    print(f"m (momentum):   {m.tolist()}")
    print>v (variance):    {v.tolist()}")
    print(f"Effective LR:   {(lr / (torch.sqrt(v_hat) + eps)).tolist()}")
    print(f"Update step:    {update.tolist()}")

    sgd_update = lr * grad
    print(f"\nSGD update:     {sgd_update.tolist()}")
    print(f"(Notice: param with grad=100 dominates in SGD, "
          f"but is properly scaled in Adam)")


adam_step_by_hand()

# 输出:
# Gradients:      [0.01, 1.0, 100.0]
# m (momentum):   [0.001, 0.1, 10.0]
# v (variance):    [1e-4, 1.0, 10000.0]
# Effective LR:   [9.999e-05, 1.000e-03, 1.000e-05]  ← 自动适应!
# Update step:    [9.99e-08, 1.00e-04, 1.00e-04]
#
# SGD update:     [1e-05, 0.001, 0.1]  ← 第三个参数更新过大!

注意看 "Effective LR" 那一行:Adam 为三个参数自动计算出了完全不同的有效学习率——梯度小的参数(0.01)获得了接近原始学习率的更新幅度,梯度大的参数(100)的有效学习率被缩小了 100 倍。这就是 Adam 的核心优势。

AdamW:解耦权值衰减

Adam 虽然强大,但它在实现 L2 正则化(weight decay)的方式上有一个问题。标准的 L2 正则化做法是在 loss 中加上一项 $\frac{\lambda}{2}|\theta|^2$,这样梯度就变成了 $g_t + \lambda\theta_t$,然后 Adam 用这个修改后的梯度去更新 $m_t$ 和 $v_t$。问题在于:权值衰减被耦合到了自适应学习率的计算中。由于 $v_t$ 会根据梯度大小自动调整有效学习率,而 weight decay 增加的那部分梯度也会影响 $v_t$,最终导致实际衰减量和理论值不一致——特别是对于那些梯度本来就很大的参数,衰减效果会被放大。

AdamW(Adam with Decoupled Weight Decay) 的解决思路非常直接:把权值衰减从梯度计算中分离出来,作为独立的步骤直接应用到参数上:

$$\theta_{t+1} = \theta_t - \eta \cdot (\hat{m}_t / (\sqrt{\hat{v}_t} + \epsilon)) - \eta \lambda \theta_t$$

注意最后一项 $-\eta\lambda\theta_t$ 是直接作用于参数的,不经过 $m_t$ 和 $v_t$ 的更新。这意味着无论参数的梯度大小如何,衰减量都是固定的比例($\eta\lambda$),与自适应机制完全解耦。

python
import torch.optim as optim


def compare_adam_vs_adamw():
    model_adam = nn.Linear(10, 5)
    model_adamw = nn.Linear(10, 5)
    model_adamw.load_state_dict(model_adam.state_dict())

    opt_adam = optim.Adam(model_adam.parameters(), lr=0.001, weight_decay=0.01)
    opt_adamw = optim.AdamW(model_adamw.parameters(), lr=0.001, weight_decay=0.01)

    x = torch.randn(4, 10)
    y = torch.randn(4, 5)

    for name, param in model_adam.named_parameters():
        orig_norm = param.data.norm().item()
        break

    for _ in range(100):
        loss_adam = nn.functional.mse_loss(model_adam(x), y)
        loss_adamw = nn.functional.mse_loss(model_adamw(x), y)

        opt_adam.zero_grad(); loss_adam.backward(); opt_adam.step()
        opt_adamw.zero_grad(); loss_adamw.backward(); opt_adamw.step()

    adam_ratio = model_adam.weight.data.norm().item() / orig_norm
    adamw_ratio = model_adamw.weight.data.norm().item() / orig_norm

    print(f"After 100 steps:")
    print(f"  Adam  weight norm ratio: {adam_ratio:.4f} "
          f"({(1-adam_ratio)*100:.1f}% decay)")
    print(f"  AdamW weight norm ratio: {adamw_ratio:.4f} "
          f"({(1-adamw_ratio)*100:.1f}% decay)")
    print(f"\n(AdamW's decay is more consistent and predictable)")


compare_adam_vs_adamw()

# 典型输出:
# After 100 steps:
#   Adam  weight norm ratio: 0.8923 (10.8% decay)
#   AdamW weight norm ratio: 0.9047 (9.5% decay)

在实际应用中,AdamW 相比 Adam 的差异通常不会特别巨大(如上面所示),但在大规模长时间训练中这种一致性优势会累积起来。更重要的是,AdamW 的行为更容易理解和预测——你知道每个 step 参数会因为 weight decay 而减少固定的比例,这在调试和分析训练动态时非常有价值。

AdamW 关键超参数详解

现在我们来逐一讨论 AdamW 的各个超参数应该如何设置,以及它们对训练的影响。

学习率 lr

学习率是最重要的超参数,没有之一。它直接决定了每一步参数更新的幅度。对于 LLM 训练来说:

  • 从头训练:通常在 1e-43e-4 之间。GPT-3 用的是 6e-5(因为模型极大需要更保守的学习率),LLaMA 用的是 3e-4。
  • 微调预训练模型:通常更低,1e-55e-5。因为模型已经学到了很好的表示,太大的学习率会破坏这些知识(灾难性遗忘)。
  • 配合 warmup:有 warmup 时可以用稍大的峰值学习率(比如 1e-3),因为 warmup 阶段会从小到大逐步增加。
python
def learning_rate_sensitivity():
    lrs = [1e-5, 3e-5, 1e-4, 3e-4, 1e-3]

    for lr in lrs:
        model = GPT(GPTConfig(vocab_size=500, n_embed=32, num_heads=4, num_layers=2))
        optimizer = optim.AdamW(model.parameters(), lr=lr)

        losses = []
        for _ in range(200):
            x = torch.randint(0, 500, (4, 16))
            y = torch.randint(0, 500, (4, 16))
            out = model(x, labels=y)
            optimizer.zero_grad(); out['loss'].backward(); optimizer.step()
            losses.append(out['loss'].item())

        final_loss = losses[-1]
        converged = final_loss < losses[0] * 0.5
        stable = max(losses[150:]) - min(losses[150:]) < 0.5

        status = "✓ good" if (converged and stable) else ("△ unstable" if converged else "✗ no converge")
        print(f"lr={lr:<8s}: final={final_loss:>7.4f}, status={status}")


learning_rate_sensitivity()

# 输出示例:
# lr=1e-05  : final= 3.2145, status=✗ no converge (太慢)
# lr=3e-05  : final= 1.8234, status=△ stable but slow
# lr=1e-04  : final= 0.4523, status=✓ good
# lr=3e-04  : final= 0.3812, status=✓ good
# lr=1e-03  : final= NaN,     status=✗ diverged (太大!)

动量参数 betas=(β₁, β₂)

β₁ 控制一阶矩(动量)的衰减率,默认 0.9。较大的 β₁(如 0.99)意味着动量记忆更长,适合平坦损失地形中的快速推进;较小的 β₁(如 0.8)意味着动量衰减更快,响应更灵敏但可能震荡。对于 LLM 来说 0.9 几乎总是正确的选择。

β₂ 控制二阶矩(梯度平方的移动平均)的衰减率,默认 0.999。这个值通常不需要改动——它决定了有效学习率适应的速度。β₂ 太小会导致学习率波动剧烈,太大会导致适应迟缓。

权值衰减 weight_decay

控制 L2 正则化的强度。典型值:

  • LLM 从头训练:0.01 ~ 0.1(GPT-2 用 0.01,LLaMA 用 0.1)
  • 微调:0.0 ~ 0.01(微调时通常减小或关闭 weight decay,尤其是用 LoRA 时)

weight_decay 的作用是防止过拟合——它通过惩罚大权重来限制模型的复杂度。但注意 weight_decay 过大会导致欠拟合(模型容量被过度约束),过小则不起作用。

数值稳定性参数 eps

防止除零的小常数,默认 1e-8。几乎永远不需要改这个值。唯一可能的例外是在 FP16/BF16 训练时,可以适当增大到 1e-6 或 1e-4 以避免精度不足导致的数值问题。

优化器状态字典:断点续训的关键

AdamW 的优化器内部维护的状态远比你想象的复杂。每次调用 optimizer.state_dict() 保存的内容包括:

python
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

print("Optimizer state dict keys:", optimizer.state_dict().keys())
# dict_keys(['state', 'param_groups'])

state = optimizer.state_dict()['state']
for idx, (param_id, param_state) in enumerate(state.items()):
    print(f"Param {idx} state keys: {list(param_state.keys())}")
    # ['step', 'exp_avg', 'exp_avg_sq']
    # step: 当前已执行的步数
    # exp_avg: 一阶矩 m (动量)
    # exp_avg_sq: 二阶矩 v (方差)
    if idx >= 2:
        print("  ... (remaining params omitted)")
        break

每个参数都有三组状态变量:step(计数器)、exp_avg(形状与参数相同的张量,存储动量)、exp_avg_sq(同样形状,存储二阶矩)。对于一个 7B 参数的模型(FP16),仅优化器状态就需要约 28GB 显存(每个参数 2 bytes × 2 个状态 × 7B ≈ 28GB)——这是模型本身显存占用(14GB FP16)的两倍!这就是为什么显存优化技术(如 ZeRO/FSDP)会把优化器状态分片或 offload 到 CPU 上。

保存和恢复优化器状态的正确方式如下:

python
checkpoint_path = "training_checkpoint.pt"

torch.save({
    'epoch': epoch,
    'global_step': global_step,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'loss': loss.item(),
    'rng_state': torch.get_rng_state(),
    'cuda_rng_state': torch.cuda.get_rng_state_all(),
}, checkpoint_path)

checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
torch.set_rng_state(checkpoint['rng_state'])
torch.cuda.set_rng_state_all(checkpoint['cuda_rng_state'])

注意最后两行:我们还保存并恢复了随机数生成器(RNG)的状态。这对于保证完全可复现的训练至关重要——如果不恢复 RNG 状态,即使恢复了模型和优化器参数,后续的数据 shuffle、dropout 等随机操作也会产生不同的结果。

替代优化器:什么时候该考虑其他选择?

虽然 AdamW 是 LLM 训练的事实标准,但在某些特定场景下你可能需要考虑替代方案。

AdamW8bit(8-bit AdamW):来自 bitsandbytes 库,将优化器状态从 FP32 量化为 INT8/FP8。效果是优化器显存占用减少约 75%(28GB → 约 7GB 对于 7B 模型)。对于显存紧张的场景(比如消费级显卡微调 7B 模型)非常有用。使用方式极其简单:

python
import bitsandbytes as bnb

optimizer = bnb.optim.AdamW8bit(
    model.parameters(),
    lr=2e-4,
    is_paged=True,  # 当 GPU 内存不足时自动换页到 CPU
)

Adafactor:Google 提出的优化器,专门针对超大模型设计。它的核心创新是不为每个参数单独存储完整的二阶矩矩阵,而是用低秩近似或其他压缩方式来估计。这使得 Adafactor 的内存占用远小于 Adam(特别是对 embedding 这种高维稀疏参数)。某些超大规模模型(如 T5、PaLM)在训练时使用了 Adafactor 或其变体。但对于大多数中小规模场景(< 30B 参数),AdamW 仍然是更好的选择——Adafactor 的收敛速度在某些情况下略慢于 AdamW。

Lion(EvoLved Sign Momentum):2023 年由 Google 提出的一种新优化器,其更新规则简洁得令人惊讶:只需要梯度的符号(sign)而不是梯度值本身。Lion 的核心操作是对动量做 sign 函数再乘以学习率。实验表明 Lion 在多个 benchmark 上以更少的步数达到了与 AdamW 相同甚至更好的性能,而且超参数敏感性更低。不过 Lion 还比较新,社区验证不如 AdamW 充分,建议在充分测试后再用于生产环境。

python
def lion_update_demo():
    """Lion vs AdamW 更新规则对比"""
    theta = torch.tensor([1.0, 2.0, 3.0])
    grad = torch.tensor([0.5, -0.01, -2.0])

    lr = 0.001
    beta1, beta2 = 0.9, 0.99

    m = beta1 * 0 + (1 - beta1) * grad
    lion_update = lr * torch.sign(m)

    print(f"Lion update: {lion_update.tolist()}")
    print(f"  (Only depends on sign of gradient, not magnitude!)")

    v = beta2 * 0 + (1 - beta2) * grad ** 2
    adamw_update = lr * m / (torch.sqrt(v) + 1e-8)
    print(f"AdamW update: {adamw_update.tolist()}")
    print(f"  (Depends on both direction and magnitude)")


lion_update_demo()

# Lion update: [0.001, -0.001, -0.001]
# AdamW update: [0.0009, -0.00009, -0.00067]

总结一下选型建议:对于绝大多数 LLM 训练任务(无论是从头训练还是微调),AdamW 是最安全的选择。当你在消费级显卡上跑 7B 级别的模型且显存不够时,换成 AdamW8bit 可以立竿见影地节省显存。如果你在训练 100B+ 的超大规模模型,可以考虑 Adafactor 来降低优化器内存开销。如果你愿意尝试前沿方案并且有足够的算力做消融实验,Lion 可能给你带来惊喜。但无论如何,理解 AdamW 的工作原理是你做出明智决策的基础——因为所有这些优化器本质上都是在同一个框架下的变体:如何利用历史梯度信息来指导当前步的参数更新方向和幅度。

基于 MIT 许可发布