数据增强 DataAugmentation

数据增强是一种机器学习技术,用于通过创建现有数据的修改版本来人工增加训练数据集的大小和多样性,从而提高模型的泛化能力。它涵盖图像、文本、音频和表格数据的增强方法,包括几何变换、同义词替换、音高移动和过采样等。关键词:数据增强、机器学习、训练数据、图像处理、文本处理、音频处理、表格数据、增强策略、SMOTE、反向翻译。

机器学习 0 次安装 0 次浏览 更新于 3/5/2026

名称: 数据增强 描述: 涵盖图像、文本、音频和表格数据的数据增强技术的全面指南。

数据增强

概述

数据增强是一种通过创建现有数据的修改版本来人工增加训练数据集大小和多样性的技术。本技能涵盖图像、文本、音频和表格数据的增强技术,包括流行的库如 Albumentations、NLPAug 和自定义增强策略。

先决条件

  • 理解机器学习概念
  • 了解 Python 编程
  • 熟悉数据预处理
  • 理解过拟合和泛化
  • 图像、文本、音频处理的基础知识

关键概念

增强类型

  • 图像增强: 几何变换、颜色调整、噪声注入
  • 文本增强: 反向翻译、同义词替换、单词插入/删除
  • 音频增强: 时间拉伸、音高移动、噪声添加、掩码
  • 表格增强: SMOTE、ADASYN、高斯噪声、特征混合

增强策略

  • 在线增强: 在训练期间应用增强
  • 离线增强: 预计算增强样本
  • 测试时增强 (TTA): 在推理时应用多个增强
  • AutoAugment: 自动搜索最优增强策略

常见库

  • Albumentations: 快速灵活的图像增强库
  • Torchvision: PyTorch 的内置变换
  • ImgAug: 强大的图像增强库
  • NLPAug: 文本增强库
  • Imbalanced-learn: 表格数据增强 (SMOTE, ADASYN)

实现指南

图像增强

Albumentations

import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2
import numpy as np

class AlbumentationsAugmentor:
    """使用 Albumentations 进行图像增强。"""

    def __init__(self, mode='train', image_size=(224, 224)):
        self.mode = mode
        self.image_size = image_size
        self.transform = self._get_transform()

    def _get_transform(self):
        """获取变换管道。"""
        if self.mode == 'train':
            return A.Compose([
                A.Resize(height=self.image_size[0], width=self.image_size[1]),
                A.HorizontalFlip(p=0.5),
                A.VerticalFlip(p=0.2),
                A.RandomRotate90(p=0.5),
                A.Rotate(limit=30, p=0.5),
                A.ShiftScaleRotate(
                    shift_limit=0.1,
                    scale_limit=0.1,
                    rotate_limit=15,
                    p=0.5
                ),
                A.OneOf([
                    A.GaussNoise(p=1.0),
                    A.ISONoise(p=1.0),
                    A.MultiplicativeNoise(p=1.0),
                ], p=0.2),
                A.OneOf([
                    A.MotionBlur(p=1.0),
                    A.MedianBlur(p=1.0),
                    A.GaussianBlur(p=1.0),
                ], p=0.2),
                A.OneOf([
                    A.OpticalDistortion(p=1.0),
                    A.GridDistortion(p=1.0),
                    A.IAAPiecewiseAffine(p=1.0),
                ], p=0.2),
                A.OneOf([
                    A.CLAHE(clip_limit=2),
                    A.IAASharpen(),
                    A.IAAEmboss(),
                    A.RandomBrightnessContrast(),
                ], p=0.3),
                A.HueSaturationValue(p=0.3),
                A.RandomBrightnessContrast(p=0.3),
                A.Cutout(num_holes=8, max_h_size=16, max_w_size=16, p=0.3),
                A.CoarseDropout(
                    max_holes=8,
                    max_height=32,
                    max_width=32,
                    min_holes=1,
                    min_height=8,
                    min_width=8,
                    p=0.3
                ),
                A.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                ),
                ToTensorV2()
            ])
        else:
            return A.Compose([
                A.Resize(height=self.image_size[0], width=self.image_size[1]),
                A.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                ),
                ToTensorV2()
            ])

    def __call__(self, image):
        """应用增强。"""
        return self.transform(image=image)['image']

