跳转到内容

LoRA原理

LoRA(Low-Rank Adaptation)是一种高效的模型微调技术,由微软研究院在 2021 年提出。它的核心思想是:大型预训练模型的权重矩阵通常是满秩的,但在微调时,我们假设权重的更新量(ΔW)可以用低秩矩阵来近似。这样可以大幅减少需要训练的参数量,同时保持与全量微调相近的性能。本篇文章详细介绍 LoRA 的原理。

LoRA 的核心思想

在传统的全量微调中,我们需要更新所有模型参数。假设模型的某个权重矩阵是 W₀ ∈ ℝ^{d×k},全量微调会更新 W₀ 为 W₀ + ΔW。

LoRA 的假设是:ΔW 可以用两个小矩阵的乘积来近似,即 ΔW = BA,其中 B ∈ ℝ^{d×r},A ∈ ℝ^{r×k},r 是远小于 d 和 k 的秩。这样:

新权重 = W₀ + α * (B @ A)

其中 α 是一个缩放因子。

python
import numpy as np

def demonstrate_lora_concept():
    """演示 LoRA 的核心概念"""
    print("=== LoRA 核心思想 ===\n")

    # 假设原始权重
    d, k = 1024, 1024  # 典型的大型权重矩阵
    r = 8  # 低秩

    print(f"原始权重 W₀: {d} × {k} = {d*k:,} 参数")

    # LoRA: 用两个小矩阵替代
    print(f"\nLoRA 分解:")
    print(f"  A: {r} × {k} = {r*k:,} 参数")
    print(f"  B: {d} × {r} = {d*r:,} 参数")
    print(f"  总计: {r*k + d*r:,} 参数")

    # 计算压缩比
    compression = (d*k) / (r*k + d*r)
    print(f"\n参数量压缩比: {compression:.1f}x")

    # 当 r=8 时
    r = 8
    compression_8 = (d*k) / (8*k + d*8)
    print(f"当 r=8 时: 压缩比 = {compression_8:.1f}x")

demonstrate_lora_concept()

LoRA 的数学推导

为什么低秩有效

LoRA 的有效性基于以下假设:

  1. 过参数化:预训练模型的权重通常具有较大的 intrinsic dimension,即存在一个较小的子空间包含了大部分有用的信息。

  2. 任务适应性:微调过程中权重的更新(ΔW)倾向于存在于一个低秩子空间中。

  3. 微软的论文表明,对于 Transformer 中的权重矩阵,ΔW 确实可以用很低的秩(如 r=4, 8)来近似,且信息损失很小。

python
def explain_lora_math():
    """解释 LoRA 的数学原理"""
    print("\n=== LoRA 数学原理 ===\n")

    print("全量微调:")
    print("  Y = X @ (W₀ + ΔW)")
    print("  需要更新 d×k 个参数\n")

    print("LoRA 微调:")
    print("  Y = X @ (W₀ + α * (B @ A))")
    print("  Y = X @ W₀ + α * (X @ B) @ A")
    print("  冻结 W₀,只更新 A 和 B")
    print("  需要更新 r×k + d×r 个参数\n")

    print("关键洞察:")
    print("  - 前向传播时,X @ B 可以预先计算")
    print("  - A 和 B 通常用随机初始化,然后训练")
    print("  - 缩放因子 α 通常设为 r 的某个倍数(如 r 或 2r)")

explain_lora_math()

LoRA 的实现细节

A 和 B 的初始化

LoRA 中 A 和 B 的初始化方式很重要:

  • A:随机初始化(通常用正态分布)
  • B:初始化为零(B = 0)

这样初始化确保在训练开始时,ΔW = BA = 0,即模型输出与原始模型完全相同。随着训练进行,ΔW 逐渐增加。

python
def initialize_lora_params(d, k, r, alpha=None):
    """初始化 LoRA 参数

    参数:
        d: 输出维度
        k: 输入维度
        r: 秩
        alpha: 缩放因子,默认设为 r
    """
    if alpha is None:
        alpha = r

    # A: (r, k) - 随机初始化
    A = np.random.randn(r, k).astype(np.float32) * 0.01

    # B: (d, r) - 零初始化
    B = np.zeros((d, r), dtype=np.float32)

    return A, B, alpha

# 示例
d, k, r = 1024, 1024, 8
A, B, alpha = initialize_lora_params(d, k, r)

print(f"A 形状: {A.shape}, 均值: {A.mean():.6f}")
print(f"B 形状: {B.shape}, 均值: {B.mean():.6f}")
print(f"α = {alpha}")

LoRA 与其他微调方法的对比

