diff --git a/python/sglang/srt/layers/attention/__init__.py b/python/sglang/srt/layers/attention/__init__.py index 1486987dc..2543e9de6 100644 --- a/python/sglang/srt/layers/attention/__init__.py +++ b/python/sglang/srt/layers/attention/__init__.py @@ -1,10 +1,14 @@ +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import Optional +from typing import TYPE_CHECKING, Optional import torch -from sglang.srt.layers.radix_attention import RadixAttention -from sglang.srt.model_executor.forward_batch_info import ForwardBatch +if TYPE_CHECKING: + from sglang.srt.layers.radix_attention import RadixAttention + from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode + from sglang.srt.speculative.spec_info import SpecInfo class AttentionBackend(ABC): @@ -22,9 +26,12 @@ class AttentionBackend(ABC): def init_forward_metadata_capture_cuda_graph( self, bs: int, + num_token: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, - encoder_lens: Optional[torch.Tensor] = None, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInfo], ): """Init the metadata for a forward pass for capturing a cuda graph.""" raise NotImplementedError() @@ -35,7 +42,9 @@ class AttentionBackend(ABC): req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, - encoder_lens: Optional[torch.Tensor] = None, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInfo], ): """Init the metadata for a forward pass for replying a cuda graph.""" raise NotImplementedError() diff --git a/python/sglang/srt/layers/attention/double_sparsity_backend.py b/python/sglang/srt/layers/attention/double_sparsity_backend.py index 856aa984c..a5e54f32d 100644 --- a/python/sglang/srt/layers/attention/double_sparsity_backend.py +++ b/python/sglang/srt/layers/attention/double_sparsity_backend.py @@ -3,7 +3,6 @@ from __future__ import annotations from typing import TYPE_CHECKING import torch -import torch.nn as nn from sglang.srt.layers.attention import AttentionBackend from sglang.srt.managers.schedule_batch import global_server_args_dict @@ -52,8 +51,6 @@ class DoubleSparseAttnBackend(AttentionBackend): self.forward_metadata = None - self.cuda_graph_max_seq_len = model_runner.model_config.context_len - def init_forward_metadata(self, forward_batch: ForwardBatch): """Init auxiliary variables for triton attention backend.""" @@ -115,55 +112,6 @@ class DoubleSparseAttnBackend(AttentionBackend): ds_req_to_token, ) - def init_cuda_graph_state(self, max_bs: int): - # TODO(Andy): Support CUDA graph for double sparse attention - raise ValueError( - "Double sparse attention does not support CUDA graph for now. Please --disable-cuda-graph" - ) - self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len - - self.cuda_graph_start_loc = torch.zeros( - (max_bs,), dtype=torch.int32, device="cuda" - ) - self.cuda_graph_attn_logits = torch.empty( - ( - self.num_head, - self.cuda_graph_max_total_num_tokens, - ), - dtype=self.reduce_dtype, - device="cuda", - ) - - def init_forward_metadata_capture_cuda_graph( - self, - bs: int, - req_pool_indices: torch.Tensor, - seq_lens: torch.Tensor, - encoder_lens=None, - ): - # NOTE: encoder_lens expected to be zeros or None - self.forward_metadata = ( - self.cuda_graph_start_loc, - self.cuda_graph_attn_logits, - self.cuda_graph_max_seq_len, - None, - ) - - def init_forward_metadata_replay_cuda_graph( - self, - bs: int, - req_pool_indices: torch.Tensor, - seq_lens: torch.Tensor, - seq_lens_sum: int, - encoder_lens=None, - ): - # NOTE: encoder_lens expected to be zeros or None - self.cuda_graph_start_loc.zero_() - self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) - - def get_cuda_graph_seq_len_fill_value(self): - return 1 - def forward_extend( self, q, diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index db1e17b5e..1cd5c56cf 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -10,7 +10,7 @@ Each backend supports two operators: extend (i.e. prefill with cached prefix) an import os from dataclasses import dataclass from enum import Enum, auto -from typing import TYPE_CHECKING, List, Union +from typing import TYPE_CHECKING, List, Optional, Union import torch import triton @@ -18,12 +18,13 @@ import triton.language as tl from sglang.global_config import global_config from sglang.srt.layers.attention import AttentionBackend -from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.utils import is_flashinfer_available if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner + from sglang.srt.speculative.spec_info import SpecInfo if is_flashinfer_available(): from flashinfer import ( @@ -113,11 +114,15 @@ class FlashInferAttnBackend(AttentionBackend): # Two wrappers: one for sliding window attention and one for full attention. # Using two wrappers is unnecessary in the current PR, but are prepared for future PRs self.prefill_wrappers_paged = [] + self.prefill_wrappers_verify = [] self.decode_wrappers = [] for _ in range(self.num_wrappers): self.prefill_wrappers_paged.append( BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD") ) + self.prefill_wrappers_verify.append( + BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD") + ) self.decode_wrappers.append( BatchDecodeWithPagedKVCacheWrapper( self.workspace_buffer, @@ -135,6 +140,7 @@ class FlashInferAttnBackend(AttentionBackend): # Other metadata self.forward_metadata: Union[PrefillMetadata, DecodeMetadata] = None self.decode_cuda_graph_metadata = {} + self.prefill_cuda_graph_metadata = {} def init_forward_metadata(self, forward_batch: ForwardBatch): if forward_batch.forward_mode.is_decode(): @@ -144,8 +150,37 @@ class FlashInferAttnBackend(AttentionBackend): forward_batch.seq_lens_sum, decode_wrappers=self.decode_wrappers, encoder_lens=forward_batch.encoder_lens, + spec_info=forward_batch.spec_info, ) self.forward_metadata = DecodeMetadata(self.decode_wrappers) + elif forward_batch.forward_mode.is_draft_extend(): + self.indices_updater_prefill.update( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_sum, + prefix_lens=None, + prefill_wrappers=self.prefill_wrappers_paged, + use_ragged=False, + encoder_lens=forward_batch.encoder_lens, + spec_info=forward_batch.spec_info, + ) + self.forward_metadata = PrefillMetadata( + self.prefill_wrappers_paged, False, False + ) + elif forward_batch.forward_mode.is_target_verify(): + self.indices_updater_prefill.update( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_sum, + prefix_lens=None, + prefill_wrappers=self.prefill_wrappers_verify, + use_ragged=False, + encoder_lens=forward_batch.encoder_lens, + spec_info=forward_batch.spec_info, + ) + self.forward_metadata = PrefillMetadata( + self.prefill_wrappers_verify, False, False + ) else: prefix_lens = forward_batch.extend_prefix_lens @@ -165,6 +200,7 @@ class FlashInferAttnBackend(AttentionBackend): prefill_wrappers=self.prefill_wrappers_paged, use_ragged=use_ragged, encoder_lens=forward_batch.encoder_lens, + spec_info=None, ) self.forward_metadata = PrefillMetadata( self.prefill_wrappers_paged, use_ragged, extend_no_prefix @@ -180,37 +216,80 @@ class FlashInferAttnBackend(AttentionBackend): cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1) ] + self.cuda_graph_custom_mask = torch.zeros( + (max_bs * self.max_context_len), + dtype=torch.uint8, + device="cuda", + ) + self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr] + self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr] + def init_forward_metadata_capture_cuda_graph( self, bs: int, + num_token: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, - encoder_lens: torch.Tensor = None, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInfo], ): - decode_wrappers = [] - for i in range(self.num_wrappers): - decode_wrappers.append( - BatchDecodeWithPagedKVCacheWrapper( - self.workspace_buffer, - "NHD", - use_cuda_graph=True, - use_tensor_cores=self.decode_use_tensor_cores, - paged_kv_indptr_buffer=self.kv_indptr[i][: bs + 1], - paged_kv_indices_buffer=self.cuda_graph_kv_indices[i], - paged_kv_last_page_len_buffer=self.kv_last_page_len[:bs], + if forward_mode.is_decode(): + decode_wrappers = [] + for i in range(self.num_wrappers): + decode_wrappers.append( + BatchDecodeWithPagedKVCacheWrapper( + self.workspace_buffer, + "NHD", + use_cuda_graph=True, + use_tensor_cores=self.decode_use_tensor_cores, + paged_kv_indptr_buffer=self.kv_indptr[i][: num_token + 1], + paged_kv_indices_buffer=self.cuda_graph_kv_indices[i], + paged_kv_last_page_len_buffer=self.kv_last_page_len[:num_token], + ) ) + seq_lens_sum = seq_lens.sum().item() + self.indices_updater_decode.update( + req_pool_indices, + seq_lens, + seq_lens_sum, + decode_wrappers=decode_wrappers, + encoder_lens=encoder_lens, + spec_info=spec_info, ) - - seq_lens_sum = seq_lens.sum().item() - self.indices_updater_decode.update( - req_pool_indices, - seq_lens, - seq_lens_sum, - decode_wrappers=decode_wrappers, - encoder_lens=encoder_lens, - ) - self.decode_cuda_graph_metadata[bs] = decode_wrappers - self.forward_metadata = DecodeMetadata(decode_wrappers) + self.decode_cuda_graph_metadata[bs] = decode_wrappers + self.forward_metadata = DecodeMetadata(decode_wrappers) + elif forward_mode.is_target_verify(): + prefill_wrappers = [] + for i in range(self.num_wrappers): + prefill_wrappers.append( + BatchPrefillWithPagedKVCacheWrapper( + self.workspace_buffer, + "NHD", + use_cuda_graph=True, + qo_indptr_buf=self.cuda_graph_qo_indptr[i][: bs + 1], + paged_kv_indptr_buf=self.kv_indptr[i][: bs + 1], + paged_kv_indices_buf=self.cuda_graph_kv_indices[i], + paged_kv_last_page_len_buf=self.kv_last_page_len[:bs], + custom_mask_buf=self.cuda_graph_custom_mask, + qk_indptr_buf=self.cuda_graph_qk_indptr[i][: bs + 1], + ) + ) + seq_lens_sum = seq_lens.sum().item() + self.indices_updater_prefill.update( + req_pool_indices, + seq_lens, + seq_lens_sum, + prefix_lens=None, + prefill_wrappers=prefill_wrappers, + use_ragged=False, + encoder_lens=encoder_lens, + spec_info=spec_info, + ) + self.prefill_cuda_graph_metadata[bs] = prefill_wrappers + self.forward_metadata = PrefillMetadata(prefill_wrappers, False, False) + else: + raise ValueError(f"Invalid mode: {forward_mode=}") def init_forward_metadata_replay_cuda_graph( self, @@ -218,24 +297,41 @@ class FlashInferAttnBackend(AttentionBackend): req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, - encoder_lens: torch.Tensor = None, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInfo], ): - self.indices_updater_decode.update( - req_pool_indices[:bs], - seq_lens[:bs], - seq_lens_sum, - decode_wrappers=self.decode_cuda_graph_metadata[bs], - encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None, - ) + if forward_mode.is_decode(): + self.indices_updater_decode.update( + req_pool_indices[:bs], + seq_lens[:bs], + seq_lens_sum, + decode_wrappers=self.decode_cuda_graph_metadata[bs], + encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None, + spec_info=spec_info, + ) + elif forward_mode.is_target_verify(): + self.indices_updater_prefill.update( + req_pool_indices[:bs], + seq_lens[:bs], + seq_lens_sum, + prefix_lens=None, + prefill_wrappers=self.prefill_cuda_graph_metadata[bs], + use_ragged=False, + encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None, + spec_info=spec_info, + ) + else: + raise ValueError("Invalid forward mode") def get_cuda_graph_seq_len_fill_value(self): return 0 def forward_extend( self, - q, - k, - v, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache=True, @@ -293,9 +389,9 @@ class FlashInferAttnBackend(AttentionBackend): def forward_decode( self, - q, - k, - v, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache=True, @@ -348,7 +444,6 @@ class FlashInferIndicesUpdaterDecode: self.data_type = model_runner.kv_cache_dtype self.q_data_type = model_runner.dtype self.sliding_window_size = model_runner.sliding_window_size - self.attn_backend = attn_backend # Buffers and wrappers @@ -371,7 +466,8 @@ class FlashInferIndicesUpdaterDecode: seq_lens: torch.Tensor, seq_lens_sum: int, decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], - encoder_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + spec_info: Optional[SpecInfo], ): # Keep the signature for type checking. It will be assigned during runtime. raise NotImplementedError() @@ -382,7 +478,8 @@ class FlashInferIndicesUpdaterDecode: seq_lens: torch.Tensor, seq_lens_sum: int, decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], - encoder_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + spec_info: Optional[SpecInfo], ): decode_wrappers = decode_wrappers or self.decode_wrappers self.call_begin_forward( @@ -392,6 +489,7 @@ class FlashInferIndicesUpdaterDecode: seq_lens_sum, self.kv_indptr[0], None, + spec_info, ) def update_sliding_window( @@ -400,7 +498,8 @@ class FlashInferIndicesUpdaterDecode: seq_lens: torch.Tensor, seq_lens_sum: int, decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], - encoder_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + spec_info: Optional[SpecInfo], ): for wrapper_id in range(2): if wrapper_id == 0: @@ -424,6 +523,7 @@ class FlashInferIndicesUpdaterDecode: paged_kernel_lens_sum_tmp, self.kv_indptr[wrapper_id], kv_start_idx_tmp, + spec_info, ) def update_cross_attention( @@ -432,7 +532,8 @@ class FlashInferIndicesUpdaterDecode: seq_lens: torch.Tensor, seq_lens_sum: int, decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], - encoder_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + spec_info: Optional[SpecInfo], ): for wrapper_id in range(2): if wrapper_id == 0: @@ -452,6 +553,7 @@ class FlashInferIndicesUpdaterDecode: seq_lens_sum, self.kv_indptr[wrapper_id], kv_start_idx, + spec_info, ) def call_begin_forward( @@ -462,23 +564,30 @@ class FlashInferIndicesUpdaterDecode: paged_kernel_lens_sum: int, kv_indptr: torch.Tensor, kv_start_idx: torch.Tensor, + spec_info: Optional[SpecInfo], ): - bs = len(req_pool_indices) - kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) - kv_indptr = kv_indptr[: bs + 1] - kv_indices = torch.empty( - paged_kernel_lens_sum, dtype=torch.int32, device="cuda" - ) - - create_flashinfer_kv_indices_triton[(bs,)]( - self.req_to_token, - req_pool_indices, - paged_kernel_lens, - kv_indptr, - kv_start_idx, - kv_indices, - self.req_to_token.shape[1], - ) + if spec_info is None: + bs = len(req_pool_indices) + kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.empty( + paged_kernel_lens_sum, dtype=torch.int32, device="cuda" + ) + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + paged_kernel_lens, + kv_indptr, + kv_start_idx, + kv_indices, + self.req_to_token.shape[1], + ) + else: + bs, kv_indices, kv_indptr = spec_info.generate_attn_arg_decode( + req_pool_indices, + paged_kernel_lens, + self.req_to_token, + ) wrapper.end_forward() wrapper.begin_forward( @@ -507,7 +616,6 @@ class FlashInferIndicesUpdaterPrefill: self.data_type = model_runner.kv_cache_dtype self.q_data_type = model_runner.dtype self.sliding_window_size = model_runner.sliding_window_size - self.attn_backend = attn_backend # Buffers and wrappers @@ -534,7 +642,8 @@ class FlashInferIndicesUpdaterPrefill: prefix_lens: torch.Tensor, prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], use_ragged: bool, - encoder_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + spec_info: Optional[SpecInfo], ): # Keep the signature for type checking. It will be assigned during runtime. raise NotImplementedError() @@ -547,7 +656,8 @@ class FlashInferIndicesUpdaterPrefill: prefix_lens: torch.Tensor, prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], use_ragged: bool, - encoder_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + spec_info: Optional[SpecInfo], ): if use_ragged: paged_kernel_lens = prefix_lens @@ -568,6 +678,7 @@ class FlashInferIndicesUpdaterPrefill: self.kv_indptr[0], self.qo_indptr[0], use_ragged, + spec_info, ) def update_sliding_window( @@ -578,7 +689,8 @@ class FlashInferIndicesUpdaterPrefill: prefix_lens: torch.Tensor, prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], use_ragged: bool, - encoder_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + spec_info: Optional[SpecInfo], ): for wrapper_id in range(2): if wrapper_id == 0: @@ -607,6 +719,7 @@ class FlashInferIndicesUpdaterPrefill: self.kv_indptr[wrapper_id], self.qo_indptr[wrapper_id], use_ragged, + spec_info, ) def update_cross_attention( @@ -617,7 +730,8 @@ class FlashInferIndicesUpdaterPrefill: prefix_lens: torch.Tensor, prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], use_ragged: bool, - encoder_lens: torch.Tensor, + encoder_lens: Optional[torch.Tensor], + spec_info: Optional[SpecInfo], ): for wrapper_id in range(2): if wrapper_id == 0: @@ -643,6 +757,7 @@ class FlashInferIndicesUpdaterPrefill: self.kv_indptr[wrapper_id], self.qo_indptr[wrapper_id], use_ragged, + spec_info, ) def call_begin_forward( @@ -658,25 +773,37 @@ class FlashInferIndicesUpdaterPrefill: kv_indptr: torch.Tensor, qo_indptr: torch.Tensor, use_ragged: bool, + spec_info: Optional[SpecInfo], ): bs = len(req_pool_indices) - kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) - kv_indptr = kv_indptr[: bs + 1] - kv_indices = torch.empty( - paged_kernel_lens_sum, dtype=torch.int32, device="cuda" - ) - create_flashinfer_kv_indices_triton[(bs,)]( - self.req_to_token, - req_pool_indices, - paged_kernel_lens, - kv_indptr, - kv_start_idx, - kv_indices, - self.req_to_token.shape[1], - ) + if spec_info is None: + # Normal extend + kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.empty( + paged_kernel_lens_sum, dtype=torch.int32, device="cuda" + ) + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + paged_kernel_lens, + kv_indptr, + kv_start_idx, + kv_indices, + self.req_to_token.shape[1], + ) - qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0) - qo_indptr = qo_indptr[: bs + 1] + qo_indptr[1 : bs + 1] = torch.cumsum(seq_lens - prefix_lens, dim=0) + qo_indptr = qo_indptr[: bs + 1] + custom_mask = None + else: + kv_indices, kv_indptr, qo_indptr, custom_mask = ( + spec_info.generate_attn_arg_prefill( + req_pool_indices, + paged_kernel_lens, + self.req_to_token, + ) + ) # extend part if use_ragged: @@ -702,6 +829,7 @@ class FlashInferIndicesUpdaterPrefill: self.head_dim, 1, q_data_type=self.q_data_type, + custom_mask=custom_mask, ) diff --git a/python/sglang/srt/layers/attention/torch_native_backend.py b/python/sglang/srt/layers/attention/torch_native_backend.py index 5e7e0e66e..f73cd168e 100644 --- a/python/sglang/srt/layers/attention/torch_native_backend.py +++ b/python/sglang/srt/layers/attention/torch_native_backend.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING import torch from torch.nn.functional import scaled_dot_product_attention @@ -23,43 +23,6 @@ class TorchNativeAttnBackend(AttentionBackend): """Init the metadata for a forward pass.""" pass - def init_cuda_graph_state(self, max_bs: int): - # TODO: Support CUDA graph - raise ValueError( - "Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph" - ) - - def init_forward_metadata_capture_cuda_graph( - self, - bs: int, - req_pool_indices: torch.Tensor, - seq_lens: torch.Tensor, - encoder_lens: Optional[torch.Tensor] = None, - ): - # TODO: Support CUDA graph - raise ValueError( - "Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph" - ) - - def init_forward_metadata_replay_cuda_graph( - self, - bs: int, - req_pool_indices: torch.Tensor, - seq_lens: torch.Tensor, - seq_lens_sum: int, - encoder_lens: Optional[torch.Tensor] = None, - ): - # TODO: Support CUDA graph - raise ValueError( - "Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph" - ) - - def get_cuda_graph_seq_len_fill_value(self): - # TODO: Support CUDA graph - raise ValueError( - "Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph" - ) - def _run_sdpa_forward_extend( self, query: torch.Tensor, diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 6c8bce20a..a2bc50478 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -1,15 +1,16 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch from sglang.srt.layers.attention import AttentionBackend -from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner + from sglang.srt.speculative.spec_info import SpecInfo class TritonAttnBackend(AttentionBackend): @@ -80,11 +81,17 @@ class TritonAttnBackend(AttentionBackend): def init_forward_metadata_capture_cuda_graph( self, bs: int, + num_token: int, req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, - encoder_lens=None, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInfo], ): - # NOTE: encoder_lens expected to be zeros or None + assert encoder_lens is None, "Not supported" + assert forward_mode.is_decode(), "Not supported" + assert spec_info is None, "Not supported" + self.forward_metadata = ( self.cuda_graph_attn_logits, None, @@ -96,7 +103,9 @@ class TritonAttnBackend(AttentionBackend): req_pool_indices: torch.Tensor, seq_lens: torch.Tensor, seq_lens_sum: int, - encoder_lens=None, + encoder_lens: Optional[torch.Tensor], + forward_mode: ForwardMode, + spec_info: Optional[SpecInfo], ): # NOTE: encoder_lens expected to be zeros or None self.cuda_graph_start_loc.zero_() @@ -107,9 +116,9 @@ class TritonAttnBackend(AttentionBackend): def forward_extend( self, - q, - k, - v, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache=True, @@ -146,9 +155,9 @@ class TritonAttnBackend(AttentionBackend): def forward_decode( self, - q, - k, - v, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, layer: RadixAttention, forward_batch: ForwardBatch, save_kv_cache=True, diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 3e9030742..c4e717ba2 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -25,14 +25,14 @@ from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed.parallel_state import graph_capture from vllm.model_executor.custom_op import CustomOp -from sglang.srt.layers.logits_processor import ( - LogitsMetadata, - LogitsProcessor, - LogitsProcessorOutput, -) +from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native from sglang.srt.layers.torchao_utils import save_gemlite_cache -from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.model_executor.forward_batch_info import ( + CaptureHiddenMode, + ForwardBatch, + ForwardMode, +) from sglang.srt.utils import maybe_torch_compile, monkey_patch_vllm_all_gather if TYPE_CHECKING: @@ -153,6 +153,10 @@ class CudaGraphRunner: if bs <= model_runner.req_to_token_pool.size and bs <= model_runner.server_args.cuda_graph_max_bs ] + + self.capture_forward_mode = ForwardMode.DECODE + self.num_tokens_per_bs = 1 + self.compile_bs = ( [ bs @@ -165,8 +169,8 @@ class CudaGraphRunner: # Attention backend self.max_bs = max(self.capture_bs) - self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs) - + self.max_num_token = self.max_bs * self.num_tokens_per_bs + self.model_runner.attn_backend.init_cuda_graph_state(self.max_num_token) self.seq_len_fill_value = ( self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() ) @@ -179,12 +183,13 @@ class CudaGraphRunner: # Common inputs with torch.device("cuda"): - self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32) + self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int32) self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32) self.seq_lens = torch.full( (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 ) - self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32) + self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int32) + self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int32) if self.is_encoder_decoder: @@ -229,6 +234,9 @@ class CudaGraphRunner: self.model_runner.model.capture_mode = False def can_run(self, forward_batch: ForwardBatch): + if not forward_batch.forward_mode.is_cuda_graph(): + return False + if self.enable_dp_attention: min_num_tokens, max_num_tokens = min(forward_batch.global_num_tokens), max( forward_batch.global_num_tokens @@ -258,12 +266,12 @@ class CudaGraphRunner: def capture(self): with graph_capture() as graph_capture_context: self.stream = graph_capture_context.stream - capture_bs = ( + capture_range = ( tqdm.tqdm(self.capture_bs) if get_tensor_model_parallel_rank() == 0 else self.capture_bs ) - for bs in capture_bs: + for bs in capture_range: with patch_model( self.model_runner.model, bs in self.compile_bs, @@ -283,12 +291,15 @@ class CudaGraphRunner: def capture_one_batch_size(self, bs: int, forward: Callable): graph = torch.cuda.CUDAGraph() stream = self.stream + num_token = bs * self.num_tokens_per_bs # Common inputs - input_ids = self.input_ids[:bs] + input_ids = self.input_ids[:num_token] req_pool_indices = self.req_pool_indices[:bs] seq_lens = self.seq_lens[:bs] - out_cache_loc = self.out_cache_loc[:bs] + out_cache_loc = self.out_cache_loc[:num_token] + positions = self.positions[:num_token] + if self.is_encoder_decoder: encoder_lens = self.encoder_lens[:bs] else: @@ -304,37 +315,41 @@ class CudaGraphRunner: global_num_tokens = None gathered_buffer = None + forward_batch = ForwardBatch( + forward_mode=self.capture_forward_mode, + batch_size=bs, + input_ids=input_ids, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool=self.model_runner.token_to_kv_pool, + attn_backend=self.model_runner.attn_backend, + out_cache_loc=out_cache_loc, + seq_lens_sum=seq_lens_sum, + encoder_lens=encoder_lens, + return_logprob=False, + top_logprobs_nums=[0] * num_token, + positions=positions, + global_num_tokens=global_num_tokens, + mrope_positions=mrope_positions, + gathered_buffer=gathered_buffer, + ) + # Attention backend self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph( bs, + num_token, req_pool_indices, seq_lens, encoder_lens, + forward_batch.forward_mode, + forward_batch.spec_info, ) # Run and capture def run_once(): - forward_batch = ForwardBatch( - forward_mode=ForwardMode.DECODE, - batch_size=bs, - input_ids=input_ids, - req_pool_indices=req_pool_indices, - seq_lens=seq_lens, - req_to_token_pool=self.model_runner.req_to_token_pool, - token_to_kv_pool=self.model_runner.token_to_kv_pool, - attn_backend=self.model_runner.attn_backend, - out_cache_loc=out_cache_loc, - seq_lens_sum=seq_lens_sum, - encoder_lens=encoder_lens, - return_logprob=False, - top_logprobs_nums=[0] * bs, - positions=clamp_position(seq_lens), - mrope_positions=mrope_positions, - global_num_tokens=global_num_tokens, - gathered_buffer=gathered_buffer, - ) logits_output = forward(input_ids, forward_batch.positions, forward_batch) - return logits_output.next_token_logits + return logits_output.next_token_logits, logits_output.hidden_states for _ in range(2): torch.cuda.synchronize() @@ -360,6 +375,9 @@ class CudaGraphRunner: def replay(self, forward_batch: ForwardBatch): assert forward_batch.out_cache_loc is not None raw_bs = forward_batch.batch_size + # In normal decoding case, raw_bs == raw_num_token + # But in speculative decoding, raw_num_token is raw_bs * self.num_tokens_per_bs + raw_num_token = forward_batch.input_ids.numel() # Pad if self.enable_dp_attention: @@ -374,10 +392,13 @@ class CudaGraphRunner: self.out_cache_loc.zero_() # Common inputs - self.input_ids[:raw_bs].copy_(forward_batch.input_ids) + self.input_ids[:raw_num_token].copy_(forward_batch.input_ids) self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) - self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc) + self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc) + positions = clamp_position(forward_batch.seq_lens) + self.positions[:raw_num_token].copy_(positions) + if self.is_encoder_decoder: self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) if forward_batch.mrope_positions is not None: @@ -390,13 +411,18 @@ class CudaGraphRunner: self.seq_lens, forward_batch.seq_lens_sum + (bs - raw_bs), self.encoder_lens, + forward_batch.forward_mode, + forward_batch.spec_info, ) # Replay self.graphs[bs].replay() - next_token_logits = self.output_buffers[bs][:raw_bs] + next_token_logits, hidden_states = self.output_buffers[bs] logits_output = LogitsProcessorOutput( - next_token_logits=next_token_logits, + next_token_logits=next_token_logits[:raw_num_token], + hidden_states=( + hidden_states[:raw_num_token] if hidden_states is not None else None + ), ) return logits_output diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 4f77c8f79..140b0fe7c 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -96,7 +96,7 @@ class ForwardMode(IntEnum): return self == ForwardMode.DRAFT_EXTEND def is_cuda_graph(self): - return self in (ForwardMode.DECODE, ForwardMode.TARGET_VERIFY) + return self == ForwardMode.DECODE or self == ForwardMode.TARGET_VERIFY def is_dummy_first(self): return self == ForwardMode.DUMMY_FIRST