Fix the perf regression due to additional_stop_token_ids (#1773)
This commit is contained in:
@@ -42,11 +42,11 @@ class Sampler(nn.Module):
|
||||
logits = logits.contiguous()
|
||||
|
||||
if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
|
||||
exit(1) if crash_on_warning else None
|
||||
logger.warning("Detected errors during sampling! NaN in the logits.")
|
||||
logits = torch.where(
|
||||
torch.isnan(logits), torch.full_like(logits, -1e5), logits
|
||||
)
|
||||
exit(1) if crash_on_warning else None
|
||||
|
||||
if sampling_info.is_all_greedy:
|
||||
# Use torch.argmax if all requests use greedy sampling
|
||||
|
||||
Reference in New Issue
Block a user