跳转到内容

缩放点积注意力

缩放点积注意力(Scaled Dot-Product Attention)是 Transformer 的核心机制。给定 Query、Key、Value 三个矩阵,注意力通过计算 Q 与 K 的点积得到注意力分数,然后使用 softmax 归一化,最后乘以 V 得到输出。本篇文章详细介绍缩放点积注意力的原理和纯 NumPy 实现,包括多头注意力的实现。

注意力机制的基本原理

Attention(Q, K, V) = softmax(QK^T / √d_k) × V

其中 d_k 是 Key 向量的维度。缩放因子 √d_k 用于防止点积值过大导致 softmax 梯度消失。

python
import numpy as np

def softmax(x, axis=-1):
    """Softmax 函数

    softmax(x_i) = exp(x_i) / sum(exp(x_j))
    """
    exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
    return exp_x / np.sum(exp_x, axis=axis, keepdims=True)

def scaled_dot_product_attention(Q, K, V, mask=None):
    """缩放点积注意力

    参数:
        Q: Query 矩阵 (..., seq_len, d_k)
        K: Key 矩阵 (..., seq_len, d_k)
        V: Value 矩阵 (..., seq_len, d_v)
        mask: 掩码 (..., seq_len, seq_len),True 表示需要 mask
    返回:
        output: 注意力输出 (..., seq_len, d_v)
        attention_weights: 注意力权重 (..., seq_len, seq_len)
    """
    d_k = Q.shape[-1]

    # 计算 QK^T
    scores = np.einsum('...nd,...md->...nm', Q, K)

    # 缩放
    scores = scores / np.sqrt(d_k)

    # 应用掩码(如果提供)
    if mask is not None:
        scores = np.where(mask, -1e9, scores)

    # Softmax 归一化
    attention_weights = softmax(scores, axis=-1)

    # 乘以 V
    output = np.einsum('...nm,...md->...nd', attention_weights, V)

    return output, attention_weights

# 示例
np.random.seed(42)
seq_len = 5
d_k = 64

Q = np.random.randn(seq_len, d_k).astype(np.float32)
K = np.random.randn(seq_len, d_k).astype(np.float32)
V = np.random.randn(seq_len, d_k).astype(np.float32)

output, attention_weights = scaled_dot_product_attention(Q, K, V)
print(f"Q 形状: {Q.shape}")
print(f"K 形状: {K.shape}")
print(f"V 形状: {V.shape}")
print(f"输出形状: {output.shape}")
print(f"注意力权重形状: {attention_weights.shape}")

多头注意力的实现

python
class MultiHeadAttention:
    """多头注意力机制"""

    def __init__(self, hidden_size, num_heads):
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads

        assert hidden_size % num_heads == 0

        # QKV 投影权重
        self.W_q = np.random.randn(hidden_size, hidden_size).astype(np.float32) * 0.02
        self.W_k = np.random.randn(hidden_size, hidden_size).astype(np.float32) * 0.02
        self.W_v = np.random.randn(hidden_size, hidden_size).astype(np.float32) * 0.02

        # 输出投影权重
        self.W_o = np.random.randn(hidden_size, hidden_size).astype(np.float32) * 0.02

    def split_heads(self, X, batch_size):
        """将 hidden_dim 分割为 num_heads 个头

        X: (batch, seq_len, hidden_size)
        返回: (batch, num_heads, seq_len, head_dim)
        """
        X = X.reshape(batch_size, -1, self.num_heads, self.head_dim)
        return X.transpose(0, 2, 1, 3)

    def forward(self, Q, K, V, mask=None):
        """前向传播

        参数:
            Q, K, V: (batch_size, seq_len, hidden_size)
            mask: (batch_size, seq_len, seq_len)
        返回:
            output: (batch_size, seq_len, hidden_size)
            attention_weights: (batch_size, num_heads, seq_len, seq_len)
        """
        batch_size = Q.shape[0]
        seq_len = Q.shape[1]

        # QKV 投影
        Q = Q @ self.W_q  # (batch, seq, hidden)
        K = K @ self.W_k
        V = V @ self.W_v

        # 分割为多头
        Q = self.split_heads(Q, batch_size)  # (batch, heads, seq, head_dim)
        K = self.split_heads(K, batch_size)
        V = self.split_heads(V, batch_size)

        # 计算注意力
        d_k = self.head_dim
        scores = np.einsum('bhnd,bhmd->bhnm', Q, K)  # (batch, heads, seq, seq)
        scores = scores / np.sqrt(d_k)

        # 应用掩码
        if mask is not None:
            # 调整掩码维度以匹配
            mask = mask[:, np.newaxis, :, :]  # (batch, 1, seq, seq)
            scores = np.where(mask, -1e9, scores)

        # Softmax
        attention_weights = softmax(scores, axis=-1)

        # 乘以 V
        context = np.einsum('bhnm,bhmd->bhnd', attention_weights, V)  # (batch, heads, seq, head_dim)

        # 合并多头
        context = context.transpose(0, 2, 1, 3)  # (batch, seq, heads, head_dim)
        context = context.reshape(batch_size, seq_len, self.hidden_size)  # (batch, seq, hidden)

        # 输出投影
        output = context @ self.W_o

        return output, attention_weights

