跳转到内容

困惑度计算

困惑度(Perplexity,PPL)是语言模型最常用的评估指标之一。它衡量模型对序列的预测能力,困惑度越低表示模型越好。本篇文章详细介绍困惑度的计算原理和纯 NumPy 实现。

困惑度的定义

困惑度定义为:

PPL = exp(-(1/N) * Σ log P(x_i))

或者等价于:

PPL = exp(average_negative_log_likelihood)

其中 N 是 token 数量,log P(x_i) 是模型对第 i 个 token 预测的对数概率。

python
import numpy as np

def compute_perplexity(log_likelihoods):
    """计算困惑度

    参数:
        log_likelihoods: 每个 token 的对数似然,形状 (N,)
    返回:
        perplexity: 困惑度
    """
    # 平均负对数似然
    avg_nll = -np.mean(log_likelihoods)

    # 困惑度
    perplexity = np.exp(avg_nll)

    return perplexity

# 示例
log_probs = np.array([-0.5, -1.0, -0.8, -1.2, -0.6])
ppl = compute_perplexity(log_probs)

print(f"对数似然: {log_probs}")
print(f"平均负对数似然: {-np.mean(log_probs):.4f}")
print(f"困惑度: {ppl:.4f}")

交叉熵与困惑度的关系

在训练语言模型时,我们通常最小化交叉熵损失。困惑度是交叉熵的指数形式:

python
def cross_entropy_to_perplexity(cross_entropy):
    """交叉熵转困惑度"""
    return np.exp(cross_entropy)

def perplexity_to_cross_entropy(perplexity):
    """困惑度转交叉熵"""
    return np.log(perplexity)

# 示例
cross_entropy = 3.5
ppl = cross_entropy_to_perplexity(cross_entropy)
print(f"交叉熵={cross_entropy:.2f} -> 困惑度={ppl:.2f}")

ppl = 33.0
ce = perplexity_to_cross_entropy(ppl)
print(f"困惑度={ppl:.2f} -> 交叉熵={ce:.2f}")

分段困惑度

对于长序列,可以计算分段困惑度:

python
def compute_segment_perplexity(log_likelihoods, segment_length=128):
    """计算分段困惑度

    参数:
        log_likelihoods: 对数似然数组
        segment_length: 分段长度
    返回:
        segment_ppls: 每段的困惑度
    """
    n = len(log_likelihoods)
    segment_ppls = []

    for i in range(0, n, segment_length):
        segment = log_likelihoods[i:i+segment_length]
        ppl = compute_perplexity(segment)
        segment_ppls.append(ppl)

    return np.array(segment_ppls)

# 示例
log_probs = np.random.randn(512) * 2 - 1  # 模拟对数似然
segment_ppls = compute_segment_perplexity(log_probs, segment_length=128)

print("分段困惑度:")
for i, ppl in enumerate(segment_ppls):
    print(f"  段 {i+1}: {ppl:.2f}")

批量困惑度计算

python
def compute_batch_perplexity(log_likelihoods, attention_mask=None):
    """计算批量困惑度

    参数:
        log_likelihoods: (batch, seq_len) 每个 token 的对数似然
        attention_mask: (batch, seq_len) 有效位置掩码
    返回:
        perplexity: 批量平均困惑度
    """
    if attention_mask is not None:
        # 只考虑有效位置
        masked_logits = log_likelihoods * attention_mask
        total_tokens = attention_mask.sum()
        avg_nll = -masked_logits.sum() / total_tokens
    else:
        avg_nll = -np.mean(log_likelihoods)

    return np.exp(avg_nll)

# 示例
batch_size = 8
seq_len = 128

log_probs = np.random.randn(batch_size, seq_len) * 2 - 1
mask = np.random.randint(0, 2, size=(batch_size, seq_len))

ppl = compute_batch_perplexity(log_probs, mask)
print(f"批量困惑度: {ppl:.2f}")

实际应用:评估语言模型

python
def evaluate_model_perplexity(model, dataset):
    """评估模型在数据集上的困惑度

    参数:
        model: 语言模型(有 forward 方法)
        dataset: 数据集(NumPy 数组)
    返回:
        avg_ppl: 平均困惑度
    """
    total_loss = 0.0
    total_tokens = 0

    for i in range(len(dataset)):
        input_ids = dataset[i]

        # 获取模型预测
        logits = model.forward(input_ids)

        # 计算负对数似然
        loss = compute_negative_log_likelihood(logits, input_ids)
        total_loss += loss
        total_tokens += len(input_ids) - 1  # 预测下一个 token

    avg_loss = total_loss / total_tokens
    perplexity = np.exp(avg_loss)

    return perplexity

def compute_negative_log_likelihood(logits, target_ids):
    """计算负对数似然

    参数:
        logits: (seq_len, vocab_size) 模型输出
        target_ids: (seq_len,) 目标 token IDs
    返回:
        loss: 负对数似然
    """
    # 移位:预测第 i+1 个 token 时使用第 i 个 token 的 logits
    target_logits = logits[:-1]
    target_tokens = target_ids[1:]

    # 计算交叉熵
    log_probs = target_logits - np.log(np.sum(np.exp(target_logits - np.max(target_logits, axis=-1, keepdims=True)), axis=-1, keepdims=True))
    loss = -log_probs[np.arange(len(target_tokens)), target_tokens]

    return np.sum(loss)

# 示例
print("困惑度评估示例完成")

常见问题与处理

数值稳定性

python
def stable_perplexity(log_likelihoods, eps=1e-9):
    """数值稳定的困惑度计算

    当对数似然非常负时,exp 可能下溢
    """
    avg_nll = -np.mean(log_likelihoods)

    # 限制范围避免溢出
    avg_nll = min(avg_nll, 700)  # exp(700) 约等于 inf

    return np.exp(avg_nll)

# 测试
log_probs = np.array([-100, -200, -300, -100])
ppl = stable_perplexity(log_probs)
print(f"数值稳定的困惑度: {ppl:.2f}")

困惑度的解释

python
def interpret_perplexity(ppl):
    """解释困惑度"""
    print(f"\n=== 困惑度解释 ===")
    print(f"困惑度 = {ppl:.2f}\n")

    if ppl < 10:
        print("非常低的困惑度,模型预测非常准确")
    elif ppl < 30:
        print("较低的困惑度,模型表现良好")
    elif ppl < 100:
        print("中等的困惑度,模型表现一般")
    elif ppl < 300:
        print("较高的困惑度,模型预测较困难")
    else:
        print("非常高的困惑度,模型预测几乎随机")

interpret_perplexity(25.0)

困惑度是评估语言模型的核心指标。掌握其计算方法对于正确评估模型性能非常重要。

基于 MIT 许可发布