[Bug] Fix Issue#10215 (#10572)
This commit is contained in:
@@ -1025,8 +1025,6 @@ struct CollectiveMmaArrayMixedInput<
|
|||||||
// src: tCrA_load, dst: tCrA_mma
|
// src: tCrA_load, dst: tCrA_mma
|
||||||
Utils::convert_A_kblock(tCrA_load, tCrA_mma, 0);
|
Utils::convert_A_kblock(tCrA_load, tCrA_mma, 0);
|
||||||
|
|
||||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
|
||||||
|
|
||||||
// Unroll the K mode manually to set scale D to 1
|
// Unroll the K mode manually to set scale D to 1
|
||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
for (int chunk_id = 0; chunk_id < NumChunksPerTileK; ++chunk_id) {
|
for (int chunk_id = 0; chunk_id < NumChunksPerTileK; ++chunk_id) {
|
||||||
@@ -1060,6 +1058,8 @@ struct CollectiveMmaArrayMixedInput<
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
warpgroup_wait<0>();
|
||||||
|
|
||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
for (int chunk_id_ = 0; chunk_id_ < NumChunksPerTileK; ++chunk_id_) {
|
for (int chunk_id_ = 0; chunk_id_ < NumChunksPerTileK; ++chunk_id_) {
|
||||||
warpgroup_fence_operand(intermediate_array[chunk_id_]);
|
warpgroup_fence_operand(intermediate_array[chunk_id_]);
|
||||||
@@ -1114,7 +1114,6 @@ struct CollectiveMmaArrayMixedInput<
|
|||||||
1,
|
1,
|
||||||
smem_pipe_read.index());
|
smem_pipe_read.index());
|
||||||
|
|
||||||
warpgroup_wait<K_WAIT_MAX>();
|
|
||||||
Utils::convert_A_kblock(tCrA_load, tCrA_mma, 0);
|
Utils::convert_A_kblock(tCrA_load, tCrA_mma, 0);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1148,8 +1147,6 @@ struct CollectiveMmaArrayMixedInput<
|
|||||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||||
warpgroup_commit_batch();
|
warpgroup_commit_batch();
|
||||||
|
|
||||||
warpgroup_wait<K_WAIT_MAX>(); // We have K_BLOCK_MAX - 1 GMMA instructions pending for this stage, so we can
|
|
||||||
// release prior barrier
|
|
||||||
if (k_block == K_BLOCK_MAX - 1) {
|
if (k_block == K_BLOCK_MAX - 1) {
|
||||||
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
||||||
++smem_pipe_release;
|
++smem_pipe_release;
|
||||||
@@ -1162,6 +1159,8 @@ struct CollectiveMmaArrayMixedInput<
|
|||||||
if (k_block == K_BLOCK_MAX - 1) {
|
if (k_block == K_BLOCK_MAX - 1) {
|
||||||
// The last k_block
|
// The last k_block
|
||||||
|
|
||||||
|
warpgroup_wait<0>();
|
||||||
|
|
||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
for (int chunk_id_ = 0; chunk_id_ < NumChunksPerTileK; ++chunk_id_) {
|
for (int chunk_id_ = 0; chunk_id_ < NumChunksPerTileK; ++chunk_id_) {
|
||||||
warpgroup_fence_operand(intermediate_array[chunk_id_]);
|
warpgroup_fence_operand(intermediate_array[chunk_id_]);
|
||||||
@@ -1241,7 +1240,6 @@ struct CollectiveMmaArrayMixedInput<
|
|||||||
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
|
||||||
warpgroup_commit_batch();
|
warpgroup_commit_batch();
|
||||||
|
|
||||||
warpgroup_wait<K_WAIT_MAX>();
|
|
||||||
if (k_block == K_BLOCK_MAX - 1) {
|
if (k_block == K_BLOCK_MAX - 1) {
|
||||||
// release prior barrier
|
// release prior barrier
|
||||||
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
||||||
@@ -1264,6 +1262,8 @@ struct CollectiveMmaArrayMixedInput<
|
|||||||
|
|
||||||
if ((k_block + 1) % NumMMAsPerChunk == 0) {
|
if ((k_block + 1) % NumMMAsPerChunk == 0) {
|
||||||
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
|
||||||
|
|
||||||
|
warpgroup_wait<0>();
|
||||||
warpgroup_fence_operand(intermediate);
|
warpgroup_fence_operand(intermediate);
|
||||||
|
|
||||||
// Apply the group-wise scaling
|
// Apply the group-wise scaling
|
||||||
@@ -1296,7 +1296,7 @@ struct CollectiveMmaArrayMixedInput<
|
|||||||
smem_pipe_release.advance(k_tile_count);
|
smem_pipe_release.advance(k_tile_count);
|
||||||
|
|
||||||
// Wait on all GMMAs to complete
|
// Wait on all GMMAs to complete
|
||||||
warpgroup_wait<0>();
|
// warpgroup_wait<0>();
|
||||||
|
|
||||||
for (int count = 0; count < prologue_mma_count; ++count) {
|
for (int count = 0; count < prologue_mma_count; ++count) {
|
||||||
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it
|
||||||
|
|||||||
@@ -157,8 +157,8 @@ def _per_tensor_quant_fp8(
|
|||||||
reason="cutlass_w4a8_moe_mm is only supported on sm90",
|
reason="cutlass_w4a8_moe_mm is only supported on sm90",
|
||||||
)
|
)
|
||||||
@pytest.mark.parametrize("batch_size", [2, 4, 8, 16, 32])
|
@pytest.mark.parametrize("batch_size", [2, 4, 8, 16, 32])
|
||||||
@pytest.mark.parametrize("k", [512, 1024, 2048, 4096, 7168])
|
@pytest.mark.parametrize("k", [256, 512, 1024, 2048, 4096, 7168])
|
||||||
@pytest.mark.parametrize("n", [256, 512, 1024, 2048])
|
@pytest.mark.parametrize("n", [256, 512, 1024, 2048, 7168])
|
||||||
@pytest.mark.parametrize("num_experts", [2, 4, 6, 8])
|
@pytest.mark.parametrize("num_experts", [2, 4, 6, 8])
|
||||||
def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
|
def test_int4_fp8_grouped_gemm_multi_experts(batch_size, k, n, num_experts):
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
|
|||||||
Reference in New Issue
Block a user