From dc48c4c0e3848b5792039605d47f015cff26872e Mon Sep 17 00:00:00 2001 From: Qi Yuhang <45795032+HydraQYH@users.noreply.github.com> Date: Tue, 14 Oct 2025 07:24:48 +0800 Subject: [PATCH] [sgl-kernel][2/N]Support Expert Specialization Grouped GEMM (#11534) --- .../expert_specialization/es_fp8_blockwise.cu | 78 +++++++--- .../es_fp8_blockwise_launcher.cuh | 138 +++++++----------- sgl-kernel/python/sgl_kernel/__init__.py | 2 +- ...cilization.py => expert_specialization.py} | 0 4 files changed, 112 insertions(+), 106 deletions(-) rename sgl-kernel/python/sgl_kernel/{expert_specilization.py => expert_specialization.py} (100%) diff --git a/sgl-kernel/csrc/expert_specialization/es_fp8_blockwise.cu b/sgl-kernel/csrc/expert_specialization/es_fp8_blockwise.cu index f7e4b61ef..a5d2904fc 100644 --- a/sgl-kernel/csrc/expert_specialization/es_fp8_blockwise.cu +++ b/sgl-kernel/csrc/expert_specialization/es_fp8_blockwise.cu @@ -68,24 +68,58 @@ void es_fp8_blockwise_scaled_grouped_mm( torch::Tensor lm_problem_sizes = torch::empty({num_experts, 3}, options_int32); torch::Tensor mm_problem_sizes = torch::empty({num_experts, 3}, options_int32); torch::Tensor hm_problem_sizes = torch::empty({num_experts, 3}, options_int32); - expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_pre_compute( - out_ptrs, - a_ptrs, - b_ptrs, - a_scales_ptrs, - b_scales_ptrs, - layout_sfa, - layout_sfb, - lm_problem_sizes, - mm_problem_sizes, - hm_problem_sizes, - output, - a, - b, - scales_a, - scales_b, - problem_sizes, - expert_offsets); + + const std::string H20_device_type_str("NVIDIA H20"); + bool is_h20_device = std::string(at::cuda::getCurrentDeviceProperties()->name) == H20_device_type_str; + at::cuda::CUDAGuard device_guard{(char)a.get_device()}; + cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device()); + + if (output.dtype() == torch::kBFloat16) { + expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_pre_compute( + out_ptrs, + a_ptrs, + b_ptrs, + a_scales_ptrs, + b_scales_ptrs, + layout_sfa, + layout_sfb, + lm_problem_sizes, + mm_problem_sizes, + hm_problem_sizes, + output, + a, + b, + scales_a, + scales_b, + problem_sizes, + expert_offsets, + is_h20_device, + stream); + } else if (output.dtype() == torch::kFloat16) { + expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_pre_compute( + out_ptrs, + a_ptrs, + b_ptrs, + a_scales_ptrs, + b_scales_ptrs, + layout_sfa, + layout_sfb, + lm_problem_sizes, + mm_problem_sizes, + hm_problem_sizes, + output, + a, + b, + scales_a, + scales_b, + problem_sizes, + expert_offsets, + is_h20_device, + stream); + } else { + TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)"); + } + if (output.dtype() == torch::kBFloat16) { expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( out_ptrs, @@ -100,7 +134,9 @@ void es_fp8_blockwise_scaled_grouped_mm( layout_sfb, lm_problem_sizes, mm_problem_sizes, - hm_problem_sizes); + hm_problem_sizes, + is_h20_device, + stream); } else if (output.dtype() == torch::kFloat16) { expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( out_ptrs, @@ -115,7 +151,9 @@ void es_fp8_blockwise_scaled_grouped_mm( layout_sfb, lm_problem_sizes, mm_problem_sizes, - hm_problem_sizes); + hm_problem_sizes, + is_h20_device, + stream); } else { TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)"); } diff --git a/sgl-kernel/csrc/expert_specialization/es_fp8_blockwise_launcher.cuh b/sgl-kernel/csrc/expert_specialization/es_fp8_blockwise_launcher.cuh index 3b0f4a8a2..19621b24f 100644 --- a/sgl-kernel/csrc/expert_specialization/es_fp8_blockwise_launcher.cuh +++ b/sgl-kernel/csrc/expert_specialization/es_fp8_blockwise_launcher.cuh @@ -3,6 +3,7 @@ #include #include +#include #include #include @@ -14,6 +15,7 @@ namespace expert_specialization { using namespace cute; +template void es_sm90_fp8_blockwise_scaled_group_mm_pre_compute( // Output torch::Tensor& out_ptrs, @@ -33,15 +35,14 @@ void es_sm90_fp8_blockwise_scaled_group_mm_pre_compute( torch::Tensor const& a_scales, torch::Tensor const& b_scales, torch::Tensor const& problem_sizes, - torch::Tensor const& expert_offsets) { + torch::Tensor const& expert_offsets, + bool is_h20_device, + cudaStream_t stream) { TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn); TORCH_CHECK(a_scales.dtype() == torch::kFloat32); TORCH_CHECK(b_scales.dtype() == torch::kFloat32); - const std::string H20_device_type_str("NVIDIA H20"); - bool is_h20_device = std::string(at::cuda::getCurrentDeviceProperties()->name) == H20_device_type_str; - // Creat Scale Factor Layout Functor using LayoutSFA = typename PerfConfigMiddleMH20::LayoutSFA; using LayoutSFB = typename PerfConfigMiddleMH20::LayoutSFB; @@ -49,74 +50,38 @@ void es_sm90_fp8_blockwise_scaled_group_mm_pre_compute( reinterpret_cast(layout_sfa.data_ptr()), reinterpret_cast(layout_sfb.data_ptr())); int num_experts = (int)expert_offsets.size(0); - auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index()); - // Dispatch - if (out_tensors.dtype() == torch::kBFloat16) { - struct Fp8BlockwiseGroupedGemmOffsetFunctor of( - static_cast(expert_offsets.data_ptr()), - static_cast(a_tensors.data_ptr()), - static_cast(b_tensors.data_ptr()), - static_cast(out_tensors.data_ptr()), - static_cast(a_scales.data_ptr()), - static_cast(b_scales.data_ptr()), - static_cast(a_ptrs.data_ptr()), - static_cast(b_ptrs.data_ptr()), - static_cast(a_scales_ptrs.data_ptr()), - static_cast(b_scales_ptrs.data_ptr()), - static_cast(out_ptrs.data_ptr())); - if (!is_h20_device) { - struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor lm_psf( - static_cast(lm_problem_sizes.data_ptr())); - struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor mm_psf( - static_cast(mm_problem_sizes.data_ptr())); - struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor hm_psf( - static_cast(hm_problem_sizes.data_ptr())); - groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>( - static_cast(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf); - } else { - struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor lm_psf( - static_cast(lm_problem_sizes.data_ptr())); - struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor mm_psf( - static_cast(mm_problem_sizes.data_ptr())); - struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor hm_psf( - static_cast(hm_problem_sizes.data_ptr())); - groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>( - static_cast(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf); - } - } else if (out_tensors.dtype() == torch::kFloat16) { - struct Fp8BlockwiseGroupedGemmOffsetFunctor of( - static_cast(expert_offsets.data_ptr()), - static_cast(a_tensors.data_ptr()), - static_cast(b_tensors.data_ptr()), - static_cast(out_tensors.data_ptr()), - static_cast(a_scales.data_ptr()), - static_cast(b_scales.data_ptr()), - static_cast(a_ptrs.data_ptr()), - static_cast(b_ptrs.data_ptr()), - static_cast(a_scales_ptrs.data_ptr()), - static_cast(b_scales_ptrs.data_ptr()), - static_cast(out_ptrs.data_ptr())); - if (!is_h20_device) { - struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor lm_psf( - static_cast(lm_problem_sizes.data_ptr())); - struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor mm_psf( - static_cast(mm_problem_sizes.data_ptr())); - struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor hm_psf( - static_cast(hm_problem_sizes.data_ptr())); - groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>( - static_cast(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf); - } else { - struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor lm_psf( - static_cast(lm_problem_sizes.data_ptr())); - struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor mm_psf( - static_cast(mm_problem_sizes.data_ptr())); - struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor hm_psf( - static_cast(hm_problem_sizes.data_ptr())); - groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>( - static_cast(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf); - } + TORCH_CHECK(num_experts <= 1024, "Expert more than 1024"); // Max threads per block is 1024 + + struct Fp8BlockwiseGroupedGemmOffsetFunctor of( + static_cast(expert_offsets.data_ptr()), + static_cast(a_tensors.data_ptr()), + static_cast(b_tensors.data_ptr()), + static_cast(out_tensors.data_ptr()), + static_cast(a_scales.data_ptr()), + static_cast(b_scales.data_ptr()), + static_cast(a_ptrs.data_ptr()), + static_cast(b_ptrs.data_ptr()), + static_cast(a_scales_ptrs.data_ptr()), + static_cast(b_scales_ptrs.data_ptr()), + static_cast(out_ptrs.data_ptr())); + if (!is_h20_device) { + struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor lm_psf( + static_cast(lm_problem_sizes.data_ptr())); + struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor mm_psf( + static_cast(mm_problem_sizes.data_ptr())); + struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor hm_psf( + static_cast(hm_problem_sizes.data_ptr())); + groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>( + static_cast(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf); } else { - TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)"); + struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor lm_psf( + static_cast(lm_problem_sizes.data_ptr())); + struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor mm_psf( + static_cast(mm_problem_sizes.data_ptr())); + struct Fp8BlockwiseGroupedGemmProblemSizeFilterFunctor hm_psf( + static_cast(hm_problem_sizes.data_ptr())); + groupedGemmPreComputeKernel<<<1, num_experts, 0, stream>>>( + static_cast(problem_sizes.data_ptr()), of, sf_layout, lm_psf, mm_psf, hm_psf); } } @@ -132,7 +97,8 @@ void launch_sm90_fp8_blockwise_scaled_group_mm( const torch::Tensor& stride_d, const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb, - const torch::Tensor& problem_sizes) { + const torch::Tensor& problem_sizes, + cudaStream_t stream) { using ElementA = typename GemmTraits::ElementA; using StrideA = typename GemmTraits::StrideA; using ElementB = typename GemmTraits::ElementB; @@ -174,9 +140,6 @@ void launch_sm90_fp8_blockwise_scaled_group_mm( epilogue_args, hw_info}; - at::cuda::CUDAGuard device_guard{(char)a_ptrs.get_device()}; - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(a_ptrs.get_device()); - auto can_implement_status = gemm_op.can_implement(args); TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM"); @@ -205,7 +168,9 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( const torch::Tensor& layout_sfb, const torch::Tensor& lm_problem_sizes, const torch::Tensor& mm_problem_sizes, - const torch::Tensor& hm_problem_sizes) { + const torch::Tensor& hm_problem_sizes, + bool is_h20_device, + cudaStream_t stream) { using LowMGemmH20Traits = ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits; using LowMGemmHx00Traits = @@ -221,9 +186,6 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( using HighMGemmHx00Traits = ExpertSpecializationSm90FP8BlockwiseGroupedGemmTraits; - const std::string H20_device_type_str("NVIDIA H20"); - bool is_h20_device = std::string(at::cuda::getCurrentDeviceProperties()->name) == H20_device_type_str; - if (!is_h20_device) { launch_sm90_fp8_blockwise_scaled_group_mm( out_ptrs, @@ -236,7 +198,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( stride_d, layout_sfb, layout_sfa, - lm_problem_sizes); + lm_problem_sizes, + stream); } else { launch_sm90_fp8_blockwise_scaled_group_mm( out_ptrs, @@ -249,7 +212,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( stride_d, layout_sfb, layout_sfa, - lm_problem_sizes); + lm_problem_sizes, + stream); } if (!is_h20_device) { @@ -264,7 +228,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( stride_d, layout_sfb, layout_sfa, - mm_problem_sizes); + mm_problem_sizes, + stream); } else { launch_sm90_fp8_blockwise_scaled_group_mm( out_ptrs, @@ -277,7 +242,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( stride_d, layout_sfa, layout_sfb, - mm_problem_sizes); + mm_problem_sizes, + stream); } if (!is_h20_device) { @@ -292,7 +258,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( stride_d, layout_sfa, layout_sfb, - hm_problem_sizes); + hm_problem_sizes, + stream); } else { launch_sm90_fp8_blockwise_scaled_group_mm( out_ptrs, @@ -305,7 +272,8 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype( stride_d, layout_sfa, layout_sfb, - hm_problem_sizes); + hm_problem_sizes, + stream); } } diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 6091ac41c..209f81434 100644 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -244,7 +244,7 @@ from sgl_kernel.elementwise import ( rmsnorm, silu_and_mul, ) -from sgl_kernel.expert_specilization import es_fp8_blockwise_scaled_grouped_mm +from sgl_kernel.expert_specialization import es_fp8_blockwise_scaled_grouped_mm from sgl_kernel.fused_moe import fused_marlin_moe from sgl_kernel.gemm import ( awq_dequantize, diff --git a/sgl-kernel/python/sgl_kernel/expert_specilization.py b/sgl-kernel/python/sgl_kernel/expert_specialization.py similarity index 100% rename from sgl-kernel/python/sgl_kernel/expert_specilization.py rename to sgl-kernel/python/sgl_kernel/expert_specialization.py