[2/2] Add python wrapper for CUTLASS FP8 Blockscale MoE Kernel. (#5694)
This commit is contained in:
@@ -151,11 +151,15 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
"(Tensor[])");
|
||||
m.impl("moe_fused_gate", torch::kCUDA, &moe_fused_gate);
|
||||
m.def(
|
||||
"fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor "
|
||||
"fp8_blockwise_scaled_grouped_mm(Tensor output, Tensor a_ptrs, Tensor b_ptrs, Tensor out_ptrs, Tensor "
|
||||
"a_scales_ptrs, Tensor b_scales_ptrs, Tensor a, Tensor b, Tensor scales_a, Tensor scales_b, Tensor "
|
||||
"stride_a, Tensor stride_b, Tensor stride_c, Tensor layout_sfa, Tensor layout_sfb, Tensor problem_sizes, Tensor "
|
||||
"expert_offsets) -> ()");
|
||||
"expert_offsets, Tensor workspace) -> ()");
|
||||
m.impl("fp8_blockwise_scaled_grouped_mm", torch::kCUDA, &fp8_blockwise_scaled_grouped_mm);
|
||||
|
||||
m.def(
|
||||
"prepare_moe_input(Tensor topk_ids, Tensor expert_offsets, Tensor problem_sizes1, Tensor problem_sizes2, Tensor "
|
||||
"input_permutation, Tensor output_permutation, int num_experts, int n, int k) -> ()");
|
||||
m.impl("prepare_moe_input", torch::kCUDA, &prepare_moe_input);
|
||||
/*
|
||||
* From csrc/speculative
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user