# 使用
train_augmentor = AlbumentationsAugmentor(mode='train', image_size=(224, 224))
val_augmentor = AlbumentationsAugmentor(mode='val', image_size=(224, 224))

# 应用增强
augmented_image = train_augmentor(original_image)

Torchvision 变换

import torchvision.transforms as transforms
from torchvision.transforms import functional as F

class TorchvisionAugmentor:
    """使用 torchvision 进行图像增强。"""

    def __init__(self, mode='train', image_size=224):
        self.mode = mode
        self.image_size = image_size
        self.transform = self._get_transform()

    def _get_transform(self):
        """获取变换管道。"""
        if self.mode == 'train':
            return transforms.Compose([
                transforms.Resize((self.image_size, self.image_size)),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomVerticalFlip(p=0.2),
                transforms.RandomRotation(degrees=30),
                transforms.RandomAffine(
                    degrees=0,
                    translate=(0.1, 0.1),
                    scale=(0.9, 1.1),
                    shear=10
                ),
                transforms.ColorJitter(
                    brightness=0.3,
                    contrast=0.3,
                    saturation=0.3,
                    hue=0.1
                ),
                transforms.RandomGrayscale(p=0.1),
                transforms.RandomPerspective(distortion_scale=0.2, p=0.3),
                transforms.RandomResizedCrop(
                    self.image_size,
                    scale=(0.8, 1.0),
                    ratio=(0.9, 1.1)
                ),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                )
            ])
        else:
            return transforms.Compose([
                transforms.Resize((self.image_size, self.image_size)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=[0.485, 0.456, 0.406],
                    std=[0.229, 0.224, 0.225]
                )
            ])

    def __call__(self, image):
        """应用增强。"""
        return self.transform(image)

# 使用
train_transform = TorchvisionAugmentor(mode='train', image_size=224)
augmented_image = train_transform(pil_image)

ImgAug

import imgaug as ia
import imgaug.augmenters as iaa

class ImgAugAugmentor:
    """使用 imgaug 进行图像增强。"""

    def __init__(self, mode='train'):
        self.mode = mode
        self.augmenter = self._get_augmenter()

    def _get_augmenter(self):
        """获取增强管道。"""
        if self.mode == 'train':
            return iaa.Sequential([
                iaa.Fliplr(0.5),  # 水平翻转
                iaa.Flipud(0.2),  # 垂直翻转
                iaa.Affine(
                    rotate=(-30, 30),
                    scale=(0.9, 1.1),
                    translate_percent=(-0.1, 0.1)
                ),
                iaa.Multiply((0.8, 1.2)),  # 亮度
                iaa.LinearContrast((0.8, 1.2)),  # 对比度
                iaa.AdditiveGaussianNoise(scale=(0, 0.05 * 255)),
                iaa.GaussianBlur(sigma=(0, 1.0)),
                iaa.Dropout(p=(0, 0.2)),
                iaa.CoarseDropout(
                    (0.0, 0.05),
                    size_percent=(0.02, 0.05)
                ),
                iaa.Crop(percent=(0, 0.1)),
                iaa.Pad(percent=(0, 0.1)),
                iaa.ElasticTransformation(alpha=(0, 50), sigma=(0, 5)),
                iaa.PiecewiseAffine(scale=(0.01, 0.05)),
            ])
        else:
            return iaa.Sequential([])

    def __call__(self, image):
        """应用增强。"""
        if self.mode == 'train':
            image_aug = self.augmenter.augment_image(image)
            return image_aug
        return image

# 使用
train_augmentor = ImgAugAugmentor(mode='train')
augmented_image = train_augmentor(numpy_image)

文本增强

反向翻译

from deep_translator import GoogleTranslator
import random

