名称:ml-pipeline-automation 描述:使用Airflow、Kubeflow、MLflow自动化机器学习工作流。适用于可重现管道、重训练计划、MLOps或遇到任务失败、依赖错误、实验跟踪问题。 关键词:ML管道、Airflow、Kubeflow、MLflow、MLOps、工作流编排、数据管道、模型训练自动化、实验跟踪、模型注册表、Airflow DAG、任务依赖、管道监控、数据质量、漂移检测、超参数调优、模型版本管理、工件管理、Kubeflow Pipelines、管道自动化、重试、传感器 许可证:MIT
ML管道自动化
使用经过生产测试的Airflow、Kubeflow和MLflow模式,编排从数据摄入到生产部署的端到端机器学习工作流。
何时使用此技能
在以下情况加载此技能:
- 构建ML管道:编排数据 → 训练 → 部署工作流
- 安排重训练:设置自动化模型重训练计划
- 实验跟踪:跟踪运行中的实验、参数、指标
- MLOps实现:构建可重现、监控的ML基础设施
- 工作流编排:管理复杂的多步骤ML工作流
- 模型注册表:管理模型版本和部署生命周期
快速开始:5步构建ML管道
# 1. 安装Airflow和MLflow(使用时检查最新版本)
pip install apache-airflow==3.1.5 mlflow==3.7.0
# 注意:这些版本截至2025年12月是最新的
# 检查PyPI获取最新稳定版本:https://pypi.org/project/apache-airflow/
# 2. 初始化Airflow数据库
airflow db init
# 3. 创建DAG文件:dags/ml_training_pipeline.py
cat > dags/ml_training_pipeline.py << 'EOF'
from airflow import DAG
from airflow.operators.python import PythonOperator
from datetime import datetime, timedelta
default_args = {
'owner': 'ml-team',
'retries': 2,
'retry_delay': timedelta(minutes=5)
}
dag = DAG(
'ml_training_pipeline',
default_args=default_args,
schedule_interval='@daily',
start_date=datetime(2025, 1, 1)
)
def train_model(**context):
import mlflow
import mlflow.sklearn
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)
mlflow.set_tracking_uri('http://localhost:5000')
mlflow.set_experiment('iris-training')
with mlflow.start_run():
model = RandomForestClassifier(n_estimators=100)
model.fit(X_train, y_train)
accuracy = model.score(X_test, y_test)
mlflow.log_metric('accuracy', accuracy)
mlflow.sklearn.log_model(model, 'model')
train = PythonOperator(
task_id='train_model',
python_callable=train_model,
dag=dag
)
EOF
# 4. 启动Airflow调度器和Web服务器
airflow scheduler &
airflow webserver --port 8080 &
# 5. 触发管道
airflow dags trigger ml_training_pipeline
# 访问UI:http://localhost:8080
结果:5分钟内获得具有实验跟踪功能的ML管道。
核心概念
管道阶段
- 数据收集 → 从源获取原始数据
- 数据验证 → 检查模式、质量、分布
- 特征工程 → 将原始数据转换为特征
- 模型训练 → 使用超参数调优训练
- 模型评估 → 在测试集上验证性能
- 模型部署 → 如果指标通过则推送到生产
- 监控 → 跟踪生产中的漂移、性能
编排工具比较
| 工具 | 最佳适用 | 优势 |
|---|---|---|
| Airflow | 通用ML工作流 | 成熟、灵活、Python原生 |
| Kubeflow | Kubernetes原生ML | 基于容器、可扩展 |
| MLflow | 实验跟踪 | 模型注册表、版本管理 |
| Prefect | 现代Python工作流 | 动态DAG、原生缓存 |
| Dagster | 资产导向管道 | 数据感知、可测试 |
基本Airflow DAG
from airflow import DAG
from airflow.operators.python import PythonOperator
from datetime import datetime, timedelta
import logging
logger = logging.getLogger(__name__)
default_args = {
'owner': 'ml-team',
'depends_on_past': False,
'email': ['alerts@example.com'],
'email_on_failure': True,
'retries': 2,
'retry_delay': timedelta(minutes=5)
}
dag = DAG(
'ml_training_pipeline',
default_args=default_args,
description='端到端ML训练管道',
schedule_interval='@daily',
start_date=datetime(2025, 1, 1),
catchup=False
)
def validate_data(**context):
"""验证输入数据质量。"""
import pandas as pd
data_path = "/data/raw/latest.csv"
df = pd.read_csv(data_path)
# 验证检查
assert len(df) > 1000, f"数据不足:{len(df)}行"
assert df.isnull().sum().sum() < len(df) * 0.1, "空值过多"
context['ti'].xcom_push(key='data_path', value=data_path)
logger.info(f"数据验证通过:{len(df)}行")
def train_model(**context):
"""使用MLflow跟踪训练ML模型。"""
import mlflow
import mlflow.sklearn
from sklearn.ensemble import RandomForestClassifier
data_path = context['ti'].xcom_pull(key='data_path', task_ids='validate_data')
mlflow.set_tracking_uri('http://mlflow:5000')
mlflow.set_experiment('production-training')
with mlflow.start_run():
# 训练逻辑
model = RandomForestClassifier(n_estimators=100)
# model.fit(X, y) ...
mlflow.log_param('n_estimators', 100)
mlflow.sklearn.log_model(model, 'model')
validate = PythonOperator(
task_id='validate_data',
python_callable=validate_data,
dag=dag
)
train = PythonOperator(
task_id='train_model',
python_callable=train_model,
dag=dag
)
validate >> train
已知问题预防
1. 任务失败无警报
问题:管道静默失败,直到用户抱怨才被发现。
解决方案:配置失败时邮件/Slack警报:
default_args = {
'email': ['ml-team@example.com'],
'email_on_failure': True,
'email_on_retry': False
}
def on_failure_callback(context):
"""失败时发送Slack警报。"""
from airflow.providers.slack.operators.slack_webhook import SlackWebhookOperator
slack_msg = f"""
:red_circle: 任务失败:{context['task_instance'].task_id}
DAG:{context['task_instance'].dag_id}
执行日期:{context['ds']}
错误:{context.get('exception')}
"""
SlackWebhookOperator(
task_id='slack_alert',
slack_webhook_conn_id='slack_webhook',
message=slack_msg
).execute(context)
task = PythonOperator(
task_id='critical_task',
python_callable=my_function,
on_failure_callback=on_failure_callback,
dag=dag
)
2. 任务间缺少XCom数据
问题:任务期望前一个任务的XCom值,得到None,崩溃。
解决方案:始终验证XCom拉取:
def process_data(**context):
data_path = context['ti'].xcom_pull(
key='data_path',
task_ids='upstream_task'
)
if data_path is None:
raise ValueError("无来自upstream_task的data_path - 检查XCom推送")
# 处理数据...
3. DAG未在UI中显示
问题:DAG文件存在于dags/但在Airflow UI中不显示。
解决方案:检查DAG解析错误:
# 检查语法错误
python dags/my_dag.py
# 在UI中查看DAG导入错误
# 导航到:浏览 → DAG导入错误
# 常见修复:
# 1. 确保DAG对象在文件中定义
# 2. 检查循环导入
# 3. 验证所有依赖已安装
# 4. 修复语法错误
4. 硬编码路径在生产中失效
问题:路径如/Users/myname/data/在本地工作,在生产中失败。
解决方案:使用Airflow变量或环境变量:
from airflow.models import Variable
def load_data(**context):
# ❌ 错误:硬编码路径
# data_path = "/Users/myname/data/train.csv"
# ✅ 好:使用Airflow变量
data_dir = Variable.get("data_directory", "/data")
data_path = f"{data_dir}/train.csv"
# 或使用环境变量
import os
data_path = os.getenv("DATA_PATH", "/data/train.csv")
5. 卡住任务消耗资源
问题:任务无限期挂起,阻塞工作槽,浪费资源。
解决方案:设置任务执行超时:
from datetime import timedelta
task = PythonOperator(
task_id='long_running_task',
python_callable=my_function,
execution_timeout=timedelta(hours=2), # 2小时后终止
dag=dag
)
6. 无数据验证导致模型训练差
问题:在损坏/不完整数据上训练,模型在生产中表现差。
解决方案:添加数据质量验证任务:
def validate_data_quality(**context):
"""综合数据验证。"""
import pandas as pd
df = pd.read_csv(data_path)
# 模式验证
required_cols = ['user_id', 'timestamp', 'feature_a', 'target']
missing_cols = set(required_cols) - set(df.columns)
if missing_cols:
raise ValueError(f"缺失列:{missing_cols}")
# 统计验证
if df['target'].isnull().sum() > 0:
raise ValueError("目标列包含空值")
if len(df) < 1000:
raise ValueError(f"数据不足:{len(df)}行")
logger.info("✅ 数据质量验证通过")
7. 未跟踪实验导致知识丢失
问题:无法重现结果,不知哪些超参数有效。
解决方案:使用MLflow进行所有实验:
import mlflow
mlflow.set_tracking_uri('http://mlflow:5000')
mlflow.set_experiment('model-experiments')
with mlflow.start_run(run_name='rf_v1'):
# 记录所有超参数
mlflow.log_params({
'model_type': 'random_forest',
'n_estimators': 100,
'max_depth': 10,
'random_state': 42
})
# 记录所有指标
mlflow.log_metrics({
'train_accuracy': 0.95,
'test_accuracy': 0.87,
'f1_score': 0.89
})
# 记录模型
mlflow.sklearn.log_model(model, 'model')
何时加载参考
加载参考文件以获取详细生产实现:
-
Airflow DAG模式:在构建具有错误处理、动态生成、传感器、任务组或重试逻辑的复杂DAG时,加载
references/airflow-patterns.md。包含完整生产DAG示例。 -
Kubeflow和MLflow集成:在使用Kubeflow Pipelines进行容器原生编排、集成MLflow跟踪、构建KFP组件或管理模型注册表时,加载
references/kubeflow-mlflow.md。 -
管道监控:在实施数据质量检查、漂移检测、警报配置或使用Prometheus进行管道健康监控时,加载
references/pipeline-monitoring.md。
最佳实践
- 幂等任务:任务在重新运行时应产生相同结果
- 原子操作:每个任务做好一件事
- 版本化管理一切:数据、代码、模型、依赖
- 全面日志记录:记录所有重要事件及其上下文
- 错误处理:快速失败并给出清晰错误消息
- 监控:跟踪管道健康、数据质量、模型漂移
- 测试:在集成前独立测试任务
- 文档化:记录DAG目的、任务依赖关系
常见模式
条件执行
from airflow.operators.python import BranchPythonOperator
def choose_branch(**context):
accuracy = context['ti'].xcom_pull(key='accuracy', task_ids='evaluate')
if accuracy > 0.9:
return 'deploy_to_production'
else:
return 'retrain_with_more_data'
branch = BranchPythonOperator(
task_id='check_accuracy',
python_callable=choose_branch,
dag=dag
)
train >> evaluate >> branch >> [deploy, retrain]
并行训练
from airflow.utils.task_group import TaskGroup
with TaskGroup('train_models', dag=dag) as train_group:
train_rf = PythonOperator(task_id='train_rf', ...)
train_lr = PythonOperator(task_id='train_lr', ...)
train_xgb = PythonOperator(task_id='train_xgb', ...)
# 所有模型并行训练
preprocess >> train_group >> select_best
等待数据
from airflow.sensors.filesystem import FileSensor
wait_for_data = FileSensor(
task_id='wait_for_data',
filepath='/data/input/{{ ds }}.csv',
poke_interval=60, # 每60秒检查一次
timeout=3600, # 1小时后超时
mode='reschedule', # 不阻塞工作器
dag=dag
)
wait_for_data >> process_data