[sgl-kernel][3/N]Support Expert Specialization Grouped GEMM (#11674)
This commit is contained in:
@@ -40,7 +40,8 @@ void es_fp8_blockwise_scaled_grouped_mm(
|
||||
const torch::Tensor& stride_b,
|
||||
const torch::Tensor& stride_d,
|
||||
const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets) {
|
||||
const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& workspace) {
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
|
||||
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
|
||||
TORCH_CHECK(problem_sizes.size(1) == 3, "problem_sizes must have shape (num_experts, 3)");
|
||||
@@ -135,6 +136,7 @@ void es_fp8_blockwise_scaled_grouped_mm(
|
||||
lm_problem_sizes,
|
||||
mm_problem_sizes,
|
||||
hm_problem_sizes,
|
||||
workspace,
|
||||
is_h20_device,
|
||||
stream);
|
||||
} else if (output.dtype() == torch::kFloat16) {
|
||||
@@ -152,6 +154,7 @@ void es_fp8_blockwise_scaled_grouped_mm(
|
||||
lm_problem_sizes,
|
||||
mm_problem_sizes,
|
||||
hm_problem_sizes,
|
||||
workspace,
|
||||
is_h20_device,
|
||||
stream);
|
||||
} else {
|
||||
|
||||
Reference in New Issue
Block a user