Enables speculative decoding for the trtllm_mla attention backend (#9238)
This commit is contained in:
@@ -11,7 +11,10 @@ from typing import TYPE_CHECKING, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
|
|
||||||
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
|
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
||||||
|
FlashInferMLAAttnBackend,
|
||||||
|
FlashInferMLAMultiStepDraftBackend,
|
||||||
|
)
|
||||||
from sglang.srt.layers.attention.utils import (
|
from sglang.srt.layers.attention.utils import (
|
||||||
TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
||||||
create_flashmla_kv_indices_triton,
|
create_flashmla_kv_indices_triton,
|
||||||
@@ -96,7 +99,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
|
|
||||||
# CUDA graph state
|
# CUDA graph state
|
||||||
self.decode_cuda_graph_metadata = {}
|
self.decode_cuda_graph_metadata = {}
|
||||||
self.cuda_graph_kv_indices = None
|
self.decode_cuda_graph_kv_indices = None
|
||||||
self.forward_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
|
self.forward_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
|
||||||
|
|
||||||
def _calc_padded_blocks(self, max_seq_len: int) -> int:
|
def _calc_padded_blocks(self, max_seq_len: int) -> int:
|
||||||
@@ -167,15 +170,18 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
kv_indices_buf: Optional[torch.Tensor] = None,
|
kv_indices_buf: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
"""Initialize CUDA graph state for TRTLLM MLA."""
|
"""Initialize CUDA graph state for TRTLLM MLA."""
|
||||||
|
|
||||||
max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
|
max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
|
||||||
|
|
||||||
self.cuda_graph_kv_indices = torch.full(
|
self.decode_cuda_graph_kv_indices = torch.full(
|
||||||
(max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
|
(max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
|
||||||
)
|
)
|
||||||
self.cuda_graph_workspace = torch.empty(
|
self.decode_cuda_graph_workspace = torch.empty(
|
||||||
self.workspace_size, dtype=torch.int8, device=self.device
|
self.workspace_size, dtype=torch.int8, device=self.device
|
||||||
)
|
)
|
||||||
|
|
||||||
|
super().init_cuda_graph_state(max_bs, max_num_tokens, kv_indices_buf)
|
||||||
|
|
||||||
def init_forward_metadata_capture_cuda_graph(
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
self,
|
self,
|
||||||
bs: int,
|
bs: int,
|
||||||
@@ -187,8 +193,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
spec_info: Optional[SpecInfo],
|
spec_info: Optional[SpecInfo],
|
||||||
):
|
):
|
||||||
"""Initialize metadata for CUDA graph capture."""
|
"""Initialize metadata for CUDA graph capture."""
|
||||||
# Delegate to parent for non-decode modes or when speculative execution is used.
|
|
||||||
if not (forward_mode.is_decode_or_idle() and spec_info is None):
|
# Delegate to parent for non-decode modes.
|
||||||
|
if not forward_mode.is_decode_or_idle():
|
||||||
return super().init_forward_metadata_capture_cuda_graph(
|
return super().init_forward_metadata_capture_cuda_graph(
|
||||||
bs,
|
bs,
|
||||||
num_tokens,
|
num_tokens,
|
||||||
@@ -199,9 +206,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
spec_info,
|
spec_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Custom fast-path for decode/idle without speculative execution.
|
# Custom fast-path for decode/idle.
|
||||||
max_seqlen_pad = self._calc_padded_blocks(seq_lens.max().item())
|
max_seqlen_pad = self._calc_padded_blocks(seq_lens.max().item())
|
||||||
block_kv_indices = self.cuda_graph_kv_indices[:bs, :max_seqlen_pad]
|
block_kv_indices = self.decode_cuda_graph_kv_indices[:bs, :max_seqlen_pad]
|
||||||
|
|
||||||
create_flashmla_kv_indices_triton[(bs,)](
|
create_flashmla_kv_indices_triton[(bs,)](
|
||||||
self.req_to_token,
|
self.req_to_token,
|
||||||
@@ -215,7 +222,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
PAGED_SIZE=self.page_size,
|
PAGED_SIZE=self.page_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata = TRTLLMMLADecodeMetadata(self.cuda_graph_workspace, block_kv_indices)
|
metadata = TRTLLMMLADecodeMetadata(
|
||||||
|
self.decode_cuda_graph_workspace, block_kv_indices
|
||||||
|
)
|
||||||
self.decode_cuda_graph_metadata[bs] = metadata
|
self.decode_cuda_graph_metadata[bs] = metadata
|
||||||
self.forward_metadata = metadata
|
self.forward_metadata = metadata
|
||||||
|
|
||||||
@@ -231,8 +240,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
seq_lens_cpu: Optional[torch.Tensor],
|
seq_lens_cpu: Optional[torch.Tensor],
|
||||||
):
|
):
|
||||||
"""Replay CUDA graph with new inputs."""
|
"""Replay CUDA graph with new inputs."""
|
||||||
# Delegate to parent for non-decode modes or when speculative execution is used.
|
# Delegate to parent for non-decode modes.
|
||||||
if not (forward_mode.is_decode_or_idle() and spec_info is None):
|
if not forward_mode.is_decode_or_idle():
|
||||||
return super().init_forward_metadata_replay_cuda_graph(
|
return super().init_forward_metadata_replay_cuda_graph(
|
||||||
bs,
|
bs,
|
||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
@@ -265,11 +274,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
|
|
||||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
"""Initialize the metadata for a forward pass."""
|
"""Initialize the metadata for a forward pass."""
|
||||||
# Delegate to parent for non-decode modes or when speculative execution is used.
|
# Delegate to parent for non-decode modes.
|
||||||
if not (
|
if not forward_batch.forward_mode.is_decode_or_idle():
|
||||||
forward_batch.forward_mode.is_decode_or_idle()
|
|
||||||
and forward_batch.spec_info is None
|
|
||||||
):
|
|
||||||
return super().init_forward_metadata(forward_batch)
|
return super().init_forward_metadata(forward_batch)
|
||||||
|
|
||||||
bs = forward_batch.batch_size
|
bs = forward_batch.batch_size
|
||||||
@@ -474,3 +480,20 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|||||||
output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
|
||||||
|
"""Multi-step draft backend for TRT-LLM MLA used by EAGLE."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, model_runner: "ModelRunner", topk: int, speculative_num_steps: int
|
||||||
|
):
|
||||||
|
super().__init__(model_runner, topk, speculative_num_steps)
|
||||||
|
|
||||||
|
for i in range(self.speculative_num_steps):
|
||||||
|
self.attn_backends[i] = TRTLLMMLABackend(
|
||||||
|
model_runner,
|
||||||
|
skip_prefill=True,
|
||||||
|
kv_indptr_buf=self.kv_indptr[i],
|
||||||
|
q_indptr_decode_buf=self.q_indptr_decode,
|
||||||
|
)
|
||||||
|
|||||||
@@ -479,11 +479,6 @@ class ServerArgs:
|
|||||||
)
|
)
|
||||||
self.page_size = 64
|
self.page_size = 64
|
||||||
|
|
||||||
if self.speculative_algorithm is not None:
|
|
||||||
raise ValueError(
|
|
||||||
"trtllm_mla backend does not support speculative decoding yet."
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.kv_cache_dtype not in ["fp8_e4m3", "auto"]:
|
if self.kv_cache_dtype not in ["fp8_e4m3", "auto"]:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"TensorRT-LLM MLA backend only supports kv-cache-dtype of fp8_e4m3 or auto."
|
"TensorRT-LLM MLA backend only supports kv-cache-dtype of fp8_e4m3 or auto."
|
||||||
|
|||||||
@@ -266,6 +266,27 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
self.topk,
|
self.topk,
|
||||||
self.speculative_num_steps,
|
self.speculative_num_steps,
|
||||||
)
|
)
|
||||||
|
elif self.server_args.attention_backend == "trtllm_mla":
|
||||||
|
if not global_server_args_dict["use_mla_backend"]:
|
||||||
|
raise ValueError(
|
||||||
|
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
|
||||||
|
)
|
||||||
|
|
||||||
|
from sglang.srt.layers.attention.trtllm_mla_backend import (
|
||||||
|
TRTLLMMLABackend,
|
||||||
|
TRTLLMMLAMultiStepDraftBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.draft_attn_backend = TRTLLMMLAMultiStepDraftBackend(
|
||||||
|
self.draft_model_runner,
|
||||||
|
self.topk,
|
||||||
|
self.speculative_num_steps,
|
||||||
|
)
|
||||||
|
self.draft_extend_attn_backend = TRTLLMMLABackend(
|
||||||
|
self.draft_model_runner,
|
||||||
|
skip_prefill=False,
|
||||||
|
)
|
||||||
|
self.has_prefill_wrapper_verify = True
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
|
f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
|
||||||
|
|||||||
Reference in New Issue
Block a user