class BackTranslationAugmentor:
    """使用反向翻译进行文本增强。"""

    def __init__(self, languages=['fr', 'de', 'es']):
        self.languages = languages

    def augment(self, text, num_augmentations=1):
        """使用反向翻译增强文本。"""
        augmented_texts = []

        for _ in range(num_augmentations):
            # 选择随机中间语言
            intermediate_lang = random.choice(self.languages)

            try:
                # 翻译到中间语言
                translated = GoogleTranslator(
                    source='auto',
                    target=intermediate_lang
                ).translate(text)

                # 翻译回英文
                back_translated = GoogleTranslator(
                    source=intermediate_lang,
                    target='en'
                ).translate(translated)

                augmented_texts.append(back_translated)
            except Exception as e:
                print(f"反向翻译失败: {e}")
                augmented_texts.append(text)

        return augmented_texts

# 使用
augmentor = BackTranslationAugmentor(languages=['fr', 'de', 'es'])
augmented_texts = augmentor.augment("This is a sample text for augmentation.")

同义词替换

import nltk
from nltk.corpus import wordnet
import random

nltk.download('wordnet')
nltk.download('omw-1.4')

class SynonymReplacementAugmentor:
    """使用同义词替换进行文本增强。"""

    def __init__(self, replacement_prob=0.3):
        self.replacement_prob = replacement_prob

    def get_synonyms(self, word):
        """获取单词的同义词。"""
        synonyms = set()
        for syn in wordnet.synsets(word):
            for lemma in syn.lemmas():
                synonym = lemma.name().replace('_', ' ')
                if synonym.lower() != word.lower():
                    synonyms.add(synonym)
        return list(synonyms)

    def augment(self, text):
        """使用同义词替换增强文本。"""
        words = text.split()
        augmented_words = []

        for word in words:
            if random.random() < self.replacement_prob:
                synonyms = self.get_synonyms(word)
                if synonyms:
                    augmented_words.append(random.choice(synonyms))
                else:
                    augmented_words.append(word)
            else:
                augmented_words.append(word)

        return ' '.join(augmented_words)

# 使用
augmentor = SynonymReplacementAugmentor(replacement_prob=0.3)
augmented_text = augmentor.augment("The quick brown fox jumps over lazy dog")

NLPAug

import nlpaug.augmenter.word as naw
import nlpaug.augmenter.char as nac
import nlpaug.augmenter.sentence as nas

class NLPAugAugmentor:
    """使用 NLPAug 进行文本增强。"""

    def __init__(self):
        # 词级别增强器
        self.synonym_aug = naw.SynonymAug(aug_src='wordnet')
        self.contextual_aug = naw.ContextualWordEmbsAug(model_path='bert-base-uncased')
        self.back_translation_aug = naw.BackTranslationAug(
            from_model_name='facebook/wmt19-en-de',
            to_model_name='facebook/wmt19-de-en'
        )
        self.insertion_aug = naw.RandomWordAug(action="insert")
        self.swap_aug = naw.RandomWordAug(action="swap")
        self.deletion_aug = naw.RandomWordAug(action="delete")

        # 字符级别增强器
        self.ocr_aug = nac.OcrAug()
        self.keyboard_aug = nac.KeyboardAug()

        # 句子级别增强器
        self.synonym_sentence_aug = nas.ContextualWordEmbsForSentenceAug(
            model_path='bert-base-uncased'
        )

    def augment_word_level(self, text, method='synonym', num_augmentations=1):
        """词级别增强。"""
        augmented_texts = []

        for _ in range(num_augmentations):
            if method == 'synonym':
                augmented = self.synonym_aug.augment(text)
            elif method == 'contextual':
                augmented = self.contextual_aug.augment(text)
            elif method == 'insert':
                augmented = self.insertion_aug.augment(text)
            elif method == 'swap':
                augmented = self.swap_aug.augment(text)
            elif method == 'delete':
                augmented = self.deletion_aug.augment(text)
            else:
                augmented = text

            augmented_texts.append(augmented)

        return augmented_texts

    def augment_char_level(self, text, method='ocr', num_augmentations=1):
        """字符级别增强。"""
        augmented_texts = []

        for _ in range(num_augmentations):
            if method == 'ocr':
                augmented = self.ocr_aug.augment(text)
            elif method == 'keyboard':
                augmented = self.keyboard_aug.augment(text)
            else:
                augmented = text

            augmented_texts.append(augmented)

        return augmented_texts

    def augment_sentence_level(self, text, num_augmentations=1):
        """句子级别增强。"""
        augmented_texts = []

        for _ in range(num_augmentations):
            augmented = self.synonym_sentence_aug.augment(text)
            augmented_texts.append(augmented)

        return augmented_texts

