diff --git a/python/sglang/srt/layers/attention_backend.py b/python/sglang/srt/layers/attention_backend.py new file mode 100644 index 000000000..35fe4ed92 --- /dev/null +++ b/python/sglang/srt/layers/attention_backend.py @@ -0,0 +1,383 @@ +from __future__ import annotations + +""" +Support different attention backends. +Now there are two backends: FlashInfer and Triton. +FlashInfer is faster and Triton is easier to customize. +Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode. +""" + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +import torch +import torch.nn as nn +from flashinfer import ( + BatchDecodeWithPagedKVCacheWrapper, + BatchPrefillWithPagedKVCacheWrapper, + BatchPrefillWithRaggedKVCacheWrapper, +) +from flashinfer.cascade import merge_state +from flashinfer.decode import _grouped_size_compiled_for_decode_kernels + +from sglang.global_config import global_config +from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices +from sglang.srt.managers.schedule_batch import ScheduleBatch +from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata + +if TYPE_CHECKING: + from sglang.srt.model_executor.model_runner import ModelRunner + + +class AttentionBackend(ABC): + """The base class of attention backends""" + + @abstractmethod + def init_forward_metadata( + self, batch: ScheduleBatch, input_metadata: InputMetadata + ): + pass + + def forward(self, q, k, v, layer, input_metadata: InputMetadata): + if input_metadata.forward_mode.is_decode(): + return self.forward_decode(q, k, v, layer, input_metadata) + else: + return self.forward_extend(q, k, v, layer, input_metadata) + + +class FlashInferAttnBackend(AttentionBackend): + """Flashinfer attention kernels.""" + + def __init__(self, model_runner: ModelRunner): + super().__init__() + self.model_runner = model_runner + + if not _grouped_size_compiled_for_decode_kernels( + model_runner.model_config.num_attention_heads // model_runner.tp_size, + model_runner.model_config.get_num_kv_heads(model_runner.tp_size), + ): + self.decode_use_tensor_cores = True + else: + self.decode_use_tensor_cores = False + + self.workspace_buffer = torch.empty( + global_config.flashinfer_workspace_size, + dtype=torch.uint8, + device="cuda", + ) + + if model_runner.sliding_window_size is None: + self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper( + self.workspace_buffer, "NHD" + ) + self.prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper( + self.workspace_buffer, "NHD" + ) + self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + self.workspace_buffer, + "NHD", + use_tensor_cores=self.decode_use_tensor_cores, + ) + else: + # Two wrappers: one for sliding window attention and one for full attention. + # Using two wrappers is unnecessary in the current PR, but are prepared for future PRs + self.prefill_wrapper_ragged = None + self.prefill_wrapper_paged = [] + self.decode_wrapper = [] + for _ in range(2): + self.prefill_wrapper_paged.append( + BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD") + ) + self.decode_wrapper.append( + BatchDecodeWithPagedKVCacheWrapper( + self.workspace_buffer, + "NHD", + use_tensor_cores=self.decode_use_tensor_cores, + ) + ) + + self.forward_metadata = None + self.cuda_graph_metadata = {} + + def init_forward_metadata( + self, batch: ScheduleBatch, input_metadata: InputMetadata + ): + if input_metadata.forward_mode.is_decode(): + prefix_lens = None + use_ragged = False + total_num_tokens = None + else: + prefix_lens = input_metadata.extend_prefix_lens + + # Some heuristics to check whether to use ragged forward + use_ragged = False + if ( + int(torch.sum(input_metadata.seq_lens)) > 4096 + and self.model_runner.sliding_window_size is None + ): + use_ragged = True + + total_num_tokens = torch.sum(input_metadata.seq_lens).item() + + update_flashinfer_indices( + input_metadata.forward_mode, + self.model_runner, + input_metadata.req_pool_indices, + input_metadata.seq_lens, + prefix_lens, + use_ragged=use_ragged, + ) + + self.forward_metadata = (use_ragged, total_num_tokens, self.decode_wrapper) + + def init_cuda_graph_state(self, max_bs: int): + self.cuda_graph_kv_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int32, device="cuda" + ) + self.cuda_graph_kv_indices = torch.zeros( + (max_bs * self.model_runner.model_config.context_len,), + dtype=torch.int32, + device="cuda", + ) + self.cuda_graph_kv_last_page_len = torch.ones( + (max_bs,), dtype=torch.int32, device="cuda" + ) + + if self.model_runner.sliding_window_size is not None: + self.cuda_graph_kv_indptr = [ + self.cuda_graph_kv_indptr, + self.cuda_graph_kv_indptr.clone(), + ] + self.cuda_graph_kv_indices = [ + self.cuda_graph_kv_indices, + self.cuda_graph_kv_indices.clone(), + ] + + def capture_cuda_graph_init(self, bs: int, req_pool_indices, seq_lens): + if self.model_runner.sliding_window_size is None: + decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + self.workspace_buffer, + "NHD", + use_cuda_graph=True, + use_tensor_cores=self.decode_use_tensor_cores, + paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[: bs + 1], + paged_kv_indices_buffer=self.cuda_graph_kv_indices, + paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[:bs], + ) + else: + decode_wrapper = [] + for i in range(2): + decode_wrapper.append( + BatchDecodeWithPagedKVCacheWrapper( + self.workspace_buffer, + "NHD", + use_cuda_graph=True, + use_tensor_cores=self.decode_use_tensor_cores, + paged_kv_indptr_buffer=self.cuda_graph_kv_indptr[i][: bs + 1], + paged_kv_indices_buffer=self.cuda_graph_kv_indices[i], + paged_kv_last_page_len_buffer=self.cuda_graph_kv_last_page_len[ + :bs + ], + ) + ) + + update_flashinfer_indices( + ForwardMode.DECODE, + self.model_runner, + req_pool_indices, + seq_lens, + None, + decode_wrapper, + ) + + self.cuda_graph_metadata[bs] = decode_wrapper + + self.forward_metadata = (False, None, decode_wrapper) + + def replay_cuda_graph_init(self, bs: int, req_pool_indices, seq_lens): + update_flashinfer_indices( + ForwardMode.DECODE, + self.model_runner, + req_pool_indices[:bs], + seq_lens[:bs], + None, + self.cuda_graph_metadata[bs], + ) + + def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): + if not isinstance(self.prefill_wrapper_paged, list): + prefill_wrapper_paged = self.prefill_wrapper_paged + else: + if layer.sliding_window_size != -1: + prefill_wrapper_paged = self.prefill_wrapper_paged[0] + else: + prefill_wrapper_paged = self.prefill_wrapper_paged[1] + + use_ragged, total_num_tokens, decode_wrapper = self.forward_metadata + + if not use_ragged: + if k is not None: + assert v is not None + input_metadata.token_to_kv_pool.set_kv_buffer( + layer.layer_id, input_metadata.out_cache_loc, k, v + ) + o = prefill_wrapper_paged.forward( + q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id), + causal=True, + sm_scale=layer.scaling, + window_left=layer.sliding_window_size, + logits_soft_cap=layer.logit_cap, + ) + else: + o1, s1 = self.prefill_wrapper_ragged.forward_return_lse( + q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + k.contiguous().view(-1, layer.tp_k_head_num, layer.head_dim), + v.contiguous().view(-1, layer.tp_v_head_num, layer.head_dim), + causal=True, + sm_scale=layer.scaling, + logits_soft_cap=layer.logit_cap, + ) + + if input_metadata.extend_no_prefix: + o = o1 + else: + o2, s2 = prefill_wrapper_paged.forward_return_lse( + q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id), + causal=False, + sm_scale=layer.scaling, + logits_soft_cap=layer.logit_cap, + ) + + o, _ = merge_state(o1, s1, o2, s2) + + input_metadata.token_to_kv_pool.set_kv_buffer( + layer.layer_id, input_metadata.out_cache_loc, k, v + ) + + if total_num_tokens >= global_config.layer_sync_threshold: + torch.cuda.synchronize() + + return o.view(-1, layer.tp_q_head_num * layer.head_dim) + + def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): + use_ragged, total_num_tokens, decode_wrapper = self.forward_metadata + + if isinstance(decode_wrapper, list): + if layer.sliding_window_size != -1: + decode_wrapper = decode_wrapper[0] + else: + decode_wrapper = decode_wrapper[1] + + if k is not None: + assert v is not None + input_metadata.token_to_kv_pool.set_kv_buffer( + layer.layer_id, input_metadata.out_cache_loc, k, v + ) + + o = decode_wrapper.forward( + q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + input_metadata.token_to_kv_pool.get_kv_buffer(layer.layer_id), + sm_scale=layer.scaling, + logits_soft_cap=layer.logit_cap, + ) + + return o.view(-1, layer.tp_q_head_num * layer.head_dim) + + +class TritonAttnBackend(AttentionBackend): + def __init__(self, model_runner: ModelRunner): + # Lazy import to avoid the initialization of cuda context + from sglang.srt.layers.triton_attention.decode_attention import ( + decode_attention_fwd, + ) + from sglang.srt.layers.triton_attention.extend_attention import ( + extend_attention_fwd, + ) + + super().__init__() + + self.decode_attention_fwd = decode_attention_fwd + self.extend_attention_fwd = extend_attention_fwd + + self.forward_metadata = None + + def init_forward_metadata( + self, batch: ScheduleBatch, input_metadata: InputMetadata + ): + """Init auxiliary variables for triton attention backend.""" + + if input_metadata.forward_mode.is_decode(): + max_seq_len = torch.max(input_metadata.seq_lens).item() + start_loc = torch.zeros_like(input_metadata.seq_lens, dtype=torch.int32) + start_loc[1:] = torch.cumsum(input_metadata.seq_lens[:-1], dim=0) + + total_num_tokens = torch.sum(input_metadata.seq_lens).item() + max_extend_len = None + else: + start_loc = max_seq_len = total_num_tokens = None + prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda") + max_extend_len = torch.max(input_metadata.seq_lens - prefix_lens).item() + + self.forward_metadata = start_loc, max_seq_len, max_extend_len, total_num_tokens + + def forward_extend(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): + if layer.qk_head_dim != layer.v_head_dim: + o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) + else: + o = torch.empty_like(q) + + input_metadata.token_to_kv_pool.set_kv_buffer( + layer.layer_id, input_metadata.out_cache_loc, k, v + ) + + start_loc, max_seq_len, max_extend_len, total_num_tokens = self.forward_metadata + + self.extend_attention_fwd( + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + k.contiguous(), + v.contiguous(), + o.view(-1, layer.tp_q_head_num, layer.v_head_dim), + input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id), + input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id), + input_metadata.req_to_token_pool.req_to_token, + input_metadata.req_pool_indices, + input_metadata.seq_lens, + input_metadata.extend_seq_lens, + input_metadata.extend_start_loc, + max_extend_len, + layer.scaling, + layer.logit_cap, + ) + + return o + + def forward_decode(self, q, k, v, layer: nn.Module, input_metadata: InputMetadata): + if layer.qk_head_dim != layer.v_head_dim: + o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim)) + else: + o = torch.empty_like(q) + + start_loc, max_seq_len, max_extend_len, total_num_tokens = self.forward_metadata + + input_metadata.token_to_kv_pool.set_kv_buffer( + layer.layer_id, input_metadata.out_cache_loc, k, v + ) + + self.decode_attention_fwd( + q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), + input_metadata.token_to_kv_pool.get_key_buffer(layer.layer_id), + input_metadata.token_to_kv_pool.get_value_buffer(layer.layer_id), + o.view(-1, layer.tp_q_head_num, layer.v_head_dim), + input_metadata.req_to_token_pool.req_to_token, + input_metadata.req_pool_indices, + start_loc, + input_metadata.seq_lens, + max_seq_len, + total_num_tokens, + layer.scaling, + layer.logit_cap, + ) + + return o diff --git a/python/sglang/srt/layers/flashinfer_utils.py b/python/sglang/srt/layers/flashinfer_utils.py index 1f9ab1514..c473d6e45 100644 --- a/python/sglang/srt/layers/flashinfer_utils.py +++ b/python/sglang/srt/layers/flashinfer_utils.py @@ -10,8 +10,8 @@ def create_flashinfer_kv_indices_triton( page_kernel_lens_ptr, kv_indptr, kv_start_idx, - max_context_len, kv_indices_ptr, + max_context_len: tl.constexpr, ): BLOCK_SIZE: tl.constexpr = 512 pid = tl.program_id(axis=0) @@ -47,15 +47,15 @@ class FlashinferUpdater: req_pool_indices, seq_lens, prefix_lens, - flashinfer_decode_wrapper=None, - flashinfer_use_ragged=False, + decode_wrapper=None, + use_ragged=False, ): self.forward_mode = forward_mode self.model_runner = model_runner self.req_pool_indices = req_pool_indices self.seq_lens = seq_lens self.prefix_lens = prefix_lens - self.flashinfer_use_ragged = flashinfer_use_ragged + self.use_ragged = use_ragged self.num_qo_heads = ( model_runner.model_config.num_attention_heads // model_runner.tp_size @@ -71,20 +71,17 @@ class FlashinferUpdater: ) ( - self.flashinfer_decode_wrapper, - self.flashinfer_prefill_wrapper_ragged, - self.flashinfer_prefill_wrapper_paged, + self.decode_wrapper, + self.prefill_wrapper_ragged, + self.prefill_wrapper_paged, ) = ( - flashinfer_decode_wrapper, - self.model_runner.flashinfer_prefill_wrapper_ragged, - self.model_runner.flashinfer_prefill_wrapper_paged, + decode_wrapper or self.model_runner.attn_backend.decode_wrapper, + self.model_runner.attn_backend.prefill_wrapper_ragged, + self.model_runner.attn_backend.prefill_wrapper_paged, ) - # CUDA graph uses different flashinfer_decode_wrapper - if self.flashinfer_decode_wrapper is None: - self.flashinfer_decode_wrapper = self.model_runner.flashinfer_decode_wrapper - def _init_indices_no_window(self): - if self.flashinfer_use_ragged: + def _init_indices_no_sliding_window(self): + if self.use_ragged: paged_kernel_lens = self.prefix_lens else: paged_kernel_lens = self.seq_lens @@ -103,13 +100,13 @@ class FlashinferUpdater: paged_kernel_lens, self.kv_indptr, None, - self.model_runner.req_to_token_pool.req_to_token.size(1), self.kv_indices, + self.model_runner.req_to_token_pool.req_to_token.size(1), ) - def _init_indices_window(self, wrapper_id): - # window attention use paged only + def _init_indices_sliding_window(self, wrapper_id): if wrapper_id == 0: + # window attention use paged only if self.forward_mode.is_decode(): paged_kernel_lens = torch.minimum( self.seq_lens, @@ -123,6 +120,7 @@ class FlashinferUpdater: - self.prefix_lens, ) else: + # full attention paged_kernel_lens = self.seq_lens kv_start_idx = self.seq_lens - paged_kernel_lens @@ -139,8 +137,8 @@ class FlashinferUpdater: paged_kernel_lens, self.kv_indptr, kv_start_idx, - self.model_runner.req_to_token_pool.req_to_token.size(1), self.kv_indices, + self.model_runner.req_to_token_pool.req_to_token.size(1), ) def _update_decode_indices(self, decode_wrapper): @@ -164,7 +162,7 @@ class FlashinferUpdater: ) qo_indptr[1:] = torch.cumsum(self.seq_lens - self.prefix_lens, dim=0) - if self.flashinfer_use_ragged: + if self.use_ragged: ragged_wrapper.end_forward() ragged_wrapper.begin_forward( qo_indptr, @@ -187,28 +185,28 @@ class FlashinferUpdater: 1, ) - def update_indices_no_window(self): - self._init_indices_no_window() + def update_indices_no_sliding_window(self): + self._init_indices_no_sliding_window() if self.forward_mode.is_decode(): - self._update_decode_indices(self.flashinfer_decode_wrapper) + self._update_decode_indices(self.decode_wrapper) else: self._update_extend_indices( - self.flashinfer_prefill_wrapper_ragged, - self.flashinfer_prefill_wrapper_paged, + self.prefill_wrapper_ragged, + self.prefill_wrapper_paged, ) - def update_indices_window(self): - assert self.flashinfer_use_ragged is False + def update_indices_sliding_window(self): + assert self.use_ragged is False for wrapper_id in range(2): - self._init_indices_window(wrapper_id) + self._init_indices_sliding_window(wrapper_id) if self.forward_mode.is_decode(): - self._update_decode_indices(self.flashinfer_decode_wrapper[wrapper_id]) + self._update_decode_indices(self.decode_wrapper[wrapper_id]) else: self._update_extend_indices( None, - self.flashinfer_prefill_wrapper_paged[wrapper_id], + self.prefill_wrapper_paged[wrapper_id], ) @@ -218,20 +216,20 @@ def update_flashinfer_indices( req_pool_indices, seq_lens, prefix_lens, - flashinfer_decode_wrapper=None, - flashinfer_use_ragged=False, + decode_wrapper=None, + use_ragged=False, ): - flashinfer_updater = FlashinferUpdater( + updater = FlashinferUpdater( forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens, - flashinfer_decode_wrapper, - flashinfer_use_ragged, + decode_wrapper, + use_ragged, ) if model_runner.sliding_window_size is None: - flashinfer_updater.update_indices_no_window() + updater.update_indices_no_sliding_window() else: - flashinfer_updater.update_indices_window() + updater.update_indices_sliding_window() diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 48567e43d..8454d2928 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -15,25 +15,14 @@ limitations under the License. """Radix attention.""" -from typing import Optional - -import torch -from flashinfer.cascade import merge_state from torch import nn -from sglang.global_config import global_config -from sglang.srt.layers.triton_attention.decode_attention import decode_attention_fwd -from sglang.srt.layers.triton_attention.extend_attention import extend_attention_fwd -from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata -from sglang.srt.model_executor.model_runner import global_server_args_dict +from sglang.srt.model_executor.forward_batch_info import InputMetadata class RadixAttention(nn.Module): """ The attention layer implementation. - Now it has two backends: FlashInfer and Triton. - FlashInfer is faster and Triton is easier to customize. - It supports two operators: extend (i.e. prefill with cached prefix) and decode. """ def __init__( @@ -43,8 +32,8 @@ class RadixAttention(nn.Module): scaling: float, num_kv_heads: int, layer_id: int, - sliding_window_size: Optional[int] = None, - logit_cap: int = -1, + sliding_window_size: int = -1, + logit_cap: float = 0.0, v_head_dim: int = -1, ): super().__init__() @@ -56,164 +45,14 @@ class RadixAttention(nn.Module): self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim self.scaling = scaling self.layer_id = layer_id - self.logit_cap = logit_cap if logit_cap is not None and logit_cap > 0 else 0 - self.sliding_window_size = sliding_window_size if sliding_window_size else -1 - - # Choose backend - if ( - global_server_args_dict["attention_backend"] == "flashinfer" - and self.qk_head_dim == self.v_head_dim - ): - self.extend_forward = self.extend_forward_flashinfer - self.decode_forward = self.decode_forward_flashinfer - elif global_server_args_dict["attention_backend"] == "triton": - self.extend_forward = self.extend_forward_triton - self.decode_forward = self.decode_forward_triton - else: - raise ValueError( - f"Invalid attention backend: {global_server_args_dict['attention_backend']}" - ) - - def extend_forward_triton(self, q, k, v, input_metadata: InputMetadata): - if self.qk_head_dim != self.v_head_dim: - o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim)) - else: - o = torch.empty_like(q) - - self.store_kv_cache(k, v, input_metadata) - extend_attention_fwd( - q.view(-1, self.tp_q_head_num, self.qk_head_dim), - k.contiguous(), - v.contiguous(), - o.view(-1, self.tp_q_head_num, self.v_head_dim), - input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id), - input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id), - input_metadata.req_to_token_pool.req_to_token, - input_metadata.req_pool_indices, - input_metadata.triton_start_loc, - input_metadata.seq_lens, - input_metadata.triton_prefix_lens, - input_metadata.extend_start_loc, - input_metadata.extend_seq_lens, - input_metadata.triton_max_seq_len, - input_metadata.triton_max_extend_len, - sm_scale=self.scaling, - logit_cap=self.logit_cap, - ) - - return o - - def decode_forward_triton(self, q, k, v, input_metadata: InputMetadata): - if self.qk_head_dim != self.v_head_dim: - o = q.new_empty((q.shape[0], self.tp_q_head_num * self.v_head_dim)) - else: - o = torch.empty_like(q) - self.store_kv_cache(k, v, input_metadata) - - decode_attention_fwd( - q.view(-1, self.tp_q_head_num, self.qk_head_dim), - input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id), - input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id), - o.view(-1, self.tp_q_head_num, self.v_head_dim), - input_metadata.req_to_token_pool.req_to_token, - input_metadata.req_pool_indices, - input_metadata.triton_start_loc, - input_metadata.seq_lens, - input_metadata.triton_max_seq_len, - input_metadata.total_num_tokens, - sm_scale=self.scaling, - logit_cap=self.logit_cap, - ) - - return o - - def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): - # using two wrappers is unnecessary in the current PR, but are prepared for future PRs - prefill_wrapper_paged = input_metadata.flashinfer_prefill_wrapper_paged - if self.sliding_window_size != -1: - prefill_wrapper_paged = prefill_wrapper_paged[0] - else: - if isinstance(prefill_wrapper_paged, list): - prefill_wrapper_paged = prefill_wrapper_paged[1] - - if not input_metadata.flashinfer_use_ragged: - if k is not None: - assert v is not None - self.store_kv_cache(k, v, input_metadata) - - o = prefill_wrapper_paged.forward( - q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), - input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id), - causal=True, - sm_scale=self.scaling, - window_left=self.sliding_window_size, - logits_soft_cap=self.logit_cap, - ) - else: - o1, s1 = ( - input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse( - q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), - k.contiguous().view(-1, self.tp_k_head_num, self.head_dim), - v.contiguous().view(-1, self.tp_v_head_num, self.head_dim), - causal=True, - sm_scale=self.scaling, - logits_soft_cap=self.logit_cap, - ) - ) - - if input_metadata.extend_no_prefix: - o = o1 - else: - o2, s2 = prefill_wrapper_paged.forward_return_lse( - q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), - input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id), - causal=False, - sm_scale=self.scaling, - logits_soft_cap=self.logit_cap, - ) - - o, _ = merge_state(o1, s1, o2, s2) - - self.store_kv_cache(k, v, input_metadata) - - if input_metadata.total_num_tokens >= global_config.layer_sync_threshold: - torch.cuda.synchronize() - - return o.view(-1, self.tp_q_head_num * self.head_dim) - - def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): - decode_wrapper = input_metadata.flashinfer_decode_wrapper - if self.sliding_window_size != -1: - decode_wrapper = decode_wrapper[0] - else: - if isinstance(decode_wrapper, list): - decode_wrapper = decode_wrapper[1] - - if k is not None: - assert v is not None - self.store_kv_cache(k, v, input_metadata) - - o = decode_wrapper.forward( - q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), - input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id), - sm_scale=self.scaling, - logits_soft_cap=self.logit_cap, - ) - - return o.view(-1, self.tp_q_head_num * self.head_dim) + self.logit_cap = logit_cap + self.sliding_window_size = sliding_window_size or -1 def forward(self, q, k, v, input_metadata: InputMetadata): if k is not None: + # For cross-layer sharing, kv can be None assert v is not None k = k.view(-1, self.tp_k_head_num, self.qk_head_dim) v = v.view(-1, self.tp_v_head_num, self.v_head_dim) - if input_metadata.forward_mode.is_extend(): - return self.extend_forward(q, k, v, input_metadata) - elif input_metadata.forward_mode.is_decode(): - return self.decode_forward(q, k, v, input_metadata) - - def store_kv_cache(self, cache_k, cache_v, input_metadata: InputMetadata): - input_metadata.token_to_kv_pool.set_kv_buffer( - self.layer_id, input_metadata.out_cache_loc, cache_k, cache_v - ) + return input_metadata.attn_backend.forward(q, k, v, self, input_metadata) diff --git a/python/sglang/srt/layers/triton_attention/decode_attention.py b/python/sglang/srt/layers/triton_attention/decode_attention.py index 82ce6efc5..adfa0d936 100644 --- a/python/sglang/srt/layers/triton_attention/decode_attention.py +++ b/python/sglang/srt/layers/triton_attention/decode_attention.py @@ -15,6 +15,7 @@ limitations under the License. """ Memory-efficient attention for decoding. +It supports page size = 1. """ # Adapted from @@ -197,7 +198,6 @@ def _decode_att_m_fwd( logit_cap, ): BLOCK = 32 - # shape constraints Lq, Lk = q.shape[-1], k_buffer.shape[-1] batch, head_num = B_req_idx.shape[0], q.shape[1] @@ -478,7 +478,6 @@ def _decode_grouped_att_m_fwd( logit_cap, ): BLOCK = 32 - # shape constraints Lq, Lk = q.shape[-1], k_buffer.shape[-1] if Lk == 576: @@ -570,9 +569,9 @@ def _decode_grouped_softmax_reducev_fwd( BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_N=BLOCK, BLOCK_H=BLOCK_H, + Lv=Lv, num_warps=num_warps, num_stages=1, - Lv=Lv, ) @@ -588,7 +587,7 @@ def decode_attention_fwd( max_len_in_batch, total_num_tokens, sm_scale, - logit_cap=-1, + logit_cap=0.0, att_m=None, ): if att_m is None: diff --git a/python/sglang/srt/layers/triton_attention/extend_attention.py b/python/sglang/srt/layers/triton_attention/extend_attention.py index 1193c4124..3cf150d8d 100644 --- a/python/sglang/srt/layers/triton_attention/extend_attention.py +++ b/python/sglang/srt/layers/triton_attention/extend_attention.py @@ -61,14 +61,14 @@ def _fwd_kernel( stride_buf_vbs, stride_buf_vh, stride_req_to_tokens_b, + logit_cap: tl.constexpr, + Lq: tl.constexpr, + Lv: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_DPE: tl.constexpr, BLOCK_DV: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, - logit_cap: tl.constexpr, - Lq: tl.constexpr, - Lv: tl.constexpr, ): cur_seq = tl.program_id(0) cur_head = tl.program_id(1) @@ -111,7 +111,7 @@ def _fwd_kernel( ) qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0) - # stage1: compute scores with prefix + # stage 1: compute scores with prefix offs_n = tl.arange(0, BLOCK_N) acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32) @@ -174,7 +174,7 @@ def _fwd_kernel( e_max = n_e_max - # stage2: compute the trianlge part + # stage 2: compute the trianlge part cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M) for start_n in range(0, cur_block_m_end, BLOCK_N): @@ -255,26 +255,22 @@ def extend_attention_fwd( v_buffer, req_to_tokens, b_req_idx, - b_start_loc, b_seq_len, - b_seq_len_prefix, - b_start_loc_extend, b_seq_len_extend, - max_len_in_batch, + b_start_loc_extend, max_len_extend, sm_scale=None, - logit_cap=-1, + logit_cap=0.0, ): """ q_extend, k_extend, v_extend, o_extend: contiguous tensors k_buffer, v_buffer: (prefix + extend) tensors in mem_manager """ - Lq, Lk, Lv, Lo = ( + Lq, Lk, Lv = ( q_extend.shape[-1], k_extend.shape[-1], v_extend.shape[-1], - o_extend.shape[-1], ) if Lq == 576: @@ -303,7 +299,7 @@ def extend_attention_fwd( else: BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32) - sm_scale = 1.0 / (Lq**0.5) if sm_scale is None else sm_scale + sm_scale = sm_scale or 1.0 / (Lq**0.5) batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1] kv_group_num = q_extend.shape[1] // k_extend.shape[1] @@ -338,27 +334,24 @@ def extend_attention_fwd( v_buffer.stride(0), v_buffer.stride(1), req_to_tokens.stride(0), + logit_cap=logit_cap, BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DPE=BLOCK_DPE, BLOCK_DV=BLOCK_DV, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, - num_warps=num_warps, - num_stages=num_stages, - logit_cap=logit_cap, Lq=Lq, Lv=Lv, + num_warps=num_warps, + num_stages=num_stages, ) def redundant_attention( q_extend, - k_extend, - v_extend, o_extend, k_buffer, v_buffer, - req_to_tokens, b_req_idx, b_start_loc, b_seq_len, diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index b6000734a..0a5eb3cdf 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -368,7 +368,7 @@ class ScheduleBatch: ) def batch_size(self): - return len(self.reqs) if self.reqs is not None else 0 + return len(self.reqs) if self.reqs else 0 def is_empty(self): return len(self.reqs) == 0 diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index c24dd5084..ecaeb404c 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -13,15 +13,13 @@ See the License for the specific language governing permissions and limitations under the License. """ -"""Run the model with cuda graph.""" +"""Run the model with cuda graph and torch.compile.""" import bisect from contextlib import contextmanager -from typing import Callable, List +from typing import Callable import torch -from flashinfer import BatchDecodeWithPagedKVCacheWrapper -from flashinfer.decode import _grouped_size_compiled_for_decode_kernels from vllm.distributed.parallel_state import graph_capture from vllm.model_executor.custom_op import CustomOp @@ -55,6 +53,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False): def patch_model( model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator" ): + """Patch the model to make it compatible with with torch.compile""" backup_ca_comm = None try: @@ -86,23 +85,28 @@ def set_torch_compile_config(): class CudaGraphRunner: - def __init__( - self, - model_runner: "ModelRunner", - max_batch_size_to_capture: int, - use_torch_compile: bool, - disable_padding: bool, - ): + """A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile.""" + + def __init__(self, model_runner: "ModelRunner"): + # Parse args self.model_runner = model_runner self.graphs = {} self.input_buffers = {} self.output_buffers = {} self.flashinfer_handlers = {} self.graph_memory_pool = None - self.disable_padding = disable_padding + self.use_torch_compile = model_runner.server_args.enable_torch_compile + self.disable_padding = model_runner.server_args.disable_cuda_graph_padding + + # Batch sizes to capture + if self.model_runner.server_args.disable_cuda_graph_padding: + self.capture_bs = list(range(1, 32)) + [64, 128] + else: + self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)] + self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if self.use_torch_compile else [] # Common inputs - self.max_bs = max_batch_size_to_capture + self.max_bs = max(self.capture_bs) self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda") self.req_pool_indices = torch.zeros( (self.max_bs,), dtype=torch.int32, device="cuda" @@ -115,56 +119,39 @@ class CudaGraphRunner: (self.max_bs,), dtype=torch.int32, device="cuda" ) - # FlashInfer inputs - self.flashinfer_kv_indptr = torch.zeros( - (self.max_bs + 1,), dtype=torch.int32, device="cuda" - ) - self.flashinfer_kv_indices = torch.zeros( - (self.max_bs * model_runner.model_config.context_len,), - dtype=torch.int32, - device="cuda", - ) - self.flashinfer_kv_last_page_len = torch.ones( - (self.max_bs,), dtype=torch.int32, device="cuda" - ) - if model_runner.sliding_window_size is None: - self.flashinfer_workspace_buffer = ( - self.model_runner.flashinfer_workspace_buffer - ) - else: - self.flashinfer_workspace_buffer = ( - self.model_runner.flashinfer_workspace_buffer - ) + # Attention backend + self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs) - self.flashinfer_kv_indptr = [ - self.flashinfer_kv_indptr, - self.flashinfer_kv_indptr.clone(), - ] - self.flashinfer_kv_indices = [ - self.flashinfer_kv_indices, - self.flashinfer_kv_indices.clone(), - ] - - # Sampling inputs + # Sampling info vocab_size = model_runner.model_config.vocab_size self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size) - self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else [] - - if use_torch_compile: + if self.use_torch_compile: set_torch_compile_config() + # Capture + try: + self.capture() + except RuntimeError as e: + raise Exception( + f"Capture cuda graph failed: {e}\n" + "Possible solutions:\n" + "1. disable cuda graph by --disable-cuda-graph\n" + "2. set --mem-fraction-static to a smaller value\n" + "3. disable torch compile by not using --enable-torch-compile\n" + "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n" + ) + def can_run(self, batch_size: int): if self.disable_padding: return batch_size in self.graphs else: return batch_size <= self.max_bs - def capture(self, batch_size_list: List[int]): - self.batch_size_list = batch_size_list + def capture(self): with graph_capture() as graph_capture_context: self.stream = graph_capture_context.stream - for bs in batch_size_list: + for bs in self.capture_bs: with patch_model( self.model_runner.model, bs in self.compile_bs, @@ -172,14 +159,10 @@ class CudaGraphRunner: ) as forward: ( graph, - input_buffers, output_buffers, - flashinfer_handler, ) = self.capture_one_batch_size(bs, forward) self.graphs[bs] = graph - self.input_buffers[bs] = input_buffers self.output_buffers[bs] = output_buffers - self.flashinfer_handlers[bs] = flashinfer_handler def capture_one_batch_size(self, bs: int, forward: Callable): graph = torch.cuda.CUDAGraph() @@ -192,48 +175,9 @@ class CudaGraphRunner: position_ids_offsets = self.position_ids_offsets[:bs] out_cache_loc = self.out_cache_loc[:bs] - # FlashInfer inputs - if not _grouped_size_compiled_for_decode_kernels( - self.model_runner.model_config.num_attention_heads - // self.model_runner.tp_size, - self.model_runner.model_config.get_num_kv_heads(self.model_runner.tp_size), - ): - use_tensor_cores = True - else: - use_tensor_cores = False - if self.model_runner.sliding_window_size is None: - flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( - self.flashinfer_workspace_buffer, - "NHD", - use_cuda_graph=True, - use_tensor_cores=use_tensor_cores, - paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1], - paged_kv_indices_buffer=self.flashinfer_kv_indices, - paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs], - ) - else: - flashinfer_decode_wrapper = [] - for i in range(2): - flashinfer_decode_wrapper.append( - BatchDecodeWithPagedKVCacheWrapper( - self.flashinfer_workspace_buffer, - "NHD", - use_cuda_graph=True, - use_tensor_cores=use_tensor_cores, - paged_kv_indptr_buffer=self.flashinfer_kv_indptr[i][: bs + 1], - paged_kv_indices_buffer=self.flashinfer_kv_indices[i], - paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[ - :bs - ], - ) - ) - update_flashinfer_indices( - ForwardMode.DECODE, - self.model_runner, - req_pool_indices, - seq_lens, - None, - flashinfer_decode_wrapper, + # Attention backend + self.model_runner.attn_backend.capture_cuda_graph_init( + bs, req_pool_indices, seq_lens ) # Run and capture @@ -246,13 +190,12 @@ class CudaGraphRunner: seq_lens=seq_lens, req_to_token_pool=self.model_runner.req_to_token_pool, token_to_kv_pool=self.model_runner.token_to_kv_pool, + attn_backend=self.model_runner.attn_backend, out_cache_loc=out_cache_loc, return_logprob=False, top_logprobs_nums=0, positions=(seq_lens - 1 + position_ids_offsets).to(torch.int64), - flashinfer_decode_wrapper=flashinfer_decode_wrapper, ) - return forward(input_ids, input_metadata.positions, input_metadata) for _ in range(2): @@ -274,15 +217,15 @@ class CudaGraphRunner: self.model_runner.tp_group.barrier() self.graph_memory_pool = graph.pool() - return graph, None, out, flashinfer_decode_wrapper + return graph, out def replay(self, batch: ScheduleBatch): assert batch.out_cache_loc is not None raw_bs = len(batch.reqs) # Pad - index = bisect.bisect_left(self.batch_size_list, raw_bs) - bs = self.batch_size_list[index] + index = bisect.bisect_left(self.capture_bs, raw_bs) + bs = self.capture_bs[index] if bs != raw_bs: self.seq_lens.zero_() self.position_ids_offsets.fill_(1) @@ -295,14 +238,9 @@ class CudaGraphRunner: self.position_ids_offsets[:raw_bs] = batch.position_ids_offsets self.out_cache_loc[:raw_bs] = batch.out_cache_loc - # FlashInfer inputs - update_flashinfer_indices( - ForwardMode.DECODE, - self.model_runner, - self.req_pool_indices[:bs], - self.seq_lens[:bs], - None, - self.flashinfer_handlers[bs], + # Attention backend + self.model_runner.attn_backend.replay_cuda_graph_init( + bs, self.req_pool_indices, self.seq_lens ) # Sampling inputs diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index f3bed6bcf..8542ced35 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -23,9 +23,8 @@ from typing import TYPE_CHECKING, List import numpy as np import torch -from sglang.srt.layers.flashinfer_utils import update_flashinfer_indices - if TYPE_CHECKING: + from sglang.srt.layers.attention_backend import AttentionBackend from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.model_executor.model_runner import ModelRunner @@ -66,12 +65,11 @@ class InputMetadata: seq_lens: torch.Tensor req_to_token_pool: ReqToTokenPool token_to_kv_pool: BaseTokenToKVPool + attn_backend: AttentionBackend # Output location of the KV cache out_cache_loc: torch.Tensor - total_num_tokens: int = None - # Position information positions: torch.Tensor = None @@ -93,18 +91,6 @@ class InputMetadata: image_offsets: List[List[int]] = None modalities: List[List[str]] = None - # Trition attention backend - triton_max_seq_len: int = 0 - triton_max_extend_len: int = 0 - triton_start_loc: torch.Tensor = None - triton_prefix_lens: torch.Tensor = None - - # FlashInfer attention backend - flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None - flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None - flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None - flashinfer_use_ragged: bool = False - def init_multimuldal_info(self, batch: ScheduleBatch): reqs = batch.reqs self.pixel_values = [r.pixel_values for r in reqs] @@ -154,32 +140,27 @@ class InputMetadata: self.positions = self.positions.to(torch.int64) def compute_extend_infos(self, batch: ScheduleBatch): - if self.forward_mode.is_decode(): - self.extend_seq_lens = self.extend_start_loc = self.extend_no_prefix = None - self.extend_seq_lens_cpu = self.logprob_start_lens_cpu = None - else: - extend_lens_cpu = [ - len(r.fill_ids) - batch.prefix_lens_cpu[i] - for i, r in enumerate(batch.reqs) - ] - self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda") - self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda") - self.extend_start_loc = torch.zeros_like(self.seq_lens) - self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0) - self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu) + extend_lens_cpu = [ + len(r.fill_ids) - batch.prefix_lens_cpu[i] for i, r in enumerate(batch.reqs) + ] + self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda") + self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda") + self.extend_start_loc = torch.zeros_like(self.seq_lens) + self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0) + self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu) - self.extend_seq_lens_cpu = extend_lens_cpu - self.logprob_start_lens_cpu = [ - ( - min( - req.logprob_start_len - batch.prefix_lens_cpu[i], - extend_lens_cpu[i] - 1, - ) - if req.logprob_start_len >= batch.prefix_lens_cpu[i] - else extend_lens_cpu[i] - 1 # Fake extend, actually decode + self.extend_seq_lens_cpu = extend_lens_cpu + self.logprob_start_lens_cpu = [ + ( + min( + req.logprob_start_len - batch.prefix_lens_cpu[i], + extend_lens_cpu[i] - 1, ) - for i, req in enumerate(batch.reqs) - ] + if req.logprob_start_len >= batch.prefix_lens_cpu[i] + else extend_lens_cpu[i] - 1 # Fake extend, actually decode + ) + for i, req in enumerate(batch.reqs) + ] @classmethod def from_schedule_batch( @@ -195,6 +176,7 @@ class InputMetadata: seq_lens=batch.seq_lens, req_to_token_pool=model_runner.req_to_token_pool, token_to_kv_pool=model_runner.token_to_kv_pool, + attn_backend=model_runner.attn_backend, out_cache_loc=batch.out_cache_loc, return_logprob=batch.return_logprob, top_logprobs_nums=batch.top_logprobs_nums, @@ -202,76 +184,12 @@ class InputMetadata: ret.sampling_info.update_penalties() ret.sampling_info.update_regex_vocab_mask(batch) - ret.compute_positions(batch) - ret.compute_extend_infos(batch) - - fm = batch.forward_mode - if not fm.is_decode() or model_runner.server_args.attention_backend == "triton": - ret.total_num_tokens = int(torch.sum(ret.seq_lens)) - - if not fm.is_decode(): + if not batch.forward_mode.is_decode(): ret.init_multimuldal_info(batch) + ret.compute_extend_infos(batch) - if model_runner.server_args.attention_backend == "triton": - ret.init_triton_args(batch) - - flashinfer_use_ragged = False - if model_runner.server_args.attention_backend == "flashinfer": - if ( - not fm.is_decode() - and int(torch.sum(ret.seq_lens)) > 4096 - and model_runner.sliding_window_size is None - ): - flashinfer_use_ragged = True - ret.init_flashinfer_handlers( - model_runner, batch.prefix_lens_cpu, flashinfer_use_ragged - ) + model_runner.attn_backend.init_forward_metadata(batch, ret) return ret - - def init_triton_args(self, batch: ScheduleBatch): - """Init auxiliary variables for triton attention backend.""" - self.triton_max_seq_len = int(torch.max(self.seq_lens)) - self.triton_start_loc = torch.zeros_like(self.seq_lens, dtype=torch.int32) - self.triton_start_loc[1:] = torch.cumsum(self.seq_lens[:-1], dim=0) - - if self.forward_mode.is_decode(): - self.triton_max_extend_len = None - else: - self.triton_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda") - extend_seq_lens = self.seq_lens - self.triton_prefix_lens - self.triton_max_extend_len = int(torch.max(extend_seq_lens)) - - def init_flashinfer_handlers( - self, - model_runner, - prefix_lens_cpu, - flashinfer_use_ragged, - ): - if self.forward_mode.is_decode(): - prefix_lens = None - else: - prefix_lens = self.extend_prefix_lens - - update_flashinfer_indices( - self.forward_mode, - model_runner, - self.req_pool_indices, - self.seq_lens, - prefix_lens, - flashinfer_use_ragged=flashinfer_use_ragged, - ) - - ( - self.flashinfer_prefill_wrapper_ragged, - self.flashinfer_prefill_wrapper_paged, - self.flashinfer_decode_wrapper, - self.flashinfer_use_ragged, - ) = ( - model_runner.flashinfer_prefill_wrapper_ragged, - model_runner.flashinfer_prefill_wrapper_paged, - model_runner.flashinfer_decode_wrapper, - flashinfer_use_ragged, - ) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index b04b0d7c0..80c741652 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -25,12 +25,6 @@ from typing import Optional, Tuple, Type import torch import torch.nn as nn -from flashinfer import ( - BatchDecodeWithPagedKVCacheWrapper, - BatchPrefillWithPagedKVCacheWrapper, - BatchPrefillWithRaggedKVCacheWrapper, -) -from flashinfer.decode import _grouped_size_compiled_for_decode_kernels from vllm.config import DeviceConfig, LoadConfig from vllm.config import ModelConfig as VllmModelConfig from vllm.distributed import ( @@ -43,8 +37,8 @@ from vllm.distributed.parallel_state import in_the_same_node_as from vllm.model_executor.model_loader import get_model from vllm.model_executor.models import ModelRegistry -from sglang.global_config import global_config from sglang.srt.configs.model_config import AttentionArch, ModelConfig +from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import SampleOutput from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict @@ -69,6 +63,8 @@ logger = logging.getLogger(__name__) class ModelRunner: + """ModelRunner runs the forward passes of the models.""" + def __init__( self, model_config: ModelConfig, @@ -100,6 +96,7 @@ class ModelRunner: } ) + # Model-specific adjustment if self.is_multimodal_model: logger.info( "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models." @@ -107,6 +104,7 @@ class ModelRunner: server_args.chunked_prefill_size = None server_args.mem_fraction_static *= 0.95 + # Init componnets min_per_gpu_memory = self.init_torch_distributed() self.load_model() self.init_memory_pool( @@ -115,7 +113,7 @@ class ModelRunner: server_args.max_total_tokens, ) self.init_cublas() - self.init_flashinfer() + self.init_attention_backend() self.init_cuda_graphs() def init_torch_distributed(self): @@ -397,9 +395,6 @@ class ModelRunner: qk_rope_head_dim=self.model_config.qk_rope_head_dim, layer_num=self.model_config.num_hidden_layers, ) - logger.info("using MLA Triton implementaion, flashinfer is disabled") - # FIXME: temporarily only Triton MLA is supported - self.server_args.attention_backend = "triton" else: self.token_to_kv_pool = MHATokenToKVPool( self.max_total_num_tokens, @@ -422,106 +417,42 @@ class ModelRunner: c = a @ b return c - def init_flashinfer(self): - """Init flashinfer attention kernel wrappers.""" - if self.server_args.attention_backend != "flashinfer": - assert ( - self.sliding_window_size is None - ), "turn on flashinfer to support window attention" - self.flashinfer_prefill_wrapper_ragged = None - self.flashinfer_prefill_wrapper_paged = None - self.flashinfer_decode_wrapper = None - return - - if not _grouped_size_compiled_for_decode_kernels( - self.model_config.num_attention_heads // self.tp_size, - self.model_config.get_num_kv_heads(self.tp_size), - ): - use_tensor_cores = True + def init_attention_backend(self): + """Init attention kernel backend.""" + if self.server_args.attention_backend == "flashinfer": + self.attn_backend = FlashInferAttnBackend(self) + elif self.server_args.attention_backend == "triton": + assert self.sliding_window_size is None, ( + "Window attention is not supported in the triton attention backend. " + "Please use `--attention-backend flashinfer`." + ) + self.attn_backend = TritonAttnBackend(self) else: - use_tensor_cores = False - - if self.sliding_window_size is None: - self.flashinfer_workspace_buffer = torch.empty( - global_config.flashinfer_workspace_size, - dtype=torch.uint8, - device="cuda", + raise ValueError( + f"Invalid attention backend: {self.server_args.attention_backend}" ) - self.flashinfer_prefill_wrapper_ragged = ( - BatchPrefillWithRaggedKVCacheWrapper( - self.flashinfer_workspace_buffer, "NHD" - ) - ) - self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper( - self.flashinfer_workspace_buffer, "NHD" - ) - self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( - self.flashinfer_workspace_buffer, - "NHD", - use_tensor_cores=use_tensor_cores, - ) - else: - self.flashinfer_workspace_buffer = torch.empty( - global_config.flashinfer_workspace_size, - dtype=torch.uint8, - device="cuda", - ) - self.flashinfer_prefill_wrapper_ragged = None - self.flashinfer_prefill_wrapper_paged = [] - self.flashinfer_decode_wrapper = [] - for i in range(2): - self.flashinfer_prefill_wrapper_paged.append( - BatchPrefillWithPagedKVCacheWrapper( - self.flashinfer_workspace_buffer, "NHD" - ) - ) - self.flashinfer_decode_wrapper.append( - BatchDecodeWithPagedKVCacheWrapper( - self.flashinfer_workspace_buffer, - "NHD", - use_tensor_cores=use_tensor_cores, - ) - ) def init_cuda_graphs(self): """Capture cuda graphs.""" + from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner + + self.cuda_graph_runner = None + if not self.is_generation: # TODO: Currently, cuda graph only captures decode steps, which only exists for generation models return - from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner + if self.server_args.disable_cuda_graph: + return - if ( - self.server_args.disable_cuda_graph - or self.server_args.attention_backend != "flashinfer" - ): - self.cuda_graph_runner = None + if self.server_args.attention_backend != "flashinfer": + logger.warning( + f"Cuda graph is not supported for attention backend: {self.server_args.attention_backend}" + ) return logger.info("Capture cuda graph begin. This can take up to several minutes.") - - if self.server_args.disable_cuda_graph_padding: - batch_size_list = list(range(1, 32)) + [64, 128] - else: - batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 21)] - - self.cuda_graph_runner = CudaGraphRunner( - self, - max_batch_size_to_capture=max(batch_size_list), - use_torch_compile=self.server_args.enable_torch_compile, - disable_padding=self.server_args.disable_cuda_graph_padding, - ) - try: - self.cuda_graph_runner.capture(batch_size_list) - except RuntimeError as e: - raise Exception( - f"Capture cuda graph failed: {e}\n" - "Possible solutions:\n" - "1. disable cuda graph by --disable-cuda-graph\n" - "2. set --mem-fraction-static to a smaller value\n" - "3. disable torch compile by not using --enable-torch-compile\n" - "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n" - ) + self.cuda_graph_runner = CudaGraphRunner(self) @torch.inference_mode() def forward_decode(self, batch: ScheduleBatch): diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index 622f27df1..6f6bb6126 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -143,18 +143,16 @@ class SamplingBatchInfo: self.linear_penalties = penalizer.apply(self.linear_penalties) def update_regex_vocab_mask(self, batch: ScheduleBatch): - bs, reqs = batch.batch_size(), batch.reqs - device = "cuda" - has_regex = any(req.regex_fsm is not None for req in reqs) + has_regex = any(req.regex_fsm is not None for req in batch.reqs) # Reset the vocab mask self.vocab_mask = None if has_regex: self.vocab_mask = torch.zeros( - bs, self.vocab_size, dtype=torch.bool, device=device + batch.batch_size(), self.vocab_size, dtype=torch.bool, device="cuda" ) - for i, req in enumerate(reqs): + for i, req in enumerate(batch.reqs): if req.regex_fsm is not None: self.vocab_mask[i].fill_(1) self.vocab_mask[i][ diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 5bdee03de..52806daa9 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -335,23 +335,19 @@ def launch_server( return # Launch processes - tokenizer_manager = TokenizerManager(server_args, port_args) - if server_args.chat_template: - load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template) pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False) - pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False) if server_args.dp_size == 1: start_controller_process = start_controller_process_single else: start_controller_process = start_controller_process_multi - proc_controller = mp.Process( target=start_controller_process, args=(server_args, port_args, pipe_controller_writer), ) proc_controller.start() + pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False) proc_detoken = mp.Process( target=start_detokenizer_process, args=( @@ -362,6 +358,10 @@ def launch_server( ) proc_detoken.start() + tokenizer_manager = TokenizerManager(server_args, port_args) + if server_args.chat_template: + load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template) + # Wait for the model to finish loading controller_init_state = pipe_controller_reader.recv() detoken_init_state = pipe_detoken_reader.recv() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 776f7bec3..36a16bc9f 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -83,8 +83,8 @@ class ServerArgs: json_model_override_args: str = "{}" # Optimization/debug options - attention_backend: str = "flashinfer" - sampling_backend: str = "flashinfer" + attention_backend: Optional[str] = None + sampling_backend: Optional[str] = None disable_flashinfer: bool = False disable_flashinfer_sampling: bool = False @@ -148,6 +148,17 @@ class ServerArgs: ) self.sampling_backend = "pytorch" + # Default kernel backends + if self.enable_mla: + logger.info("MLA optimization is tunred on. Use triton backend.") + self.attention_backend = "triton" + + if self.attention_backend is None: + self.attention_backend = "flashinfer" + + if self.sampling_backend is None: + self.sampling_backend = "flashinfer" + # Model-specific patches if "Alibaba-NLP/gte-Qwen2-1.5B-instruct" == self.model_path: logger.info( diff --git a/test/srt/test_create_kvindices.py b/test/srt/test_create_kvindices.py index 2159cca95..8fb0231d8 100644 --- a/test/srt/test_create_kvindices.py +++ b/test/srt/test_create_kvindices.py @@ -55,8 +55,8 @@ class TestCreateKvIndices(unittest.TestCase): paged_kernel_lens, kv_indptr, None, - req_to_token.size(1), kv_indices_triton, + req_to_token.size(1), ) # Check diff --git a/test/srt/test_moe_serving_throughput.py b/test/srt/test_moe_serving_throughput.py index e0c851d4e..65b0b55b9 100644 --- a/test/srt/test_moe_serving_throughput.py +++ b/test/srt/test_moe_serving_throughput.py @@ -19,7 +19,8 @@ class TestServingThroughput(unittest.TestCase): other_args = [] if disable_radix_cache: other_args.append("--disable-radix-cache") - other_args.extend(["--attention-backend", attention_backend]) + if attention_backend: + other_args.extend(["--attention-backend", attention_backend]) other_args.extend(["--chunked-prefill-size", str(chunked_prefill_size)]) other_args.extend(["--tensor-parallel-size", "2"]) diff --git a/test/srt/test_serving_throughput.py b/test/srt/test_serving_throughput.py index 1b458e9e6..81aff3ed2 100644 --- a/test/srt/test_serving_throughput.py +++ b/test/srt/test_serving_throughput.py @@ -19,7 +19,8 @@ class TestServingThroughput(unittest.TestCase): other_args = [] if disable_radix_cache: other_args.append("--disable-radix-cache") - other_args.extend(["--attention-backend", attention_backend]) + if attention_backend: + other_args.extend(["--attention-backend", attention_backend]) other_args.extend(["--chunked-prefill-size", str(chunked_prefill_size)]) model = DEFAULT_MODEL_NAME_FOR_TEST diff --git a/test/srt/test_triton_attention_kernels.py b/test/srt/test_triton_attention_kernels.py index 0d094b557..79b26f67a 100644 --- a/test/srt/test_triton_attention_kernels.py +++ b/test/srt/test_triton_attention_kernels.py @@ -96,23 +96,17 @@ class TestExtendAttention(unittest.TestCase): v_buffer, req_to_tokens, b_req_idx, - b_start_loc, b_seq_len, - b_seq_len_prefix, - b_start_loc_extend, b_seq_len_extend, - max_len_in_batch, + b_start_loc_extend, max_len_extend, ) redundant_attention( q_extend, - k_extend, - v_extend, o_redundant, k_buffer, v_buffer, - req_to_tokens, b_req_idx, b_start_loc, b_seq_len,