跳转到内容

7.2 FSDP(Fully Sharded Data Parallel)

上一节我们学习了 DDP——每张 GPU 上保存一份完整的模型副本,通过 AllReduce 同步梯度来保持参数一致。DDP 的设计简洁高效,但它有一个根本性的限制:每张卡必须能放下完整的模型。对于 7B BF16 模型来说这需要约 14GB,加上优化器状态和梯度总共需要 ~56GB/卡——这意味着至少需要 A100 (80GB) 才能跑得动。那如果你想在几张 RTX 3090 (24GB) 或 A6000 (48GB) 上训练一个更大的模型呢?或者你想训练一个 70B 的模型呢?DDP 完全无能为力。

FSDP(Fully Sharded Data Parallel) 就是解决这个问题的方案。它是 PyTorch 原生支持的模型分片策略:不再在每张卡上保存完整模型副本,而是把模型的参数、梯度和优化器状态全部切分(shard)到各张 GPU 上。每张卡只持有总参数量的 1/N(N 为 GPU 数量),在前向和反向传播时按需收集当前层需要的完整参数。这意味着 4 张 24GB 显存的卡可以训练一个原本需要 96GB 才能放下的模型。

FSDP vs DDP:核心区别

理解 FSDP 最快的方式是把它和 DDP 做对比:

DDP (4 GPUs, 7B model):
┌──────────┐  ┌──────────┐  ┌──────────┐  ┌──────────┐
│ Model    │  │ Model    │  │ Model    │  │ Model    │
│ (14 GB)  │  │ (14 GB)  │  │ (14 GB)  │  │ (14 GB)  │
│ (完整)   │  │ (完整)   │  │ (完整)   │  │ (完整)   │
└──────────┘  └──────────┘  └──────────┘  └──────────┘
总显存: 56 GB+ (仅模型) → 每卡需 ≥ 14 GB

FSDP (4 GPUs, 7B model):
┌──────────┐  ┌──────────┐  ┌──────────┐  ┌──────────┐
│ Params   │  │ Params   │  │ Params   │  │ Params   │
│ (3.5 GB) │  │ (3.5 GB) │  │ (3.5 GB) │  │ (3.5 GB) │
│ (1/4)    │  │ (1/4)    │  │ (1/4)    │  │ (1/4)    │
└──────────┘  └──────────┘  └──────────┘  └──────────┘
总显存: 14 GB (模型总计) → 每卡仅需 ≥ 3.5 GB

从图中可以清楚地看到:DDP 中每张卡的模型显存不随 GPU 数量变化(都是完整副本);而 FSDP 中每张卡只需要保存 1/N 的参数。

但"分片"不是免费的——FSDP 需要在前向传播的每一层执行前后做额外的通信操作:

FSDP 前向传播过程 (以 Transformer Block i 为例):

1. All-Gather: 从所有 rank 收集 Block i 的完整参数
   Rank 0: 拥有 [W_i_0] → 收集 [W_i_1, W_i_2, W_i_3] → 得到完整 W_i
   Rank 1: 拥有 [W_i_1] → 收集 [W_i_0, W_i_2, W_i_3] → 得到完整 W_i
   ...

2. Compute: 用完整的 W_i 执行 Block i 的前向计算

3. Discard: 计算完成后丢弃非本 rank 负责的那部分参数
   (只保留自己负责的分片用于后续反向传播)
   
4. 对下一层重复步骤 1-3

反向传播的过程类似,只是通信方向相反:先 Reduce-Scatter 分散梯度,然后每个 rank 用自己的梯度分片更新自己负责的参数分片。

这些额外的 All-Gather 和 Reduce-Scatter 操作就是 FSDP 相比 DDP 的主要开销来源。通信量与模型大小成正比——模型越大,通信开销越大。这也是为什么 FSDP 在小模型上可能比 DDP 还慢的原因:通信开销超过了并行计算的收益。

Sharding Strategies:三种分片级别

FSDP 提供了三种不同的分片策略(sharding strategy),让你在显存节省和通信开销之间做权衡:

策略分片内容显存占用(每卡)通信开销适用场景
NO_SHARD无(等价于 DDP)最高(= DDP)最低对比基准 / 不需要分片
SHARD_GRAD_OP梯度 + 优化器状态中等中等折中方案
FULL_SHARD参数 + 梯度 + 优化器状态最低最高大模型 / 显存紧张

SHARD_GRAD_OP:只对梯度和优化器状态做分片,参数仍然完整保存在每张卡上。这节省了优化器状态的显存(通常是最大的单一内存消耗项),但模型本身仍需要完整加载。适合模型刚好能放进单卡、但优化器状态导致 OOM 的情况。

