Rust机器学习领域技能 domain-ml

本技能专注于使用Rust编程语言高效构建和部署机器学习与人工智能应用。涵盖关键领域约束如内存效率、GPU加速、模型可移植性,并提供具体的设计模式、推荐框架(如candle, tract, tch-rs)和代码实现。适用于需要在Rust生态系统中实现高性能模型推理、训练以及数据处理管道的开发者。关键词:Rust机器学习,AI模型部署,高性能推理,GPU加速,ONNX,批处理,内存优化,tract,candle,tch-rs,量化交易AI基础设施。

机器学习 0 次安装 0 次浏览 更新于 2/27/2026

name: domain-ml description: “在Rust中构建ML/AI应用时使用。关键词:机器学习,ML,AI,张量,模型,推理,神经网络,深度学习,训练,预测,ndarray,tch-rs,burn,candle,机器学习,人工智能,模型推理” user-invocable: false

机器学习领域

第3层:领域约束

领域约束 → 设计影响

领域规则 设计约束 Rust实现
大数据 高效内存 零拷贝,流式处理
GPU加速 CUDA/Metal支持 candle, tch-rs
模型可移植性 标准格式 ONNX
批处理 吞吐量优先于延迟 批量推理
数值精度 浮点数处理 ndarray,谨慎使用f32/f64
可复现性 确定性 种子随机数,版本控制

关键约束

内存效率

规则:避免复制大型张量
原因:内存带宽是瓶颈
RUST:引用,视图,原地操作

GPU利用率

规则:批量操作以提高GPU效率
原因:GPU每次内核启动的开销
RUST:批处理大小,异步数据加载

模型可移植性

规则:使用标准模型格式
原因:在Python中训练,在Rust中部署
RUST:通过tract或candle使用ONNX

向下追溯 ↓

从约束到设计(第2层):

"需要高效的数据管道"
    ↓ m10-性能:流式处理,批处理
    ↓ polars:惰性求值

"需要GPU推理"
    ↓ m07-并发:异步数据加载
    ↓ candle/tch-rs:CUDA后端

"需要模型加载"
    ↓ m12-生命周期:惰性初始化,缓存
    ↓ tract:ONNX运行时

用例 → 框架

用例 推荐 原因
仅推理 tract (ONNX) 轻量级,可移植
训练 + 推理 candle, burn 纯Rust,GPU支持
PyTorch模型 tch-rs 直接绑定
数据管道 polars 快速,惰性求值

关键库

用途
张量 ndarray
ONNX推理 tract
ML框架 candle, burn
PyTorch绑定 tch-rs
数据处理 polars
嵌入 fastembed

设计模式

模式 目的 实现
模型加载 一次加载,重复使用 OnceLock<Model>
批处理 提高吞吐量 收集然后处理
流式处理 处理大数据 基于迭代器
GPU异步 并行化 数据加载与计算并行

代码模式:推理服务器

use std::sync::OnceLock;
use tract_onnx::prelude::*;

static MODEL: OnceLock<SimplePlan<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>> = OnceLock::new();

fn get_model() -> &'static SimplePlan<...> {
    MODEL.get_or_init(|| {
        tract_onnx::onnx()
            .model_for_path("model.onnx")
            .unwrap()
            .into_optimized()
            .unwrap()
            .into_runnable()
            .unwrap()
    })
}

async fn predict(input: Vec<f32>) -> anyhow::Result<Vec<f32>> {
    let model = get_model();
    let input = tract_ndarray::arr1(&input).into_shape((1, input.len()))?;
    let result = model.run(tvec!(input.into()))?;
    Ok(result[0].to_array_view::<f32>()?.iter().copied().collect())
}

代码模式:批量推理

async fn batch_predict(inputs: Vec<Vec<f32>>, batch_size: usize) -> Vec<Vec<f32>> {
    let mut results = Vec::with_capacity(inputs.len());

    for batch in inputs.chunks(batch_size) {
        // 将输入堆叠成批处理张量
        let batch_tensor = stack_inputs(batch);

        // 对批次进行推理
        let batch_output = model.run(batch_tensor).await;

        // 解堆叠结果
        results.extend(unstack_outputs(batch_output));
    }

    results
}

常见错误

错误 领域违规 修复方法
克隆张量 内存浪费 使用视图
单次推理 GPU利用率低 批处理
每次请求加载模型 速度慢 单例模式
同步数据加载 GPU空闲 异步管道

追溯至第1层

约束 第2层模式 第1层实现
内存效率 零拷贝 ndarray视图
模型单例 惰性初始化 OnceLock<Model>
批处理 分块迭代 chunks() + 并行
GPU异步 并发加载 tokio::spawn + GPU

相关技能

何时使用 参见
性能 m10-性能
惰性初始化 m12-生命周期
异步模式 m07-并发
内存效率 m01-所有权