PyTorch部署 PyTorchDeployment

这项技能涵盖了如何将 PyTorch 模型部署到生产环境,包括模型导出、优化、服务部署、监控和性能基准测试等关键步骤,是实现深度学习模型从开发到上线的重要指南。

机器学习 0 次安装 0 次浏览 更新于 3/5/2026

PyTorch 部署

概览

PyTorch 部署涉及将模型导出到生产就绪格式,优化推理性能,并通过各种部署模式提供模型服务。这项技能涵盖了 TorchScript、ONNX 导出、TorchServe、模型优化技术、推理优化、FastAPI 部署、模型版本控制、A/B 测试、监控、错误处理和性能基准测试。

前提条件

  • 了解 PyTorch 和深度学习模型
  • 掌握模型训练和评估知识
  • 熟悉 Web 框架(FastAPI, Flask)
  • 理解 Docker 和容器化
  • 基本了解云部署概念

核心概念

模型导出格式

  • TorchScript:PyTorch 的中间表示,用于生产部署
  • 跟踪:捕获来自示例输入的计算路径
  • 脚本:捕获包括控制流在内的整个 Python 代码
  • ONNX:用于跨框架兼容性的开放神经网络交换
  • TorchServe:PyTorch 的模型服务框架

模型优化

  • 量化:降低精度(FP32 → FP16/INT8)以提高效率
  • 剪枝:从模型中移除较不重要的权重
  • 知识蒸馏:从较大的教师模型中训练较小的模型
  • 模型压缩:减少模型大小的技术

推理优化

  • 批处理:一起处理多个输入以提高效率
  • GPU 利用率:多 GPU 推理以提高吞吐量
  • 混合精度:使用 FP16 进行更快的计算
  • 缓存:重复计算结果

部署模式

  • FastAPI 服务器:REST API 用于模型服务
  • TorchServe:生产就绪的模型服务框架
  • ONNX 运行时:高性能推理引擎
  • Docker 部署:容器化模型部署

实施指南

模型导出格式

TorchScript

TorchScript 是 PyTorch 模型的中间表示,可以在 C++ 等高性能环境中运行。

跟踪与脚本:

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 64, 3)
        self.fc = nn.Linear(64 * 26 * 26, 10)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

model = MyModel()
model.eval()

# 方法 1:跟踪(捕获实际计算路径)
example_input = torch.randn(1, 3, 28, 28)
traced_model = torch.jit.trace(model, example_input)
traced_model.save("model_traced.pt")

# 方法 2:脚本(捕获整个 Python 代码)
scripted_model = torch.jit.script(model)
scripted_model.save("model_scripted.pt")

# 加载和推理
loaded_model = torch.jit.load("model_traced.pt")
output = loaded_model(example_input)

处理脚本中的控制流:

class ConditionalModel(nn.Module):
    def forward(self, x):
        if x.sum() > 0:
            return x * 2
        else:
            return x / 2

# 使用脚本处理具有控制流的模型
model = ConditionalModel()
scripted_model = torch.jit.script(model)

ONNX 导出

开放神经网络交换(ONNX)实现了不同框架之间的互操作性。

import torch
import torch.onnx

# 导出到 ONNX
model = MyModel()
model.eval()
dummy_input = torch.randn(1, 3, 28, 28)

torch.onnx.export(
    model,                      # 要导出的模型
    dummy_input,                # 示例输入
    "model.onnx",               # 输出文件
    export_params=True,         # 存储训练参数
    opset_version=17,           # ONNX opset 版本
    do_constant_folding=True,   # 优化常量
    input_names=['input'],      # 输入名称
    output_names=['output'],    # 输出名称
    dynamic_axes={              # 动态轴,用于可变批量大小
        'input': {0: 'batch_size'},
        'output': {0: 'batch_size'}
    }
)

# 验证 ONNX 模型
import onnx
onnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model)

# 使用 ONNX 运行时进行推理
import onnxruntime as ort

