架构设计Skill architecture-design

这个技能用于在机器学习项目中设计和实施可扩展的架构,通过工厂和注册模式来动态创建和管理数据集、模型、数据增强等可注册组件,支持项目模块化和扩展性。关键词:机器学习架构,工厂模式,注册模式,组件注册,设计模式,ML项目开发

机器学习 0 次安装 0 次浏览 更新于 3/13/2026

name: architecture-design description: | 仅在需要工厂/注册模式时,用于在ML项目中创建新的可注册组件。

✅ 使用时机:

  • 创建新的Dataset类(需要使用@register_dataset)
  • 创建新的Model类(需要使用@register_model)
  • 创建带有__init__.py工厂的新模块目录
  • 从头开始初始化新的ML项目结构
  • 添加新的组件类型(Augmentation、CollateFunction、Metrics)

❌ 不要使用时机:

  • 修改现有函数或方法
  • 修复现有代码中的错误
  • 添加辅助函数或实用工具
  • 不添加新可注册组件的重构
  • 对单个文件的简单代码更改
  • 修改配置文件
  • 阅读或理解现有代码

关键指标:任务是否需要@register_*装饰器或工厂模式?如果不需要,请跳过此技能。 version: 1.2.0

架构设计 - ML项目模板

本技能定义了基于模板结构的机器学习项目的标准代码架构。在修改或扩展代码时,请遵循这些模式以保持一致性。

概述

项目遵循模块化、可扩展的架构,关注点分离清晰。每个模块(数据、模型、训练器、分析)使用工厂和注册模式独立组织,以实现最大灵活性。

核心设计模式

工厂模式

每个模块使用工厂动态创建实例:

# 示例来自 data_module/dataset/__init__.py
DATASET_FACTORY: Dict = {}

def DatasetFactory(data_name: str):
    dataset = DATASET_FACTORY.get(data_name, None)
    if dataset is None:
        print(f"{data_name} 数据集未实现,使用简单数据集")
        dataset = DATASET_FACTORY.get('simple')
    return dataset

详细指南,请参考 references/factory_pattern.md

注册模式

组件通过装饰器自行注册:

# 示例来自 data_module/dataset/simple_dataset.py
@register_dataset("simple")
class SimpleDataset(Dataset):
    def __init__(self, data):
        self.data = data

详细指南,请参考 references/registry_pattern.md

自动导入模式

模块自动发现并导入子模块:

# 示例来自 data_module/dataset/__init__.py
models_dir = os.path.dirname(__file__)
import_modules(models_dir, "src.data_module.dataset")

详细指南,请参考 references/auto_import.md

目录结构

project/
├── run/
│   ├── pipeline/            # 主要工作流脚本
│   │   ├── training/        # 训练管道
│   │   ├── prepare_data/    # 数据准备管道
│   │   └── analysis/        # 分析管道
│   └── conf/                # Hydra配置文件
│       ├── training/        # 训练配置
│       ├── dataset/         # 数据集配置
│       ├── model/           # 模型配置
│       ├── prepare_data/    # 数据准备配置
│       └── analysis/        # 分析配置
│
├── src/
│   ├── data_module/         # 数据处理模块
│   │   ├── dataset/         # 数据集实现
│   │   ├── augmentation/    # 数据增强
│   │   ├── collate_fn/      # 聚合函数
│   │   ├── compute_metrics/ # 指标计算
│   │   ├── prepare_data/    # 数据准备逻辑
│   │   ├── data_func/       # 数据实用函数
│   │   └── utils.py         # 模块特定实用工具
│   │
│   ├── model_module/        # 模型实现
│   │   ├── brain_decoder/   # 脑解码器模型
│   │   └── model/           # 备用模型位置
│   │
│   ├── trainer_module/      # 训练逻辑
│   ├── analysis_module/     # 分析和评估
│   ├── llm/                 # LLM相关代码
│   └── utils/               # 共享实用工具
│
├── data/
│   ├── raw/                 # 原始、不可变数据
│   ├── processed/           # 清理、转换数据
│   └── external/            # 第三方数据
│
├── outputs/
│   ├── logs/                # 训练和评估日志
│   ├── checkpoints/         # 模型检查点
│   ├── tables/              # 结果表格
│   └── figures/             # 图表和可视化
│
├── pyproject.toml           # 项目配置
├── uv.lock                  # 依赖锁定文件
├── TODO.md                  # 任务跟踪
├── README.md                # 项目文档
└── .gitignore               # Git忽略规则

