[2/2] Add python wrapper for CUTLASS FP8 Blockscale MoE Kernel. (#5694)

This commit is contained in:
Elfie Guo
2025-05-16 13:14:07 -07:00
committed by GitHub
parent 839fb31e5f
commit 6fc9357503
12 changed files with 896 additions and 41 deletions

View File

@@ -151,11 +151,15 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"(Tensor[])");
m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);
m.def(
"fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor "
"fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a_ptrs, Tensor b_ptrs, Tensor out_ptrs, Tensor "
"a_scales_ptrs, Tensor b_scales_ptrs, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor "
"stride_a, Tensor stride_b, Tensor stride_c, Tensor layout_sfa, Tensor layout_sfb, Tensor problem_sizes, Tensor "
"expert_offsets) -> ()");
"expert_offsets, Tensor workspace) -> ()");
m.impl("fp8_blockwise_scaled_grouped_mm", torch::kCUDA, &fp8_blockwise_scaled_grouped_mm);
m.def(
"prepare_moe_input(Tensor topk_ids, Tensor expert_offsets, Tensor problem_sizes1, Tensor problem_sizes2, Tensor "
"input_permutation, Tensor output_permutation, int num_experts, int n, int k) -> ()");
m.impl("prepare_moe_input", torch::kCUDA, &prepare_moe_input);
/*
* From csrc/speculative
*/