[2/2] Add python wrapper for CUTLASS FP8 Blockscale MoE Kernel. (#5694)
This commit is contained in:
@@ -211,6 +211,11 @@ std::vector<at::Tensor> moe_fused_gate(
|
||||
|
||||
void fp8_blockwise_scaled_grouped_mm(
|
||||
torch::Tensor& output,
|
||||
torch::Tensor& a_ptrs,
|
||||
torch::Tensor& b_ptrs,
|
||||
torch::Tensor& out_ptrs,
|
||||
torch::Tensor& a_scales_ptrs,
|
||||
torch::Tensor& b_scales_ptrs,
|
||||
const torch::Tensor& a,
|
||||
const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a,
|
||||
@@ -221,7 +226,19 @@ void fp8_blockwise_scaled_grouped_mm(
|
||||
const torch::Tensor& layout_sfa,
|
||||
const torch::Tensor& layout_sfb,
|
||||
const torch::Tensor& problem_sizes,
|
||||
const torch::Tensor& expert_offsets);
|
||||
const torch::Tensor& expert_offsets,
|
||||
const torch::Tensor& workspace);
|
||||
|
||||
void prepare_moe_input(
|
||||
const torch::Tensor& topk_ids,
|
||||
torch::Tensor& expert_offsets,
|
||||
torch::Tensor& problem_sizes1,
|
||||
torch::Tensor& problem_sizes2,
|
||||
torch::Tensor& input_permutation,
|
||||
torch::Tensor& output_permutation,
|
||||
const int64_t num_experts,
|
||||
const int64_t n,
|
||||
const int64_t k);
|
||||
|
||||
/*
|
||||
* From csrc/speculative
|
||||
|
||||
Reference in New Issue
Block a user