详细目录结构和文件描述,请参考 references/structure.md

模块组织

创建新数据集

添加新数据集时:

  1. src/data_module/dataset/ 中创建文件
  2. 使用 @register_dataset("name") 装饰器
  3. 继承自 torch.utils.data.Dataset
  4. 实现 __init____len____getitem__
from torch.utils.data import Dataset
from typing import Dict
import torch
from src.data_module.dataset import register_dataset

@register_dataset("custom")
class CustomDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, i: int) -> Dict[str, torch.Tensor]:
        return self.data[i]

创建新模型

关键:模型使用配置驱动模式

添加新模型时:

  1. src/model_module/model/ 或适当的模块子目录中创建文件
  2. 使用 @register_model('ModelName') 装饰器
  3. __init__ 仅接受 cfg 参数——所有超参数来自配置
  4. forward() 返回字典:{"loss": loss, "labels": labels, "logits": logits}
  5. 使用 self.training 处理训练与推理模式
from src.model_module.brain_decoder import register_model

@register_model('MyModel')
class MyModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.task = cfg.dataset.task

        # 所有参数来自 cfg
        self.hidden_dim = cfg.model.hidden_dim
        self.output_dim = cfg.dataset.target_size[cfg.dataset.task]

    def forward(self, x, labels=None, **kwargs):
        if self.training:
            # 训练逻辑
            pass
        else:
            # 推理逻辑
            pass

        return {"loss": loss, "labels": labels, "logits": logits}

添加数据增强

添加增强时:

  1. src/data_module/augmentation/ 中创建文件
  2. 实现转换函数
  3. 如果需要,用工厂注册

代码风格指南

全面风格指南,请参考 references/code_style.md

关键原则:

  • 始终对函数签名使用类型提示
  • 遵循导入顺序:标准库 → 第三方 → 本地
  • 模块 __init__.py 文件包含工厂/注册逻辑
  • 模型类必须是配置驱动

配置管理

项目使用Hydra进行配置管理:

  • 配置文件在 run/conf/ 中按模块组织
  • 每个阶段(训练、分析)有自己的配置结构
  • 对所有配置使用YAML文件

项目工作流程

修改代码前

  1. 阅读相关模块的工厂/注册模式
  2. 检查现有实现的一致性
  3. 遵循已建立的目录结构
  4. 对新组件使用注册装饰器

添加新功能

  1. 确定功能所属模块
  2. 检查是否存在类似功能
  3. 如果创建新组件类型,遵循工厂/注册模式
  4. 如果需要,添加配置文件
  5. 更新文档

代码审查清单

  • [ ] 适当使用工厂/注册模式
  • [ ] 遵循模块目录结构
  • [ ] 有正确的类型注释
  • [ ] 导入顺序正确
  • [ ] 使用注册装饰器
  • [ ] 如果需要,添加配置文件

额外资源

参考文件

详细信息,请查阅:

  • references/structure.md - 详细目录结构和文件描述
  • references/factory_pattern.md - 工厂模式深入解释
  • references/registry_pattern.md - 注册模式深入解释
  • references/auto_import.md - 自动导入模式深入解释
  • references/code_style.md - 全面代码风格指南

示例文件

工作示例在 examples/ 中:

  • examples/custom_dataset.py - 自定义数据集实现
  • examples/custom_model.py - 自定义模型实现
  • examples/augmentation_example.py - 数据增强示例
  • examples/config_example.yaml - 配置文件示例
  • examples/pipeline_example.sh - 管道脚本示例