Crash the server on warnings in CI (#1772)
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -17,6 +18,11 @@ if is_flashinfer_available():
|
|||||||
top_p_renorm_prob,
|
top_p_renorm_prob,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Crash on warning if we are running CI tests
|
||||||
|
crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -36,6 +42,7 @@ class Sampler(nn.Module):
|
|||||||
logits = logits.contiguous()
|
logits = logits.contiguous()
|
||||||
|
|
||||||
if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
|
if self.use_nan_detectioin and torch.any(torch.isnan(logits)):
|
||||||
|
exit(1) if crash_on_warning else None
|
||||||
logger.warning("Detected errors during sampling! NaN in the logits.")
|
logger.warning("Detected errors during sampling! NaN in the logits.")
|
||||||
logits = torch.where(
|
logits = torch.where(
|
||||||
torch.isnan(logits), torch.full_like(logits, -1e5), logits
|
torch.isnan(logits), torch.full_like(logits, -1e5), logits
|
||||||
|
|||||||
@@ -116,7 +116,7 @@ class CudaGraphRunner:
|
|||||||
if self.model_runner.server_args.disable_cuda_graph_padding:
|
if self.model_runner.server_args.disable_cuda_graph_padding:
|
||||||
self.capture_bs = list(range(1, 32)) + [64, 128]
|
self.capture_bs = list(range(1, 32)) + [64, 128]
|
||||||
else:
|
else:
|
||||||
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
self.capture_bs = [1, 2, 3, 4] + [i * 8 for i in range(1, 21)]
|
||||||
self.capture_bs = [
|
self.capture_bs = [
|
||||||
bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size
|
bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,3 +1,8 @@
|
|||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
python -m unittest test_eval_accuracy_large.TestEvalAccuracyLarge.test_mmlu
|
||||||
|
"""
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
@@ -32,12 +37,12 @@ class TestEvalAccuracyLarge(unittest.TestCase):
|
|||||||
base_url=self.base_url,
|
base_url=self.base_url,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
eval_name="mmlu",
|
eval_name="mmlu",
|
||||||
num_examples=3000,
|
num_examples=5000,
|
||||||
num_threads=1024,
|
num_threads=1024,
|
||||||
)
|
)
|
||||||
|
|
||||||
metrics = run_eval(args)
|
metrics = run_eval(args)
|
||||||
assert metrics["score"] >= 0.705, f"{metrics}"
|
assert metrics["score"] >= 0.71, f"{metrics}"
|
||||||
|
|
||||||
def test_human_eval(self):
|
def test_human_eval(self):
|
||||||
args = SimpleNamespace(
|
args = SimpleNamespace(
|
||||||
|
|||||||
@@ -1,3 +1,8 @@
|
|||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
python -m unittest test_moe_eval_accuracy_large.TestMoEEvalAccuracyLarge.test_mmlu
|
||||||
|
"""
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
@@ -11,7 +16,7 @@ from sglang.test.test_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestEvalAccuracyLarge(unittest.TestCase):
|
class TestMoEEvalAccuracyLarge(unittest.TestCase):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
cls.model = DEFAULT_MOE_MODEL_NAME_FOR_TEST
|
cls.model = DEFAULT_MOE_MODEL_NAME_FOR_TEST
|
||||||
@@ -37,7 +42,7 @@ class TestEvalAccuracyLarge(unittest.TestCase):
|
|||||||
base_url=self.base_url,
|
base_url=self.base_url,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
eval_name="mmlu",
|
eval_name="mmlu",
|
||||||
num_examples=3000,
|
num_examples=5000,
|
||||||
num_threads=1024,
|
num_threads=1024,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import json
|
|
||||||
import unittest
|
import unittest
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user