Support speculative decoding in the trtllm_mha attention backend (#9331)
Co-authored-by: ispobock <ispobaoke@gmail.com>
This commit is contained in:
@@ -266,6 +266,22 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
elif self.server_args.attention_backend == "trtllm_mha":
|
||||
from sglang.srt.layers.attention.trtllm_mha_backend import (
|
||||
TRTLLMHAAttnBackend,
|
||||
TRTLLMHAAttnMultiStepDraftBackend,
|
||||
)
|
||||
|
||||
self.draft_attn_backend = TRTLLMHAAttnMultiStepDraftBackend(
|
||||
self.draft_model_runner,
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
self.draft_extend_attn_backend = TRTLLMHAAttnBackend(
|
||||
self.draft_model_runner,
|
||||
skip_prefill=False,
|
||||
)
|
||||
self.has_prefill_wrapper_verify = True
|
||||
elif self.server_args.attention_backend == "trtllm_mla":
|
||||
if not global_server_args_dict["use_mla_backend"]:
|
||||
raise ValueError(
|
||||
|
||||
Reference in New Issue
Block a user