[ROCm] Enable MTP (NextN) on AMD GPU (#4631)
This commit is contained in:
@@ -65,6 +65,18 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
||||
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
|
||||
"token_expert_indices, Tensor gating_output) -> ()");
|
||||
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
|
||||
|
||||
m.def(
|
||||
"verify_tree_greedy(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
|
||||
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
|
||||
"Tensor target_predict, int cuda_stream) -> ()");
|
||||
m.impl("verify_tree_greedy", torch::kCUDA, &verify_tree_greedy);
|
||||
|
||||
m.def(
|
||||
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
|
||||
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, "
|
||||
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()");
|
||||
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(common_ops)
|
||||
|
||||
Reference in New Issue
Block a user