跳转到内容

5.1 Lightning 核心哲学与第一个模型

回顾一下我们在第 4 章做的事情:写了一个完整的训练循环,从最简单的 30 行代码开始,逐步添加了梯度裁剪、学习率调度、混合精度训练、评估循环、日志记录、checkpoint 管理、早停机制……最终组装成了一个功能齐全的 Trainer 类。这个类大概有 150~200 行代码,而且我们还没包括分布式训练、多 GPU 同步、FP16/BF16 自动处理、梯度 checkpointing 这些更高级的功能。如果你把这些都加上去,一个生产级的训练框架轻松就能超过 500 行。问题在于:这 500 行代码中,真正和你的模型逻辑相关的可能只有 20~30 行(forward、loss 计算),剩下的 470 行都是通用的工程样板代码。每次开始一个新项目都要重新写一遍这些样板代码不仅浪费时间,更重要的是——它们是 bug 的温床。设备管理写错了?分布式通信顺序不对?半精度转换漏了一层?这些错误往往不会立即报错,而是在训练了几十个小时后才以奇怪的方式显现出来。

PyTorch Lightning 的核心设计理念就是解决这个痛点。它把训练相关的所有工程细节抽象成一个统一的 Trainer 类,而你只需要通过继承 LightningModule 来定义模型的核心逻辑——前向传播怎么算、loss 怎么定义、优化器用什么。其他一切:设备管理(.cuda() / .cpu() / .mps())、分布式训练(DDP/FSDP/DeepSpeed)、半精度(FP16/BF16)、梯度裁剪、学习率调度、checkpoint 保存与恢复、日志记录、早停……全部由 Trainer 自动处理。你用 25 行核心代码就能完成之前 200 行才能做到的事情,而且不容易出错。

从手写到 Lightning:一次直观的对比

让我们先看一个具体的例子来感受这种差异。下面是我们第 4 章手写的训练循环的核心部分:

python
model = GPT(config).cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=0.01)
criterion = nn.CrossEntropyLoss(ignore_index=-100)

scheduler = get_cosine_schedule_with_warmup(
    optimizer, num_warmup_steps=500, num_training_steps=total_steps
)

for epoch in range(num_epochs):
    model.train()
    for step, batch in enumerate(train_loader):
        input_ids = batch['input_ids'].cuda()
        labels = batch['labels'].cuda()

        with autocast('cuda', dtype=torch.bfloat16):
            outputs = model(input_ids, labels=labels)
            loss = outputs['loss']

        scaler.scale(loss).backward()
        scaler.unscale_(optimizer)
        clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

    model.eval()
    val_losses = []
    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].cuda()
            labels = batch['labels'].cuda()
            outputs = model(input_ids, labels=labels)
            val_losses.append(outputs['loss'].item())
    avg_val_loss = sum(val_losses) / len(val_losses)

    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        torch.save(model.state_dict(), "best_model.pt")

同样的功能用 PyTorch Lightning 来实现:

python
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping


class LitGPT(pl.LightningModule):
    def __init__(self, config: GPTConfig):
        super().__init__()
        self.config = config
        self.model = GPT(config)

    def forward(self, input_ids, labels=None):
        return self.model(input_ids, labels=labels)

    def training_step(self, batch, batch_idx):
        output = self(batch['input_ids'], batch['labels'])
        self.log('train_loss', output['loss'], prog_bar=True)
        return output['loss']

    def validation_step(self, batch, batch_idx):
        output = self(batch['input_ids'], batch['labels'])
        self.log('val_loss', output['loss'], prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.config.lr,
            weight_decay=self.config.weight_decay,
        )
        scheduler = get_cosine_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.config.warmup_steps,
            num_training_steps=self.config.total_steps,
        )
        return [optimizer], [{"scheduler": scheduler, "interval": "step"}]

    def configure_callbacks(self):
        return [
            ModelCheckpoint(monitor='val_loss', save_top_k=3),
            EarlyStopping(monitor='val_loss', patience=5),
        ]


