name: architecture-design description: | 仅在需要工厂/注册模式时,用于在ML项目中创建新的可注册组件。
✅ 使用时机:
- 创建新的Dataset类(需要使用@register_dataset)
- 创建新的Model类(需要使用@register_model)
- 创建带有__init__.py工厂的新模块目录
- 从头开始初始化新的ML项目结构
- 添加新的组件类型(Augmentation、CollateFunction、Metrics)
❌ 不要使用时机:
- 修改现有函数或方法
- 修复现有代码中的错误
- 添加辅助函数或实用工具
- 不添加新可注册组件的重构
- 对单个文件的简单代码更改
- 修改配置文件
- 阅读或理解现有代码
关键指标:任务是否需要@register_*装饰器或工厂模式?如果不需要,请跳过此技能。 version: 1.2.0
架构设计 - ML项目模板
本技能定义了基于模板结构的机器学习项目的标准代码架构。在修改或扩展代码时,请遵循这些模式以保持一致性。
概述
项目遵循模块化、可扩展的架构,关注点分离清晰。每个模块(数据、模型、训练器、分析)使用工厂和注册模式独立组织,以实现最大灵活性。
核心设计模式
工厂模式
每个模块使用工厂动态创建实例:
# 示例来自 data_module/dataset/__init__.py
DATASET_FACTORY: Dict = {}
def DatasetFactory(data_name: str):
dataset = DATASET_FACTORY.get(data_name, None)
if dataset is None:
print(f"{data_name} 数据集未实现,使用简单数据集")
dataset = DATASET_FACTORY.get('simple')
return dataset
详细指南,请参考 references/factory_pattern.md。
注册模式
组件通过装饰器自行注册:
# 示例来自 data_module/dataset/simple_dataset.py
@register_dataset("simple")
class SimpleDataset(Dataset):
def __init__(self, data):
self.data = data
详细指南,请参考 references/registry_pattern.md。
自动导入模式
模块自动发现并导入子模块:
# 示例来自 data_module/dataset/__init__.py
models_dir = os.path.dirname(__file__)
import_modules(models_dir, "src.data_module.dataset")
详细指南,请参考 references/auto_import.md。
目录结构
project/
├── run/
│ ├── pipeline/ # 主要工作流脚本
│ │ ├── training/ # 训练管道
│ │ ├── prepare_data/ # 数据准备管道
│ │ └── analysis/ # 分析管道
│ └── conf/ # Hydra配置文件
│ ├── training/ # 训练配置
│ ├── dataset/ # 数据集配置
│ ├── model/ # 模型配置
│ ├── prepare_data/ # 数据准备配置
│ └── analysis/ # 分析配置
│
├── src/
│ ├── data_module/ # 数据处理模块
│ │ ├── dataset/ # 数据集实现
│ │ ├── augmentation/ # 数据增强
│ │ ├── collate_fn/ # 聚合函数
│ │ ├── compute_metrics/ # 指标计算
│ │ ├── prepare_data/ # 数据准备逻辑
│ │ ├── data_func/ # 数据实用函数
│ │ └── utils.py # 模块特定实用工具
│ │
│ ├── model_module/ # 模型实现
│ │ ├── brain_decoder/ # 脑解码器模型
│ │ └── model/ # 备用模型位置
│ │
│ ├── trainer_module/ # 训练逻辑
│ ├── analysis_module/ # 分析和评估
│ ├── llm/ # LLM相关代码
│ └── utils/ # 共享实用工具
│
├── data/
│ ├── raw/ # 原始、不可变数据
│ ├── processed/ # 清理、转换数据
│ └── external/ # 第三方数据
│
├── outputs/
│ ├── logs/ # 训练和评估日志
│ ├── checkpoints/ # 模型检查点
│ ├── tables/ # 结果表格
│ └── figures/ # 图表和可视化
│
├── pyproject.toml # 项目配置
├── uv.lock # 依赖锁定文件
├── TODO.md # 任务跟踪
├── README.md # 项目文档
└── .gitignore # Git忽略规则
详细目录结构和文件描述,请参考 references/structure.md。
模块组织
创建新数据集
添加新数据集时:
- 在
src/data_module/dataset/中创建文件 - 使用
@register_dataset("name")装饰器 - 继承自
torch.utils.data.Dataset - 实现
__init__、__len__、__getitem__
from torch.utils.data import Dataset
from typing import Dict
import torch
from src.data_module.dataset import register_dataset
@register_dataset("custom")
class CustomDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, i: int) -> Dict[str, torch.Tensor]:
return self.data[i]
创建新模型
关键:模型使用配置驱动模式
添加新模型时:
- 在
src/model_module/model/或适当的模块子目录中创建文件 - 使用
@register_model('ModelName')装饰器 __init__仅接受cfg参数——所有超参数来自配置forward()返回字典:{"loss": loss, "labels": labels, "logits": logits}- 使用
self.training处理训练与推理模式
from src.model_module.brain_decoder import register_model
@register_model('MyModel')
class MyModel(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
self.task = cfg.dataset.task
# 所有参数来自 cfg
self.hidden_dim = cfg.model.hidden_dim
self.output_dim = cfg.dataset.target_size[cfg.dataset.task]
def forward(self, x, labels=None, **kwargs):
if self.training:
# 训练逻辑
pass
else:
# 推理逻辑
pass
return {"loss": loss, "labels": labels, "logits": logits}
添加数据增强
添加增强时:
- 在
src/data_module/augmentation/中创建文件 - 实现转换函数
- 如果需要,用工厂注册
代码风格指南
全面风格指南,请参考 references/code_style.md。
关键原则:
- 始终对函数签名使用类型提示
- 遵循导入顺序:标准库 → 第三方 → 本地
- 模块
__init__.py文件包含工厂/注册逻辑 - 模型类必须是配置驱动
配置管理
项目使用Hydra进行配置管理:
- 配置文件在
run/conf/中按模块组织 - 每个阶段(训练、分析)有自己的配置结构
- 对所有配置使用YAML文件
项目工作流程
修改代码前
- 阅读相关模块的工厂/注册模式
- 检查现有实现的一致性
- 遵循已建立的目录结构
- 对新组件使用注册装饰器
添加新功能
- 确定功能所属模块
- 检查是否存在类似功能
- 如果创建新组件类型,遵循工厂/注册模式
- 如果需要,添加配置文件
- 更新文档
代码审查清单
- [ ] 适当使用工厂/注册模式
- [ ] 遵循模块目录结构
- [ ] 有正确的类型注释
- [ ] 导入顺序正确
- [ ] 使用注册装饰器
- [ ] 如果需要,添加配置文件
额外资源
参考文件
详细信息,请查阅:
references/structure.md- 详细目录结构和文件描述references/factory_pattern.md- 工厂模式深入解释references/registry_pattern.md- 注册模式深入解释references/auto_import.md- 自动导入模式深入解释references/code_style.md- 全面代码风格指南
示例文件
工作示例在 examples/ 中:
examples/custom_dataset.py- 自定义数据集实现examples/custom_model.py- 自定义模型实现examples/augmentation_example.py- 数据增强示例examples/config_example.yaml- 配置文件示例examples/pipeline_example.sh- 管道脚本示例