cutlass-tritonSkill cutlass-triton

cutlass-triton是一个专门用于生成优化GPU内核的技能,使用CUTLASS和Triton实现高性能计算。

深度学习 1 次安装 31 次浏览 更新于 2/25/2026

name: cutlass-triton description: 高性能内核模板库和领域特定语言。生成 CUTLASS GEMM 配置,实现 Triton 内核定义,配置尾处理操作,调整瓦片大小和 warp 排列,与 cuBLAS 基准测试。 allowed-tools: Bash(*) 读写编辑 glob grep webfetch metadata: author: babysitter-sdk version: “1.0.0” category: 内核生成 backlog-id: SK-016

cutlass-triton

您是 cutlass-triton - 一个专门用于高性能内核模板库和领域特定语言的专业技能。这个技能提供了生成优化的 GPU 内核的专家能力,使用 CUTLASS 和 Triton。

概览

这个技能使得 AI 驱动的内核生成成为可能,包括:

  • 生成 CUTLASS GEMM 配置
  • 实现 Triton 内核定义
  • 配置尾处理操作
  • 处理张量布局转换
  • 调整瓦片大小和 warp 排列
  • 支持混合精度矩阵操作
  • 与 cuBLAS 实现基准测试
  • 生成自定义注意力内核

前提条件

  • CUTLASS 3.0+(仅头文件库)
  • Triton 2.0+(Python 包)
  • CUDA Toolkit 11.0+
  • Python 3.8+(对于 Triton)

能力

1. CUTLASS GEMM 配置

配置高性能 GEMM:

#include <cutlass/cutlass.h>
#include <cutlass/gemm/device/gemm.h>

// 定义 GEMM 操作类型
using ElementA = cutlass::half_t;
using ElementB = cutlass::half_t;
using ElementC = cutlass::half_t;
using ElementAccumulator = float;

using LayoutA = cutlass::layout::RowMajor;
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::RowMajor;

// 定义 CUTLASS GEMM
using Gemm = cutlass::gemm::device::Gemm<
    ElementA, LayoutA,
    ElementB, LayoutB,
    ElementC, LayoutC,
    ElementAccumulator,
    cutlass::arch::OpClassTensorOp,
    cutlass::arch::Sm80,
    cutlass::gemm::GemmShape<128, 256, 64>,  // 线程块形状
    cutlass::gemm::GemmShape<64, 64, 64>,    // Warp 形状
    cutlass::gemm::GemmShape<16, 8, 16>,     // 指令形状(张量核心)
    cutlass::epilogue::thread::LinearCombination<
        ElementC, 128 / cutlass::sizeof_bits<ElementC>::value,
        ElementAccumulator, ElementAccumulator>,
    cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>,
    3  // 阶段
>;

// 运行 GEMM
void runGemm(int M, int N, int K,
             ElementA* A, ElementB* B, ElementC* C,
             ElementAccumulator alpha, ElementAccumulator beta) {
    Gemm gemm_op;
    Gemm::Arguments args(
        {M, N, K},
        {A, K}, {B, K}, {C, N}, {C, N},
        {alpha, beta}
    );

    cutlass::Status status = gemm_op(args);
    if (status != cutlass::Status::kSuccess) {
        // 处理错误
    }
}

2. CUTLASS 3.0 (Cute) API

现代 CUTLASS 与 Cute:

#include <cute/tensor.hpp>
#include <cutlass/gemm/collective/collective_mma.hpp>

using namespace cute;

// 使用 Cute 定义布局
using SmemLayoutA = Layout<Shape<_128, _64>, Stride<_64, _1>>;
using SmemLayoutB = Layout<Shape<_64, _128>, Stride<_1, _64>>;

// 集体 MMA 配置
using CollectiveMma = cutlass::gemm::collective::CollectiveMma<
    cutlass::arch::Sm90,
    Shape<_128, _256, _64>,  // 瓦片形状
    ElementA, cutlass::layout::RowMajor,
    ElementB, cutlass::layout::ColumnMajor,
    ElementAccumulator,
    TiledMMA<
        MMA_Atom<SM80_16x8x16_F32F16F16F32_TN>,
        Layout<Shape<_2, _2, _1>>
    >,
    GmemTiledCopyA, SmemLayoutA, SmemCopyAtomA,
    GmemTiledCopyB, SmemLayoutB, SmemCopyAtomB
>;

3. Triton 内核开发

用 Triton DSL 写内核:

import triton
import triton.language as tl

@triton.jit
def matmul_kernel(
    a_ptr, b_ptr, c_ptr,
    M, N, K,
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
    # 程序 ID
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)

    # 块偏移量
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)

    # 指向第一个块的指针
    a_ptrs = a_ptr + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak
    b_ptrs = b_ptr + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn

    # 初始化累加器
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)

    # 主循环
    for k in range(0, K, BLOCK_K):
        # 加载块
        a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k, other=0.0)
        b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k, other=0.0)

        # 计算
        acc += tl.dot(a, b)

        # 移动指针
        a_ptrs += BLOCK_K * stride_ak
        b_ptrs += BLOCK_K * stride_bk

    # 存储结果
    c_ptrs = c_ptr + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn
    tl.store(c_ptrs, acc, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))


def matmul(a, b):
    M, K = a.shape
    K, N = b.shape
    c = torch.empty((M, N), device=a.device, dtype=a.dtype)

    grid = lambda meta: (
        triton.cdiv(M, meta['BLOCK_M']),
        triton.cdiv(N, meta['BLOCK_N'])
    )

    matmul_kernel[grid](
        a, b, c,
        M, N, K,
        a.stride(0), a.stride(1),
        b.stride(0), b.stride(1),
        c.stride(0), c.stride(1),
        BLOCK_M=64, BLOCK_N=64, BLOCK_K=32
    )
    return c

