Fix sampler nan check when calling top_k_top_p_sampling_from_probs (#5546)

This commit is contained in:
Yubo Wang
2025-04-19 21:47:23 -07:00
committed by GitHub
parent 613b197e57
commit 20f1c8e374
2 changed files with 8 additions and 9 deletions

View File

@@ -100,17 +100,16 @@ class Sampler(nn.Module):
probs, sampling_info.min_ps
)
else:
# Check Nan will throw exception, only check when crash_on_warnings is True
check_nan = self.use_nan_detection and crash_on_warnings()
batch_next_token_ids = top_k_top_p_sampling_from_probs(
probs,
sampling_info.top_ks,
sampling_info.top_ps,
filter_apply_order="joint",
check_nan=check_nan,
)
if self.use_nan_detection:
logger.warning("Detected errors during sampling!")
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
elif global_server_args_dict["sampling_backend"] == "pytorch":
# A slower fallback implementation with torch native operations.
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(