[sgl-kernel] Support FlashInfer top_k_top_p_sampling_from_logits (#9060)
Co-authored-by: luoyuan.luo <luoyuan.luo@antgroup.com>
This commit is contained in:
@@ -593,6 +593,10 @@ void top_p_sampling_from_probs(
|
||||
double top_p_val,
|
||||
bool deterministic,
|
||||
std::optional<at::Generator> gen);
|
||||
|
||||
void top_k_mask_logits(
|
||||
at::Tensor logits, at::Tensor mask_logits, std::optional<at::Tensor> maybe_top_k_arr, int64_t top_k_val);
|
||||
|
||||
torch::Tensor moe_wna16_marlin_gemm(
|
||||
torch::Tensor& a,
|
||||
std::optional<torch::Tensor> const& c_or_none,
|
||||
|
||||
Reference in New Issue
Block a user