# 使用
augmentor = NLPAugAugmentor()

# 词级别增强
augmented_texts = augmentor.augment_word_level(
    "This is a sample text.",
    method='synonym',
    num_augmentations=3
)

# 字符级别增强
augmented_texts = augmentor.augment_char_level(
    "This is a sample text.",
    method='ocr',
    num_augmentations=3
)

音频增强

import numpy as np
import librosa
import soundfile as sf
from scipy.signal import butter, lfilter

class AudioAugmentor:
    """用于语音和音乐的音频增强。"""

    def __init__(self, sample_rate=16000):
        self.sample_rate = sample_rate

    def add_noise(self, audio, noise_factor=0.005):
        """添加高斯噪声。"""
        noise = np.random.randn(len(audio))
        augmented = audio + noise_factor * noise
        return augmented

    def time_shift(self, audio, shift_max=0.2):
        """随机时间偏移音频。"""
        shift = int(np.random.uniform(-shift_max, shift_max) * len(audio))
        return np.roll(audio, shift)

    def pitch_shift(self, audio, n_steps=2):
        """音高移动。"""
        return librosa.effects.pitch_shift(audio, sr=self.sample_rate, n_steps=n_steps)

    def speed_change(self, audio, speed_factor=1.2):
        """改变速度。"""
        return librosa.effects.time_stretch(audio, rate=speed_factor)

    def time_stretch(self, audio, rate=1.2):
        """时间拉伸(改变持续时间)。"""
        return librosa.effects.time_stretch(audio, rate=rate)

    def add_reverb(self, audio, decay=0.5):
        """添加混响。"""
        delay = int(0.05 * self.sample_rate)
        impulse = np.zeros(delay + int(decay * self.sample_rate))
        impulse[delay] = 1
        impulse[delay:] *= np.exp(-np.arange(len(impulse) - delay) / (decay * self.sample_rate))

        augmented = np.convolve(audio, impulse)[:len(audio)]
        return augmented

    def frequency_mask(self, audio, freq_mask_param=10):
        """频率掩码(用于频谱图)。"""
        freq_mask = np.random.randint(0, freq_mask_param + 1)
        f0 = np.random.uniform(0, freq_mask_param - freq_mask)
        f0 = int(f0)
        augmented = audio.copy()
        augmented[f0:f0 + freq_mask, :] = 0
        return augmented

    def time_mask(self, audio, time_mask_param=10):
        """时间掩码(用于频谱图)。"""
        time_mask = np.random.randint(0, time_mask_param + 1)
        t0 = np.random.uniform(0, time_mask_param - time_mask)
        t0 = int(t0)
        augmented = audio.copy()
        augmented[:, t0:t0 + time_mask] = 0
        return augmented

    def augment(self, audio, augmentations=None):
        """应用随机增强。"""
        if augmentations is None:
            augmentations = [
                ('noise', 0.3),
                ('time_shift', 0.3),
                ('pitch_shift', 0.2),
                ('speed_change', 0.2)
            ]

        augmented = audio.copy()

        for aug_name, prob in augmentations:
            if np.random.random() < prob:
                if aug_name == 'noise':
                    augmented = self.add_noise(augmented)
                elif aug_name == 'time_shift':
                    augmented = self.time_shift(augmented)
                elif aug_name == 'pitch_shift':
                    n_steps = np.random.randint(-2, 3)
                    augmented = self.pitch_shift(augmented, n_steps)
                elif aug_name == 'speed_change':
                    speed = np.random.uniform(0.8, 1.2)
                    augmented = self.speed_change(augmented, speed)

        return augmented

# 使用
augmentor = AudioAugmentor(sample_rate=16000)
augmented_audio = augmentor.augment(original_audio)

表格数据增强