session = ort.InferenceSession("model.onnx")
input_name = session.get_inputs()[0].name
output_name = session.get_outputs()[0].name

outputs = session.run([output_name], {input_name: dummy_input.numpy()})

自定义 ONNX 操作符:

from torch.onnx import register_custom_op_symbolic

def custom_gsymbolic(g, input, alpha):
    return g.op("CustomOp", input, alpha_f=alpha)

register_custom_op_symbolic("aten::gelu", custom_gsymbolic, 17)

TorchServe

TorchServe 是一个灵活且易于使用的工具,用于提供 PyTorch 模型服务。

安装:

pip install torchserve torch-model-archiver torch-workflow-archiver

模型归档:

# 创建模型处理器(handler.py)
class ModelHandler:
    def __init__(self):
        self.model = None
        self.mapping = None
        self.device = None
        self.initialized = False

    def initialize(self, context):
        """初始化模型并加载权重。"""
        properties = context.system_properties
        self.device = torch.device("cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() else "cpu")

        model_dir = properties.get("model_dir")
        model_pt_path = os.path.join(model_dir, "model.pth")

        self.model = torch.load(model_pt_path, map_location=self.device)
        self.model.eval()
        self.initialized = True

    def preprocess(self, requests):
        """预处理输入数据。"""
        inputs = []
        for req in requests:
            data = req.get("data") or req.get("body")
            inputs.append(torch.tensor(data))
        return torch.stack(inputs)

    def inference(self, input_data):
        """运行推理。"""
        with torch.no_grad():
            output = self.model(input_data)
        return output

    def postprocess(self, inference_output):
        """后处理输出。"""
        return inference_output.cpu().numpy().tolist()

    def handle(self, data, context):
        """主处理器函数。"""
        try:
            data = self.preprocess(data)
            data = data.to(self.device)
            output = self.inference(data)
            return self.postprocess(output)
        except Exception as e:
            return [{"error": str(e)}]

归档和提供服务:

# 归档模型
torch-model-archiver \
  --model-name mymodel \
  --version 1.0 \
  --serialized-file model.pth \
  --handler handler.py \
  --extra-files config.json,index_to_name.json \
  --export-path model_store

# 启动 TorchServe
torchserve --start --ncs --model-store model_store --models mymodel=mymodel.mar

# 进行预测
curl -X POST http://localhost:8080/predictions/mymodel \
  -H "Content-Type: application/json" \
  -d '{"data": [[...]]}'

模型优化

量化

量化通过使用较低精度的数字来减小模型大小并提高推理速度。

后训练量化(PTQ):

import torch
from torch.quantization import quantize_dynamic

# 动态量化(权重量化,激活计算为浮点数)
model = MyModel()

# 量化特定层
quantized_model = quantize_dynamic(
    model,
    {nn.Linear, nn.LSTM},  # 要量化的层
    dtype=torch.qint8      # 量化数据类型
)

# 保存量化模型
torch.jit.save(torch.jit.script(quantized_model), "model_quantized.pt")

静态量化:

import torch
from torch.quantization import (
    quantize,
    prepare,
    convert,
    get_default_qconfig,
)

# 为静态量化准备模型
model = MyModel()
model.eval()

# 设置量化配置
model.qconfig = get_default_qconfig('fbgemm')

# 使用校准数据准备模型
prepared_model = prepare(model)

# 使用代表性数据进行校准
with torch.no_grad():
    for data in calibration_dataloader:
        prepared_model(data)

# 转换为量化模型
quantized_model = convert(prepared_model)

# 保存
torch.jit.save(torch.jit.script(quantized_model), "model_static_quantized.pt")

量化感知训练(QAT):

import torch
from torch.quantization import prepare_qat, convert

# 为 QAT 准备模型
model = MyModel()
model.train()
model.qconfig = get_default_qconfig('fbgemm')

# 准备 QAT
model_prepared = prepare_qat(model, inplace=True)

# 使用量化模拟进行微调
optimizer = torch.optim.SGD(model_prepared.parameters(), lr=0.01)

for epoch in range(num_epochs):
    for batch in train_dataloader:
        optimizer.zero_grad()
        loss = criterion(model_prepared(batch[0]), batch[1])
        loss.backward()
        optimizer.step()

# 转换为量化模型
model_prepared.eval()
quantized_model = convert(model_prepared)

剪枝

剪枝从模型中移除较不重要的权重。

结构化剪枝:

import torch.nn.utils.prune as prune
import torch

model = MyModel()

# 在线性层中剪枝 30% 的权重
for name, module in model.named_modules():
    if isinstance(module, nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.3)

# 使剪枝永久化
for name, module in model.named_modules():
    if isinstance(module, nn.Linear):
        prune.remove(module, 'weight')

全局无结构剪枝:

# 全局跨所有层剪枝
parameters_to_prune = []
for name, module in model.named_modules():
    if isinstance(module, (nn.Linear, nn.Conv2d)):
        parameters_to_prune.append((module, 'weight'))

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2
)

