正确性验证方法

Golden 数据生成

Python Golden 生成脚本模板

#!/usr/bin/python3
import os
import numpy as np
np.random.seed(42)

def gen_reduce_scatter_golden(nranks, M, N, dtype=np.float16):
    """生成 ReduceScatter 的 golden 数据"""
    inputs = []
    for r in range(nranks):
        data = np.random.randn(M, N).astype(dtype)
        data.tofile(f"rank{r}_input.bin")
        inputs.append(data)

    summed = np.sum(inputs, axis=0)

    for r in range(nranks):
        golden = np.zeros_like(summed)
        # 按 tiling 策略填充 rank r 应该持有的结果
        golden.tofile(f"rank{r}_golden.bin")

def gen_allgather_golden(nranks, M, N, dtype=np.float16):
    """生成 AllGather 的 golden 数据:每个 rank 持有完整的 summed 结果"""
    inputs = []
    for r in range(nranks):
        data = np.fromfile(f"rank{r}_input.bin", dtype=dtype).reshape(M, N)
        inputs.append(data)

    golden = np.sum(inputs, axis=0)
    for r in range(nranks):
        golden.tofile(f"rank{r}_allreduce_golden.bin")

if __name__ == "__main__":
    nranks = 8
    M, N = 5416, 1408
    gen_reduce_scatter_golden(nranks, M, N)
    gen_allgather_golden(nranks, M, N)

Golden 数据组织

testdata/
├── rank0_input.bin
├── rank1_input.bin
├── ...
├── rank0_golden.bin
├── rank1_golden.bin
├── ...
└── config.json

数据类型映射

Python numpy C++ 类型 ACL 类型
np.float16 half aclFloat16
np.float32 float float
np.int32 int32_t int32_t
np.int16 int16_t int16_t

验证函数模板

template <typename T>
bool VerifyResult(const T *actual, const T *expected, size_t count,
                  float atol = 1.0f, float rtol = 0.01f)
{
    int error_count = 0;
    int max_errors = 10;
    float max_diff = 0.0f;

    for (size_t i = 0; i < count; ++i) {
        float a = static_cast<float>(actual[i]);
        float e = static_cast<float>(expected[i]);
        float diff = std::abs(a - e);
        float threshold = atol + rtol * std::abs(e);

        if (diff > threshold) {
            if (error_count < max_errors) {
                printf("  Mismatch at [%zu]: actual=%f, expected=%f, diff=%f, threshold=%f\n",
                       i, a, e, diff, threshold);
            }
            error_count++;
            max_diff = std::max(max_diff, diff);
        }
    }

    if (error_count > 0) {
        printf("  Total errors: %d / %zu (max_diff=%f)\n", error_count, count, max_diff);
    }
    return error_count == 0;
}

精度标准

数据类型 推荐 atol 推荐 rtol 说明
float (FP32) 1e-5 1e-4 高精度
half (FP16) 1.0 0.01 AtomicAdd 累积误差较大
int32 / int16 0 0 精确匹配

FP16 AtomicAdd 精度注意: - 多 rank AtomicAdd 累积会引入浮点误差 - rank 数越多,累积误差越大 - 建议 FP16 使用 atol=1.0, rtol=0.01 或更宽松的阈值


分阶段验证

对于多阶段算子(如 RS + Barrier + AG),建议分阶段验证:

// 阶段 1 验证:RS 完成后,检查 reduced_output
RunReduceScatterOnly(...);
aclrtSynchronizeStream(stream);
bool rs_pass = VerifyReduceScatter(reduced_output, rs_golden);

// 阶段 2 验证:完整 AllReduce 后,检查最终结果
RunFullAllReduce(...);
aclrtSynchronizeStream(stream);
bool ar_pass = VerifyAllReduce(reduced_output, ar_golden);

main.cpp 模板(多 rank 测试)

#include "acl/acl.h"
#include "comm_mpi.h"
#include "hccl_context.h"
#include <cstdio>

extern void launchTPutTest(uint8_t *local, uint8_t *remote, void *stream);

int main(int argc, char **argv)
{
    MPI_Init(&argc, &argv);
    int rank, nranks;
    MPI_Comm_rank(MPI_COMM_WORLD, &rank);
    MPI_Comm_size(MPI_COMM_WORLD, &nranks);

    aclInit(nullptr);
    aclrtSetDevice(rank);
    aclrtStream stream;
    aclrtCreateStream(&stream);

    HcclRootInfo rootInfo;
    if (rank == 0) HcclGetRootInfo(&rootInfo);
    MPI_Bcast(&rootInfo, sizeof(rootInfo), MPI_BYTE, 0, MPI_COMM_WORLD);
    HcclComm comm;
    HcclCommInitRootInfo(nranks, &rootInfo, rank, &comm);

    size_t dataSize = ROWS * COLS * sizeof(half);
    uint8_t *localBuf, *remoteBuf;
    // ... 获取通信窗口地址 ...

    std::vector<half> hostData(ROWS * COLS);
    for (int i = 0; i < ROWS * COLS; i++) hostData[i] = (half)(rank * 1000 + i);
    aclrtMemcpy(localBuf, dataSize, hostData.data(), dataSize, ACL_MEMCPY_HOST_TO_DEVICE);

    MPI_Barrier(MPI_COMM_WORLD);

    launchTPutTest(localBuf, remoteBuf, stream);
    aclrtSynchronizeStream(stream);

    MPI_Barrier(MPI_COMM_WORLD);

    std::vector<half> result(ROWS * COLS);
    aclrtMemcpy(result.data(), dataSize, localBuf, dataSize, ACL_MEMCPY_DEVICE_TO_HOST);

    bool pass = true;
    for (int i = 0; i < ROWS * COLS; i++) {
        half expected = /* 根据通信语义计算 */;
        if (abs((float)result[i] - (float)expected) > 1e-3) {
            printf("FAIL: rank %d, idx %d, got %f, expected %f\n",
                   rank, i, (float)result[i], (float)expected);
            pass = false;
            break;
        }
    }
    printf("Rank %d: %s\n", rank, pass ? "PASS" : "FAIL");

    HcclCommDestroy(comm);
    aclrtDestroyStream(stream);
    aclrtResetDevice(rank);
    aclFinalize();
    MPI_Finalize();
    return pass ? 0 : 1;
}