Add draft extend CUDA graph for flashinfer backend (#6805)

This commit is contained in:
Ke Bao
2025-06-02 16:51:26 +08:00
committed by GitHub
parent 55444ed667
commit a2cb5913a0
5 changed files with 170 additions and 3 deletions

View File

@@ -19,6 +19,7 @@ from sglang.test.few_shot_gsm8k import run_eval
from sglang.test.test_utils import (
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
DEFAULT_MODEL_NAME_FOR_TEST_MLA,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
@@ -602,6 +603,7 @@ class TestEAGLEDraftExtend(CustomTestCase):
"fa3",
],
)
cls.accept_len_threshold = 1.50
@classmethod
def tearDownClass(cls):
@@ -636,7 +638,89 @@ class TestEAGLEDraftExtend(CustomTestCase):
acc_length = 1.0
print(f"{acc_length=}")
self.assertGreater(acc_length, 1.50)
self.assertGreater(acc_length, self.accept_len_threshold)
class TestEAGLEDraftExtendFlashinfer(TestEAGLEDraftExtend):
@classmethod
def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--speculative-algorithm",
"EAGLE",
"--speculative-draft-model-path",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
1,
"--speculative-eagle-topk",
1,
"--speculative-num-draft-tokens",
2,
"--max-running-requests",
4,
"--attention-backend",
"flashinfer",
],
)
cls.accept_len_threshold = 1.50
class TestEAGLEDraftExtendTriton(TestEAGLEDraftExtend):
@classmethod
def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--speculative-algorithm",
"EAGLE",
"--speculative-draft-model-path",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
1,
"--speculative-eagle-topk",
1,
"--speculative-num-draft-tokens",
2,
"--max-running-requests",
4,
"--attention-backend",
"triton",
],
)
cls.accept_len_threshold = 1.50
class TestEAGLEDraftExtendFlashinferMLA(TestEAGLEDraftExtend):
@classmethod
def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
DEFAULT_MODEL_NAME_FOR_TEST_MLA,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--speculative-algorithm",
"EAGLE",
"--speculative-num-steps",
1,
"--speculative-eagle-topk",
1,
"--speculative-num-draft-tokens",
2,
"--max-running-requests",
4,
"--attention-backend",
"flashinfer",
],
)
cls.accept_len_threshold = 1.85
if __name__ == "__main__":