迭代剪枝:

def iterative_pruning(model, train_loader, num_iterations=5, prune_amount=0.2):
    for iteration in range(num_iterations):
        print(f"Pruning iteration {iteration + 1}/{num_iterations}")

        # 剪枝
        for name, module in model.named_modules():
            if isinstance(module, nn.Linear):
                prune.l1_unstructured(module, name='weight', amount=prune_amount)

        # 微调
        for epoch in range(5):
            for batch in train_loader:
                # 训练代码在这里
                pass

    # 使剪枝永久化
    for name, module in model.named_modules():
        if isinstance(module, nn.Linear):
            prune.remove(module, 'weight')

    return model

模型压缩

知识蒸馏:

import torch
import torch.nn as nn
import torch.nn.functional as F

class DistillationLoss(nn.Module):
    def __init__(self, alpha=0.5, temperature=2.0):
        super().__init__()
        self.alpha = alpha
        self.temperature = temperature
        self.kl_div = nn.KLDivLoss(reduction='batchmean')

    def forward(self, student_logits, teacher_logits, targets):
        # 硬损失(带有真实标签的交叉熵)
        hard_loss = F.cross_entropy(student_logits, targets)

        # 软损失(与教师的 KL 散度)
        soft_loss = self.kl_div(
            F.log_softmax(student_logits / self.temperature, dim=1),
            F.softmax(teacher_logits / self.temperature, dim=1)
        ) * (self.temperature ** 2)

        return self.alpha * soft_loss + (1 - self.alpha) * hard_loss

# 训练循环
teacher_model = load_teacher_model()
student_model = create_student_model()
criterion = DistillationLoss(alpha=0.7, temperature=3.0)

teacher_model.eval()
student_model.train()

for batch in train_loader:
    inputs, targets = batch

    with torch.no_grad():
        teacher_outputs = teacher_model(inputs)

    student_outputs = student_model(inputs)
    loss = criterion(student_outputs, teacher_outputs, targets)

    loss.backward()
    optimizer.step()

推理优化

批处理

import torch
from collections import deque
import threading
import time

class BatchInferenceServer:
    def __init__(self, model, max_batch_size=32, max_wait_time=0.1):
        self.model = model
        self.model.eval()
        self.max_batch_size = max_batch_size
        self.max_wait_time = max_wait_time
        self.batch_queue = deque()
        self.results = {}
        self.lock = threading.Lock()
        self.running = False

    def start(self):
        self.running = True
        self.thread = threading.Thread(target=self._process_batches)
        self.thread.start()

    def stop(self):
        self.running = False
        self.thread.join()

    def predict(self, input_data):
        request_id = id(input_data)
        with self.lock:
            self.batch_queue.append((request_id, input_data))
        return request_id

    def get_result(self, request_id, timeout=10):
        start_time = time.time()
        while request_id not in self.results:
            if time.time() - start_time > timeout:
                raise TimeoutError("Prediction timeout")
            time.sleep(0.01)
        return self.results.pop(request_id)

    def _process_batches(self):
        while self.running:
            batch = []
            start_time = time.time()

            with self.lock:
                while len(batch) < self.max_batch_size and \
                      (time.time() - start_time) < self.max_wait_time:
                    if self.batch_queue:
                        batch.append(self.batch_queue.popleft())
                    else:
                        time.sleep(0.001)

            if batch:
                request_ids, inputs = zip(*batch)
                batch_tensor = torch.stack(inputs)

                with torch.no_grad():
                    outputs = self.model(batch_tensor)

                with self.lock:
                    for req_id, output in zip(request_ids, outputs):
                        self.results[req_id] = output

