From 96a2093ef021b7fb10cf727050e0c87494c5463a Mon Sep 17 00:00:00 2001 From: Ying Sheng Date: Wed, 14 Aug 2024 10:37:01 -0700 Subject: [PATCH] [Fix] Compatibility of window attention and cuda graph (#1090) --- python/sglang/srt/layers/radix_attention.py | 16 ++++-- .../srt/model_executor/cuda_graph_runner.py | 55 +++++++++++++++---- .../srt/model_executor/forward_batch_info.py | 10 +--- .../sglang/srt/model_executor/model_runner.py | 22 +++----- python/sglang/srt/server_args.py | 4 +- .../test/{long_prompt => long_prompt.txt} | 0 python/sglang/test/runners.py | 2 +- 7 files changed, 70 insertions(+), 39 deletions(-) rename python/sglang/test/{long_prompt => long_prompt.txt} (100%) diff --git a/python/sglang/srt/layers/radix_attention.py b/python/sglang/srt/layers/radix_attention.py index 49b86ad19..978a5d4c0 100644 --- a/python/sglang/srt/layers/radix_attention.py +++ b/python/sglang/srt/layers/radix_attention.py @@ -34,6 +34,7 @@ class RadixAttention(nn.Module): scaling: float, num_kv_heads: int, layer_id: int, + reuse: bool = False, sliding_window_size: int = -1, logit_cap: int = -1, v_head_dim: int = -1, @@ -47,6 +48,7 @@ class RadixAttention(nn.Module): self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim self.scaling = scaling self.layer_id = layer_id + self.reuse = reuse self.sliding_window_size = sliding_window_size if ( @@ -127,8 +129,9 @@ class RadixAttention(nn.Module): if isinstance(prefill_wrapper_paged, list): prefill_wrapper_paged = prefill_wrapper_paged[1] - if not input_metadata.flashinfer_use_ragged: - self.store_kv_cache(k, v, input_metadata) + if not input_metadata.flashinfer_use_ragged or self.reuse: + if not self.reuse: + self.store_kv_cache(k, v, input_metadata) o = prefill_wrapper_paged.forward( q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), @@ -179,7 +182,8 @@ class RadixAttention(nn.Module): if isinstance(decode_wrapper, list): decode_wrapper = decode_wrapper[1] - self.store_kv_cache(k, v, input_metadata) + if not self.reuse: + self.store_kv_cache(k, v, input_metadata) o = decode_wrapper.forward( q.contiguous().view(-1, self.tp_q_head_num, self.head_dim), @@ -191,8 +195,10 @@ class RadixAttention(nn.Module): return o.view(-1, self.tp_q_head_num * self.head_dim) def forward(self, q, k, v, input_metadata: InputMetadata): - k = k.view(-1, self.tp_k_head_num, self.qk_head_dim) - v = v.view(-1, self.tp_v_head_num, self.v_head_dim) + if k is not None: + assert v is not None + k = k.view(-1, self.tp_k_head_num, self.qk_head_dim) + v = v.view(-1, self.tp_v_head_num, self.v_head_dim) if input_metadata.forward_mode == ForwardMode.EXTEND: return self.extend_forward(q, k, v, input_metadata) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index a74e8eef7..ed26322c3 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -107,9 +107,6 @@ class CudaGraphRunner: ) # FlashInfer inputs - self.flashinfer_workspace_buffer = ( - self.model_runner.flashinfer_workspace_buffers[0] - ) self.flashinfer_kv_indptr = torch.zeros( (self.max_bs + 1,), dtype=torch.int32, device="cuda" ) @@ -121,6 +118,23 @@ class CudaGraphRunner: self.flashinfer_kv_last_page_len = torch.ones( (self.max_bs,), dtype=torch.int32, device="cuda" ) + if model_runner.sliding_window_size is None: + self.flashinfer_workspace_buffer = ( + self.model_runner.flashinfer_workspace_buffers[0] + ) + else: + self.flashinfer_workspace_buffers = [ + self.model_runner.flashinfer_workspace_buffers[0], + self.model_runner.flashinfer_workspace_buffers[2], + ] + self.flashinfer_kv_indptr = [ + self.flashinfer_kv_indptr, + self.flashinfer_kv_indptr.clone(), + ] + self.flashinfer_kv_indices = [ + self.flashinfer_kv_indices, + self.flashinfer_kv_indices.clone(), + ] self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else [] @@ -171,15 +185,32 @@ class CudaGraphRunner: use_tensor_cores = True else: use_tensor_cores = False - flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( - self.flashinfer_workspace_buffer, - "NHD", - use_cuda_graph=True, - use_tensor_cores=use_tensor_cores, - paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1], - paged_kv_indices_buffer=self.flashinfer_kv_indices, - paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs], - ) + if self.model_runner.sliding_window_size is None: + flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( + self.flashinfer_workspace_buffer, + "NHD", + use_cuda_graph=True, + use_tensor_cores=use_tensor_cores, + paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1], + paged_kv_indices_buffer=self.flashinfer_kv_indices, + paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs], + ) + else: + flashinfer_decode_wrapper = [] + for i in range(2): + flashinfer_decode_wrapper.append( + BatchDecodeWithPagedKVCacheWrapper( + self.flashinfer_workspace_buffers[i], + "NHD", + use_cuda_graph=True, + use_tensor_cores=use_tensor_cores, + paged_kv_indptr_buffer=self.flashinfer_kv_indptr[i][: bs + 1], + paged_kv_indices_buffer=self.flashinfer_kv_indices[i], + paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[ + :bs + ], + ) + ) update_flashinfer_indices( ForwardMode.DECODE, self.model_runner, diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 3b2ee9de0..809b3329d 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -154,7 +154,6 @@ class InputMetadata: model_runner: "ModelRunner", batch: ScheduleBatch, forward_mode: ForwardMode, - sliding_window_size: Optional[int] = None, ): ret = cls( forward_mode=forward_mode, @@ -198,7 +197,7 @@ class InputMetadata: ): flashinfer_use_ragged = True ret.init_flashinfer_handlers( - model_runner, prefix_lens, flashinfer_use_ragged, sliding_window_size + model_runner, prefix_lens, flashinfer_use_ragged ) return ret @@ -221,7 +220,6 @@ class InputMetadata: model_runner, prefix_lens, flashinfer_use_ragged, - sliding_window_size=None, ): update_flashinfer_indices( self.forward_mode, @@ -230,7 +228,6 @@ class InputMetadata: self.seq_lens, prefix_lens, flashinfer_use_ragged=flashinfer_use_ragged, - sliding_window_size=sliding_window_size, ) ( @@ -254,7 +251,6 @@ def update_flashinfer_indices( prefix_lens, flashinfer_decode_wrapper=None, flashinfer_use_ragged=False, - sliding_window_size=None, ): """Init auxiliary variables for FlashInfer attention backend.""" num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size @@ -262,7 +258,7 @@ def update_flashinfer_indices( head_dim = model_runner.model_config.head_dim batch_size = len(req_pool_indices) - if sliding_window_size is None: + if model_runner.sliding_window_size is None: if flashinfer_use_ragged: paged_kernel_lens = prefix_lens else: @@ -335,7 +331,7 @@ def update_flashinfer_indices( if wrapper_id == 0 and forward_mode == ForwardMode.DECODE: paged_kernel_lens = torch.minimum( - paged_kernel_lens, torch.tensor(sliding_window_size) + paged_kernel_lens, torch.tensor(model_runner.sliding_window_size) ) kv_start_idx = seq_lens - paged_kernel_lens else: diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 9da284da6..0a7483423 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -187,6 +187,11 @@ class ModelRunner: scheduler_config=None, cache_config=None, ) + self.sliding_window_size = ( + self.model.get_window_size() + if hasattr(self.model, "get_window_size") + else None + ) self.is_generation = is_generation_model( self.model_config.hf_config.architectures ) @@ -295,12 +300,6 @@ class ModelRunner: return c def init_flashinfer(self): - self.sliding_window_size = ( - self.model.get_window_size() - if hasattr(self.model, "get_window_size") - else None - ) - if self.server_args.disable_flashinfer: assert ( self.sliding_window_size is None @@ -339,7 +338,7 @@ class ModelRunner: use_tensor_cores=use_tensor_cores, ) else: - workspace_buffers = torch.empty( + self.flashinfer_workspace_buffers = torch.empty( 4, global_config.flashinfer_workspace_size, dtype=torch.uint8, @@ -351,17 +350,17 @@ class ModelRunner: for i in range(2): self.flashinfer_prefill_wrapper_ragged.append( BatchPrefillWithRaggedKVCacheWrapper( - workspace_buffers[2 * i + 0], "NHD" + self.flashinfer_workspace_buffers[2 * i + 0], "NHD" ) ) self.flashinfer_prefill_wrapper_paged.append( BatchPrefillWithPagedKVCacheWrapper( - workspace_buffers[2 * i + 1], "NHD" + self.flashinfer_workspace_buffers[2 * i + 1], "NHD" ) ) self.flashinfer_decode_wrapper.append( BatchDecodeWithPagedKVCacheWrapper( - workspace_buffers[2 * i + 0], + self.flashinfer_workspace_buffers[2 * i + 0], "NHD", use_tensor_cores=use_tensor_cores, ) @@ -404,7 +403,6 @@ class ModelRunner: self, batch, ForwardMode.DECODE, - sliding_window_size=self.sliding_window_size, ) return self.model.forward( @@ -417,7 +415,6 @@ class ModelRunner: self, batch, forward_mode=ForwardMode.EXTEND, - sliding_window_size=self.sliding_window_size, ) return self.model.forward( batch.input_ids, input_metadata.positions, input_metadata @@ -429,7 +426,6 @@ class ModelRunner: self, batch, forward_mode=ForwardMode.EXTEND, - sliding_window_size=self.sliding_window_size, ) return self.model.forward( batch.input_ids, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 5e7996b80..8ed66960b 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -453,10 +453,12 @@ class ServerArgs: logger.info( f"When using sliding window in gemma-2, disable radix_cache, regex_jump_forward, and turn on flashinfer." ) + # FIXME: compatibility with radix attention self.disable_radix_cache = True + # FIXME: compatibility with jump forward self.disable_regex_jump_forward = True self.disable_flashinfer = False - self.disable_cuda_graph = True + # FIXME: compatibility with chunked prefill self.chunked_prefill_size = None diff --git a/python/sglang/test/long_prompt b/python/sglang/test/long_prompt.txt similarity index 100% rename from python/sglang/test/long_prompt rename to python/sglang/test/long_prompt.txt diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index c8357a16c..e325ecb71 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -36,7 +36,7 @@ DEFAULT_PROMPTS = [ ] dirpath = os.path.dirname(__file__) -with open(os.path.join(dirpath, "long_prompt"), "r") as f: +with open(os.path.join(dirpath, "long_prompt.txt"), "r") as f: long_prompt = f.read() DEFAULT_PROMPTS.append(long_prompt)