算子融合技术

本文档深入介绍 PTO 算子融合技术,帮助开发者通过融合多个算子减少内存访问,提升整体性能。

目录


1. 算子融合概述

1.1 什么是算子融合

定义:将多个独立的算子合并为一个算子,在片上内存中完成所有计算,减少中间结果的 GM 存储和读取。

核心思想

传统方式:
  Kernel1: GM → L1 → Compute → L1 → GM
  Kernel2: GM → L1 → Compute → L1 → GM
  Kernel3: GM → L1 → Compute → L1 → GM

融合方式:
  FusedKernel: GM → L1 → Compute1 → Compute2 → Compute3 → L1 → GM

1.2 融合的优势

优势1:减少内存访问

示例:Add + ReLU + Mul

// 融合前:3 个独立算子
y = Add(x, bias);      // Load x, Store y
z = ReLU(y);           // Load y, Store z
out = Mul(z, scale);   // Load z, Store out

// 内存访问统计:
// - Load: 3 次(x, y, z)
// - Store: 3 次(y, z, out)
// - 总计: 6 次 GM 访问

// 融合后:1 个融合算子
out = FusedAddReLUMul(x, bias, scale);

// 内存访问统计:
// - Load: 1 次(x)
// - Store: 1 次(out)
// - 总计: 2 次 GM 访问

// 内存访问减少:(6 - 2) / 6 = 67%

优势2:减少内核启动开销

内核启动开销: - 每次内核启动:~10-50 μs - 3 个独立内核:30-150 μs - 1 个融合内核:10-50 μs - 节省:20-100 μs

优势3:提高数据局部性

缓存命中率提升

融合前:
  - 中间结果写回 GM
  - 可能被其他核心的数据驱逐出缓存
  - 下一个算子重新加载(缓存未命中)

融合后:
  - 中间结果保持在 L1
  - 无缓存驱逐
  - 100% 缓存命中

1.3 融合的挑战

挑战1:片上内存限制

// 问题:融合后可能超出 L1 容量
// A2/A3: L1 ~512 KB/核
// A5: L1 ~1 MB/核

// 示例:3 个算子各需 200 KB
// 融合前:每个算子独立运行,200 KB < 512 KB ✓
// 融合后:需要 600 KB > 512 KB ✗

挑战2:数据依赖复杂

// 问题:中间结果被多次使用
y = Add(x, bias);
z1 = ReLU(y);    // 使用 y
z2 = Sigmoid(y); // 再次使用 y
// 无法简单融合,需要保留 y

挑战3:计算密集型算子

// 问题:计算时间远大于内存访问时间
// GEMM: 计算密集型
// - 计算时间:100 ms
// - 内存访问时间:10 ms
// - 融合收益:< 10%(不值得)

2. 融合模式分类

2.1 逐元素融合(Element-wise Fusion)

特点: - 所有算子都是逐元素操作 - 无数据依赖 - 融合最简单,收益最大

常见模式

// 模式1:Add + ReLU
out = ReLU(Add(x, bias));

// 模式2:Add + ReLU + Mul
out = Mul(ReLU(Add(x, bias)), scale);

// 模式3:Add + BatchNorm + ReLU
out = ReLU(BatchNorm(Add(x, bias)));

// 模式4:Mul + Add + Sigmoid
out = Sigmoid(Add(Mul(x, scale), bias));

实现示例

__global__ __aicore__ void FusedAddReLUMul(
    __gm__ float* out,
    __gm__ const float* in,
    float bias,
    float scale,
    uint32_t length) {

  int block_idx = get_block_idx();
  int block_num = get_block_num();

  // 计算当前核心负责的数据范围
  int elements_per_block = (length + block_num - 1) / block_num;
  int start = block_idx * elements_per_block;
  int end = min(start + elements_per_block, length);

  using TileT = Tile<TileType::Vec, float, 16, 256>;

  for (int i = start; i < end; i += 16 * 256) {
    int size = min(16 * 256, end - i);

    TileT tile;

    // 加载数据
    TLOAD(tile, GlobalTensor(in + i));

    // 融合计算:Add + ReLU + Mul
    TADDS(tile, tile, bias);    // Add
    TRELU(tile, tile);          // ReLU
    TMULS(tile, tile, scale);   // Mul

    // 存储结果
    TSTORE(GlobalTensor(out + i), tile);
  }
}

性能分析

数据量:1M 元素(4 MB)
平台:A3(24 核)

融合前:
- Add: 0.05 ms
- ReLU: 0.05 ms
- Mul: 0.05 ms
- 总计: 0.15 ms

融合后:
- FusedAddReLUMul: 0.05 ms

加速比:3×

2.2 归约融合(Reduction Fusion)

特点: - 包含归约操作(sum, max, min) - 需要保留归约结果 - 融合中等复杂度