GPU 利用率

import torch
import torch.multiprocessing as mp

def run_inference(rank, model, inputs, outputs):
    """多 GPU 推理的工作者函数。"""
    torch.cuda.set_device(rank)
    model = model.to(rank)
    model.eval()

    with torch.no_grad():
        outputs[rank] = model(inputs[rank])

def multi_gpu_inference(model, inputs):
    """跨多个 GPU 分发推理。"""
    num_gpus = torch.cuda.device_count()
    outputs = [None] * num_gpus

    # 跨 GPU 分割输入
    inputs_per_gpu = torch.chunk(inputs, num_gpus)
    inputs = [inp.to(i) for i, inp in enumerate(inputs_per_gpu)]

    # 启动进程
    mp.spawn(
        run_inference,
        args=(model, inputs, outputs),
        nprocs=num_gpus,
        join=True
    )

    return torch.cat(outputs, dim=0)

混合精度

import torch
from torch.cuda.amp import autocast, GradScaler

# 推理时使用混合精度
def mixed_precision_inference(model, inputs):
    model.eval()
    with autocast():
        with torch.no_grad():
            outputs = model(inputs)
    return outputs

# 训练时使用混合精度
scaler = GradScaler()

for batch in train_loader:
    inputs, targets = batch
    inputs, targets = inputs.to(device), targets.to(device)

    optimizer.zero_grad()

    with autocast():
        outputs = model(inputs)
        loss = criterion(outputs, targets)

    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()

部署模式

FastAPI 服务器

from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import io

app = FastAPI(title="PyTorch 模型 API")

# 加载模型
class ImageClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 3)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc = nn.Linear(64 * 13 * 13, 10)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = x.view(x.size(0), -1)
        return self.fc(x)

model = ImageClassifier()
model.load_state_dict(torch.load("model.pth"))
model.eval()

# 预处理
transform = transforms.Compose([
    transforms.Resize((28, 28)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

class PredictionResponse(BaseModel):
    class_id: int
    class_name: str
    confidence: float

@app.post("/predict", response_model=PredictionResponse)
async def predict(file: UploadFile = File(...)):
    try:
        # 读取和预处理图像
        image_data = await file.read()
        image = Image.open(io.BytesIO(image_data)).convert("RGB")
        input_tensor = transform(image).unsqueeze(0)

        # 推理
        with torch.no_grad():
            outputs = model(input_tensor)
            probabilities = torch.softmax(outputs, dim=1)
            confidence, predicted = torch.max(probabilities, 1)

        return PredictionResponse(
            class_id=predicted.item(),
            class_name=f"class_{predicted.item()}",
            confidence=confidence.item()
        )
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))

@app.get("/health")
async def health_check():
    return {"status": "healthy"}

@app.get("/model/info")
async def model_info():
    return {
        "model_type": "ImageClassifier",
        "parameters": sum(p.numel() for p in model.parameters()),
        "input_shape": "(batch, 3, 28, 28)",
        "output_shape": "(batch, 10)"
    }

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)

TorchServe 配置

config.properties:

inference_address=http://0.0.0.0:8080
management_address=http://0.0.0.0:8081
metrics_address=http://0.0.0.0:8082
number_of_netty_threads=4
job_queue_size=10
model_store=model_store
load_models=all
number_of_gpu=1
default_response_timeout=120

