diff --git a/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu b/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu index d252c29c2..748dd2137 100644 --- a/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu +++ b/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu @@ -485,7 +485,8 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape( torch::TensorOptions options_int = torch::TensorOptions().dtype(torch::kInt64).device(a.device()); torch::Tensor problem_sizes_transpose = torch::empty(num_experts * 3, options_int); - if (a.size(1) > 128) { + if (at::cuda::getCurrentDeviceProperties()->multiProcessorCount == 78 && a.size(1) > 128) { + // For H20 with K > 128, use Pingpong Schedule run_get_group_gemm_starts( expert_offsets, a_ptrs, @@ -517,7 +518,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape( expert_offsets, workspace); } else { - // Small K + // For H20 with K <= 128, and H100 & H200 & H800, use Cooperative Schedule run_get_group_gemm_starts( expert_offsets, a_ptrs,