[ROCm] enable moe topk softmax in amd (#4448)
This commit is contained in:
committed by
GitHub
parent
f5bbf6037d
commit
9b8333d992
@@ -61,6 +61,10 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
||||
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
|
||||
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()");
|
||||
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
||||
m.def(
|
||||
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
|
||||
"token_expert_indices, Tensor gating_output) -> ()");
|
||||
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(common_ops)
|
||||
|
||||
@@ -41,6 +41,7 @@ include_dirs = [
|
||||
sources = [
|
||||
"csrc/allreduce/custom_all_reduce.hip",
|
||||
"csrc/moe/moe_align_kernel.cu",
|
||||
"csrc/moe/moe_topk_softmax_kernels.cu",
|
||||
"csrc/torch_extension_rocm.cc",
|
||||
]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user