PyTorchLightningSkill pytorch-lightning

PyTorch Lightning 是一个用于组织和自动化 PyTorch 深度学习项目的框架,支持多GPU/TPU训练、分布式策略、数据管道管理、实验跟踪和模块化代码结构,提升神经网络训练的效率和可扩展性。关键词:深度学习框架, PyTorch, 训练自动化, 分布式训练, 多GPU, 神经网络训练, 实验日志。

深度学习 0 次安装 0 次浏览 更新于 3/16/2026

name: pytorch-lightning description: “深度学习框架(PyTorch Lightning)。将PyTorch代码组织成LightningModules,配置Trainers用于多GPU/TPU训练,实现数据管道、回调、日志记录(W&B, TensorBoard),分布式训练(DDP, FSDP, DeepSpeed),用于可扩展的神经网络训练。”

PyTorch Lightning

概述

PyTorch Lightning 是一个深度学习框架,它组织 PyTorch 代码以消除样板代码,同时保持完全灵活性。自动化训练工作流程、多设备编排,并实现神经网络训练和跨多个 GPU/TPU 扩展的最佳实践。

何时使用此技能

此技能应在以下情况使用:

  • 使用 PyTorch Lightning 构建、训练或部署神经网络
  • 将 PyTorch 代码组织成 LightningModules
  • 配置 Trainers 用于多 GPU/TPU 训练
  • 使用 LightningDataModules 实现数据管道
  • 处理回调、日志记录和分布式训练策略(DDP, FSDP, DeepSpeed)
  • 专业地结构化深度学习项目

核心能力

1. LightningModule - 模型定义

将 PyTorch 模型组织成六个逻辑部分:

  1. 初始化 - __init__()setup()
  2. 训练循环 - training_step(batch, batch_idx)
  3. 验证循环 - validation_step(batch, batch_idx)
  4. 测试循环 - test_step(batch, batch_idx)
  5. 预测 - predict_step(batch, batch_idx)
  6. 优化器配置 - configure_optimizers()

快速模板参考: 参见 scripts/template_lightning_module.py 获取完整的样板代码。

详细文档: 阅读 references/lightning_module.md 获取全面的方法文档、钩子、属性和最佳实践。

2. Trainer - 训练自动化

Trainer 自动化训练循环、设备管理、梯度操作和回调。关键特性:

  • 多 GPU/TPU 支持,可选择策略(DDP, FSDP, DeepSpeed)
  • 自动混合精度训练
  • 梯度累积和裁剪
  • 检查点保存和早停
  • 进度条和日志记录

快速设置参考: 参见 scripts/quick_trainer_setup.py 获取常见的 Trainer 配置示例。

详细文档: 阅读 references/trainer.md 获取所有参数、方法和配置选项。

3. LightningDataModule - 数据管道组织

将所有数据处理步骤封装在可重用的类中:

  1. prepare_data() - 下载和处理数据(单进程)
  2. setup() - 创建数据集并应用变换(每个 GPU)
  3. train_dataloader() - 返回训练 DataLoader
  4. val_dataloader() - 返回验证 DataLoader
  5. test_dataloader() - 返回测试 DataLoader

快速模板参考: 参见 scripts/template_datamodule.py 获取完整的样板代码。

详细文档: 阅读 references/data_module.md 获取方法细节和使用模式。

4. Callbacks - 可扩展的训练逻辑

在特定训练钩子处添加自定义功能,而无需修改 LightningModule。内置回调包括:

  • ModelCheckpoint - 保存最佳/最新模型
  • EarlyStopping - 当指标平台时停止
  • LearningRateMonitor - 跟踪学习率调度器变化
  • BatchSizeFinder - 自动确定最佳批大小

详细文档: 阅读 references/callbacks.md 获取内置回调和自定义回调创建。

5. Logging - 实验跟踪

与多个日志平台集成:

  • TensorBoard(默认)
  • Weights & Biases (WandbLogger)
  • MLflow (MLFlowLogger)
  • Neptune (NeptuneLogger)
  • Comet (CometLogger)
  • CSV (CSVLogger)

在任何 LightningModule 方法中使用 self.log("metric_name", value) 记录指标。

详细文档: 阅读 references/logging.md 获取日志器设置和配置。

6. Distributed Training - 扩展到多个设备

基于模型大小选择正确的策略:

  • DDP - 用于参数 <500M 的模型(ResNet, 较小的变换器)
  • FSDP - 用于参数 500M+ 的模型(大型变换器,推荐给 Lightning 用户)
  • DeepSpeed - 用于前沿特性和细粒度控制

配置方式:Trainer(strategy="ddp", accelerator="gpu", devices=4)

详细文档: 阅读 references/distributed_training.md 获取策略比较和配置。

7. 最佳实践

  • 设备无关代码 - 使用 self.device 而不是 .cuda()
  • 超参数保存 - 在 __init__() 中使用 self.save_hyperparameters()
  • 指标记录 - 使用 self.log() 自动跨设备聚合
  • 可重复性 - 使用 seed_everything()Trainer(deterministic=True)
  • 调试 - 使用 Trainer(fast_dev_run=True) 测试一个批次

详细文档: 阅读 references/best_practices.md 获取常见模式和陷阱。

快速工作流程

  1. 定义模型:

    class MyModel(L.LightningModule):
        def __init__(self):
            super().__init__()
            self.save_hyperparameters()
            self.model = YourNetwork()
    
        def training_step(self, batch, batch_idx):
            x, y = batch
            loss = F.cross_entropy(self.model(x), y)
            self.log("train_loss", loss)
            return loss
    
        def configure_optimizers(self):
            return torch.optim.Adam(self.parameters())
    
  2. 准备数据:

    # 选项 1: 直接 DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=32)
    
    # 选项 2: LightningDataModule(推荐用于可重用性)
    dm = MyDataModule(batch_size=32)
    
  3. 训练:

    trainer = L.Trainer(max_epochs=10, accelerator="gpu", devices=2)
    trainer.fit(model, train_loader)  # 或 trainer.fit(model, datamodule=dm)
    

资源

scripts/

可执行的 Python 模板,用于常见的 PyTorch Lightning 模式:

  • template_lightning_module.py - 完整的 LightningModule 样板代码
  • template_datamodule.py - 完整的 LightningDataModule 样板代码
  • quick_trainer_setup.py - 常见的 Trainer 配置示例

references/

每个 PyTorch Lightning 组件的详细文档:

  • lightning_module.md - 全面的 LightningModule 指南(方法、钩子、属性)
  • trainer.md - Trainer 配置和参数
  • data_module.md - LightningDataModule 模式和方法
  • callbacks.md - 内置和自定义回调
  • logging.md - 日志器集成和使用
  • distributed_training.md - DDP, FSDP, DeepSpeed 比较和设置
  • best_practices.md - 常见模式、技巧和陷阱