diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index a7474326f..a02673dc3 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -15,6 +15,8 @@ limitations under the License. """Radix attention.""" +from typing import Optional + import torch from flashinfer.cascade import merge_state from torch import nn @@ -34,8 +36,7 @@ class RadixAttention(nn.Module): scaling: float, num_kv_heads: int, layer_id: int, - reuse: bool = False, - sliding_window_size: int = -1, + sliding_window_size: Optional[int] = None, logit_cap: int = -1, v_head_dim: int = -1, ): @@ -48,8 +49,7 @@ class RadixAttention(nn.Module): self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim self.scaling = scaling self.layer_id = layer_id - self.reuse = reuse - self.sliding_window_size = sliding_window_size + self.sliding_window_size = sliding_window_size if sliding_window_size else -1 if ( not global_server_args_dict.get("disable_flashinfer", False) @@ -118,16 +118,16 @@ class RadixAttention(nn.Module): def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata): # using two wrappers is unnecessary in the current PR, but are prepared for future PRs - prefill_wrapper_ragged = input_metadata.flashinfer_prefill_wrapper_ragged prefill_wrapper_paged = input_metadata.flashinfer_prefill_wrapper_paged - if self.sliding_window_size != -1 or self.reuse: + if self.sliding_window_size != -1: prefill_wrapper_paged = prefill_wrapper_paged[0] else: if isinstance(prefill_wrapper_paged, list): prefill_wrapper_paged = prefill_wrapper_paged[1] - if not input_metadata.flashinfer_use_ragged or self.reuse: - if not self.reuse: + if not input_metadata.flashinfer_use_ragged: + if k is not None: + assert v is not None self.store_kv_cache(k, v, input_metadata) o = prefill_wrapper_paged.forward( @@ -139,21 +139,20 @@ class RadixAttention(nn.Module): logits_soft_cap=self.logit_cap, ) else: - o1, s1 = prefill_wrapper_ragged.forward_return_lse( - q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), - k.contiguous().view(-1, self.tp_k_head_num, self.head_dim), - v.contiguous().view(-1, self.tp_v_head_num, self.head_dim), - causal=True, - sm_scale=self.scaling, - window_left=self.sliding_window_size, - logits_soft_cap=self.logit_cap, + o1, s1 = ( + input_metadata.flashinfer_prefill_wrapper_ragged.forward_return_lse( + q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), + k.contiguous().view(-1, self.tp_k_head_num, self.head_dim), + v.contiguous().view(-1, self.tp_v_head_num, self.head_dim), + causal=True, + sm_scale=self.scaling, + logits_soft_cap=self.logit_cap, + ) ) if input_metadata.extend_no_prefix: o = o1 else: - # TODO window attention + radix attention will come up in next PR - assert self.sliding_window_size == -1 o2, s2 = prefill_wrapper_paged.forward_return_lse( q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id), @@ -179,7 +178,8 @@ class RadixAttention(nn.Module): if isinstance(decode_wrapper, list): decode_wrapper = decode_wrapper[1] - if not self.reuse: + if k is not None: + assert v is not None self.store_kv_cache(k, v, input_metadata) o = decode_wrapper.forward( diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 66479b255..ce5ea25ea 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -194,6 +194,7 @@ class InputMetadata: if ( forward_mode != ForwardMode.DECODE and int(torch.sum(ret.seq_lens)) > 4096 + and model_runner.sliding_window_size is None ): flashinfer_use_ragged = True ret.init_flashinfer_handlers( @@ -322,22 +323,25 @@ def update_flashinfer_indices( 1, ) else: + # window attention use paged only kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda") for wrapper_id in range(2): - if flashinfer_use_ragged and wrapper_id == 1: - # full attention use ragged+paged - paged_kernel_lens = prefix_lens + if wrapper_id == 0: + if forward_mode == ForwardMode.DECODE: + paged_kernel_lens = torch.minimum( + seq_lens, torch.tensor(model_runner.sliding_window_size + 1) + ) + else: + paged_kernel_lens = torch.minimum( + seq_lens, + torch.tensor(model_runner.sliding_window_size) + + seq_lens + - prefix_lens, + ) else: - # window attention use paged only paged_kernel_lens = seq_lens - if wrapper_id == 0 and forward_mode == ForwardMode.DECODE: - paged_kernel_lens = torch.minimum( - paged_kernel_lens, torch.tensor(model_runner.sliding_window_size) - ) - kv_start_idx = seq_lens - paged_kernel_lens - else: - kv_start_idx = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + kv_start_idx = seq_lens - paged_kernel_lens kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda") kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) @@ -376,17 +380,6 @@ def update_flashinfer_indices( ) qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0) - if flashinfer_use_ragged and wrapper_id == 1: - model_runner.flashinfer_prefill_wrapper_ragged.end_forward() - model_runner.flashinfer_prefill_wrapper_ragged.begin_forward( - qo_indptr, - qo_indptr, - num_qo_heads, - num_kv_heads, - head_dim, - ) - - # cached part model_runner.flashinfer_prefill_wrapper_paged[wrapper_id].end_forward() model_runner.flashinfer_prefill_wrapper_paged[wrapper_id].begin_forward( qo_indptr, diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index d3ed96fe0..7af4ec2dd 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -334,11 +334,7 @@ class ModelRunner: dtype=torch.uint8, device="cuda", ) - self.flashinfer_prefill_wrapper_ragged = ( - BatchPrefillWithRaggedKVCacheWrapper( - self.flashinfer_workspace_buffer, "NHD" - ) - ) + self.flashinfer_prefill_wrapper_ragged = None self.flashinfer_prefill_wrapper_paged = [] self.flashinfer_decode_wrapper = [] for i in range(2): diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index 463d5e505..80b99742e 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -213,7 +213,7 @@ class Gemma2Attention(nn.Module): self.scaling, num_kv_heads=self.num_kv_heads, layer_id=layer_idx, - sliding_window_size=get_window_size(config) if use_sliding_window else -1, + sliding_window_size=get_window_size(config) if use_sliding_window else None, logit_cap=self.config.attn_logit_softcapping, ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 6512e1b6e..738ab7d1a 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -450,16 +450,8 @@ class ServerArgs: self.dp_size > 1 and self.node_rank is not None ), "multi-node data parallel is not supported" if "gemma-2" in self.model_path.lower(): - logger.info( - f"When using sliding window in gemma-2, disable radix_cache, regex_jump_forward, and turn on flashinfer." - ) - # FIXME: compatibility with radix attention - self.disable_radix_cache = True - # FIXME: compatibility with jump forward - self.disable_regex_jump_forward = True + logger.info(f"When using sliding window in gemma-2, turn on flashinfer.") self.disable_flashinfer = False - # FIXME: compatibility with chunked prefill - self.chunked_prefill_size = None @dataclasses.dataclass