TLOAD¶
Tile Operation Diagram¶
Introduction¶
Load data from a GlobalTensor (GM) into a Tile.
Math Interpretation¶
Notation depends on the GlobalTensor shape/stride and the Tile layout. Conceptually (2D view, with a base offset):
\[ \mathrm{dst}_{i,j} = \mathrm{src}_{r_0 + i,\; c_0 + j} \]
Assembly Syntax¶
PTO-AS form: see PTO-AS Specification.
Synchronous form:
%t0 = tload %sv[%c0, %c0] : (!pto.memref<...>, index, index) -> !pto.tile<...>
IR Level 1 (SSA)¶
%dst = pto.tload %mem : !pto.partition_tensor_view<MxNxdtype> ->
!pto.tile<loc, dtype, rows, cols, blayout, slayout, fractal, pad>
IR Level 2 (DPS)¶
pto.tload ins(%mem : !pto.partition_tensor_view<MxNxdtype>) outs(%dst : !pto.tile_buf<...>)
C++ Intrinsic¶
Declared in include/pto/common/pto_instr.hpp:
template <typename TileData, typename GlobalData, typename... WaitEvents>
PTO_INST RecordEvent TLOAD(TileData& dst, GlobalData& src, WaitEvents&... events);
Constraints¶
- Implementation checks (A2A3):
TileData::DTypemust be one of:int8_t,uint8_t,int16_t,uint16_t,int32_t,uint32_t,int64_t,uint64_t,half,bfloat16_t,float.- Destination tile location must be
TileType::VecorTileType::Mat. sizeof(TileData::DType) == sizeof(GlobalData::DType).- Runtime: all
src.GetShape(dim)values anddst.GetValidRow()/GetValidCol()must be> 0. TileType::Vecloads only support matching layouts: ND->ND, DN->DN, NZ->NZ.TileType::Matloads support: ND->ND, DN->DN, NZ->NZ, plus ND->NZ and DN->ZN.- For ND->NZ or DN->ZN:
GlobalData::staticShape[0..2] == 1andTileData::SFractalSize == 512.
- For ND->NZ or DN->ZN:
- For
int64_t/uint64_t, only ND->ND or DN->DN are supported. - Implementation checks (A5):
sizeof(TileData::DType)must be1,2,4, or8bytes, and must matchsizeof(GlobalData::DType).- For
int64_t/uint64_t,TileData::PadValmust bePadValue::NullorPadValue::Zero. TileType::Vecloads require one of the following layout pairs:- ND with row-major +
SLayout::NoneBox(ND->ND), - DN with col-major +
SLayout::NoneBox(DN->DN), - NZ with
SLayout::RowMajor(NZ->NZ).
- ND with row-major +
- For row-major ND->ND with compile-time-known shapes,
TileData::ValidColmust equalGlobalData::staticShape[4], andTileData::ValidRowmust equal the product ofGlobalData::staticShape[0..3]. TileType::Matloads are additionally constrained byTLoadCubeCheck(e.g., only specific ND/DN/NZ conversions and L1-size limits).-
TileType::Matloads also handle loads for mx format, which includeMX_A_ZZ/MX_A_ND/MX_A_DNto ZZ for scalarA andMX_B_NN/MX_B_ND/MX_B_DNto NN for scalarB.- for
MX_A_ZZ/MX_B_NN:GlobalData::staticShape[3] == 16andGlobalData::staticShape[4] == 2. - for
MX_A_ND/MX_ADN/MX_B_ND/MX_B_DN:GlobalData::staticShape[0] == 1andGlobalData::staticShape[1] == 1andGlobalData::staticShape[4] == 2. - for scaleA,
dst.GetValidCol() % 2 == 0. - for scaleB,
dst.GetValidRow() % 2 == 0
- for
-
Valid region:
- The implementation uses
dst.GetValidRow()/dst.GetValidCol()as the transfer size.
Examples¶
Auto¶
#include <pto/pto-inst.hpp>
using namespace pto;
template <typename T>
void example_auto(__gm__ T* in) {
using TileT = Tile<TileType::Vec, T, 16, 16>;
using GShape = Shape<1, 1, 1, 16, 16>;
using GStride = BaseShape2D<T, 16, 16, Layout::ND>;
using GTensor = GlobalTensor<T, GShape, GStride, Layout::ND>;
GTensor gin(in);
TileT t;
TLOAD(t, gin);
}
Manual¶
#include <pto/pto-inst.hpp>
using namespace pto;
template <typename T>
void example_manual(__gm__ T* in) {
using TileT = Tile<TileType::Vec, T, 16, 16>;
using GShape = Shape<1, 1, 1, 16, 16>;
using GStride = BaseShape2D<T, 16, 16, Layout::ND>;
using GTensor = GlobalTensor<T, GShape, GStride, Layout::ND>;
GTensor gin(in);
TileT t;
TASSIGN(t, 0x1000);
TLOAD(t, gin);
}
ASM Form Examples¶
Auto Mode¶
# Auto mode: compiler/runtime-managed placement and scheduling.
%dst = pto.tload %mem : !pto.partition_tensor_view<MxNxdtype> ->
Manual Mode¶
# Manual mode: bind resources explicitly before issuing the instruction.
# Optional for tile operands:
# pto.tassign %arg0, @tile(0x1000)
# pto.tassign %arg1, @tile(0x2000)
%dst = pto.tload %mem : !pto.partition_tensor_view<MxNxdtype> ->
PTO Assembly Form¶
%t0 = tload %sv[%c0, %c0] : (!pto.memref<...>, index, index) -> !pto.tile<...>
# IR Level 2 (DPS)
pto.tload ins(%mem : !pto.partition_tensor_view<MxNxdtype>) outs(%dst : !pto.tile_buf<...>)