Fix the perf regression due to additional_stop_token_ids (#1773)

This commit is contained in:
Lianmin Zheng
2024-10-23 16:45:21 -07:00
committed by GitHub
parent 05b3bf5e8e
commit 8f8f96a621
5 changed files with 20 additions and 16 deletions

View File

@@ -42,11 +42,11 @@ class Sampler(nn.Module):
logits = logits.contiguous()
if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
exit(1) if crash_on_warning else None
logger.warning("Detected errors during sampling! NaN in the logits.")
logits = torch.where(
torch.isnan(logits), torch.full_like(logits, -1e5), logits
)
exit(1) if crash_on_warning else None
if sampling_info.is_all_greedy:
# Use torch.argmax if all requests use greedy sampling