Simplify FA3 tests (#5779)
This commit is contained in:
@@ -10,12 +10,13 @@ from sglang.test.test_utils import (
|
||||
DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
CustomTestCase,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(get_device_sm() < 90, "Test requires CUDA SM 90 or higher")
|
||||
class TestFlashAttention3LocalAttn(unittest.TestCase):
|
||||
class TestFlashAttention3LocalAttn(CustomTestCase):
|
||||
model = DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION
|
||||
base_url = DEFAULT_URL_FOR_TEST
|
||||
accuracy_threshold = 0.90
|
||||
@@ -23,7 +24,6 @@ class TestFlashAttention3LocalAttn(unittest.TestCase):
|
||||
@classmethod
|
||||
def get_server_args(cls):
|
||||
return [
|
||||
"--trust-remote-code",
|
||||
"--cuda-graph-max-bs",
|
||||
"2",
|
||||
"--attention-backend",
|
||||
@@ -36,8 +36,6 @@ class TestFlashAttention3LocalAttn(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# disable deep gemm precompile to make launch server faster
|
||||
# please don't do this if you want to make your inference workload faster
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
@@ -51,6 +49,8 @@ class TestFlashAttention3LocalAttn(unittest.TestCase):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def test_gsm8k(self):
|
||||
requests.get(self.base_url + "/flush_cache")
|
||||
|
||||
args = SimpleNamespace(
|
||||
num_shots=4,
|
||||
num_questions=100,
|
||||
|
||||
Reference in New Issue
Block a user