跳转到内容

文本文件输入输出

NumPy 提供了便捷的函数来读写文本文件,这在处理 LLM 相关数据时非常有用。词汇表文件、配置文件、训练日志等通常以文本格式存储。np.loadtxtnp.savetxt 是处理这类数据的利器,它们可以高效地读写逗号分隔值(CSV)或其他分隔符格式的文件。本篇文章介绍这两个函数的使用方法、常见陷阱,以及在 LLM 场景中的实际应用。

np.loadtxt:读取文本文件

np.loadtxt 是 NumPy 中最常用的文本文件读取函数,它可以自动解析文本文件并转换为 NumPy 数组:

python
import numpy as np

# 基本用法:读取 CSV 文件
# 假设文件格式:
# 1.0, 2.0, 3.0
# 4.0, 5.0, 6.0

data = np.loadtxt('data.csv', delimiter=',')
print(f"读取的数据:\n{data}")

指定数据类型

默认情况下,loadtxt 返回 float64 数组。可以通过 dtype 参数指定其他类型:

python
# 读取为整数
int_data = np.loadtxt('int_data.csv', delimiter=',', dtype=np.int32)
print(f"整数数据 dtype: {int_data.dtype}")

# 读取为字符串
# str_data = np.loadtxt('str_data.csv', delimiter=',', dtype=str)

处理表头和注释

实际数据文件通常包含表头或注释行:

python
# skiprows: 跳过前 N 行
data = np.loadtxt('data_with_header.csv', delimiter=',', skiprows=1)

# comments: 指定注释字符(默认为 '#')
data = np.loadtxt('data_with_comments.txt', delimiter=',', comments='#')

处理缺失值

loadtxt 可以处理缺失值:

python
# 缺失值默认用 nan 表示
data = np.loadtxt('data_with_missing.csv', delimiter=',')

# 或者指定缺失值的表示方式
data = np.loadtxt('data_with_missing.csv', delimiter=',', 
                  missing_values=['NA', 'N/A', ''])

np.savetxt:保存文本文件

np.savetxt 将数组保存为文本文件:

python
# 基本用法:保存为 CSV
arr = np.array([[1.0, 2.0, 3.0],
                [4.0, 5.0, 6.0]])
np.savetxt('output.csv', arr, delimiter=',')

格式化输出

可以指定输出格式:

python
# 使用 fmt 参数控制格式
# %.2f: 保留两位小数
np.savetxt('formatted.csv', arr, delimiter=',', fmt='%.2f')

# 使用科学计数法
np.savetxt('scientific.csv', arr, delimiter=',', fmt='%.2e')

# 混合格式
np.savetxt('mixed.csv', arr, delimiter=',', fmt=['%.2f', '%.3f', '%.4f'])

添加表头和注释

python
# header: 添加表头
np.savetxt('with_header.csv', arr, delimiter=',',
           header='col1, col2, col3', comments='')

# footer: 添加脚注
np.savetxt('with_footer.csv', arr, delimiter=',',
           footer='Generated by NumPy')

在LLM场景中的应用

保存和加载词汇表

词汇表文件通常是多行文本,每行一个 token。可以用 loadtxt 读取:

python
def load_vocabulary(filepath):
    """加载词汇表文件

    词汇表格式:每行一个 token(字符串)
    """
    # 方法1:直接读取
    with open(filepath, 'r', encoding='utf-8') as f:
        vocab = [line.strip() for line in f]

    # 方法2:使用 loadtxt(返回字符串数组)
    vocab = np.loadtxt(filepath, dtype=str, delimiter='\n')
    return vocab

# 保存词汇表
def save_vocabulary(vocab, filepath):
    """保存词汇表到文件"""
    np.savetxt(filepath, vocab, fmt='%s', delimiter='\n')

# 示例词汇表
vocab = ['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]', 'the', 'cat', 'dog']
save_vocabulary(vocab, 'vocab.txt')

# 加载词汇表
loaded_vocab = load_vocabulary('vocab.txt')
print(f"加载的词汇表: {loaded_vocab}")
print(f"词汇表大小: {len(loaded_vocab)}")

保存和加载配置文件

训练超参数或模型配置可以用文本格式保存:

python
def save_config(filepath, config_dict):
    """保存配置为文本文件

    格式: key=value
    """
    with open(filepath, 'w') as f:
        for key, value in config_dict.items():
            f.write(f"{key}={value}\n")

def load_config(filepath):
    """从文本文件加载配置"""
    config = {}
    with open(filepath, 'r') as f:
        for line in f:
            line = line.strip()
            if line and '=' in line:
                key, value = line.split('=', 1)
                config[key] = value
    return config

# 示例配置
config = {
    'vocab_size': 50257,
    'hidden_size': 768,
    'num_layers': 12,
    'num_heads': 12,
    'intermediate_size': 3072,
    'max_position_embeddings': 1024,
    'learning_rate': 5e-5,
    'batch_size': 32,
}

