
简介【免费下载链接】xla-npuXLA-NPU 是一个面向华为昇腾NPU硬件的 XLA后端实现。本项目通过接入OpenXLA/XLA开源项目将XLA开源生态与华为 CANN软件栈集成对接JAX框架。JAX框架运行时可以直接加载XLA-NPU使得基于JAX框架开发的模型可以运行在昇腾NPU上提供推理场景图编译加速能力。项目地址: https://gitcode.com/cann/xla-npu概述XLA-NPU是一个面向华为昇腾NPUNeural Processing Unit硬件的XLAAccelerated Linear Algebra后端实现。本项目通过接入OpenXLA/XLA开源项目将XLA开源生态与华为 CANNCompute Architecture for Neural Networks软件栈集成对接JAX框架。JAX框架运行时可以直接加载XLA-NPU使得基于JAX框架开发的模型可以运行在昇腾NPU上提供推理场景图编译加速能力。xla_npu实现OpenXLA PJRT运行时接口通过调用CANN软件栈中Runtime接口管理设备、Stream、Event、内存等从而驱动NPU设备运行模型同时对接CANN生态中Graph Engine、AFIR等编译后端实现图编译。JAX框架通过加载XLA-NPU动态库so文件实现JAX框架对接NPU设备运行JAX脚本及网络。图1XLA-NPU架构图使用说明使用场景当前版本的xla_npu作为beta特性主要专注于推理场景下的模型优化。产品支持情况Atlas A3 训练系列产品/Atlas A3 推理系列产品Atlas A2 训练系列产品/Atlas A2 推理系列产品整体约束当前只支持使用1张NPU卡不支持集合通信。只支持jax.jit()整图编译。支持的JAX API清单|JAX API|约束| |--|--| |jax.numpy.add|支持fp32| |jax.numpy.subtract|支持fp32| |jax.numpy.multiply|支持fp32| |jax.numpy.divide|支持fp32| |jax.numpy.dot|支持fp32| |jax.numpy.tanh|支持fp32| |jax.numpy.negative|支持fp32| |jax.numpy.exp|支持fp32| |jax.numpy.maximum|支持fp32| |jax.numpy.concatenate|支持fp32| |jax.numpy.max|支持fp32| |jax.numpy.sum|支持fp32| |jax.nn.gelu|支持fp32|Demo模型及单个算子用例样例参考demo。常见问题使用afir融合后端执行测试用例报错ModuleNotFoundError: No module named runtime问题原因:ASCEND_MLIR_PYTHON_PATH环境变量指向了一个错误或者无效路径解决方法:执行测试用例前, 重新执行export ASCEND_MLIR_PYTHON_PATHxla-npu代码仓中dependency下载的Ascend-MLIR中Python可执行文件路径【免费下载链接】xla-npuXLA-NPU 是一个面向华为昇腾NPU硬件的 XLA后端实现。本项目通过接入OpenXLA/XLA开源项目将XLA开源生态与华为 CANN软件栈集成对接JAX框架。JAX框架运行时可以直接加载XLA-NPU使得基于JAX框架开发的模型可以运行在昇腾NPU上提供推理场景图编译加速能力。项目地址: https://gitcode.com/cann/xla-npu创作声明:本文部分内容由AI辅助生成(AIGC),仅供参考