catlass:昇腾NPU上的算子模板库

发布时间:2026/5/21 7:28:50

catlass:昇腾NPU上的算子模板库 CANN 的矩阵计算能力来自三层结构最顶层是ops-blas这样的标准算子库中间是 GEMM 的优化实现层最底层是catlass——一套用模板元编程生成的算子 Kernel 工厂。很多人第一次看到 catlass 会觉得它就是 CUTLASS 的昇腾移植版。这个印象对了一半。catlass 确实借鉴了 CUTLASS 的模板化思想但针对昇腾达芬奇架构做了大量不同于 GPU 的设计取舍。catlass 为什么存在AI 计算的核心是矩阵乘法。一张计算图里 70-80% 的 FLOPs 来自各种形状的 GEMM——大的如 FFN 的[B, 4096] × [4096, 11008]小的如 Attention 里的[B, 40, n, 128] × [B, 40, 128, n]。这些 GEMM 的形状千差万别但底层都是同一个操作M×K 乘以 K×N。如果为每种形状手写一个 Kernel工作量巨大且难以维护。catlass 的答案是用 C 模板来生成 Kernel。开发者描述我要一个 M4096、K4096、N11008 的 GEMM——catlass 的模板引擎在编译期展开成对应的 Tile 循环、搬运指令和计算指令。不同形状的 GEMM 共享同一套模板代码只在模板参数上做区分。算子模板的核心思想catlass 把一个 GEMM Kernel 拆解成几个可组合的模板层次Tile 描述层。定义每个分块的大小——M 维度分多少、N 维度分多少、K 维度做多少次循环。这个层的配置直接影响片上 L1 Buffer 的利用率和搬运次数。数据搬运层。定义数据从 GMDDR到片上 L1 的搬运策略——要不要做转置、数据是 FP16 还是 INT8、地址对齐方式。catlass 把搬运描述为Iterator模板不同的 GEMM 形状对应不同的 Iterator 特化版本。计算指令层。定义 AI Core 上的计算指令序列——Cube Unit 做矩阵乘、Vector Unit 做偏置加和激活函数。这一层直接映射到达芬奇架构的硬件指令。这三层模板在编译期通过模板参数组合展开为特定形状的完整 Kernel 代码。开发者不需要手写任何循环或搬运指令。// catlass 模板组合示例伪代码usingGemmTraitscatlass::GemmTraitsM4096,N11008,K4096,LayoutARowMajor,LayoutBColMajor,TypeAfloat16,TypeBfloat16,TypeCfloat16,TileShape128,128,64,IteratorAIteratorContiguous,IteratorBIteratorPacked;// 编译期展开为完整 Kernelcatlass::KernelGemmTraits::launch(ptrA,ptrB,ptrC,stream);catlass 与 CUTLASS 的区别硬件抽象不同。CUTLASS 抽象的是 GPU 的 warp/thread 层次——开发者需要配置每个线程计算哪些元素、如何做 shared memory 的 bank conflict 避免。catlass 抽象的是昇腾的 Cube Unit 和 Vector Unit——开发者配置的是分块大小和数据搬运策略不涉及线程模型。编译期 vs 运行时。CUTLASS 的模板展开走 CUDA 的 JIT 编译路径运行时可能存在首次编译延迟。catlass 的模板在 ATC 编译模型时就全部展开Runtime 加载 Kernel 时已经是编译好的二进制。推理时零编译开销。融合能力不同。CUTLASS 的模板化停在 GEMM Kernel 边界——GEMM 做完后的 bias add 和 activation 需要独立的 Kernel 调用。catlass 的模板允许在 Tile 循环内部嵌入 Vector 计算指令比如 GEMM 的每个分块计算完后立即做 ReLU结果不落 DDR。// catlass 的 GEMM ReLU 融合模板usingGemmReluTraitscatlass::GemmTraitsM4096,N11008,K4096,EpilogueReluTypeC// 分块计算完后立即 ReLU;这个融合能力在 Transformer 推理中非常关键——FFN 层的GEMM → BiasAdd → ReLU → GEMM可以融合成两个带 Epilogue 的 catlass Kernel省掉中间 Tensor 的两次 DDR 读写。CANN 如何调用 catlassCANN 的算子调用链路中catlass 不直接暴露给应用层。AscendCL 推理 → GE 图优化 → ops-blas高层算子库→ catlass模板 Kernel 生成应用层调用的是 AscendCL 或 PyTorch 的标准推理接口。GE 在图优化阶段把计算图中的 GEMM 算子替换成 ops-blas 的优化实现。ops-blas 内部在初始化时通过 catlass 的模板引擎生成特定形状的 Kernel注册到 Runtime 的算子表中。推理时 Runtime 直接查表加载对应 Kernel。开发者不需要直接接触 catlass但理解它的模板化思想对于手工优化 GEMM 形状很关键——当你发现某个 GEMM 形状的推理性能异常时大概率是 catlass 为该形状生成的 Tile 参数没有匹配到最优配置。Transformer 中的 GEMM 优化以 LLaMA-7B 为例模型运行时涉及的 GEMM 主要是三种形状FFN 前向 GEMM[B, 4096] × [4096, 11008]。这是一个典型的 M 小、K 和 N 大的矩形 GEMM。catlass 为这种形状生成的 Tile 策略是 N 维度大分块减少 Tile 循环次数、M 维度小分块匹配 Batch 的动态变化。Attention 投影 GEMMQ X W_Q形状[B, n, 4096] × [4096, 4096]。三个投影Q/K/V形状相同。catlass 会让三个投影共享同一套 Tile 模板在 Stream 上流水线执行。Decoder Block 的残差 GEMM形状小且不规则。catlass 为这类小 GEMM 专门设计了 Tile 模板避免在小矩阵上跑大 Tile 导致的搬运浪费。实测中catlass 的模板化 GEMM 比手写固定 Kernel 在各种形状上的平均性能差距在 5% 以内但在开发效率和代码维护上有数量级的优势。CANN catlass 算子模板仓库ops-blas 线性代数算子库Tile 配置的艺术catlass 的模板化核心在 Tile 参数的配置。不同的 GEMM 形状对 Tile 的要求完全不同。大 GEMMM4096, K4096, N11008 Tile(M128, N128, K64) — 让 Cube Unit 保持满载 循环次数M32, N86, K64 → 总启动 176,128 次 Tile 计算 小 GEMMM1, K4096, N4096 Tile(M1, N128, K4096) — M 维度不切N 平铺 循环次数M1, N32, K1 → 32 次 不规则 GEMMM7, K1024, N256 Tile(M7, N64, K1024) — 非 2 的幂次K 全量 循环次数M1, N4, K1 → 4 次catlass 的模板引擎内置了一套启发式 Tile 选择逻辑根据 M、N、K 的比值和金丝雀测试数据自动选择理论最优的 Tile 配置。如果自动选择不是最优开发者可以通过模板参数覆盖。Kernel 生成的编译期流程catlass 的模板展开不是运行时发生的而是在 ATC 编译模型时通过编译器完成模型 ONNX → ATC 解析 GEMM 形状 → catlass 模板实例化编译期 → 模板参数推导 Tile 配置 → 生成 /tmp/.catlass_kernels/shape_hash.o → 链接进 OM 模型 → Runtime 加载时已经是编译好的二进制每个 GEMM 形状的 Kernel 编译一次后缓存到文件系统。下次遇到相同形状的 GEMM 直接复用。这个缓存机制在模型部署时非常关键——模型加载时不需要重新编译 KernelOM 文件内已经包含了所有 GEMM Kernel 的二进制。继续学习catlass 解决的是GEMM Kernel 怎么写的问题。上一层 ops-blas 解决的是GEMM 形状太多怎么管理的问题。理解了 catlass 的模板化思想后顺着 ops-blas 的仓库可以看到 CANN 是如何为数百种 GEMM 形状提供统一优化接口的。模板化 vs 手写 Kernel 的取舍模板化的代价是灵活性受限。手写 Kernel 可以为特定形状做极致优化——比如在 M1 的场景省略掉 M 维度的所有循环。catlass 的通用模板无法覆盖所有极端形状的特殊优化但是覆盖了 95% 的常用场景。剩余 5% 的极端形状可以通过手写 Ascend C Kernel 来补充catlass 提供了与手写 Kernel 的互操作接口。参考仓库catlass 算子模板库ops-blas 线性代数算子库

相关新闻