python
def compare_finetuning_methods():
    """对比不同微调方法"""
    print("\n=== 微调方法对比 ===\n")

    # 假设 GPT-2 中一个注意力层的参数
    d_model = 768
    d_ff = 3072  # FFN 中间层

    # 全量微调
    full_params = d_model * d_model + d_model * d_ff  # QKV + FFN
    print(f"全量微调: {full_params:,} 参数")

    # Adapter(假设 bottleneck = 64)
    adapter_bottleneck = 64
    adapter_params = d_model * adapter_bottleneck + adapter_bottleneck * d_model
    print(f"Adapter (bottleneck={adapter_bottleneck}): {adapter_params:,} 参数")

    # LoRA (r=8)
    r = 8
    lora_params = r * d_model + d_model * r  # A 和 B
    print(f"LoRA (r={r}): {lora_params:,} 参数")

    # LoRA vs 全量
    print(f"\nLoRA 节省: {(1 - lora_params/full_params)*100:.2f}%")

compare_finetuning_methods()

LoRA 适用哪些层

LoRA 最初主要应用于 Transformer 的自注意力层中的 QKV 投影和输出投影:

python
def show_lora_targets():
    """展示 LoRA 适用的层"""
    print("\n=== LoRA 适用层 ===\n")

    print("在 Transformer 中,LoRA 通常应用于:")
    print("  1. W_q: Query 投影矩阵")
    print("  2. W_k: Key 投影矩阵")
    print("  3. W_v: Value 投影矩阵")
    print("  4. W_o: Output 投影矩阵")
    print("  5. (可选) FFN 的 W1 和 W2\n")

    print("不去 적용:")
    print("  - LayerNorm")
    print("  - 偏置")
    print("  - 嵌入层")

show_lora_targets()

LoRA 的优势

python
def list_lora_advantages():
    """列出 LoRA 的优势"""
    print("\n=== LoRA 优势 ===\n")

    advantages = [
        ("参数效率", "只需训练约 0.1%-1% 的原始参数"),
        ("显存效率", "训练时只需存储 LoRA 参数和优化器状态"),
        ("可插拔", "可以随时切换不同的 LoRA 权重,相当于模型的不同技能"),
        ("部署友好", "多个 LoRA 可以共享原始模型,按需加载"),
        ("训练稳定", "从零初始化开始,不会破坏预训练权重"),
    ]

    for title, desc in advantages:
        print(f"✓ {title}: {desc}")

list_lora_advantages()

LoRA 的局限性

python
def list_lora_limitations():
    """列出 LoRA 的局限性"""
    print("\n=== LoRA 局限性 ===\n")

    limitations = [
        ("任务类型限制", "对于需要大幅改变模型行为的任务,LoRA 可能不如全量微调"),
        ("秩的选择", "r 太大则参数多,太小则表达能力不足,需要调参"),
        ("注意力层选择", "只微调注意力层可能错过 FFN 层学到的重要知识"),
    ]

    for title, desc in limitations:
        print(f"✗ {title}: {desc}")

list_lora_limitations()

常见问题

Q: 为什么 B 初始化为零?

A: B 初始化为零确保在训练开始时 ΔW = BA = 0,模型输出与原始完全相同。如果 B 随机初始化,模型一开始就会有大幅变化,可能导致训练不稳定。

python
def explain_b_zero_init():
    """解释为什么 B 初始化为零"""
    print("\n=== B 零初始化的原因 ===\n")

    print("如果 BA 初始化为非零矩阵 W_init:")
    print("  训练开始时,Y = X @ (W₀ + W_init)")
    print("  这相当于模型从 W₀ + W_init 开始,而不是从 W₀ 开始")
    print("  可能导致训练不稳定\n")

    print("如果 B=0 初始化:")
    print("  训练开始时,Y = X @ W₀ + 0 = X @ W₀")
    print("  模型输出与原始完全相同")
    print("  然后逐渐学习 ΔW = BA\n")

    print("这类似于从原始模型平滑过渡到微调模型")

explain_b_zero_init()

Q: α 的作用是什么?

A: α 是缩放因子,用于控制 LoRA 更新量的大小。通常设置为 r 或 2r。更新公式是:

Y = X @ W₀ + (α/r) * (X @ B @ A)

这样当 r 变化时,不需要重新调整学习率。

python
def explain_alpha():
    """解释 α 的作用"""
    print("\n=== α 缩放因子 ===\n")

    print("更新公式:")
    print("  Y = X @ (W₀ + (α/r) * BA)\n")

    print("作用:")
    print("  - 控制 LoRA 贡献的大小")
    print("  - 当 r 增加时,保持 (α/r) 的比例稳定")
    print("  - 常见设置:α = r 或 α = 2r")

explain_alpha()

理解 LoRA 的原理是掌握现代大模型微调技术的基础。在下一篇文章中,我们将使用 NumPy 实现 LoRA 的前向传播。

基于 MIT 许可发布