主题
文本文件输入输出
NumPy 提供了便捷的函数来读写文本文件,这在处理 LLM 相关数据时非常有用。词汇表文件、配置文件、训练日志等通常以文本格式存储。np.loadtxt 和 np.savetxt 是处理这类数据的利器,它们可以高效地读写逗号分隔值(CSV)或其他分隔符格式的文件。本篇文章介绍这两个函数的使用方法、常见陷阱,以及在 LLM 场景中的实际应用。
np.loadtxt:读取文本文件
np.loadtxt 是 NumPy 中最常用的文本文件读取函数,它可以自动解析文本文件并转换为 NumPy 数组:
python
import numpy as np
# 基本用法:读取 CSV 文件
# 假设文件格式:
# 1.0, 2.0, 3.0
# 4.0, 5.0, 6.0
data = np.loadtxt('data.csv', delimiter=',')
print(f"读取的数据:\n{data}")指定数据类型
默认情况下,loadtxt 返回 float64 数组。可以通过 dtype 参数指定其他类型:
python
# 读取为整数
int_data = np.loadtxt('int_data.csv', delimiter=',', dtype=np.int32)
print(f"整数数据 dtype: {int_data.dtype}")
# 读取为字符串
# str_data = np.loadtxt('str_data.csv', delimiter=',', dtype=str)处理表头和注释
实际数据文件通常包含表头或注释行:
python
# skiprows: 跳过前 N 行
data = np.loadtxt('data_with_header.csv', delimiter=',', skiprows=1)
# comments: 指定注释字符(默认为 '#')
data = np.loadtxt('data_with_comments.txt', delimiter=',', comments='#')处理缺失值
loadtxt 可以处理缺失值:
python
# 缺失值默认用 nan 表示
data = np.loadtxt('data_with_missing.csv', delimiter=',')
# 或者指定缺失值的表示方式
data = np.loadtxt('data_with_missing.csv', delimiter=',',
missing_values=['NA', 'N/A', ''])np.savetxt:保存文本文件
np.savetxt 将数组保存为文本文件:
python
# 基本用法:保存为 CSV
arr = np.array([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0]])
np.savetxt('output.csv', arr, delimiter=',')格式化输出
可以指定输出格式:
python
# 使用 fmt 参数控制格式
# %.2f: 保留两位小数
np.savetxt('formatted.csv', arr, delimiter=',', fmt='%.2f')
# 使用科学计数法
np.savetxt('scientific.csv', arr, delimiter=',', fmt='%.2e')
# 混合格式
np.savetxt('mixed.csv', arr, delimiter=',', fmt=['%.2f', '%.3f', '%.4f'])添加表头和注释
python
# header: 添加表头
np.savetxt('with_header.csv', arr, delimiter=',',
header='col1, col2, col3', comments='')
# footer: 添加脚注
np.savetxt('with_footer.csv', arr, delimiter=',',
footer='Generated by NumPy')在LLM场景中的应用
保存和加载词汇表
词汇表文件通常是多行文本,每行一个 token。可以用 loadtxt 读取:
python
def load_vocabulary(filepath):
"""加载词汇表文件
词汇表格式:每行一个 token(字符串)
"""
# 方法1:直接读取
with open(filepath, 'r', encoding='utf-8') as f:
vocab = [line.strip() for line in f]
# 方法2:使用 loadtxt(返回字符串数组)
vocab = np.loadtxt(filepath, dtype=str, delimiter='\n')
return vocab
# 保存词汇表
def save_vocabulary(vocab, filepath):
"""保存词汇表到文件"""
np.savetxt(filepath, vocab, fmt='%s', delimiter='\n')
# 示例词汇表
vocab = ['[PAD]', '[UNK]', '[CLS]', '[SEP]', '[MASK]', 'the', 'cat', 'dog']
save_vocabulary(vocab, 'vocab.txt')
# 加载词汇表
loaded_vocab = load_vocabulary('vocab.txt')
print(f"加载的词汇表: {loaded_vocab}")
print(f"词汇表大小: {len(loaded_vocab)}")保存和加载配置文件
训练超参数或模型配置可以用文本格式保存:
python
def save_config(filepath, config_dict):
"""保存配置为文本文件
格式: key=value
"""
with open(filepath, 'w') as f:
for key, value in config_dict.items():
f.write(f"{key}={value}\n")
def load_config(filepath):
"""从文本文件加载配置"""
config = {}
with open(filepath, 'r') as f:
for line in f:
line = line.strip()
if line and '=' in line:
key, value = line.split('=', 1)
config[key] = value
return config
# 示例配置
config = {
'vocab_size': 50257,
'hidden_size': 768,
'num_layers': 12,
'num_heads': 12,
'intermediate_size': 3072,
'max_position_embeddings': 1024,
'learning_rate': 5e-5,
'batch_size': 32,
}
save_config('training_config.txt', config)
loaded_config = load_config('training_config.txt')
print(f"加载的配置: {loaded_config}")保存训练日志
训练过程中的损失值和指标可以保存为 CSV 格式:
python
def save_training_log(log_path, epochs, losses, accuracies):
"""保存训练日志
参数:
log_path: 日志文件路径
epochs: epoch 列表
losses: 损失列表
accuracies: 准确率列表
"""
header = 'epoch,loss,accuracy'
data = np.column_stack([epochs, losses, accuracies])
np.savetxt(log_path, data, delimiter=',', header=header, comments='', fmt=['%d', '%.6f', '%.6f'])
# 示例训练日志
np.random.seed(42)
epochs = np.arange(1, 11)
losses = 2.5 - 0.2 * epochs + np.random.randn(10) * 0.1
accuracies = 0.5 + 0.05 * epochs + np.random.randn(10) * 0.02
save_training_log('training_log.csv', epochs, losses, accuracies)
# 读取并绘图
log_data = np.loadtxt('training_log.csv', delimiter=',', skiprows=1)
print(f"训练日志:\n{log_data}")高级用法
读取不规则分隔符文件
python
# 使用不同的分隔符
data = np.loadtxt('data.tsv', delimiter='\t') # Tab 分隔
# 使用正则表达式分隔(需要更多处理)
# 对于复杂格式,可能需要使用 np.genfromtxt
data = np.genfromtxt('complex_data.csv', delimiter=',')使用 np.genfromtxt 处理复杂情况
np.genfromtxt 是更强大的文本文件读取函数,可以处理缺失值、混合类型等复杂情况:
python
# 处理缺失值
data = np.genfromtxt('data_with_missing.csv', delimiter=',',
missing_values='NA', filling_values=0.0)
# 自动检测数据类型
data = np.genfromtxt('mixed_data.csv', delimiter=',',
dtype=None, encoding='utf-8')读取部分数据
对于大型文件,可能只需要读取部分数据:
python
# 只读取前 N 行
data = np.loadtxt('large_file.csv', delimiter=',', max_rows=1000)
# 读取指定行(使用 skiprows 和 max_rows 组合)
data = np.loadtxt('large_file.csv', delimiter=',',
skiprows=100, max_rows=50)常见误区
误区一:忘记指定 delimiter
如果文件使用非逗号分隔,务必指定正确的 delimiter:
python
# Tab 分隔文件
data = np.loadtxt('data.tsv', delimiter='\t')
# 空格分隔文件
data = np.loadtxt('data.txt', delimiter=' ')误区二:大文件使用 loadtxt 效率低
对于大型文本文件,loadtxt 可能很慢,因为需要逐行解析。考虑使用二进制格式(np.save)或内存映射(np.memmap):
python
# 对于大型数组,使用二进制格式更高效
# arr = np.load('data.npy') # 比 loadtxt 快很多误区三:默认使用科学计数法导致精度问题
默认格式可能使用科学计数法,对于精确数据可能不合适:
python
# 指定固定小数位数
np.savetxt('precise.csv', arr, delimiter=',', fmt='%.10f')
# 或使用 '%s' 保存完整精度(对于浮点数)误区四:编码问题
处理包含非 ASCII 字符的文件时,需要指定编码:
python
# 读取 UTF-8 编码的文件
vocab = np.loadtxt('vocab.txt', dtype=str, delimiter='\n', encoding='utf-8')性能对比
文本文件 IO 比二进制 IO 慢很多:
python
import time
arr = np.random.randn(10000, 100)
# 测量 savetxt 性能
start = time.time()
np.savetxt('text_data.csv', arr, delimiter=',', fmt='%.6f')
text_time = time.time() - start
# 测量 save 性能
start = time.time()
np.save('binary_data.npy', arr)
binary_time = time.time() - start
import os
text_size = os.path.getsize('text_data.csv')
binary_size = os.path.getsize('binary_data.npy')
print(f"文本文件大小: {text_size / 1024 / 1024:.1f} MB, 耗时: {text_time:.2f}s")
print(f"二进制文件大小: {binary_size / 1024 / 1024:.1f} MB, 耗时: {binary_time:.4f}s")API 总结
| 函数 | 用途 | 关键参数 |
|---|---|---|
np.loadtxt(fname) | 读取文本文件 | delimiter, dtype, skiprows, max_rows |
np.savetxt(fname, X) | 保存文本文件 | delimiter, fmt, header, comments |
np.genfromtxt(fname) | 读取复杂文本 | delimiter, dtype, missing_values, filling_values |
np.loadtxt 和 np.savetxt 是处理 LLM 相关文本数据的实用工具,掌握它们的用法可以让你更高效地管理词汇表、配置文件和训练日志。