TBROADCAST¶
简介¶
将当前 NPU 的数据广播到并行组中所有 rank。调用方 NPU 为根节点,其数据将被复制到所有其他 NPU。
只有根节点需要执行 TBROADCAST。非根节点只需确保在操作期间其目标缓冲区已分配且可写。在非根节点上调用 TBROADCAST 属于未定义行为。
大 Tile 支持:当 GlobalTensor 在行和/或列方向超出 UB(统一缓冲区)Tile 容量时,传输将通过二维滑动自动分块。
数学语义¶
操作完成后:
\[ \mathrm{dst}^{(k)}_{i,j} = \mathrm{src}^{(\text{root})}_{i,j} \quad \forall k \in [0, N) \]
其中 \(N\) 为 rank 总数,root 为调用方 NPU。
汇编语法¶
PTO-AS 形式:参见 PTO-AS 规范。
同步形式:
tbroadcast %group, %src : (!pto.group<...>, !pto.memref<...>)
降级时会为 GM→UB→GM 数据路径引入 UB 暂存 Tile;C++ 内建接口需要显式传入 stagingTileData(或 pingTile / pongTile)操作数。
C++ 内建接口¶
声明于 include/pto/comm/pto_comm_inst.hpp:
// 基础广播(单暂存 Tile)
template <typename ParallelGroupType, typename GlobalSrcData, typename TileData, typename... WaitEvents>
PTO_INST RecordEvent TBROADCAST(ParallelGroupType ¶llelGroup, GlobalSrcData &srcGlobalData,
TileData &stagingTileData, WaitEvents&... events);
// 乒乓广播(使用两个暂存 Tile 实现双缓冲)
template <typename ParallelGroupType, typename GlobalSrcData, typename TileData, typename... WaitEvents>
PTO_INST RecordEvent TBROADCAST(ParallelGroupType ¶llelGroup, GlobalSrcData &srcGlobalData,
TileData &pingTile, TileData &pongTile, WaitEvents&... events);
约束¶
- 类型约束:
ParallelGroup::value_type::RawDType必须等于GlobalSrcData::RawDType。TileData::DType必须等于GlobalSrcData::RawDType。
- 内存约束:
srcGlobalData必须指向本地内存(当前 NPU)。stagingTileData(或pingTile/pongTile)必须预先在 UB 中分配。
- ParallelGroup 约束:
parallelGroup.tensors[k]必须指向 rankk的目标缓冲区(从根节点视角看到的远端 GM)。parallelGroup.GetRootIdx()标识调用方 NPU 为广播根节点。- 所有目标 tensor 假定具有相同的形状和步幅。
- 分块模式约束(数据超出单个 UB Tile 时):
- 若
TileData具有静态ValidRow,则GetShape(DIM_3)必须能被ValidRow整除。如需支持不足一行的情况,请使用DYNAMICValidRow 的 Tile。 - 若
TileData具有静态ValidCol,则GetShape(DIM_4)必须能被ValidCol整除。如需支持不足一列的情况,请使用DYNAMICValidCol 的 Tile。
- 若
示例¶
基础广播¶
#include <pto/comm/pto_comm_inst.hpp>
using namespace pto;
template <typename T, int ROWS, int COLS, int TILE_ROWS, int TILE_COLS, int NRANKS>
void broadcast(__gm__ T* group_addrs[NRANKS], __gm__ T* my_data, int my_rank) {
// Tile 维度可以与 tensor 维度不同。
// 二维滑动分块路径会自动在行和列两个方向进行分块。
using TileT = Tile<TileType::Vec, T, TILE_ROWS, TILE_COLS, BLayout::RowMajor, -1, -1>;
using GTensor = GlobalTensor<T, Shape<1,1,1,ROWS,COLS>,
BaseShape2D<T, ROWS, COLS, Layout::ND>, Layout::ND>;
GTensor tensors[NRANKS];
for (int i = 0; i < NRANKS; ++i) {
tensors[i] = GTensor(group_addrs[i]);
}
comm::ParallelGroup<GTensor> group(tensors, NRANKS, my_rank);
GTensor srcG(my_data);
TileT stagingTile(TILE_ROWS, TILE_COLS);
// 当前 NPU 将自身数据广播到所有其他 NPU
comm::TBROADCAST(group, srcG, stagingTile);
}
乒乓广播(双缓冲)¶
使用两个 UB Tile,将下一块的 TLOAD 与当前块的 TSTORE 重叠执行。
#include <pto/comm/pto_comm_inst.hpp>
using namespace pto;
template <typename T, int ROWS, int COLS, int TILE_ROWS, int TILE_COLS, int NRANKS>
void broadcast_pingpong(__gm__ T* group_addrs[NRANKS], __gm__ T* my_data, int my_rank) {
using TileT = Tile<TileType::Vec, T, TILE_ROWS, TILE_COLS, BLayout::RowMajor, -1, -1>;
using GPerRank = GlobalTensor<T, Shape<1,1,1,ROWS,COLS>,
BaseShape2D<T, ROWS, COLS, Layout::ND>, Layout::ND>;
GPerRank tensors[NRANKS];
for (int i = 0; i < NRANKS; ++i) {
tensors[i] = GPerRank(group_addrs[i]);
}
comm::ParallelGroup<GPerRank> group(tensors, NRANKS, my_rank);
GPerRank srcG(my_data);
TileT pingTile(TILE_ROWS, TILE_COLS);
TileT pongTile(TILE_ROWS, TILE_COLS);
// 乒乓模式:将 TLOAD 与 TSTORE 重叠执行以提升吞吐量
comm::TBROADCAST(group, srcG, pingTile, pongTile);
}