模型版本控制 ModelVersioning

模型版本控制技能专注于机器学习模型的全生命周期管理,包括版本策略、模型注册表、元数据跟踪、部署工作流程、A/B测试和模型比较。关键词:机器学习、模型管理、版本控制、元数据、部署、MLOps。

机器学习 0 次安装 0 次浏览 更新于 3/5/2026

名称: 模型版本控制 描述: 机器学习模型版本控制和管理策略的综合指南,包括模型注册表、元数据跟踪和部署工作流程。

模型版本控制

概述

模型版本控制是在整个生命周期中跟踪和管理机器学习模型不同版本的实践。此技能涵盖版本控制策略、模型注册表、元数据管理、谱系跟踪、工件存储、推广工作流程、A/B测试和模型比较工具。

前提条件

  • 理解机器学习模型开发
  • 了解 Git 和版本控制系统
  • 熟悉模型部署概念
  • 理解数据管道和工作流程
  • 数据库系统的基础知识

关键概念

版本控制策略

  • 语义版本控制:使用 MAJOR.MINOR.PATCH 格式跟踪变更
  • 基于时间戳的版本控制:使用时间戳作为唯一版本标识符
  • 基于 Git 的版本控制:利用 git 提交和标签进行版本控制

模型注册表

  • 集中存储:模型版本的单点真理
  • 阶段管理:开发、暂存、生产、归档阶段
  • 元数据跟踪:全面的模型信息存储
  • 模型加载:通过版本或阶段检索模型

元数据管理

  • 模型元数据模式:关于模型的结构化信息
  • 元数据存储:用于存储和查询元数据的数据库
  • 谱系跟踪:跟踪模型关系和数据源
  • 工件存储:管理模型文件和相关工件

部署工作流程

  • 推广管道:将模型移动到不同阶段
  • 回滚策略:回滚到先前版本
  • A/B 测试:在生产中比较模型版本
  • 模型比较:跨版本分析性能

实施指南

版本控制策略

语义版本控制

from dataclasses import dataclass
from typing import Optional
from datetime import datetime

@dataclass
class ModelVersion:
    """使用语义版本控制的模型版本。"""
    major: int
    minor: int
    patch: int
    pre_release: Optional[str] = None
    build_metadata: Optional[str] = None

    def __str__(self):
        version = f"{self.major}.{self.minor}.{self.patch}"
        if self.pre_release:
            version += f"-{self.pre_release}"
        if self.build_metadata:
            version += f"+{self.build_metadata}"
        return version

    @staticmethod
    def parse(version_string: str) -> 'ModelVersion':
        """解析版本字符串。"""
        # 解析语义版本
        parts = version_string.split('+')
        version = parts[0]
        build = parts[1] if len(parts) > 1 else None

        parts = version.split('-')
        version = parts[0]
        pre = parts[1] if len(parts) > 1 else None

        major, minor, patch = map(int, version.split('.'))

        return ModelVersion(major, minor, patch, pre, build)

    def increment_major(self):
        """递增主版本(重大变更)。"""
        return ModelVersion(self.major + 1, 0, 0)

    def increment_minor(self):
        """递增次版本(新功能)。"""
        return ModelVersion(self.major, self.minor + 1, 0)

    def increment_patch(self):
        """递增补丁版本(错误修复)。"""
        return ModelVersion(self.major, self.minor, self.patch + 1)

# 使用
v1 = ModelVersion(1, 0, 0)
print(v1)  # 1.0.0

v2 = v1.increment_minor()
print(v2)  # 1.1.0

v3 = ModelVersion.parse("2.1.3-beta+build123")
print(v3)  # 2.1.3-beta+build123

基于时间戳的版本控制

from datetime import datetime
import pytz

class TimestampVersion:
    """基于时间戳的版本控制。"""

    def __init__(self, timezone='UTC'):
        self.timezone = pytz.timezone(timezone)

    def generate(self) -> str:
        """生成基于时间戳的版本。"""
        now = datetime.now(self.timezone)
        return now.strftime("%Y%m%d-%H%M%S")

    def generate_with_microseconds(self) -> str:
        """生成带有微秒的版本。"""
        now = datetime.now(self.timezone)
        return now.strftime("%Y%m%d-%H%M%S-%f")

    def parse(self, version_string: str) -> datetime:
        """解析时间戳版本。"""
        # 处理两种格式
        if '-' in version_string and version_string.count('-') == 2:
            # 带有微秒
            dt_str = version_string.replace('-', '')
            dt_str = dt_str[:-6] + '.' + dt_str[-6:]
        else:
            # 不带微秒
            dt_str = version_string.replace('-', '')

        return datetime.strptime(dt_str, "%Y%m%d%H%M%S")

# 使用
versioner = TimestampVersion()
version = versioner.generate()
print(version)  # 20240114-123045

version_micro = versioner.generate_with_microseconds()
print(version_micro)  # 20240114-123045-123456

基于 Git 的版本控制

import subprocess
from typing import Optional