trainer = pl.Trainer(
    max_epochs=20,
    accelerator="auto",
    precision="bf16-mixed",
    gradient_clip_val=1.0,
    accumulate_grad_batches=4,
)

lit_model = LitGPT(config)
trainer.fit(lit_model, train_loader, val_loader)

数一下有效代码行数:LitGPT 类大约 30 行,Trainer 配置大约 8 行,调用 fit 一行。总共不到 40 行就完成了手写版本 80+ 行的功能。但更重要的是省略了什么:没有 .cuda() 调用(Lightning 自动处理)、没有 autocast 上下文管理器(precision="bf16-mixed" 一行搞定)、没有 scaler 管理(Lightning 内部自动处理)、没有 model.train()/eval() 切换(Lightning 在正确的时机自动调用)、没有手动 checkpoint 保存(ModelCheckpoint callback 自动处理)、没有早停逻辑(EarlyStopping callback 处理)。所有这些工程细节都被封装在了 Trainer 的内部实现中,经过了数千个项目的验证,比你自己手写的可靠得多。

LightningModule vs nn.Module:继承关系与新增能力

理解 Lightning 的第一步是搞清楚 LightningModule 和 PyTorch 原生 nn.Module 之间的关系。关键事实是:LightningModule 继承自 nn.Module。这意味着你之前学到的关于 nn.Module 的一切知识——parameters()state_dict().train()/.eval().to(device)、子模块注册等——在 LightningModule 中完全适用,没有任何破坏性变更。

python
import pytorch_lightning as pl

class LitGPT(pl.LightningModule):
    def __init__(self, config):
        super().__init__()
        self.model = GPT(config)       # nn.Module 子模块 ✓
        self.save_hyperparameters()   # Lightning 新增

    def forward(self, x):             # 和 nn.Module 完全一样 ✓
        return self.model(x)


lit_model = LitGPT(config)

print(f"Is nn.Module? {isinstance(lit_model, nn.Module)}")
print(f"Parameters count: {sum(p.numel() for p in lit_model.parameters()):,}")
print(f"State dict keys (first 3): {list(lit_model.state_dict().keys())[:3]}")

lit_model.eval()
print(f"Training mode: {lit_model.training}")

hparams = lit_model.hparams
print(f"Hyperparameters saved: {list(hparams.keys())}")

除了继承 nn.Module 的全部能力之外,LightningModule 还新增了几个重要的接口方法:

方法触发时机用途
training_step(self, batch, batch_idx)每个 training batch定义训练逻辑,返回 loss
validation_step(self, batch, batch_idx)每个 validation batch定义验证逻辑
test_step(self, batch, batch_idx)每个 test batch定义测试逻辑
predict_step(self, batch, batch_idx)每个 prediction batch定义推理逻辑
configure_optimizers(self)训练开始时返回优化器和学习率调度器
configure_callbacks(self)训练开始时返回 callbacks 列表

其中 training_step 是最核心的方法——它替代了手写循环中的"前向传播 → 计算 loss → 返回 loss"这一整块逻辑。你只需要关心"给定一个 batch,怎么计算 loss",其余的一切(梯度计算、参数更新、学习率调整、日志记录)由 Trainer 接管。

Trainer 核心参数速览

pl.Trainer 是 Lightning 的另一个核心组件——它是整个训练过程的编排者。虽然它的默认值已经能覆盖大多数场景,但了解每个参数的含义有助于你在需要时做出正确的配置选择。

基本训练控制

python
trainer = pl.Trainer(
    max_epochs=20,           # 最大训练轮数
    max_steps=100_000,       # 最大步数(优先于 max_epochs)
    limit_train_batches=100, # 每个 epoch 只用前 100 个 batch(快速调试用)
    limit_val_batches=10,    # 验证集同理
)

limit_train_batcheslimit_val_batches 在开发调试阶段极其有用——你可以只用数据集的一小部分来快速验证整个流程是否跑通,而不需要等待完整遍历整个数据集。设为 0.1 表示使用 10% 的数据,设为整数表示使用指定数量的 batch。

加速器与精度