常见模式

// 模式1:Softmax
// max → sub → exp → sum → div
out = exp(x - max(x)) / sum(exp(x - max(x)))

// 模式2:LayerNorm
// mean → sub → square → mean → sqrt → div
out = (x - mean(x)) / sqrt(mean((x - mean(x))^2) + eps)

// 模式3:RMSNorm
// square → mean → sqrt → div
out = x / sqrt(mean(x^2) + eps)

Softmax 融合实现

__global__ __aicore__ void FusedSoftmax(
    __gm__ float* out,
    __gm__ const float* in,
    int rows,
    int cols) {

  int block_idx = get_block_idx();

  // 每个核心处理一行
  if (block_idx >= rows) return;

  using TileVec = Tile<TileType::Vec, float, 1, 256>;
  using TileScalar = Tile<TileType::Vec, float, 1, 1>;

  TileVec input, shifted, exp_vals, output;
  TileScalar max_val, sum_val;

  for (int col = 0; col < cols; col += 256) {
    int size = min(256, cols - col);

    // 加载输入
    TLOAD(input, in[block_idx * cols + col : size]);

    // 步骤1:计算最大值
    TROWMAX(max_val, input);

    // 步骤2:减去最大值(数值稳定)
    TROWEXPANDSUB(shifted, input, max_val);

    // 步骤3:计算指数
    TEXP(exp_vals, shifted);

    // 步骤4:计算和
    TROWSUM(sum_val, exp_vals);

    // 步骤5:归一化
    TROWEXPANDDIV(output, exp_vals, sum_val);

    // 存储结果
    TSTORE(out[block_idx * cols + col : size], output);
  }
}

性能分析

数据量:1024 × 1024(4 MB)
平台:A3(24 核)

融合前(5 个独立算子):
- RowMax: 0.08 ms
- Sub: 0.05 ms
- Exp: 0.12 ms
- RowSum: 0.08 ms
- Div: 0.05 ms
- 总计: 0.38 ms

融合后:
- FusedSoftmax: 0.15 ms

加速比:2.5×

2.3 矩阵融合(Matrix Fusion)

特点: - 包含矩阵乘法 - 融合后处理(Bias, Activation) - 高性能收益

常见模式

// 模式1:GEMM + Bias
out = MatMul(A, B) + bias

// 模式2:GEMM + Bias + ReLU
out = ReLU(MatMul(A, B) + bias)

// 模式3:GEMM + Bias + GELU
out = GELU(MatMul(A, B) + bias)

// 模式4:GEMM + Residual + LayerNorm
out = LayerNorm(MatMul(A, B) + residual)

GEMM + Bias + ReLU 融合实现

__global__ __aicore__ void FusedGEMMBiasReLU(
    __gm__ float* C,
    __gm__ const float* A,
    __gm__ const float* B,
    __gm__ const float* bias,
    int M, int K, int N) {

  int block_idx = get_block_idx();

  // 2D 划分
  int blocks_n = (N + TILE_N - 1) / TILE_N;
  int block_m = block_idx / blocks_n;
  int block_n = block_idx % blocks_n;

  int m_start = block_m * TILE_M;
  int n_start = block_n * TILE_N;

  if (m_start >= M || n_start >= N) return;

  using TileLeft = TileLeft<half, 128, 64>;
  using TileRight = TileRight<half, 64, 256>;
  using TileAcc = TileAcc<float, 128, 256>;
  using TileBias = Tile<TileType::Vec, float, 1, 256>;

  TileAcc acc;
  TFILL(acc, 0);

  // 矩阵乘法
  for (int k = 0; k < K; k += 64) {
    TileLeft tileA;
    TileRight tileB;

    TLOAD(tileA, A[m_start:128, k:64]);
    TLOAD(tileB, B[k:64, n_start:256]);
    TMATMUL_ACC(acc, tileA, tileB);
  }

  // 融合 Bias
  TileBias bias_tile;
  TLOAD(bias_tile, bias[n_start:256]);
  TROWEXPANDADD(acc, acc, bias_tile);

  // 融合 ReLU
  TRELU(acc, acc);

  // 存储结果
  TSTORE(C[m_start:128, n_start:256], acc);
}

性能分析

矩阵大小:1024 × 1024 × 1024
平台:A3(24 核)

融合前:
- GEMM: 2.5 ms
- Bias: 0.05 ms
- ReLU: 0.05 ms
- 总计: 2.6 ms

融合后:
- FusedGEMMBiasReLU: 2.52 ms

加速比:1.03×(收益较小,但避免了额外的内核启动)

2.4 复杂融合(Complex Fusion)

特点: - 多种操作类型混合 - 复杂的数据流 - 需要精心设计

示例:Fused Multi-Head Attention

