[sgl-kernel][2/N]Support Expert Specialization Grouped GEMM (#11534)

This commit is contained in:
Qi Yuhang
2025-10-14 07:24:48 +08:00
committed by GitHub
parent 6dc9ca8c85
commit dc48c4c0e3
4 changed files with 112 additions and 106 deletions

View File

@@ -68,24 +68,58 @@ void es_fp8_blockwise_scaled_grouped_mm(
torch::Tensor lm_problem_sizes = torch::empty({num_experts, 3}, options_int32);
torch::Tensor mm_problem_sizes = torch::empty({num_experts, 3}, options_int32);
torch::Tensor hm_problem_sizes = torch::empty({num_experts, 3}, options_int32);
expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_pre_compute(
out_ptrs,
a_ptrs,
b_ptrs,
a_scales_ptrs,
b_scales_ptrs,
layout_sfa,
layout_sfb,
lm_problem_sizes,
mm_problem_sizes,
hm_problem_sizes,
output,
a,
b,
scales_a,
scales_b,
problem_sizes,
expert_offsets);
const std::string H20_device_type_str("NVIDIA H20");
bool is_h20_device = std::string(at::cuda::getCurrentDeviceProperties()->name) == H20_device_type_str;
at::cuda::CUDAGuard device_guard{(char)a.get_device()};
cudaStream_t stream = at::cuda::getCurrentCUDAStream(a.get_device());
if (output.dtype() == torch::kBFloat16) {
expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_pre_compute<cutlass::bfloat16_t>(
out_ptrs,
a_ptrs,
b_ptrs,
a_scales_ptrs,
b_scales_ptrs,
layout_sfa,
layout_sfb,
lm_problem_sizes,
mm_problem_sizes,
hm_problem_sizes,
output,
a,
b,
scales_a,
scales_b,
problem_sizes,
expert_offsets,
is_h20_device,
stream);
} else if (output.dtype() == torch::kFloat16) {
expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_pre_compute<cutlass::half_t>(
out_ptrs,
a_ptrs,
b_ptrs,
a_scales_ptrs,
b_scales_ptrs,
layout_sfa,
layout_sfb,
lm_problem_sizes,
mm_problem_sizes,
hm_problem_sizes,
output,
a,
b,
scales_a,
scales_b,
problem_sizes,
expert_offsets,
is_h20_device,
stream);
} else {
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
}
if (output.dtype() == torch::kBFloat16) {
expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype<cutlass::bfloat16_t>(
out_ptrs,
@@ -100,7 +134,9 @@ void es_fp8_blockwise_scaled_grouped_mm(
layout_sfb,
lm_problem_sizes,
mm_problem_sizes,
hm_problem_sizes);
hm_problem_sizes,
is_h20_device,
stream);
} else if (output.dtype() == torch::kFloat16) {
expert_specialization::es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype<cutlass::half_t>(
out_ptrs,
@@ -115,7 +151,9 @@ void es_fp8_blockwise_scaled_grouped_mm(
layout_sfb,
lm_problem_sizes,
mm_problem_sizes,
hm_problem_sizes);
hm_problem_sizes,
is_h20_device,
stream);
} else {
TORCH_CHECK(false, "Invalid output type (must be float16 or bfloat16)");
}