名称: 数据增强 描述: 涵盖图像、文本、音频和表格数据的数据增强技术的全面指南。
数据增强
概述
数据增强是一种通过创建现有数据的修改版本来人工增加训练数据集大小和多样性的技术。本技能涵盖图像、文本、音频和表格数据的增强技术,包括流行的库如 Albumentations、NLPAug 和自定义增强策略。
先决条件
- 理解机器学习概念
- 了解 Python 编程
- 熟悉数据预处理
- 理解过拟合和泛化
- 图像、文本、音频处理的基础知识
关键概念
增强类型
- 图像增强: 几何变换、颜色调整、噪声注入
- 文本增强: 反向翻译、同义词替换、单词插入/删除
- 音频增强: 时间拉伸、音高移动、噪声添加、掩码
- 表格增强: SMOTE、ADASYN、高斯噪声、特征混合
增强策略
- 在线增强: 在训练期间应用增强
- 离线增强: 预计算增强样本
- 测试时增强 (TTA): 在推理时应用多个增强
- AutoAugment: 自动搜索最优增强策略
常见库
- Albumentations: 快速灵活的图像增强库
- Torchvision: PyTorch 的内置变换
- ImgAug: 强大的图像增强库
- NLPAug: 文本增强库
- Imbalanced-learn: 表格数据增强 (SMOTE, ADASYN)
实现指南
图像增强
Albumentations
import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import numpy as np
class AlbumentationsAugmentor:
"""使用 Albumentations 进行图像增强。"""
def __init__(self, mode='train', image_size=(224, 224)):
self.mode = mode
self.image_size = image_size
self.transform = self._get_transform()
def _get_transform(self):
"""获取变换管道。"""
if self.mode == 'train':
return A.Compose([
A.Resize(height=self.image_size[0], width=self.image_size[1]),
A.HorizontalFlip(p=0.5),
A.VerticalFlip(p=0.2),
A.RandomRotate90(p=0.5),
A.Rotate(limit=30, p=0.5),
A.ShiftScaleRotate(
shift_limit=0.1,
scale_limit=0.1,
rotate_limit=15,
p=0.5
),
A.OneOf([
A.GaussNoise(p=1.0),
A.ISONoise(p=1.0),
A.MultiplicativeNoise(p=1.0),
], p=0.2),
A.OneOf([
A.MotionBlur(p=1.0),
A.MedianBlur(p=1.0),
A.GaussianBlur(p=1.0),
], p=0.2),
A.OneOf([
A.OpticalDistortion(p=1.0),
A.GridDistortion(p=1.0),
A.IAAPiecewiseAffine(p=1.0),
], p=0.2),
A.OneOf([
A.CLAHE(clip_limit=2),
A.IAASharpen(),
A.IAAEmboss(),
A.RandomBrightnessContrast(),
], p=0.3),
A.HueSaturationValue(p=0.3),
A.RandomBrightnessContrast(p=0.3),
A.Cutout(num_holes=8, max_h_size=16, max_w_size=16, p=0.3),
A.CoarseDropout(
max_holes=8,
max_height=32,
max_width=32,
min_holes=1,
min_height=8,
min_width=8,
p=0.3
),
A.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
ToTensorV2()
])
else:
return A.Compose([
A.Resize(height=self.image_size[0], width=self.image_size[1]),
A.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
),
ToTensorV2()
])
def __call__(self, image):
"""应用增强。"""
return self.transform(image=image)['image']
# 使用
train_augmentor = AlbumentationsAugmentor(mode='train', image_size=(224, 224))
val_augmentor = AlbumentationsAugmentor(mode='val', image_size=(224, 224))
# 应用增强
augmented_image = train_augmentor(original_image)
Torchvision 变换
import torchvision.transforms as transforms
from torchvision.transforms import functional as F
class TorchvisionAugmentor:
"""使用 torchvision 进行图像增强。"""
def __init__(self, mode='train', image_size=224):
self.mode = mode
self.image_size = image_size
self.transform = self._get_transform()
def _get_transform(self):
"""获取变换管道。"""
if self.mode == 'train':
return transforms.Compose([
transforms.Resize((self.image_size, self.image_size)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.2),
transforms.RandomRotation(degrees=30),
transforms.RandomAffine(
degrees=0,
translate=(0.1, 0.1),
scale=(0.9, 1.1),
shear=10
),
transforms.ColorJitter(
brightness=0.3,
contrast=0.3,
saturation=0.3,
hue=0.1
),
transforms.RandomGrayscale(p=0.1),
transforms.RandomPerspective(distortion_scale=0.2, p=0.3),
transforms.RandomResizedCrop(
self.image_size,
scale=(0.8, 1.0),
ratio=(0.9, 1.1)
),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
else:
return transforms.Compose([
transforms.Resize((self.image_size, self.image_size)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
def __call__(self, image):
"""应用增强。"""
return self.transform(image)
# 使用
train_transform = TorchvisionAugmentor(mode='train', image_size=224)
augmented_image = train_transform(pil_image)
ImgAug
import imgaug as ia
import imgaug.augmenters as iaa
class ImgAugAugmentor:
"""使用 imgaug 进行图像增强。"""
def __init__(self, mode='train'):
self.mode = mode
self.augmenter = self._get_augmenter()
def _get_augmenter(self):
"""获取增强管道。"""
if self.mode == 'train':
return iaa.Sequential([
iaa.Fliplr(0.5), # 水平翻转
iaa.Flipud(0.2), # 垂直翻转
iaa.Affine(
rotate=(-30, 30),
scale=(0.9, 1.1),
translate_percent=(-0.1, 0.1)
),
iaa.Multiply((0.8, 1.2)), # 亮度
iaa.LinearContrast((0.8, 1.2)), # 对比度
iaa.AdditiveGaussianNoise(scale=(0, 0.05 * 255)),
iaa.GaussianBlur(sigma=(0, 1.0)),
iaa.Dropout(p=(0, 0.2)),
iaa.CoarseDropout(
(0.0, 0.05),
size_percent=(0.02, 0.05)
),
iaa.Crop(percent=(0, 0.1)),
iaa.Pad(percent=(0, 0.1)),
iaa.ElasticTransformation(alpha=(0, 50), sigma=(0, 5)),
iaa.PiecewiseAffine(scale=(0.01, 0.05)),
])
else:
return iaa.Sequential([])
def __call__(self, image):
"""应用增强。"""
if self.mode == 'train':
image_aug = self.augmenter.augment_image(image)
return image_aug
return image
# 使用
train_augmentor = ImgAugAugmentor(mode='train')
augmented_image = train_augmentor(numpy_image)
文本增强
反向翻译
from deep_translator import GoogleTranslator
import random
class BackTranslationAugmentor:
"""使用反向翻译进行文本增强。"""
def __init__(self, languages=['fr', 'de', 'es']):
self.languages = languages
def augment(self, text, num_augmentations=1):
"""使用反向翻译增强文本。"""
augmented_texts = []
for _ in range(num_augmentations):
# 选择随机中间语言
intermediate_lang = random.choice(self.languages)
try:
# 翻译到中间语言
translated = GoogleTranslator(
source='auto',
target=intermediate_lang
).translate(text)
# 翻译回英文
back_translated = GoogleTranslator(
source=intermediate_lang,
target='en'
).translate(translated)
augmented_texts.append(back_translated)
except Exception as e:
print(f"反向翻译失败: {e}")
augmented_texts.append(text)
return augmented_texts
# 使用
augmentor = BackTranslationAugmentor(languages=['fr', 'de', 'es'])
augmented_texts = augmentor.augment("This is a sample text for augmentation.")
同义词替换
import nltk
from nltk.corpus import wordnet
import random
nltk.download('wordnet')
nltk.download('omw-1.4')
class SynonymReplacementAugmentor:
"""使用同义词替换进行文本增强。"""
def __init__(self, replacement_prob=0.3):
self.replacement_prob = replacement_prob
def get_synonyms(self, word):
"""获取单词的同义词。"""
synonyms = set()
for syn in wordnet.synsets(word):
for lemma in syn.lemmas():
synonym = lemma.name().replace('_', ' ')
if synonym.lower() != word.lower():
synonyms.add(synonym)
return list(synonyms)
def augment(self, text):
"""使用同义词替换增强文本。"""
words = text.split()
augmented_words = []
for word in words:
if random.random() < self.replacement_prob:
synonyms = self.get_synonyms(word)
if synonyms:
augmented_words.append(random.choice(synonyms))
else:
augmented_words.append(word)
else:
augmented_words.append(word)
return ' '.join(augmented_words)
# 使用
augmentor = SynonymReplacementAugmentor(replacement_prob=0.3)
augmented_text = augmentor.augment("The quick brown fox jumps over lazy dog")
NLPAug
import nlpaug.augmenter.word as naw
import nlpaug.augmenter.char as nac
import nlpaug.augmenter.sentence as nas
class NLPAugAugmentor:
"""使用 NLPAug 进行文本增强。"""
def __init__(self):
# 词级别增强器
self.synonym_aug = naw.SynonymAug(aug_src='wordnet')
self.contextual_aug = naw.ContextualWordEmbsAug(model_path='bert-base-uncased')
self.back_translation_aug = naw.BackTranslationAug(
from_model_name='facebook/wmt19-en-de',
to_model_name='facebook/wmt19-de-en'
)
self.insertion_aug = naw.RandomWordAug(action="insert")
self.swap_aug = naw.RandomWordAug(action="swap")
self.deletion_aug = naw.RandomWordAug(action="delete")
# 字符级别增强器
self.ocr_aug = nac.OcrAug()
self.keyboard_aug = nac.KeyboardAug()
# 句子级别增强器
self.synonym_sentence_aug = nas.ContextualWordEmbsForSentenceAug(
model_path='bert-base-uncased'
)
def augment_word_level(self, text, method='synonym', num_augmentations=1):
"""词级别增强。"""
augmented_texts = []
for _ in range(num_augmentations):
if method == 'synonym':
augmented = self.synonym_aug.augment(text)
elif method == 'contextual':
augmented = self.contextual_aug.augment(text)
elif method == 'insert':
augmented = self.insertion_aug.augment(text)
elif method == 'swap':
augmented = self.swap_aug.augment(text)
elif method == 'delete':
augmented = self.deletion_aug.augment(text)
else:
augmented = text
augmented_texts.append(augmented)
return augmented_texts
def augment_char_level(self, text, method='ocr', num_augmentations=1):
"""字符级别增强。"""
augmented_texts = []
for _ in range(num_augmentations):
if method == 'ocr':
augmented = self.ocr_aug.augment(text)
elif method == 'keyboard':
augmented = self.keyboard_aug.augment(text)
else:
augmented = text
augmented_texts.append(augmented)
return augmented_texts
def augment_sentence_level(self, text, num_augmentations=1):
"""句子级别增强。"""
augmented_texts = []
for _ in range(num_augmentations):
augmented = self.synonym_sentence_aug.augment(text)
augmented_texts.append(augmented)
return augmented_texts
# 使用
augmentor = NLPAugAugmentor()
# 词级别增强
augmented_texts = augmentor.augment_word_level(
"This is a sample text.",
method='synonym',
num_augmentations=3
)
# 字符级别增强
augmented_texts = augmentor.augment_char_level(
"This is a sample text.",
method='ocr',
num_augmentations=3
)
音频增强
import numpy as np
import librosa
import soundfile as sf
from scipy.signal import butter, lfilter
class AudioAugmentor:
"""用于语音和音乐的音频增强。"""
def __init__(self, sample_rate=16000):
self.sample_rate = sample_rate
def add_noise(self, audio, noise_factor=0.005):
"""添加高斯噪声。"""
noise = np.random.randn(len(audio))
augmented = audio + noise_factor * noise
return augmented
def time_shift(self, audio, shift_max=0.2):
"""随机时间偏移音频。"""
shift = int(np.random.uniform(-shift_max, shift_max) * len(audio))
return np.roll(audio, shift)
def pitch_shift(self, audio, n_steps=2):
"""音高移动。"""
return librosa.effects.pitch_shift(audio, sr=self.sample_rate, n_steps=n_steps)
def speed_change(self, audio, speed_factor=1.2):
"""改变速度。"""
return librosa.effects.time_stretch(audio, rate=speed_factor)
def time_stretch(self, audio, rate=1.2):
"""时间拉伸(改变持续时间)。"""
return librosa.effects.time_stretch(audio, rate=rate)
def add_reverb(self, audio, decay=0.5):
"""添加混响。"""
delay = int(0.05 * self.sample_rate)
impulse = np.zeros(delay + int(decay * self.sample_rate))
impulse[delay] = 1
impulse[delay:] *= np.exp(-np.arange(len(impulse) - delay) / (decay * self.sample_rate))
augmented = np.convolve(audio, impulse)[:len(audio)]
return augmented
def frequency_mask(self, audio, freq_mask_param=10):
"""频率掩码(用于频谱图)。"""
freq_mask = np.random.randint(0, freq_mask_param + 1)
f0 = np.random.uniform(0, freq_mask_param - freq_mask)
f0 = int(f0)
augmented = audio.copy()
augmented[f0:f0 + freq_mask, :] = 0
return augmented
def time_mask(self, audio, time_mask_param=10):
"""时间掩码(用于频谱图)。"""
time_mask = np.random.randint(0, time_mask_param + 1)
t0 = np.random.uniform(0, time_mask_param - time_mask)
t0 = int(t0)
augmented = audio.copy()
augmented[:, t0:t0 + time_mask] = 0
return augmented
def augment(self, audio, augmentations=None):
"""应用随机增强。"""
if augmentations is None:
augmentations = [
('noise', 0.3),
('time_shift', 0.3),
('pitch_shift', 0.2),
('speed_change', 0.2)
]
augmented = audio.copy()
for aug_name, prob in augmentations:
if np.random.random() < prob:
if aug_name == 'noise':
augmented = self.add_noise(augmented)
elif aug_name == 'time_shift':
augmented = self.time_shift(augmented)
elif aug_name == 'pitch_shift':
n_steps = np.random.randint(-2, 3)
augmented = self.pitch_shift(augmented, n_steps)
elif aug_name == 'speed_change':
speed = np.random.uniform(0.8, 1.2)
augmented = self.speed_change(augmented, speed)
return augmented
# 使用
augmentor = AudioAugmentor(sample_rate=16000)
augmented_audio = augmentor.augment(original_audio)
表格数据增强
import numpy as np
import pandas as pd
from imblearn.over_sampling import SMOTE, ADASYN
from imblearn.under_sampling import RandomUnderSampler
class TabularAugmentor:
"""表格数据的增强。"""
def __init__(self):
self.smote = None
self.adasyn = None
def smote_augmentation(self, X, y, sampling_strategy='auto'):
"""SMOTE 过采样。"""
self.smote = SMOTE(sampling_strategy=sampling_strategy, random_state=42)
X_resampled, y_resampled = self.smote.fit_resample(X, y)
return X_resampled, y_resampled
def adasyn_augmentation(self, X, y, sampling_strategy='auto'):
"""ADASYN 过采样。"""
self.adasyn = ADASYN(sampling_strategy=sampling_strategy, random_state=42)
X_resampled, y_resampled = self.adasyn.fit_resample(X, y)
return X_resampled, y_resampled
def random_oversampling(self, X, y):
"""随机过采样。"""
from imblearn.over_sampling import RandomOverSampler
ros = RandomOverSampler(random_state=42)
X_resampled, y_resampled = ros.fit_resample(X, y)
return X_resampled, y_resampled
def random_undersampling(self, X, y, sampling_strategy='auto'):
"""随机欠采样。"""
rus = RandomUnderSampler(sampling_strategy=sampling_strategy, random_state=42)
X_resampled, y_resampled = rus.fit_resample(X, y)
return X_resampled, y_resampled
def gaussian_noise_augmentation(self, X, noise_level=0.01):
"""向特征添加高斯噪声。"""
noise = np.random.normal(0, noise_level, X.shape)
X_augmented = X + noise
return X_augmented
def feature_mixup(self, X, y, alpha=0.2, n_samples=100):
"""特征混合增强。"""
X_augmented = []
y_augmented = []
for _ in range(n_samples):
# 采样两个随机索引
idx1, idx2 = np.random.choice(len(X), 2, replace=False)
# 混合特征
lam = np.random.beta(alpha, alpha)
x_mixed = lam * X[idx1] + (1 - lam) * X[idx2]
X_augmented.append(x_mixed)
y_augmented.append(y[idx1]) # 使用第一个样本的标签
return np.array(X_augmented), np.array(y_augmented)
def bootstrap_sampling(self, X, y, n_samples=None):
"""自助采样。"""
if n_samples is None:
n_samples = len(X)
indices = np.random.choice(len(X), n_samples, replace=True)
X_bootstrapped = X[indices]
y_bootstrapped = y[indices]
return X_bootstrapped, y_bootstrapped
# 使用
augmentor = TabularAugmentor()
# SMOTE 增强
X_augmented, y_augmented = augmentor.smote_augmentation(X_train, y_train)
# 高斯噪声增强
X_noisy = augmentor.gaussian_noise_augmentation(X_train, noise_level=0.01)
# 特征混合
X_mixup, y_mixup = augmentor.feature_mixup(X_train, y_train, n_samples=100)
增强策略
在线 vs 离线增强
class OnlineAugmentation:
"""在训练期间应用增强(在线)。"""
def __init__(self, augmentor):
self.augmentor = augmentor
def __call__(self, batch):
"""应用增强到批次。"""
augmented_batch = []
for item in batch:
augmented_item = self.augmentor(item)
augmented_batch.append(augmented_item)
return augmented_batch
# 使用 PyTorch DataLoader
from torch.utils.data import Dataset, DataLoader
class OnlineAugmentedDataset(Dataset):
"""在线增强的数据集。"""
def __init__(self, data, labels, augmentor):
self.data = data
self.labels = labels
self.augmentor = augmentor
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
item = self.data[idx]
label = self.labels[idx]
# 在训练期间应用增强
augmented_item = self.augmentor(item)
return augmented_item, label
# 创建在线增强的数据集
dataset = OnlineAugmentedDataset(X_train, y_train, augmentor)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
class OfflineAugmentation:
"""预计算增强样本(离线)。"""
def __init__(self, augmentor, augmentations_per_sample=2):
self.augmentor = augmentor
self.augmentations_per_sample = augmentations_per_sample
def augment_dataset(self, X, y):
"""创建增强数据集。"""
X_augmented = []
y_augmented = []
for i in range(len(X)):
# 添加原始样本
X_augmented.append(X[i])
y_augmented.append(y[i])
# 添加增强样本
for _ in range(self.augmentations_per_sample):
augmented = self.augmentor(X[i])
X_augmented.append(augmented)
y_augmented.append(y[i])
return np.array(X_augmented), np.array(y_augmented)
# 使用
offline_augmentor = OfflineAugmentation(augmentor, augmentations_per_sample=2)
X_augmented, y_augmented = offline_augmentor.augment_dataset(X_train, y_train)
概率设置
class ProbabilisticAugmentor:
"""基于概率应用的增强器。"""
def __init__(self):
self.augmentations = []
def add_augmentation(self, augmentor, probability=0.5):
"""添加带概率的增强。"""
self.augmentations.append((augmentor, probability))
def augment(self, data):
"""基于概率应用增强。"""
augmented_data = data
for augmentor, prob in self.augmentations:
if np.random.random() < prob:
augmented_data = augmentor(augmented_data)
return augmented_data
# 使用
prob_augmentor = ProbabilisticAugmentor()
prob_augmentor.add_augmentation(horizontal_flip, probability=0.5)
prob_augmentor.add_augmentation(rotation, probability=0.3)
prob_augmentor.add_augmentation(color_jitter, probability=0.4)
augmented_data = prob_augmentor.augment(original_data)
自定义增强
import numpy as np
from PIL import Image, ImageFilter, ImageEnhance
class CustomAugmentor:
"""自定义增强函数。"""
@staticmethod
def cutout(image, n_holes=8, max_h_size=16, max_w_size=16):
"""Cutout 增强。"""
if isinstance(image, Image.Image):
image = np.array(image)
h, w = image.shape[:2]
for _ in range(n_holes):
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - max_h_size // 2, 0, h)
y2 = np.clip(y + max_h_size // 2, 0, h)
x1 = np.clip(x - max_w_size // 2, 0, w)
x2 = np.clip(x + max_w_size // 2, 0, w)
if len(image.shape) == 3:
image[y1:y2, x1:x2, :] = 0
else:
image[y1:y2, x1:x2] = 0
return image
@staticmethod
def mixup(image1, image2, alpha=0.2):
"""Mixup 增强。"""
if isinstance(image1, Image.Image):
image1 = np.array(image1)
if isinstance(image2, Image.Image):
image2 = np.array(image2)
lam = np.random.beta(alpha, alpha)
mixed = lam * image1 + (1 - lam) * image2
return mixed.astype(np.uint8)
@staticmethod
def cutmix(image1, image2, alpha=1.0):
"""CutMix 增强。"""
if isinstance(image1, Image.Image):
image1 = np.array(image1)
if isinstance(image2, Image.Image):
image2 = np.array(image2)
h, w = image1.shape[:2]
# 生成随机边界框
lam = np.random.beta(alpha, alpha)
cut_rat = np.sqrt(1. - lam)
cut_w = int(w * cut_rat)
cut_h = int(h * cut_rat)
cx = np.random.randint(w)
cy = np.random.randint(h)
bbx1 = np.clip(cx - cut_w // 2, 0, w)
bby1 = np.clip(cy - cut_h // 2, 0, h)
bbx2 = np.clip(cx + cut_w // 2, 0, w)
bby2 = np.clip(cy + cut_h // 2, 0, h)
# 应用 cutmix
mixed = image1.copy()
mixed[bby1:bby2, bbx1:bbx2] = image2[bby1:bby2, bbx1:bbx2]
return mixed
@staticmethod
def mosaic(images):
"""Mosaic 增强(4 张图像)。"""
if len(images) != 4:
raise ValueError("Mosaic 需要恰好 4 张图像")
h, w = images[0].shape[:2]
mosaic = np.zeros((h * 2, w * 2, images[0].shape[2]), dtype=images[0].dtype)
# 将图像放置在 2x2 网格中
mosaic[:h, :w] = images[0]
mosaic[:h, w:] = images[1]
mosaic[h:, :w] = images[2]
mosaic[h:, w:] = images[3]
return mosaic
# 使用
# Cutout
cutout_image = CustomAugmentor.cutout(image, n_holes=8, max_h_size=32)
# Mixup
mixed_image = CustomAugmentor.mixup(image1, image2, alpha=0.2)
# CutMix
cutmix_image = CustomAugmentor.cutmix(image1, image2, alpha=1.0)
# Mosaic
mosaic_image = CustomAugmentor.mosaic([img1, img2, img3, img4])
验证集处理
class ValidationAugmentation:
"""处理验证集的增强。"""
def __init__(self, augmentor):
self.augmentor = augmentor
def augment_with_test_time_augmentation(self, data, n_augmentations=5):
"""测试时增强 (TTA)。"""
augmented_samples = []
# 创建多个增强版本
for _ in range(n_augmentations):
augmented = self.augmentor(data)
augmented_samples.append(augmented)
return augmented_samples
def average_predictions(self, predictions):
"""平均 TTA 的预测。"""
return np.mean(predictions, axis=0)
# 使用
val_augmentor = ValidationAugmentation(augmentor)
# 测试时增强
augmented_samples = val_augmentor.augment_with_test_time_augmentation(
validation_sample,
n_augmentations=5
)
# 获取每个增强样本的预测
predictions = [model.predict(sample) for sample in augmented_samples]
# 平均预测
final_prediction = val_augmentor.average_predictions(predictions)
生产考虑
高效增强
import multiprocessing as mp
from functools import partial
class EfficientAugmentor:
"""使用多进程进行高效增强。"""
def __init__(self, augmentor, n_workers=None):
self.augmentor = augmentor
self.n_workers = n_workers or mp.cpu_count()
def augment_batch(self, batch):
"""增强一批样本。"""
with mp.Pool(self.n_workers) as pool:
augmented_batch = pool.map(self.augmentor, batch)
return augmented_batch
def augment_dataset(self, dataset, batch_size=32):
"""高效增强整个数据集。"""
augmented_data = []
for i in range(0, len(dataset), batch_size):
batch = dataset[i:i + batch_size]
augmented_batch = self.augment_batch(batch)
augmented_data.extend(augmented_batch)
return augmented_data
# 使用
efficient_augmentor = EfficientAugmentor(augmentor, n_workers=4)
augmented_data = efficient_augmentor.augment_dataset(X_train, batch_size=32)
增强缓存
import hashlib
import pickle
from pathlib import Path
class CachedAugmentor:
"""带缓存的增强器,用于可重复性。"""
def __init__(self, augmentor, cache_dir='./augmentation_cache'):
self.augmentor = augmentor
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
def _get_cache_key(self, data):
"""生成数据的缓存键。"""
if isinstance(data, np.ndarray):
data_bytes = data.tobytes()
elif isinstance(data, (str, bytes)):
data_bytes = data.encode() if isinstance(data, str) else data
else:
data_bytes = pickle.dumps(data)
return hashlib.md5(data_bytes).hexdigest()
def augment(self, data, use_cache=True):
"""带缓存的增强。"""
cache_key = self._get_cache_key(data)
cache_file = self.cache_dir / f"{cache_key}.pkl"
if use_cache and cache_file.exists():
with open(cache_file, 'rb') as f:
return pickle.load(f)
# 应用增强
augmented = self.augmentor(data)
# 缓存结果
if use_cache:
with open(cache_file, 'wb') as f:
pickle.dump(augmented, f)
return augmented
# 使用
cached_augmentor = CachedAugmentor(augmentor)
augmented_data = cached_augmentor.augment(original_data)
常见模式
AutoAugment
import random
class AutoAugment:
"""AutoAugment 策略。"""
def __init__(self):
self.policies = [
('rotate', 30, 0.5),
('translate_x', 0.1, 0.5),
('translate_y', 0.1, 0.5),
('shear_x', 0.1, 0.5),
('shear_y', 0.1, 0.5),
('contrast', 0.3, 0.5),
('brightness', 0.3, 0.5),
('sharpness', 0.3, 0.5),
('posterize', 4, 0.5),
('solarize', 256, 0.5),
]
def apply_policy(self, image):
"""应用随机策略。"""
# 选择 2 个随机策略
selected_policies = random.sample(self.policies, 2)
augmented = image
for policy_name, magnitude, prob in selected_policies:
if random.random() < prob:
augmented = self._apply_transform(augmented, policy_name, magnitude)
return augmented
def _apply_transform(self, image, transform_name, magnitude):
"""应用特定变换。"""
# 实现每个变换
if transform_name == 'rotate':
return self._rotate(image, magnitude)
elif transform_name == 'translate_x':
return self._translate_x(image, magnitude)
# ... 其他变换
return image
# 使用
autoaugment = AutoAugment()
augmented_image = autoaugment.apply_policy(original_image)
RandAugment
class RandAugment:
"""RandAugment 策略。"""
def __init__(self, n=2, m=10):
self.n = n # 增强数量
self.m = m # 幅度
self.augmentations = [
self._rotate,
self._translate_x,
self._translate_y,
self._shear_x,
self._shear_y,
self._contrast,
self._brightness,
self._sharpness,
]
def __call__(self, image):
"""应用 RandAugment。"""
# 选择 n 个随机增强
selected = random.choices(self.augmentations, k=self.n)
augmented = image
for aug in selected:
augmented = aug(augmented, self.m)
return augmented
def _rotate(self, image, magnitude):
"""随机旋转。"""
angle = random.uniform(-magnitude, magnitude)
return self._rotate_image(image, angle)
def _contrast(self, image, magnitude):
"""随机对比度。"""
factor = random.uniform(1 - magnitude/10, 1 + magnitude/10)
return self._adjust_contrast(image, factor)
# ... 其他增强方法
# 使用
randaugment = RandAugment(n=2, m=10)
augmented_image = randaugment(original_image)
最佳实践
-
从简单开始
- 从基本增强开始(翻转、旋转)
- 逐渐增加复杂性
- 监控对模型性能的影响
-
使用适当的增强
- 分类:更激进的增强
- 检测:小心空间增强(需要调整边界框)
- 分割:对图像和掩码应用相同的增强
-
不要增强验证/测试集
- 只应用归一化,不应用增强
- 使用测试时增强 (TTA) 进行推理
-
监控性能
- 跟踪有和无增强的训练/验证损失
- 使用可视化验证增强
- 检查增强后的标签保留
-
使用基于概率的应用
- 以适当的概率应用增强
- 避免过度增强导致数据失真
- 平衡多样性和数据质量
-
处理类别不平衡
- 对表格数据使用 SMOTE 或 ADASYN
- 对少数类别应用更多增强
- 考虑加权采样
-
优化性能
- 使用多进程进行批次增强
- 尽可能缓存增强样本
- 使用高效库如 Albumentations
-
调试增强
- 可视化增强样本
- 检查标签保留
- 验证增强管道的正确性
-
考虑任务特定需求
- 医学影像:有限的增强
- 卫星影像:旋转不变增强
- 文本:保留语义含义
-
可重复性
- 设置随机种子以获得一致结果
- 缓存增强以进行调试
- 文档化增强管道