MSCATTER¶
Tile Operation Diagram¶
Introduction¶
Scatter data from a Tile into a GlobalTensor (GM) using per-row or per-element indices. This custom instruction performs indexed memory writes to global memory, supporting both row-level block transfers (e.g., expert weight updates) and element-level indexed transfers (e.g., sparse gradient accumulation).
MSCATTER is implemented as a SIMT kernel using cce::async_invoke with 1024 threads (32 warps × 32 lanes). Optional atomic operations handle write conflicts when multiple source elements target the same destination.
Math Interpretation¶
Row-Indexed Scatter¶
For a table with RowWidth-sized rows, given a 1D index tile idx of length NumRows:
Each index specifies the destination row in the table where the corresponding source row is written.
Element-Indexed Scatter¶
For per-element indexing where each source element specifies its destination:
Indices are interpreted as linear element offsets into the destination.
Atomic Accumulation Mode¶
When ScatterAtomicOp::Add is specified:
Assembly Syntax¶
PTO-AS form: see docs/assembly/PTO-AS.md.
Row-indexed scatter:
mscatter.row %table, %src, %idx : (!pto.memref<...>, !pto.tile<NxMxT>, !pto.tile<Nx1xi32>)
Element-indexed scatter:
mscatter.elem %table, %src, %idx : (!pto.memref<...>, !pto.tile<NxMxT>, !pto.tile<NxMxi32>)
Atomic scatter:
mscatter.row.atomic_add %table, %src, %idx : (!pto.memref<...>, !pto.tile<NxMxT>, !pto.tile<Nx1xi32>)
C++ Intrinsic¶
Declared in include/pto/npu/a5/MScatter.hpp:
template <ScatterAtomicOp Atomic = ScatterAtomicOp::None, ScatterOOB Mode = ScatterOOB::Undefined,
typename GlobalTable, typename TileSrc, typename TileIdx>
PTO_INTERNAL void MSCATTER(GlobalTable& table, TileSrc& src, TileIdx& indices);
Parameters:
- table: Destination GlobalTensor in GM
- src: Source tile in UB with shape [NumRows, NumCols]
- indices: Index tile containing row indices (shape [NumRows, 1]) or element indices (shape [NumRows, NumCols])
- Atomic: Atomic operation mode for handling write conflicts (template parameter)
- Mode: Out-of-bounds handling mode (template parameter)
Atomic Types¶
The ScatterAtomicOp enum controls behavior when multiple source elements write to the same destination:
enum class ScatterAtomicOp : uint8_t {
None = 0, // Non-atomic write
Add = 1, // Atomic addition
Max = 2, // Atomic maximum
Min = 3 // Atomic minimum
};
Atomic Type Constraints¶
AtomicNone: Available for all supported data typesAtomicAdd: Available forint32_t,uint32_t,float,halfAtomicMax: Available forint32_t,floatAtomicMin: Available forint32_t,float
Out-of-Bounds Handling¶
The ScatterOOB enum controls behavior when indices exceed table bounds:
enum class ScatterOOB : uint8_t {
Undefined = 0, // No bounds check
Skip = 1, // Skip OOB writes (no memory access)
Clamp = 2, // Clamp index to valid range [0, tableSize-1]
Wrap = 3 // Index modulo tableSize (idx % tableSize)
};
Constraints¶
Data Types (A5)¶
TileSrc::DTypemust be one of:int8_t,uint8_t,int16_t,uint16_t,int32_t,uint32_t,half,bfloat16_t,float,float8_e4m3_t,float8_e5m2_t.
Index Types¶
TileIdx::DTypemust beint32_toruint32_t.
Tile Constraints¶
- Source tile location must be
TileType::Vec(UB). - Index tile location must be
TileType::Vec(UB). TileSrc::Rows == TileIdx::Rows(matching row count).TileIdx::Cols == 1for row-indexed scatter, orTileIdx::Cols == TileSrc::Colsfor element-indexed scatter.- Source and table must have the same data type.
Shape Constraints¶
- For row-indexed mode: Table shape dimension 3 specifies number of rows, dimension 4 specifies row width.
- For element-indexed mode: Table is treated as a linear array of size
Shape[3] * Shape[4].
Examples¶
Row-Indexed Scatter (Weight Update)¶
#include <pto/npu/a5/custom/MScatter.hpp>
using namespace pto;
template <typename T, int NumRows, int RowWidth, int TableRows>
void example_weight_update(__gm__ T* table, __gm__ int32_t* indices) {
using IdxTile = Tile<TileType::Vec, int32_t, NumRows, 1>;
using SrcTile = Tile<TileType::Vec, T, NumRows, RowWidth>;
using TableShape = Shape<1, 1, 1, TableRows, RowWidth>;
using TableStride = Stride<1, 1, 1, RowWidth, 1>;
using TableTensor = GlobalTensor<T, TableShape, TableStride, Layout::ND>;
TableTensor tableGM(table);
IdxTile idx;
SrcTile src;
TASSIGN(idx, 0x0);
TASSIGN(src, 0x1000);
// Scatter with skip mode for invalid indices
MSCATTER<ScatterAtomicOp::None, ScatterOOB::Skip>(tableGM, src, idx);
}
Atomic Gradient Accumulation¶
#include <pto/npu/a5/custom/MScatter.hpp>
using namespace pto;
void example_gradient_accumulation(__gm__ float* gradTable, __gm__ int32_t* tokenIndices) {
using IdxTile = Tile<TileType::Vec, int32_t, 16, 1>;
using GradTile = Tile<TileType::Vec, float, 16, 64>;
using TableShape = Shape<1, 1, 1, 65536, 64>;
using TableStride = Stride<1, 1, 1, 64, 1>;
using TableTensor = GlobalTensor<float, TableShape, TableStride, Layout::ND>;
TableTensor tableGM(gradTable);
IdxTile idx;
GradTile grads;
TASSIGN(idx, 0x0);
TASSIGN(grads, 0x1000);
// Atomic add for gradient accumulation
MSCATTER<ScatterAtomicOp::Add, ScatterOOB::Skip>(tableGM, grads, idx);
}
Element-Indexed Scatter (Sparse Update)¶
#include <pto/npu/a5/custom/MScatter.hpp>
using namespace pto;
void example_sparse_update(__gm__ float* data, __gm__ int32_t* sparseIndices) {
using IdxTile = Tile<TileType::Vec, int32_t, 16, 16>;
using SrcTile = Tile<TileType::Vec, float, 16, 16>;
using DataShape = Shape<1, 1, 1, 1024, 64>;
using DataStride = Stride<1, 1, 1, 64, 1>;
using DataTensor = GlobalTensor<float, DataShape, DataStride, Layout::ND>;
DataTensor dataGM(data);
IdxTile idx;
SrcTile src;
TASSIGN(idx, 0x0);
TASSIGN(src, 0x2000);
// Scatter with wrap mode
MSCATTER<ScatterAtomicOp::None, ScatterOOB::Wrap>(dataGM, src, idx);
}
Manual Memory Assignment¶
#include <pto/npu/a5/custom/MScatter.hpp>
using namespace pto;
void example_manual() {
using IdxTile = Tile<TileType::Vec, int32_t, 8, 1>;
using SrcTile = Tile<TileType::Vec, half, 8, 64>;
using TableShape = Shape<1, 1, 1, 65536, 64>;
using TableStride = Stride<1, 1, 1, 64, 1>;
using TableTensor = GlobalTensor<half, TableShape, TableStride, Layout::ND>;
__gm__ half* tablePtr = /* ... */;
TableTensor tableGM(tablePtr);
IdxTile idx;
SrcTile src;
TASSIGN(idx, 0x0);
TASSIGN(src, 0x1000);
MSCATTER<ScatterAtomicOp::None, ScatterOOB::Clamp>(tableGM, src, idx);
}
Performance Considerations¶
-
Row-indexed scatter is more efficient than element-indexed when writing structured data because it enables coalesced memory access within each row.
-
Atomic operations: Use
ScatterAtomicOp::Nonewhen indices are guaranteed unique for best performance. Atomic modes add synchronization overhead. -
SIMT execution: The kernel uses 1024 threads (32 warps × 32 lanes) for parallel scatter operations.
-
Out-of-bounds mode:
ScatterOOB::Undefinedis fastest but requires indices to be valid. UseSkip,Clamp, orWrapwhen indices may exceed bounds. -
Write conflicts: When multiple source elements target the same destination with
AtomicNone, the result is non-deterministic (last writer wins).