Bug: Fix min_p sampling crash when using flashinfer backend (#3207)
Co-authored-by: zhaochenyang20 <zhaochen20@outlook.com>
This commit is contained in:
@@ -85,7 +85,7 @@ class Sampler(nn.Module):
|
||||
if sampling_info.need_min_p_sampling:
|
||||
probs = top_k_renorm_prob(probs, sampling_info.top_ks)
|
||||
probs = top_p_renorm_prob(probs, sampling_info.top_ps)
|
||||
batch_next_token_ids, success = min_p_sampling_from_probs(
|
||||
batch_next_token_ids = min_p_sampling_from_probs(
|
||||
probs, uniform_samples, sampling_info.min_ps
|
||||
)
|
||||
else:
|
||||
@@ -97,9 +97,9 @@ class Sampler(nn.Module):
|
||||
filter_apply_order="joint",
|
||||
)
|
||||
|
||||
if self.use_nan_detectioin and not torch.all(success):
|
||||
logger.warning("Detected errors during sampling!")
|
||||
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
|
||||
if self.use_nan_detectioin and not torch.all(success):
|
||||
logger.warning("Detected errors during sampling!")
|
||||
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
|
||||
|
||||
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
||||
# A slower fallback implementation with torch native operations.
|
||||
|
||||
Reference in New Issue
Block a user