Docker 部署:

FROM pytorch/torchserve:latest

# 复制模型归档
COPY model_store /home/model-server/model-store

# 复制配置
COPY config.properties /home/model-server/config.properties

# 暴露端口
EXPOSE 8080 8081 8082

# 启动 TorchServe
CMD ["torchserve", \
     "--start", \
     "--model-store", "/home/model-server/model-store", \
     "--models", "mymodel=mymodel.mar", \
     "--ts-config", "/home/model-server/config.properties"]

ONNX 运行时服务器

Python ONNX 运行时服务器:

from fastapi import FastAPI
import numpy as np
import onnxruntime as ort
from pydantic import BaseModel

app = FastAPI()

# 加载 ONNX 模型
session = ort.InferenceSession("model.onnx")

class InputData(BaseModel):
    data: list

@app.post("/predict")
async def predict(input_data: InputData):
    input_array = np.array(input_data.data, dtype=np.float32)

    # 运行推理
    outputs = session.run(
        None,
        {session.get_inputs()[0].name: input_array}
    )

    return {"output": outputs[0].tolist()}

模型版本控制

版本控制策略

import os
import json
from datetime import datetime
import torch

class ModelVersionManager:
    def __init__(self, base_path="models"):
        self.base_path = base_path
        os.makedirs(base_path, exist_ok=True)

    def save_model(self, model, version, metadata=None):
        """保存带有版本和元数据的模型。"""
        version_path = os.path.join(self.base_path, f"v{version}")
        os.makedirs(version_path, exist_ok=True)

        # 保存模型权重
        model_path = os.path.join(version_path, "model.pth")
        torch.save(model.state_dict(), model_path)

        # 保存元数据
        metadata = metadata or {}
        metadata.update({
            "version": version,
            "saved_at": datetime.now().isoformat(),
            "model_path": model_path
        })

        metadata_path = os.path.join(version_path, "metadata.json")
        with open(metadata_path, "w") as f:
            json.dump(metadata, f, indent=2)

        return version_path

    def load_model(self, version, model_class):
        """通过版本加载模型。"""
        version_path = os.path.join(self.base_path, f"v{version}")
        model_path = os.path.join(version_path, "model.pth")

        model = model_class()
        model.load_state_dict(torch.load(model_path))
        model.eval()

        return model

    def list_versions(self):
        """列出所有可用版本。"""
        versions = []
        for item in os.listdir(self.base_path):
            if item.startswith("v"):
                version_path = os.path.join(self.base_path, item)
                metadata_path = os.path.join(version_path, "metadata.json")
                if os.path.exists(metadata_path):
                    with open(metadata_path) as f:
                        versions.append(json.load(f))
        return sorted(versions, key=lambda x: x["version"])

A/B 测试模型

import random
from typing import Dict, Optional
import torch

class ABTestModelRouter:
    def __init__(self, models: Dict[str, torch.nn.Module], traffic_split: Dict[str, float]):
        """
        参数:
            models: 字典,模型名称 -> 模型
            traffic_split: 字典,模型名称 -> 流量百分比(总和必须为 1.0)
        """
        self.models = models
        self.traffic_split = traffic_split
        self.model_names = list(traffic_split.keys())
        self.cumulative_split = []
        cumulative = 0
        for name in self.model_names:
            cumulative += traffic_split[name]
            self.cumulative_split.append(cumulative)

    def get_model(self, request_id: Optional[str] = None) -> torch.nn.Module:
        """根据流量分割选择模型。"""
        # 使用 request_id 进行一致性路由(同一请求始终路由到同一模型)
        if request_id:
            hash_val = hash(request_id) % 1000
            rand_val = hash_val / 1000.0
        else:
            rand_val = random.random()

        for i, threshold in enumerate(self.cumulative_split):
            if rand_val < threshold:
                return self.models[self.model_names[i]]

        return self.models[self.model_names[-1]]

    def predict(self, input_data, request_id: Optional[str] = None):
        """进行 A/B 测试。"""
        model = self.get_model(request_id)
        model.eval()

        with torch.no_grad():
            output = model(input_data)

        return output

# 使用
model_a = create_model_v1()
model_b = create_model_v2()

router = ABTestModelRouter(
    models={"v1": model_a, "v2": model_b},
    traffic_split={"v1": 0.7, "v2": 0.3}
)

# 预测将 70% 路由到 v1,30% 路由到 v2
output = router.predict(input_data, request_id="user_123")

模型监控

import time
import json
from collections import defaultdict
from datetime import datetime
import torch

class ModelMonitor:
    def __init__(self, model_name: str):
        self.model_name = model_name
        self.metrics = defaultdict(list)
        self.start_time = time.time()

    def log_prediction(self, request_id: str, input_shape: tuple,
                       output_shape: tuple, latency: float,
                       model_version: str):
        """记录预测指标。"""
        self.metrics["predictions"].append({
            "request_id": request_id,
            "timestamp": datetime.now().isoformat(),
            "input_shape": input_shape,
            "output_shape": output_shape,
            "latency_ms": latency,
            "model_version": model_version
        })

    def log_error(self, request_id: str, error_type: str, error_message: str):
        """记录预测错误。"""
        self.metrics["errors"].append({
            "request_id": request_id,
            "timestamp": datetime.now().isoformat(),
            "error_type": error_type,
            "error_message": error_message
        })

    def get_summary(self):
        """获取监控摘要。"""
        predictions = self.metrics["predictions"]
        errors = self.metrics["errors"]

        if predictions:
            avg_latency = sum(p["latency_ms"] for p in predictions) / len(predictions)
            total_predictions = len(predictions)
        else:
            avg_latency = 0
            total_predictions = 0

        return {
            "model_name": self.model_name,
            "uptime_seconds": time.time() - self.start_time,
            "total_predictions": total_predictions,
            "total_errors": len(errors),
            "average_latency_ms": avg_latency,
            "error_rate": len(errors) / max(total_predictions, 1) * 100
        }

# 与 FastAPI 一起使用
monitor = ModelMonitor("image_classifier")

@app.post("/predict")
async def predict(file: UploadFile = File(...)):
    request_id = str(uuid.uuid4())
    start_time = time.time()

    try:
        # 预处理和预测
        output = model(input_tensor)
        latency = (time.time() - start_time) * 1000

        monitor.log_prediction(
            request_id=request_id,
            input_shape=tuple(input_tensor.shape),
            output_shape=tuple(output.shape),
            latency=latency,
            model_version="1.0"
        )

        return {"output": output.tolist()}

    except Exception as e:
        monitor.log_error(request_id, type(e).__name__, str(e))
        raise HTTPException(status_code=500, detail=str(e))

错误处理

import logging
from functools import wraps
import torch

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class ModelInferenceError(Exception):
    """模型推理错误的基类。"""
    pass

class ModelLoadError(ModelInferenceError):
    """模型加载失败时引发的异常。"""
    pass

class InputValidationError(ModelInferenceError):
    """输入验证失败时引发的异常。"""
    pass