save_config('training_config.txt', config)
loaded_config = load_config('training_config.txt')
print(f"加载的配置: {loaded_config}")

保存训练日志

训练过程中的损失值和指标可以保存为 CSV 格式:

python
def save_training_log(log_path, epochs, losses, accuracies):
    """保存训练日志

    参数:
        log_path: 日志文件路径
        epochs: epoch 列表
        losses: 损失列表
        accuracies: 准确率列表
    """
    header = 'epoch,loss,accuracy'
    data = np.column_stack([epochs, losses, accuracies])
    np.savetxt(log_path, data, delimiter=',', header=header, comments='', fmt=['%d', '%.6f', '%.6f'])

# 示例训练日志
np.random.seed(42)
epochs = np.arange(1, 11)
losses = 2.5 - 0.2 * epochs + np.random.randn(10) * 0.1
accuracies = 0.5 + 0.05 * epochs + np.random.randn(10) * 0.02

save_training_log('training_log.csv', epochs, losses, accuracies)

# 读取并绘图
log_data = np.loadtxt('training_log.csv', delimiter=',', skiprows=1)
print(f"训练日志:\n{log_data}")

高级用法

读取不规则分隔符文件

python
# 使用不同的分隔符
data = np.loadtxt('data.tsv', delimiter='\t')  # Tab 分隔

# 使用正则表达式分隔(需要更多处理)
# 对于复杂格式,可能需要使用 np.genfromtxt
data = np.genfromtxt('complex_data.csv', delimiter=',')

使用 np.genfromtxt 处理复杂情况

np.genfromtxt 是更强大的文本文件读取函数,可以处理缺失值、混合类型等复杂情况:

python
# 处理缺失值
data = np.genfromtxt('data_with_missing.csv', delimiter=',', 
                     missing_values='NA', filling_values=0.0)

# 自动检测数据类型
data = np.genfromtxt('mixed_data.csv', delimiter=',', 
                     dtype=None, encoding='utf-8')

读取部分数据

对于大型文件,可能只需要读取部分数据:

python
# 只读取前 N 行
data = np.loadtxt('large_file.csv', delimiter=',', max_rows=1000)

# 读取指定行(使用 skiprows 和 max_rows 组合)
data = np.loadtxt('large_file.csv', delimiter=',', 
                  skiprows=100, max_rows=50)

常见误区

误区一:忘记指定 delimiter

如果文件使用非逗号分隔,务必指定正确的 delimiter:

python
# Tab 分隔文件
data = np.loadtxt('data.tsv', delimiter='\t')

# 空格分隔文件
data = np.loadtxt('data.txt', delimiter=' ')

误区二:大文件使用 loadtxt 效率低

对于大型文本文件,loadtxt 可能很慢,因为需要逐行解析。考虑使用二进制格式(np.save)或内存映射(np.memmap):

python
# 对于大型数组,使用二进制格式更高效
# arr = np.load('data.npy')  # 比 loadtxt 快很多

误区三:默认使用科学计数法导致精度问题

默认格式可能使用科学计数法,对于精确数据可能不合适:

python
# 指定固定小数位数
np.savetxt('precise.csv', arr, delimiter=',', fmt='%.10f')

# 或使用 '%s' 保存完整精度(对于浮点数)

误区四:编码问题

处理包含非 ASCII 字符的文件时,需要指定编码:

python
# 读取 UTF-8 编码的文件
vocab = np.loadtxt('vocab.txt', dtype=str, delimiter='\n', encoding='utf-8')

性能对比

文本文件 IO 比二进制 IO 慢很多:

python
import time

arr = np.random.randn(10000, 100)

# 测量 savetxt 性能
start = time.time()
np.savetxt('text_data.csv', arr, delimiter=',', fmt='%.6f')
text_time = time.time() - start

# 测量 save 性能
start = time.time()
np.save('binary_data.npy', arr)
binary_time = time.time() - start

import os
text_size = os.path.getsize('text_data.csv')
binary_size = os.path.getsize('binary_data.npy')

print(f"文本文件大小: {text_size / 1024 / 1024:.1f} MB, 耗时: {text_time:.2f}s")
print(f"二进制文件大小: {binary_size / 1024 / 1024:.1f} MB, 耗时: {binary_time:.4f}s")

API 总结

函数用途关键参数
np.loadtxt(fname)读取文本文件delimiter, dtype, skiprows, max_rows
np.savetxt(fname, X)保存文本文件delimiter, fmt, header, comments
np.genfromtxt(fname)读取复杂文本delimiter, dtype, missing_values, filling_values

np.loadtxtnp.savetxt 是处理 LLM 相关文本数据的实用工具,掌握它们的用法可以让你更高效地管理词汇表、配置文件和训练日志。

基于 MIT 许可发布