Add speculator attention backend switch (#9981)

This commit is contained in:
cicirori
2025-09-08 06:44:36 +02:00
committed by GitHub
parent 3b99f23c44
commit 8c5930f08a
6 changed files with 130 additions and 54 deletions

View File

@@ -132,5 +132,51 @@ class TestHybridAttnBackendSpeculativeDecoding(TestHybridAttnBackendBase):
]
class TestHybridAttnBackendSpeculativeDecodingPrefillBackend(TestHybridAttnBackendBase):
speculative_decode = True
# This eagle test uses a very small model, so the accuracy is low.
accuracy_threshold = 0.2
@classmethod
def get_server_args(cls):
return DEFAULT_SERVER_ARGS + [
"--speculative-algorithm",
"EAGLE",
"--speculative-draft",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
"3",
"--speculative-eagle-topk",
"2",
"--speculative-num-draft-tokens",
"4",
"--speculative-attention-backend",
"prefill",
]
class TestHybridAttnBackendSpeculativeDecodingDecodeBackend(TestHybridAttnBackendBase):
speculative_decode = True
# This eagle test uses a very small model, so the accuracy is low.
accuracy_threshold = 0.2
@classmethod
def get_server_args(cls):
return DEFAULT_SERVER_ARGS + [
"--speculative-algorithm",
"EAGLE",
"--speculative-draft",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
"3",
"--speculative-eagle-topk",
"2",
"--speculative-num-draft-tokens",
"4",
"--speculative-attention-backend",
"decode",
]
if __name__ == "__main__":
unittest.main()