Crash the CI jobs on model import errors (#2072)
This commit is contained in:
@@ -8,7 +8,7 @@ from torch import nn
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.utils import is_flashinfer_available
|
||||
from sglang.srt.utils import crash_on_warnings, is_flashinfer_available
|
||||
|
||||
if is_flashinfer_available():
|
||||
from flashinfer.sampling import (
|
||||
@@ -19,10 +19,6 @@ if is_flashinfer_available():
|
||||
)
|
||||
|
||||
|
||||
# Crash on warning if we are running CI tests
|
||||
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -46,7 +42,8 @@ class Sampler(nn.Module):
|
||||
logits = torch.where(
|
||||
torch.isnan(logits), torch.full_like(logits, -1e5), logits
|
||||
)
|
||||
exit(1) if crash_on_warning else None
|
||||
if crash_on_warnings():
|
||||
raise ValueError("Detected errors during sampling! NaN in the logits.")
|
||||
|
||||
if sampling_info.is_all_greedy:
|
||||
# Use torch.argmax if all requests use greedy sampling
|
||||
|
||||
Reference in New Issue
Block a user