Support speculative decoding in the trtllm_mha attention backend (#9331)

Co-authored-by: ispobock <ispobaoke@gmail.com>
This commit is contained in:
Qiaolin Yu
2025-08-21 23:53:35 -07:00
committed by GitHub
parent fedfe91c1a
commit 9ec314c6ac
3 changed files with 413 additions and 32 deletions

View File

@@ -266,6 +266,22 @@ class EAGLEWorker(TpModelWorker):
self.topk,
self.speculative_num_steps,
)
elif self.server_args.attention_backend == "trtllm_mha":
from sglang.srt.layers.attention.trtllm_mha_backend import (
TRTLLMHAAttnBackend,
TRTLLMHAAttnMultiStepDraftBackend,
)
self.draft_attn_backend = TRTLLMHAAttnMultiStepDraftBackend(
self.draft_model_runner,
self.topk,
self.speculative_num_steps,
)
self.draft_extend_attn_backend = TRTLLMHAAttnBackend(
self.draft_model_runner,
skip_prefill=False,
)
self.has_prefill_wrapper_verify = True
elif self.server_args.attention_backend == "trtllm_mla":
if not global_server_args_dict["use_mla_backend"]:
raise ValueError(