跳转到内容

创建结构化数组

结构化数组是 NumPy 中一种强大的数据类型,它允许你在一个数组中存储异构的、命名字段的数据。普通 NumPy 数组要求所有元素类型相同,但结构化数组的每个元素可以包含多个不同类型的字段,每个字段都有唯一的名称。在处理语言模型数据时,结构化数组非常适合存储 token_ids、attention_mask、token_type_ids 等相关联的数据,将它们打包在一起便于管理和传递。理解结构化数组的创建和使用方法,可以让你的数据处理代码更加清晰和高效。

为什么使用结构化数组

在 LLM 数据处理中,我们经常需要处理这样一组相关联的数据:

  • token_ids:token 的整数 ID
  • attention_mask:注意力掩码(1 表示有效,0 表示填充)
  • token_type_ids:句子段 ID(用于 BERT 等模型)
  • special_tokens_mask:特殊 token 掩码

普通做法是分别用独立的数组存储这些数据,但结构化数组可以将它们打包成一个紧凑的数据结构,减少函数传递参数的数量,也便于批量处理。

python
import numpy as np

# 普通做法:多个独立数组
token_ids = np.array([464, 197, 35, 2057])
attention_mask = np.array([1, 1, 1, 1])
token_type_ids = np.array([0, 0, 0, 0])

# 结构化数组做法:将相关数据打包
data = np.array([(464, 1, 0), (197, 1, 0), (35, 1, 0), (2057, 1, 0)],
                dtype=[('token_id', np.int32),
                       ('attention_mask', np.int8),
                       ('token_type_id', np.int8)])
print(f"结构化数组:\n{data}")

定义结构化数据类型

创建结构化数组首先需要定义其数据类型(dtype)。dtype 描述了数组中每个字段的名称和数据类型:

python
# 定义 dtype:字段名 -> 数据类型
token_dtype = np.dtype([
    ('token_id', np.int32),      # token ID
    ('attention_mask', np.int8),  # 注意力掩码
    ('token_type_id', np.int8),  # token 类型 ID
    ('special_token_mask', np.int8)  # 特殊 token 掩码
])

# 查看 dtype
print(f"Token 数据类型:\n{token_dtype}")
print(f"字段名: {token_dtype.names}")  # ('token_id', 'attention_mask', ...)

创建结构化数组

有多种方式创建结构化数组:

方法一:使用 dtype 参数

python
# 创建结构化数组
tokens = np.array([(464, 1, 0, 0), (197, 1, 0, 0), (35, 1, 0, 0), (2057, 1, 0, 0)],
                  dtype=[('token_id', np.int32),
                         ('attention_mask', np.int8),
                         ('token_type_id', np.int8),
                         ('special_token_mask', np.int8)])

print(f"数组形状: {tokens.shape}")
print(f"数组 dtype: {tokens.dtype}")
print(f"数组内容:\n{tokens}")

方法二:使用字典 dtype

字段名到数据类型的字典也可以定义 dtype:

python
# 简化的字典 dtype(字段名会自动生成)
dict_dtype = np.dtype({
    'names': ['token_id', 'attention_mask', 'token_type_id'],
    'formats': ['i4', 'i1', 'i1']  # i4=int32, i1=int8
})

tokens = np.zeros(4, dtype=dict_dtype)
tokens['token_id'] = [464, 197, 35, 2057]
tokens['attention_mask'] = [1, 1, 1, 1]
tokens['token_type_id'] = [0, 0, 0, 0]

print(f"字典 dtype 创建的数组:\n{tokens}")

方法三:使用字段名简写

对于只需要字段名和类型的基本情况:

python
# 使用简化格式:'字段名:类型',多个字段用逗号分隔
simple_dtype = 'token_id:i4, attention_mask:i1, token_type_id:i1'

tokens = np.array([(464, 1, 0), (197, 1, 0), (35, 1, 0), (2057, 1, 0)],
                  dtype=simple_dtype)
print(f"简化格式创建:\n{tokens}")

支持的数据类型

结构化数组支持多种 NumPy 数据类型:

python
# 常见的数据类型
example_dtype = np.dtype([
    ('int8_field', np.int8),
    ('int16_field', np.int16),
    ('int32_field', np.int32),
    ('int64_field', np.int64),
    ('float32_field', np.float32),
    ('float64_field', np.float64),
    ('bool_field', np.bool_),
    ('bytes_field', 'S10'),  # 固定长度字节串
])

arr = np.zeros(3, dtype=example_dtype)
print(f"支持的类型示例:\n{arr.dtype}")

在LLM场景中的应用

存储训练样本

结构化数组非常适合存储单个训练样本的所有相关数据:

