From fe531d6f4e165d2feced37f58d47f9c809309739 Mon Sep 17 00:00:00 2001 From: Yuhao Yao <37280700+yuhyao@users.noreply.github.com> Date: Thu, 25 Sep 2025 09:51:50 +0800 Subject: [PATCH] [Bug] Fix Issue#10215 (#10572) --- ...ay_tma_gmma_rs_warpspecialized_mixed_input_.hpp | 14 +++++++------- sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/sgl-kernel/csrc/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp b/sgl-kernel/csrc/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp index b37d5696c..22b344794 100644 --- a/sgl-kernel/csrc/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp +++ b/sgl-kernel/csrc/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp @@ -1025,8 +1025,6 @@ struct CollectiveMmaArrayMixedInput< // src: tCrA_load, dst: tCrA_mma Utils::convert_A_kblock(tCrA_load, tCrA_mma, 0); - tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int chunk_id = 0; chunk_id < NumChunksPerTileK; ++chunk_id) { @@ -1060,6 +1058,8 @@ struct CollectiveMmaArrayMixedInput< } } + warpgroup_wait<0>(); + CUTLASS_PRAGMA_UNROLL for (int chunk_id_ = 0; chunk_id_ < NumChunksPerTileK; ++chunk_id_) { warpgroup_fence_operand(intermediate_array[chunk_id_]); @@ -1114,7 +1114,6 @@ struct CollectiveMmaArrayMixedInput< 1, smem_pipe_read.index()); - warpgroup_wait(); Utils::convert_A_kblock(tCrA_load, tCrA_mma, 0); } } @@ -1148,8 +1147,6 @@ struct CollectiveMmaArrayMixedInput< tiled_mma.accumulate_ = GMMA::ScaleOut::One; warpgroup_commit_batch(); - warpgroup_wait(); // We have K_BLOCK_MAX - 1 GMMA instructions pending for this stage, so we can - // release prior barrier if (k_block == K_BLOCK_MAX - 1) { pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it ++smem_pipe_release; @@ -1162,6 +1159,8 @@ struct CollectiveMmaArrayMixedInput< if (k_block == K_BLOCK_MAX - 1) { // The last k_block + warpgroup_wait<0>(); + CUTLASS_PRAGMA_UNROLL for (int chunk_id_ = 0; chunk_id_ < NumChunksPerTileK; ++chunk_id_) { warpgroup_fence_operand(intermediate_array[chunk_id_]); @@ -1241,7 +1240,6 @@ struct CollectiveMmaArrayMixedInput< tiled_mma.accumulate_ = GMMA::ScaleOut::One; warpgroup_commit_batch(); - warpgroup_wait(); if (k_block == K_BLOCK_MAX - 1) { // release prior barrier pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it @@ -1264,6 +1262,8 @@ struct CollectiveMmaArrayMixedInput< if ((k_block + 1) % NumMMAsPerChunk == 0) { tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + warpgroup_wait<0>(); warpgroup_fence_operand(intermediate); // Apply the group-wise scaling @@ -1296,7 +1296,7 @@ struct CollectiveMmaArrayMixedInput< smem_pipe_release.advance(k_tile_count); // Wait on all GMMAs to complete - warpgroup_wait<0>(); + // warpgroup_wait<0>(); for (int count = 0; count < prologue_mma_count; ++count) { pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it diff --git a/sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py b/sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py index 3f9e60077..7acba566c 100644 --- a/sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py +++ b/sgl-kernel/tests/test_cutlass_w4a8_moe_mm.py @@ -157,8 +157,8 @@ def _per_tensor_quant_fp8( reason="cutlass_w4a8_moe_mm is only supported on sm90", ) @pytest.mark.parametrize("batch_size", [2, 4, 8, 16, 32]) -@pytest.mark.parametrize("k", [512, 1024, 2048, 4096, 7168]) -@pytest.mark.parametrize("n", [256, 512, 1024, 2048]) +@pytest.mark.parametrize("k", [256, 512, 1024, 2048, 4096, 7168]) +@pytest.mark.parametrize("n", [256, 512, 1024, 2048, 7168]) @pytest.mark.parametrize("num_experts", [2, 4, 6, 8]) def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts): torch.manual_seed(0)