[1/n]: add cutlass W4A8 moe kernel for hopper architecture (#7772)

Signed-off-by: yangsijia.614 <yangsijia.614@bytedance.com>
Co-authored-by: yicwang <yichen.wang@bytedance.com>
This commit is contained in:
SijiaYang
2025-07-05 11:50:12 +08:00
committed by GitHub
parent cb432f1770
commit da3890e82a
16 changed files with 3602 additions and 0 deletions

View File

@@ -467,6 +467,35 @@ void transfer_kv_all_layer_mla_direct(
int64_t page_size,
int64_t num_layers);
/*
* From csrc/moe/cutlass_moe/w4a8
*/
void get_cutlass_w4a8_moe_mm_data(
const torch::Tensor& topk_ids,
torch::Tensor& expert_offsets,
torch::Tensor& problem_sizes1,
torch::Tensor& problem_sizes2,
torch::Tensor& input_permutation,
torch::Tensor& output_permutation,
const int64_t num_experts,
const int64_t n,
const int64_t k);
void cutlass_w4a8_moe_mm(
torch::Tensor& d_tensors,
torch::Tensor const& a_tensors,
torch::Tensor const& b_tensors,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& expert_offsets,
torch::Tensor const& problem_sizes,
torch::Tensor const& a_strides,
torch::Tensor const& b_strides,
torch::Tensor const& d_strides,
torch::Tensor const& s_strides,
int64_t chunk_size,
int64_t topk);
/*
* From FlashInfer
*/