# 示例
np.random.seed(42)
batch_size = 2
seq_len = 10
hidden_size = 768
num_heads = 12

mha = MultiHeadAttention(hidden_size, num_heads)
X = np.random.randn(batch_size, seq_len, hidden_size).astype(np.float32)

output, attention_weights = mha.forward(X, X, X)
print(f"输入形状: {X.shape}")
print(f"输出形状: {output.shape}")
print(f"注意力权重形状: {attention_weights.shape}")

简化版单头注意力

为了便于理解,这里提供简化版的实现:

python
class SimpleAttention:
    """简化的注意力机制(单头)"""

    def __init__(self, d_model):
        self.d_model = d_model

        # 简化的 QKV 投影
        self.W_q = np.random.randn(d_model, d_model).astype(np.float32) * 0.02
        self.W_k = np.random.randn(d_model, d_model).astype(np.float32) * 0.02
        self.W_v = np.random.randn(d_model, d_model).astype(np.float32) * 0.02

    def forward(self, X, mask=None):
        """前向传播

        X: (batch_size, seq_len, d_model)
        """
        batch_size, seq_len, _ = X.shape

        # QKV 投影
        Q = X @ self.W_q
        K = X @ self.W_k
        V = X @ self.W_v

        # 计算注意力分数
        scores = Q @ K.transpose(0, 2, 1)  # (batch, seq, seq)
        scores = scores / np.sqrt(self.d_model)

        # 应用掩码
        if mask is not None:
            scores = np.where(mask, -1e9, scores)

        # Softmax
        attention_weights = softmax(scores, axis=-1)

        # 加权求和
        output = attention_weights @ V

        return output, attention_weights

# 示例
simple_attn = SimpleAttention(d_model=512)
X_simple = np.random.randn(2, 10, 512).astype(np.float32)
output, weights = simple_attn.forward(X_simple)
print(f"\n简化注意力:")
print(f"输入形状: {X_simple.shape}")
print(f"输出形状: {output.shape}")
print(f"注意力权重形状: {weights.shape}")

注意力权重的可视化

注意力权重展示了每个位置对其他位置的"关注程度":

python
def visualize_attention(attention_weights, tokens, head_idx=0):
    """可视化注意力权重

    参数:
        attention_weights: (batch, heads, seq, seq) 或 (batch, seq, seq)
        tokens: token 列表
        head_idx: 要可视化的头索引
    """
    if len(attention_weights.shape) == 4:
        # 多头注意力,选择指定的头
        weights = attention_weights[0, head_idx]  # (seq, seq)
    else:
        weights = attention_weights[0]

    seq_len = len(tokens)
    print(f"\n=== 注意力权重可视化 (Head {head_idx}) ===\n")

    # 打印表头
    header = "         " + "".join([f"{t:>8}" for t in tokens[:seq_len]])
    print(header)
    print("-" * len(header))

    # 打印矩阵
    for i, token in enumerate(tokens[:seq_len]):
        row = f"{token:>8} |"
        for j in range(seq_len):
            w = weights[i, j]
            char = '*' if w > 0.5 else ('+' if w > 0.2 else ('.' if w > 0.05 else ' '))
            row += f"{char:>8}"
        print(row)

