[sgl-kernel] Support FlashInfer top_k_top_p_sampling_from_logits (#9060)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
@@ -345,15 +345,19 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
m.def("top_p_renorm_probs(Tensor probs, Tensor! renorm_probs, Tensor? maybe_top_p_arr, float top_p_val) -> ()");
|
||||
m.impl("top_p_renorm_probs", torch::kCUDA, &top_p_renorm_probs);
|
||||
|
||||
m.def(
|
||||
"top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? "
|
||||
"maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()");
|
||||
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
|
||||
|
||||
m.def(
|
||||
"top_k_top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? maybe_top_k_arr, "
|
||||
"float top_k_val, Tensor? maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()");
|
||||
m.impl("top_k_top_p_sampling_from_probs", torch::kCUDA, &top_k_top_p_sampling_from_probs);
|
||||
|
||||
m.def(
|
||||
"top_p_sampling_from_probs(Tensor probs, Tensor output, Tensor? maybe_indices, Tensor? "
|
||||
"maybe_top_p_arr, float top_p_val, bool deterministic, Generator? gen) -> ()");
|
||||
m.impl("top_p_sampling_from_probs", torch::kCUDA, &top_p_sampling_from_probs);
|
||||
m.def("top_k_mask_logits(Tensor logits, Tensor mask_logits, Tensor? maybe_top_k_arr, int top_k_val) -> ()");
|
||||
m.impl("top_k_mask_logits", torch::kCUDA, &top_k_mask_logits);
|
||||
|
||||
m.def(
|
||||
"moe_wna16_marlin_gemm(Tensor! a, Tensor? c_or_none,"
|
||||
"Tensor! b_q_weight, Tensor! b_scales, Tensor? b_zeros_or_none,"
|
||||
|
||||
Reference in New Issue
Block a user