diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 7aeb00d6b..ee69c0cb9 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -11,7 +11,10 @@ from typing import TYPE_CHECKING, Optional, Union import torch import triton -from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend +from sglang.srt.layers.attention.flashinfer_mla_backend import ( + FlashInferMLAAttnBackend, + FlashInferMLAMultiStepDraftBackend, +) from sglang.srt.layers.attention.utils import ( TRITON_PAD_NUM_PAGE_PER_BLOCK, create_flashmla_kv_indices_triton, @@ -96,7 +99,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): # CUDA graph state self.decode_cuda_graph_metadata = {} - self.cuda_graph_kv_indices = None + self.decode_cuda_graph_kv_indices = None self.forward_metadata: Union[TRTLLMMLADecodeMetadata, None] = None def _calc_padded_blocks(self, max_seq_len: int) -> int: @@ -167,15 +170,18 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): kv_indices_buf: Optional[torch.Tensor] = None, ): """Initialize CUDA graph state for TRTLLM MLA.""" + max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len) - self.cuda_graph_kv_indices = torch.full( + self.decode_cuda_graph_kv_indices = torch.full( (max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device ) - self.cuda_graph_workspace = torch.empty( + self.decode_cuda_graph_workspace = torch.empty( self.workspace_size, dtype=torch.int8, device=self.device ) + super().init_cuda_graph_state(max_bs, max_num_tokens, kv_indices_buf) + def init_forward_metadata_capture_cuda_graph( self, bs: int, @@ -187,8 +193,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): spec_info: Optional[SpecInfo], ): """Initialize metadata for CUDA graph capture.""" - # Delegate to parent for non-decode modes or when speculative execution is used. - if not (forward_mode.is_decode_or_idle() and spec_info is None): + + # Delegate to parent for non-decode modes. + if not forward_mode.is_decode_or_idle(): return super().init_forward_metadata_capture_cuda_graph( bs, num_tokens, @@ -199,9 +206,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): spec_info, ) - # Custom fast-path for decode/idle without speculative execution. + # Custom fast-path for decode/idle. max_seqlen_pad = self._calc_padded_blocks(seq_lens.max().item()) - block_kv_indices = self.cuda_graph_kv_indices[:bs, :max_seqlen_pad] + block_kv_indices = self.decode_cuda_graph_kv_indices[:bs, :max_seqlen_pad] create_flashmla_kv_indices_triton[(bs,)]( self.req_to_token, @@ -215,7 +222,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): PAGED_SIZE=self.page_size, ) - metadata = TRTLLMMLADecodeMetadata(self.cuda_graph_workspace, block_kv_indices) + metadata = TRTLLMMLADecodeMetadata( + self.decode_cuda_graph_workspace, block_kv_indices + ) self.decode_cuda_graph_metadata[bs] = metadata self.forward_metadata = metadata @@ -231,8 +240,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): seq_lens_cpu: Optional[torch.Tensor], ): """Replay CUDA graph with new inputs.""" - # Delegate to parent for non-decode modes or when speculative execution is used. - if not (forward_mode.is_decode_or_idle() and spec_info is None): + # Delegate to parent for non-decode modes. + if not forward_mode.is_decode_or_idle(): return super().init_forward_metadata_replay_cuda_graph( bs, req_pool_indices, @@ -265,11 +274,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): def init_forward_metadata(self, forward_batch: ForwardBatch): """Initialize the metadata for a forward pass.""" - # Delegate to parent for non-decode modes or when speculative execution is used. - if not ( - forward_batch.forward_mode.is_decode_or_idle() - and forward_batch.spec_info is None - ): + # Delegate to parent for non-decode modes. + if not forward_batch.forward_mode.is_decode_or_idle(): return super().init_forward_metadata(forward_batch) bs = forward_batch.batch_size @@ -474,3 +480,20 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim) return output + + +class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend): + """Multi-step draft backend for TRT-LLM MLA used by EAGLE.""" + + def __init__( + self, model_runner: "ModelRunner", topk: int, speculative_num_steps: int + ): + super().__init__(model_runner, topk, speculative_num_steps) + + for i in range(self.speculative_num_steps): + self.attn_backends[i] = TRTLLMMLABackend( + model_runner, + skip_prefill=True, + kv_indptr_buf=self.kv_indptr[i], + q_indptr_decode_buf=self.q_indptr_decode, + ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 326b67e37..150d02e77 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -479,11 +479,6 @@ class ServerArgs: ) self.page_size = 64 - if self.speculative_algorithm is not None: - raise ValueError( - "trtllm_mla backend does not support speculative decoding yet." - ) - if self.kv_cache_dtype not in ["fp8_e4m3", "auto"]: raise ValueError( "TensorRT-LLM MLA backend only supports kv-cache-dtype of fp8_e4m3 or auto." diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 9dc7438c9..71f3b15c9 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -266,6 +266,27 @@ class EAGLEWorker(TpModelWorker): self.topk, self.speculative_num_steps, ) + elif self.server_args.attention_backend == "trtllm_mla": + if not global_server_args_dict["use_mla_backend"]: + raise ValueError( + "trtllm_mla backend requires MLA model (use_mla_backend=True)." + ) + + from sglang.srt.layers.attention.trtllm_mla_backend import ( + TRTLLMMLABackend, + TRTLLMMLAMultiStepDraftBackend, + ) + + self.draft_attn_backend = TRTLLMMLAMultiStepDraftBackend( + self.draft_model_runner, + self.topk, + self.speculative_num_steps, + ) + self.draft_extend_attn_backend = TRTLLMMLABackend( + self.draft_model_runner, + skip_prefill=False, + ) + self.has_prefill_wrapper_verify = True else: raise ValueError( f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"