[2/2] Add python wrapper for CUTLASS FP8 Blockscale MoE Kernel. (#5694)

This commit is contained in:
Elfie Guo
2025-05-16 13:14:07 -07:00
committed by GitHub
parent 839fb31e5f
commit 6fc9357503
12 changed files with 896 additions and 41 deletions

View File

@@ -211,6 +211,11 @@ std::vector<at::Tensor> moe_fused_gate(
void fp8_blockwise_scaled_grouped_mm(
torch::Tensor& output,
torch::Tensor& a_ptrs,
torch::Tensor& b_ptrs,
torch::Tensor& out_ptrs,
torch::Tensor& a_scales_ptrs,
torch::Tensor& b_scales_ptrs,
const torch::Tensor& a,
const torch::Tensor& b,
const torch::Tensor& scales_a,
@@ -221,7 +226,19 @@ void fp8_blockwise_scaled_grouped_mm(
const torch::Tensor& layout_sfa,
const torch::Tensor& layout_sfb,
const torch::Tensor& problem_sizes,
const torch::Tensor& expert_offsets);
const torch::Tensor& expert_offsets,
const torch::Tensor& workspace);
void prepare_moe_input(
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation,
torch::Tensor& output_permutation,
const int64_t num_experts,
const int64_t n,
const int64_t k);
/*
* From csrc/speculative