python
trainer = pl.Trainer(
    accelerator="gpu",          # "cpu", "gpu", "tpu", "mps", "auto"
    devices=1,                  # 使用几张卡;-1 表示全部
    precision="bf16-mixed",     # "32", "16-mixed", "bf16-mixed"
)

accelerator="auto" 会自动检测可用的硬件——有 CUDA 就用 GPU,有 Apple Silicon 就用 MPS,都没有就用 CPU。这是最推荐的设置,因为它让你的代码在不同机器上无需修改即可运行。

精度选项中 "32" 表示纯 FP32(不使用混合精度);"16-mixed" 表示 FP16 混合精度(需要 GradScaler);"bf16-mixed" 表示 BF16 混合精度(不需要 GradScaler,推荐用于 Ampere+ GPU)。

梯度相关

python
trainer = pl.Trainer(
    gradient_clip_val=1.0,         # 梯度裁剪阈值
    gradient_clip_algorithm="norm", # "norm"(默认) 或 "value"
    accumulate_grad_batches=4,      # 梯度累积步数
)

注意这里的 gradient_clip_val 直接对应我们手写循环中的 clip_grad_norm_(..., max_norm=1.0) —— Lightning 会在每次 optimizer.step() 之前自动执行裁剪。

日志与回调

python
from pytorch_lightning.loggers import WandbLogger

trainer = pl.Trainer(
    logger=WandbLogger(project="my-gpt"),
    enable_checkpointing=True,
    default_root_dir="./outputs",
)

Lightning 支持多种日志后端:TensorBoardLogger(默认)、CSVLogger、WandbLogger、CometLogger 等。切换日志后端只需改一行代码,你的 self.log() 调用会自动适配到新的后端。

第一个完整的 Lightning 项目

让我们把所有知识整合起来,用 Lightning 重写我们的 GPT 训练项目:

python
import pytorch_lightning as pl
from pytorch_lightning.callbacks import (
    ModelCheckpoint, EarlyStopping, LearningRateMonitor,
)
from pytorch_lightning.loggers import WandbLogger


