diff --git a/sgl-kernel/csrc/torch_extension_rocm.cc b/sgl-kernel/csrc/torch_extension_rocm.cc index ade9a6d44..d424ce6d6 100644 --- a/sgl-kernel/csrc/torch_extension_rocm.cc +++ b/sgl-kernel/csrc/torch_extension_rocm.cc @@ -61,6 +61,10 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { "moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! " "experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()"); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); + m.def( + "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " + "token_expert_indices, Tensor gating_output) -> ()"); + m.impl("topk_softmax", torch::kCUDA, &topk_softmax); } REGISTER_EXTENSION(common_ops) diff --git a/sgl-kernel/setup_rocm.py b/sgl-kernel/setup_rocm.py index 55bc37266..a9cc5edca 100644 --- a/sgl-kernel/setup_rocm.py +++ b/sgl-kernel/setup_rocm.py @@ -41,6 +41,7 @@ include_dirs = [ sources = [ "csrc/allreduce/custom_all_reduce.hip", "csrc/moe/moe_align_kernel.cu", + "csrc/moe/moe_topk_softmax_kernels.cu", "csrc/torch_extension_rocm.cc", ]