Simplify FA3 tests (#5779)

This commit is contained in:
Lianmin Zheng
2025-04-27 01:30:17 -07:00
committed by GitHub
parent 6e313c1b8b
commit 4d23ba08f5
3 changed files with 14 additions and 67 deletions

View File

@@ -30,7 +30,7 @@ suites = {
TestFile("test_chunked_prefill.py", 336), TestFile("test_chunked_prefill.py", 336),
TestFile("test_eagle_infer.py", 500), TestFile("test_eagle_infer.py", 500),
TestFile("test_ebnf_constrained.py"), TestFile("test_ebnf_constrained.py"),
TestFile("test_fa3.py", 500), TestFile("test_fa3.py", 400),
TestFile("test_fp8_kernel.py", 8), TestFile("test_fp8_kernel.py", 8),
TestFile("test_embedding_openai_server.py", 36), TestFile("test_embedding_openai_server.py", 36),
TestFile("test_hidden_states.py", 55), TestFile("test_hidden_states.py", 55),
@@ -92,7 +92,7 @@ suites = {
TestFile("test_verl_engine.py", 100), TestFile("test_verl_engine.py", 100),
], ],
"per-commit-8-gpu": [ "per-commit-8-gpu": [
TestFile("test_local_attn.py", 100), TestFile("test_local_attn.py", 250),
], ],
"nightly": [ "nightly": [
TestFile("test_nightly_gsm8k_eval.py"), TestFile("test_nightly_gsm8k_eval.py"),

View File

@@ -3,7 +3,6 @@ import unittest
from types import SimpleNamespace from types import SimpleNamespace
import requests import requests
import torch
from sglang.srt.utils import get_device_sm, kill_process_tree from sglang.srt.utils import get_device_sm, kill_process_tree
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
@@ -14,6 +13,7 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN, DEFAULT_MODEL_NAME_FOR_TEST_MLA_NEXTN,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server, popen_launch_server,
) )
@@ -47,9 +47,8 @@ if OFFLINE_MODE:
# Default server arguments shared across all tests # Default server arguments shared across all tests
DEFAULT_SERVER_ARGS = [ DEFAULT_SERVER_ARGS = [
"--trust-remote-code", "--trust-remote-code",
"--enable-torch-compile",
"--cuda-graph-max-bs", "--cuda-graph-max-bs",
"2", "4",
"--attention-backend", "--attention-backend",
"fa3", "fa3",
] ]
@@ -60,7 +59,7 @@ Integration test for python/sglang/srt/layers/attention/flashattention_backend.p
@unittest.skipIf(get_device_sm() < 90, "Test requires CUDA SM 90 or higher") @unittest.skipIf(get_device_sm() < 90, "Test requires CUDA SM 90 or higher")
class BaseFlashAttentionTest(unittest.TestCase): class BaseFlashAttentionTest(CustomTestCase):
"""Base class for testing FlashAttention3.""" """Base class for testing FlashAttention3."""
model = DEFAULT_MODEL_NAME_FOR_TEST model = DEFAULT_MODEL_NAME_FOR_TEST
@@ -78,13 +77,13 @@ class BaseFlashAttentionTest(unittest.TestCase):
def setUpClass(cls): def setUpClass(cls):
# disable deep gemm precompile to make launch server faster # disable deep gemm precompile to make launch server faster
# please don't do this if you want to make your inference workload faster # please don't do this if you want to make your inference workload faster
os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "False" os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "false"
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
cls.process = popen_launch_server( cls.process = popen_launch_server(
cls.model, cls.model,
cls.base_url, cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=cls.get_server_args(), other_args=cls.get_server_args(),
env=os.environ,
) )
@classmethod @classmethod
@@ -92,6 +91,8 @@ class BaseFlashAttentionTest(unittest.TestCase):
kill_process_tree(cls.process.pid) kill_process_tree(cls.process.pid)
def test_gsm8k(self): def test_gsm8k(self):
requests.get(self.base_url + "/flush_cache")
args = SimpleNamespace( args = SimpleNamespace(
num_shots=4, num_shots=4,
num_questions=100, num_questions=100,
@@ -102,7 +103,7 @@ class BaseFlashAttentionTest(unittest.TestCase):
data_path=GSM_DATASET_PATH, data_path=GSM_DATASET_PATH,
) )
metrics = run_eval_few_shot_gsm8k(args) metrics = run_eval_few_shot_gsm8k(args)
print(metrics) print(f"{metrics=}")
# Use the appropriate metric key based on the test class # Use the appropriate metric key based on the test class
metric_key = "accuracy" metric_key = "accuracy"
@@ -192,60 +193,6 @@ class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest):
return args return args
class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest):
"""Test FlashAttention3 with speculative decode enabled, topk > 1"""
model = DEFAULT_MODEL_NAME_FOR_TEST
@classmethod
def get_server_args(cls):
args = super().get_server_args()
args.extend(
[
"--cuda-graph-max-bs",
"2",
"--speculative-algorithm",
"EAGLE3",
"--speculative-draft",
DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3,
"--speculative-num-steps",
"5",
"--speculative-eagle-topk",
"4",
"--speculative-num-draft-tokens",
"8",
"--dtype",
"float16",
]
)
return args
def test_gsm8k(self):
"""
Override the test_gsm8k to further test for average speculative accept length.
"""
requests.get(self.base_url + "/flush_cache")
args = SimpleNamespace(
num_shots=5,
data_path=GSM_DATASET_PATH,
num_questions=200,
max_new_tokens=512,
parallel=128,
host="http://127.0.0.1",
port=int(self.base_url.split(":")[-1]),
)
metrics = run_eval_few_shot_gsm8k(args)
print(metrics)
self.assertGreater(metrics["accuracy"], 0.60)
server_info = requests.get(self.base_url + "/get_server_info")
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, 1.8)
class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest): class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest):
"""Test FlashAttention3 with speculative decode enabled with deepseek v3 test model and its nextN model""" """Test FlashAttention3 with speculative decode enabled with deepseek v3 test model and its nextN model"""

