name: domain-ml
description: “在Rust中构建ML/AI应用时使用。关键词:机器学习,ML,AI,张量,模型,推理,神经网络,深度学习,训练,预测,ndarray,tch-rs,burn,candle,机器学习,人工智能,模型推理”
user-invocable: false
机器学习领域
第3层:领域约束
领域约束 → 设计影响
| 领域规则 |
设计约束 |
Rust实现 |
| 大数据 |
高效内存 |
零拷贝,流式处理 |
| GPU加速 |
CUDA/Metal支持 |
candle, tch-rs |
| 模型可移植性 |
标准格式 |
ONNX |
| 批处理 |
吞吐量优先于延迟 |
批量推理 |
| 数值精度 |
浮点数处理 |
ndarray,谨慎使用f32/f64 |
| 可复现性 |
确定性 |
种子随机数,版本控制 |
关键约束
内存效率
规则:避免复制大型张量
原因:内存带宽是瓶颈
RUST:引用,视图,原地操作
GPU利用率
规则:批量操作以提高GPU效率
原因:GPU每次内核启动的开销
RUST:批处理大小,异步数据加载
模型可移植性
规则:使用标准模型格式
原因:在Python中训练,在Rust中部署
RUST:通过tract或candle使用ONNX
向下追溯 ↓
从约束到设计(第2层):
"需要高效的数据管道"
↓ m10-性能:流式处理,批处理
↓ polars:惰性求值
"需要GPU推理"
↓ m07-并发:异步数据加载
↓ candle/tch-rs:CUDA后端
"需要模型加载"
↓ m12-生命周期:惰性初始化,缓存
↓ tract:ONNX运行时
用例 → 框架
| 用例 |
推荐 |
原因 |
| 仅推理 |
tract (ONNX) |
轻量级,可移植 |
| 训练 + 推理 |
candle, burn |
纯Rust,GPU支持 |
| PyTorch模型 |
tch-rs |
直接绑定 |
| 数据管道 |
polars |
快速,惰性求值 |
关键库
| 用途 |
库 |
| 张量 |
ndarray |
| ONNX推理 |
tract |
| ML框架 |
candle, burn |
| PyTorch绑定 |
tch-rs |
| 数据处理 |
polars |
| 嵌入 |
fastembed |
设计模式
| 模式 |
目的 |
实现 |
| 模型加载 |
一次加载,重复使用 |
OnceLock<Model> |
| 批处理 |
提高吞吐量 |
收集然后处理 |
| 流式处理 |
处理大数据 |
基于迭代器 |
| GPU异步 |
并行化 |
数据加载与计算并行 |
代码模式:推理服务器
use std::sync::OnceLock;
use tract_onnx::prelude::*;
static MODEL: OnceLock<SimplePlan<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>> = OnceLock::new();
fn get_model() -> &'static SimplePlan<...> {
MODEL.get_or_init(|| {
tract_onnx::onnx()
.model_for_path("model.onnx")
.unwrap()
.into_optimized()
.unwrap()
.into_runnable()
.unwrap()
})
}
async fn predict(input: Vec<f32>) -> anyhow::Result<Vec<f32>> {
let model = get_model();
let input = tract_ndarray::arr1(&input).into_shape((1, input.len()))?;
let result = model.run(tvec!(input.into()))?;
Ok(result[0].to_array_view::<f32>()?.iter().copied().collect())
}
代码模式:批量推理
async fn batch_predict(inputs: Vec<Vec<f32>>, batch_size: usize) -> Vec<Vec<f32>> {
let mut results = Vec::with_capacity(inputs.len());
for batch in inputs.chunks(batch_size) {
// 将输入堆叠成批处理张量
let batch_tensor = stack_inputs(batch);
// 对批次进行推理
let batch_output = model.run(batch_tensor).await;
// 解堆叠结果
results.extend(unstack_outputs(batch_output));
}
results
}
常见错误
| 错误 |
领域违规 |
修复方法 |
| 克隆张量 |
内存浪费 |
使用视图 |
| 单次推理 |
GPU利用率低 |
批处理 |
| 每次请求加载模型 |
速度慢 |
单例模式 |
| 同步数据加载 |
GPU空闲 |
异步管道 |
追溯至第1层
| 约束 |
第2层模式 |
第1层实现 |
| 内存效率 |
零拷贝 |
ndarray视图 |
| 模型单例 |
惰性初始化 |
OnceLock<Model> |
| 批处理 |
分块迭代 |
chunks() + 并行 |
| GPU异步 |
并发加载 |
tokio::spawn + GPU |
相关技能
| 何时使用 |
参见 |
| 性能 |
m10-性能 |
| 惰性初始化 |
m12-生命周期 |
| 异步模式 |
m07-并发 |
| 内存效率 |
m01-所有权 |