Support page size > 1 + eagle (#4908)
This commit is contained in:
@@ -26,7 +26,7 @@ suites = {
|
||||
TestFile("test_abort.py", 51),
|
||||
TestFile("test_block_int8.py", 22),
|
||||
TestFile("test_chunked_prefill.py", 336),
|
||||
TestFile("test_eagle_infer.py", 447),
|
||||
TestFile("test_eagle_infer.py", 500),
|
||||
TestFile("test_ebnf_constrained.py"),
|
||||
TestFile("test_fp8_kernel.py", 2),
|
||||
TestFile("test_embedding_openai_server.py", 36),
|
||||
|
||||
@@ -298,10 +298,16 @@ class TestEAGLEServer(CustomTestCase):
|
||||
print(f"{metrics=}")
|
||||
self.assertGreater(metrics["accuracy"], 0.20)
|
||||
|
||||
server_info = requests.get(self.base_url + "/get_server_info")
|
||||
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
|
||||
server_info = requests.get(self.base_url + "/get_server_info").json()
|
||||
avg_spec_accept_length = server_info["avg_spec_accept_length"]
|
||||
print(f"{avg_spec_accept_length=}")
|
||||
self.assertGreater(avg_spec_accept_length, 3.5)
|
||||
|
||||
speculative_eagle_topk = server_info["speculative_eagle_topk"]
|
||||
|
||||
if speculative_eagle_topk == 1:
|
||||
self.assertGreater(avg_spec_accept_length, 2.5)
|
||||
else:
|
||||
self.assertGreater(avg_spec_accept_length, 3.5)
|
||||
|
||||
# Wait a little bit so that the memory check happens.
|
||||
time.sleep(4)
|
||||
@@ -535,5 +541,36 @@ class TestEAGLEServerTriton(TestEAGLEServer):
|
||||
)
|
||||
|
||||
|
||||
class TestEAGLEServerPageSize(TestEAGLEServer):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--speculative-algorithm",
|
||||
"EAGLE",
|
||||
"--speculative-draft-model-path",
|
||||
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
||||
"--speculative-num-steps",
|
||||
5,
|
||||
"--speculative-eagle-topk",
|
||||
1,
|
||||
"--speculative-num-draft-tokens",
|
||||
6,
|
||||
"--mem-fraction-static",
|
||||
0.7,
|
||||
"--chunked-prefill-size",
|
||||
128,
|
||||
"--max-running-requests",
|
||||
8,
|
||||
"--page-size",
|
||||
4,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -157,6 +157,7 @@ class TestFlashinferMLAMTP(CustomTestCase):
|
||||
self.assertGreater(metrics["accuracy"], 0.60)
|
||||
|
||||
server_info = requests.get(self.base_url + "/get_server_info")
|
||||
print(f"{server_info=}")
|
||||
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
|
||||
print(f"{avg_spec_accept_length=}")
|
||||
self.assertGreater(avg_spec_accept_length, 2.5)
|
||||
|
||||
Reference in New Issue
Block a user