跳转到内容

逐元素运算

逐元素运算是NumPy最核心的特性之一。想象一下,当你需要对序列中每个token的embedding进行缩放时,或者当你需要计算两个向量的逐元素差值来衡量相似度时,又或者当你需要在残差连接中将上一层的输出与输入相加时,这些都需要用到逐元素运算。顾名思义,逐元素运算就是对数组中的每个元素独立进行相同的操作。在深度学习中,从激活函数的应用到损失函数的计算,从残差连接到LayerNorm,几乎所有的计算都离不开逐元素运算。NumPy的逐元素运算不仅语法简洁,而且底层使用高度优化的C代码,性能比Python循环快上百倍。

基本算术运算

NumPy支持标准的算术运算符,这些运算符在数组上执行的是逐元素运算:

python
import numpy as np

a = np.array([1, 2, 3, 4, 5])
b = np.array([10, 20, 30, 40, 50])

print(f"a + b = {a + b}")  # [11, 22, 33, 44, 55]
print(f"a - b = {a - b}")  # [-9, -18, -27, -36, -45]
print(f"a * b = {a * b}")  # [10, 40, 90, 160, 250]
print(f"a / b = {a / b}")  # [0.1, 0.1, 0.1, 0.1, 0.1]

需要特别注意的是,乘法 * 是逐元素乘法,不是矩阵乘法。矩阵乘法需要使用 @ 运算符或 np.matmul 函数。

标量与数组的运算

标量与数组的运算会自动将标量广播到数组的形状:

python
arr = np.array([1, 2, 3, 4, 5])

print(f"arr + 10 = {arr + 10}")  # [11, 12, 13, 14, 15]
print(f"arr * 2 = {arr * 2}")    # [2, 4, 6, 8, 10]
print(f"arr ** 2 = {arr ** 2}")  # [1, 4, 9, 16, 25]

这种广播特性使得标量运算变得极其自然,无需任何显式循环。

比较运算与布尔运算

比较运算返回布尔数组:

python
a = np.array([1, 2, 3, 4, 5])

print(f"a > 3: {a > 3}")  # [False False False  True  True]
print(f"a == 3: {a == 3}")  # [False False  True False False]
print(f"a != 3: {a != 3}")  # [True True False True True]

布尔运算使用 &(与)、|(或)、~(非):

python
a = np.array([1, 2, 3, 4, 5])

# 组合条件
print(f"(a > 2) & (a < 5): {(a > 2) & (a < 5)}")  # [False False  True  True False]
print(f"(a < 2) | (a > 4): {(a < 2) | (a > 4)}")  # [True False False False  True]
print(f"~(a > 3): {~(a > 3)}")  # [True True True False False]

注意必须使用位运算符 &, |, ~,而不是Python的 and, or, not

在LLM场景中的应用

残差连接

残差连接是Transformer中的核心组件,将输入与输出相加:output = input + sublayer(input)。逐元素加法是实现残差连接的基础:

python
batch_size = 4
seq_len = 512
hidden_dim = 768

# 模拟输入
input_tensor = np.random.randn(batch_size, seq_len, hidden_dim).astype(np.float32)

# 模拟经过某个子层的输出
sublayer_output = np.random.randn(batch_size, seq_len, hidden_dim).astype(np.float32)

# 残差连接
residual_output = input_tensor + sublayer_output
print(f"残差输出形状: {residual_output.shape}")

残差连接的作用是让梯度能够直接流过网络,解决深层网络的梯度消失问题。

LayerNorm中的均方差计算

LayerNorm对每个样本的每个位置进行归一化,计算均方误差(MSE)是其中的关键步骤:

python
def mean_squared_error(y_true, y_pred):
    """计算均方误差"""
    return np.mean((y_true - y_pred) ** 2)

# 模拟真实值和预测值
y_true = np.random.randn(32, 512, 768).astype(np.float32)
y_pred = np.random.randn(32, 512, 768).astype(np.float32)

mse = mean_squared_error(y_true, y_pred)
print(f"MSE: {mse:.6f}")

逐元素运算在这里发挥了重要作用:y_true - y_pred 计算逐元素差值,然后平方,再求平均。

激活函数

各种激活函数都是逐元素运算的典型应用:

python
def sigmoid(x):
    return 1 / (1 + np.exp(-np.clip(x, -500, 500)))

def relu(x):
    return np.maximum(0, x)

def gelu(x):
    return 0.5 * x * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * x**3)))

# 测试激活函数
x = np.array([-2, -1, 0, 1, 2])
print(f"Sigmoid: {sigmoid(x)}")
print(f"ReLU: {relu(x)}")
print(f"GELU: {gelu(x)}")

注意力权重的缩放

在注意力机制中,需要对注意力分数进行缩放:

python
batch_size = 4
seq_len = 512
head_dim = 64

# 模拟注意力分数
attention_scores = np.random.randn(batch_size, seq_len, seq_len).astype(np.float32) * 0.1

# 缩放:除以 sqrt(head_dim)
scaled_scores = attention_scores / np.sqrt(head_dim)
print(f"缩放后分数范围: [{scaled_scores.min():.4f}, {scaled_scores.max():.4f}]")

原地运算

NumPy支持原地修改数组的操作:

python
arr = np.array([1, 2, 3, 4, 5])
arr += 10
print(f"原地加法: {arr}")  # [11 12 13 14 15]

arr *= 2
print(f"原地乘法: {arr}")  # [22 24 26 28 30]

原地操作会直接修改原始数组,节省内存。

常见误区与注意事项

误区一:混淆逐元素乘法和矩阵乘法

python
A = np.array([[1, 2], [3, 4]])
B = np.array([[5, 6], [7, 8]])

# 逐元素乘法
print(f"逐元素乘法 A * B:\n{A * B}")

# 矩阵乘法
print(f"矩阵乘法 A @ B:\n{A @ B}")

这是两个完全不同的操作!

误区二:整数除法

python
arr = np.array([1, 2, 3, 4, 5])
print(f"整数数组 / 整数: {arr / 2}")  # float64
print(f"整数数组 // 整数: {arr // 2}")  # 整数除法

在Python 3中,/ 总是返回浮点数,// 才是整数除法。

误区三:复数运算

python
a = np.array([1+2j, 3+4j])
print(f"共轭: {np.conj(a)}")
print(f"绝对值: {np.abs(a)}")  # sqrt(1^2 + 2^2)

NumPy支持完整的复数运算。

底层原理

NumPy的逐元素运算之所以快,是因为底层使用了SIMD(Single Instruction Multiple Data)指令。CPU可以单条指令处理多个浮点数,这意味着一次加法运算可以同时处理多个元素。此外,NumPy的运算是在预分配的连续内存上进行的,避免了Python对象管理的开销。

小结

逐元素运算是NumPy最基础也最重要的运算类型。算术运算、比较运算、布尔运算都可以逐元素作用于数组。在LLM场景中,逐元素运算广泛用于:残差连接、LayerNorm、激活函数、注意力权重缩放等。理解逐元素运算与矩阵乘法的区别是避免bug的关键。

面试时需要能够解释逐元素乘法与矩阵乘法的区别,理解广播机制如何工作,以及注意整数除法和位运算符的使用。

基于 MIT 许可发布