class LitGPT(pl.LightningModule):
    """基于 Lightning 的 GPT 训练模块"""

    def __init__(
        self,
        vocab_size: int = 1000,
        n_embed: int = 128,
        num_heads: int = 4,
        num_layers: int = 4,
        max_seq_len: int = 256,
        lr: float = 3e-4,
        weight_decay: float = 0.01,
        warmup_ratio: float = 0.05,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.save_hyperparameters()

        config = GPTConfig(
            vocab_size=vocab_size,
            n_embed=n_embed,
            num_heads=num_heads,
            num_layers=num_layers,
            max_seq_len=max_seq_len,
            dropout=dropout,
        )
        self.gpt = GPT(config)

    def forward(self, input_ids, labels=None):
        return self.gpt(input_ids, labels=labels)

    def training_step(self, batch, batch_idx):
        output = self(batch['input_ids'], batch['labels'])
        loss = output['loss']
        self.log('train_loss', loss, prog_bar=True, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        output = self(batch['input_ids'], batch['labels'])
        loss = output['loss']
        self.log('val_loss', loss, prog_bar=True, on_epoch=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(
            self.parameters(),
            lr=self.hparams.lr,
            weight_decay=self.hparams.weight_decay,
        )

        total_steps = self.trainer.estimated_stepping_batches
        warmup_steps = int(total_steps * self.hparams.warmup_ratio)

        scheduler = get_cosine_schedule_with_warmup(
            optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=total_steps,
        )

        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'step',
            },
        }

    def configure_callbacks(self):
        return [
            ModelCheckpoint(
                monitor='val_loss',
                save_top_k=3,
                filename='gpt-{epoch:02d}-{val_loss:.4f}',
            ),
            EarlyStopping(
                monitor='val_loss',
                patience=5,
                mode='min',
            ),
            LearningRateMonitor(logging_interval='step'),
        ]


def train_with_lightning():
    lit_model = LitGPT(
        vocab_size=1000,
        n_embed=128,
        num_heads=4,
        num_layers=4,
        max_seq_len=256,
        lr=3e-4,
    )

    trainer = pl.Trainer(
        max_epochs=20,
        accelerator="auto",
        precision="bf16-mixed",
        gradient_clip_val=1.0,
        accumulate_grad_batches=4,
        logger=WandbLogger(project="lit-gpt-tutorial"),
        enable_checkpointing=True,
        limit_val_batches=20,
    )

    trainer.fit(lit_model, train_loader, val_loader)

    print("\nTraining complete!")
    print(f"Best checkpoint path: {trainer.checkpoint_callback.best_model_path}")


if __name__ == "__main__":
    train_with_lightning()

这段代码的运行效果和第 4 章的手写 Trainer 完全等价,但代码量减少了约 75%,且消除了几乎所有可能的工程错误来源。当你运行它时,Lightning 会在终端输出详细的训练进度信息:

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
...
Epoch 0: 100%|████████████| 625/625 [01:23<00:00, 7.50it/s, v_num=1, train_loss=3.2145, lr=2.98e-05]
Validating: 100%|████████████| 20/20 [00:02<00:00, 9.87it/s, val_loss=3.1567]
Epoch 1: 100%|████████████| 625/625 [01:21<00:00, 7.69it/s, v_num=1, train_loss=2.8765, lr=8.45e-05]
Validating: 100%|████████████| 20/20 [00:02<00:00, 9.91it/s, val_loss=2.7834]
...

每一行都包含了当前 epoch、进度条、速度(it/s)、各种 log 指标和学习率——这些都是免费的,不需要你写一行额外的日志代码。

常见误区与注意事项

在使用 Lightning 的过程中,有几个新手容易踩的坑值得提前说明。

误区一:在 training_step 中手动调用 .cuda().to(device)

Lightning 会自动把 batch 数据移动到正确的设备上。如果你再手动调用 .cuda(),在单卡训练时不会出错(重复移动到同一设备),但在多卡或 CPU 训练时会报错或产生意外行为。正确做法是在 training_step 中直接使用接收到的 batch,假设它已经在正确的设备上了。

误区二:在 training_step 中忘记返回 loss

training_step 必须返回一个标量张量(loss),因为 Trainer 需要它来调用 .backward()。如果你返回了 None 或者一个字典但没有包含 'loss' 键,训练将无法进行。如果除了 loss 还想返回其他信息(比如中间层的输出用于分析),可以用字典形式返回并包含 loss 键:

python
def training_step(self, batch, batch_idx):
    output = self(batch['input_ids'], batch['labels'])
    return {'loss': output['loss'], 'logits': output['logits']}

误区三:混淆 on_stepon_epoch 参数

self.log() 有两个重要的布尔参数:

  • on_step=True:每一步都记录(适合 train_loss 这种每步都在变化的指标)
  • on_epoch=True:每个 epoch 结束时自动计算平均值并记录(适合 val_loss 这种需要聚合的指标)

对于训练指标通常同时设置两个(on_step=True, on_epoch=True);对于验证指标通常只设置 on_epoch=True(避免日志过于频繁)。

误区四:configure_optimizers 中返回格式错误

这是最常见的 Lightning 报错之一。configure_optimizers 支持多种返回格式,最常用的两种是:

python
def configure_optimizers(self):
    opt = torch.optim.AdamW(self.parameters(), lr=3e-4)
    sched = get_cosine_schedule_with_warmup(opt, ...)
    
    # 格式一:简化版(只有优化器)
    return opt
    
    # 格式二:完整版(优化器 + 调度器)
    return {
        'optimizer': opt,
        'lr_scheduler': {
            'scheduler': sched,
            'interval': 'step',     # 'step' 或 'epoch'
            'frequency': 1,
        },
    }

注意调度器的 interval 参数——大多数 LLM 训练应该用 'step'(每步更新学习率),而不是默认的 'epoch'(每 epoch 更新一次)。

到这里,你已经掌握了 Lightning 的基本用法。下一节我们将深入 Lightning 的生命周期系统——了解 Trainer 在训练过程中究竟按什么顺序调用了哪些钩子方法,以及如何利用 Callback 系统来实现自定义的训练行为。

基于 MIT 许可发布