模型优化
概览
模型优化是改进机器学习模型以用于生产部署的过程,通过减小模型大小、提高推理速度并保持准确度。这个技能涵盖了量化、剪枝、知识蒸馏、模型压缩、架构优化、推理优化、ONNX优化、TensorRT集成和基准测试工具。
前置条件
- 理解PyTorch和深度学习
- 了解模型架构和训练
- 熟悉模型部署概念
- 理解精度(FP32, FP16, INT8)
- 基本了解模型服务
核心概念
量化
- 动态量化:推理期间即时量化权重
- 静态量化:预量化权重与校准数据
- 量化感知训练(QAT):训练时考虑量化
- 逐通道量化:每个输出通道单独量化
- INT8/FP16:降低精度格式以提高效率
剪枝
- 结构化剪枝:移除整个通道/过滤器
- 非结构化剪枝:基于大小移除个别权重
- 全局与局部剪枝:模型范围内的剪枝
- 迭代剪枝:逐步剪枝与微调
知识蒸馏
- 教师-学生:大型教师模型训练小型学生模型
- 软目标:使用教师模型软化的输出作为目标
- 特征蒸馏:匹配中间表示
- 自我蒸馏:模型自我教学(EMA)
模型压缩
- 权重共享:K-means聚类共享权重
- 低秩分解:基于SVD的层分解
- 架构设计:高效架构(MobileNet等)
推理优化
- 批处理:一起处理多个输入
- 缓存:重复计算结果
- GPU优化:cuDNN、半精度、内核融合
- ONNX/TensorRT:特定硬件优化
实施指南
量化
训练后量化
动态量化:
import torch
import torch.nn as nn
def apply_dynamic_quantization(model, layers_to_quantize=[nn.Linear, nn.LSTM]):
"""Apply dynamic quantization to model."""
quantized_model = torch.quantization.quantize_dynamic(
model,
layers_to_quantize,
dtype=torch.qint8
)
return quantized_model
# Example
model = MyModel()
quantized_model = apply_dynamic_quantization(model)
# Save quantized model
torch.save(quantized_model.state_dict(), "quantized_model.pth")
# Load and use
quantized_model = MyModel()
quantized_model.load_state_dict(torch.load("quantized_model.pth"))
quantized_model.eval()
# Compare sizes
original_size = get_model_size(model)
quantized_size = get_model_size(quantized_model)
print(f"Original size: {original_size:.2f} MB")
print(f"Quantized size: {quantized_size:.2f} MB")
print(f"Reduction: {(1 - quantized_size/original_size)*100:.1f}%")
def get_model_size(model):
"""Get model size in MB."""
param_size = 0
for param in model.parameters():
param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in model.buffers():
buffer_size += buffer.nelement() * buffer.element_size()
return (param_size + buffer_size) / 1024**2
静态量化:
import torch
from torch.quantization import prepare, convert, get_default_qconfig
def apply_static_quantization(model, calibration_dataloader):
"""Apply static quantization with calibration."""
model.eval()
# Set quantization configuration
model.qconfig = get_default_qconfig('fbgemm')
# Prepare model for quantization
prepared_model = prepare(model)
# Calibrate with representative data
print("Calibrating model...")
with torch.no_grad():
for inputs, _ in calibration_dataloader:
prepared_model(inputs)
# Convert to quantized model
quantized_model = convert(prepared_model)
return quantized_model
# Usage
model = MyModel()
calibration_loader = get_calibration_loader()
quantized_model = apply_static_quantization(model, calibration_loader)
逐通道量化:
def apply_per_channel_quantization(model):
"""Apply per-channel quantization for better accuracy."""
model.eval()
# Per-channel quantization configuration
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
# For Conv2d and Linear layers, use per-channel quantization
for name, module in model.named_modules():
if isinstance(module, (nn.Conv2d, nn.Linear)):
module.qconfig = torch.quantization.QConfig(
activation=torch.quantization.MinMaxObserver.with_args(
dtype=torch.qint8
),
weight=torch.quantization.PerChannelMinMaxObserver.with_args(
dtype=torch.qint8,
qscheme=torch.per_channel_symmetric
)
)
prepared_model = prepare(model)
# Calibrate...
quantized_model = convert(prepared_model)
return quantized_model
量化感知训练 (QAT)
import torch
import torch.nn as nn
from torch.quantization import prepare_qat, convert, get_default_qat_qconfig
def quantization_aware_training(model, train_loader, val_loader, epochs=10, lr=0.001):
"""Train model with quantization awareness."""
model.train()
# Set QAT configuration
model.qconfig = get_default_qat_qconfig('fbgemm')
# Prepare model for QAT
model_prepared = prepare_qat(model, inplace=True)
# Setup optimizer
optimizer = torch.optim.Adam(model_prepared.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
# Training loop
for epoch in range(epochs):
model_prepared.train()
for inputs, targets in train_loader:
optimizer.zero_grad()
outputs = model_prepared(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
# Validate
model_prepared.eval()
val_loss = 0
with torch.no_grad():
for inputs, targets in val_loader:
outputs = model_prepared(inputs)
val_loss += criterion(outputs, targets).item()
print(f"Epoch {epoch}, Val Loss: {val_loss/len(val_loader):.4f}")
# Convert to quantized model
model_prepared.eval()
quantized_model = convert(model_prepared)
return quantized_model
# Usage
model = MyModel()
quantized_model = quantization_aware_training(model, train_loader, val_loader)
INT8和FP16
INT8量化:
def int8_quantization(model, calibration_loader):
"""Quantize model to INT8."""
model.eval()
# INT8 configuration
model.qconfig = torch.quantization.QConfig(
activation=torch.quantization.MinMaxObserver.with_args(
dtype=torch.qint8
),
weight=torch.quantization.MinMaxObserver.with_args(
dtype=torch.qint8
)
)
prepared_model = prepare(model)
# Calibrate
with torch.no_grad():
for inputs, _ in calibration_loader:
prepared_model(inputs)
quantized_model = convert(prepared_model)
return quantized_model
FP16(半精度):
def convert_to_fp16(model):
"""Convert model to FP16."""
model = model.half()
return model
# Usage
model = MyModel()
model = convert_to_fp16(model)
# Inference with FP16
model.eval()
with torch.no_grad():
inputs = inputs.half() # Convert input to FP16
outputs = model(inputs)
混合精度训练:
from torch.cuda.amp import autocast, GradScaler
def mixed_precision_training(model, train_loader, epochs=10, lr=0.001):
"""Train with mixed precision (FP16)."""
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
scaler = GradScaler()
for epoch in range(epochs):
for inputs, targets in train_loader:
optimizer.zero_grad()
with autocast():
outputs = model(inputs)
loss = criterion(outputs, targets)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
return model
剪枝
结构化剪枝
通道剪枝:
import torch.nn.utils.prune as prune
def structured_prune_channels(model, prune_ratio=0.3):
"""Prune entire channels from convolutional layers."""
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
# Prune 30% of output channels
prune.ln_structured(
module,
name='weight',
amount=prune_ratio,
n=2,
dim=0 # Prune along output channel dimension
)
return model
# Make pruning permanent
def make_pruning_permanent(model):
"""Remove pruning reparameterization."""
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
prune.remove(module, 'weight')
return model
# Usage
model = MyModel()
model = structured_prune_channels(model, prune_ratio=0.3)
model = make_pruning_permanent(model)
过滤器剪枝:
def filter_pruning(model, dataloader, prune_ratio=0.3):
"""Prune filters based on L1 norm."""
# Calculate filter norms
filter_norms = {}
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
# Calculate L1 norm for each filter
filter_norm = module.weight.data.abs().sum(dim=(1, 2, 3))
filter_norms[name] = filter_norm
# Determine filters to prune
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
filter_norm = filter_norms[name]
num_filters = filter_norm.size(0)
num_prune = int(num_filters * prune_ratio)
# Get indices of filters to prune (lowest norm)
_, prune_indices = torch.topk(filter_norm, num_prune, largest=False)
# Create pruning mask
mask = torch.ones(num_filters)
mask[prune_indices] = 0
# Apply pruning
prune.custom_from_mask(module, name='weight', mask=mask.unsqueeze(1).unsqueeze(2).unsqueeze(3))
return model
非结构化剪枝
L1非结构化剪枝:
def unstructured_prune(model, prune_ratio=0.2):
"""Apply unstructured L1 pruning."""
for name, module in model.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
prune.l1_unstructured(
module,
name='weight',
amount=prune_ratio
)
return model
# Usage
model = MyModel()
model = unstructured_prune(model, prune_ratio=0.2)
全局非结构化剪枝:
def global_unstructured_prune(model, prune_ratio=0.2):
"""Prune globally across all layers."""
parameters_to_prune = []
for name, module in model.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
parameters_to_prune.append((module, 'weight'))
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=prune_ratio
)
return model
迭代剪枝
def iterative_pruning(model, train_loader, val_loader,
num_iterations=5, prune_ratio=0.2,
fine_tune_epochs=5):
"""Iteratively prune and fine-tune model."""
criterion = nn.CrossEntropyLoss()
for iteration in range(num_iterations):
print(f"
Pruning iteration {iteration + 1}/{num_iterations}")
# Prune
model = global_unstructured_prune(model, prune_ratio)
# Fine-tune
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
model.train()
for epoch in range(fine_tune_epochs):
for inputs, targets in train_loader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
# Evaluate
model.eval()
val_accuracy = evaluate(model, val_loader)
print(f"Val accuracy after pruning: {val_accuracy:.2f}%")
# Make pruning permanent
for name, module in model.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
prune.remove(module, 'weight')
return model
知识蒸馏
基础知识蒸馏
import torch
import torch.nn as nn
import torch.nn.functional as F
class DistillationLoss(nn.Module):
"""Knowledge distillation loss."""
def __init__(self, alpha=0.5, temperature=3.0):
super().__init__()
self.alpha = alpha
self.temperature = temperature
self.kl_div = nn.KLDivLoss(reduction='batchmean')
def forward(self, student_logits, teacher_logits, targets):
# Hard loss (cross-entropy with true labels)
hard_loss = F.cross_entropy(student_logits, targets)
# Soft loss (KL divergence with teacher)
soft_loss = self.kl_div(
F.log_softmax(student_logits / self.temperature, dim=1),
F.softmax(teacher_logits / self.temperature, dim=1)
) * (self.temperature ** 2)
return self.alpha * soft_loss + (1 - self.alpha) * hard_loss
def knowledge_distillation(teacher_model, student_model,
train_loader, val_loader,
epochs=50, lr=0.001,
alpha=0.5, temperature=3.0):
"""Train student model with knowledge distillation."""
teacher_model.eval()
student_model.train()
optimizer = torch.optim.Adam(student_model.parameters(), lr=lr)
criterion = DistillationLoss(alpha=alpha, temperature=temperature)
for epoch in range(epochs):
student_model.train()
total_loss = 0
for inputs, targets in train_loader:
optimizer.zero_grad()
# Get teacher outputs (no gradients)
with torch.no_grad():
teacher_outputs = teacher_model(inputs)
# Get student outputs
student_outputs = student_model(inputs)
# Compute distillation loss
loss = criterion(student_outputs, teacher_outputs, targets)
loss.backward()
optimizer.step()
total_loss += loss.item()
# Evaluate
student_model.eval()
val_acc = evaluate(student_model, val_loader)
print(f"Epoch {epoch}, Loss: {total_loss/len(train_loader):.4f}, Val Acc: {val_acc:.2f}%")
return student_model
基于特征的蒸馏
class FeatureDistillationLoss(nn.Module):
"""Feature-based distillation loss."""
def __init__(self, feature_weights=None):
super().__init__()
self.feature_weights = feature_weights or [1.0, 1.0, 1.0]
self.mse_loss = nn.MSELoss()
def forward(self, student_features, teacher_features):
"""Compute feature distillation loss."""
total_loss = 0
for i, (s_feat, t_feat) in enumerate(zip(student_features, teacher_features)):
weight = self.feature_weights[i] if i < len(self.feature_weights) else 1.0
total_loss += weight * self.mse_loss(s_feat, t_feat)
return total_loss
def feature_distillation(teacher_model, student_model,
train_loader, val_loader,
epochs=50, lr=0.001):
"""Train student with feature distillation."""
teacher_model.eval()
student_model.train()
optimizer = torch.optim.Adam(student_model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
feature_criterion = FeatureDistillationLoss()
for epoch in range(epochs):
student_model.train()
total_loss = 0
total_cls_loss = 0
total_feat_loss = 0
for inputs, targets in train_loader:
optimizer.zero_grad()
# Get features and outputs
with torch.no_grad():
teacher_features, teacher_outputs = teacher_model.forward_features(inputs)
student_features, student_outputs = student_model.forward_features(inputs)
# Compute losses
cls_loss = criterion(student_outputs, targets)
feat_loss = feature_criterion(student_features, teacher_features)
loss = cls_loss + 0.1 * feat_loss
loss.backward()
optimizer.step()
total_loss += loss.item()
total_cls_loss += cls_loss.item()
total_feat_loss += feat_loss.item()
print(f"Epoch {epoch}, Loss: {total_loss/len(train_loader):.4f}")
return student_model
自我蒸馏
def self_distillation(model, train_loader, val_loader,
epochs=50, lr=0.001, temperature=3.0):
"""Train model with self-distillation."""
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = DistillationLoss(alpha=0.5, temperature=temperature)
# Create an EMA copy of model
ema_model = create_ema_model(model)
for epoch in range(epochs):
model.train()
total_loss = 0
for inputs, targets in train_loader:
optimizer.zero_grad()
# Get EMA outputs
with torch.no_grad():
ema_outputs = ema_model(inputs)
# Get current outputs
current_outputs = model(inputs)
# Compute distillation loss
loss = criterion(current_outputs, ema_outputs, targets)
loss.backward()
optimizer.step()
# Update EMA model
update_ema_model(model, ema_model, decay=0.99)
total_loss += loss.item()
print(f"Epoch {epoch}, Loss: {total_loss/len(train_loader):.4f}")
return model
def create_ema_model(model):
"""Create EMA copy of model."""
ema_model = type(model)(**model.__dict__)
ema_model.load_state_dict(model.state_dict())
ema_model.eval()
return ema_model
def update_ema_model(model, ema_model, decay=0.99):
"""Update EMA model parameters."""
with torch.no_grad():
for ema_param, param in zip(ema_model.parameters(), model.parameters()):
ema_param.data.mul_(decay).add_(param.data, alpha=1 - decay)
模型压缩
权重共享
def apply_weight_sharing(model, num_clusters=16):
"""Apply weight sharing (k-means clustering)."""
for name, module in model.named_modules():
if isinstance(module, (nn.Linear, nn.Conv2d)):
weight = module.weight.data
# Flatten weight for clustering
weight_flat = weight.view(-1, 1).numpy()
# K-means clustering
from sklearn.cluster import KMeans
kmeans = KMeans(n_clusters=num_clusters, random_state=0)
labels = kmeans.fit_predict(weight_flat)
centroids = kmeans.cluster_centers_
# Replace weights with centroids
weight_shared = torch.tensor(centroids[labels].reshape(weight.shape),
dtype=weight.dtype, device=weight.device)
module.weight.data = weight_shared
return model
低秩分解
def low_rank_factorization_conv(conv_layer, rank):
"""Factorize convolutional layer using SVD."""
# Get weight: (out_channels, in_channels, kH, kW)
weight = conv_layer.weight.data
# Reshape to 2D: (out_channels, in_channels * kH * kW)
weight_2d = weight.view(conv_layer.out_channels, -1)
# SVD
U, S, V = torch.svd(weight_2d)
# Truncate
U_r = U[:, :rank]
S_r = torch.diag(S[:rank])
V_r = V[:, :rank]
# Create two layers
layer1 = nn.Conv2d(
conv_layer.in_channels,
rank,
conv_layer.kernel_size,
stride=conv_layer.stride,
padding=conv_layer.padding,
bias=False
)
layer2 = nn.Conv2d(
rank,
conv_layer.out_channels,
kernel_size=1,
bias=conv_layer.bias is not None
)
# Set weights
layer1.weight.data = V_r.t().view(rank, conv_layer.in_channels,
conv_layer.kernel_size[0], conv_layer.kernel_size[1])
layer2.weight.data = (U_r @ S_r).t().view(conv_layer.out_channels, rank, 1, 1)
if conv_layer.bias is not None:
layer2.bias.data = conv_layer.bias.data
return nn.Sequential(layer1, layer2)
架构优化
深度可分离卷积
class DepthwiseSeparableConv(nn.Module):
"""Depthwise separable convolution for efficient models."""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
super().__init__()
self.depthwise = nn.Conv2d(
in_channels,
in_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=in_channels,
bias=False
)
self.pointwise = nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
bias=False
)
self.bn1 = nn.BatchNorm2d(in_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.depthwise(x)
x = self.bn1(x)
x = self.relu(x)
x = self.pointwise(x)
x = self.bn2(x)
x = self.relu(x)
return x
# Replace standard conv with depthwise separable
def replace_with_depthwise(model):
"""Replace Conv2d layers with DepthwiseSeparableConv."""
for name, module in list(model.named_children()):
if isinstance(module, nn.Conv2d):
if module.kernel_size == (1, 1):
# Keep 1x1 conv as is (pointwise)
continue
# Replace with depthwise separable
depthwise_conv = DepthwiseSeparableConv(
module.in_channels,
module.out_channels,
kernel_size=module.kernel_size[0],
stride=module.stride[0],
padding=module.padding[0]
)
setattr(model, name, depthwise_conv)
return model
MobileNet块
class MobileNetBlock(nn.Module):
"""Inverted residual block from MobileNetV2."""
def __init__(self, in_channels, out_channels, stride, expand_ratio):
super().__init__()
hidden_dim = in_channels * expand_ratio
layers = []
# Expansion
if expand_ratio != 1:
layers.extend([
nn.Conv2d(in_channels, hidden_dim, 1, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True)
])
# Depthwise
layers.extend([
nn.Conv2d(hidden_dim, hidden_dim, 3, stride=stride,
padding=1, groups=hidden_dim, bias=False),
nn.BatchNorm2d(hidden_dim),
nn.ReLU6(inplace=True)
])
# Pointwise (linear)
layers.append(nn.Conv2d(hidden_dim, out_channels, 1, bias=False))
layers.append(nn.BatchNorm2d(out_channels))
self.conv = nn.Sequential(*layers)
# Skip connection
self.use_skip = (stride == 1 and in_channels == out_channels)
def forward(self, x):
out = self.conv(x)
if self.use_skip:
return x + out
return out
推理优化
批处理
from collections import deque
import threading
import time
class BatchInference:
"""Batch inference for improved throughput."""
def __init__(self, model, max_batch_size=32, max_wait_time=0.1):
self.model = model
self.model.eval()
self.max_batch_size = max_batch_size
self.max_wait_time = max_wait_time
self.batch_queue = deque()
self.results = {}
self.lock = threading.Lock()
self.running = False
def start(self):
"""Start batch processing thread."""
self.running = True
self.thread = threading.Thread(target=self._process_batches)
self.thread.start()
def stop(self):
"""Stop batch processing."""
self.running = False
self.thread.join()
def predict(self, input_data):
"""Add input to batch queue."""
request_id = id(input_data)
with self.lock:
self.batch_queue.append((request_id, input_data))
return request_id
def get_result(self, request_id, timeout=10):
"""Get prediction result."""
start_time = time.time()
while request_id not in self.results:
if time.time() - start_time > timeout:
raise TimeoutError("Prediction timeout")
time.sleep(0.01)
return self.results.pop(request_id)
def _process_batches(self):
"""Process batches."""
while self.running:
batch = []
start_time = time.time()
with self.lock:
while len(batch) < self.max_batch_size and \
(time.time() - start_time) < self.max_wait_time:
if self.batch_queue:
batch.append(self.batch_queue.popleft())
else:
time.sleep(0.001)
if batch:
request_ids, inputs = zip(*batch)
batch_tensor = torch.stack(inputs)
with torch.no_grad():
outputs = self.model(batch_tensor)
with self.lock:
for req_id, output in zip(request_ids, outputs):
self.results[req_id] = output
缓存
from functools import lru_cache
import hashlib
import pickle
class ModelCache:
"""Cache model predictions."""
def __init__(self, cache_size=1000):
self.cache_size = cache_size
self.cache = {}
def _get_key(self, input_data):
"""Generate cache key from input."""
if isinstance(input_data, torch.Tensor):
input_hash = hashlib.md5(input_data.numpy().tobytes()).hexdigest()
else:
input_hash = hashlib.md5(pickle.dumps(input_data)).hexdigest()
return input_hash
def get(self, input_data):
"""Get cached prediction."""
key = self._get_key(input_data)
return self.cache.get(key)
def set(self, input_data, output):
"""Cache prediction."""
key = self._get_key(input_data)
# Evict oldest if cache is full
if len(self.cache) >= self.cache_size:
self.cache.pop(next(iter(self.cache)))
self.cache[key] = output
def clear(self):
"""Clear cache."""
self.cache.clear()
# Usage
cache = ModelCache(cache_size=1000)
def predict_with_cache(model, input_data):
"""Predict with caching."""
# Check cache
cached_output = cache.get(input_data)
if cached_output is not None:
return cached_output
# Run inference
with torch.no_grad():
output = model(input_data)
# Cache result
cache.set(input_data, output)
return output
GPU优化
import torch
def optimize_for_gpu(model):
"""Optimize model for GPU inference."""
# Enable cuDNN benchmark for optimal convolution algorithms
torch.backends.cudnn.benchmark = True
# Disable deterministic mode for better performance
torch.backends.cudnn.deterministic = False
# Use half precision if supported
if torch.cuda.is_available():
model = model.half()
return model
def optimize_inference(model, input_shape):
"""Optimize model for inference."""
model.eval()
# Optimize with torch.compile (PyTorch 2.0+)
if hasattr(torch, 'compile'):
model = torch.compile(model)
# Add dummy input for tracing
dummy_input = torch.randn(*input_shape)
if torch.cuda.is_available():
dummy_input = dummy_input.cuda()
model = model.cuda()
# Warm up
with torch.no_grad():
for _ in range(10):
_ = model(dummy_input)
return model
ONNX优化
ONNX导出和优化
import torch
import onnx
import onnxruntime as ort
def export_to_onnx(model, input_shape, output_path="model.onnx"):
"""Export PyTorch model to ONNX."""
model.eval()
dummy_input = torch.randn(*input_shape)
torch.onnx.export(
model,
dummy_input,
output_path,
export_params=True,
opset_version=17,
do_constant_folding=True,
input_names=['input'],
output_names=['output'],
dynamic_axes={
'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
}
)
# Verify model
onnx_model = onnx.load(output_path)
onnx.checker.check_model(onnx_model)
return output_path
def optimize_onnx_model(onnx_path, optimized_path="model_optimized.onnx"):
"""Optimize ONNX model."""
from onnxoptimizer import optimize
onnx_model = onnx.load(onnx_path)
# Apply optimizations
optimized_model = optimize(
onnx_model,
passes=[
'eliminate_unused_initializer',
'fuse_add_bias_into_conv',
'fuse_bn_into_conv',
'fuse_consecutive_concats',
'fuse_consecutive_reduce_unsqueeze',
'fuse_consecutive_squeezes',
'fuse_consecutive_transposes',
'fuse_matmul_add_bias_into_gemm',
'fuse_pad_into_conv',
'fuse_transpose_into_gemm',
'eliminate_nop_transpose',
'eliminate_nop_pad',
'eliminate_identity',
'eliminate_deadend',
'fuse_add_conv_into_conv',
'fuse_consecutive_transposes',
'fuse_transpose_into_gemm',
'eliminate_nop_transpose',
'eliminate_nop_pad',
'eliminate_identity',
'eliminate_deadend',
'fuse_add_conv_into_conv',
'fuse_consecutive_squeezes',
'fuse_consecutive_transposes',
'fuse_matmul_add_bias_into_gemm',
'fuse_pad_into_conv',
'fuse_transpose_into_gemm',
'eliminate_nop_transpose',
'eliminate_nop_pad',
'eliminate_identity',
'eliminate_deadend',
'fuse_add_conv_into_conv',
'fuse_consecutive_transposes',
'fuse_transpose_into_gemm',
'eliminate_nop_transpose',
'eliminate_nop_pad',
'eliminate_identity',
'eliminate_deadend',
]
)
onnx.save(optimized_model, optimized_path)
return optimized_path
# Usage
model = MyModel()
export_to_onnx(model, (1, 3, 224, 224))
optimize_onnx_model("model.onnx")
ONNX运行时优化
def create_optimized_onnx_session(onnx_path, providers=['CUDAExecutionProvider']):
"""Create optimized ONNX Runtime session."""
sess_options = ort.SessionOptions()
# Enable graph optimization
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
# Enable memory arena
sess_options.enable_mem_pattern = True
sess_options.enable_cpu_mem_arena = True
# Set execution mode
sess_options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
# Create session
session = ort.InferenceSession(
onnx_path,
sess_options=sess_options,
providers=providers
)
return session
# Usage
session = create_optimized_onnx_session("model.onnx")
# Run inference
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name
outputs = session.run([output_name], {input_name: input_data})
TensorRT集成
TensorRT转换
import torch
from torch2trt import TRTModule
def convert_to_tensorrt(model, input_shape, fp16_mode=True):
"""Convert PyTorch model to TensorRT."""
model.eval()
model = model.cuda()
# Create dummy input
x = torch.ones(input_shape).cuda()
# Convert to TensorRT
model_trt = torch2trt.torch2trt(
model,
[x],
fp16_mode=fp16_mode,
max_workspace_size=1 << 30 # 1GB
)
return model_trt
# Usage
model = MyModel()
model_trt = convert_to_tensorrt(model, (1, 3, 224, 224))
# Save TensorRT model
torch.save(model_trt.state_dict(), 'model_trt.pth')
# Load TensorRT model
model_trt = TRTModule()
model_trt.load_state_dict(torch.load('model_trt.pth'))
TensorRT与ONNX
import tensorrt as trt
import pycuda.driver as cuda
import pycuda.autoinit
def build_tensorrt_engine(onnx_path, engine_path="model.trt", fp16=True):
"""Build TensorRT engine from ONNX model."""
logger = trt.Logger(trt.Logger.WARNING)
builder = trt.Builder(logger)
network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
parser = trt.OnnxParser(network, logger)
# Parse ONNX model
with open(onnx_path, 'rb') as model:
parser.parse(model.read())
# Build configuration
config = builder.create_builder_config()
config.max_workspace_size = 1 << 30 # 1GB
if fp16 and builder.platform_has_fast_fp16:
config.set_flag(trt.BuilderFlag.FP16)
# Build engine
engine = builder.build_engine(network, config)
# Save engine
with open(engine_path, 'wb') as f:
f.write(engine.serialize())
return engine
# Usage
engine = build_tensorrt_engine("model.onnx", fp16=True)
基准测试工具
模型分析
import torch
import time
import numpy as np
class ModelProfiler:
"""Profile model performance."""
def __init__(self, model, input_shape, device='cuda'):
self.model = model
self.input_shape = input_shape
self.device = device
self.model.eval()
if device == 'cuda':
self.model = self.model.cuda()
def profile_inference(self, num_runs=100, warmup=10):
"""Profile inference latency."""
dummy_input = torch.randn(*self.input_shape).to(self.device)
# Warmup
with torch.no_grad():
for _ in range(warmup):
_ = self.model(dummy_input)
# Benchmark
latencies = []
with torch.no_grad():
for _ in range(num_runs):
if self.device == 'cuda':
torch.cuda.synchronize()
start = time.perf_counter()
_ = self.model(dummy_input)
if self.device == 'cuda':
torch.cuda.synchronize()
end = time.perf_counter()
latencies.append((end - start) * 1000) # ms
return {
'mean_ms': np.mean(latencies),
'std_ms': np.std(latencies),
'min_ms': np.min(latencies),
'max_ms': np.max(latencies),
'p50_ms': np.percentile(latencies, 50),
'p95_ms': np.percentile(latencies, 95),
'p99_ms': np.percentile(latencies, 99)
}
def profile_memory(self):
"""Profile GPU memory usage."""
if self.device != 'cuda':
return {'error': 'GPU memory profiling requires CUDA'}
dummy_input = torch.randn(*self.input_shape).cuda()
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
with torch.no_grad():
_ = self.model(dummy_input)
return {
'allocated_mb': torch.cuda.max_memory_allocated() / 1024 / 1024,
'reserved_mb': torch.cuda.max_memory_reserved() / 1024 / 1024
}
def profile_throughput(self, duration_seconds=10):
"""Profile inference throughput."""
dummy_input = torch.randn(*self.input_shape).to(self.device)
start_time = time.time()
num_inferences = 0
with torch.no_grad():
while time.time() - start_time < duration_seconds:
_ = self.model(dummy_input)
num_inferences += 1
elapsed = time.time() - start_time
throughput = num_inferences / elapsed
return {
'duration_seconds': elapsed,
'num_inferences': num_inferences,
'throughput_per_second': throughput
}
# Usage
profiler = ModelProfiler(model, (1, 3, 224, 224), device='cuda')
latency = profiler.profile_inference()
memory = profiler.profile_memory()
throughput = profiler.profile_throughput()
print(f"Latency: {latency['p95_ms']:.2f} ms (p95)")
print(f"Memory: {memory['allocated_mb']:.2f} MB")
print(f"Throughput: {throughput['throughput_per_second']:.2f} inferences/sec")
模型比较
def compare_models(models, input_shape, device='cuda'):
"""Compare multiple models."""
results = {}
for name, model in models.items():
print(f"
Profiling {name}...")
profiler = ModelProfiler(model, input_shape, device)
latency = profiler.profile_inference()
memory = profiler.profile_memory()
throughput = profiler.profile_throughput()
# Count parameters
num_params = sum(p.numel() for p in model.parameters())
results[name] = {
'parameters': num_params,
'latency_p95_ms': latency['p95_ms'],
'memory_mb': memory['allocated_mb'],
'throughput_per_sec': throughput['throughput_per_second']
}
# Print comparison
print("
" + "=" * 80)
print(f"{'Model':<20} {'Params':>12} {'Latency (ms)':>15} {'Memory (MB)':>12} {'Throughput':>12}")
print("=" * 80)
for name, metrics in results.items():
print(f"{name:<20} {metrics['parameters']:>12,} "
f"{metrics['latency_p95_ms']:>15.2f} "
f"{metrics['memory_mb']:>12.2f} "
f"{metrics['throughput_per_sec']:>12.2f}")
return results
# Usage
models = {
'Original': original_model,
'Quantized': quantized_model,
'Pruned': pruned_model
}
compare_models(models, (1, 3, 224, 224))
最佳实践
- 从简单开始
- 从基本优化(FP16)开始
- 逐步应用更积极的技术
- 每次优化后监控准确度
- 使用基线进行比较
- 优化前先测量
- 优化前对模型进行分析
- 记录基线指标(延迟、内存、吞吐量)
- 使用真实输入数据
- 在目标硬件上测试
- 使用适当的优化技术
- 量化以减小大小
- 剪枝以提高速度
- 蒸馏以压缩模型
- 架构设计以提高效率
- 保持准确度
- 设置可接受的准确度下降阈值
- 使用校准数据进行量化
- 剪枝后微调
- 在代表性数据集上验证
- 在目标硬件上测试
- 针对生产硬件进行优化
- 在实际部署环境中测试
- 考虑GPU/CPU限制
- 分析内存使用情况
- 处理边缘情况
- 处理可变输入大小
- 处理批量大小1
- 处理不同类型的数据
- 用现实世界数据进行测试
- 版本控制
- 保留原始模型备份
- 记录优化步骤
- 跟踪模型版本
- 保持可重现性
- 在生产中监控
- 跟踪推理延迟
- 监控内存使用情况
- 设置异常警报
- 记录优化指标
- 使用生产框架
- ONNX用于跨平台部署
- TensorRT用于NVIDIA GPU
- OpenVINO用于Intel CPU
- TFLite用于移动部署
- 持续改进
- A/B测试不同的优化策略
- 收集生产指标
- 根据数据持续优化
- 保持对新技术的更新