import numpy as np
import pandas as pd
from imblearn.over_sampling import SMOTE, ADASYN
from imblearn.under_sampling import RandomUnderSampler

class TabularAugmentor:
    """表格数据的增强。"""

    def __init__(self):
        self.smote = None
        self.adasyn = None

    def smote_augmentation(self, X, y, sampling_strategy='auto'):
        """SMOTE 过采样。"""
        self.smote = SMOTE(sampling_strategy=sampling_strategy, random_state=42)
        X_resampled, y_resampled = self.smote.fit_resample(X, y)
        return X_resampled, y_resampled

    def adasyn_augmentation(self, X, y, sampling_strategy='auto'):
        """ADASYN 过采样。"""
        self.adasyn = ADASYN(sampling_strategy=sampling_strategy, random_state=42)
        X_resampled, y_resampled = self.adasyn.fit_resample(X, y)
        return X_resampled, y_resampled

    def random_oversampling(self, X, y):
        """随机过采样。"""
        from imblearn.over_sampling import RandomOverSampler
        ros = RandomOverSampler(random_state=42)
        X_resampled, y_resampled = ros.fit_resample(X, y)
        return X_resampled, y_resampled

    def random_undersampling(self, X, y, sampling_strategy='auto'):
        """随机欠采样。"""
        rus = RandomUnderSampler(sampling_strategy=sampling_strategy, random_state=42)
        X_resampled, y_resampled = rus.fit_resample(X, y)
        return X_resampled, y_resampled

    def gaussian_noise_augmentation(self, X, noise_level=0.01):
        """向特征添加高斯噪声。"""
        noise = np.random.normal(0, noise_level, X.shape)
        X_augmented = X + noise
        return X_augmented

    def feature_mixup(self, X, y, alpha=0.2, n_samples=100):
        """特征混合增强。"""
        X_augmented = []
        y_augmented = []

        for _ in range(n_samples):
            # 采样两个随机索引
            idx1, idx2 = np.random.choice(len(X), 2, replace=False)

            # 混合特征
            lam = np.random.beta(alpha, alpha)
            x_mixed = lam * X[idx1] + (1 - lam) * X[idx2]

            X_augmented.append(x_mixed)
            y_augmented.append(y[idx1])  # 使用第一个样本的标签

        return np.array(X_augmented), np.array(y_augmented)

    def bootstrap_sampling(self, X, y, n_samples=None):
        """自助采样。"""
        if n_samples is None:
            n_samples = len(X)

        indices = np.random.choice(len(X), n_samples, replace=True)
        X_bootstrapped = X[indices]
        y_bootstrapped = y[indices]

        return X_bootstrapped, y_bootstrapped

# 使用
augmentor = TabularAugmentor()

# SMOTE 增强
X_augmented, y_augmented = augmentor.smote_augmentation(X_train, y_train)

# 高斯噪声增强
X_noisy = augmentor.gaussian_noise_augmentation(X_train, noise_level=0.01)

# 特征混合
X_mixup, y_mixup = augmentor.feature_mixup(X_train, y_train, n_samples=100)

增强策略

在线 vs 离线增强

class OnlineAugmentation:
    """在训练期间应用增强(在线)。"""

    def __init__(self, augmentor):
        self.augmentor = augmentor

    def __call__(self, batch):
        """应用增强到批次。"""
        augmented_batch = []
        for item in batch:
            augmented_item = self.augmentor(item)
            augmented_batch.append(augmented_item)
        return augmented_batch

# 使用 PyTorch DataLoader
from torch.utils.data import Dataset, DataLoader

class OnlineAugmentedDataset(Dataset):
    """在线增强的数据集。"""

    def __init__(self, data, labels, augmentor):
        self.data = data
        self.labels = labels
        self.augmentor = augmentor

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        label = self.labels[idx]

        # 在训练期间应用增强
        augmented_item = self.augmentor(item)

        return augmented_item, label

