FA PTO PyTorch 移植示例

概述

本示例演示了如何使用 PTO 实现 Flash Attention 内核,并通过 torch_npu 将其作为自定义 PyTorch 算子对外暴露。示例展示了在 Ascend AI 处理器上实现高性能自定义内核集成,并具备自动 tile 适配能力。

支持的 AI 处理器

  • A2/A3/A5

1. 环境准备

创建虚拟环境并安装依赖:

python -m venv virEnv
source virEnv/bin/activate
python3 -m pip install -r requirements.txt

确保已配置 Ascend Toolkit 和 PTO 库:

export ASCEND_HOME_PATH=[YOUR_ASCEND_PATH/SYSTEM_ASCEND_PATH]
source [YOUR_ASCEND_PATH/SYSTEM_ASCEND_PATH]/latest/bin/setenv.bash
export PTO_LIB_PATH=[YOUR_PATH]/pto-isa

2. 构建 Wheel 包

项目支持通过 SOC_VERSION 环境变量为不同的 SOC 版本进行构建。构建系统会根据目标 SOC 自动配置正确的优化宏(例如 PTO_NPU_ARCH_A2A3PTO_NPU_ARCH_A5)。

默认构建(A2 / A3):

python3 setup.py bdist_wheel

为特定 SOC 构建(例如 A5):

# A5 示例
SOC_VERSION=ascend910_9599 python3 setup.py bdist_wheel

3. 安装 Wheel 包

pip install dist/*.whl --force-reinstall

4. 运行测试

运行验证脚本,将内核结果与黄金参考值进行比较。测试涵盖多种序列长度(1k 至 32k)并验证动态 tile 逻辑。

cd test
python3 test.py

特性

  • 动态 Tiling:根据输入序列长度自动选择最佳 tile 大小(128 或 256)。
  • 跨架构支持:通过构建时配置,统一的代码库同时支持 A2/A3 和 A5 架构。