Support page size > 1 + eagle (#4908)

This commit is contained in:
Lianmin Zheng
2025-03-30 00:46:23 -07:00
committed by GitHub
parent 5ec5eaf760
commit b26bc86b36
16 changed files with 374 additions and 71 deletions

View File

@@ -26,7 +26,7 @@ suites = {
TestFile("test_abort.py", 51),
TestFile("test_block_int8.py", 22),
TestFile("test_chunked_prefill.py", 336),
TestFile("test_eagle_infer.py", 447),
TestFile("test_eagle_infer.py", 500),
TestFile("test_ebnf_constrained.py"),
TestFile("test_fp8_kernel.py", 2),
TestFile("test_embedding_openai_server.py", 36),

View File

@@ -298,10 +298,16 @@ class TestEAGLEServer(CustomTestCase):
print(f"{metrics=}")
self.assertGreater(metrics["accuracy"], 0.20)
server_info = requests.get(self.base_url + "/get_server_info")
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
server_info = requests.get(self.base_url + "/get_server_info").json()
avg_spec_accept_length = server_info["avg_spec_accept_length"]
print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, 3.5)
speculative_eagle_topk = server_info["speculative_eagle_topk"]
if speculative_eagle_topk == 1:
self.assertGreater(avg_spec_accept_length, 2.5)
else:
self.assertGreater(avg_spec_accept_length, 3.5)
# Wait a little bit so that the memory check happens.
time.sleep(4)
@@ -535,5 +541,36 @@ class TestEAGLEServerTriton(TestEAGLEServer):
)
class TestEAGLEServerPageSize(TestEAGLEServer):
@classmethod
def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--speculative-algorithm",
"EAGLE",
"--speculative-draft-model-path",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
5,
"--speculative-eagle-topk",
1,
"--speculative-num-draft-tokens",
6,
"--mem-fraction-static",
0.7,
"--chunked-prefill-size",
128,
"--max-running-requests",
8,
"--page-size",
4,
],
)
if __name__ == "__main__":
unittest.main()

View File

@@ -157,6 +157,7 @@ class TestFlashinferMLAMTP(CustomTestCase):
self.assertGreater(metrics["accuracy"], 0.60)
server_info = requests.get(self.base_url + "/get_server_info")
print(f"{server_info=}")
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
print(f"{avg_spec_accept_length=}")
self.assertGreater(avg_spec_accept_length, 2.5)