# 示例
simple_attn = SimpleAttention(d_model=64)
tokens = ['[CLS]', 'The', 'cat', 'eats', 'the', 'fish', '[SEP]']
X_demo = np.random.randn(1, len(tokens), 64).astype(np.float32)
_, attention_weights = simple_attn.forward(X_demo)

# 找出注意力最集中的位置
print(f"注意力权重范围: [{attention_weights.min():.3f}, {attention_weights.max():.3f}")
print(f"每行注意力之和(应为1): {attention_weights[0].sum(axis=1)}")

注意力机制的特性

自注意力(Self-Attention)

自注意力是指 Q、K、V 都来自同一个输入序列:

python
def self_attention(X, mask=None):
    """自注意力"""
    return scaled_dot_product_attention(X, X, X, mask)

# 示例
X = np.random.randn(2, 5, 64).astype(np.float32)
output, _ = self_attention(X)
print(f"自注意力输出形状: {output.shape}")

交叉注意力(Cross-Attention)

交叉注意力是指 Q 来自一个序列,K、V 来自另一个序列:

python
def cross_attention(Q, K, V, mask=None):
    """交叉注意力

    例如:解码器关注编码器的输出
    Q: 解码器隐藏状态
    K, V: 编码器隐藏状态
    """
    return scaled_dot_product_attention(Q, K, V, mask)

# 示例
encoder_output = np.random.randn(2, 10, 64).astype(np.float32)
decoder_hidden = np.random.randn(2, 5, 64).astype(np.float32)
output, _ = cross_attention(decoder_hidden, encoder_output, encoder_output)
print(f"交叉注意力输出形状: {output.shape}")

数值稳定性问题

softmax 在输入值很大时可能溢出:

python
def stable_attention(Q, K, V):
    """数值稳定的注意力实现

    使用 log-sum-exp 技巧
    """
    d_k = Q.shape[-1]
    scores = Q @ K.transpose(0, 2, 1)
    scores = scores / np.sqrt(d_k)

    # 减去最大值提高数值稳定性
    scores_max = scores.max(axis=-1, keepdims=True)
    scores = scores - scores_max

    exp_scores = np.exp(scores)
    attention_weights = exp_scores / exp_scores.sum(axis=-1, keepdims=True)

    return attention_weights @ V

# 测试数值稳定性
Q = np.random.randn(2, 10, 64).astype(np.float32) * 10  # 大值
K = np.random.randn(2, 10, 64).astype(np.float32) * 10
V = np.random.randn(2, 10, 64).astype(np.float32)

output = stable_attention(Q, K, V)
print(f"数值稳定注意力输出形状: {output.shape}")
print(f"输出值范围: [{output.min():.4f}, {output.max():.4f}]")

常见误区

误区一:忘记缩放因子

不缩放会导致 softmax 梯度消失:

python
# 错误:忘记缩放
scores = Q @ K.transpose(0, 2, 1)  # d_k 大时,scores 可能很大
# softmax 会变成接近 one-hot

# 正确:缩放
scores = scores / np.sqrt(d_k)

误区二:混淆注意力权重的维度

注意力权重的形状是 (batch, heads, seq, seq) 或 (batch, seq, seq):

python
# 多头注意力
# attention_weights: (batch, heads, query_seq, key_seq)

# 自注意力时,query_seq == key_seq == seq_len
# 但维度不能混淆

误区三:掩码使用不当

掩码应该将需要隐藏的位置设为 True 或 -inf:

python
# 正确的掩码方式
mask = np.triu(np.ones((seq_len, seq_len)), k=1).astype(bool)  # 上三角为 True
scores = np.where(mask, -1e9, scores)

API 总结

函数/类描述
softmax(x, axis)Softmax 函数
scaled_dot_product_attention(Q, K, V)缩放点积注意力
MultiHeadAttention多头注意力
SimpleAttention简化版单头注意力

缩放点积注意力是 Transformer 的核心。理解其原理和实现,对于掌握现代语言模型至关重要。

基于 MIT 许可发布