FULL_SHARD(最常用):参数、梯度、优化器状态全部切分。这是真正的"Fully Sharded"——每张卡上的总显存约为 total_model_size / N。适合大模型训练的标准选择。

python
from torch.distributed.fsdp import (
    FullyShardedDataParallel as FSDP,
    FullStateDictConfig,
    ShardingStrategy,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy


def create_fsdp_model(model_class, config, sharding="FULL_SHARD"):
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    
    model = model_class(config).to(local_rank)
    
    auto_wrap_policy = transformer_auto_wrap_policy(
        transformer_layer_cls={TransformerBlock},
    )
    
    fsdp_config = dict(
        sharding_strategy={
            "FULL_SHARD": ShardingStrategy.FULL_SHARD,
            "SHARD_GRAD_OP": ShardingStrategy.SHARD_GRAD_OP,
            "NO_SHARD": ShardingStrategy.NO_SHARD,
        }[sharding],
        auto_wrap_policy=auto_wrap_policy,
        device_id=torch.cuda.current_device(),
        state_dict_type="full",
        state_dict_config=FullStateDictConfig(
            offload_to_cpu=True,
        ),
        limit_all_gathers=True,
    )
    
    model = FSDP(model, **fsdp_config)
    
    return model


model = create_fsdp_model(GPT, GPTConfig(), sharding="FULL_SHARD")

transformer_auto_wrap_policy — 自动决定哪些层要分片

这是 FSDP 最关键也最容易出错的配置项。它的作用是告诉 FSDP 应该把哪些模块当作一个"分片单元"(sharding unit)。对于 Transformer 模型来说,最自然的分片粒度是单个 Transformer Block——因为每个 Block 是一个独立的功能单元(包含 Attention + FFN + Norms),Block 之间的依赖关系是顺序的(Block i 的输出是 Block i+1 的输入)。

transformer_auto_wrap_policy 接收一个类(或类的元组),当 FSDP 在遍历模型结构时遇到这个类的实例,就会创建一个新的分片边界。上面的代码中我们指定了 TransformerBlock 作为分片单元,意味着 FSDP 会把每个 Block 的参数作为一个整体来分片——All-Gather 时一次收集整个 Block 的参数,计算完后一起释放。

如果选错了 wrap policy(比如把整个模型当作一个分片单元),那就退化成了没有分片效果;如果把每个 Linear 层都当作分片单元(太细粒度),会导致频繁的小规模 All-Gather 操作,通信开销爆炸。

state_dict_type="full" 和 CPU Offloading

这两个配置与 checkpoint 保存相关。state_dict_type="full" 表示在调用 save_state_dict() 时,FSDP 会从所有 rank 收集完整的模型参数再保存(而不是只保存本地分片)。这样得到的 checkpoint 是一个完整的、可以直接用 load_state_dict() 加载到单卡或任意数量 GPU 上的标准 checkpoint。

offload_to_cpu=True 进一步优化了这个过程:在收集全量参数做 checkpoint 时,先把不需要的参数 offload 到 CPU 内存中,减少 GPU 显存峰值占用。这在保存超大模型的 checkpoint 时非常有用——否则可能因为收集全量参数时的临时显存需求而 OOM。

FSDP 性能分析:什么时候值得用?

FSDP 不是万能药——它在某些场景下甚至比 DDP 更慢。让我们量化地分析一下:

python
def fsdp_vs_ddp_analysis():
    model_sizes = [(7, 14), (13, 26), (34, 68), (70, 140)]
    gpu_configs = [
        ("RTX 4090", 24),
        ("A100", 80),
        ("A6000", 48),
    ]
    
    print(f"\n{'Model':>6s} | {'FP16 Size':>10s} | "
          f"{'DDP min GPU':>12s} | {'FSDP(4x) per card':>17s}")
    print("-" * 65)
    
    for name, fp16_gb in model_sizes:
        ddp_min = fp16_gb + fp16_gb * 2 + 4  # model + grad + opt + act
        fsdp_per_card = fp16_gb / 4 + fp16_gb / 8 + 1  # shard + shard_grad + act
        
        print(f"{name:>6}B | {fp16_gb:>9d} GB | "
              f"{ddp_min:>11.1f} GB | {fsdp_per_card:>16.1f} GB")


fsdp_vs_ddp_analysis()

# 输出:
# Model | FP16 Size | DDP min GPU | FSDP(4x) per card
# -----------------------------------------------------------------
#     7B |       14 GB |       46.0 GB |            5.5 GB
#    13B |       26 GB |       58.0 GB |            9.5 GB
#    34B |       68 GB |      100.0 GB |           23.5 GB
#    70B |      140 GB |      172.0 GB |           47.5 GB

从这个表中可以得出几个实用的结论:

  • 7B 模型:DDP 需要 46GB/卡 → 至少 A6000 (48GB);FSDP 4×RTX 4090 只需 5.5GB/卡 → 消费级显卡就能跑
  • 13B 模型:DDP 需要 58GB/卡 → 必须 A100 (80GB);FSDP 4×RTX 4090 需要 9.5GB/卡 → 可以
  • 34B 模型:DDP 几乎不可能(100GB/卡超过任何消费级显卡);FSDP 4×A100 需要 23.5GB/卡 → 可行
  • 70B 模型:DDP 完全不可行;FSDP 8×A100 需要 ~23GB/卡 → 这是 LLaMA-70B 微调的标准配置

关于速度:FSDP 由于额外的 All-Gather/RScatter 通信,通常比 DDP 慢 10%~30%(取决于模型大小和网络带宽)。但对于那些不用 FSDP 就根本跑不起来的场景来说,这点性能损失完全是可以接受的。

在 Lightning 中使用 FSDP

好消息是在 PyTorch Lightning 中启用 FSDP 只需要改一行参数:

python
import pytorch_lightning as pl
from pytorch_lightning.strategies import FSDPStrategy


trainer = pl.Trainer(
    strategy=FSDPStrategy(
        sharding_strategy="FULL_SHARD",
        activation_checkpointing=True,
        state_dict_type="full",
        limit_all_gathers=True,
        # auto_wrap_policy 由 Lightning 自动推断
    ),
    devices=4,
    accelerator="gpu",
    precision="bf16-mixed",
)

trainer.fit(lit_model, train_loader, val_loader)

Lightning 会自动处理:

  • FSDP 包装模型的正确时机
  • DistributedSampler 的设置
  • 只在 rank 0 上保存 checkpoint
  • 正确的 set_epoch() 调用
  • CPU offloading 配置

你唯一需要确保的是模型定义中使用了标准的 nn.Module 子类作为 Transformer Block(Lightning 能自动识别并设置正确的 auto_wrap_policy)。如果你的模型结构比较特殊,也可以手动指定:

python
strategy = FSDPStrategy(
    ...,
    auto_wrap_policy=transformer_auto_wrap_policy(
        transformer_layer_cls={TransformerBlock},
    ),
)

常见问题排查

问题一:"CUDA out of memory" 即使使用了 FSDP

可能原因及解决方案:

  1. 序列长度太长 —— 减小 max_length 或开启 gradient_checkpointing
  2. batch size 太大 —— 进一步减小 per_device_batch_size
  3. 未开启 activation_checkpointing —— 加上 activation_checkpointing=True
  4. 某些层没有被正确分片 —— 检查 auto_wrap_policy 是否覆盖了所有大参数层
  5. 使用了 FULL_SHARD 但 GPU 数量不够 —— 尝试 SHARD_GRAD_OP 或增加 GPU 数量

问题二:训练速度比预期慢很多

FSDP 的通信开销受以下因素影响:

  1. 网络带宽 —— NVLink 比 PCIe 快得多(900 GB/s vs 64 GB/s)。多卡在同一节点(NVLink互联)比跨节点(InfiniBand/Ethernet)快得多
  2. 模型大小 —— 模型越大,每次 All-Gather 传输的数据越多
  3. 分片粒度 —— 太细(每个 Linear 层分片)会导致频繁的小规模通信

检查方法:使用 torch.profilernsys 分析实际的时间花在哪里。如果大部分时间花在通信上,考虑增大 batch size 来摊薄通信开销,或者换用 DeepSpeed ZeRO(有更精细的通信优化)。

问题三:Checkpoint 文件异常大

这是因为 state_dict_type="full" 会保存完整的模型参数。对于 70B 模型的 checkpoint 可能需要 140GB+ 磁盘空间。解决方案:

  • 使用 state_dict_type="sharded" 保存分片格式的 checkpoint(每 rank 一个文件,恢复时也需要同样数量的 rank)
  • 开启 offload_to_cpu=True 并使用低精度保存
  • 定期清理旧的 checkpoint

到这里,我们已经掌握了两种数据并行方案:DDP(适合模型能放入单卡的场景)和 FSDP(适合超大规模模型的场景)。但在实际的大规模预训练任务(100B+ 参数)中,还有一个被广泛使用的第三方方案——微软的 DeepSpeed ZeRO。它提供了比 FSDP 更灵活的配置选项和更多的优化特性(如 CPU/NVMe Offloading),是许多顶级 LLM 训练框架(如 Megatron-LM、DeepSpeed-MiLE)的基础组件。下一节我们将深入 DeepSpeed ZeRO。

基于 MIT 许可发布