Enables speculative decoding for the trtllm_mla attention backend (#9238)
This commit is contained in:
@@ -11,7 +11,10 @@ from typing import TYPE_CHECKING, Optional, Union
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
|
||||
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
||||
FlashInferMLAAttnBackend,
|
||||
FlashInferMLAMultiStepDraftBackend,
|
||||
)
|
||||
from sglang.srt.layers.attention.utils import (
|
||||
TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
||||
create_flashmla_kv_indices_triton,
|
||||
@@ -96,7 +99,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
|
||||
# CUDA graph state
|
||||
self.decode_cuda_graph_metadata = {}
|
||||
self.cuda_graph_kv_indices = None
|
||||
self.decode_cuda_graph_kv_indices = None
|
||||
self.forward_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
|
||||
|
||||
def _calc_padded_blocks(self, max_seq_len: int) -> int:
|
||||
@@ -167,15 +170,18 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
kv_indices_buf: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""Initialize CUDA graph state for TRTLLM MLA."""
|
||||
|
||||
max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
|
||||
|
||||
self.cuda_graph_kv_indices = torch.full(
|
||||
self.decode_cuda_graph_kv_indices = torch.full(
|
||||
(max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self.cuda_graph_workspace = torch.empty(
|
||||
self.decode_cuda_graph_workspace = torch.empty(
|
||||
self.workspace_size, dtype=torch.int8, device=self.device
|
||||
)
|
||||
|
||||
super().init_cuda_graph_state(max_bs, max_num_tokens, kv_indices_buf)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
@@ -187,8 +193,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
"""Initialize metadata for CUDA graph capture."""
|
||||
# Delegate to parent for non-decode modes or when speculative execution is used.
|
||||
if not (forward_mode.is_decode_or_idle() and spec_info is None):
|
||||
|
||||
# Delegate to parent for non-decode modes.
|
||||
if not forward_mode.is_decode_or_idle():
|
||||
return super().init_forward_metadata_capture_cuda_graph(
|
||||
bs,
|
||||
num_tokens,
|
||||
@@ -199,9 +206,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
spec_info,
|
||||
)
|
||||
|
||||
# Custom fast-path for decode/idle without speculative execution.
|
||||
# Custom fast-path for decode/idle.
|
||||
max_seqlen_pad = self._calc_padded_blocks(seq_lens.max().item())
|
||||
block_kv_indices = self.cuda_graph_kv_indices[:bs, :max_seqlen_pad]
|
||||
block_kv_indices = self.decode_cuda_graph_kv_indices[:bs, :max_seqlen_pad]
|
||||
|
||||
create_flashmla_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
@@ -215,7 +222,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
PAGED_SIZE=self.page_size,
|
||||
)
|
||||
|
||||
metadata = TRTLLMMLADecodeMetadata(self.cuda_graph_workspace, block_kv_indices)
|
||||
metadata = TRTLLMMLADecodeMetadata(
|
||||
self.decode_cuda_graph_workspace, block_kv_indices
|
||||
)
|
||||
self.decode_cuda_graph_metadata[bs] = metadata
|
||||
self.forward_metadata = metadata
|
||||
|
||||
@@ -231,8 +240,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
"""Replay CUDA graph with new inputs."""
|
||||
# Delegate to parent for non-decode modes or when speculative execution is used.
|
||||
if not (forward_mode.is_decode_or_idle() and spec_info is None):
|
||||
# Delegate to parent for non-decode modes.
|
||||
if not forward_mode.is_decode_or_idle():
|
||||
return super().init_forward_metadata_replay_cuda_graph(
|
||||
bs,
|
||||
req_pool_indices,
|
||||
@@ -265,11 +274,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
"""Initialize the metadata for a forward pass."""
|
||||
# Delegate to parent for non-decode modes or when speculative execution is used.
|
||||
if not (
|
||||
forward_batch.forward_mode.is_decode_or_idle()
|
||||
and forward_batch.spec_info is None
|
||||
):
|
||||
# Delegate to parent for non-decode modes.
|
||||
if not forward_batch.forward_mode.is_decode_or_idle():
|
||||
return super().init_forward_metadata(forward_batch)
|
||||
|
||||
bs = forward_batch.batch_size
|
||||
@@ -474,3 +480,20 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
|
||||
"""Multi-step draft backend for TRT-LLM MLA used by EAGLE."""
|
||||
|
||||
def __init__(
|
||||
self, model_runner: "ModelRunner", topk: int, speculative_num_steps: int
|
||||
):
|
||||
super().__init__(model_runner, topk, speculative_num_steps)
|
||||
|
||||
for i in range(self.speculative_num_steps):
|
||||
self.attn_backends[i] = TRTLLMMLABackend(
|
||||
model_runner,
|
||||
skip_prefill=True,
|
||||
kv_indptr_buf=self.kv_indptr[i],
|
||||
q_indptr_decode_buf=self.q_indptr_decode,
|
||||
)
|
||||
|
||||
@@ -479,11 +479,6 @@ class ServerArgs:
|
||||
)
|
||||
self.page_size = 64
|
||||
|
||||
if self.speculative_algorithm is not None:
|
||||
raise ValueError(
|
||||
"trtllm_mla backend does not support speculative decoding yet."
|
||||
)
|
||||
|
||||
if self.kv_cache_dtype not in ["fp8_e4m3", "auto"]:
|
||||
raise ValueError(
|
||||
"TensorRT-LLM MLA backend only supports kv-cache-dtype of fp8_e4m3 or auto."
|
||||
|
||||
@@ -266,6 +266,27 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
elif self.server_args.attention_backend == "trtllm_mla":
|
||||
if not global_server_args_dict["use_mla_backend"]:
|
||||
raise ValueError(
|
||||
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
|
||||
)
|
||||
|
||||
from sglang.srt.layers.attention.trtllm_mla_backend import (
|
||||
TRTLLMMLABackend,
|
||||
TRTLLMMLAMultiStepDraftBackend,
|
||||
)
|
||||
|
||||
self.draft_attn_backend = TRTLLMMLAMultiStepDraftBackend(
|
||||
self.draft_model_runner,
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
self.draft_extend_attn_backend = TRTLLMMLABackend(
|
||||
self.draft_model_runner,
|
||||
skip_prefill=False,
|
||||
)
|
||||
self.has_prefill_wrapper_verify = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
|
||||
|
||||
Reference in New Issue
Block a user