MSCATTER

Tile Operation Diagram

MSCATTER tile operation

Introduction

Scatter data from a Tile into a GlobalTensor (GM) using per-row or per-element indices. This custom instruction performs indexed memory writes to global memory, supporting both row-level block transfers (e.g., expert weight updates) and element-level indexed transfers (e.g., sparse gradient accumulation).

MSCATTER is implemented as a SIMT kernel using cce::async_invoke with 1024 threads (32 warps × 32 lanes). Optional atomic operations handle write conflicts when multiple source elements target the same destination.

Math Interpretation

Row-Indexed Scatter

For a table with RowWidth-sized rows, given a 1D index tile idx of length NumRows:

\[ \mathrm{table}_{\mathrm{idx}_{i},\, j} = \mathrm{src}_{i,j} \quad \text{for } 0 \le i < \text{NumRows},\; 0 \le j < \text{RowWidth} \]

Each index specifies the destination row in the table where the corresponding source row is written.

Element-Indexed Scatter

For per-element indexing where each source element specifies its destination:

\[ \mathrm{table}[\mathrm{idx}_{i,j}] = \mathrm{src}_{i,j} \]

Indices are interpreted as linear element offsets into the destination.

Atomic Accumulation Mode

When ScatterAtomicOp::Add is specified:

\[ \mathrm{table}[\mathrm{idx}_{i,j}] \mathrel{+}= \mathrm{src}_{i,j} \]

Assembly Syntax

PTO-AS form: see docs/assembly/PTO-AS.md.

Row-indexed scatter:

mscatter.row %table, %src, %idx : (!pto.memref<...>, !pto.tile<NxMxT>, !pto.tile<Nx1xi32>)

Element-indexed scatter:

mscatter.elem %table, %src, %idx : (!pto.memref<...>, !pto.tile<NxMxT>, !pto.tile<NxMxi32>)

Atomic scatter:

mscatter.row.atomic_add %table, %src, %idx : (!pto.memref<...>, !pto.tile<NxMxT>, !pto.tile<Nx1xi32>)

C++ Intrinsic

Declared in include/pto/npu/a5/MScatter.hpp:

template <ScatterAtomicOp Atomic = ScatterAtomicOp::None, ScatterOOB Mode = ScatterOOB::Undefined,
          typename GlobalTable, typename TileSrc, typename TileIdx>
PTO_INTERNAL void MSCATTER(GlobalTable& table, TileSrc& src, TileIdx& indices);

Parameters: - table: Destination GlobalTensor in GM - src: Source tile in UB with shape [NumRows, NumCols] - indices: Index tile containing row indices (shape [NumRows, 1]) or element indices (shape [NumRows, NumCols]) - Atomic: Atomic operation mode for handling write conflicts (template parameter) - Mode: Out-of-bounds handling mode (template parameter)

Atomic Types

The ScatterAtomicOp enum controls behavior when multiple source elements write to the same destination:

enum class ScatterAtomicOp : uint8_t {
    None = 0,  // Non-atomic write
    Add  = 1,  // Atomic addition
    Max  = 2,  // Atomic maximum
    Min  = 3   // Atomic minimum
};

Atomic Type Constraints

  • AtomicNone: Available for all supported data types
  • AtomicAdd: Available for int32_t, uint32_t, float, half
  • AtomicMax: Available for int32_t, float
  • AtomicMin: Available for int32_t, float

Out-of-Bounds Handling

The ScatterOOB enum controls behavior when indices exceed table bounds:

enum class ScatterOOB : uint8_t {
    Undefined = 0,  // No bounds check
    Skip      = 1,  // Skip OOB writes (no memory access)
    Clamp     = 2,  // Clamp index to valid range [0, tableSize-1]
    Wrap      = 3   // Index modulo tableSize (idx % tableSize)
};

Constraints

Data Types (A5)

  • TileSrc::DType must be one of: int8_t, uint8_t, int16_t, uint16_t, int32_t, uint32_t, half, bfloat16_t, float, float8_e4m3_t, float8_e5m2_t.

Index Types

  • TileIdx::DType must be int32_t or uint32_t.

Tile Constraints

  • Source tile location must be TileType::Vec (UB).
  • Index tile location must be TileType::Vec (UB).
  • TileSrc::Rows == TileIdx::Rows (matching row count).
  • TileIdx::Cols == 1 for row-indexed scatter, or TileIdx::Cols == TileSrc::Cols for element-indexed scatter.
  • Source and table must have the same data type.

Shape Constraints

  • For row-indexed mode: Table shape dimension 3 specifies number of rows, dimension 4 specifies row width.
  • For element-indexed mode: Table is treated as a linear array of size Shape[3] * Shape[4].

