diff --git a/.github/workflows/execute-notebook.yml b/.github/workflows/execute-notebook.yml
index 9f3ca4181..affed72bf 100644
--- a/.github/workflows/execute-notebook.yml
+++ b/.github/workflows/execute-notebook.yml
@@ -33,7 +33,7 @@ jobs:
pip install -r docs/requirements.txt
apt-get update
apt-get install -y pandoc
- apt-get update && apt-get install -y parallel
+ apt-get update && apt-get install -y parallel retry
- name: Setup Jupyter Kernel
run: |
diff --git a/docs/Makefile b/docs/Makefile
index 9f4052639..0d6cad10e 100644
--- a/docs/Makefile
+++ b/docs/Makefile
@@ -23,7 +23,8 @@ compile:
parallel -0 -j3 --halt soon,fail=1 ' \
NB_NAME=$$(basename {}); \
START_TIME=$$(date +%s); \
- jupyter nbconvert --to notebook --execute --inplace "{}" \
+ retry --delay=0 --times=3 -- \
+ jupyter nbconvert --to notebook --execute --inplace "{}" \
--ExecutePreprocessor.timeout=600 \
--ExecutePreprocessor.kernel_name=python3; \
RET_CODE=$$?; \
diff --git a/python/sglang/test/test_activation.py b/python/sglang/test/test_activation.py
index 357a23319..38366e92b 100644
--- a/python/sglang/test/test_activation.py
+++ b/python/sglang/test/test_activation.py
@@ -4,9 +4,10 @@ import unittest
import torch
from sglang.srt.layers.activation import GeluAndMul
+from sglang.test.test_utils import CustomTestCase
-class TestGeluAndMul(unittest.TestCase):
+class TestGeluAndMul(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16]
NUM_TOKENS = [7, 83, 2048]
D = [512, 4096, 5120, 13824]
diff --git a/python/sglang/test/test_block_fp8.py b/python/sglang/test/test_block_fp8.py
index 25aaf498a..a7c068ac2 100644
--- a/python/sglang/test/test_block_fp8.py
+++ b/python/sglang/test/test_block_fp8.py
@@ -11,6 +11,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
static_quant_fp8,
w8a8_block_fp8_matmul,
)
+from sglang.test.test_utils import CustomTestCase
_is_cuda = torch.cuda.is_available() and torch.version.cuda
@@ -44,7 +45,7 @@ def native_per_token_group_quant_fp8(
return x_q, x_s
-class TestPerTokenGroupQuantFP8(unittest.TestCase):
+class TestPerTokenGroupQuantFP8(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16, torch.float32]
NUM_TOKENS = [7, 83, 2048]
D = [512, 4096, 5120, 13824]
@@ -111,7 +112,7 @@ def native_static_quant_fp8(x, x_s, dtype=torch.float8_e4m3fn):
return x_q, x_s
-class TestStaticQuantFP8(unittest.TestCase):
+class TestStaticQuantFP8(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16, torch.float32]
NUM_TOKENS = [7, 83, 2048]
D = [512, 4096, 5120, 13824]
@@ -210,7 +211,7 @@ def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.fl
return C
-class TestW8A8BlockFP8Matmul(unittest.TestCase):
+class TestW8A8BlockFP8Matmul(CustomTestCase):
if not _is_cuda:
OUT_DTYPES = [torch.float32, torch.half, torch.bfloat16]
@@ -331,7 +332,7 @@ def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
).sum(dim=1)
-class TestW8A8BlockFP8FusedMoE(unittest.TestCase):
+class TestW8A8BlockFP8FusedMoE(CustomTestCase):
DTYPES = [torch.float32, torch.half, torch.bfloat16]
M = [1, 33, 64, 222, 1024 * 128]
N = [128, 1024, 2048]
diff --git a/python/sglang/test/test_block_fp8_ep.py b/python/sglang/test/test_block_fp8_ep.py
index c077d0c45..ad8a1694d 100644
--- a/python/sglang/test/test_block_fp8_ep.py
+++ b/python/sglang/test/test_block_fp8_ep.py
@@ -13,6 +13,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
silu_and_mul_triton_kernel,
)
from sglang.srt.layers.moe.topk import select_experts
+from sglang.test.test_utils import CustomTestCase
# For test
@@ -232,7 +233,7 @@ def block_dequant(
return x_dq_block
-class TestW8A8BlockFP8EPMoE(unittest.TestCase):
+class TestW8A8BlockFP8EPMoE(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16]
M = [1, 222, 1024, 2048]
N = [128, 1024, 2048]
diff --git a/python/sglang/test/test_dynamic_grad_mode.py b/python/sglang/test/test_dynamic_grad_mode.py
index c0287ec3d..078c3139d 100644
--- a/python/sglang/test/test_dynamic_grad_mode.py
+++ b/python/sglang/test/test_dynamic_grad_mode.py
@@ -3,9 +3,10 @@ import unittest
import torch
from sglang.srt.utils import DynamicGradMode
+from sglang.test.test_utils import CustomTestCase
-class TestDynamicGradMode(unittest.TestCase):
+class TestDynamicGradMode(CustomTestCase):
def test_inference(self):
# Test inference_mode
DynamicGradMode.set_inference_mode(True)
diff --git a/python/sglang/test/test_layernorm.py b/python/sglang/test/test_layernorm.py
index 770e69733..05b6593eb 100644
--- a/python/sglang/test/test_layernorm.py
+++ b/python/sglang/test/test_layernorm.py
@@ -4,9 +4,10 @@ import unittest
import torch
from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm
+from sglang.test.test_utils import CustomTestCase
-class TestRMSNorm(unittest.TestCase):
+class TestRMSNorm(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16]
NUM_TOKENS = [7, 83, 4096]
HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, 8199]
@@ -56,7 +57,7 @@ class TestRMSNorm(unittest.TestCase):
self._run_rms_norm_test(*params)
-class TestGemmaRMSNorm(unittest.TestCase):
+class TestGemmaRMSNorm(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16]
NUM_TOKENS = [7, 83, 4096]
HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, 8199]
diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py
index 87426729a..f13986e61 100644
--- a/python/sglang/test/test_utils.py
+++ b/python/sglang/test/test_utils.py
@@ -8,6 +8,7 @@ import random
import subprocess
import threading
import time
+import traceback
import unittest
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
@@ -998,3 +999,30 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple):
rank += 1
else:
raise
+
+
+class CustomTestCase(unittest.TestCase):
+ def _callTestMethod(self, method):
+ _retry_execution(
+ lambda: super(CustomTestCase, self)._callTestMethod(method),
+ max_retry=_get_max_retry(),
+ )
+
+
+def _get_max_retry():
+ return int(os.environ.get("SGLANG_TEST_MAX_RETRY", "2" if is_in_ci() else "0"))
+
+
+def _retry_execution(fn, max_retry: int):
+ if max_retry == 0:
+ fn()
+ return
+
+ try:
+ fn()
+ except Exception as e:
+ print(
+ f"retry_execution failed once and will retry. This may be an error or a flaky test. Error: {e}"
+ )
+ traceback.print_exc()
+ _retry_execution(fn, max_retry=max_retry - 1)
diff --git a/test/lang/test_anthropic_backend.py b/test/lang/test_anthropic_backend.py
index 03911449d..dc8f0b17e 100644
--- a/test/lang/test_anthropic_backend.py
+++ b/test/lang/test_anthropic_backend.py
@@ -3,9 +3,10 @@ import unittest
from sglang import Anthropic, set_default_backend
from sglang.test.test_programs import test_mt_bench, test_stream
+from sglang.test.test_utils import CustomTestCase
-class TestAnthropicBackend(unittest.TestCase):
+class TestAnthropicBackend(CustomTestCase):
backend = None
@classmethod
diff --git a/test/lang/test_bind_cache.py b/test/lang/test_bind_cache.py
index 5ed68ff45..d5beefae6 100644
--- a/test/lang/test_bind_cache.py
+++ b/test/lang/test_bind_cache.py
@@ -1,10 +1,10 @@
import unittest
import sglang as sgl
-from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST
+from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, CustomTestCase
-class TestBind(unittest.TestCase):
+class TestBind(CustomTestCase):
backend = None
@classmethod
diff --git a/test/lang/test_choices.py b/test/lang/test_choices.py
index 88cd22dfb..89e7ca1c7 100644
--- a/test/lang/test_choices.py
+++ b/test/lang/test_choices.py
@@ -7,6 +7,7 @@ from sglang.lang.choices import (
token_length_normalized,
unconditional_likelihood_normalized,
)
+from sglang.test.test_utils import CustomTestCase
MOCK_CHOICES_INPUT_DATA = {
"choices": [
@@ -51,7 +52,7 @@ MOCK_CHOICES_INPUT_DATA = {
}
-class TestChoices(unittest.TestCase):
+class TestChoices(CustomTestCase):
def test_token_length_normalized(self):
"""Confirm 'antidisestablishmentarianism' is selected due to high confidences for
diff --git a/test/lang/test_litellm_backend.py b/test/lang/test_litellm_backend.py
index 649e2e4d3..74c3a187a 100644
--- a/test/lang/test_litellm_backend.py
+++ b/test/lang/test_litellm_backend.py
@@ -3,9 +3,10 @@ import unittest
from sglang import LiteLLM, set_default_backend
from sglang.test.test_programs import test_mt_bench, test_stream
+from sglang.test.test_utils import CustomTestCase
-class TestAnthropicBackend(unittest.TestCase):
+class TestAnthropicBackend(CustomTestCase):
chat_backend = None
@classmethod
diff --git a/test/lang/test_openai_backend.py b/test/lang/test_openai_backend.py
index 220784ab3..dbef7d450 100644
--- a/test/lang/test_openai_backend.py
+++ b/test/lang/test_openai_backend.py
@@ -17,9 +17,10 @@ from sglang.test.test_programs import (
test_stream,
test_tool_use,
)
+from sglang.test.test_utils import CustomTestCase
-class TestOpenAIBackend(unittest.TestCase):
+class TestOpenAIBackend(CustomTestCase):
instruct_backend = None
chat_backend = None
chat_vision_backend = None
diff --git a/test/lang/test_srt_backend.py b/test/lang/test_srt_backend.py
index 29f7a12a2..0e05eb906 100644
--- a/test/lang/test_srt_backend.py
+++ b/test/lang/test_srt_backend.py
@@ -22,10 +22,10 @@ from sglang.test.test_programs import (
test_stream,
test_tool_use,
)
-from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST
+from sglang.test.test_utils import DEFAULT_MODEL_NAME_FOR_TEST, CustomTestCase
-class TestSRTBackend(unittest.TestCase):
+class TestSRTBackend(CustomTestCase):
backend = None
@classmethod
diff --git a/test/lang/test_tracing.py b/test/lang/test_tracing.py
index 7c3af071b..3f02ac52b 100644
--- a/test/lang/test_tracing.py
+++ b/test/lang/test_tracing.py
@@ -3,9 +3,10 @@ import unittest
import sglang as sgl
from sglang.lang.backend.base_backend import BaseBackend
from sglang.lang.chat_template import get_chat_template
+from sglang.test.test_utils import CustomTestCase
-class TestTracing(unittest.TestCase):
+class TestTracing(CustomTestCase):
def test_few_shot_qa(self):
@sgl.function
def few_shot_qa(s, question):
diff --git a/test/lang/test_vertexai_backend.py b/test/lang/test_vertexai_backend.py
index da229854e..83ce7fc0b 100644
--- a/test/lang/test_vertexai_backend.py
+++ b/test/lang/test_vertexai_backend.py
@@ -10,9 +10,10 @@ from sglang.test.test_programs import (
test_parallel_encoding,
test_stream,
)
+from sglang.test.test_utils import CustomTestCase
-class TestVertexAIBackend(unittest.TestCase):
+class TestVertexAIBackend(CustomTestCase):
backend = None
@classmethod
diff --git a/test/srt/models/lora/test_lora.py b/test/srt/models/lora/test_lora.py
index 042038efe..16a7f0d48 100644
--- a/test/srt/models/lora/test_lora.py
+++ b/test/srt/models/lora/test_lora.py
@@ -18,6 +18,7 @@ import unittest
import torch
from sglang.test.runners import HFRunner, SRTRunner
+from sglang.test.test_utils import CustomTestCase
LORA_SETS = [
# {
@@ -70,7 +71,7 @@ What do you know about llamas?
# PROMPTS.append(sample[0]["content"][:2000])
-class TestLoRA(unittest.TestCase):
+class TestLoRA(CustomTestCase):
def inference(self, prompts, lora_set, tp_size, torch_dtype, max_new_tokens):
print("=================== testing inference =======================")
diff --git a/test/srt/models/lora/test_lora_backend.py b/test/srt/models/lora/test_lora_backend.py
index 4c7fe8074..79d7dbebd 100644
--- a/test/srt/models/lora/test_lora_backend.py
+++ b/test/srt/models/lora/test_lora_backend.py
@@ -21,7 +21,7 @@ import torch
from utils import BACKENDS, TORCH_DTYPES, LoRAAdaptor, LoRAModelCase
from sglang.test.runners import HFRunner, SRTRunner
-from sglang.test.test_utils import calculate_rouge_l, is_in_ci
+from sglang.test.test_utils import CustomTestCase, calculate_rouge_l, is_in_ci
CI_LORA_MODELS = [
LoRAModelCase(
@@ -67,7 +67,7 @@ PROMPTS = [
]
-class TestLoRABackend(unittest.TestCase):
+class TestLoRABackend(CustomTestCase):
def run_backend(
self,
prompt: str,
diff --git a/test/srt/models/lora/test_lora_tp.py b/test/srt/models/lora/test_lora_tp.py
index 2da0d9e02..925a34bc3 100644
--- a/test/srt/models/lora/test_lora_tp.py
+++ b/test/srt/models/lora/test_lora_tp.py
@@ -21,7 +21,7 @@ import torch
from utils import TORCH_DTYPES, LoRAAdaptor, LoRAModelCase
from sglang.test.runners import HFRunner, SRTRunner
-from sglang.test.test_utils import calculate_rouge_l, is_in_ci
+from sglang.test.test_utils import CustomTestCase, calculate_rouge_l, is_in_ci
CI_LORA_MODELS = [
LoRAModelCase(
@@ -69,7 +69,7 @@ PROMPTS = [
BACKEND = "triton"
-class TestLoRATP(unittest.TestCase):
+class TestLoRATP(CustomTestCase):
def run_tp(
self,
prompt: str,
diff --git a/test/srt/models/lora/test_multi_lora_backend.py b/test/srt/models/lora/test_multi_lora_backend.py
index 8d0047df3..7fca18a8d 100644
--- a/test/srt/models/lora/test_multi_lora_backend.py
+++ b/test/srt/models/lora/test_multi_lora_backend.py
@@ -19,7 +19,7 @@ from typing import List
import torch
from utils import BACKENDS, TORCH_DTYPES, LoRAAdaptor, LoRAModelCase
-from sglang.test.test_utils import is_in_ci
+from sglang.test.test_utils import CustomTestCase, is_in_ci
MULTI_LORA_MODELS = [
LoRAModelCase(
@@ -51,7 +51,7 @@ PROMPTS = [
]
-class TestMultiLoRABackend(unittest.TestCase):
+class TestMultiLoRABackend(CustomTestCase):
def run_backend_batch(
self,
prompts: List[str],
diff --git a/test/srt/models/test_embedding_models.py b/test/srt/models/test_embedding_models.py
index d93a65859..517ee831f 100644
--- a/test/srt/models/test_embedding_models.py
+++ b/test/srt/models/test_embedding_models.py
@@ -20,7 +20,7 @@ import torch
from transformers import AutoConfig, AutoTokenizer
from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner
-from sglang.test.test_utils import get_similarities, is_in_ci
+from sglang.test.test_utils import CustomTestCase, get_similarities, is_in_ci
MODELS = [
("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, 1e-5),
@@ -31,7 +31,7 @@ MODELS = [
TORCH_DTYPES = [torch.float16]
-class TestEmbeddingModels(unittest.TestCase):
+class TestEmbeddingModels(CustomTestCase):
@classmethod
def setUpClass(cls):
diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py
index b486c8a96..237af60cd 100644
--- a/test/srt/models/test_generation_models.py
+++ b/test/srt/models/test_generation_models.py
@@ -33,7 +33,7 @@ from sglang.test.runners import (
SRTRunner,
check_close_model_outputs,
)
-from sglang.test.test_utils import is_in_ci
+from sglang.test.test_utils import CustomTestCase, is_in_ci
@dataclasses.dataclass
@@ -71,7 +71,7 @@ ALL_OTHER_MODELS = [
TORCH_DTYPES = [torch.float16]
-class TestGenerationModels(unittest.TestCase):
+class TestGenerationModels(CustomTestCase):
@classmethod
def setUpClass(cls):
diff --git a/test/srt/models/test_gme_qwen_models.py b/test/srt/models/test_gme_qwen_models.py
index 82f56adb3..265bf8a39 100644
--- a/test/srt/models/test_gme_qwen_models.py
+++ b/test/srt/models/test_gme_qwen_models.py
@@ -19,7 +19,7 @@ import unittest
import torch
from sglang.test.runners import HFRunner, SRTRunner
-from sglang.test.test_utils import get_similarities
+from sglang.test.test_utils import CustomTestCase, get_similarities
TEXTS = "two Subway Series sandwiches with meats, cheese, lettuce, tomatoes, and onions on a black background, accompanied by the Subway Series logo, highlighting a new sandwich series."
IMAGES = "https://huggingface.co/datasets/liuhaotian/llava-bench-in-the-wild/resolve/main/images/023.jpg"
@@ -31,7 +31,7 @@ MODELS = [
TORCH_DTYPES = [torch.float16]
-class TestQmeQwenModels(unittest.TestCase):
+class TestQmeQwenModels(CustomTestCase):
@classmethod
def setUpClass(cls):
mp.set_start_method("spawn", force=True)
diff --git a/test/srt/models/test_grok_models.py b/test/srt/models/test_grok_models.py
index 814f09a00..625fa1a65 100644
--- a/test/srt/models/test_grok_models.py
+++ b/test/srt/models/test_grok_models.py
@@ -6,11 +6,12 @@ from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestGrok(unittest.TestCase):
+class TestGrok(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "lmzheng/grok-1"
diff --git a/test/srt/models/test_qwen_models.py b/test/srt/models/test_qwen_models.py
index d2a418ccf..567e19f08 100644
--- a/test/srt/models/test_qwen_models.py
+++ b/test/srt/models/test_qwen_models.py
@@ -6,11 +6,12 @@ from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestQwen2(unittest.TestCase):
+class TestQwen2(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "Qwen/Qwen2-7B-Instruct"
@@ -41,7 +42,7 @@ class TestQwen2(unittest.TestCase):
self.assertGreater(metrics["accuracy"], 0.78)
-class TestQwen2FP8(unittest.TestCase):
+class TestQwen2FP8(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "neuralmagic/Qwen2-7B-Instruct-FP8"
diff --git a/test/srt/models/test_reward_models.py b/test/srt/models/test_reward_models.py
index 69ad56367..5592ce223 100644
--- a/test/srt/models/test_reward_models.py
+++ b/test/srt/models/test_reward_models.py
@@ -18,6 +18,7 @@ import unittest
import torch
from sglang.test.runners import HFRunner, SRTRunner
+from sglang.test.test_utils import CustomTestCase
MODELS = [
("LxzGordon/URM-LLaMa-3.1-8B", 1, 4e-2),
@@ -41,7 +42,7 @@ CONVS = [
]
-class TestRewardModels(unittest.TestCase):
+class TestRewardModels(CustomTestCase):
@classmethod
def setUpClass(cls):
diff --git a/test/srt/test_abort.py b/test/srt/test_abort.py
index ae27d83a8..d2ab4d034 100644
--- a/test/srt/test_abort.py
+++ b/test/srt/test_abort.py
@@ -5,10 +5,10 @@ from concurrent.futures import ThreadPoolExecutor
import requests
-from sglang.test.test_utils import run_and_check_memory_leak
+from sglang.test.test_utils import CustomTestCase, run_and_check_memory_leak
-class TestAbort(unittest.TestCase):
+class TestAbort(CustomTestCase):
def workload_func(self, base_url, model):
def process_func():
def run_one(_):
diff --git a/test/srt/test_awq.py b/test/srt/test_awq.py
index 30493634d..1b461194a 100644
--- a/test/srt/test_awq.py
+++ b/test/srt/test_awq.py
@@ -7,11 +7,12 @@ from sglang.test.test_utils import (
DEFAULT_AWQ_MOE_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestAWQ(unittest.TestCase):
+class TestAWQ(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_AWQ_MOE_MODEL_NAME_FOR_TEST
diff --git a/test/srt/test_bench_one_batch.py b/test/srt/test_bench_one_batch.py
index 65c894b57..9973b1fa9 100644
--- a/test/srt/test_bench_one_batch.py
+++ b/test/srt/test_bench_one_batch.py
@@ -3,6 +3,7 @@ import unittest
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_MOE_MODEL_NAME_FOR_TEST,
+ CustomTestCase,
get_bool_env_var,
is_in_ci,
run_bench_one_batch,
@@ -10,7 +11,7 @@ from sglang.test.test_utils import (
)
-class TestBenchOneBatch(unittest.TestCase):
+class TestBenchOneBatch(CustomTestCase):
def test_bs1(self):
output_throughput = run_bench_one_batch(
DEFAULT_MODEL_NAME_FOR_TEST, ["--cuda-graph-max-bs", "2"]
diff --git a/test/srt/test_bench_serving.py b/test/srt/test_bench_serving.py
index f8b4b1f9a..811b5d739 100644
--- a/test/srt/test_bench_serving.py
+++ b/test/srt/test_bench_serving.py
@@ -6,13 +6,14 @@ from sglang.test.test_utils import (
DEFAULT_FP8_MODEL_NAME_FOR_TEST,
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_MOE_MODEL_NAME_FOR_TEST,
+ CustomTestCase,
is_in_ci,
run_bench_serving,
write_github_step_summary,
)
-class TestBenchServing(unittest.TestCase):
+class TestBenchServing(CustomTestCase):
def test_offline_throughput_default(self):
res = run_bench_serving(
diff --git a/test/srt/test_block_int8.py b/test/srt/test_block_int8.py
index 1c6c7656c..2b8b841f0 100644
--- a/test/srt/test_block_int8.py
+++ b/test/srt/test_block_int8.py
@@ -5,6 +5,7 @@ import torch
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
+from sglang.test.test_utils import CustomTestCase
# For test
@@ -121,7 +122,7 @@ def torch_w8a8_block_int8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
).sum(dim=1)
-class TestW8A8BlockINT8FusedMoE(unittest.TestCase):
+class TestW8A8BlockINT8FusedMoE(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16]
M = [1, 33, 64, 222]
N = [128, 1024]
diff --git a/test/srt/test_cache_report.py b/test/srt/test_cache_report.py
index f128aa147..2dd5619b3 100644
--- a/test/srt/test_cache_report.py
+++ b/test/srt/test_cache_report.py
@@ -8,11 +8,12 @@ from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestCacheReport(unittest.TestCase):
+class TestCacheReport(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
diff --git a/test/srt/test_chunked_prefill.py b/test/srt/test_chunked_prefill.py
index cafd99931..fbdc87c55 100644
--- a/test/srt/test_chunked_prefill.py
+++ b/test/srt/test_chunked_prefill.py
@@ -4,10 +4,10 @@ python3 -m unittest test_chunked_prefill.TestChunkedPrefill.test_mixed_chunked_p
import unittest
-from sglang.test.test_utils import run_mmlu_test, run_mulit_request_test
+from sglang.test.test_utils import CustomTestCase, run_mmlu_test, run_mulit_request_test
-class TestChunkedPrefill(unittest.TestCase):
+class TestChunkedPrefill(CustomTestCase):
def test_chunked_prefill(self):
run_mmlu_test(disable_radix_cache=False, enable_mixed_chunk=False)
diff --git a/test/srt/test_create_kvindices.py b/test/srt/test_create_kvindices.py
index e604e5755..4196eb290 100644
--- a/test/srt/test_create_kvindices.py
+++ b/test/srt/test_create_kvindices.py
@@ -5,9 +5,10 @@ import numpy as np
import torch
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
+from sglang.test.test_utils import CustomTestCase
-class TestCreateKvIndices(unittest.TestCase):
+class TestCreateKvIndices(CustomTestCase):
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
diff --git a/test/srt/test_custom_allreduce.py b/test/srt/test_custom_allreduce.py
index 7ac0c8ffc..38600aeab 100644
--- a/test/srt/test_custom_allreduce.py
+++ b/test/srt/test_custom_allreduce.py
@@ -17,6 +17,7 @@ from sglang.srt.distributed.parallel_state import (
graph_capture,
initialize_model_parallel,
)
+from sglang.test.test_utils import CustomTestCase
def get_open_port() -> int:
@@ -54,7 +55,7 @@ def multi_process_parallel(
ray.shutdown()
-class TestCustomAllReduce(unittest.TestCase):
+class TestCustomAllReduce(CustomTestCase):
@classmethod
def setUpClass(cls):
random.seed(42)
diff --git a/test/srt/test_data_parallelism.py b/test/srt/test_data_parallelism.py
index 1c674c327..e54e3a5cf 100644
--- a/test/srt/test_data_parallelism.py
+++ b/test/srt/test_data_parallelism.py
@@ -10,11 +10,12 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestDataParallelism(unittest.TestCase):
+class TestDataParallelism(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
diff --git a/test/srt/test_double_sparsity.py b/test/srt/test_double_sparsity.py
index 060a7926f..c936e79bb 100644
--- a/test/srt/test_double_sparsity.py
+++ b/test/srt/test_double_sparsity.py
@@ -8,11 +8,12 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestDoubleSparsity(unittest.TestCase):
+class TestDoubleSparsity(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
diff --git a/test/srt/test_dp_attention.py b/test/srt/test_dp_attention.py
index f7811911f..622f9b633 100644
--- a/test/srt/test_dp_attention.py
+++ b/test/srt/test_dp_attention.py
@@ -7,11 +7,12 @@ from sglang.test.test_utils import (
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestDPAttentionDP2TP2(unittest.TestCase):
+class TestDPAttentionDP2TP2(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
diff --git a/test/srt/test_eagle_infer.py b/test/srt/test_eagle_infer.py
index 30c846353..dd15b1af9 100644
--- a/test/srt/test_eagle_infer.py
+++ b/test/srt/test_eagle_infer.py
@@ -24,6 +24,7 @@ from sglang.test.test_utils import (
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
run_logprob_check,
)
@@ -33,7 +34,7 @@ prefill_tolerance = 5e-2
decode_tolerance: float = 5e-2
-class TestEAGLEEngine(unittest.TestCase):
+class TestEAGLEEngine(CustomTestCase):
BASE_CONFIG = {
"model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
"speculative_draft_model_path": DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
@@ -179,7 +180,7 @@ class TestEAGLE3Engine(TestEAGLEEngine):
NUM_CONFIGS = 1
-class TestEAGLEServer(unittest.TestCase):
+class TestEAGLEServer(CustomTestCase):
PROMPTS = [
"[INST] <>\\nYou are a helpful assistant.\\n<>\\nToday is a sunny day and I like[/INST]"
'[INST] <>\\nYou are a helpful assistant.\\n<>\\nWhat are the mental triggers in Jeff Walker\'s Product Launch Formula and "Launch" book?[/INST]',
diff --git a/test/srt/test_ebnf_constrained.py b/test/srt/test_ebnf_constrained.py
index 799373c59..e0771b97d 100644
--- a/test/srt/test_ebnf_constrained.py
+++ b/test/srt/test_ebnf_constrained.py
@@ -15,6 +15,7 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
@@ -42,7 +43,7 @@ def setup_class(cls, backend: str, disable_overlap: bool):
)
-class TestEBNFConstrained(unittest.TestCase):
+class TestEBNFConstrained(CustomTestCase):
@classmethod
def setUpClass(cls):
setup_class(cls, "xgrammar", disable_overlap=False)
diff --git a/test/srt/test_embedding_openai_server.py b/test/srt/test_embedding_openai_server.py
index 8097bf42c..e7d19d451 100644
--- a/test/srt/test_embedding_openai_server.py
+++ b/test/srt/test_embedding_openai_server.py
@@ -7,11 +7,12 @@ from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestOpenAIServer(unittest.TestCase):
+class TestOpenAIServer(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "intfloat/e5-mistral-7b-instruct"
diff --git a/test/srt/test_eval_accuracy_large.py b/test/srt/test_eval_accuracy_large.py
index 6b43f5aa8..efb202463 100644
--- a/test/srt/test_eval_accuracy_large.py
+++ b/test/srt/test_eval_accuracy_large.py
@@ -12,13 +12,14 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
is_in_ci,
popen_launch_server,
write_github_step_summary,
)
-class TestEvalAccuracyLarge(unittest.TestCase):
+class TestEvalAccuracyLarge(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
diff --git a/test/srt/test_eval_fp8_accuracy.py b/test/srt/test_eval_fp8_accuracy.py
index 9431d14d3..12b2499d9 100644
--- a/test/srt/test_eval_fp8_accuracy.py
+++ b/test/srt/test_eval_fp8_accuracy.py
@@ -13,11 +13,12 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestEvalFP8Accuracy(unittest.TestCase):
+class TestEvalFP8Accuracy(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_FP8_MODEL_NAME_FOR_ACCURACY_TEST
@@ -44,7 +45,7 @@ class TestEvalFP8Accuracy(unittest.TestCase):
self.assertGreaterEqual(metrics["score"], 0.61)
-class TestEvalFP8DynamicQuantAccuracy(unittest.TestCase):
+class TestEvalFP8DynamicQuantAccuracy(CustomTestCase):
def _run_test(self, model, other_args, expected_score):
base_url = DEFAULT_URL_FOR_TEST
@@ -109,7 +110,7 @@ class TestEvalFP8DynamicQuantAccuracy(unittest.TestCase):
)
-class TestEvalFP8ModelOptQuantAccuracy(unittest.TestCase):
+class TestEvalFP8ModelOptQuantAccuracy(CustomTestCase):
def _run_test(self, model, other_args, expected_score):
base_url = DEFAULT_URL_FOR_TEST
diff --git a/test/srt/test_expert_distribution.py b/test/srt/test_expert_distribution.py
index cd8fb7e3e..e3826303d 100755
--- a/test/srt/test_expert_distribution.py
+++ b/test/srt/test_expert_distribution.py
@@ -10,11 +10,12 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestExpertDistribution(unittest.TestCase):
+class TestExpertDistribution(CustomTestCase):
def setUp(self):
# Clean up any existing expert distribution files before each test
for f in glob.glob("expert_distribution_*.csv"):
diff --git a/test/srt/test_fim_completion.py b/test/srt/test_fim_completion.py
index 132911e65..09db1d4bc 100644
--- a/test/srt/test_fim_completion.py
+++ b/test/srt/test_fim_completion.py
@@ -7,11 +7,12 @@ from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestFimCompletion(unittest.TestCase):
+class TestFimCompletion(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "deepseek-ai/deepseek-coder-1.3b-base"
diff --git a/test/srt/test_fp8_kernel.py b/test/srt/test_fp8_kernel.py
index dcc5d4274..1f8d94b3a 100644
--- a/test/srt/test_fp8_kernel.py
+++ b/test/srt/test_fp8_kernel.py
@@ -6,9 +6,10 @@ from sglang.srt.layers.quantization.fp8_kernel import (
per_token_group_quant_fp8,
w8a8_block_fp8_matmul,
)
+from sglang.test.test_utils import CustomTestCase
-class TestFP8Base(unittest.TestCase):
+class TestFP8Base(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.M = 256
diff --git a/test/srt/test_fp8_kvcache.py b/test/srt/test_fp8_kvcache.py
index 4a8a24346..669926f5d 100644
--- a/test/srt/test_fp8_kvcache.py
+++ b/test/srt/test_fp8_kvcache.py
@@ -9,11 +9,12 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST_QWEN,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestFp8KvcacheBase(unittest.TestCase):
+class TestFp8KvcacheBase(CustomTestCase):
model_config = None
@classmethod
diff --git a/test/srt/test_function_calling.py b/test/srt/test_function_calling.py
index f422d5ea5..9eaa93027 100644
--- a/test/srt/test_function_calling.py
+++ b/test/srt/test_function_calling.py
@@ -10,11 +10,12 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestOpenAIServerFunctionCalling(unittest.TestCase):
+class TestOpenAIServerFunctionCalling(CustomTestCase):
@classmethod
def setUpClass(cls):
# Replace with the model name needed for testing; if not required, reuse DEFAULT_SMALL_MODEL_NAME_FOR_TEST
diff --git a/test/srt/test_fused_moe.py b/test/srt/test_fused_moe.py
index 6534a4a60..fcff74d62 100644
--- a/test/srt/test_fused_moe.py
+++ b/test/srt/test_fused_moe.py
@@ -7,9 +7,10 @@ from vllm.model_executor.layers.fused_moe import fused_moe as fused_moe_vllm
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
+from sglang.test.test_utils import CustomTestCase
-class TestFusedMOE(unittest.TestCase):
+class TestFusedMOE(CustomTestCase):
NUM_EXPERTS = [8, 64]
TOP_KS = [2, 6]
diff --git a/test/srt/test_get_weights_by_name.py b/test/srt/test_get_weights_by_name.py
index 6dcb1d249..3d404df10 100644
--- a/test/srt/test_get_weights_by_name.py
+++ b/test/srt/test_get_weights_by_name.py
@@ -12,6 +12,7 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
is_in_ci,
popen_launch_server,
)
@@ -26,7 +27,7 @@ def _process_return(ret):
return np.array(ret)
-class TestGetWeightsByName(unittest.TestCase):
+class TestGetWeightsByName(CustomTestCase):
def init_hf_model(self, model_name, tie_word_embeddings):
self.hf_model = AutoModelForCausalLM.from_pretrained(
diff --git a/test/srt/test_gguf.py b/test/srt/test_gguf.py
index b4072c775..e9776067c 100644
--- a/test/srt/test_gguf.py
+++ b/test/srt/test_gguf.py
@@ -3,9 +3,10 @@ import unittest
from huggingface_hub import hf_hub_download
import sglang as sgl
+from sglang.test.test_utils import CustomTestCase
-class TestGGUF(unittest.TestCase):
+class TestGGUF(CustomTestCase):
def test_models(self):
prompt = "Today is a sunny day and I like"
sampling_params = {"temperature": 0, "max_new_tokens": 8}
diff --git a/test/srt/test_gptqmodel_dynamic.py b/test/srt/test_gptqmodel_dynamic.py
index f0ee63e6b..54dbaf496 100644
--- a/test/srt/test_gptqmodel_dynamic.py
+++ b/test/srt/test_gptqmodel_dynamic.py
@@ -8,6 +8,7 @@ from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
@@ -102,7 +103,7 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool):
# GPTQ with Dynamic Per/Module Quantization Control
# Leverages GPTQModel (pypi) to produce the `dynamic` models
# Test GPTQ fallback kernel that is not Marlin
-class TestGPTQModelDynamic(unittest.TestCase):
+class TestGPTQModelDynamic(CustomTestCase):
MODEL_PATH = (
"ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symFalse"
)
@@ -157,7 +158,7 @@ class TestGPTQModelDynamic(unittest.TestCase):
# GPTQ with Dynamic Per/Module Quantization Control
# Leverages GPTQModel (pypi) to produce the `dynamic` models
# Test Marlin kernel
-class TestGPTQModelDynamicWithMarlin(unittest.TestCase):
+class TestGPTQModelDynamicWithMarlin(CustomTestCase):
MODEL_PATH = (
"ModelCloud/Qwen1.5-1.8B-Chat-GPTQ-4bits-dynamic-cfg-with-lm_head-symTrue"
)
diff --git a/test/srt/test_health_check.py b/test/srt/test_health_check.py
index 708230ffd..1f101e43b 100644
--- a/test/srt/test_health_check.py
+++ b/test/srt/test_health_check.py
@@ -3,11 +3,12 @@ import unittest
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestHealthCheck(unittest.TestCase):
+class TestHealthCheck(CustomTestCase):
def test_health_check(self):
"""Test that metrics endpoint returns data when enabled"""
with self.assertRaises(TimeoutError):
diff --git a/test/srt/test_hicache.py b/test/srt/test_hicache.py
index 0b1d91366..ac7a27f04 100644
--- a/test/srt/test_hicache.py
+++ b/test/srt/test_hicache.py
@@ -7,11 +7,12 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestPageSize(unittest.TestCase):
+class TestPageSize(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
diff --git a/test/srt/test_hicache_mla.py b/test/srt/test_hicache_mla.py
index 8250395a0..8396615f3 100644
--- a/test/srt/test_hicache_mla.py
+++ b/test/srt/test_hicache_mla.py
@@ -7,11 +7,12 @@ from sglang.test.test_utils import (
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestHierarchicalMLA(unittest.TestCase):
+class TestHierarchicalMLA(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
diff --git a/test/srt/test_hidden_states.py b/test/srt/test_hidden_states.py
index 87676c0ad..81e42f7b1 100644
--- a/test/srt/test_hidden_states.py
+++ b/test/srt/test_hidden_states.py
@@ -4,10 +4,10 @@ import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import sglang as sgl
-from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
+from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase
-class TestHiddenState(unittest.TestCase):
+class TestHiddenState(CustomTestCase):
def test_return_hidden_states(self):
prompts = ["Today is", "Today is a sunny day and I like"]
model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
diff --git a/test/srt/test_input_embeddings.py b/test/srt/test_input_embeddings.py
index 92b643fd3..7efd66565 100644
--- a/test/srt/test_input_embeddings.py
+++ b/test/srt/test_input_embeddings.py
@@ -11,11 +11,12 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestInputEmbeds(unittest.TestCase):
+class TestInputEmbeds(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
diff --git a/test/srt/test_int8_kernel.py b/test/srt/test_int8_kernel.py
index b1c757996..959aab900 100644
--- a/test/srt/test_int8_kernel.py
+++ b/test/srt/test_int8_kernel.py
@@ -6,6 +6,7 @@ import torch
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
+from sglang.test.test_utils import CustomTestCase
def native_w8a8_per_token_matmul(A, B, As, Bs, output_dtype=torch.float16):
@@ -71,7 +72,7 @@ def torch_w8a8_per_column_moe(a, w1, w2, w1_s, w2_s, score, topk):
).sum(dim=1)
-class TestW8A8Int8FusedMoE(unittest.TestCase):
+class TestW8A8Int8FusedMoE(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16]
M = [1, 33]
N = [128, 1024]
diff --git a/test/srt/test_json_constrained.py b/test/srt/test_json_constrained.py
index f9295cba2..7b42319bf 100644
--- a/test/srt/test_json_constrained.py
+++ b/test/srt/test_json_constrained.py
@@ -16,6 +16,7 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
@@ -50,7 +51,7 @@ def setup_class(cls, backend: str):
)
-class TestJSONConstrainedOutlinesBackend(unittest.TestCase):
+class TestJSONConstrainedOutlinesBackend(CustomTestCase):
@classmethod
def setUpClass(cls):
setup_class(cls, backend="outlines")
diff --git a/test/srt/test_large_max_new_tokens.py b/test/srt/test_large_max_new_tokens.py
index dcaeef5aa..a6d6baee8 100644
--- a/test/srt/test_large_max_new_tokens.py
+++ b/test/srt/test_large_max_new_tokens.py
@@ -17,11 +17,12 @@ from sglang.test.test_utils import (
DEFAULT_URL_FOR_TEST,
STDERR_FILENAME,
STDOUT_FILENAME,
+ CustomTestCase,
popen_launch_server,
)
-class TestLargeMaxNewTokens(unittest.TestCase):
+class TestLargeMaxNewTokens(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
diff --git a/test/srt/test_matched_stop.py b/test/srt/test_matched_stop.py
index 7b09a6d35..357b07f31 100644
--- a/test/srt/test_matched_stop.py
+++ b/test/srt/test_matched_stop.py
@@ -7,6 +7,7 @@ from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
@@ -18,7 +19,7 @@ The story should span multiple events, challenges, and character developments ov
"""
-class TestMatchedStop(unittest.TestCase):
+class TestMatchedStop(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
diff --git a/test/srt/test_metrics.py b/test/srt/test_metrics.py
index 03dbf48c8..fadd427b7 100644
--- a/test/srt/test_metrics.py
+++ b/test/srt/test_metrics.py
@@ -7,11 +7,12 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestEnableMetrics(unittest.TestCase):
+class TestEnableMetrics(CustomTestCase):
def test_metrics_enabled(self):
"""Test that metrics endpoint returns data when enabled"""
process = popen_launch_server(
diff --git a/test/srt/test_mla.py b/test/srt/test_mla.py
index 42037be7c..40f82e087 100644
--- a/test/srt/test_mla.py
+++ b/test/srt/test_mla.py
@@ -7,11 +7,12 @@ from sglang.test.test_utils import (
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestMLA(unittest.TestCase):
+class TestMLA(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
diff --git a/test/srt/test_mla_deepseek_v3.py b/test/srt/test_mla_deepseek_v3.py
index 42a7df59b..fa950f713 100644
--- a/test/srt/test_mla_deepseek_v3.py
+++ b/test/srt/test_mla_deepseek_v3.py
@@ -9,11 +9,12 @@ from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestMLADeepseekV3(unittest.TestCase):
+class TestMLADeepseekV3(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "lmsys/sglang-ci-dsv3-test"
@@ -48,7 +49,7 @@ class TestMLADeepseekV3(unittest.TestCase):
self.assertGreater(metrics["accuracy"], 0.62)
-class TestDeepseekV3MTP(unittest.TestCase):
+class TestDeepseekV3MTP(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "lmsys/sglang-ci-dsv3-test"
diff --git a/test/srt/test_mla_flashinfer.py b/test/srt/test_mla_flashinfer.py
index 68d1749ab..4fa3eb58f 100644
--- a/test/srt/test_mla_flashinfer.py
+++ b/test/srt/test_mla_flashinfer.py
@@ -9,11 +9,12 @@ from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestFlashinferMLA(unittest.TestCase):
+class TestFlashinferMLA(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "lmsys/sglang-ci-dsv3-test"
@@ -55,7 +56,7 @@ class TestFlashinferMLA(unittest.TestCase):
self.assertGreater(metrics["accuracy"], 0.62)
-class TestFlashinferMLANoRagged(unittest.TestCase):
+class TestFlashinferMLANoRagged(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "lmsys/sglang-ci-dsv3-test"
@@ -99,7 +100,7 @@ class TestFlashinferMLANoRagged(unittest.TestCase):
self.assertGreater(metrics["accuracy"], 0.62)
-class TestFlashinferMLAMTP(unittest.TestCase):
+class TestFlashinferMLAMTP(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "lmsys/sglang-ci-dsv3-test"
diff --git a/test/srt/test_mla_fp8.py b/test/srt/test_mla_fp8.py
index 4fe18b526..a2fac9883 100644
--- a/test/srt/test_mla_fp8.py
+++ b/test/srt/test_mla_fp8.py
@@ -7,11 +7,12 @@ from sglang.test.test_utils import (
DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestMLA(unittest.TestCase):
+class TestMLA(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST
diff --git a/test/srt/test_mla_int8_deepseek_v3.py b/test/srt/test_mla_int8_deepseek_v3.py
index 3fda82139..27b7e0af4 100644
--- a/test/srt/test_mla_int8_deepseek_v3.py
+++ b/test/srt/test_mla_int8_deepseek_v3.py
@@ -9,11 +9,12 @@ from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestMLADeepseekV3ChannelInt8(unittest.TestCase):
+class TestMLADeepseekV3ChannelInt8(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "sgl-project/sglang-ci-dsv3-channel-int8-test"
@@ -48,7 +49,7 @@ class TestMLADeepseekV3ChannelInt8(unittest.TestCase):
self.assertGreater(metrics["accuracy"], 0.62)
-class TestDeepseekV3MTPChannelInt8(unittest.TestCase):
+class TestDeepseekV3MTPChannelInt8(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "sgl-project/sglang-ci-dsv3-channel-int8-test"
@@ -109,7 +110,7 @@ class TestDeepseekV3MTPChannelInt8(unittest.TestCase):
self.assertGreater(avg_spec_accept_length, 2.5)
-class TestMLADeepseekV3BlockInt8(unittest.TestCase):
+class TestMLADeepseekV3BlockInt8(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "sgl-project/sglang-ci-dsv3-block-int8-test"
@@ -144,7 +145,7 @@ class TestMLADeepseekV3BlockInt8(unittest.TestCase):
self.assertGreater(metrics["accuracy"], 0.62)
-class TestDeepseekV3MTPBlockInt8(unittest.TestCase):
+class TestDeepseekV3MTPBlockInt8(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "sgl-project/sglang-ci-dsv3-block-int8-test"
diff --git a/test/srt/test_mla_tp.py b/test/srt/test_mla_tp.py
index 777be82e0..e957cf2de 100644
--- a/test/srt/test_mla_tp.py
+++ b/test/srt/test_mla_tp.py
@@ -8,11 +8,12 @@ from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestDeepseekTP2(unittest.TestCase):
+class TestDeepseekTP2(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "lmsys/sglang-ci-dsv3-test"
diff --git a/test/srt/test_modelopt_fp8kvcache.py b/test/srt/test_modelopt_fp8kvcache.py
index da6bb3651..a4704c239 100644
--- a/test/srt/test_modelopt_fp8kvcache.py
+++ b/test/srt/test_modelopt_fp8kvcache.py
@@ -6,9 +6,10 @@ from sglang.srt.layers.quantization.modelopt_quant import (
ModelOptFp8Config,
ModelOptFp8KVCacheMethod,
)
+from sglang.test.test_utils import CustomTestCase
-class TestModelOptFp8KVCacheMethod(unittest.TestCase):
+class TestModelOptFp8KVCacheMethod(CustomTestCase):
def test_kv_cache_method_initialization(self):
"""Test that ModelOptFp8KVCacheMethod can be instantiated and
inherits from BaseKVCacheMethod."""
diff --git a/test/srt/test_models_from_modelscope.py b/test/srt/test_models_from_modelscope.py
index 3440e5591..dacaae30a 100644
--- a/test/srt/test_models_from_modelscope.py
+++ b/test/srt/test_models_from_modelscope.py
@@ -5,9 +5,10 @@ import unittest
from unittest import mock
from sglang.srt.utils import prepare_model_and_tokenizer
+from sglang.test.test_utils import CustomTestCase
-class TestDownloadFromModelScope(unittest.TestCase):
+class TestDownloadFromModelScope(CustomTestCase):
@classmethod
def setUpClass(cls):
diff --git a/test/srt/test_moe_deepep.py b/test/srt/test_moe_deepep.py
index f89f49810..9c4194823 100644
--- a/test/srt/test_moe_deepep.py
+++ b/test/srt/test_moe_deepep.py
@@ -7,11 +7,12 @@ from sglang.test.test_utils import (
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestDeepEPMoE(unittest.TestCase):
+class TestDeepEPMoE(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
diff --git a/test/srt/test_moe_ep.py b/test/srt/test_moe_ep.py
index 054866e76..284dcba9f 100644
--- a/test/srt/test_moe_ep.py
+++ b/test/srt/test_moe_ep.py
@@ -7,11 +7,12 @@ from sglang.test.test_utils import (
DEFAULT_MLA_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestEpMoE(unittest.TestCase):
+class TestEpMoE(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
@@ -59,7 +60,7 @@ class TestEpMoE(unittest.TestCase):
self.assertGreater(metrics["score"], 0.8)
-class TestEpMoEFP8(unittest.TestCase):
+class TestEpMoEFP8(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MLA_MODEL_NAME_FOR_TEST
diff --git a/test/srt/test_moe_eval_accuracy_large.py b/test/srt/test_moe_eval_accuracy_large.py
index 144c99bb9..26bbd247e 100644
--- a/test/srt/test_moe_eval_accuracy_large.py
+++ b/test/srt/test_moe_eval_accuracy_large.py
@@ -12,13 +12,14 @@ from sglang.test.test_utils import (
DEFAULT_MOE_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
is_in_ci,
popen_launch_server,
write_github_step_summary,
)
-class TestMoEEvalAccuracyLarge(unittest.TestCase):
+class TestMoEEvalAccuracyLarge(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MOE_MODEL_NAME_FOR_TEST
diff --git a/test/srt/test_nightly_gsm8k_eval.py b/test/srt/test_nightly_gsm8k_eval.py
index a8d750029..5caa076b1 100644
--- a/test/srt/test_nightly_gsm8k_eval.py
+++ b/test/srt/test_nightly_gsm8k_eval.py
@@ -15,6 +15,7 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
is_in_ci,
popen_launch_server,
write_github_step_summary,
@@ -129,7 +130,7 @@ def check_model_scores(results):
raise AssertionError("\n".join(failed_models))
-class TestNightlyGsm8KEval(unittest.TestCase):
+class TestNightlyGsm8KEval(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model_groups = [
diff --git a/test/srt/test_nightly_human_eval.py b/test/srt/test_nightly_human_eval.py
index 6558b9eff..2a1ea3b27 100644
--- a/test/srt/test_nightly_human_eval.py
+++ b/test/srt/test_nightly_human_eval.py
@@ -14,11 +14,12 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2,
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
is_in_ci,
)
-class TestNightlyHumanEval(unittest.TestCase):
+class TestNightlyHumanEval(CustomTestCase):
@classmethod
def setUpClass(cls):
if is_in_ci():
diff --git a/test/srt/test_nightly_math_eval.py b/test/srt/test_nightly_math_eval.py
index 3a4eb0adf..20db12454 100644
--- a/test/srt/test_nightly_math_eval.py
+++ b/test/srt/test_nightly_math_eval.py
@@ -7,11 +7,12 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestEvalAccuracyLarge(unittest.TestCase):
+class TestEvalAccuracyLarge(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
diff --git a/test/srt/test_no_chunked_prefill.py b/test/srt/test_no_chunked_prefill.py
index 8252c9ae0..59869ff59 100644
--- a/test/srt/test_no_chunked_prefill.py
+++ b/test/srt/test_no_chunked_prefill.py
@@ -2,12 +2,13 @@ import unittest
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
+ CustomTestCase,
run_bench_serving,
run_mmlu_test,
)
-class TestNoChunkedPrefill(unittest.TestCase):
+class TestNoChunkedPrefill(CustomTestCase):
def test_no_chunked_prefill(self):
run_mmlu_test(
diff --git a/test/srt/test_no_overlap_scheduler.py b/test/srt/test_no_overlap_scheduler.py
index 341207148..d236819bc 100644
--- a/test/srt/test_no_overlap_scheduler.py
+++ b/test/srt/test_no_overlap_scheduler.py
@@ -6,10 +6,10 @@ python3 test_overlap_schedule.py
import unittest
-from sglang.test.test_utils import run_mmlu_test
+from sglang.test.test_utils import CustomTestCase, run_mmlu_test
-class TestOverlapSchedule(unittest.TestCase):
+class TestOverlapSchedule(CustomTestCase):
def test_no_radix_attention_chunked_prefill(self):
run_mmlu_test(
disable_radix_cache=True, chunked_prefill_size=32, disable_overlap=True
diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py
index e9adf617f..e18168eb7 100644
--- a/test/srt/test_openai_server.py
+++ b/test/srt/test_openai_server.py
@@ -18,11 +18,12 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestOpenAIServer(unittest.TestCase):
+class TestOpenAIServer(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
@@ -541,7 +542,7 @@ The SmartHome Mini is a compact smart home assistant available in black or white
# EBNF Test Class: TestOpenAIServerEBNF
# Launches the server with xgrammar, has only EBNF tests
# -------------------------------------------------------------------------
-class TestOpenAIServerEBNF(unittest.TestCase):
+class TestOpenAIServerEBNF(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
@@ -624,7 +625,7 @@ class TestOpenAIServerEBNF(unittest.TestCase):
)
-class TestOpenAIEmbedding(unittest.TestCase):
+class TestOpenAIEmbedding(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST
diff --git a/test/srt/test_page_size.py b/test/srt/test_page_size.py
index 0fabfbfa4..f56e1ed6b 100644
--- a/test/srt/test_page_size.py
+++ b/test/srt/test_page_size.py
@@ -8,11 +8,12 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestPageSize(unittest.TestCase):
+class TestPageSize(CustomTestCase):
@classmethod
def setUpClass(cls):
os.environ["SGLANG_DEBUG_MEMORY_POOL"] = "1"
diff --git a/test/srt/test_penalty.py b/test/srt/test_penalty.py
index e1d11a9ac..bfbd2777a 100644
--- a/test/srt/test_penalty.py
+++ b/test/srt/test_penalty.py
@@ -10,11 +10,12 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestPenalty(unittest.TestCase):
+class TestPenalty(CustomTestCase):
@classmethod
def setUpClass(cls):
diff --git a/test/srt/test_pytorch_sampling_backend.py b/test/srt/test_pytorch_sampling_backend.py
index 4f1403e0a..8a368310e 100644
--- a/test/srt/test_pytorch_sampling_backend.py
+++ b/test/srt/test_pytorch_sampling_backend.py
@@ -9,11 +9,12 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestPyTorchSamplingBackend(unittest.TestCase):
+class TestPyTorchSamplingBackend(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
diff --git a/test/srt/test_radix_attention.py b/test/srt/test_radix_attention.py
index 207303c8c..554801dc0 100644
--- a/test/srt/test_radix_attention.py
+++ b/test/srt/test_radix_attention.py
@@ -8,6 +8,7 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
kill_process_tree,
popen_launch_server,
)
@@ -59,7 +60,7 @@ def run_test(base_url, nodes):
assert res.status_code == 200
-class TestRadixCacheFCFS(unittest.TestCase):
+class TestRadixCacheFCFS(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
diff --git a/test/srt/test_reasoning_content.py b/test/srt/test_reasoning_content.py
index f07dd6339..695dabfdb 100644
--- a/test/srt/test_reasoning_content.py
+++ b/test/srt/test_reasoning_content.py
@@ -20,11 +20,12 @@ from sglang.test.test_utils import (
DEFAULT_REASONING_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestReasoningContentAPI(unittest.TestCase):
+class TestReasoningContentAPI(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_REASONING_MODEL_NAME_FOR_TEST
@@ -181,7 +182,7 @@ class TestReasoningContentAPI(unittest.TestCase):
assert len(response.choices[0].message.content) > 0
-class TestReasoningContentWithoutParser(unittest.TestCase):
+class TestReasoningContentWithoutParser(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_REASONING_MODEL_NAME_FOR_TEST
diff --git a/test/srt/test_regex_constrained.py b/test/srt/test_regex_constrained.py
index d6e448b4e..b0fb49796 100644
--- a/test/srt/test_regex_constrained.py
+++ b/test/srt/test_regex_constrained.py
@@ -15,6 +15,7 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
@@ -41,7 +42,7 @@ def setup_class(cls, backend: str, disable_overlap: bool):
)
-class TestRegexConstrained(unittest.TestCase):
+class TestRegexConstrained(CustomTestCase):
@classmethod
def setUpClass(cls):
setup_class(cls, "xgrammar", disable_overlap=False)
diff --git a/test/srt/test_release_memory_occupation.py b/test/srt/test_release_memory_occupation.py
index c84b64e77..7ccd9f1f7 100644
--- a/test/srt/test_release_memory_occupation.py
+++ b/test/srt/test_release_memory_occupation.py
@@ -5,13 +5,13 @@ import torch
from transformers import AutoModelForCausalLM
import sglang as sgl
-from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
+from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase
# (temporarily) set to true to observe memory usage in nvidia-smi more clearly
_DEBUG_EXTRA = True
-class TestReleaseMemoryOccupation(unittest.TestCase):
+class TestReleaseMemoryOccupation(CustomTestCase):
def test_release_and_resume_occupation(self):
prompt = "Today is a sunny day and I like"
sampling_params = {"temperature": 0, "max_new_tokens": 8}
diff --git a/test/srt/test_request_length_validation.py b/test/srt/test_request_length_validation.py
index 7f057cb20..b3c202f64 100644
--- a/test/srt/test_request_length_validation.py
+++ b/test/srt/test_request_length_validation.py
@@ -7,11 +7,12 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestRequestLengthValidation(unittest.TestCase):
+class TestRequestLengthValidation(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST
diff --git a/test/srt/test_retract_decode.py b/test/srt/test_retract_decode.py
index 3ca8620be..92f5ab915 100644
--- a/test/srt/test_retract_decode.py
+++ b/test/srt/test_retract_decode.py
@@ -8,11 +8,12 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestRetractDecode(unittest.TestCase):
+class TestRetractDecode(CustomTestCase):
@classmethod
def setUpClass(cls):
os.environ["SGLANG_TEST_RETRACT"] = "1"
@@ -40,7 +41,7 @@ class TestRetractDecode(unittest.TestCase):
self.assertGreaterEqual(metrics["score"], 0.65)
-class TestRetractDecodeChunkCache(unittest.TestCase):
+class TestRetractDecodeChunkCache(CustomTestCase):
@classmethod
def setUpClass(cls):
os.environ["SGLANG_TEST_RETRACT"] = "1"
diff --git a/test/srt/test_sagemaker_server.py b/test/srt/test_sagemaker_server.py
index fab7ca4dc..68688c112 100644
--- a/test/srt/test_sagemaker_server.py
+++ b/test/srt/test_sagemaker_server.py
@@ -13,11 +13,12 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestSageMakerServer(unittest.TestCase):
+class TestSageMakerServer(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
diff --git a/test/srt/test_schedule_policy.py b/test/srt/test_schedule_policy.py
index 52c5b8289..305f0ca94 100644
--- a/test/srt/test_schedule_policy.py
+++ b/test/srt/test_schedule_policy.py
@@ -8,9 +8,10 @@ from sglang.srt.managers.schedule_policy import (
)
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
from sglang.srt.sampling.sampling_params import SamplingParams
+from sglang.test.test_utils import CustomTestCase
-class TestSchedulePolicy(unittest.TestCase):
+class TestSchedulePolicy(CustomTestCase):
def setUp(self):
self.tree_cache = RadixCache(None, None, False)
diff --git a/test/srt/test_server_args.py b/test/srt/test_server_args.py
index de5c0f1bc..bb5618a15 100644
--- a/test/srt/test_server_args.py
+++ b/test/srt/test_server_args.py
@@ -2,9 +2,10 @@ import json
import unittest
from sglang.srt.server_args import prepare_server_args
+from sglang.test.test_utils import CustomTestCase
-class TestPrepareServerArgs(unittest.TestCase):
+class TestPrepareServerArgs(CustomTestCase):
def test_prepare_server_args(self):
server_args = prepare_server_args(
[
diff --git a/test/srt/test_session_control.py b/test/srt/test_session_control.py
index 9a3de8d13..7a68b2b17 100644
--- a/test/srt/test_session_control.py
+++ b/test/srt/test_session_control.py
@@ -19,6 +19,7 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
@@ -27,7 +28,7 @@ def remove_prefix(text: str, prefix: str) -> str:
return text[len(prefix) :] if text.startswith(prefix) else text
-class TestSessionControl(unittest.TestCase):
+class TestSessionControl(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
@@ -560,7 +561,7 @@ class TestSessionControl(unittest.TestCase):
)
-class TestSessionControlVision(unittest.TestCase):
+class TestSessionControlVision(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "lmms-lab/llava-onevision-qwen2-7b-ov"
diff --git a/test/srt/test_skip_tokenizer_init.py b/test/srt/test_skip_tokenizer_init.py
index 41787e2c1..a9d27242f 100644
--- a/test/srt/test_skip_tokenizer_init.py
+++ b/test/srt/test_skip_tokenizer_init.py
@@ -19,11 +19,12 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_VLM_MODEL_NAME,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestSkipTokenizerInit(unittest.TestCase):
+class TestSkipTokenizerInit(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
diff --git a/test/srt/test_srt_endpoint.py b/test/srt/test_srt_endpoint.py
index 9673b19c5..81671e4a0 100644
--- a/test/srt/test_srt_endpoint.py
+++ b/test/srt/test_srt_endpoint.py
@@ -20,12 +20,13 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
run_logprob_check,
)
-class TestSRTEndpoint(unittest.TestCase):
+class TestSRTEndpoint(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
diff --git a/test/srt/test_srt_engine.py b/test/srt/test_srt_engine.py
index c535d5c06..672344c63 100644
--- a/test/srt/test_srt_engine.py
+++ b/test/srt/test_srt_engine.py
@@ -18,10 +18,11 @@ from sglang.test.few_shot_gsm8k_engine import run_eval
from sglang.test.test_utils import (
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
+ CustomTestCase,
)
-class TestSRTEngine(unittest.TestCase):
+class TestSRTEngine(CustomTestCase):
def test_1_engine_runtime_consistency(self):
prompt = "Today is a sunny day and I like"
diff --git a/test/srt/test_srt_engine_with_quant_args.py b/test/srt/test_srt_engine_with_quant_args.py
index 3851ab41a..e3b30ea39 100644
--- a/test/srt/test_srt_engine_with_quant_args.py
+++ b/test/srt/test_srt_engine_with_quant_args.py
@@ -1,10 +1,10 @@
import unittest
import sglang as sgl
-from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
+from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase
-class TestSRTEngineWithQuantArgs(unittest.TestCase):
+class TestSRTEngineWithQuantArgs(CustomTestCase):
def test_1_quantization_args(self):
diff --git a/test/srt/test_torch_compile.py b/test/srt/test_torch_compile.py
index e71de3391..760cec84b 100644
--- a/test/srt/test_torch_compile.py
+++ b/test/srt/test_torch_compile.py
@@ -10,11 +10,12 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestTorchCompile(unittest.TestCase):
+class TestTorchCompile(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
diff --git a/test/srt/test_torch_compile_moe.py b/test/srt/test_torch_compile_moe.py
index fb78dd7f4..34c80d450 100644
--- a/test/srt/test_torch_compile_moe.py
+++ b/test/srt/test_torch_compile_moe.py
@@ -10,11 +10,12 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestTorchCompileMoe(unittest.TestCase):
+class TestTorchCompileMoe(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST
diff --git a/test/srt/test_torch_native_attention_backend.py b/test/srt/test_torch_native_attention_backend.py
index 512aa5597..3af0557d0 100644
--- a/test/srt/test_torch_native_attention_backend.py
+++ b/test/srt/test_torch_native_attention_backend.py
@@ -12,13 +12,14 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
is_in_ci,
popen_launch_server,
run_bench_one_batch,
)
-class TestTorchNativeAttnBackend(unittest.TestCase):
+class TestTorchNativeAttnBackend(CustomTestCase):
def test_latency(self):
output_throughput = run_bench_one_batch(
DEFAULT_MODEL_NAME_FOR_TEST,
diff --git a/test/srt/test_torch_tp.py b/test/srt/test_torch_tp.py
index e17b212f6..8349cdf29 100644
--- a/test/srt/test_torch_tp.py
+++ b/test/srt/test_torch_tp.py
@@ -1,9 +1,9 @@
import unittest
-from sglang.test.test_utils import is_in_ci, run_bench_one_batch
+from sglang.test.test_utils import CustomTestCase, is_in_ci, run_bench_one_batch
-class TestTorchTP(unittest.TestCase):
+class TestTorchTP(CustomTestCase):
def test_torch_native_llama(self):
output_throughput = run_bench_one_batch(
"meta-llama/Meta-Llama-3-8B",
diff --git a/test/srt/test_torchao.py b/test/srt/test_torchao.py
index a6414c60b..77ec0a570 100644
--- a/test/srt/test_torchao.py
+++ b/test/srt/test_torchao.py
@@ -9,11 +9,12 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestTorchAO(unittest.TestCase):
+class TestTorchAO(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
diff --git a/test/srt/test_triton_attention_backend.py b/test/srt/test_triton_attention_backend.py
index 4e479c809..c4cdf5e5b 100644
--- a/test/srt/test_triton_attention_backend.py
+++ b/test/srt/test_triton_attention_backend.py
@@ -12,13 +12,14 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
is_in_ci,
popen_launch_server,
run_bench_one_batch,
)
-class TestTritonAttnBackend(unittest.TestCase):
+class TestTritonAttnBackend(CustomTestCase):
def test_latency(self):
output_throughput = run_bench_one_batch(
DEFAULT_MODEL_NAME_FOR_TEST,
diff --git a/test/srt/test_triton_attention_kernels.py b/test/srt/test_triton_attention_kernels.py
index 2b90ce81b..184733e7f 100644
--- a/test/srt/test_triton_attention_kernels.py
+++ b/test/srt/test_triton_attention_kernels.py
@@ -15,9 +15,10 @@ from sglang.srt.layers.attention.triton_ops.extend_attention import (
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
context_attention_fwd,
)
+from sglang.test.test_utils import CustomTestCase
-class TestTritonAttention(unittest.TestCase):
+class TestTritonAttention(CustomTestCase):
def _set_all_seeds(self, seed):
"""Set all random seeds for reproducibility."""
diff --git a/test/srt/test_triton_attention_rocm_mla.py b/test/srt/test_triton_attention_rocm_mla.py
index c2a11f979..f3074aeeb 100644
--- a/test/srt/test_triton_attention_rocm_mla.py
+++ b/test/srt/test_triton_attention_rocm_mla.py
@@ -10,9 +10,10 @@ from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
decode_attention_fwd_grouped_rope,
)
from sglang.srt.layers.rotary_embedding import DeepseekScalingRotaryEmbedding
+from sglang.test.test_utils import CustomTestCase
-class TestTritonAttentionMLA(unittest.TestCase):
+class TestTritonAttentionMLA(CustomTestCase):
def _set_all_seeds(self, seed):
"""Set all random seeds for reproducibility."""
diff --git a/test/srt/test_update_weights_from_disk.py b/test/srt/test_update_weights_from_disk.py
index 248525048..11b7e678a 100644
--- a/test/srt/test_update_weights_from_disk.py
+++ b/test/srt/test_update_weights_from_disk.py
@@ -10,6 +10,7 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
is_in_ci,
popen_launch_server,
)
@@ -18,7 +19,7 @@ from sglang.test.test_utils import (
###############################################################################
# Engine Mode Tests (Single-configuration)
###############################################################################
-class TestEngineUpdateWeightsFromDisk(unittest.TestCase):
+class TestEngineUpdateWeightsFromDisk(CustomTestCase):
def setUp(self):
self.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
# Initialize the engine in offline (direct) mode.
@@ -70,7 +71,7 @@ class TestEngineUpdateWeightsFromDisk(unittest.TestCase):
###############################################################################
# HTTP Server Mode Tests (Single-configuration)
###############################################################################
-class TestServerUpdateWeightsFromDisk(unittest.TestCase):
+class TestServerUpdateWeightsFromDisk(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
@@ -159,7 +160,7 @@ class TestServerUpdateWeightsFromDisk(unittest.TestCase):
# - In a non-CI environment: test both Engine and Server modes, and enumerate all combinations
# with tp and dp ranging from 1 to 2.
###############################################################################
-class TestUpdateWeightsFromDiskParameterized(unittest.TestCase):
+class TestUpdateWeightsFromDiskParameterized(CustomTestCase):
def run_common_test(self, mode, tp, dp):
"""
Common test procedure for update_weights_from_disk.
diff --git a/test/srt/test_update_weights_from_distributed.py b/test/srt/test_update_weights_from_distributed.py
index fc15efcfe..7352e757a 100644
--- a/test/srt/test_update_weights_from_distributed.py
+++ b/test/srt/test_update_weights_from_distributed.py
@@ -33,6 +33,7 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
is_in_ci,
popen_launch_server,
)
@@ -523,7 +524,7 @@ def test_update_weights_from_distributed(
torch.cuda.empty_cache()
-class TestUpdateWeightsFromDistributed(unittest.TestCase):
+class TestUpdateWeightsFromDistributed(CustomTestCase):
def test_update_weights_from_distributed(self):
diff --git a/test/srt/test_update_weights_from_tensor.py b/test/srt/test_update_weights_from_tensor.py
index 1e0134715..1f3592447 100644
--- a/test/srt/test_update_weights_from_tensor.py
+++ b/test/srt/test_update_weights_from_tensor.py
@@ -5,7 +5,7 @@ import unittest
import torch
import sglang as sgl
-from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST
+from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase
def test_update_weights_from_tensor(tp_size):
@@ -40,7 +40,7 @@ def test_update_weights_from_tensor(tp_size):
), f"Memory leak detected: {memory_after - memory_before} bytes"
-class TestUpdateWeightsFromTensor(unittest.TestCase):
+class TestUpdateWeightsFromTensor(CustomTestCase):
def test_update_weights_from_tensor(self):
tp_sizes = [1, 2]
for tp_size in tp_sizes:
diff --git a/test/srt/test_verl_engine.py b/test/srt/test_verl_engine.py
index 7f76bc654..72c0d5225 100644
--- a/test/srt/test_verl_engine.py
+++ b/test/srt/test_verl_engine.py
@@ -27,7 +27,7 @@ from sglang.test.runners import (
check_close_model_outputs,
get_dtype_str,
)
-from sglang.test.test_utils import is_in_ci
+from sglang.test.test_utils import CustomTestCase, is_in_ci
_MAX_NEW_TOKENS = 8
_PROMPTS = ["1+1=2, 1+2=3, 1+3=4, 1+4=5, 1+5=", "1*1=1, 1*2=2, 1*3=3, 1*4=4, 1*5="]
@@ -73,7 +73,7 @@ ALL_OTHER_MODELS = [
]
-class TestVerlEngine(unittest.TestCase):
+class TestVerlEngine(CustomTestCase):
@classmethod
def setUpClass(cls):
multiprocessing.set_start_method("spawn")
diff --git a/test/srt/test_vertex_endpoint.py b/test/srt/test_vertex_endpoint.py
index b20dc8fda..a899d6251 100644
--- a/test/srt/test_vertex_endpoint.py
+++ b/test/srt/test_vertex_endpoint.py
@@ -11,11 +11,12 @@ from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestVertexEndpoint(unittest.TestCase):
+class TestVertexEndpoint(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
diff --git a/test/srt/test_vision_chunked_prefill.py b/test/srt/test_vision_chunked_prefill.py
index d0db034fb..cb5c132ca 100644
--- a/test/srt/test_vision_chunked_prefill.py
+++ b/test/srt/test_vision_chunked_prefill.py
@@ -18,11 +18,12 @@ from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestVisionChunkedPrefill(unittest.TestCase):
+class TestVisionChunkedPrefill(CustomTestCase):
def prepare_video_messages(self, video_path, max_frames_num=8):
# We import decord here to avoid a strange Segmentation fault (core dumped) issue.
# The following import order will cause Segmentation fault.
diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py
index 3ca6330bb..4ca61d448 100644
--- a/test/srt/test_vision_openai_server.py
+++ b/test/srt/test_vision_openai_server.py
@@ -20,6 +20,7 @@ from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
@@ -35,7 +36,7 @@ AUDIO_TRUMP_SPEECH_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test
AUDIO_BIRD_SONG_URL = "https://raw.githubusercontent.com/sgl-project/sgl-test-files/refs/heads/main/audios/bird_song.mp3"
-class TestOpenAIVisionServer(unittest.TestCase):
+class TestOpenAIVisionServer(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "lmms-lab/llava-onevision-qwen2-0.5b-ov"
@@ -507,7 +508,7 @@ class TestQwen2_5_VLServer(TestOpenAIVisionServer):
cls.base_url += "/v1"
-class TestVLMContextLengthIssue(unittest.TestCase):
+class TestVLMContextLengthIssue(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "Qwen/Qwen2-VL-7B-Instruct"
diff --git a/test/srt/test_w8a8_quantization.py b/test/srt/test_w8a8_quantization.py
index 78579d5e2..2cb2fa073 100644
--- a/test/srt/test_w8a8_quantization.py
+++ b/test/srt/test_w8a8_quantization.py
@@ -9,11 +9,12 @@ from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_utils import (
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
+ CustomTestCase,
popen_launch_server,
)
-class TestW8A8(unittest.TestCase):
+class TestW8A8(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.model = "neuralmagic/Meta-Llama-3-8B-Instruct-quantized.w8a8"