Fix ngram spec with page size > 1 (#11135)

This commit is contained in:
Liangsheng Yin
2025-10-02 12:34:23 +08:00
committed by GitHub
parent 0b2aa8a70c
commit 25e7dbe8af
5 changed files with 32 additions and 10 deletions

View File

@@ -31,7 +31,7 @@ DEFAULT_SERVER_ARGS = [
]
class TestStandaloneSpeculativeDecodingBase(CustomTestCase):
class TestNgramSpeculativeDecodingBase(CustomTestCase):
model = DEFAULT_NGRAM_SPECULATIVE_TARGET_MODEL_FOR_TEST
base_url = DEFAULT_URL_FOR_TEST
@@ -88,20 +88,30 @@ class TestStandaloneSpeculativeDecodingBase(CustomTestCase):
self.assertGreater(avg_spec_accept_length, self.spec_decode_threshold)
class TestStandaloneSpeculativeDecodingTriton(TestStandaloneSpeculativeDecodingBase):
class TestNgramSpeculativeDecodingTriton(TestNgramSpeculativeDecodingBase):
@classmethod
def get_server_args(cls):
return DEFAULT_SERVER_ARGS + ["--attention-backend", "triton"]
class TestStandaloneSpeculativeDecodingFlashinfer(
TestStandaloneSpeculativeDecodingBase
):
class TestNgramSpeculativeDecodingFlashinfer(TestNgramSpeculativeDecodingBase):
@classmethod
def get_server_args(cls):
return DEFAULT_SERVER_ARGS + ["--attention-backend", "flashinfer"]
class TestNgramSpeculativeDecodingPaged(TestNgramSpeculativeDecodingBase):
@classmethod
def get_server_args(cls):
return DEFAULT_SERVER_ARGS + [
"--attention-backend",
"flashinfer",
"--page-size",
"64",
]
if __name__ == "__main__":
unittest.main()