// QKV 投影 + Softmax + 输出投影
// Q = Linear(x)
// K = Linear(x)
// V = Linear(x)
// Attention = Softmax(Q @ K^T / sqrt(d))
// Out = Attention @ V
// Result = Linear(Out)

3. 融合实现技术

3.1 手动融合步骤

步骤1:识别融合机会

# 分析计算图
def analyze_fusion_opportunities(graph):
    candidates = []

    for node in graph.nodes:
        # 查找连续的逐元素操作
        if is_elementwise(node):
            chain = find_elementwise_chain(node)
            if len(chain) >= 2:
                candidates.append(chain)

    return candidates

步骤2:验证融合可行性

// 检查清单
bool can_fuse(Op op1, Op op2) {
  // 1. 检查数据依赖
  if (op2.input != op1.output) return false;

  // 2. 检查中间结果是否被其他算子使用
  if (op1.output.num_users > 1) return false;

  // 3. 检查片上内存容量
  size_t required_memory = op1.memory + op2.memory;
  if (required_memory > L1_CAPACITY) return false;

  // 4. 检查数据类型兼容性
  if (op1.output_type != op2.input_type) return false;

  return true;
}

步骤3:实现融合 kernel

// 模板化融合 kernel
template<typename Op1, typename Op2, typename Op3>
__global__ __aicore__ void FusedKernel(
    __gm__ float* out,
    __gm__ const float* in,
    Op1 op1, Op2 op2, Op3 op3) {

  using TileT = Tile<TileType::Vec, float, 16, 256>;
  TileT tile;

  TLOAD(tile, in);

  // 依次执行融合的操作
  op1(tile, tile);
  op2(tile, tile);
  op3(tile, tile);

  TSTORE(out, tile);
}

步骤4:性能验证

// 对比融合前后的性能
void benchmark_fusion() {
  // 融合前
  auto start = GetTime();
  kernel1<<<...>>>();
  kernel2<<<...>>>();
  kernel3<<<...>>>();
  auto time_unfused = GetTime() - start;

  // 融合后
  start = GetTime();
  fused_kernel<<<...>>>();
  auto time_fused = GetTime() - start;

  float speedup = time_unfused / time_fused;
  printf("Speedup: %.2fx\n", speedup);
}

3.2 编译器自动融合

未来特性:PTO Tile Fusion

// 使用 pragma 指示编译器融合
#pragma pto_fusion_begin
{
  TADD(y, x, bias);
  TRELU(z, y);
  TMUL(out, z, scale);
}
#pragma pto_fusion_end

// 编译器自动:
// 1. 分析数据依赖
// 2. 检查融合可行性
// 3. 生成融合 kernel
// 4. 优化内存访问

4. 融合收益分析

4.1 理论收益计算

公式

加速比 = T_unfused / T_fused

其中:
T_unfused = Σ(T_compute_i + T_memory_i + T_launch_i)
T_fused = T_compute_fused + T_memory_fused + T_launch_fused

通常:
T_compute_fused ≈ Σ T_compute_i(计算时间不变)
T_memory_fused << Σ T_memory_i(内存访问大幅减少)
T_launch_fused << Σ T_launch_i(启动开销减少)

示例计算

Add + ReLU + Mul 融合:

融合前:
- Add: 0.01 ms (compute) + 0.04 ms (memory) + 0.02 ms (launch) = 0.07 ms
- ReLU: 0.01 ms + 0.04 ms + 0.02 ms = 0.07 ms
- Mul: 0.01 ms + 0.04 ms + 0.02 ms = 0.07 ms
- 总计: 0.21 ms

融合后:
- Compute: 0.03 ms(3 个操作)
- Memory: 0.04 ms(只加载和存储一次)
- Launch: 0.02 ms(只启动一次)
- 总计: 0.09 ms

加速比: 0.21 / 0.09 = 2.3×

4.2 实际收益测量

测量方法

// 使用 msprof 测量
void measure_fusion_benefit() {
  // 1. 测量融合前
  msprof_start();
  run_unfused_kernels();
  auto metrics_unfused = msprof_stop();

  // 2. 测量融合后
  msprof_start();
  run_fused_kernel();
  auto metrics_fused = msprof_stop();

  // 3. 分析收益
  printf("Memory access reduction: %.1f%%\n",
         100.0 * (1 - metrics_fused.memory_bytes / 
                      metrics_unfused.memory_bytes));

  printf("Kernel launch reduction: %d → %d\n",
         metrics_unfused.num_kernels,
         metrics_fused.num_kernels);

  printf("Overall speedup: %.2fx\n",
         metrics_unfused.time / metrics_fused.time);
}

5. 融合策略选择

5.1 决策树

