Crash the CI jobs on model import errors (#2072)

This commit is contained in:
Lianmin Zheng
2024-11-17 22:18:11 -08:00
committed by GitHub
parent a7164b620f
commit df7fe4521a
5 changed files with 30 additions and 25 deletions

View File

@@ -8,7 +8,7 @@ from torch import nn
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.utils import is_flashinfer_available
from sglang.srt.utils import crash_on_warnings, is_flashinfer_available
if is_flashinfer_available():
from flashinfer.sampling import (
@@ -19,10 +19,6 @@ if is_flashinfer_available():
)
# Crash on warning if we are running CI tests
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
logger = logging.getLogger(__name__)
@@ -46,7 +42,8 @@ class Sampler(nn.Module):
logits = torch.where(
torch.isnan(logits), torch.full_like(logits, -1e5), logits
)
exit(1) if crash_on_warning else None
if crash_on_warnings():
raise ValueError("Detected errors during sampling! NaN in the logits.")
if sampling_info.is_all_greedy:
# Use torch.argmax if all requests use greedy sampling