4. Triton 自动调优

自动内核调优:

@triton.autotune(
    configs=[
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=3, num_warps=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'BLOCK_K': 32}, num_stages=3, num_warps=4),
        triton.Config({'BLOCK_M': 64, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=3, num_warps=4),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'BLOCK_K': 32}, num_stages=3, num_warps=8),
        triton.Config({'BLOCK_M': 128, 'BLOCK_N': 256, 'BLOCK_K': 64}, num_stages=4, num_warps=8),
    ],
    key=['M', 'N', 'K']
)
@triton.jit
def matmul_autotune(
    a_ptr, b_ptr, c_ptr,
    M, N, K,
    stride_am, stride_ak,
    stride_bk, stride_bn,
    stride_cm, stride_cn,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
    # 相同的内核主体...
    pass

5. 尾处理操作

自定义后处理:

// CUTLASS 尾处理带激活函数
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationRelu<
    ElementC,
    128 / cutlass::sizeof_bits<ElementC>::value,
    ElementAccumulator,
    ElementAccumulator
>;

// 融合偏置 + 激活
using EpilogueWithBias = cutlass::epilogue::thread::LinearCombinationBias<
    ElementC,
    128 / cutlass::sizeof_bits<ElementC>::value,
    ElementAccumulator,
    ElementAccumulator,
    cutlass::epilogue::thread::ReLu
>;
# Triton 尾处理
@triton.jit
def fused_matmul_relu(
    a_ptr, b_ptr, bias_ptr, c_ptr,
    M, N, K,
    # ... 步长 ...
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
    # ... 矩阵乘法计算 ...

    # 尾处理:添加偏置和 ReLU
    bias = tl.load(bias_ptr + offs_n)
    acc = acc + bias[None, :]
    acc = tl.maximum(acc, 0.0)

    tl.store(c_ptrs, acc, mask=mask)

6. Triton 中的 Flash Attention

优化的注意力内核:

@triton.jit
def flash_attention_kernel(
    Q, K, V, Out,
    stride_qz, stride_qh, stride_qm, stride_qk,
    stride_kz, stride_kh, stride_kn, stride_kk,
    stride_vz, stride_vh, stride_vn, stride_vk,
    stride_oz, stride_oh, stride_om, stride_ok,
    Z, H, M, N,
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
    pid_m = tl.program_id(0)
    pid_z = tl.program_id(1)
    pid_h = tl.program_id(2)

    # 初始化
    offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_n = tl.arange(0, BLOCK_N)
    offs_k = tl.arange(0, BLOCK_K)

    # 加载 Q 块
    q_ptrs = Q + pid_z * stride_qz + pid_h * stride_qh + \
             offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk
    q = tl.load(q_ptrs, mask=offs_m[:, None] < M)

    # 运行最大值和总和用于在线 softmax
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')
    l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
    acc = tl.zeros([BLOCK_M, BLOCK_K], dtype=tl.float32)

    # 遍历 K, V 块
    for start_n in range(0, N, BLOCK_N):
        # 加载 K, V 块
        # 计算注意力分数
        # 在线 softmax 更新
        # 累积输出
        pass

    # 存储输出
    o_ptrs = Out + pid_z * stride_oz + pid_h * stride_oh + \
             offs_m[:, None] * stride_om + offs_k[None, :] * stride_ok
    tl.store(o_ptrs, acc, mask=offs_m[:, None] < M)

7. 基准测试

比较性能:

import torch
import triton

def benchmark_matmul(M, N, K, dtype=torch.float16):
    a = torch.randn((M, K), device='cuda', dtype=dtype)
    b = torch.randn((K, N), device='cuda', dtype=dtype)

    # Triton
    triton_fn = lambda: triton_matmul(a, b)
    triton_ms = triton.testing.do_bench(triton_fn)

    # cuBLAS
    cublas_fn = lambda: torch.matmul(a, b)
    cublas_ms = triton.testing.do_bench(cublas_fn)

    # TFLOPS
    tflops = 2 * M * N * K / 1e12
    print(f"Triton: {triton_ms:.2f} ms ({tflops/triton_ms*1e3:.1f} TFLOPS)")
    print(f"cuBLAS: {cublas_ms:.2f} ms ({tflops/cublas_ms*1e3:.1f} TFLOPS)")
    print(f"Ratio: {cublas_ms/triton_ms:.2f}x")

# 基准测试不同大小
for size in [1024, 2048, 4096, 8192]:
    print(f"
=== {size}x{size}x{size} ===")
    benchmark_matmul(size, size, size)

流程集成

这个技能与以下流程集成:

  • tensor-core-programming.js - 张量核心工作流
  • custom-cuda-operator-development.js - 自定义操作符
  • ml-inference-optimization.js - ML 推理

输出格式

{
  "operation": "generate-kernel",
  "framework": "triton",
  "kernel_type": "matmul",
  "configuration": {
    "BLOCK_M": 128,
    "BLOCK_N": 128,
    "BLOCK_K": 32,
    "num_stages": 3,
    "num_warps": 8
  },
  "performance": {
    "tflops": 145.2,
    "vs_cublas": 0.95,
    "memory_bound": false
  },
  "generated_files": ["matmul_kernel.py"]
}

依赖关系

  • CUTLASS 3.0+
  • Triton 2.0+
  • CUDA Toolkit 11.0+
  • PyTorch(用于 Triton 集成)

限制

  • CUTLASS 模板增加编译时间
  • Triton 需要 Python 环境
  • 张量核心需要特定数据类型/对齐
  • 性能因 GPU 架构而异