Add retry for flaky tests in CI (#4755)

This commit is contained in:
fzyzcjy
2025-03-26 07:53:12 +08:00
committed by GitHub
parent 52029bd1e3
commit 15ddd84322
112 changed files with 273 additions and 152 deletions

View File

@@ -4,9 +4,10 @@ import unittest
import torch
from sglang.srt.layers.activation import GeluAndMul
from sglang.test.test_utils import CustomTestCase
class TestGeluAndMul(unittest.TestCase):
class TestGeluAndMul(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16]
NUM_TOKENS = [7, 83, 2048]
D = [512, 4096, 5120, 13824]

View File

@@ -11,6 +11,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
static_quant_fp8,
w8a8_block_fp8_matmul,
)
from sglang.test.test_utils import CustomTestCase
_is_cuda = torch.cuda.is_available() and torch.version.cuda
@@ -44,7 +45,7 @@ def native_per_token_group_quant_fp8(
return x_q, x_s
class TestPerTokenGroupQuantFP8(unittest.TestCase):
class TestPerTokenGroupQuantFP8(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16, torch.float32]
NUM_TOKENS = [7, 83, 2048]
D = [512, 4096, 5120, 13824]
@@ -111,7 +112,7 @@ def native_static_quant_fp8(x, x_s, dtype=torch.float8_e4m3fn):
return x_q, x_s
class TestStaticQuantFP8(unittest.TestCase):
class TestStaticQuantFP8(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16, torch.float32]
NUM_TOKENS = [7, 83, 2048]
D = [512, 4096, 5120, 13824]
@@ -210,7 +211,7 @@ def native_w8a8_block_fp8_matmul(A, B, As, Bs, block_size, output_dtype=torch.fl
return C
class TestW8A8BlockFP8Matmul(unittest.TestCase):
class TestW8A8BlockFP8Matmul(CustomTestCase):
if not _is_cuda:
OUT_DTYPES = [torch.float32, torch.half, torch.bfloat16]
@@ -331,7 +332,7 @@ def torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score, topk, block_shape):
).sum(dim=1)
class TestW8A8BlockFP8FusedMoE(unittest.TestCase):
class TestW8A8BlockFP8FusedMoE(CustomTestCase):
DTYPES = [torch.float32, torch.half, torch.bfloat16]
M = [1, 33, 64, 222, 1024 * 128]
N = [128, 1024, 2048]

View File

@@ -13,6 +13,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
silu_and_mul_triton_kernel,
)
from sglang.srt.layers.moe.topk import select_experts
from sglang.test.test_utils import CustomTestCase
# For test
@@ -232,7 +233,7 @@ def block_dequant(
return x_dq_block
class TestW8A8BlockFP8EPMoE(unittest.TestCase):
class TestW8A8BlockFP8EPMoE(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16]
M = [1, 222, 1024, 2048]
N = [128, 1024, 2048]

View File

@@ -3,9 +3,10 @@ import unittest
import torch
from sglang.srt.utils import DynamicGradMode
from sglang.test.test_utils import CustomTestCase
class TestDynamicGradMode(unittest.TestCase):
class TestDynamicGradMode(CustomTestCase):
def test_inference(self):
# Test inference_mode
DynamicGradMode.set_inference_mode(True)

View File

@@ -4,9 +4,10 @@ import unittest
import torch
from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm
from sglang.test.test_utils import CustomTestCase
class TestRMSNorm(unittest.TestCase):
class TestRMSNorm(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16]
NUM_TOKENS = [7, 83, 4096]
HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, 8199]
@@ -56,7 +57,7 @@ class TestRMSNorm(unittest.TestCase):
self._run_rms_norm_test(*params)
class TestGemmaRMSNorm(unittest.TestCase):
class TestGemmaRMSNorm(CustomTestCase):
DTYPES = [torch.half, torch.bfloat16]
NUM_TOKENS = [7, 83, 4096]
HIDDEN_SIZES = [768, 769, 770, 771, 5120, 5124, 5125, 5126, 8192, 8199]

View File

@@ -8,6 +8,7 @@ import random
import subprocess
import threading
import time
import traceback
import unittest
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
@@ -998,3 +999,30 @@ def run_logprob_check(self: unittest.TestCase, arg: Tuple):
rank += 1
else:
raise
class CustomTestCase(unittest.TestCase):
def _callTestMethod(self, method):
_retry_execution(
lambda: super(CustomTestCase, self)._callTestMethod(method),
max_retry=_get_max_retry(),
)
def _get_max_retry():
return int(os.environ.get("SGLANG_TEST_MAX_RETRY", "2" if is_in_ci() else "0"))
def _retry_execution(fn, max_retry: int):
if max_retry == 0:
fn()
return
try:
fn()
except Exception as e:
print(
f"retry_execution failed once and will retry. This may be an error or a flaky test. Error: {e}"
)
traceback.print_exc()
_retry_execution(fn, max_retry=max_retry - 1)