[sgl-kernel][3/N]Support Expert Specialization Grouped GEMM (#11674)

This commit is contained in:
Qi Yuhang
2025-10-16 04:39:31 +08:00
committed by GitHub
parent f226d3da2a
commit 6c01844f45
7 changed files with 22 additions and 8 deletions

View File

@@ -40,7 +40,8 @@ void es_fp8_blockwise_scaled_grouped_mm(
const torch::Tensor& stride_b,
const torch::Tensor& stride_d,
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets) {
const torch::Tensor& expert_offsets,
const torch::Tensor& workspace) {
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have shape (num_experts, 3)");
@@ -135,6 +136,7 @@ void es_fp8_blockwise_scaled_grouped_mm(
lm_problem_sizes,
mm_problem_sizes,
hm_problem_sizes,
workspace,
is_h20_device,
stream);
} else if (output.dtype() == torch::kFloat16) {
@@ -152,6 +154,7 @@ void es_fp8_blockwise_scaled_grouped_mm(
lm_problem_sizes,
mm_problem_sizes,
hm_problem_sizes,
workspace,
is_h20_device,
stream);
} else {