主题
7.2 FSDP(Fully Sharded Data Parallel)
上一节我们学习了 DDP——每张 GPU 上保存一份完整的模型副本,通过 AllReduce 同步梯度来保持参数一致。DDP 的设计简洁高效,但它有一个根本性的限制:每张卡必须能放下完整的模型。对于 7B BF16 模型来说这需要约 14GB,加上优化器状态和梯度总共需要 ~56GB/卡——这意味着至少需要 A100 (80GB) 才能跑得动。那如果你想在几张 RTX 3090 (24GB) 或 A6000 (48GB) 上训练一个更大的模型呢?或者你想训练一个 70B 的模型呢?DDP 完全无能为力。
FSDP(Fully Sharded Data Parallel) 就是解决这个问题的方案。它是 PyTorch 原生支持的模型分片策略:不再在每张卡上保存完整模型副本,而是把模型的参数、梯度和优化器状态全部切分(shard)到各张 GPU 上。每张卡只持有总参数量的 1/N(N 为 GPU 数量),在前向和反向传播时按需收集当前层需要的完整参数。这意味着 4 张 24GB 显存的卡可以训练一个原本需要 96GB 才能放下的模型。
FSDP vs DDP:核心区别
理解 FSDP 最快的方式是把它和 DDP 做对比:
DDP (4 GPUs, 7B model):
┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐
│ Model │ │ Model │ │ Model │ │ Model │
│ (14 GB) │ │ (14 GB) │ │ (14 GB) │ │ (14 GB) │
│ (完整) │ │ (完整) │ │ (完整) │ │ (完整) │
└──────────┘ └──────────┘ └──────────┘ └──────────┘
总显存: 56 GB+ (仅模型) → 每卡需 ≥ 14 GB
FSDP (4 GPUs, 7B model):
┌──────────┐ ┌──────────┐ ┌──────────┐ ┌──────────┐
│ Params │ │ Params │ │ Params │ │ Params │
│ (3.5 GB) │ │ (3.5 GB) │ │ (3.5 GB) │ │ (3.5 GB) │
│ (1/4) │ │ (1/4) │ │ (1/4) │ │ (1/4) │
└──────────┘ └──────────┘ └──────────┘ └──────────┘
总显存: 14 GB (模型总计) → 每卡仅需 ≥ 3.5 GB从图中可以清楚地看到:DDP 中每张卡的模型显存不随 GPU 数量变化(都是完整副本);而 FSDP 中每张卡只需要保存 1/N 的参数。
但"分片"不是免费的——FSDP 需要在前向传播的每一层执行前后做额外的通信操作:
FSDP 前向传播过程 (以 Transformer Block i 为例):
1. All-Gather: 从所有 rank 收集 Block i 的完整参数
Rank 0: 拥有 [W_i_0] → 收集 [W_i_1, W_i_2, W_i_3] → 得到完整 W_i
Rank 1: 拥有 [W_i_1] → 收集 [W_i_0, W_i_2, W_i_3] → 得到完整 W_i
...
2. Compute: 用完整的 W_i 执行 Block i 的前向计算
3. Discard: 计算完成后丢弃非本 rank 负责的那部分参数
(只保留自己负责的分片用于后续反向传播)
4. 对下一层重复步骤 1-3反向传播的过程类似,只是通信方向相反:先 Reduce-Scatter 分散梯度,然后每个 rank 用自己的梯度分片更新自己负责的参数分片。
这些额外的 All-Gather 和 Reduce-Scatter 操作就是 FSDP 相比 DDP 的主要开销来源。通信量与模型大小成正比——模型越大,通信开销越大。这也是为什么 FSDP 在小模型上可能比 DDP 还慢的原因:通信开销超过了并行计算的收益。
Sharding Strategies:三种分片级别
FSDP 提供了三种不同的分片策略(sharding strategy),让你在显存节省和通信开销之间做权衡:
| 策略 | 分片内容 | 显存占用(每卡) | 通信开销 | 适用场景 |
|---|---|---|---|---|
| NO_SHARD | 无(等价于 DDP) | 最高(= DDP) | 最低 | 对比基准 / 不需要分片 |
| SHARD_GRAD_OP | 梯度 + 优化器状态 | 中等 | 中等 | 折中方案 |
| FULL_SHARD | 参数 + 梯度 + 优化器状态 | 最低 | 最高 | 大模型 / 显存紧张 |
SHARD_GRAD_OP:只对梯度和优化器状态做分片,参数仍然完整保存在每张卡上。这节省了优化器状态的显存(通常是最大的单一内存消耗项),但模型本身仍需要完整加载。适合模型刚好能放进单卡、但优化器状态导致 OOM 的情况。
FULL_SHARD(最常用):参数、梯度、优化器状态全部切分。这是真正的"Fully Sharded"——每张卡上的总显存约为 total_model_size / N。适合大模型训练的标准选择。
python
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
FullStateDictConfig,
ShardingStrategy,
)
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy
def create_fsdp_model(model_class, config, sharding="FULL_SHARD"):
local_rank = int(os.environ.get("LOCAL_RANK", 0))
model = model_class(config).to(local_rank)
auto_wrap_policy = transformer_auto_wrap_policy(
transformer_layer_cls={TransformerBlock},
)
fsdp_config = dict(
sharding_strategy={
"FULL_SHARD": ShardingStrategy.FULL_SHARD,
"SHARD_GRAD_OP": ShardingStrategy.SHARD_GRAD_OP,
"NO_SHARD": ShardingStrategy.NO_SHARD,
}[sharding],
auto_wrap_policy=auto_wrap_policy,
device_id=torch.cuda.current_device(),
state_dict_type="full",
state_dict_config=FullStateDictConfig(
offload_to_cpu=True,
),
limit_all_gathers=True,
)
model = FSDP(model, **fsdp_config)
return model
model = create_fsdp_model(GPT, GPTConfig(), sharding="FULL_SHARD")transformer_auto_wrap_policy — 自动决定哪些层要分片
这是 FSDP 最关键也最容易出错的配置项。它的作用是告诉 FSDP 应该把哪些模块当作一个"分片单元"(sharding unit)。对于 Transformer 模型来说,最自然的分片粒度是单个 Transformer Block——因为每个 Block 是一个独立的功能单元(包含 Attention + FFN + Norms),Block 之间的依赖关系是顺序的(Block i 的输出是 Block i+1 的输入)。
transformer_auto_wrap_policy 接收一个类(或类的元组),当 FSDP 在遍历模型结构时遇到这个类的实例,就会创建一个新的分片边界。上面的代码中我们指定了 TransformerBlock 作为分片单元,意味着 FSDP 会把每个 Block 的参数作为一个整体来分片——All-Gather 时一次收集整个 Block 的参数,计算完后一起释放。
如果选错了 wrap policy(比如把整个模型当作一个分片单元),那就退化成了没有分片效果;如果把每个 Linear 层都当作分片单元(太细粒度),会导致频繁的小规模 All-Gather 操作,通信开销爆炸。
state_dict_type="full" 和 CPU Offloading
这两个配置与 checkpoint 保存相关。state_dict_type="full" 表示在调用 save_state_dict() 时,FSDP 会从所有 rank 收集完整的模型参数再保存(而不是只保存本地分片)。这样得到的 checkpoint 是一个完整的、可以直接用 load_state_dict() 加载到单卡或任意数量 GPU 上的标准 checkpoint。
offload_to_cpu=True 进一步优化了这个过程:在收集全量参数做 checkpoint 时,先把不需要的参数 offload 到 CPU 内存中,减少 GPU 显存峰值占用。这在保存超大模型的 checkpoint 时非常有用——否则可能因为收集全量参数时的临时显存需求而 OOM。
FSDP 性能分析:什么时候值得用?
FSDP 不是万能药——它在某些场景下甚至比 DDP 更慢。让我们量化地分析一下:
python
def fsdp_vs_ddp_analysis():
model_sizes = [(7, 14), (13, 26), (34, 68), (70, 140)]
gpu_configs = [
("RTX 4090", 24),
("A100", 80),
("A6000", 48),
]
print(f"\n{'Model':>6s} | {'FP16 Size':>10s} | "
f"{'DDP min GPU':>12s} | {'FSDP(4x) per card':>17s}")
print("-" * 65)
for name, fp16_gb in model_sizes:
ddp_min = fp16_gb + fp16_gb * 2 + 4 # model + grad + opt + act
fsdp_per_card = fp16_gb / 4 + fp16_gb / 8 + 1 # shard + shard_grad + act
print(f"{name:>6}B | {fp16_gb:>9d} GB | "
f"{ddp_min:>11.1f} GB | {fsdp_per_card:>16.1f} GB")
fsdp_vs_ddp_analysis()
# 输出:
# Model | FP16 Size | DDP min GPU | FSDP(4x) per card
# -----------------------------------------------------------------
# 7B | 14 GB | 46.0 GB | 5.5 GB
# 13B | 26 GB | 58.0 GB | 9.5 GB
# 34B | 68 GB | 100.0 GB | 23.5 GB
# 70B | 140 GB | 172.0 GB | 47.5 GB从这个表中可以得出几个实用的结论:
- 7B 模型:DDP 需要 46GB/卡 → 至少 A6000 (48GB);FSDP 4×RTX 4090 只需 5.5GB/卡 → 消费级显卡就能跑
- 13B 模型:DDP 需要 58GB/卡 → 必须 A100 (80GB);FSDP 4×RTX 4090 需要 9.5GB/卡 → 可以
- 34B 模型:DDP 几乎不可能(100GB/卡超过任何消费级显卡);FSDP 4×A100 需要 23.5GB/卡 → 可行
- 70B 模型:DDP 完全不可行;FSDP 8×A100 需要 ~23GB/卡 → 这是 LLaMA-70B 微调的标准配置
关于速度:FSDP 由于额外的 All-Gather/RScatter 通信,通常比 DDP 慢 10%~30%(取决于模型大小和网络带宽)。但对于那些不用 FSDP 就根本跑不起来的场景来说,这点性能损失完全是可以接受的。
在 Lightning 中使用 FSDP
好消息是在 PyTorch Lightning 中启用 FSDP 只需要改一行参数:
python
import pytorch_lightning as pl
from pytorch_lightning.strategies import FSDPStrategy
trainer = pl.Trainer(
strategy=FSDPStrategy(
sharding_strategy="FULL_SHARD",
activation_checkpointing=True,
state_dict_type="full",
limit_all_gathers=True,
# auto_wrap_policy 由 Lightning 自动推断
),
devices=4,
accelerator="gpu",
precision="bf16-mixed",
)
trainer.fit(lit_model, train_loader, val_loader)Lightning 会自动处理:
- FSDP 包装模型的正确时机
- DistributedSampler 的设置
- 只在 rank 0 上保存 checkpoint
- 正确的
set_epoch()调用 - CPU offloading 配置
你唯一需要确保的是模型定义中使用了标准的 nn.Module 子类作为 Transformer Block(Lightning 能自动识别并设置正确的 auto_wrap_policy)。如果你的模型结构比较特殊,也可以手动指定:
python
strategy = FSDPStrategy(
...,
auto_wrap_policy=transformer_auto_wrap_policy(
transformer_layer_cls={TransformerBlock},
),
)常见问题排查
问题一:"CUDA out of memory" 即使使用了 FSDP
可能原因及解决方案:
- 序列长度太长 —— 减小 max_length 或开启 gradient_checkpointing
- batch size 太大 —— 进一步减小 per_device_batch_size
- 未开启 activation_checkpointing —— 加上
activation_checkpointing=True - 某些层没有被正确分片 —— 检查 auto_wrap_policy 是否覆盖了所有大参数层
- 使用了 FULL_SHARD 但 GPU 数量不够 —— 尝试 SHARD_GRAD_OP 或增加 GPU 数量
问题二:训练速度比预期慢很多
FSDP 的通信开销受以下因素影响:
- 网络带宽 —— NVLink 比 PCIe 快得多(900 GB/s vs 64 GB/s)。多卡在同一节点(NVLink互联)比跨节点(InfiniBand/Ethernet)快得多
- 模型大小 —— 模型越大,每次 All-Gather 传输的数据越多
- 分片粒度 —— 太细(每个 Linear 层分片)会导致频繁的小规模通信
检查方法:使用 torch.profiler 或 nsys 分析实际的时间花在哪里。如果大部分时间花在通信上,考虑增大 batch size 来摊薄通信开销,或者换用 DeepSpeed ZeRO(有更精细的通信优化)。
问题三:Checkpoint 文件异常大
这是因为 state_dict_type="full" 会保存完整的模型参数。对于 70B 模型的 checkpoint 可能需要 140GB+ 磁盘空间。解决方案:
- 使用
state_dict_type="sharded"保存分片格式的 checkpoint(每 rank 一个文件,恢复时也需要同样数量的 rank) - 开启
offload_to_cpu=True并使用低精度保存 - 定期清理旧的 checkpoint
到这里,我们已经掌握了两种数据并行方案:DDP(适合模型能放入单卡的场景)和 FSDP(适合超大规模模型的场景)。但在实际的大规模预训练任务(100B+ 参数)中,还有一个被广泛使用的第三方方案——微软的 DeepSpeed ZeRO。它提供了比 FSDP 更灵活的配置选项和更多的优化特性(如 CPU/NVMe Offloading),是许多顶级 LLM 训练框架(如 Megatron-LM、DeepSpeed-MiLE)的基础组件。下一节我们将深入 DeepSpeed ZeRO。