ops-blas 的 GEMM 优化:昇腾NPU上的矩阵乘法引擎

发布时间:2026/5/21 7:28:50

ops-blas 的 GEMM 优化:昇腾NPU上的矩阵乘法引擎 矩阵乘法是 AI 计算的基本操作。一个 Transformer 模型在前向推理中大概 70% 的 FLOPs 花在 GEMM 上——FFN 层是 GEMMAttention 的投影是 GEMMScore 计算在 Cube 上也是以 GEMM 形式执行的。ops-blas 是 CANN 里专门管 GEMM 的算子库。它不负责模板化 Kernel 生成那是 catlass 的事它的定位是管理数十种 GEMM 变体——从 FP16 的通用 GEMM 到 INT8 的量化 GEMM——并在不同形状下选择最优实现。为什么 GEMM 是 AI 的核心AI 模型的本质是数据乘以参数。一个 Transformer Block 的推理过程输入 X → X W_Q → Q X W_K → K X W_V → V Attention Score Q K^T Context Score V Output Context W_O FFN Output W_1 → ReLU → W_2去掉非线性激活函数和归一化Transformer Block 就是一连串 GEMM。训练时反向传播也由 GEMM 构成梯度矩阵乘权重或权重矩阵乘梯度。所以 GEMM 的性能直接决定了 NPU 的利用率。如果 GEMM 做不到硬件峰值效率的 80%整个推理服务的吞吐会被压在 GEMM 这一层。昇腾NPU如何执行矩阵计算昇腾的达芬奇架构执行 GEMM 的方式跟 GPU 不同。GPU 用大量小 CUDA Core 做并行矩阵乘昇腾用专用的 Cube Unit立方体单元。Cube Unit 一次执行一个A[16×16] × B[16×16]的小矩阵乘输出一个C[16×16]的结果块。更大的矩阵乘由 Cube Unit 分块重复执行M4096, K4096, N11008 的 GEMM 将 A 按 16×16 切分成 256×256 块 将 B 按 16×16 切分成 256×688 块 Cube Unit 逐块计算累加到 C 的对应位置 总调用次数256 × 256 × 688 ≈ 45M 次Cube Unit 的计算速度极快——单次16×16FP16 矩阵乘在 Ascend 910 上约耗时 4 个时钟周期。但 Cube Unit 的瓶颈不在计算在数据供给。Cube Unit 每秒需要消耗几十 GB 的数据才能跑满数据供给跟不上时 Cube Unit 就空转。Tile 分块为什么重要前面 catlass 文章提到 Tile 是模板化的核心。在 ops-blas 层面Tile 问题更具体——它决定了 Cube Unit 的空闲率。一个 M4096 的 GEMM如果不做 Tile就是把 A 矩阵整块从 DDR 搬进片上 L1。但 L1 只有几百 KB装不下几个完整的矩阵。所以必须分块A 矩阵 → 切成 M_tile × K_tile 的小块 B 矩阵 → 切成 K_tile × N_tile 的小块 ┌─────────────────┐ A_tile → │ Cube Unit │ → C_tile 累加 B_tile → │ 16×16 矩阵乘 │ └─────────────────┘Tile 大小选多大太小了 Cube Unit 频繁切换数据每个 Tile 的 DMA 启动开销占比大。太大了 L1 装不下部分数据 spill 到 DDR。ops-blas 的做法是维护一个 Tile 配置表——为常见 GEMM 形状预存最优 Tile 参数。不常见的形状通过测量方法在 Runtime 快速跑几个 Tile 方案选取延迟最低的确定。GEMM 形状推荐 TileCube 利用率1×4096 × 4096×11008M1, N128, K409635%8×4096 × 4096×11008M8, N128, K204862%64×4096 × 4096×11008M64, N256, K102478%256×4096 × 4096×11008M128, N256, K12885%Batch 越大 Cube 利用率越高因为 M 维度变长后 Tile 循环中的 DMA 启动开销占比下降。Memory Bound 为何是本质瓶颈ops-blas 优化的核心不是让 Cube Unit 算得更快——Cube 已经快到极限了。优化的目标是在 Cube Unit 计算的同时保证数据搬运不成为瓶颈。一个M128, N256, K128的 Tile计算量128 × 256 × 128 × 2 8.4M FLOPs乘加算两次搬运量A128×128×232KB, B128×256×264KB, C128×256×264KB 160KB计算/搬运比8.4M / 160K 52.5 FLOPs/byteCube Unit 的理论算力约 24 TFLOPS要达到这个算力需要每秒 24T / 52.5 ≈ 457 GB/s 的搬运带宽。而 DDR 的实际带宽通常只有 200-300 GB/s。这个差距就是 GEMM 是 Memory Bound 算子的定量证据。ops-blas 的优化策略就是围绕搬运不够快这个核心矛盾展开的Double Buffer用一个 Buffer 计算、另一个 Buffer 搬运下一块数据分块复用同一块 A 的 Tile 跟多个 B 的 Tile 配对减少 A 的搬运次数数据压缩FP16 比 FP32 少搬一半INT8 再减半大模型中的 GEMM 瓶颈LLaMA-70B 的解码阶段每步生成一个 Token 涉及约 10 个 GEMMGEMM 位置形状计算量 (GFLOPs)搬运量 (MB)受限类型Q 投影1×8192 × 8192×81920.130.5计算K 投影1×8192 × 8192×81920.130.5计算V 投影1×8192 × 8192×81920.130.5计算Attention Score1×128 × 128×n0.0003n0.001n搬运输出投影1×8192 × 8192×81920.130.5计算FFN 升维1×8192 × 8192×286720.471.75计算FFN 降维1×28672 × 28672×81920.471.75计算解码阶段 M1 的小 GEMM 占了大部分。M1 时 Cube Unit 利用率只有 35% 左右。ops-blas 的应对方案是用 Vector Unit 处理小 GEMMVector 在 M1 时比 Cube 更灵活或者把多个 request 的 GEMM 在 M 维度上拼接成一个更大的 GEMM。CANN ops-blas 线性代数算子库catlass 算子模板生成ops-blas 的优化思路ops-blas 内部把 GEMM 按形状分成几类每类走不同的优化路径大 GEMMM≥256, N≥1024, K≥1024Cube Unit 满载运行。优化重点是 Double Buffer 和 Tile 参数调优。ops-blas 会做一次金丝雀测试——在模型加载时用小样本跑几组 Tile 配置选出实测最快的。中 GEMM16≤M256Cube Unit 部分空转。优化方向是把多个小 GEMM 在 M 维度拼接。比如 Batch8 的 Attention 投影M8 的 GEMM 在 Cube 上跑利用率很低ops-blas 会把 8 个 request 的 Q 投影 GEMM 拼接成 M64 的大 GEMM一次计算产出 8 个结果。这个技巧叫 GEMM Batching。小 GEMMM16Cube Unit 利用率太低直接用 Cube 跑不如用 Vector Unit。ops-blas 对这类 GEMM 走 Vector 路径——用 Vector 的 SIMD 指令做矩阵乘。Vector 在小矩阵上的灵活性高于 Cube总耗时反而更短。GEMM 形状Cube 路径Vector 路径最优选择1×4096 × 4096×81920.045ms0.038msVector8×4096 × 4096×81920.12ms0.15msCube64×4096 × 4096×81920.58ms0.92msCube256×4096 × 4096×81921.9ms3.4msCubeM1 时 Vector 路径快 15%。M≥8 后 Cube 路径开始占优。GEMM 的双 Buffer 优化细节ops-blas 最核心的优化是 Double Buffer它让数据搬运和矩阵计算完全重叠。朴素执行无 Buffer [搬 A_0] → [搬 B_0] → [Cube 算 C_0] → [搬 A_1] → [搬 B_1] → [Cube 算 C_1] → ... Double Buffer 执行 [搬 A_0][搬 B_0] → 开始算 C_0 的同时 → [搬 A_1][搬 B_1] → 算 C_1 的同时 → [搬 A_2]... ↑ Cube Unit 不空等 ↑ 无缝衔接通过 Double BufferCube Unit 在计算当前 Tile 时下一块 Tile 的数据已经通过 DMA 搬到了 L1 上。Cube 不再等数据。ops-blas 实现 Double Buffer 的方式是在 L1 中分配两份 BufferBuffer_A 和 Buffer_B 各两份。计算线程操作 Buffer pair 0DMA 线程操作 Buffer pair 1。每个 Tile 结束时交换角色。实际工程实现中这个切换需要在 GEMM 循环的边界上精确控制同步点。ops-blas 把所有同步逻辑封装在模板循环内部上层的 Tile 循环只管递交任务不需要管同步。小结ops-blas 的存在意义是让上层应用AscendCL、PyTorch不需要为不同 GEMM 形状操心。无论 GEMM 是大是小、M 是 1 还是 1024、数据类型是 FP16 还是 INT8ops-blas 都能找到当前条件下最快的执行路径。理解了它的 Tile 策略和 Double Buffer 机制才能在 GEMM 性能异常时快速定位瓶颈。GEMM 的量化变体ops-blas 不止管 FP16 GEMM。量化推理场景中 INT8 GEMM 的优化复杂度更高。INT8 GEMM 的问题在于A 和 B 的量化参数scale 和 zero_point可能不同乘完后需要反量化到 FP16 再做累加。ops-blas 把反量化步骤融合到 GEMM 的 Epilogue 中——Cube Unit 输出 INT32 累加结果后Vector Unit 立即做 scale 乘法和 zero_point 偏移然后 cast 回 FP16。// ops-blas 的 INT8 GEMM 融合流程Cube:C_int32A_int8 B_int8// 矩阵乘Vector:C_fp16(C_int32-offset)*scale// 反量化Vector:C_fp16ReLU(C_fp16)// 再融合激活三个步骤在一次 Kernel Launch 内完成中间结果不落 DDR。从 ops-blas 到 catlassops-blas 和 catlass 的分工简单说就是ops-blas 是管 GEMM 形状的catlass 是写 GEMM Kernel 的。ops-blas 负责判断这个形状用 Cube 还是 Vector、Tile 大小怎么设catlass 负责把决策变成可执行的 Kernel 代码。理解了这个分工就知道 GEMM 优化在 CANN 里是怎么分层协作的。参考仓库ops-blas 线性代数算子库catlass 算子模板生成器

相关新闻