diff --git a/python/sglang/srt/layers/attention/trtllm_mha_backend.py b/python/sglang/srt/layers/attention/trtllm_mha_backend.py index b737d96e7..a48cc9794 100644 --- a/python/sglang/srt/layers/attention/trtllm_mha_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mha_backend.py @@ -10,13 +10,18 @@ from typing import TYPE_CHECKING, Optional import torch -from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend +from sglang.srt.layers.attention.flashinfer_backend import ( + FlashInferAttnBackend, + FlashInferMultiStepDraftBackend, +) from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.utils import is_flashinfer_available if is_flashinfer_available(): import flashinfer +from sglang.srt.speculative.eagle_utils import EagleDraftInput + if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner @@ -55,9 +60,12 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): model_runner: ModelRunner, skip_prefill: bool = False, kv_indptr_buf: Optional[torch.Tensor] = None, - q_indptr_decode_buf: Optional[torch.Tensor] = None, + kv_last_page_len_buf: Optional[torch.Tensor] = None, + speculative_step_id: int = 0, ): - super().__init__(model_runner, skip_prefill, kv_indptr_buf, q_indptr_decode_buf) + super().__init__( + model_runner, skip_prefill, kv_indptr_buf, kv_last_page_len_buf + ) config = model_runner.model_config @@ -87,6 +95,16 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): # CUDA graph state self.decode_cuda_graph_metadata = {} + # Speculative decoding + # Only support topk <= 1 for now. + self.topk = model_runner.server_args.speculative_eagle_topk or 0 + self.speculative_step_id = speculative_step_id + self.target_verify_metadata = {} + + self.speculative_num_draft_tokens = ( + model_runner.server_args.speculative_num_draft_tokens + ) + # Forward metadata self.forward_metadata: Optional[TRTLLMMHAMetadata] = None @@ -97,11 +115,12 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): kv_indices_buf: Optional[torch.Tensor] = None, ): """Initialize CUDA graph state for TRTLLM MHA.""" + max_num_pages = (self.max_context_len + self.page_size - 1) // self.page_size self.decode_cuda_graph_metadata = { "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device), "page_table": torch.zeros( max_bs, - (self.max_context_len + self.page_size - 1) // self.page_size, + max_num_pages, dtype=torch.int32, device=self.device, ), @@ -110,6 +129,70 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ), } + if ( + self.speculative_num_draft_tokens is not None + and self.speculative_num_draft_tokens > 0 + ): + self.decode_cuda_graph_metadata["cu_seqlens_q"] = torch.arange( + 0, max_bs + 1, dtype=torch.int32, device=self.device + ) + self.decode_cuda_graph_metadata["cu_seqlens_k"] = torch.zeros( + max_bs + 1, dtype=torch.int32, device=self.device + ) + self.decode_cuda_graph_metadata["page_table_draft_decode"] = torch.zeros( + max_bs, + max_num_pages, + dtype=torch.int32, + device=self.device, + ) + self.target_verify_metadata = { + "cache_seqlens": torch.zeros( + max_bs, dtype=torch.int32, device=self.device + ), + "cu_seqlens_q": torch.arange( + 0, + max_bs * self.speculative_num_draft_tokens + 1, + step=self.speculative_num_draft_tokens, + dtype=torch.int32, + device=self.device, + ), + "cu_seqlens_k": torch.zeros( + max_bs + 1, dtype=torch.int32, device=self.device + ), + "page_table": torch.zeros( + max_bs, + max_num_pages, + dtype=torch.int32, + device=self.device, + ), + "strided_indices": torch.arange( + 0, self.max_context_len, self.page_size, device=self.device + ), + } + + self.draft_extend_metadata = { + "cache_seqlens": torch.zeros( + max_bs, dtype=torch.int32, device=self.device + ), + "cu_seqlens_q": torch.zeros( + max_bs + 1, + dtype=torch.int32, + device=self.device, + ), + "cu_seqlens_k": torch.zeros( + max_bs + 1, dtype=torch.int32, device=self.device + ), + "page_table": torch.zeros( + max_bs, + max_num_pages, + dtype=torch.int32, + device=self.device, + ), + "strided_indices": torch.arange( + 0, self.max_context_len, self.page_size, device=self.device + ), + } + def init_forward_metadata_capture_cuda_graph( self, bs: int, @@ -122,16 +205,105 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ): """Initialize metadata for CUDA graph capture.""" metadata = TRTLLMMHAMetadata() + device = seq_lens.device - # Get sequence information - metadata.cache_seqlens_int32 = seq_lens[:bs].to(torch.int32) + if forward_mode.is_decode_or_idle(): + if spec_info is not None: + # Draft Decode + # Here we only support topk = 1 for now. + metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[ + "cache_seqlens" + ][:bs] + metadata.max_seq_len_k = seq_lens.max().item() + ( + self.speculative_step_id + 1 + ) + metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][ + : bs + 1 + ] + metadata.cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum( + metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 + ), + (1, 0), + ) + metadata.page_table = self.decode_cuda_graph_metadata[ + "page_table_draft_decode" + ][:bs, :] + self.decode_cuda_graph_metadata[bs] = metadata + else: + # Normal Decode + # Get sequence information + metadata.cache_seqlens_int32 = seq_lens[:bs].to(torch.int32) + batch_size = len(seq_lens) + metadata.cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0) + ) - # Precompute maximum sequence length - metadata.max_seq_len_k = seq_lens[:bs].max().item() + # Precompute maximum sequence length + metadata.max_seq_len_k = seq_lens.max().item() + # Precompute cumulative sequence lengths + metadata.cu_seqlens_q = torch.arange( + 0, batch_size + 1, dtype=torch.int32, device=device + ) + # Precompute page table + metadata.page_table = self.decode_cuda_graph_metadata["page_table"][ + :bs, : + ] + self.decode_cuda_graph_metadata[bs] = metadata + elif forward_mode.is_target_verify(): + # Target Verify + # Here we only support topk = 1 for now. + metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][ + :bs + ] + metadata.cache_seqlens_int32.copy_( + (seq_lens + self.speculative_num_draft_tokens) + ) - # Precompute page table - metadata.page_table = self.decode_cuda_graph_metadata["page_table"][:bs, :] - self.decode_cuda_graph_metadata[bs] = metadata + metadata.cu_seqlens_q = torch.arange( + 0, + bs * self.speculative_num_draft_tokens + 1, + self.speculative_num_draft_tokens, + dtype=torch.int32, + device=device, + ) + + metadata.cu_seqlens_k = self.target_verify_metadata["cu_seqlens_k"][ + : (bs + 1) + ] + + metadata.max_seq_len_q = self.speculative_num_draft_tokens + metadata.max_seq_len_k = ( + seq_lens.max().item() + self.speculative_num_draft_tokens + ) + + metadata.page_table = self.target_verify_metadata["page_table"][:bs, :] + + self.target_verify_metadata[bs] = metadata + elif forward_mode.is_draft_extend(): + metadata.cache_seqlens_int32 = self.draft_extend_metadata["cache_seqlens"][ + :bs + ] + metadata.cache_seqlens_int32.copy_(seq_lens) + num_tokens_per_bs = num_tokens // bs + metadata.cu_seqlens_q = torch.arange( + 0, + bs * num_tokens_per_bs + 1, + num_tokens_per_bs, + dtype=torch.int32, + device=device, + ) + + metadata.cu_seqlens_k = self.draft_extend_metadata["cu_seqlens_k"][ + : (bs + 1) + ] + num_tokens_per_bs = num_tokens // bs + metadata.max_seq_len_q = num_tokens_per_bs + metadata.max_seq_len_k = seq_lens.max().item() + + metadata.page_table = self.draft_extend_metadata["page_table"][:bs, :] + + self.draft_extend_metadata[bs] = metadata self.forward_metadata = metadata def init_forward_metadata_replay_cuda_graph( @@ -149,21 +321,91 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): seq_lens = seq_lens[:bs] seq_lens_cpu = seq_lens_cpu[:bs] req_pool_indices = req_pool_indices[:bs] - device = seq_lens.device metadata = None + if forward_mode.is_decode_or_idle(): + if spec_info is not None: + # Draft Decode + # Here we only support topk = 1 for now. + metadata = self.decode_cuda_graph_metadata[bs] + max_len = seq_lens_cpu.max().item() + metadata.max_seq_len_k = max_len + self.speculative_step_id + 1 - # Normal Decode - metadata = self.decode_cuda_graph_metadata[bs] - max_len = seq_lens_cpu.max().item() - max_seq_pages = (max_len + self.page_size - 1) // self.page_size - metadata.max_seq_len_k = max_len + max_seq_pages = ( + metadata.max_seq_len_k + self.page_size - 1 + ) // self.page_size - metadata.cache_seqlens_int32.copy_(seq_lens) - page_indices = self.req_to_token[ - req_pool_indices[:, None], - self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][None, :], - ] - metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size) + metadata.cache_seqlens_int32.copy_( + seq_lens + self.speculative_step_id + 1 + ) + else: + # Normal Decode + metadata = self.decode_cuda_graph_metadata[bs] + max_len = seq_lens_cpu.max().item() + max_seq_pages = (max_len + self.page_size - 1) // self.page_size + metadata.max_seq_len_k = max_len + + metadata.cache_seqlens_int32.copy_(seq_lens) + + metadata.cu_seqlens_k[1:].copy_( + torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32) + ) + page_indices = self.req_to_token[ + req_pool_indices[:, None], + self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][ + None, : + ], + ] + metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size) + elif forward_mode.is_target_verify(): + # Here we only support topk = 1 for now. + metadata = self.target_verify_metadata[bs] + metadata.cache_seqlens_int32.copy_( + (seq_lens + self.speculative_num_draft_tokens) + ) + + metadata.max_seq_len_k = ( + seq_lens_cpu.max().item() + self.speculative_num_draft_tokens + ) + max_len = seq_lens_cpu.max().item() + metadata.cu_seqlens_k[1:].copy_( + torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32) + ) + max_seq_pages = ( + metadata.max_seq_len_k + self.page_size - 1 + ) // self.page_size + page_indices = self.req_to_token[ + req_pool_indices[:, None], + self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages], + ] + page_indices //= self.page_size + metadata.page_table[:, :max_seq_pages].copy_(page_indices) + elif forward_mode.is_draft_extend(): + metadata = self.draft_extend_metadata[bs] + metadata.cache_seqlens_int32.copy_(seq_lens) + + metadata.max_seq_len_k = seq_lens_cpu.max().item() + max_len = seq_lens_cpu.max().item() + metadata.cu_seqlens_k[1:].copy_( + torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32) + ) + accept_length = spec_info.accept_length[:bs] + if spec_info.accept_length_cpu: + metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1 + else: + metadata.max_seq_len_q = 1 + + metadata.cu_seqlens_q[1:].copy_( + torch.cumsum(accept_length, dim=0, dtype=torch.int32) + ) + + max_seq_pages = ( + metadata.max_seq_len_k + self.page_size - 1 + ) // self.page_size + page_indices = self.req_to_token[ + req_pool_indices[:, None], + self.draft_extend_metadata["strided_indices"][:max_seq_pages], + ] + metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size) self.forward_metadata = metadata def get_cuda_graph_seq_len_fill_value(self) -> int: @@ -179,12 +421,65 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): device = seqlens_in_batch.device if forward_batch.forward_mode.is_decode_or_idle(): - # Normal Decode - metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) - metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + if forward_batch.spec_info is not None: + # Draft Decode + # Here we only support topk = 1 for now. + metadata.cache_seqlens_int32 = ( + seqlens_in_batch + (self.speculative_step_id + 1) + ).to(torch.int32) + metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + ( + self.speculative_step_id + 1 + ) + metadata.cu_seqlens_q = torch.arange( + 0, batch_size + 1, dtype=torch.int32, device=device + ) + metadata.cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum( + metadata.cache_seqlens_int32, dim=0, dtype=torch.int32 + ), + (1, 0), + ) + metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : metadata.max_seq_len_k + ] + else: + # Normal Decode + metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) + metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + metadata.cu_seqlens_q = torch.arange( + 0, batch_size + 1, dtype=torch.int32, device=device + ) + metadata.cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0) + ) + metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ + forward_batch.req_pool_indices, : metadata.max_seq_len_k + ] + elif forward_batch.forward_mode.is_target_verify(): + # Only support topk = 1 for now. + metadata.cache_seqlens_int32 = ( + forward_batch.seq_lens + self.speculative_num_draft_tokens + ).to(torch.int32) + metadata.max_seq_len_q = self.speculative_num_draft_tokens + metadata.max_seq_len_k = ( + forward_batch.seq_lens_cpu.max().item() + + self.speculative_num_draft_tokens + ) + metadata.cu_seqlens_q = torch.arange( + 0, + batch_size * self.speculative_num_draft_tokens + 1, + self.speculative_num_draft_tokens, + dtype=torch.int32, + device=device, + ) + metadata.cu_seqlens_k = torch.nn.functional.pad( + torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32), + (1, 0), + ) metadata.page_table = forward_batch.req_to_token_pool.req_to_token[ forward_batch.req_pool_indices, : metadata.max_seq_len_k ] + else: metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32) metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() @@ -195,7 +490,10 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): forward_batch.req_pool_indices, : metadata.max_seq_len_k ] - if any(forward_batch.extend_prefix_lens_cpu): + if ( + any(forward_batch.extend_prefix_lens_cpu) + or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND + ): extend_seq_lens = forward_batch.extend_seq_lens metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu) metadata.cu_seqlens_q = torch.nn.functional.pad( @@ -332,3 +630,65 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): ) return o.view(-1, layer.tp_q_head_num * layer.head_dim) + + +class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend): + """Multi-step TRTLLM MHA attention kernel used by EAGLE.""" + + def __init__( + self, model_runner: ModelRunner, topk: int, speculative_num_steps: int + ): + super().__init__(model_runner, topk, speculative_num_steps) + for i in range(speculative_num_steps): + self.attn_backends[i] = TRTLLMHAAttnBackend( + model_runner, + skip_prefill=True, + kv_indptr_buf=self.kv_indptr[i], + kv_last_page_len_buf=self.kv_last_page_len, + speculative_step_id=i, + ) + + def init_forward_metadata(self, forward_batch: ForwardBatch): + for i in range(self.speculative_num_steps - 1): + self.attn_backends[i].init_forward_metadata(forward_batch) + + def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): + for i in range(self.speculative_num_steps): + self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens) + + def init_forward_metadata_capture_cuda_graph( + self, + forward_batch: ForwardBatch, + ): + assert forward_batch.spec_info is not None + assert isinstance(forward_batch.spec_info, EagleDraftInput) + + for i in range(self.speculative_num_steps - 1): + self.attn_backends[i].init_forward_metadata_capture_cuda_graph( + forward_batch.batch_size, + forward_batch.batch_size * self.topk, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + encoder_lens=forward_batch.encoder_lens, + forward_mode=ForwardMode.DECODE, + spec_info=forward_batch.spec_info, + ) + + def init_forward_metadata_replay_cuda_graph( + self, forward_batch: ForwardBatch, bs: int + ): + assert forward_batch.spec_info is not None + assert isinstance(forward_batch.spec_info, EagleDraftInput) + + for i in range(self.speculative_num_steps - 1): + + self.attn_backends[i].init_forward_metadata_replay_cuda_graph( + bs, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_sum, + encoder_lens=forward_batch.encoder_lens, + forward_mode=ForwardMode.DECODE, + spec_info=forward_batch.spec_info, + seq_lens_cpu=forward_batch.seq_lens_cpu, + ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f220770ba..27de75400 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -500,11 +500,6 @@ class ServerArgs: ) self.page_size = 64 - if self.speculative_algorithm is not None: - raise ValueError( - "trtllm_mha backend does not support speculative decoding yet." - ) - if self.attention_backend == "dual_chunk_flash_attn": logger.warning( "Mixed chunk, radix cache, and cuda graphs are disabled because of using dual chunk flash attention backend" @@ -653,6 +648,16 @@ class ServerArgs: self.speculative_num_draft_tokens, ) = auto_choose_speculative_params(self) + if ( + self.attention_backend == "trtllm_mha" + or self.decode_attention_backend == "trtllm_mha" + or self.prefill_attention_backend == "trtllm_mha" + ): + if self.speculative_eagle_topk > 1: + raise ValueError( + "trtllm_mha backend only supports topk = 1 for speculative decoding." + ) + if ( self.speculative_eagle_topk == 1 and self.speculative_num_draft_tokens != self.speculative_num_steps + 1 diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 71f3b15c9..cd26d3d04 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -266,6 +266,22 @@ class EAGLEWorker(TpModelWorker): self.topk, self.speculative_num_steps, ) + elif self.server_args.attention_backend == "trtllm_mha": + from sglang.srt.layers.attention.trtllm_mha_backend import ( + TRTLLMHAAttnBackend, + TRTLLMHAAttnMultiStepDraftBackend, + ) + + self.draft_attn_backend = TRTLLMHAAttnMultiStepDraftBackend( + self.draft_model_runner, + self.topk, + self.speculative_num_steps, + ) + self.draft_extend_attn_backend = TRTLLMHAAttnBackend( + self.draft_model_runner, + skip_prefill=False, + ) + self.has_prefill_wrapper_verify = True elif self.server_args.attention_backend == "trtllm_mla": if not global_server_args_dict["use_mla_backend"]: raise ValueError(