def handle_inference_errors(func):
    """处理推理错误的装饰器。"""
    @wraps(func)
    def wrapper(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except ModelLoadError as e:
            logger.error(f"Model load error: {e}")
            raise
        except InputValidationError as e:
            logger.warning(f"Input validation error: {e}")
            raise
        except torch.cuda.OutOfMemoryError:
            logger.error("CUDA out of memory")
            raise ModelInferenceError("GPU memory exhausted")
        except Exception as e:
            logger.error(f"Unexpected error during inference: {e}")
            raise ModelInferenceError(f"Inference failed: {str(e)}")
    return wrapper

class SafeModelWrapper:
    """用于安全模型推理的错误处理包装器。"""
    def __init__(self, model_path, device="cuda"):
        self.model_path = model_path
        self.device = device
        self.model = None
        self._load_model()

    def _load_model(self):
        """带错误处理的模型加载。"""
        try:
            self.model = torch.load(self.model_path, map_location=self.device)
            self.model.eval()
            logger.info(f"Model loaded successfully from {self.model_path}")
        except FileNotFoundError:
            raise ModelLoadError(f"Model file not found: {self.model_path}")
        except Exception as e:
            raise ModelLoadError(f"Failed to load model: {str(e)}")

    @handle_inference_errors
    def predict(self, input_data):
        """带错误处理的安全预测。"""
        if not isinstance(input_data, torch.Tensor):
            raise InputValidationError("Input must be a torch.Tensor")

        if input_data.dim() != 4:
            raise InputValidationError(f"Expected 4D input, got {input_data.dim()}D")

        with torch.no_grad():
            output = self.model(input_data.to(self.device))

        return output.cpu()

    def get_model_info(self):
        """获取模型信息。"""
        total_params = sum(p.numel() for p in self.model.parameters())
        trainable_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)

        return {
            "model_path": self.model_path,
            "device": str(self.device),
            "total_parameters": total_params,
            "trainable_parameters": trainable_params
        }

性能基准测试

import time
import torch
import numpy as np
from typing import List, Dict
import json

