Crash the server on warnings in CI (#1772)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
@@ -17,6 +18,11 @@ if is_flashinfer_available():
|
||||
top_p_renorm_prob,
|
||||
)
|
||||
|
||||
|
||||
# Crash on warning if we are running CI tests
|
||||
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -36,6 +42,7 @@ class Sampler(nn.Module):
|
||||
logits = logits.contiguous()
|
||||
|
||||
if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
|
||||
exit(1) if crash_on_warning else None
|
||||
logger.warning("Detected errors during sampling! NaN in the logits.")
|
||||
logits = torch.where(
|
||||
torch.isnan(logits), torch.full_like(logits, -1e5), logits
|
||||
|
||||
Reference in New Issue
Block a user