Examples

Row-Indexed Scatter (Weight Update)

#include <pto/npu/a5/custom/MScatter.hpp>

using namespace pto;

template <typename T, int NumRows, int RowWidth, int TableRows>
void example_weight_update(__gm__ T* table, __gm__ int32_t* indices) {
    using IdxTile = Tile<TileType::Vec, int32_t, NumRows, 1>;
    using SrcTile = Tile<TileType::Vec, T, NumRows, RowWidth>;
    using TableShape = Shape<1, 1, 1, TableRows, RowWidth>;
    using TableStride = Stride<1, 1, 1, RowWidth, 1>;
    using TableTensor = GlobalTensor<T, TableShape, TableStride, Layout::ND>;

    TableTensor tableGM(table);
    IdxTile idx;
    SrcTile src;

    TASSIGN(idx, 0x0);
    TASSIGN(src, 0x1000);

    // Scatter with skip mode for invalid indices
    MSCATTER<ScatterAtomicOp::None, ScatterOOB::Skip>(tableGM, src, idx);
}

Atomic Gradient Accumulation

#include <pto/npu/a5/custom/MScatter.hpp>

using namespace pto;

void example_gradient_accumulation(__gm__ float* gradTable, __gm__ int32_t* tokenIndices) {
    using IdxTile = Tile<TileType::Vec, int32_t, 16, 1>;
    using GradTile = Tile<TileType::Vec, float, 16, 64>;
    using TableShape = Shape<1, 1, 1, 65536, 64>;
    using TableStride = Stride<1, 1, 1, 64, 1>;
    using TableTensor = GlobalTensor<float, TableShape, TableStride, Layout::ND>;

    TableTensor tableGM(gradTable);
    IdxTile idx;
    GradTile grads;

    TASSIGN(idx, 0x0);
    TASSIGN(grads, 0x1000);

    // Atomic add for gradient accumulation
    MSCATTER<ScatterAtomicOp::Add, ScatterOOB::Skip>(tableGM, grads, idx);
}

Element-Indexed Scatter (Sparse Update)

#include <pto/npu/a5/custom/MScatter.hpp>

using namespace pto;

void example_sparse_update(__gm__ float* data, __gm__ int32_t* sparseIndices) {
    using IdxTile = Tile<TileType::Vec, int32_t, 16, 16>;
    using SrcTile = Tile<TileType::Vec, float, 16, 16>;
    using DataShape = Shape<1, 1, 1, 1024, 64>;
    using DataStride = Stride<1, 1, 1, 64, 1>;
    using DataTensor = GlobalTensor<float, DataShape, DataStride, Layout::ND>;

    DataTensor dataGM(data);
    IdxTile idx;
    SrcTile src;

    TASSIGN(idx, 0x0);
    TASSIGN(src, 0x2000);

    // Scatter with wrap mode
    MSCATTER<ScatterAtomicOp::None, ScatterOOB::Wrap>(dataGM, src, idx);
}

Manual Memory Assignment

#include <pto/npu/a5/custom/MScatter.hpp>

using namespace pto;

void example_manual() {
    using IdxTile = Tile<TileType::Vec, int32_t, 8, 1>;
    using SrcTile = Tile<TileType::Vec, half, 8, 64>;
    using TableShape = Shape<1, 1, 1, 65536, 64>;
    using TableStride = Stride<1, 1, 1, 64, 1>;
    using TableTensor = GlobalTensor<half, TableShape, TableStride, Layout::ND>;

    __gm__ half* tablePtr = /* ... */;
    TableTensor tableGM(tablePtr);

    IdxTile idx;
    SrcTile src;

    TASSIGN(idx, 0x0);
    TASSIGN(src, 0x1000);

    MSCATTER<ScatterAtomicOp::None, ScatterOOB::Clamp>(tableGM, src, idx);
}

Performance Considerations

  1. Row-indexed scatter is more efficient than element-indexed when writing structured data because it enables coalesced memory access within each row.

  2. Atomic operations: Use ScatterAtomicOp::None when indices are guaranteed unique for best performance. Atomic modes add synchronization overhead.

  3. SIMT execution: The kernel uses 1024 threads (32 warps × 32 lanes) for parallel scatter operations.

  4. Out-of-bounds mode: ScatterOOB::Undefined is fastest but requires indices to be valid. Use Skip, Clamp, or Wrap when indices may exceed bounds.

  5. Write conflicts: When multiple source elements target the same destination with AtomicNone, the result is non-deterministic (last writer wins).

  • TSTORE: Contiguous block transfer from Tile to GM
  • TSCATTER: Index-based scatter within tiles (UB-to-UB)
  • MGATHER: Indexed gather from GM to Tile (inverse operation)