[1/2] Add Kernel support for Cutlass based Fused FP4 MoE (#6093)
Signed-off-by: Pavani Majety <pmajety@nvidia.com>
This commit is contained in:
@@ -132,6 +132,20 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
" Tensor! output_scale, Tensor! input_scale) -> ()");
|
||||
m.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);
|
||||
|
||||
// Compute NVFP4 experts quantization.
|
||||
m.def(
|
||||
"scaled_fp4_experts_quant(Tensor! output, Tensor! output_scale,"
|
||||
"Tensor input, Tensor input_global_scale, Tensor input_offset_by_experts,"
|
||||
"Tensor output_scale_offset_by_experts) -> ()");
|
||||
m.impl("scaled_fp4_experts_quant", torch::kCUDA, &scaled_fp4_experts_quant);
|
||||
|
||||
m.def(
|
||||
"cutlass_fp4_group_mm(Tensor! output, Tensor a, Tensor b,"
|
||||
"Tensor a_blockscale, Tensor b_blockscale, Tensor alphas,"
|
||||
"Tensor ab_strides, Tensor c_strides, Tensor problem_sizes,"
|
||||
" Tensor expert_offsets, Tensor sf_offsets) -> ()");
|
||||
m.impl("cutlass_fp4_group_mm", torch::kCUDA, &cutlass_fp4_group_mm);
|
||||
|
||||
/*
|
||||
* From csrc/moe
|
||||
*/
|
||||
@@ -161,9 +175,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
"expert_offsets, Tensor workspace) -> ()");
|
||||
m.impl("fp8_blockwise_scaled_grouped_mm", torch::kCUDA, &fp8_blockwise_scaled_grouped_mm);
|
||||
m.def(
|
||||
"prepare_moe_input(Tensor topk_ids, Tensor expert_offsets, Tensor problem_sizes1, Tensor problem_sizes2, Tensor "
|
||||
"input_permutation, Tensor output_permutation, int num_experts, int n, int k) -> ()");
|
||||
"prepare_moe_input(Tensor topk_ids, Tensor expert_offsets, Tensor? blockscale_offsets, Tensor problem_sizes1,"
|
||||
" Tensor problem_sizes2, Tensor input_permutation, Tensor output_permutation, int num_experts, int n, int k) -> "
|
||||
"()");
|
||||
m.impl("prepare_moe_input", torch::kCUDA, &prepare_moe_input);
|
||||
|
||||
m.def("shuffle_rows(Tensor input, Tensor dst2src_map, Tensor output) -> ()");
|
||||
m.impl("shuffle_rows", torch::kCUDA, &shuffle_rows);
|
||||
|
||||
/*
|
||||
* From csrc/speculative
|
||||
*/
|
||||
|
||||
Reference in New Issue
Block a user