# 创建在线增强的数据集
dataset = OnlineAugmentedDataset(X_train, y_train, augmentor)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
class OfflineAugmentation:
    """预计算增强样本(离线)。"""

    def __init__(self, augmentor, augmentations_per_sample=2):
        self.augmentor = augmentor
        self.augmentations_per_sample = augmentations_per_sample

    def augment_dataset(self, X, y):
        """创建增强数据集。"""
        X_augmented = []
        y_augmented = []

        for i in range(len(X)):
            # 添加原始样本
            X_augmented.append(X[i])
            y_augmented.append(y[i])

            # 添加增强样本
            for _ in range(self.augmentations_per_sample):
                augmented = self.augmentor(X[i])
                X_augmented.append(augmented)
                y_augmented.append(y[i])

        return np.array(X_augmented), np.array(y_augmented)

# 使用
offline_augmentor = OfflineAugmentation(augmentor, augmentations_per_sample=2)
X_augmented, y_augmented = offline_augmentor.augment_dataset(X_train, y_train)

概率设置

class ProbabilisticAugmentor:
    """基于概率应用的增强器。"""

    def __init__(self):
        self.augmentations = []

    def add_augmentation(self, augmentor, probability=0.5):
        """添加带概率的增强。"""
        self.augmentations.append((augmentor, probability))

    def augment(self, data):
        """基于概率应用增强。"""
        augmented_data = data

        for augmentor, prob in self.augmentations:
            if np.random.random() < prob:
                augmented_data = augmentor(augmented_data)

        return augmented_data

# 使用
prob_augmentor = ProbabilisticAugmentor()
prob_augmentor.add_augmentation(horizontal_flip, probability=0.5)
prob_augmentor.add_augmentation(rotation, probability=0.3)
prob_augmentor.add_augmentation(color_jitter, probability=0.4)

augmented_data = prob_augmentor.augment(original_data)

自定义增强

import numpy as np
from PIL import Image, ImageFilter, ImageEnhance

