[sgl-kernel][3/N]Support Expert Specialization Grouped GEMM (#11674)

This commit is contained in:
Qi Yuhang
2025-10-16 04:39:31 +08:00
committed by GitHub
parent f226d3da2a
commit 6c01844f45
7 changed files with 22 additions and 8 deletions

View File

@@ -98,6 +98,7 @@ void launch_sm90_fp8_blockwise_scaled_group_mm(
const torch::Tensor& layout_sfa,
const torch::Tensor& layout_sfb,
const torch::Tensor& problem_sizes,
const torch::Tensor& workspace,
cudaStream_t stream) {
using ElementA = typename GemmTraits::ElementA;
using StrideA = typename GemmTraits::StrideA;
@@ -143,10 +144,6 @@ void launch_sm90_fp8_blockwise_scaled_group_mm(
auto can_implement_status = gemm_op.can_implement(args);
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, "Failed to implement GEMM");
torch::TensorOptions options_uint8 = torch::TensorOptions().dtype(torch::kUInt8).device(out_ptrs.device());
size_t workspace_size = gemm_op.get_workspace_size(args);
torch::Tensor workspace = torch::empty(workspace_size, options_uint8);
auto status = gemm_op.initialize(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
@@ -169,6 +166,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
const torch::Tensor& lm_problem_sizes,
const torch::Tensor& mm_problem_sizes,
const torch::Tensor& hm_problem_sizes,
const torch::Tensor& workspace,
bool is_h20_device,
cudaStream_t stream) {
using LowMGemmH20Traits =
@@ -199,6 +197,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
layout_sfb,
layout_sfa,
lm_problem_sizes,
workspace,
stream);
} else {
launch_sm90_fp8_blockwise_scaled_group_mm<LowMGemmH20Traits>(
@@ -213,6 +212,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
layout_sfb,
layout_sfa,
lm_problem_sizes,
workspace,
stream);
}
@@ -229,6 +229,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
layout_sfb,
layout_sfa,
mm_problem_sizes,
workspace,
stream);
} else {
launch_sm90_fp8_blockwise_scaled_group_mm<HighMGemmHx00Traits>(
@@ -243,6 +244,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
layout_sfa,
layout_sfb,
mm_problem_sizes,
workspace,
stream);
}
@@ -259,6 +261,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
layout_sfa,
layout_sfb,
hm_problem_sizes,
workspace,
stream);
} else {
launch_sm90_fp8_blockwise_scaled_group_mm<HighMGemmH20Traits>(
@@ -273,6 +276,7 @@ void es_sm90_fp8_blockwise_scaled_group_mm_distpatch_out_dtype(
layout_sfa,
layout_sfb,
hm_problem_sizes,
workspace,
stream);
}
}