Enables speculative decoding for the trtllm_mla attention backend (#9238)

This commit is contained in:
pranavm-nvidia
2025-08-21 01:18:21 -07:00
committed by GitHub
parent 18da2c96ec
commit 64574ef8c0
3 changed files with 60 additions and 21 deletions

View File

@@ -266,6 +266,27 @@ class EAGLEWorker(TpModelWorker):
self.topk,
self.speculative_num_steps,
)
elif self.server_args.attention_backend == "trtllm_mla":
if not global_server_args_dict["use_mla_backend"]:
raise ValueError(
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
)
from sglang.srt.layers.attention.trtllm_mla_backend import (
TRTLLMMLABackend,
TRTLLMMLAMultiStepDraftBackend,
)
self.draft_attn_backend = TRTLLMMLAMultiStepDraftBackend(
self.draft_model_runner,
self.topk,
self.speculative_num_steps,
)
self.draft_extend_attn_backend = TRTLLMMLABackend(
self.draft_model_runner,
skip_prefill=False,
)
self.has_prefill_wrapper_verify = True
else:
raise ValueError(
f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"