Bug: Fix min_p sampling crash when using flashinfer backend (#3207)

Co-authored-by: zhaochenyang20 <zhaochen20@outlook.com>
This commit is contained in:
zifeitong
2025-02-02 15:36:07 -08:00
committed by GitHub
parent 566d61d90f
commit 28b0a62bb3

View File

@@ -85,7 +85,7 @@ class Sampler(nn.Module):
if sampling_info.need_min_p_sampling:
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
batch_next_token_ids, success = min_p_sampling_from_probs(
batch_next_token_ids = min_p_sampling_from_probs(
probs, uniform_samples, sampling_info.min_ps
)
else: