Fix sampler nan check when calling top_k_top_p_sampling_from_probs (#5546)
This commit is contained in:
@@ -100,17 +100,16 @@ class Sampler(nn.Module):
|
||||
probs, sampling_info.min_ps
|
||||
)
|
||||
else:
|
||||
# Check Nan will throw exception, only check when crash_on_warnings is True
|
||||
check_nan = self.use_nan_detection and crash_on_warnings()
|
||||
batch_next_token_ids = top_k_top_p_sampling_from_probs(
|
||||
probs,
|
||||
sampling_info.top_ks,
|
||||
sampling_info.top_ps,
|
||||
filter_apply_order="joint",
|
||||
check_nan=check_nan,
|
||||
)
|
||||
|
||||
if self.use_nan_detection:
|
||||
logger.warning("Detected errors during sampling!")
|
||||
batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
|
||||
|
||||
elif global_server_args_dict["sampling_backend"] == "pytorch":
|
||||
# A slower fallback implementation with torch native operations.
|
||||
batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
|
||||
|
||||
Reference in New Issue
Block a user