是否融合?
├─ 中间结果只使用一次?
│   ├─ 是 → 继续
│   └─ 否 → 不融合
├─ 融合后不超出 L1 容量?
│   ├─ 是 → 继续
│   └─ 否 → 不融合或部分融合
├─ 预期加速比 > 1.2×?
│   ├─ 是 → 融合
│   └─ 否 → 不融合

5.2 融合优先级

高优先级(强烈推荐融合): 1. 多个逐元素操作 2. Softmax 类归约操作 3. BatchNorm + Activation 4. GEMM + Bias + Activation

中优先级(视情况融合): 1. LayerNorm + 后续操作 2. Attention 内部操作 3. 卷积 + Bias + Activation

低优先级(通常不融合): 1. 大型矩阵乘法(已经计算密集) 2. 中间结果被多次使用 3. 融合后超出片上内存


6. 实战案例

案例1:ResNet Block 融合

原始实现

# ResNet Block
def resnet_block(x, weight1, weight2, bias1, bias2):
    # Conv1 + BN1 + ReLU
    y = conv2d(x, weight1)
    y = batch_norm(y)
    y = relu(y)

    # Conv2 + BN2
    y = conv2d(y, weight2)
    y = batch_norm(y)

    # Residual + ReLU
    y = y + x
    y = relu(y)

    return y

融合策略

// 融合1:Conv + BN + ReLU
__global__ void FusedConvBNReLU(...) {
  // Conv
  TMATMUL(output, input, weight);

  // BN(融合)
  TROWEXPANDSUB(output, output, mean);
  TROWEXPANDDIV(output, output, std);
  TROWEXPANDMUL(output, output, gamma);
  TROWEXPANDADD(output, output, beta);

  // ReLU(融合)
  TRELU(output, output);
}

// 融合2:Add + ReLU
__global__ void FusedAddReLU(...) {
  TADD(output, y, residual);
  TRELU(output, output);
}

性能提升

原始:8 个 kernel,1.2 ms
融合后:4 个 kernel,0.7 ms
加速比:1.7×

案例2:Transformer Layer 融合

原始实现

# Transformer Layer
def transformer_layer(x, Wq, Wk, Wv, Wo):
    # QKV 投影
    Q = linear(x, Wq)
    K = linear(x, Wk)
    V = linear(x, Wv)

    # Attention
    scores = matmul(Q, K.T) / sqrt(d)
    attn = softmax(scores)
    out = matmul(attn, V)

    # 输出投影
    out = linear(out, Wo)

    return out

融合策略

// 融合1:QKV 投影(3 个 GEMM 合并)
__global__ void FusedQKVProjection(...) {
  // 一次性计算 Q, K, V
  TMATMUL(Q, x, Wq);
  TMATMUL(K, x, Wk);
  TMATMUL(V, x, Wv);
}

// 融合2:Attention Score + Softmax
__global__ void FusedAttentionSoftmax(...) {
  // Score
  TMATMUL(scores, Q, K_T);

  // Scale(融合)
  TMULS(scores, scores, 1.0 / sqrt(d));

  // Softmax(融合)
  TROWMAX(max_val, scores);
  TROWEXPANDSUB(shifted, scores, max_val);
  TEXP(exp_vals, shifted);
  TROWSUM(sum_val, exp_vals);
  TROWEXPANDDIV(attn, exp_vals, sum_val);
}

性能提升

原始:12 个 kernel,3.5 ms
融合后:6 个 kernel,2.1 ms
加速比:1.7×

7. 最佳实践

7.1 融合设计原则

DO: - 优先融合逐元素操作 - 融合 Softmax 等归约操作 - 在 GEMM 后融合 Bias 和 Activation - 保持融合 kernel 简单易懂 - 测量实际性能收益

DON'T: - 不要融合中间结果被多次使用的算子 - 不要融合导致 L1 溢出的操作 - 不要过度融合(保持可维护性) - 不要假设融合一定更快(需要测量)

7.2 融合检查清单

融合前检查: - [ ] 中间结果只使用一次 - [ ] 融合后不超出 L1 容量 - [ ] 数据类型兼容 - [ ] 无复杂的控制流

融合后验证: - [ ] 数值正确性验证 - [ ] 性能提升 > 20% - [ ] 代码可维护性良好 - [ ] 建立性能回归测试

7.3 调试技巧

验证正确性

// 对比融合前后的输出
void verify_fusion() {
  // 运行融合前的版本
  run_unfused_kernels(input, output_ref);

  // 运行融合后的版本
  run_fused_kernel(input, output_test);

  // 对比结果
  float max_diff = 0;
  for (int i = 0; i < size; i++) {
    float diff = abs(output_ref[i] - output_test[i]);
    max_diff = max(max_diff, diff);
  }

  assert(max_diff < 1e-5);
  printf("Fusion verified: max_diff = %.2e\n", max_diff);
}

参考资源