[optimize] fuse renormalize into moe_topk_softmax (#7744)

Co-authored-by: ispobock <ispobaoke@gmail.com>
This commit is contained in:
Yi Zhang
2025-07-04 03:42:44 +08:00
committed by GitHub
parent 6840a7bbb2
commit 2998c4bdf4
7 changed files with 254 additions and 101 deletions

View File

@@ -169,9 +169,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"pad_sorted_token_ids) -> ()");
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
m.def(
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
"token_expert_indices, Tensor gating_output) -> ()");
m.def("topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor gating_output, bool renormalize) -> ()");
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
m.def(