[BugFix][kernel] fix matmul_allreduce_add_rmsnorm_kernel (#5335)
### What this PR does / why we need it?
fix matmul_allreduce_add_rmsnorm_kernel, add hccl Init, SetCcTiling
interface
test case use multicard-4
### Does this PR introduce _any_ user-facing change?
NO
### How was this patch tested?
pytest -sv tests/e2e/nightly/ops/test_matmul_allreduce_add_rmsnorm.py
multicard-4 pass
https://github.com/vllm-project/vllm-ascend/actions/runs/20502630658/job/58914474652?pr=5335
- vLLM version: release/v0.13.0
- vLLM main:
bc0a5a0c08
Signed-off-by: tongrunze <t00574058@china.huawei.com>
Co-authored-by: tongrunze <t00574058@china.huawei.com>
This commit is contained in:
@@ -33,6 +33,9 @@ extern "C" __global__ __aicore__ void matmul_allreduce_add_rmsnorm(
|
||||
__gm__ void* mc2CcTiling = (__gm__ void*)(&(tilingData->mc2CcTiling));
|
||||
auto contextGM0 = AscendC::GetHcclContext<AscendC::HCCL_GROUP_ID_0>();
|
||||
|
||||
hccl_.Init(contextGM0, mc2InitTiling);
|
||||
hccl_.SetCcTiling(mc2CcTiling);
|
||||
|
||||
if ASCEND_IS_AIC {
|
||||
MatmulAllreduceAddRmsnormAicKernel<DTYPE_X1, DTYPE_Y> op;
|
||||
op.Init(x1, x2, residual, gamma, y, workspace, &tiling_data, hccl_);
|
||||
|
||||
@@ -126,7 +126,7 @@ def worker(rank, ep_world_size, batch_size, m, k, n):
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_matmul_allreduce_add_rmsnorm_kernel():
|
||||
ep_world_size = 8
|
||||
ep_world_size = 4
|
||||
batch_size = 1
|
||||
m = 10000
|
||||
k = 1024
|
||||
|
||||
Reference in New Issue
Block a user