class CustomAugmentor:
    """自定义增强函数。"""

    @staticmethod
    def cutout(image, n_holes=8, max_h_size=16, max_w_size=16):
        """Cutout 增强。"""
        if isinstance(image, Image.Image):
            image = np.array(image)

        h, w = image.shape[:2]

        for _ in range(n_holes):
            y = np.random.randint(h)
            x = np.random.randint(w)

            y1 = np.clip(y - max_h_size // 2, 0, h)
            y2 = np.clip(y + max_h_size // 2, 0, h)
            x1 = np.clip(x - max_w_size // 2, 0, w)
            x2 = np.clip(x + max_w_size // 2, 0, w)

            if len(image.shape) == 3:
                image[y1:y2, x1:x2, :] = 0
            else:
                image[y1:y2, x1:x2] = 0

        return image

    @staticmethod
    def mixup(image1, image2, alpha=0.2):
        """Mixup 增强。"""
        if isinstance(image1, Image.Image):
            image1 = np.array(image1)
        if isinstance(image2, Image.Image):
            image2 = np.array(image2)

        lam = np.random.beta(alpha, alpha)
        mixed = lam * image1 + (1 - lam) * image2

        return mixed.astype(np.uint8)

    @staticmethod
    def cutmix(image1, image2, alpha=1.0):
        """CutMix 增强。"""
        if isinstance(image1, Image.Image):
            image1 = np.array(image1)
        if isinstance(image2, Image.Image):
            image2 = np.array(image2)

        h, w = image1.shape[:2]

        # 生成随机边界框
        lam = np.random.beta(alpha, alpha)
        cut_rat = np.sqrt(1. - lam)
        cut_w = int(w * cut_rat)
        cut_h = int(h * cut_rat)

        cx = np.random.randint(w)
        cy = np.random.randint(h)

        bbx1 = np.clip(cx - cut_w // 2, 0, w)
        bby1 = np.clip(cy - cut_h // 2, 0, h)
        bbx2 = np.clip(cx + cut_w // 2, 0, w)
        bby2 = np.clip(cy + cut_h // 2, 0, h)

        # 应用 cutmix
        mixed = image1.copy()
        mixed[bby1:bby2, bbx1:bbx2] = image2[bby1:bby2, bbx1:bbx2]

        return mixed

    @staticmethod
    def mosaic(images):
        """Mosaic 增强(4 张图像)。"""
        if len(images) != 4:
            raise ValueError("Mosaic 需要恰好 4 张图像")

        h, w = images[0].shape[:2]
        mosaic = np.zeros((h * 2, w * 2, images[0].shape[2]), dtype=images[0].dtype)

        # 将图像放置在 2x2 网格中
        mosaic[:h, :w] = images[0]
        mosaic[:h, w:] = images[1]
        mosaic[h:, :w] = images[2]
        mosaic[h:, w:] = images[3]

        return mosaic

# 使用
# Cutout
cutout_image = CustomAugmentor.cutout(image, n_holes=8, max_h_size=32)

# Mixup
mixed_image = CustomAugmentor.mixup(image1, image2, alpha=0.2)

# CutMix
cutmix_image = CustomAugmentor.cutmix(image1, image2, alpha=1.0)

# Mosaic
mosaic_image = CustomAugmentor.mosaic([img1, img2, img3, img4])

验证集处理

class ValidationAugmentation:
    """处理验证集的增强。"""

    def __init__(self, augmentor):
        self.augmentor = augmentor

    def augment_with_test_time_augmentation(self, data, n_augmentations=5):
        """测试时增强 (TTA)。"""
        augmented_samples = []

        # 创建多个增强版本
        for _ in range(n_augmentations):
            augmented = self.augmentor(data)
            augmented_samples.append(augmented)

        return augmented_samples

    def average_predictions(self, predictions):
        """平均 TTA 的预测。"""
        return np.mean(predictions, axis=0)

# 使用
val_augmentor = ValidationAugmentation(augmentor)

# 测试时增强
augmented_samples = val_augmentor.augment_with_test_time_augmentation(
    validation_sample,
    n_augmentations=5
)

# 获取每个增强样本的预测
predictions = [model.predict(sample) for sample in augmented_samples]

# 平均预测
final_prediction = val_augmentor.average_predictions(predictions)

生产考虑

高效增强

import multiprocessing as mp
from functools import partial

class EfficientAugmentor:
    """使用多进程进行高效增强。"""

    def __init__(self, augmentor, n_workers=None):
        self.augmentor = augmentor
        self.n_workers = n_workers or mp.cpu_count()

    def augment_batch(self, batch):
        """增强一批样本。"""
        with mp.Pool(self.n_workers) as pool:
            augmented_batch = pool.map(self.augmentor, batch)
        return augmented_batch

    def augment_dataset(self, dataset, batch_size=32):
        """高效增强整个数据集。"""
        augmented_data = []

        for i in range(0, len(dataset), batch_size):
            batch = dataset[i:i + batch_size]
            augmented_batch = self.augment_batch(batch)
            augmented_data.extend(augmented_batch)

        return augmented_data

# 使用
efficient_augmentor = EfficientAugmentor(augmentor, n_workers=4)
augmented_data = efficient_augmentor.augment_dataset(X_train, batch_size=32)

增强缓存

import hashlib
import pickle
from pathlib import Path

class CachedAugmentor:
    """带缓存的增强器,用于可重复性。"""

    def __init__(self, augmentor, cache_dir='./augmentation_cache'):
        self.augmentor = augmentor
        self.cache_dir = Path(cache_dir)
        self.cache_dir.mkdir(parents=True, exist_ok=True)

    def _get_cache_key(self, data):
        """生成数据的缓存键。"""
        if isinstance(data, np.ndarray):
            data_bytes = data.tobytes()
        elif isinstance(data, (str, bytes)):
            data_bytes = data.encode() if isinstance(data, str) else data
        else:
            data_bytes = pickle.dumps(data)

        return hashlib.md5(data_bytes).hexdigest()

    def augment(self, data, use_cache=True):
        """带缓存的增强。"""
        cache_key = self._get_cache_key(data)
        cache_file = self.cache_dir / f"{cache_key}.pkl"

        if use_cache and cache_file.exists():
            with open(cache_file, 'rb') as f:
                return pickle.load(f)

        # 应用增强
        augmented = self.augmentor(data)

        # 缓存结果
        if use_cache:
            with open(cache_file, 'wb') as f:
                pickle.dump(augmented, f)

        return augmented

# 使用
cached_augmentor = CachedAugmentor(augmentor)
augmented_data = cached_augmentor.augment(original_data)

常见模式

AutoAugment

import random

class AutoAugment:
    """AutoAugment 策略。"""

    def __init__(self):
        self.policies = [
            ('rotate', 30, 0.5),
            ('translate_x', 0.1, 0.5),
            ('translate_y', 0.1, 0.5),
            ('shear_x', 0.1, 0.5),
            ('shear_y', 0.1, 0.5),
            ('contrast', 0.3, 0.5),
            ('brightness', 0.3, 0.5),
            ('sharpness', 0.3, 0.5),
            ('posterize', 4, 0.5),
            ('solarize', 256, 0.5),
        ]

    def apply_policy(self, image):
        """应用随机策略。"""
        # 选择 2 个随机策略
        selected_policies = random.sample(self.policies, 2)

        augmented = image
        for policy_name, magnitude, prob in selected_policies:
            if random.random() < prob:
                augmented = self._apply_transform(augmented, policy_name, magnitude)

        return augmented

    def _apply_transform(self, image, transform_name, magnitude):
        """应用特定变换。"""
        # 实现每个变换
        if transform_name == 'rotate':
            return self._rotate(image, magnitude)
        elif transform_name == 'translate_x':
            return self._translate_x(image, magnitude)
        # ... 其他变换

        return image

# 使用
autoaugment = AutoAugment()
augmented_image = autoaugment.apply_policy(original_image)

RandAugment

class RandAugment:
    """RandAugment 策略。"""

    def __init__(self, n=2, m=10):
        self.n = n  # 增强数量
        self.m = m  # 幅度

        self.augmentations = [
            self._rotate,
            self._translate_x,
            self._translate_y,
            self._shear_x,
            self._shear_y,
            self._contrast,
            self._brightness,
            self._sharpness,
        ]

    def __call__(self, image):
        """应用 RandAugment。"""
        # 选择 n 个随机增强
        selected = random.choices(self.augmentations, k=self.n)

        augmented = image
        for aug in selected:
            augmented = aug(augmented, self.m)

        return augmented

    def _rotate(self, image, magnitude):
        """随机旋转。"""
        angle = random.uniform(-magnitude, magnitude)
        return self._rotate_image(image, angle)

    def _contrast(self, image, magnitude):
        """随机对比度。"""
        factor = random.uniform(1 - magnitude/10, 1 + magnitude/10)
        return self._adjust_contrast(image, factor)

    # ... 其他增强方法

# 使用
randaugment = RandAugment(n=2, m=10)
augmented_image = randaugment(original_image)

最佳实践

  1. 从简单开始

    • 从基本增强开始(翻转、旋转)
    • 逐渐增加复杂性
    • 监控对模型性能的影响
  2. 使用适当的增强

    • 分类:更激进的增强
    • 检测:小心空间增强(需要调整边界框)
    • 分割:对图像和掩码应用相同的增强
  3. 不要增强验证/测试集

    • 只应用归一化,不应用增强
    • 使用测试时增强 (TTA) 进行推理
  4. 监控性能

    • 跟踪有和无增强的训练/验证损失
    • 使用可视化验证增强
    • 检查增强后的标签保留
  5. 使用基于概率的应用

    • 以适当的概率应用增强
    • 避免过度增强导致数据失真
    • 平衡多样性和数据质量
  6. 处理类别不平衡

    • 对表格数据使用 SMOTE 或 ADASYN
    • 对少数类别应用更多增强
    • 考虑加权采样
  7. 优化性能

    • 使用多进程进行批次增强
    • 尽可能缓存增强样本
    • 使用高效库如 Albumentations
  8. 调试增强

    • 可视化增强样本
    • 检查标签保留
    • 验证增强管道的正确性
  9. 考虑任务特定需求

    • 医学影像:有限的增强
    • 卫星影像:旋转不变增强
    • 文本:保留语义含义
  10. 可重复性

    • 设置随机种子以获得一致结果
    • 缓存增强以进行调试
    • 文档化增强管道

相关技能