[sgl-kernel] Support FlashInfer top_k_top_p_sampling_from_logits (#9060)

Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
Yuan Luo
2025-08-15 01:56:36 +08:00
committed by GitHub
parent 432f2053dd
commit 53dcc750b6
6 changed files with 349 additions and 5 deletions

View File

@@ -593,6 +593,10 @@ void top_p_sampling_from_probs(
double top_p_val,
bool deterministic,
std::optional<at::Generator> gen);
void top_k_mask_logits(
at::Tensor logits, at::Tensor mask_logits, std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val);
torch::Tensor moe_wna16_marlin_gemm(
torch::Tensor& a,
std::optional<torch::Tensor> const& c_or_none,