TREDUCE¶
简介¶
Reduce 操作:从多个远端 NPU 收集数据并在本地执行逐元素归约。
只有根节点需要执行 TREDUCE。非根节点只需确保在操作期间其源缓冲区已就绪且保持有效。在非根节点上调用 TREDUCE 属于未定义行为。
大 Tile 支持:当 GlobalTensor 在行和/或列方向超出 UB Tile 容量时,归约操作将通过二维滑动自动分块。
数学语义¶
对有效区域内每个元素 (i, j):
\[\mathrm{dst}^{\mathrm{local}}_{i,j} = \bigoplus_{r=0}^{N-1} \mathrm{src}^{(r)}_{i,j}\]
其中 \(N\) 为 rank 总数,\(\oplus\) 为归约运算(求和、取最大值、取最小值等)。
汇编语法¶
PTO-AS 形式:参见 PTO-AS 规范。
同步形式:
treduce %group, %dst {op = #pto.reduce_op<Sum>} : (!pto.group<...>, !pto.memref<...>)
treduce %group, %dst {op = #pto.reduce_op<Max>} : (!pto.group<...>, !pto.memref<...>)
降级时会为 reduce 流水线引入内部累加 Tile 和接收 Tile;C++ 内建接口需要显式传入 accTileData、recvTileData(或 accTileData、pingTileData、pongTileData)操作数。
C++ 内建接口¶
声明于 include/pto/comm/pto_comm_inst.hpp:
// 基础 reduce(累加 Tile + 接收 Tile)
template <typename ParallelGroupType, typename GlobalDstData, typename TileData, typename... WaitEvents>
PTO_INST RecordEvent TREDUCE(ParallelGroupType ¶llelGroup, GlobalDstData &dstGlobalData,
TileData &accTileData, TileData &recvTileData, ReduceOp op, WaitEvents&... events);
// 乒乓 reduce(累加 Tile + ping/pong Tile 实现双缓冲)
template <typename ParallelGroupType, typename GlobalDstData, typename TileData, typename... WaitEvents>
PTO_INST RecordEvent TREDUCE(ParallelGroupType ¶llelGroup, GlobalDstData &dstGlobalData,
TileData &accTileData, TileData &pingTileData, TileData &pongTileData,
ReduceOp op, WaitEvents&... events);
约束¶
- 类型约束:
ParallelGroup::value_type::RawDType必须等于GlobalDstData::RawDType。TileData::DType必须等于GlobalDstData::RawDType。
- 内存约束:
dstGlobalData必须指向本地内存(当前 NPU)。accTileData、recvTileData(或accTileData、pingTileData、pongTileData)必须为预先分配的 UB Tile。
- ParallelGroup 约束:
parallelGroup.tensors[r]必须指向 rankr的源缓冲区(从根节点视角看到的远端 GM)。parallelGroup.GetRootIdx()标识调用方 NPU 为 reduce 根节点。- 所有源 tensor 假定具有相同的形状和步幅。
- 分块模式约束(数据超出单个 UB Tile 时):
- 若
TileData具有静态ValidRow,则GetShape(DIM_3)必须能被ValidRow整除。如需支持不足一行的情况,请使用DYNAMICValidRow 的 Tile。 - 若
TileData具有静态ValidCol,则GetShape(DIM_4)必须能被ValidCol整除。如需支持不足一列的情况,请使用DYNAMICValidCol 的 Tile。
- 若
示例¶
基础求和归约¶
#include <pto/comm/pto_comm_inst.hpp>
using namespace pto;
template <typename T, int SIZE, int NRANKS>
void reduce_sum(__gm__ T* group_addrs[NRANKS], __gm__ T* result, int my_rank) {
using TileT = Tile<TileType::Vec, T, 1, SIZE>;
using GTensor = GlobalTensor<T, Shape<1,1,1,1,SIZE>,
BaseShape2D<T, 1, SIZE, Layout::ND>, Layout::ND>;
GTensor tensors[NRANKS];
for (int i = 0; i < NRANKS; ++i) tensors[i] = GTensor(group_addrs[i]);
comm::ParallelGroup<GTensor> group(tensors, NRANKS, my_rank);
GTensor dstG(result);
TileT accTile, recvTile;
comm::TREDUCE(group, dstG, accTile, recvTile, comm::ReduceOp::Sum);
}
最大值归约¶
#include <pto/comm/pto_comm_inst.hpp>
using namespace pto;
template <typename T, int SIZE, int NRANKS>
void reduce_max(__gm__ T* group_addrs[NRANKS], __gm__ T* result, int my_rank) {
using TileT = Tile<TileType::Vec, T, 1, SIZE>;
using GTensor = GlobalTensor<T, Shape<1,1,1,1,SIZE>,
BaseShape2D<T, 1, SIZE, Layout::ND>, Layout::ND>;
GTensor tensors[NRANKS];
for (int i = 0; i < NRANKS; ++i) tensors[i] = GTensor(group_addrs[i]);
comm::ParallelGroup<GTensor> group(tensors, NRANKS, my_rank);
GTensor dstG(result);
TileT accTile, recvTile;
comm::TREDUCE(group, dstG, accTile, recvTile, comm::ReduceOp::Max);
}