View File

@@ -10,12 +10,13 @@ from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION, DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST, DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server, popen_launch_server,
) )
@unittest.skipIf(get_device_sm() < 90, "Test requires CUDA SM 90 or higher") @unittest.skipIf(get_device_sm() < 90, "Test requires CUDA SM 90 or higher")
class TestFlashAttention3LocalAttn(unittest.TestCase): class TestFlashAttention3LocalAttn(CustomTestCase):
model = DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION model = DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION
base_url = DEFAULT_URL_FOR_TEST base_url = DEFAULT_URL_FOR_TEST
accuracy_threshold = 0.90 accuracy_threshold = 0.90
@@ -23,7 +24,6 @@ class TestFlashAttention3LocalAttn(unittest.TestCase):
@classmethod @classmethod
def get_server_args(cls): def get_server_args(cls):
return [ return [
"--trust-remote-code",
"--cuda-graph-max-bs", "--cuda-graph-max-bs",
"2", "2",
"--attention-backend", "--attention-backend",
@@ -36,8 +36,6 @@ class TestFlashAttention3LocalAttn(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
# disable deep gemm precompile to make launch server faster
# please don't do this if you want to make your inference workload faster
cls.process = popen_launch_server( cls.process = popen_launch_server(
cls.model, cls.model,
cls.base_url, cls.base_url,
@@ -51,6 +49,8 @@ class TestFlashAttention3LocalAttn(unittest.TestCase):
kill_process_tree(cls.process.pid) kill_process_tree(cls.process.pid)
def test_gsm8k(self): def test_gsm8k(self):
requests.get(self.base_url + "/flush_cache")
args = SimpleNamespace( args = SimpleNamespace(
num_shots=4, num_shots=4,
num_questions=100, num_questions=100,