主题
数据增强与清洗:让数据发挥最大价值
这一节讲什么?
在实际项目中,你拿到的原始数据往往存在各种问题:标注噪声、类别不平衡、样本量不足、包含无关信息或有害内容等。直接把这样的数据送入模型训练,效果可想而知。
这一节,我们将系统性地学习如何提升数据质量和扩充数据规模:
- 文本增强技术:同义词替换、回译、随机噪声——用少量数据生成更多样本
- 数据去重:MinHash/SimHash 处理大规模文本去重
- 数据平衡:过采样/欠采样/focal loss 应对长尾分布
- 质量过滤:语言检测、长度过滤、毒性检测
一、文本数据增强
1.1 为什么需要数据增强?
在 NLP 任务中,标注数据是昂贵的资源。一个高质量的领域情感分析数据集可能需要数周的人工标注工作,成本高达数十万元。数据增强的核心思想是:
通过对现有样本进行有意义的变换,人工生成新的训练样本,从而在不增加标注成本的前提下扩大训练集。
但要注意:不是所有增强方法都适用于所有任务。比如同义词替换可能改变情感极性("好"→"棒"没问题,但某些上下文中可能有微妙差异),回译可能引入翻译腔。
1.2 同义词替换(Synonym Replacement)
python
import random
from datasets import Dataset
class SynonymReplacer:
"""基于词典的同义词替换增强器"""
def __init__(self, synonym_dict=None):
self.synonym_dict = synonym_dict or {
"好": ["棒", "优秀", "出色", "赞", "不错"],
"差": ["烂", "糟糕", "差劲", "不行"],
"喜欢": ["爱", "钟爱", "喜爱"],
"讨厌": ["厌恶", "反感", "不喜欢"],
"大": ["巨大", "庞大", "硕大"],
"小": ["微小", "细小", "渺小"],
}
def augment(self, text, replace_ratio=0.3, n_augments=2):
"""对文本进行同义词替换增强"""
words = list(text)
results = [text]
for _ in range(n_augments):
new_words = words.copy()
n_replace = max(1, int(len(words) * replace_ratio))
positions = random.sample(range(len(words)), min(n_replace, len(words)))
for pos in positions:
word = words[pos]
if word in self.synonym_dict:
new_words[pos] = random.choice(self.synonym_dict[word])
results.append("".join(new_words))
return results
replacer = SynonymReplacer()
test_texts = [
"这个产品非常好",
"质量太差了",
"我很喜欢这个设计",
]
print("=== 同义词替换增强 ===")
for text in test_texts:
augmented = replacer.augment(text, replace_ratio=0.4, n_augments=3)
print(f"\n原文: {text}")
for i, aug_text in enumerate(augmented[1:], 1):
print(f" 增强{i}: {aug_text}")1.3 回译增强(Back-Translation)
回译的思路是:中文 → 英文 → 中文(或者经过更多中间语言)。由于每次翻译都会引入一些变化,最终得到的文本与原文语义相近但表达不同:
python
def demonstrate_back_translation():
"""演示回译增强的效果(模拟)"""
from transformers import pipeline
try:
zh_en = pipeline("translation", model="Helsinki-nlp/opus-mt-zh-en", device=-1)
en_zh = pipeline("translation", model="Helsinki-nlp/opus-mt-en-zh", device=-1)
original = "这款手机拍照效果非常出色,电池续航也很给力"
print(f"原文: {original}")
english = zh_en(original)[0]["translation_text"]
print(f"中→英: {english}")
back_zh = en_zh(english)[0]["translation_text"]
print(f"英→中: {back_zh}")
print(f"\n变化点:")
for a, b in zip(original, back_zh):
if a != b:
print(f" '{a}' → '{b}'")
except Exception as e:
print(f"回译模型未完全加载: {e}")
print("\n模拟回译结果示例:")
examples = [
("这款手机拍照效果非常出色,电池续航也很给力",
"这部手机的摄影功能十分优异,电池使用时间也相当不错"),
("客服态度很好,物流也很快",
"客户服务态度佳,配送速度也很快"),
]
for orig, bt in examples:
print(f"\n 原文: {orig}")
print(f" 回译: {bt}")
demonstrate_back_translation()回译的优势与风险:
- ✅ 保持语义一致性较高(比随机替换更可靠)
- ✅ 引入自然的句式变化(因为翻译模型学到了多种表达)
- ❌ 需要加载额外的翻译模型(资源开销大)
- ❌ 可能引入"翻译腔"(不够地道的表达)
- ❌ 多次回译可能导致语义漂移
1.4 EDA(Easy Data Augmentation)
EDA 是一组简单但有效的文本增强操作组合:
python
import random
import re
class EDAAugmentor:
"""
Easy Data Augmentation: 四种简单操作的组合
论文: EDA: Easy Data Augmentation Techniques for Boosting
Classifier Performance on Text Classification Tasks
"""
def __init__(self, alpha_sr=0.1, alpha_ri=0.1, alpha_rs=0.1,
p_rd=0.1, num_aug=2):
self.alpha_sr = alpha_sr # 同义词替换比例
self.alpha_ri = alpha_ri # 随机插入比例
self.alpha_rs = alpha_rs # 随机交换比例
self.p_rd = p_rd # 随机删除概率
self.num_aug = num_aug # 增强倍数
def _get_words(self, text):
return list(text)
def _synonym_replace(self, words, synonym_dict):
n_sr = max(1, int(self.alpha_sr * len(words)))
new_words = words.copy()
random_set = random.sample(list(range(len(new_words))), min(n_sr, len(new_words)))
for idx in random_set:
word = new_words[idx]
if word in synonym_dict:
new_words[idx] = random.choice(synonym_dict[word])
return new_words
def _random_insert(self, words, synonym_dict):
n_ri = max(1, int(self.alpha_ri * len(words)))
new_words = words.copy()
for _ in range(n_ri):
add_word = random.choice(list(synonym_dict.keys()))
random_idx = random.randint(0, len(new_words))
new_words.insert(random_idx, add_word)
return new_words
def _random_swap(self, words):
n_rs = max(1, int(self.alpha_rs * len(words)))
new_words = words.copy()
for _ in range(n_rs):
if len(new_words) >= 2:
idx1, idx2 = random.sample(range(len(new_words)), 2)
new_words[idx1], new_words[idx2] = new_words[idx2], new_words[idx1]
return new_words
def _random_delete(self, words):
if len(words) == 1:
return words
new_words = [w for w in words if random.random() > self.p_rd]
if len(new_words) == 0:
return [random.choice(words)]
return new_words
def augment(self, text, synonym_dict=None):
synonym_dict = synonym_dict or {
"的": "之", "很": "非常", "好": "棒", "不": "没",
"是": "为", "了": "过", "在": "于", "和": "与",
}
words = self._get_words(text)
aug_texts = []
for _ in range(self.num_aug):
a_words = words.copy()
aug_type = random.choice(["sr", "ri", "rs", "rd"])
if aug_type == "sr":
a_words = self._synonym_replace(a_words, synonym_dict)
elif aug_type == "ri":
a_words = self._random_insert(a_words, synonym_dict)
elif aug_type == "rs":
a_words = self._random_swap(a_words)
else:
a_words = self._random_delete(a_words)
aug_texts.append("".join(a_words))
return aug_texts
eda = EDAAugmentor(num_aug=3)
texts = ["这个产品真的很好用", "今天天气非常不错"]
for text in texts:
augmented = eda.augment(text)
print(f"\n原文: {text}")
for i, aug in enumerate(augmented, 1):
print(f" EDA-{i}: {aug}")1.5 使用 Datasets 的 map 实现批量增强
python
from datasets import load_dataset
from transformers import AutoTokenizer
def build_augmented_dataset(base_dataset, augment_fn, copies=2):
"""对数据集进行增强并扩展"""
original_data = {"text": [], "label": []}
for example in base_dataset:
text = example["text"]
label = example["label"]
original_data["text"].append(text)
original_data["label"].append(label)
augmented_texts = augment_fn(text, n_augments=copies)
for aug_text in augmented_texts:
original_data["text"].append(aug_text)
original_data["label"].append(label)
augmented_ds = Dataset.from_dict(original_data)
return augmented_ds
ds = load_dataset("imdb", split="train").select(range(100))
augmentor = SynonymReplacer()
augmented_ds = build_augmented_dataset(ds, augmentor.augment, copies=2)
print(f"原始数据集大小: {len(ds)}")
print(f"增强后数据集大小: {len(augmented_ds)}")
print(f"扩增倍数: {len(augmented_ds) / len(ds):.1f}x")二、数据去重
大规模爬取的语料库中往往存在大量重复或高度相似的文本。去重不仅能减少存储空间和计算开销,更重要的是避免模型在训练时对重复样本过拟合。
2.1 精确去重
python
def exact_dedup(dataset, text_col="text"):
"""精确去重(完全相同的文本只保留一条)"""
seen = set()
unique_indices = []
for i, example in enumerate(dataset):
text_hash = hash(example[text_col])
if text_hash not in seen:
seen.add(text_hash)
unique_indices.append(i)
deduped = dataset.select(unique_indices)
removed = len(dataset) - len(deduped)
print(f"精确去重: {len(dataset)} → {len(deduped)} (移除 {removed} 条重复)")
return deduped2.2 SimHash —— 近似去重
对于"大致相同但不完全一致"的文本(如同一篇文章的不同版本、同一事件的多篇报道),需要使用近似去重算法。SimHash 是工业界最常用的方案之一:
python
class SimHashDeduplicator:
"""
SimHash 近似去重器
原理:
1. 将文本分词得到特征集合
2. 对每个特征计算 hash(固定长度的二进制向量)
3. 将所有 hash 向量加权求和(正特征+1,负特征-1)
4. 对结果向量每个维度:正→1,负→0 → 得到 fingerprint
5. 两个 fingerprint 的汉明距离 < threshold 则视为近似重复
"""
def __init__(self, hash_bits=64, similarity_threshold=0.85):
self.hash_bits = hash_bits
self.threshold = int(hash_bits * (1 - similarity_threshold))
def _tokenize(self, text):
simple_tokens = []
current_token = []
for ch in text:
if '\u4e00' <= ch <= '\u9fff':
if current_token:
simple_tokens.append(''.join(current_token))
current_token = []
simple_tokens.append(ch)
else:
current_token.append(ch)
if current_token:
simple_tokens.append(''.join(current_token))
return [t for t in simple_tokens if len(t) > 0]
def _hash_feature(self, feature):
md5_hash = hashlib.md5(feature.encode('utf-8')).digest()
integer_value = int.from_bytes(md5_hash[:8], byteorder='big')
return integer_value >> (64 - self.hash_bits)
def compute_simhash(self, text):
tokens = self._tokenize(text)
vector = [0] * self.hash_bits
token_counts = {}
for t in tokens:
token_counts[t] = token_counts.get(t, 0) + 1
for token, count in token_counts.items():
h = self._hash_feature(token)
for i in range(self.hash_bits):
bitmask = 1 << i
if h & bitmask:
vector[i] += count
else:
vector[i] -= count
fingerprint = 0
for i in range(self.hash_bits):
if vector[i] > 0:
fingerprint |= (1 << i)
return fingerprint
def hamming_distance(self, hash1, hash2):
x = hash1 ^ hash2
distance = 0
while x:
distance += 1
x &= x - 1
return distance
def deduplicate(self, texts, show_progress=True):
fingerprints = []
unique_indices = []
seen = set()
total = len(texts)
for i, text in enumerate(texts):
fp = self.compute_simhash(text)
is_duplicate = False
for seen_fp in seen:
if self.hamming_distance(fp, seen_fp) <= self.threshold:
is_duplicate = True
break
if not is_duplicate:
seen.add(fp)
unique_indices.append(i)
if show_progress and (i + 1) % 1000 == 0:
print(f" 进度: {i+1}/{total}, 已保留 {len(unique_indices)}, "
f"去重率 {(i+1-len(unique_indices))/(i+1)*100:.1f}%")
return unique_indices
import hashlib
deduplicator = SimHashDeduplicator(hash_bits=64, similarity_threshold=0.90)
test_corpus = [
"人工智能正在改变世界",
"人工智能正在改变世界。",
"AI正在改变我们的世界",
"深度学习是人工智能的一个分支",
"深度学习是 AI 的一个重要分支",
"今天的天气真好",
"自然语言处理很有趣",
"NLP 是一个非常有趣的研究方向",
] * 10
indices = deduplicator.deduplicate(test_corpus)
unique_texts = [test_corpus[i] for i in indices]
print(f"\nSimHash 去重结果:")
print(f" 原始: {len(test_corpus)} 条")
print(f" 去重后: {len(unique_texts)} 条")
print(f"\n保留的唯一文本:")
for t in unique_texts:
print(f" • {t}")三、数据不平衡处理
3.1 诊断不平衡问题
python
from datasets import load_dataset
from collections import Counter
import matplotlib.pyplot as plt
def diagnose_class_imbalance(dataset, label_col="label"):
"""诊断数据集中的类别不平衡问题"""
labels = dataset[label_col]
counts = Counter(labels)
unique_labels = sorted(counts.keys())
frequencies = [counts[l] for l in unique_labels]
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
colors = plt.cm.Set3(np.linspace(0, 1, len(unique_labels)))
axes[0].bar(range(len(unique_labels)), frequencies, color=colors)
axes[0].set_xticks(range(len(unique_labels)))
axes[0].set_xticklabels([str(l) for l in unique_labels])
axes[0].set_ylabel('Sample Count')
axes[0].set_title('各类别样本数量')
total = sum(frequencies)
percentages = [f / total * 100 for f in frequencies]
axes[1].pie(frequencies, labels=[str(l) for l in unique_labels],
autopct='%1.1f%%', colors=colors)
axes[1].set_title('类别占比')
sorted_pct = sorted(percentages, reverse=True)
imbalance_ratio = sorted_pct[0] / max(sorted_pct[-1], 1)
axes[2].bar(range(len(sorted_pct)), sorted_pct, color='coral')
axes[2].axhline(y=total/len(unique_labels)/total*100, color='green',
linestyle='--', label='均衡线')
axes[2].set_ylabel('Percentage (%)')
axes[2].set_title(f'不平衡程度 (最大/最小比={imbalance_ratio:.1f}x)')
axes[2].legend()
plt.tight_layout()
plt.show()
print(f"\n统计摘要:")
print(f" 总样本数: {total}")
print(f" 类别数: {len(unique_labels)}")
print(f" 最多类: {max(frequencies)} 样本 ({max(percentages):.1f}%)")
print(f" 最少类: {min(frequencies)} 样本 ({min(percentages):.1f}%)")
print(f" 不平衡比: {imbalance_ratio:.1f}x")
if imbalance_ratio > 10:
print(f" ⚠️ 严重不平衡!建议采取处理措施")
elif imbalance_ratio > 3:
print(f" ⚠️ 轻度不平衡,建议关注少数类的表现")
else:
print(f" ✓ 分布相对均衡")
ds = load_dataset("ag_news", split="train").select(range(2000))
diagnose_class_imbalance(ds, "label")3.2 过采样(Oversampling)
python
def oversample_minority_classes(dataset, label_col="label", strategy="random",
random_seed=42):
"""
过采样:复制少数类样本以平衡数据集
Args:
strategy: "random" (随机复制) 或 "smote" (需额外库)
"""
random.seed(random_seed)
labels = dataset[label_col]
counts = Counter(labels)
max_count = max(counts.values())
all_indices = list(range(len(dataset)))
extra_indices = []
for label, count in counts.items():
if count < max_count:
class_indices = [i for i in all_indices if labels[i] == label]
n_needed = max_count - count
if strategy == "random":
sampled = random.choices(class_indices, k=n_needed)
extra_indices.extend(sampled)
balanced_indices = all_indices + extra_indices
random.shuffle(balanced_indices)
balanced_ds = dataset.select(balanced_indices)
new_counts = Counter(balanced_ds[label_col])
print(f"过采样前: {dict(counts)}")
print(f"过采样后: {dict(new_counts)}")
return balanced_ds3.3 欠采样(Undersampling)与加权损失
python
def create_balanced_sampler(dataset, label_col="label"):
"""创建 WeightedRandomSampler 用于欠采样效果"""
from torch.utils.data import WeightedRandomSampler
import torch
labels = dataset[label_col]
counts = Counter(labels)
label_to_count = dict(counts)
weights = 1. / torch.tensor(
[label_to_count[label.item()] for label in labels],
dtype=torch.float
)
sampler = WeightedRandomSampler(
weights=weights,
num_samples=len(weights),
replacement=True,
)
return sampler
def compute_class_weights(dataset, label_col="label"):
"""计算用于损失函数加权的类别权重"""
labels = dataset[label_col]
counts = Counter(labels)
total = len(labels)
n_classes = len(counts)
weights = {}
for label, count in counts.items():
weights[label] = total / (n_classes * count)
print(f"类别权重 (少数类权重更高):")
for label in sorted(weights.keys()):
print(f" 类别 {label}: weight={weights[label]:.3f} (样本数={counts[label]})")
return weights四、质量过滤规则引擎
python
class DataQualityFilter:
"""
数据质量过滤器
组合多条规则对数据进行清洗
"""
def __init__(self, rules=None):
self.rules = rules or {}
self.stats = {"total": 0, "passed": 0, "rejected": {}}
def add_rule(self, name, filter_fn, description=""):
self.rules[name] = {"fn": filter_fn, "desc": description}
def check_length(self, min_len=5, max_len=5000):
"""添加长度过滤规则"""
def rule(example):
text = example.get("text", "")
return min_len <= len(text.strip()) <= max_len
self.add_rule("length_check", rule,
f"长度在 [{min_len}, {max_len}] 之间")
return self
def check_language(self, target_lang="zh", threshold=0.7):
"""添加语言检测规则"""
try:
from langdetect import detect
def rule(example):
text = example.get("text", "").strip()
if not text:
return False
try:
return detect(text) == target_lang
except:
return False
self.add_rule("language_check", rule,
f"语言为 {target_lang}")
except ImportError:
pass
return self
def check_repetition(self, max_repeat_ratio=0.3):
"""添加重复内容过滤"""
def rule(example):
text = example.get("text", "").strip()
if not text:
return False
chars = list(text)
unique_chars = set(chars)
repeat_ratio = 1 - len(unique_chars) / len(chars)
return repeat_ratio < max_repeat_ratio
self.add_rule("repetition_check", rule,
f"重复字符比例 < {max_repeat_ratio}")
return self
def check_special_patterns(self, blocked_patterns=None):
"""添加模式匹配规则"""
import re
patterns = blocked_patterns or [
r"http[s]?://\S+", # URL
r"<[^>]+>", # HTML 标签
r"[\U0001F600-\U0001F64F]", # Emoji
r"^[\s\W\d]+$", # 纯符号/数字
]
compiled = [re.compile(p) for p in patterns]
def rule(example):
text = example.get("text", "").strip()
for pattern in compiled:
if pattern.search(text):
return False
return True
self.add_rule("pattern_check", rule, "无 URL/HTML/特殊模式")
return self
def filter_dataset(self, dataset, verbose=True):
"""对整个数据集执行过滤"""
self.stats = {"total": len(dataset), "passed": 0, "rejected": {}}
def apply_all_rules(example):
for name, rule_info in self.rules.items():
if not rule_info["fn"](example):
self.stats["rejected"][name] = \
self.stats["rejected"].get(name, 0) + 1
return False
self.stats["passed"] += 1
return True
filtered = dataset.filter(apply_all_rules, num_proc=4)
if verbose:
print("\n" + "=" * 50)
print("数据质量过滤报告")
print("=" * 50)
print(f"总样本数: {self.stats['total']:,}")
print(f"通过筛选: {self.stats['passed']:,} ({self.stats['passed']/self.stats['total']*100:.1f}%)")
print(f"被拒绝: {self.stats['total'] - self.stats['passed']:,}")
if self.stats["rejected"]:
print("\n各规则拒绝数量:")
for name, count in sorted(self.stats["rejected"].items(),
key=lambda x: -x[1]):
pct = count / self.stats['total'] * 100
desc = self.rules[name]["desc"]
print(f" {name:<20s} {count:>6d} ({pct:>5.1f}%) — {desc}")
return filtered
filter_engine = DataQualityFilter()
filter_engine.check_length(min_len=3, max_len=2000)\
.check_repetition(max_repeat_ratio=0.5)\
.check_special_patterns()
dummy_data = {
"text": [
"这是一个正常的中文文本",
"短", # 太短
"A" * 3000, # 太长
"访问 https://example.com 了解更多", # 含URL
"哈哈哈哈哈哈哈哈哈", # 高重复
"正常的产品评论,质量不错",
"<script>alert(1)</script>", # HTML标签
"又一个正常样本",
],
"label": [1, 0, 1, 0, 1, 1, 0, 1],
}
dummy_ds = Dataset.from_dict(dummy_data)
cleaned_ds = filter_engine.filter_dataset(dummy_ds)输出示例:
==================================================
数据质量过滤报告
==================================================
总样本数: 8
通过筛选: 3 (37.5%)
被拒绝: 5
各规则拒绝数量:
length_check 2 (25.0%) — 长度在 [3, 2000] 之间
repetition_check 1 (12.5%) — 重复字符比例 < 0.5
pattern_check 2 (25.0%) — 无 URL/HTML/特殊模式小结
这一节我们系统学习了数据增强与清洗的关键技术:
- 文本增强四件套:同义词替换(简单有效)、回译增强(语义保持好但资源消耗大)、EDA(四种操作组合)、基于
map的批量增强流水线 - SimHash 近似去重:通过将文本映射为定长指纹再计算汉明距离来识别近似重复文本;精确去重只能处理完全相同的文本
- 不平衡数据处理:诊断(可视化分布)→ 过采样(复制少数类)→ 欠采样(
WeightedRandomSampler)→ 加权损失(compute_class_weights);不平衡比 >10x 时必须处理 - 质量过滤规则引擎:长度检查、语言检测、重复检测、模式匹配(URL/HTML 等);可组合使用;建议始终在生产流程中加入基本的质量门控
下一节是第5章的最后一节——自定义数据集,包括上传到 Hub 和 Builder 模式。