class GitVersion:
    """基于 Git 的版本控制。"""

    @staticmethod
    def get_commit_hash(short: bool = True) -> Optional[str]:
        """获取当前提交哈希。"""
        try:
            length = 7 if short else None
            result = subprocess.run(
                ['git', 'rev-parse', 'HEAD'],
                capture_output=True,
                text=True,
                check=True
            )
            commit_hash = result.stdout.strip()
            return commit_hash[:length] if short else commit_hash
        except subprocess.CalledProcessError:
            return None

    @staticmethod
    def get_branch() -> Optional[str]:
        """获取当前分支名称。"""
        try:
            result = subprocess.run(
                ['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
                capture_output=True,
                text=True,
                check=True
            )
            return result.stdout.strip()
        except subprocess.CalledProcessError:
            return None

    @staticmethod
    def get_tag() -> Optional[str]:
        """获取当前标签。"""
        try:
            result = subprocess.run(
                ['git', 'describe', '--tags', '--exact-match'],
                capture_output=True,
                text=True,
                check=True
            )
            return result.stdout.strip()
        except subprocess.CalledProcessError:
            return None

    @staticmethod
    def generate_version() -> str:
        """从 Git 信息生成版本。"""
        tag = GitVersion.get_tag()
        if tag:
            return tag

        branch = GitVersion.get_branch()
        commit = GitVersion.get_commit_hash(short=True)

        if branch and commit:
            return f"{branch}-{commit}"
        elif commit:
            return commit
        else:
            return "unknown"

# 使用
version = GitVersion.generate_version()
print(version)  # main-abc1234 或 v1.0.0

模型注册表

MLflow 模型注册表

import mlflow
import mlflow.pytorch
from mlflow.tracking import MlflowClient
from typing import Dict, Any

class MLflowModelRegistry:
    """MLflow 模型注册表包装器。"""

    def __init__(self, tracking_uri: str = None):
        if tracking_uri:
            mlflow.set_tracking_uri(tracking_uri)

        self.client = MlflowClient()

    def register_model(
        self,
        model,
        model_name: str,
        version: str,
        description: str = None,
        tags: Dict[str, Any] = None,
        metrics: Dict[str, float] = None
    ):
        """向 MLflow 注册模型。"""
        with mlflow.start_run():
            # 记录模型
            mlflow.pytorch.log_model(model, "model")

            # 记录参数
            if tags:
                mlflow.set_tags(tags)

            # 记录指标
            if metrics:
                mlflow.log_metrics(metrics)

            # 注册模型
            model_uri = f"runs:/{mlflow.active_run().info.run_id}/model"
            registered_model = mlflow.register_model(
                model_uri,
                model_name,
                tags={"version": version}
            )

            # 更新描述
            if description:
                self.client.update_model_version(
                    name=model_name,
                    version=registered_model.version,
                    description=description
                )

        return registered_model

    def get_model_version(self, model_name: str, version: str = None):
        """从注册表获取模型版本。"""
        if version:
            return self.client.get_model_version(model_name, version)
        else:
            # 获取最新版本
            latest = self.client.get_latest_versions(model_name, stages=["Production"])
            return latest[0] if latest else None

    def load_model(self, model_name: str, version: str = None, stage: str = None):
        """从注册表加载模型。"""
        if stage:
            model_uri = f"models:/{model_name}/{stage}"
        elif version:
            model_uri = f"models:/{model_name}/{version}"
        else:
            model_uri = f"models:/{model_name}/Production"

        return mlflow.pytorch.load_model(model_uri)

    def transition_stage(
        self,
        model_name: str,
        version: str,
        stage: str,
        archive_existing_versions: bool = False
    ):
        """将模型转换到新阶段。"""
        self.client.transition_model_version_stage(
            name=model_name,
            version=version,
            stage=stage,
            archive_existing_versions=archive_existing_versions
        )

    def list_models(self, name_filter: str = None):
        """列出所有注册的模型。"""
        models = self.client.search_registered_models(filter_string=name_filter)
        return models

    def get_model_history(self, model_name: str):
        """获取模型的版本历史。"""
        versions = self.client.get_model_version_stages(model_name)
        return versions

# 使用
registry = MLflowModelRegistry(tracking_uri="http://localhost:5000")

# 注册模型
registered = registry.register_model(
    model=my_model,
    model_name="image_classifier",
    version="1.0.0",
    description="初始模型发布",
    tags={"framework": "pytorch", "task": "classification"},
    metrics={"accuracy": 0.95, "f1": 0.94}
)

# 转换到生产
registry.transition_stage("image_classifier", "1", "Production")

# 从生产加载模型
model = registry.load_model("image_classifier", stage="Production")

自定义模型注册表

import json
import shutil
from pathlib import Path
from datetime import datetime
import hashlib
import pickle
import torch

class ModelRegistry:
    """自定义模型注册表。"""

    def __init__(self, registry_path: str):
        self.registry_path = Path(registry_path)
        self.registry_path.mkdir(parents=True, exist_ok=True)
        self.index_file = self.registry_path / "index.json"
        self._load_index()

    def _load_index(self):
        """加载模型索引。"""
        if self.index_file.exists():
            with open(self.index_file, 'r') as f:
                self.index = json.load(f)
        else:
            self.index = {"models": {}}

    def _save_index(self):
        """保存模型索引。"""
        with open(self.index_file, 'w') as f:
            json.dump(self.index, f, indent=2)

    def _compute_hash(self, file_path: Path) -> str:
        """计算文件哈希。"""
        hash_md5 = hashlib.md5()
        with open(file_path, "rb") as f:
            for chunk in iter(lambda: f.read(4096), b""):
                hash_md5.update(chunk)
        return hash_md5.hexdigest()

    def register(
        self,
        model,
        model_name: str,
        version: str,
        metadata: Dict = None,
        metrics: Dict = None,
        artifacts: Dict = None
    ):
        """注册模型。"""
        model_path = self.registry_path / model_name / version
        model_path.mkdir(parents=True, exist_ok=True)

        # 保存模型
        model_file = model_path / "model.pth"
        torch.save(model.state_dict(), model_file)
        model_hash = self._compute_hash(model_file)

        # 保存元数据
        metadata = metadata or {}
        metadata.update({
            "version": version,
            "registered_at": datetime.now().isoformat(),
            "model_hash": model_hash,
            "model_path": str(model_file),
            "metrics": metrics or {}
        })

        metadata_file = model_path / "metadata.json"
        with open(metadata_file, 'w') as f:
            json.dump(metadata, f, indent=2)

        # 保存工件
        if artifacts:
            artifacts_dir = model_path / "artifacts"
            artifacts_dir.mkdir(exist_ok=True)
            for name, artifact in artifacts.items():
                artifact_path = artifacts_dir / name
                if isinstance(artifact, (dict, list)):
                    with open(artifact_path.with_suffix('.json'), 'w') as f:
                        json.dump(artifact, f, indent=2)
                else:
                    with open(artifact_path, 'wb') as f:
                        pickle.dump(artifact, f)

        # 更新索引
        if model_name not in self.index["models"]:
            self.index["models"][model_name] = {"versions": []}

        self.index["models"][model_name]["versions"].append({
            "version": version,
            "registered_at": metadata["registered_at"],
            "metrics": metrics or {},
            "stage": "Development"
        })

        self._save_index()

        return model_path

    def load(self, model_name: str, version: str = None, stage: str = None):
        """加载模型。"""
        if version:
            model_path = self.registry_path / model_name / version
        elif stage:
            # 查找指定阶段的版本
            versions = self.index["models"].get(model_name, {}).get("versions", [])
            version_info = next((v for v in versions if v["stage"] == stage), None)
            if not version_info:
                raise ValueError(f"未找到阶段 {stage} 的版本")
            model_path = self.registry_path / model_name / version_info["version"]
        else:
            # 获取最新版本
            versions = self.index["models"].get(model_name, {}).get("versions", [])
            if not versions:
                raise ValueError(f"未找到 {model_name} 的版本")
            latest = sorted(versions, key=lambda x: x["registered_at"])[-1]
            model_path = self.registry_path / model_name / latest["version"]

        # 加载元数据
        metadata_file = model_path / "metadata.json"
        with open(metadata_file) as f:
            metadata = json.load(f)

        # 加载模型
        model_file = model_path / "model.pth"
        # 假设模型类已知
        model = MyModel()  # 替换为实际模型类
        model.load_state_dict(torch.load(model_file))
        model.eval()

        return model, metadata

    def transition_stage(self, model_name: str, version: str, stage: str):
        """将模型转换到新阶段。"""
        versions = self.index["models"].get(model_name, {}).get("versions", [])
        for v in versions:
            if v["version"] == version:
                v["stage"] = stage
                break

        self._save_index()

    def list_versions(self, model_name: str):
        """列出模型的所有版本。"""
        return self.index["models"].get(model_name, {}).get("versions", [])

    def list_models(self):
        """列出所有注册的模型。"""
        return list(self.index["models"].keys())

# 使用
registry = ModelRegistry("./model_registry")

# 注册模型
registry.register(
    model=my_model,
    model_name="image_classifier",
    version="1.0.0",
    metadata={"framework": "pytorch", "task": "classification"},
    metrics={"accuracy": 0.95, "f1": 0.94}
)

# 转换到生产
registry.transition_stage("image_classifier", "1.0.0", "Production")

# 加载模型
model, metadata = registry.load("image_classifier", stage="Production")

元数据管理

模型元数据模式

from dataclasses import dataclass, asdict
from typing import Dict, List, Optional, Any
from datetime import datetime
from enum import Enum
import json

class ModelStage(Enum):
    DEVELOPMENT = "Development"
    STAGING = "Staging"
    PRODUCTION = "Production"
    ARCHIVED = "Archived"

@dataclass
class ModelMetadata:
    """全面的模型元数据。"""
    # 基础信息
    model_name: str
    version: str
    framework: str
    task: str

    # 训练信息
    training_data_version: str
    training_start: str
    training_end: str
    training_duration_seconds: float

    # 架构
    architecture: str
    parameters: int
    model_size_mb: float

    # 性能
    metrics: Dict[str, float]
    test_set: str

    # 部署
    stage: ModelStage
    deployed_at: Optional[str] = None
    deployment_environment: Optional[str] = None

    # 附加
    tags: List[str] = None
    description: str = ""
    git_commit: Optional[str] = None
    git_branch: Optional[str] = None

    # 谱系
    parent_model: Optional[str] = None
    parent_version: Optional[str] = None
    data_sources: List[str] = None

    # 合规
    data_retention_days: int = 365
    pii_present: bool = False
    gdpr_compliant: bool = True

    def to_dict(self) -> Dict:
        """转换为字典。"""
        data = asdict(self)
        data['stage'] = self.stage.value
        return data

    @classmethod
    def from_dict(cls, data: Dict) -> 'ModelMetadata':
        """从字典创建。"""
        if 'stage' in data and isinstance(data['stage'], str):
            data['stage'] = ModelStage(data['stage'])
        return cls(**data)

    def to_json(self) -> str:
        """转换为 JSON。"""
        return json.dumps(self.to_dict(), indent=2)

    @classmethod
    def from_json(cls, json_str: str) -> 'ModelMetadata':
        """从 JSON 创建。"""
        data = json.loads(json_str)
        return cls.from_dict(data)

# 使用
metadata = ModelMetadata(
    model_name="image_classifier",
    version="1.0.0",
    framework="pytorch",
    task="classification",
    training_data_version="v1.0",
    training_start="2024-01-01T00:00:00",
    training_end="2024-01-02T12:00:00",
    training_duration_seconds=86400.0,
    architecture="resnet50",
    parameters=25600000,
    model_size_mb=98.5,
    metrics={"accuracy": 0.95, "f1": 0.94, "precision": 0.93, "recall": 0.95},
    test_set="test_v1.0",
    stage=ModelStage.PRODUCTION,
    tags=["vision", "classification", "production"],
    description="初始生产模型",
    git_commit="abc1234",
    git_branch="main",
    data_sources=["imagenet", "custom_dataset"],
    data_retention_days=730,
    pii_present=False,
    gdpr_compliant=True
)

# 保存元数据
with open("model_metadata.json", "w") as f:
    f.write(metadata.to_json())

元数据存储

import sqlite3
from typing import List, Optional, Dict
from pathlib import Path

class MetadataStore:
    """基于 SQLite 的元数据存储。"""

    def __init__(self, db_path: str):
        self.db_path = Path(db_path)
        self.db_path.parent.mkdir(parents=True, exist_ok=True)
        self._init_db()

    def _init_db(self):
        """初始化数据库模式。"""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()

        cursor.execute("""
            CREATE TABLE IF NOT EXISTS models (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                model_name TEXT NOT NULL,
                version TEXT NOT NULL,
                framework TEXT NOT NULL,
                task TEXT NOT NULL,
                training_data_version TEXT,
                training_start TEXT,
                training_end TEXT,
                architecture TEXT,
                parameters INTEGER,
                model_size_mb REAL,
                stage TEXT NOT NULL,
                deployed_at TEXT,
                deployment_environment TEXT,
                tags TEXT,
                description TEXT,
                git_commit TEXT,
                git_branch TEXT,
                parent_model TEXT,
                parent_version TEXT,
                data_sources TEXT,
                data_retention_days INTEGER DEFAULT 365,
                pii_present INTEGER DEFAULT 0,
                gdpr_compliant INTEGER DEFAULT 1,
                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
                UNIQUE(model_name, version)
            )
        """)

        cursor.execute("""
            CREATE TABLE IF NOT EXISTS metrics (
                id INTEGER PRIMARY KEY AUTOINCREMENT,
                model_id INTEGER NOT NULL,
                metric_name TEXT NOT NULL,
                metric_value REAL NOT NULL,
                FOREIGN KEY (model_id) REFERENCES models (id)
            )
        """)

        conn.commit()
        conn.close()

    def save_metadata(self, metadata: ModelMetadata) -> int:
        """保存模型元数据。"""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()

        # 插入模型
        cursor.execute("""
            INSERT INTO models (
                model_name, version, framework, task,
                training_data_version, training_start, training_end,
                architecture, parameters, model_size_mb,
                stage, deployed_at, deployment_environment,
                tags, description, git_commit, git_branch,
                parent_model, parent_version, data_sources,
                data_retention_days, pii_present, gdpr_compliant
            ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
        """, (
            metadata.model_name, metadata.version, metadata.framework, metadata.task,
            metadata.training_data_version, metadata.training_start, metadata.training_end,
            metadata.architecture, metadata.parameters, metadata.model_size_mb,
            metadata.stage.value, metadata.deployed_at, metadata.deployment_environment,
            json.dumps(metadata.tags or []), metadata.description, metadata.git_commit, metadata.git_branch,
            metadata.parent_model, metadata.parent_version, json.dumps(metadata.data_sources or []),
            metadata.data_retention_days, 1 if metadata.pii_present else 0, 1 if metadata.gdpr_compliant else 0
        ))

        model_id = cursor.lastrowid

        # 插入指标
        for metric_name, metric_value in metadata.metrics.items():
            cursor.execute("""
                INSERT INTO metrics (model_id, metric_name, metric_value)
                VALUES (?, ?, ?)
            """, (model_id, metric_name, metric_value))

        conn.commit()
        conn.close()

        return model_id

    def get_metadata(self, model_name: str, version: str) -> Optional[ModelMetadata]:
        """获取模型元数据。"""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()

        cursor.execute("""
            SELECT * FROM models WHERE model_name = ? AND version = ?
        """, (model_name, version))

        row = cursor.fetchone()
        if not row:
            conn.close()
            return None

        # 获取指标
        cursor.execute("""
            SELECT metric_name, metric_value FROM metrics WHERE model_id = ?
        """, (row[0],))

        metrics = {metric_name: metric_value for metric_name, metric_value in cursor.fetchall()}
        conn.close()

        # 构建元数据
        return ModelMetadata(
            model_name=row[1],
            version=row[2],
            framework=row[3],
            task=row[4],
            training_data_version=row[5],
            training_start=row[6],
            training_end=row[7],
            training_duration_seconds=0,  # 未存储
            architecture=row[8],
            parameters=row[9],
            model_size_mb=row[10],
            stage=ModelStage(row[11]),
            deployed_at=row[12],
            deployment_environment=row[13],
            tags=json.loads(row[14]),
            description=row[15],
            git_commit=row[16],
            git_branch=row[17],
            parent_model=row[18],
            parent_version=row[19],
            data_sources=json.loads(row[20]),
            data_retention_days=row[21],
            pii_present=bool(row[22]),
            gdpr_compliant=bool(row[23]),
            metrics=metrics
        )

    def list_models(self, stage: ModelStage = None) -> List[Dict]:
        """列出所有模型。"""
        conn = sqlite3.connect(self.db_path)
        cursor = conn.cursor()

        if stage:
            cursor.execute("""
                SELECT model_name, version, stage, deployed_at FROM models WHERE stage = ?
            """, (stage.value,))
        else:
            cursor.execute("""
                SELECT model_name, version, stage, deployed_at FROM models
            """)

        models = [
            {"model_name": row[0], "version": row[1], "stage": row[2], "deployed_at": row[3]}
            for row in cursor.fetchall()
        ]

        conn.close()
        return models

# 使用
store = MetadataStore("./model_metadata.db")
store.save_metadata(metadata)

# 获取元数据
metadata = store.get_metadata("image_classifier", "1.0.0")

# 列出生产模型
production_models = store.list_models(stage=ModelStage.PRODUCTION)

模型谱系跟踪

谱系图

from typing import List, Dict, Optional
from dataclasses import dataclass
import networkx as nx
import matplotlib.pyplot as plt

@dataclass
class ModelNode:
    """模型谱系图中的节点。"""
    model_name: str
    version: str
    node_type: str  # "model", "data", "experiment"

class ModelLineage:
    """跟踪模型谱系。"""

    def __init__(self):
        self.graph = nx.DiGraph()

    def add_model(
        self,
        model_name: str,
        version: str,
        parent_model: Optional[str] = None,
        parent_version: Optional[str] = None,
        data_sources: Optional[List[str]] = None
    ):
        """向谱系添加模型。"""
        node_id = f"{model_name}:{version}"
        self.graph.add_node(node_id, model_name=model_name, version=version, node_type="model")

        # 添加父边
        if parent_model and parent_version:
            parent_id = f"{parent_model}:{parent_version}"
            self.graph.add_edge(parent_id, node_id, relation="derived_from")

        # 添加数据源边
        if data_sources:
            for data_source in data_sources:
                self.graph.add_node(data_source, node_type="data")
                self.graph.add_edge(data_source, node_id, relation="trained_on")

    def add_experiment(
        self,
        experiment_id: str,
        model_name: str,
        version: str,
        hyperparameters: Dict
    ):
        """向谱系添加实验。"""
        node_id = f"{model_name}:{version}"
        exp_id = f"experiment:{experiment_id}"

        self.graph.add_node(exp_id, node_type="experiment", hyperparameters=hyperparameters)
        self.graph.add_edge(exp_id, node_id, relation="produced")

    def get_ancestors(self, model_name: str, version: str) -> List[Dict]:
        """获取所有祖先模型。"""
        node_id = f"{model_name}:{version}"
        ancestors = nx.ancestors(self.graph, node_id)

        result = []
        for ancestor_id in ancestors:
            node = self.graph.nodes[ancestor_id]
            result.append({
                "id": ancestor_id,
                "type": node.get("node_type"),
                "data": node
            })

        return result

    def get_descendants(self, model_name: str, version: str) -> List[Dict]:
        """获取所有后代模型。"""
        node_id = f"{model_name}:{version}"
        descendants = nx.descendants(self.graph, node_id)

        result = []
        for descendant_id in descendants:
            node = self.graph.nodes[descendant_id]
            result.append({
                "id": descendant_id,
                "type": node.get("node_type"),
                "data": node
            })

        return result

    def visualize(self, output_path: str = None):
        """可视化谱系图。"""
        pos = nx.spring_layout(self.graph)

        # 按类型着色节点
        colors = []
        for node in self.graph.nodes():
            node_type = self.graph.nodes[node].get("node_type", "model")
            if node_type == "model":
                colors.append("lightblue")
            elif node_type == "data":
                colors.append("lightgreen")
            else:
                colors.append("lightyellow")

        plt.figure(figsize=(12, 8))
        nx.draw(self.graph, pos, node_color=colors, with_labels=True, node_size=1000, font_size=8)

        # 添加图例
        from matplotlib.patches import Patch
        legend_elements = [
            Patch(facecolor='lightblue', label='Model'),
            Patch(facecolor='lightgreen', label='Data'),
            Patch(facecolor='lightyellow', label='Experiment')
        ]
        plt.legend(handles=legend_elements)

        plt.title("Model Lineage")
        plt.tight_layout()

        if output_path:
            plt.savefig(output_path)
        else:
            plt.show()

# 使用
lineage = ModelLineage()

lineage.add_model("base_model", "1.0.0", data_sources=["dataset_v1"])
lineage.add_model("fine_tuned", "1.0.0", parent_model="base_model", parent_version="1.0.0",
                    data_sources=["dataset_v2"])
lineage.add_model("production", "1.0.0", parent_model="fine_tuned", parent_version="1.0.0")

# 获取谱系
ancestors = lineage.get_ancestors("production", "1.0.0")
descendants = lineage.get_descendants("base_model", "1.0.0")

# 可视化
lineage.visualize("lineage.png")

工件存储

工件管理器

import os
import shutil
import hashlib
from pathlib import Path
from typing import List, Optional
import json
from datetime import datetime

class ArtifactManager:
    """管理模型工件。"""

    def __init__(self, storage_path: str):
        self.storage_path = Path(storage_path)
        self.storage_path.mkdir(parents=True, exist_ok=True)
        self.index_file = self.storage_path / "artifacts.json"
        self._load_index()

    def _load_index(self):
        """加载工件索引。"""
        if self.index_file.exists():
            with open(self.index_file, 'r') as f:
                self.index = json.load(f)
        else:
            self.index = {"artifacts": {}}

    def _save_index(self):
        """保存工件索引。"""
        with open(self.index_file, 'w') as f:
            json.dump(self.index, f, indent=2)

    def _compute_hash(self, file_path: Path) -> str:
        """计算 SHA256 哈希。"""
        hash_sha256 = hashlib.sha256()
        with open(file_path, "rb") as f:
            for chunk in iter(lambda: f.read(4096), b""):
                hash_sha256.update(chunk)
        return hash_sha256.hexdigest()

    def store(
        self,
        source_path: str,
        artifact_name: str,
        version: str,
        metadata: Dict = None
    ) -> str:
        """存储工件。"""
        source_path = Path(source_path)
        artifact_hash = self._compute_hash(source_path)

        # 创建存储路径
        storage_dir = self.storage_path / artifact_name / version
        storage_dir.mkdir(parents=True, exist_ok=True)

        # 复制工件
        artifact_path = storage_dir / source_path.name
        shutil.copy2(source_path, artifact_path)

        # 存储元数据
        artifact_metadata = {
            "name": artifact_name,
            "version": version,
            "hash": artifact_hash,
            "size_bytes": source_path.stat().st_size,
            "stored_at": datetime.now().isoformat(),
            "path": str(artifact_path),
            "metadata": metadata or {}
        }

        # 更新索引
        key = f"{artifact_name}:{version}"
        self.index["artifacts"][key] = artifact_metadata
        self._save_index()

        return str(artifact_path)

    def retrieve(self, artifact_name: str, version: str, destination: str = None) -> Path:
        """检索工件。"""
        key = f"{artifact_name}:{version}"

        if key not in self.index["artifacts"]:
            raise ValueError(f"未找到工件 {key}")

        artifact_info = self.index["artifacts"][key]
        artifact_path = Path(artifact_info["path"])

        if destination:
            dest_path = Path(destination)
            shutil.copy2(artifact_path, dest_path)
            return dest_path

        return artifact_path

    def list_artifacts(self, artifact_name: str = None) -> List[Dict]:
        """列出工件。"""
        if artifact_name:
            return [
                v for k, v in self.index["artifacts"].items()
                if k.startswith(f"{artifact_name}:")
            ]
        return list(self.index["artifacts"].values())

    def delete(self, artifact_name: str, version: str):
        """删除工件。"""
        key = f"{artifact_name}:{version}"

        if key not in self.index["artifacts"]:
            raise ValueError(f"未找到工件 {key}")

        artifact_info = self.index["artifacts"][key]
        artifact_path = Path(artifact_info["path"])

        # 删除文件
        if artifact_path.exists():
            artifact_path.unlink()

        # 删除空目录
        artifact_path.parent.rmdir()
        if artifact_path.parent.parent.name == artifact_name:
            artifact_path.parent.parent.rmdir()

        # 更新索引
        del self.index["artifacts"][key]
        self._save_index()

# 使用
manager = ArtifactManager("./artifacts")

# 存储工件
manager.store(
    source_path="model.pth",
    artifact_name="image_classifier",
    version="1.0.0",
    metadata={"framework": "pytorch", "task": "classification"}
)

# 检索工件
artifact_path = manager.retrieve("image_classifier", "1.0.0", destination="./downloaded_model.pth")

模型推广工作流程

推广管道

from enum import Enum
from typing import Callable, Optional
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class PromotionStage(Enum):
    DEVELOPMENT = "Development"
    STAGING = "Staging"
    PRODUCTION = "Production"
    ARCHIVED = "Archived"

class ModelPromotionPipeline:
    """管理模型通过阶段推广。"""

    def __init__(self, registry):
        self.registry = registry
        self.pre_promotion_hooks = {}
        self.post_promotion_hooks = {}

    def register_pre_hook(self, stage: PromotionStage, hook: Callable):
        """注册推广前钩子。"""
        if stage not in self.pre_promotion_hooks:
            self.pre_promotion_hooks[stage] = []
        self.pre_promotion_hooks[stage].append(hook)

    def register_post_hook(self, stage: PromotionStage, hook: Callable):
        """注册推广后钩子。"""
        if stage not in self.post_promotion_hooks:
            self.post_promotion_hooks[stage] = []
        self.post_promotion_hooks[stage].append(hook)

    def promote(
        self,
        model_name: str,
        version: str,
        from_stage: PromotionStage,
        to_stage: PromotionStage,
        force: bool = False
    ):
        """将模型推广到下一阶段。"""
        logger.info(f"Promoting {model_name}:{version} from {from_stage} to {to_stage}")

        # 运行推广前钩子
        if to_stage in self.pre_promotion_hooks:
            for hook in self.pre_promotion_hooks[to_stage]:
                hook(model_name, version, from_stage, to_stage)

        # 验证检查
        if not force and to_stage == PromotionStage.PRODUCTION:
            self._validate_for_production(model_name, version)

        # 执行推广
        self.registry.transition_stage(model_name, version, to_stage.value)

        logger.info(f"Successfully promoted {model_name}:{version} to {to_stage}")

        # 运行推广后钩子
        if to_stage in self.post_promotion_hooks:
            for hook in self.post_promotion_hooks[to_stage]:
                hook(model_name, version, from_stage, to_stage)

    def _validate_for_production(self, model_name: str, version: str):
        """在生产部署前验证模型。"""
        logger.info(f"Validating {model_name}:{version} for production")

        # 检查指标
        metadata = self.registry.get_metadata(model_name, version)
        if not metadata:
            raise ValueError(f"Model metadata not found")

        # 检查最低准确度阈值
        min_accuracy = 0.90
        accuracy = metadata.metrics.get("accuracy", 0)
        if accuracy < min_accuracy:
            raise ValueError(
                f"Model accuracy {accuracy} below minimum threshold {min_accuracy}"
            )

        # 检查所需测试
        # 在此处添加您的验证逻辑

        logger.info(f"Validation passed for {model_name}:{version}")

# 示例钩子
def log_promotion(model_name: str, version: str, from_stage: PromotionStage, to_stage: PromotionStage):
    """记录推广事件。"""
    logger.info(f"PROMOTION: {model_name}:{version} {from_stage} -> {to_stage}")

def notify_team(model_name: str, version: str, from_stage: PromotionStage, to_stage: PromotionStage):
    """通知团队关于生产部署。"""
    if to_stage == PromotionStage.PRODUCTION:
        # 发送通知(电子邮件、Slack 等)
        logger.info(f"NOTIFICATION: {model_name}:{version} deployed to production")

def run_sanity_checks(model_name: str, version: str, from_stage: PromotionStage, to_stage: PromotionStage):
    """在推广前运行健全性检查。"""
    if to_stage == PromotionStage.PRODUCTION:
        # 运行健全性检查
        logger.info(f"Running sanity checks for {model_name}:{version}")
        # 在此处添加您的健全性检查逻辑

# 使用
pipeline = ModelPromotionPipeline(registry)

# 注册钩子
pipeline.register_pre_hook(PromotionStage.PRODUCTION, run_sanity_checks)
pipeline.register_post_hook(PromotionStage.PRODUCTION, log_promotion)
pipeline.register_post_hook(PromotionStage.PRODUCTION, notify_team)

# 推广模型
pipeline.promote(
    model_name="image_classifier",
    version="1.0.0",
    from_stage=PromotionStage.STAGING,
    to_stage=PromotionStage.PRODUCTION
)

回滚策略

class ModelRollback:
    """处理模型回滚。"""

    def __init__(self, registry):
        self.registry = registry
        self.rollback_history = {}

    def create_checkpoint(self, model_name: str, version: str):
        """创建回滚检查点。"""
        logger.info(f"Creating checkpoint for {model_name}:{version}")

        # 获取当前生产版本
        current = self.registry.get_model_version(model_name, stage="Production")

        if current:
            checkpoint = {
                "model_name": model_name,
                "version": current.version,
                "created_at": datetime.now().isoformat(),
                "metadata": self.registry.get_metadata(model_name, current.version)
            }

            key = f"{model_name}:{version}"
            self.rollback_history[key] = checkpoint

            logger.info(f"Checkpoint created: {current.version}")
        else:
            logger.warning(f"No production version found for {model_name}")

    def rollback(self, model_name: str, to_version: str = None):
        """回滚到先前版本。"""
        logger.info(f"Rolling back {model_name} to {to_version or 'previous version'}")

        if to_version:
            # 回滚到特定版本
            self.registry.transition_stage(model_name, to_version, "Production")
        else:
            # 回滚到最后检查点
            current_production = self.registry.get_model_version(model_name, stage="Production")
            if current_production:
                key = f"{model_name}:{current_production.version}"
                if key in self.rollback_history:
                    checkpoint = self.rollback_history[key]
                    self.registry.transition_stage(model_name, checkpoint["version"], "Production")
                    logger.info(f"Rolled back to {checkpoint['version']}")
                else:
                    logger.warning(f"No checkpoint found for {key}")

    def get_rollback_history(self, model_name: str) -> List[Dict]:
        """获取模型的回滚历史。"""
        return [
            v for k, v in self.rollback_history.items()
            if k.startswith(f"{model_name}:")
        ]

# 使用
rollback = ModelRollback(registry)

# 在推广前创建检查点
rollback.create_checkpoint("image_classifier", "1.0.0")

# 推广新版本
pipeline.promote("image_classifier", "1.0.0", PromotionStage.STAGING, PromotionStage.PRODUCTION)

# 如果需要,回滚
rollback.rollback("image_classifier")

A/B 测试设置

A/B 测试管理器

from typing import Dict, List
import random
import numpy as np

class ABTestManager:
    """管理模型的 A/B 测试。"""

    def __init__(self, registry):
        self.registry = registry
        self.active_tests = {}

    def create_test(
        self,
        test_name: str,
        model_a: str,
        version_a: str,
        model_b: str,
        version_b: str,
        traffic_split: float = 0.5,
        metrics: List[str] = None
    ):
        """创建 A/B 测试。"""
        test_id = f"{test_name}_{datetime.now().strftime('%Y%m%d%H%M%S')}"

        self.active_tests[test_id] = {
            "test_name": test_name,
            "model_a": {"name": model_a, "version": version_a},
            "model_b": {"name": model_b, "version": version_b},
            "traffic_split": traffic_split,
            "metrics": metrics or ["accuracy", "latency"],
            "created_at": datetime.now().isoformat(),
            "results_a": [],
            "results_b": []
        }

        logger.info(f"Created A/B test: {test_id}")
        return test_id

    def route_request(self, test_id: str, request_id: str = None) -> str:
        """将请求路由到模型 A 或 B。"""
        if test_id not in self.active_tests:
            raise ValueError(f"Test {test_id} not found")

        test = self.active_tests[test_id]

        # 使用 request_id 进行一致路由
        if request_id:
            hash_val = hash(request_id) % 1000
            rand_val = hash_val / 1000.0
        else:
            rand_val = random.random()

        if rand_val < test["traffic_split"]:
            return "A"
        else:
            return "B"

    def record_result(self, test_id: str, model: str, result: Dict):
        """记录测试结果。"""
        if test_id not in self.active_tests:
            raise ValueError(f"Test {test_id} not found")

        test = self.active_tests[test_id]

        if model == "A":
            test["results_a"].append(result)
        elif model == "B":
            test["results_b"].append(result)

    def get_results(self, test_id: str) -> Dict:
        """获取 A/B 测试结果。"""
        if test_id not in self.active_tests:
            raise ValueError(f"Test {test_id} not found")

        test = self.active_tests[test_id]

        # 计算统计
        results_a = test["results_a"]
        results_b = test["results_b"]

        stats = {
            "test_name": test["test_name"],
            "model_a": test["model_a"],
            "model_b": test["model_b"],
            "traffic_split": test["traffic_split"],
            "samples_a": len(results_a),
            "samples_b": len(results_b)
        }

        # 计算指标
        for metric in test["metrics"]:
            values_a = [r.get(metric, 0) for r in results_a]
            values_b = [r.get(metric, 0) for r in results_b]

            stats[f"{metric}_a_mean"] = np.mean(values_a) if values_a else 0
            stats[f"{metric}_b_mean"] = np.mean(values_b) if values_b else 0
            stats[f"{metric}_a_std"] = np.std(values_a) if values_a else 0
            stats[f"{metric}_b_std"] = np.std(values_b) if values_b else 0

        return stats

    def conclude_test(self, test_id: str, winner: str = None) -> Dict:
        """结束 A/B 测试并推广胜者。"""
        results = self.get_results(test_id)

        if winner:
            if winner == "A":
                winner_model = self.active_tests[test_id]["model_a"]
            else:
                winner_model = self.active_tests[test_id]["model_b"]

            # 推广胜者
            self.registry.transition_stage(
                winner_model["name"],
                winner_model["version"],
                "Production"
            )

            results["winner"] = winner
            results["winner_model"] = winner_model

        # 归档测试
        del self.active_tests[test_id]

        return results

# 使用
ab_test = ABTestManager(registry)

# 创建 A/B 测试
test_id = ab_test.create_test(
    test_name="model_comparison",
    model_a="image_classifier",
    version_a="1.0.0",
    model_b="image_classifier",
    version_b="1.1.0",
    traffic_split=0.5,
    metrics=["accuracy", "latency_ms"]
)

# 路由请求
for request_id in request_ids:
    model = ab_test.route_request(test_id, request_id)
    result = run_inference(model, request_id)
    ab_test.record_result(test_id, model, result)

# 获取结果
results = ab_test.get_results(test_id)

# 结束测试
ab_test.conclude_test(test_id, winner="B")

模型比较

模型比较工具

from typing import Dict, List
import pandas as pd

class ModelComparator:
    """比较多个模型。"""

    def __init__(self, registry):
        self.registry = registry

    def compare_models(
        self,
        models: List[Dict],
        metrics: List[str] = None
    ) -> pd.DataFrame:
        """比较多个模型。"""
        comparison_data = []

        for model_info in models:
            model_name = model_info["name"]
            version = model_info["version"]

            metadata = self.registry.get_metadata(model_name, version)

            row = {
                "Model": f"{model_name}:{version}",
                "Framework": metadata.framework,
                "Architecture": metadata.architecture,
                "Parameters": metadata.parameters,
                "Size (MB)": metadata.model_size_mb,
                "Stage": metadata.stage.value
            }

            # 添加指标
            for metric in (metrics or list(metadata.metrics.keys())):
                row[metric] = metadata.metrics.get(metric, 0)

            comparison_data.append(row)

        return pd.DataFrame(comparison_data)

    def compare_versions(
        self,
        model_name: str,
        versions: List[str]
    ) -> pd.DataFrame:
        """比较模型的不同版本。"""
        models = [{"name": model_name, "version": v} for v in versions]
        return self.compare_models(models)

    def find_best_model(
        self,
        model_name: str,
        metric: str,
        maximize: bool = True
    ) -> Dict:
        """按指标找到最佳模型版本。"""
        versions = self.registry.list_versions(model_name)

        best_version = None
        best_value = float('-inf') if maximize else float('inf')

        for version_info in versions:
            metadata = self.registry.get_metadata(model_name, version_info["version"])
            value = metadata.metrics.get(metric, 0)

            if (maximize and value > best_value) or (not maximize and value < best_value):
                best_value = value
                best_version = version_info["version"]

        return {
            "model_name": model_name,
            "version": best_version,
            "metric": metric,
            "value": best_value
        }

# 使用
comparator = ModelComparator(registry)

# 比较模型
comparison = comparator.compare_models([
    {"name": "image_classifier", "version": "1.0.0"},
    {"name": "image_classifier", "version": "1.1.0"},
    {"name": "image_classifier", "version": "2.0.0"}
], metrics=["accuracy", "f1", "latency_ms"])

print(comparison)

# 找到最佳模型
best = comparator.find_best_model("image_classifier", "accuracy", maximize=True)
print(f"Best model: {best}")

最佳实践

版本控制指南

  1. 使用语义版本控制

    • MAJOR: 重大变更
    • MINOR: 新功能,向后兼容
    • PATCH: 错误修复,向后兼容
  2. 在 Git 中标签发布

    git tag -a v1.0.0 -m "Release version 1.0.0"
    git push origin v1.0.0
    
  3. CHANGELOG.md 中记录变更

    # 变更日志
    
    ## [1.1.0] - 2024-01-15
    ### 添加
    - 新功能 X
    - 新功能 Y
    
    ### 更改
    - 改进了模型性能
    
    ## [1.0.0] - 2024-01-01
    ### 添加
    - 初始发布
    
  4. 使用 Git 提交哈希以确保可重复性

    metadata = {
        "git_commit": GitVersion.get_commit_hash(),
        "git_branch": GitVersion.get_branch()
    }
    
  5. 存储模型超参数

    hyperparameters = {
        "learning_rate": 0.001,
        "batch_size": 32,
        "epochs": 100,
        "optimizer": "Adam"
    }
    

注册表最佳实践

  1. 生产部署前始终验证

    def validate_model(model, test_loader, thresholds):
        """验证模型满足生产阈值。"""
        metrics = evaluate(model, test_loader)
    
        for metric, threshold in thresholds.items():
            if metrics[metric] < threshold:
                raise ValueError(
                    f"Model {metric} ({metrics[metric]}) below threshold ({threshold})"
                )
    
        return metrics
    
  2. 保持模型谱系

    lineage = ModelLineage()
    lineage.add_model(
        model_name="model_v2",
        version="1.0.0",
        parent_model="model_v1",
        parent_version="1.0.0"
    )
    
  3. 使用一致的元数据

    metadata = ModelMetadata(
        model_name="my_model",
        version="1.0.0",
        framework="pytorch",
        task="classification",
        # ... 所有必填字段
    )
    
  4. 归档旧模型

    def archive_old_models(registry, model_name, keep_versions=5):
        """归档旧模型版本。"""
        versions = registry.list_versions(model_name)
    
        # 按注册日期排序
        versions.sort(key=lambda x: x["registered_at"])
    
        # 归档除最新 N 个版本外的所有版本
        for version_info in versions[:-keep_versions]:
            registry.transition_stage(
                model_name,
                version_info["version"],
                "Archived"
            )
    
  5. 监控生产模型

    def monitor_production_models(registry, alert_thresholds):
        """监控生产模型的问题。"""
        production_models = registry.list_models(stage="Production")
    
        for model_info in production_models:
            model, metadata = registry.load(
                model_info["model_name"],
                model_info["version"]
            )
    
            # 检查模型健康状况
            health = check_model_health(model)
    
            if not health["healthy"]:
                # 发送警报
                send_alert(
                    model_name=model_info["model_name"],
                    version=model_info["version"],
                    issue=health["issue"]
                )
    

相关技能