Crash the server on warnings in CI (#1772)

This commit is contained in:
Lianmin Zheng
2024-10-23 16:27:13 -07:00
committed by GitHub
parent 3f5ac88d02
commit 05b3bf5e8e
5 changed files with 22 additions and 6 deletions

View File

@@ -1,4 +1,5 @@
import logging
import os
from typing import Union
import torch
@@ -17,6 +18,11 @@ if is_flashinfer_available():
top_p_renorm_prob,
)
# Crash on warning if we are running CI tests
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
logger = logging.getLogger(__name__)
@@ -36,6 +42,7 @@ class Sampler(nn.Module):
logits = logits.contiguous()
if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
exit(1) if crash_on_warning else None
logger.warning("Detected errors during sampling! NaN in the logits.")
logits = torch.where(
torch.isnan(logits), torch.full_like(logits, -1e5), logits

View File

@@ -116,7 +116,7 @@ class CudaGraphRunner:
if self.model_runner.server_args.disable_cuda_graph_padding:
self.capture_bs = list(range(1, 32)) + [64, 128]
else:
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
self.capture_bs = [1, 2, 3, 4] + [i * 8 for i in range(1, 21)]
self.capture_bs = [
bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size
]