class ModelBenchmark:
    def __init__(self, model, input_shape, warmup_runs=10, benchmark_runs=100):
        self.model = model
        self.model.eval()
        self.input_shape = input_shape
        self.warmup_runs = warmup_runs
        self.benchmark_runs = benchmark_runs
        self.device = next(model.parameters()).device

    def _generate_input(self, batch_size=1):
        """为基准测试生成随机输入。"""
        return torch.randn(batch_size, *self.input_shape, device=self.device)

    def benchmark_latency(self, batch_sizes: List[int] = [1, 8, 16, 32]):
        """针对不同批量大小进行推理延迟基准测试。"""
        results = {}

        for batch_size in batch_sizes:
            # 预热
            for _ in range(self.warmup_runs):
                input_data = self._generate_input(batch_size)
                with torch.no_grad():
                    _ = self.model(input_data)

            # 基准测试
            latencies = []
            for _ in range(self.benchmark_runs):
                input_data = self._generate_input(batch_size)

                torch.cuda.synchronize() if self.device.type == 'cuda' else None
                start_time = time.perf_counter()

                with torch.no_grad():
                    _ = self.model(input_data)

                torch.cuda.synchronize() if self.device.type == 'cuda' else None
                end_time = time.perf_counter()

                latencies.append((end_time - start_time) * 1000)  # 转换为毫秒

            results[batch_size] = {
                "mean_ms": np.mean(latencies),
                "std_ms": np.std(latencies),
                "min_ms": np.min(latencies),
                "max_ms": np.max(latencies),
                "p50_ms": np.percentile(latencies, 50),
                "p95_ms": np.percentile(latencies, 95),
                "p99_ms": np.percentile(latencies, 99)
            }

        return results

    def benchmark_throughput(self, duration_seconds=30):
        """在一定时间内基准测试吞吐量。"""
        input_data = self._generate_input(1)

        start_time = time.time()
        predictions = 0

        while (time.time() - start_time) < duration_seconds:
            with torch.no_grad():
                _ = self.model(input_data)
            predictions += 1

        elapsed = time.time() - start_time
        throughput = predictions / elapsed

        return {
            "duration_seconds": elapsed,
            "total_predictions": predictions,
            "throughput_per_second": throughput
        }

    def benchmark_memory(self, batch_sizes: List[int] = [1, 8, 16, 32]):
        """基准测试 GPU 内存使用情况。"""
        if self.device.type != 'cuda':
            return {"error": "GPU 内存基准测试需要 CUDA"}

        results = {}
        torch.cuda.reset_peak_memory_stats()

        for batch_size in batch_sizes:
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats()

            input_data = self._generate_input(batch_size)

            with torch.no_grad():
                _ = self.model(input_data)

            results[batch_size] = {
                "allocated_mb": torch.cuda.max_memory_allocated() / 1024 / 1024,
                "reserved_mb": torch.cuda.max_memory_reserved() / 1024 / 1024
            }

        return results

    def run_full_benchmark(self):
        """运行全面基准测试。"""
        print("Running Model Benchmark...")
        print("=" * 50)

        # 延迟基准测试
        print("
1. Latency Benchmark:")
        latency_results = self.benchmark_latency()
        for batch_size, metrics in latency_results.items():
            print(f"  Batch {batch_size}: {metrics['mean_ms']:.2f}ms (p95: {metrics['p95_ms']:.2f}ms)")

        # 吞吐量基准测试
        print("
2. Throughput Benchmark:")
        throughput_results = self.benchmark_throughput()
        print(f"  {throughput_results['throughput_per_second']:.2f} predictions/second")

        # 内存基准测试
        print("
3. Memory Benchmark:")
        memory_results = self.benchmark_memory()
        for batch_size, metrics in memory_results.items():
            print(f"  Batch {batch_size}: {metrics['allocated_mb']:.2f} MB allocated")

        # 模型信息
        print("
4. Model Info:")
        total_params = sum(p.numel() for p in self.model.parameters())
        print(f"  Total parameters: {total_params:,}")

        return {
            "latency": latency_results,
            "throughput": throughput_results,
            "memory": memory_results,
            "parameters": total_params
        }

# 使用
model = load_model()
benchmark = ModelBenchmark(model, input_shape=(3, 224, 224))
results = benchmark.run_full_benchmark()

# 保存结果
with open("benchmark_results.json", "w") as f:
    json.dump(results, f, indent=2)

最佳实践

部署前清单

  • 模型导出
    • 模型导出到生产格式(TorchScript/ONNX)
    • 导出的模型经过测试和验证
    • 优化模型大小(量化/剪枝)
  • 性能
    • 推理延迟满足 SLA(实时 < 100ms)
    • 预期负载下测试吞吐量
    • GPU 内存使用优化
    • 配置批处理
  • 可靠性
    • 实现错误处理
    • 失败时的优雅降级
    • 外部依赖的断路器模式
    • 瞬态故障的重试逻辑
  • 监控
    • 指标收集(延迟、吞吐量、错误)
    • 配置日志记录
    • 健康检查端点
    • 设置警报阈值
  • 安全性
    • 实现输入验证
    • 配置速率限制
    • API 的身份验证/授权
    • 安全存储模型文件
  • 部署
    • 创建 Docker 容器
    • 配置环境变量
    • 设置 CI/CD 管道
    • 蓝绿部署策略

部署后清单

  • 验证
    • 通过烟雾测试
    • 开始 A/B 测试
    • 监控模型性能
    • 错误率在可接受范围内
  • 文档
    • 更新 API 文档
    • 记录模型版本
    • 记录已知问题
    • 创建运行手册

性能优化提示

  1. 使用 TorchScript 进行生产
    • 将模型导出到 TorchScript 以获得更快的推理速度
    • 对于没有控制流的模型使用跟踪
    • 对于具有动态控制流的模型使用脚本
  2. 应用量化
    • 使用动态量化快速部署
    • 使用静态量化以获得更好的性能
    • 使用 QAT 以最小的准确度损失
  3. 优化批量大小
    • 寻找适合您的硬件的最佳批量大小
    • 使用更大的批量以获得更好的 GPU 利用率
    • 选择批量大小时考虑延迟要求
  4. 使用混合精度
    • 启用 FP16 以获得更快的计算速度
    • 使用 GradScaler 进行训练稳定性
    • 部署前测试准确度影响
  5. 监控模型性能
    • 跟踪延迟、吞吐量和错误率
    • 设置性能下降的警报
    • 监控 GPU 内存使用情况
    • 跟踪预测漂移

相关技能