python
# 定义训练样本的结构
sample_dtype = np.dtype([
    ('input_ids', 'i4', (128,)),      # 输入 token IDs,最多 128 个
    ('attention_mask', 'i1', (128,)), # 注意力掩码
    ('token_type_ids', 'i1', (128,)), # token 类型 ID
    ('position_ids', 'i2', (128,)),   # 位置 ID
    ('labels', 'i4', (128,)),         # 标签(用于语言建模)
])

# 创建训练样本数组
num_samples = 1000
dataset = np.zeros(num_samples, dtype=sample_dtype)

# 填充数据
np.random.seed(42)
for i in range(num_samples):
    seq_len = np.random.randint(32, 128)
    dataset[i]['input_ids'] = np.random.randint(0, 50257, size=128)
    dataset[i]['attention_mask'] = np.concatenate([
        np.ones(seq_len),
        np.zeros(128 - seq_len)
    ]).astype(np.int8)
    dataset[i]['labels'] = np.concatenate([
        dataset[i]['input_ids'][1:],
        np.zeros(1)
    ]).astype(np.int32)

print(f"数据集形状: {dataset.shape}")
print(f"数据集 dtype: {dataset.dtype}")
print(f"第一个样本 input_ids 前10个: {dataset[0]['input_ids'][:10]}")

批量处理

结构化数组可以方便地进行批量处理:

python
def create_batch_from_samples(samples, batch_size):
    """从样本数组中提取 batch

    返回结构化数组的视图(不复制数据)
    """
    return samples[:batch_size]

# 示例
batch = create_batch_from_samples(dataset, batch_size=32)

print(f"Batch 形状: {batch.shape}")
print(f"Batch 共享内存: {np.shares_memory(dataset, batch)}")
print(f"Batch 第一条的 input_ids 前5个: {batch[0]['input_ids'][:5]}")

内存布局

结构化数组的内存布局非常紧凑,相关数据紧密存储在一起:

python
# 对比:普通做法 vs 结构化数组
seq_len = 128
batch_size = 32

# 普通做法:多个独立数组
token_ids = np.zeros((batch_size, seq_len), dtype=np.int32)
attention_mask = np.zeros((batch_size, seq_len), dtype=np.int8)
token_type_ids = np.zeros((batch_size, seq_len), dtype=np.int8)

total_memory_normal = token_ids.nbytes + attention_mask.nbytes + token_type_ids.nbytes
print(f"普通做法内存: {total_memory_normal / 1024:.1f} KB")

# 结构化数组
batch_dtype = np.dtype([
    ('token_ids', np.int32, (seq_len,)),
    ('attention_mask', np.int8, (seq_len,)),
    ('token_type_ids', np.int8, (seq_len,))
])
batch_structured = np.zeros(batch_size, dtype=batch_dtype)

total_memory_structured = batch_structured.nbytes
print(f"结构化数组内存: {total_memory_structured / 1024:.1f} KB")
print(f"内存节省: {(1 - total_memory_structured / total_memory_normal) * 100:.1f}%")

常见误区

误区一:忘记字段名必须唯一

字段名必须是唯一的字符串:

python
# 错误:重复字段名
try:
    dup_dtype = np.dtype([('field', np.int32), ('field', np.int32)])
except Exception as e:
    print(f"错误: {e}")

# 正确
correct_dtype = np.dtype([('field1', np.int32), ('field2', np.int32)])

误区二:混淆 dtype 的字段名和变量名

结构化数组的字段名是 dtype 的一部分,不是变量名:

python
arr = np.zeros(3, dtype=[('a', np.int32), ('b', np.float64)])

# 访问字段用字符串索引
print(arr['a'])  # 正确
# print(arr.a)  # 错误!NumPy 结构化数组不支持属性访问

误区三:假设结构化数组支持所有普通数组操作

某些操作在结构化数组上行为不同:

python
arr = np.array([(1, 2.0), (3, 4.0)], dtype=[('x', np.int32), ('y', np.float64)])

# reshape 可能不按预期工作
try:
    reshaped = arr.reshape(1, 2)  # 可能产生意外结果
except Exception as e:
    print(f"reshape 可能有问题: {e}")

API 总结

创建方式示例
元组列表np.array([(1, 2), (3, 4)], dtype=[('a', np.int32), ('b', np.int32)])
字典 dtypenp.dtype({'names': ['a', 'b'], 'formats': ['i4', 'i4']})
简化字符串np.dtype('a:i4, b:i4')
嵌套结构np.dtype([('x', np.int32), ('y', [('a', np.int32), ('b', np.int32)])])

结构化数组是处理 LLM 数据的利器,它将相关联的字段打包在一起,既节省内存,又便于管理。掌握结构化数组的创建和使用,能让你的数据处理代码更加优雅高效。

基于 MIT 许可发布