From 91bf524364c1150ccb87f0481b90e00ecbf43ba9 Mon Sep 17 00:00:00 2001 From: Trunrain <270250579@qq.com> Date: Mon, 5 Jan 2026 15:19:54 +0800 Subject: [PATCH] [BugFix][kernel] fix matmul_allreduce_add_rmsnorm_kernel (#5335) ### What this PR does / why we need it? fix matmul_allreduce_add_rmsnorm_kernel, add hccl Init, SetCcTiling interface test case use multicard-4 ### Does this PR introduce _any_ user-facing change? NO ### How was this patch tested? pytest -sv tests/e2e/nightly/ops/test_matmul_allreduce_add_rmsnorm.py multicard-4 pass https://github.com/vllm-project/vllm-ascend/actions/runs/20502630658/job/58914474652?pr=5335 - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/bc0a5a0c089844b17cb93f3294348f411e523586 Signed-off-by: tongrunze Co-authored-by: tongrunze --- .../op_kernel/matmul_allreduce_add_rmsnorm.cpp | 3 +++ .../ops/singlecard_ops/test_matmul_allreduce_add_rmsnorm.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/csrc/matmul_allreduce_add_rmsnorm/op_kernel/matmul_allreduce_add_rmsnorm.cpp b/csrc/matmul_allreduce_add_rmsnorm/op_kernel/matmul_allreduce_add_rmsnorm.cpp index 3e76f15f..ef907abf 100644 --- a/csrc/matmul_allreduce_add_rmsnorm/op_kernel/matmul_allreduce_add_rmsnorm.cpp +++ b/csrc/matmul_allreduce_add_rmsnorm/op_kernel/matmul_allreduce_add_rmsnorm.cpp @@ -33,6 +33,9 @@ extern "C" __global__ __aicore__ void matmul_allreduce_add_rmsnorm( __gm__ void* mc2CcTiling = (__gm__ void*)(&(tilingData->mc2CcTiling)); auto contextGM0 = AscendC::GetHcclContext(); + hccl_.Init(contextGM0, mc2InitTiling); + hccl_.SetCcTiling(mc2CcTiling); + if ASCEND_IS_AIC { MatmulAllreduceAddRmsnormAicKernel op; op.Init(x1, x2, residual, gamma, y, workspace, &tiling_data, hccl_); diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_matmul_allreduce_add_rmsnorm.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_matmul_allreduce_add_rmsnorm.py index 762802a1..34202652 100644 --- a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_matmul_allreduce_add_rmsnorm.py +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_matmul_allreduce_add_rmsnorm.py @@ -126,7 +126,7 @@ def worker(rank, ep_world_size, batch_size, m, k, n): @torch.inference_mode() def test_matmul_allreduce_add_rmsnorm_kernel(): - ep_world_size = 8 + ep_world_size = 4 batch_size = 1 m = 10000 k = 1024