Add retry for flaky tests in CI (#4755)
This commit is contained in:
@@ -4,9 +4,10 @@ import unittest
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.activation import GeluAndMul
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
class TestGeluAndMul(unittest.TestCase):
|
||||
class TestGeluAndMul(CustomTestCase):
|
||||
DTYPES = [torch.half, torch.bfloat16]
|
||||
NUM_TOKENS = [7, 83, 2048]
|
||||
D = [512, 4096, 5120, 13824]
|
||||
|
||||
@@ -11,6 +11,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
static_quant_fp8,
|
||||
w8a8_block_fp8_matmul,
|
||||
)
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
_is_cuda = torch.cuda.is_available() and torch.version.cuda
|
||||
|
||||
@@ -44,7 +45,7 @@ def native_per_token_group_quant_fp8(
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
class TestPerTokenGroupQuantFP8(unittest.TestCase):
|
||||
class TestPerTokenGroupQuantFP8(CustomTestCase):
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float32]
|
||||
NUM_TOKENS = [7, 83, 2048]
|
||||
D = [512, 4096, 5120, 13824]
|
||||
@@ -111,7 +112,7 @@ def native_static_quant_fp8(x, x_s, dtype=torch.float8_e4m3fn):
|
||||
return x_q, x_s
|
||||
|
||||
|
||||
class TestStaticQuantFP8(unittest.TestCase):
|
||||
class TestStaticQuantFP8(CustomTestCase):
|
||||
DTYPES = [torch.half, torch.bfloat16, torch.float32]
|
||||
NUM_TOKENS = [7, 83, 2048]
|
||||
D = [512, 4096, 5120, 13824]
|
||||
@@ -210,7 +211,7 @@ def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.fl
|
||||
return C
|
||||
|
||||
|
||||
class TestW8A8BlockFP8Matmul(unittest.TestCase):
|
||||
class TestW8A8BlockFP8Matmul(CustomTestCase):
|
||||
|
||||
if not _is_cuda:
|
||||
OUT_DTYPES = [torch.float32, torch.half, torch.bfloat16]
|
||||
@@ -331,7 +332,7 @@ def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
|
||||
).sum(dim=1)
|
||||
|
||||
|
||||
class TestW8A8BlockFP8FusedMoE(unittest.TestCase):
|
||||
class TestW8A8BlockFP8FusedMoE(CustomTestCase):
|
||||
DTYPES = [torch.float32, torch.half, torch.bfloat16]
|
||||
M = [1, 33, 64, 222, 1024 * 128]
|
||||
N = [128, 1024, 2048]
|
||||
|
||||
@@ -13,6 +13,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
silu_and_mul_triton_kernel,
|
||||
)
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
# For test
|
||||
@@ -232,7 +233,7 @@ def block_dequant(
|
||||
return x_dq_block
|
||||
|
||||
|
||||
class TestW8A8BlockFP8EPMoE(unittest.TestCase):
|
||||
class TestW8A8BlockFP8EPMoE(CustomTestCase):
|
||||
DTYPES = [torch.half, torch.bfloat16]
|
||||
M = [1, 222, 1024, 2048]
|
||||
N = [128, 1024, 2048]
|
||||
|
||||
@@ -3,9 +3,10 @@ import unittest
|
||||
import torch
|
||||
|
||||
from sglang.srt.utils import DynamicGradMode
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
class TestDynamicGradMode(unittest.TestCase):
|
||||
class TestDynamicGradMode(CustomTestCase):
|
||||
def test_inference(self):
|
||||
# Test inference_mode
|
||||
DynamicGradMode.set_inference_mode(True)
|
||||
|
||||
@@ -4,9 +4,10 @@ import unittest
|
||||
import torch
|
||||
|
||||
from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm
|
||||
from sglang.test.test_utils import CustomTestCase
|
||||
|
||||
|
||||
class TestRMSNorm(unittest.TestCase):
|
||||
class TestRMSNorm(CustomTestCase):
|
||||
DTYPES = [torch.half, torch.bfloat16]
|
||||
NUM_TOKENS = [7, 83, 4096]
|
||||
HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, 8199]
|
||||
@@ -56,7 +57,7 @@ class TestRMSNorm(unittest.TestCase):
|
||||
self._run_rms_norm_test(*params)
|
||||
|
||||
|
||||
class TestGemmaRMSNorm(unittest.TestCase):
|
||||
class TestGemmaRMSNorm(CustomTestCase):
|
||||
DTYPES = [torch.half, torch.bfloat16]
|
||||
NUM_TOKENS = [7, 83, 4096]
|
||||
HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, 8199]
|
||||
|
||||
@@ -8,6 +8,7 @@ import random
|
||||
import subprocess
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
import unittest
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
@@ -998,3 +999,30 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple):
|
||||
rank += 1
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
class CustomTestCase(unittest.TestCase):
|
||||
def _callTestMethod(self, method):
|
||||
_retry_execution(
|
||||
lambda: super(CustomTestCase, self)._callTestMethod(method),
|
||||
max_retry=_get_max_retry(),
|
||||
)
|
||||
|
||||
|
||||
def _get_max_retry():
|
||||
return int(os.environ.get("SGLANG_TEST_MAX_RETRY", "2" if is_in_ci() else "0"))
|
||||
|
||||
|
||||
def _retry_execution(fn, max_retry: int):
|
||||
if max_retry == 0:
|
||||
fn()
|
||||
return
|
||||
|
||||
try:
|
||||
fn()
|
||||
except Exception as e:
|
||||
print(
|
||||
f"retry_execution failed once and will retry. This may be an error or a flaky test. Error: {e}"
|
||||
)
|
||||
traceback.print_exc()
|
||||
_retry_execution(fn, max_retry=max_retry - 1)
|
||||
|
||||
Reference in New Issue
Block a user