[ROCm] enable moe topk softmax in amd (#4448)
This commit is contained in:
committed by
GitHub
parent
f5bbf6037d
commit
9b8333d992
@@ -61,6 +61,10 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
|||||||
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
|
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
|
||||||
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()");
|
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()");
|
||||||
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
||||||
|
m.def(
|
||||||
|
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
|
||||||
|
"token_expert_indices, Tensor gating_output) -> ()");
|
||||||
|
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_EXTENSION(common_ops)
|
REGISTER_EXTENSION(common_ops)
|
||||||
|
|||||||
@@ -41,6 +41,7 @@ include_dirs = [
|
|||||||
sources = [
|
sources = [
|
||||||
"csrc/allreduce/custom_all_reduce.hip",
|
"csrc/allreduce/custom_all_reduce.hip",
|
||||||
"csrc/moe/moe_align_kernel.cu",
|
"csrc/moe/moe_align_kernel.cu",
|
||||||
|
"csrc/moe/moe_topk_softmax_kernels.cu",
|
||||||
"csrc/torch_extension_rocm.cc",
|
"csrc/torch_extension_rocm.cc",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user