diff --git a/csrc/matmul_allreduce_add_rmsnorm/op_kernel/matmul_allreduce_add_rmsnorm.cpp b/csrc/matmul_allreduce_add_rmsnorm/op_kernel/matmul_allreduce_add_rmsnorm.cpp index 3e76f15f..ef907abf 100644 --- a/csrc/matmul_allreduce_add_rmsnorm/op_kernel/matmul_allreduce_add_rmsnorm.cpp +++ b/csrc/matmul_allreduce_add_rmsnorm/op_kernel/matmul_allreduce_add_rmsnorm.cpp @@ -33,6 +33,9 @@ extern "C" __global__ __aicore__ void matmul_allreduce_add_rmsnorm( __gm__ void* mc2CcTiling = (__gm__ void*)(&(tilingData->mc2CcTiling)); auto contextGM0 = AscendC::GetHcclContext(); + hccl_.Init(contextGM0, mc2InitTiling); + hccl_.SetCcTiling(mc2CcTiling); + if ASCEND_IS_AIC { MatmulAllreduceAddRmsnormAicKernel op; op.Init(x1, x2, residual, gamma, y, workspace, &tiling_data, hccl_); diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_matmul_allreduce_add_rmsnorm.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_matmul_allreduce_add_rmsnorm.py index 762802a1..34202652 100644 --- a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_matmul_allreduce_add_rmsnorm.py +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_matmul_allreduce_add_rmsnorm.py @@ -126,7 +126,7 @@ def worker(rank, ep_world_size, batch_size, m, k, n): @torch.inference_mode() def test_matmul_allreduce_add_rmsnorm_kernel(): - ep_world_size = 8 + ep_world_size = 4 batch_size = 1 m = 10000 k = 1024