From 9b8333d99204ea8582d5132dd48084decaa68d25 Mon Sep 17 00:00:00 2001 From: yiakwy-xpu-ml-framework-team <89890040+yiakwy-xpu-ml-framework-team@users.noreply.github.com> Date: Mon, 17 Mar 2025 09:16:55 +0800 Subject: [PATCH] [ROCm] enable moe topk softmax in amd (#4448) --- sgl-kernel/csrc/torch_extension_rocm.cc | 4 ++++ sgl-kernel/setup_rocm.py | 1 + 2 files changed, 5 insertions(+) diff --git a/sgl-kernel/csrc/torch_extension_rocm.cc b/sgl-kernel/csrc/torch_extension_rocm.cc index ade9a6d44..d424ce6d6 100644 --- a/sgl-kernel/csrc/torch_extension_rocm.cc +++ b/sgl-kernel/csrc/torch_extension_rocm.cc @@ -61,6 +61,10 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { "moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! " "experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()"); m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); + m.def( + "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " + "token_expert_indices, Tensor gating_output) -> ()"); + m.impl("topk_softmax", torch::kCUDA, &topk_softmax); } REGISTER_EXTENSION(common_ops) diff --git a/sgl-kernel/setup_rocm.py b/sgl-kernel/setup_rocm.py index 55bc37266..a9cc5edca 100644 --- a/sgl-kernel/setup_rocm.py +++ b/sgl-kernel/setup_rocm.py @@ -41,6 +41,7 @@ include_dirs = [ sources = [ "csrc/allreduce/custom_all_reduce.hip", "csrc/moe/moe_align_kernel.cu", + "csrc/moe/moe_topk_softmax_kernels.cu", "csrc/torch_extension_rocm.cc", ]