From 384d85ba358a6a097090f9d7dbe0f621c8c47829 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Thu, 24 Oct 2024 13:30:11 -0700 Subject: [PATCH] Re-introduce `get_cuda_graph_seq_len_fill_value` (#1783) --- python/sglang/srt/layers/attention/__init__.py | 4 ++++ .../srt/layers/attention/double_sparsity_backend.py | 3 +++ python/sglang/srt/layers/attention/flashinfer_backend.py | 3 +++ python/sglang/srt/layers/attention/triton_backend.py | 3 +++ python/sglang/srt/model_executor/cuda_graph_runner.py | 8 ++++++-- 5 files changed, 19 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/attention/__init__.py b/python/sglang/srt/layers/attention/__init__.py index 7759015aa..f5d573f5f 100644 --- a/python/sglang/srt/layers/attention/__init__.py +++ b/python/sglang/srt/layers/attention/__init__.py @@ -41,6 +41,10 @@ class AttentionBackend(ABC): """Init the metadata for a forward pass for replying a cuda graph.""" raise NotImplementedError() + def get_cuda_graph_seq_len_fill_value(self): + """Get the fill value for padded seq lens. Typically, it is 0 or 1.""" + raise NotImplementedError() + def forward( self, q: torch.Tensor, diff --git a/python/sglang/srt/layers/attention/double_sparsity_backend.py b/python/sglang/srt/layers/attention/double_sparsity_backend.py index af8c8a71c..73c32df8f 100644 --- a/python/sglang/srt/layers/attention/double_sparsity_backend.py +++ b/python/sglang/srt/layers/attention/double_sparsity_backend.py @@ -161,6 +161,9 @@ class DoubleSparseAttnBackend(AttentionBackend): self.cuda_graph_start_loc.zero_() self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) + def get_cuda_graph_seq_len_fill_value(self): + return 1 + def forward_extend( self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch ): diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 3475d8b7e..c6b5393ee 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -210,6 +210,9 @@ class FlashInferAttnBackend(AttentionBackend): encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None, ) + def get_cuda_graph_seq_len_fill_value(self): + return 0 + def forward_extend( self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch ): diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 9fc3c9751..47b8d3cd5 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -108,6 +108,9 @@ class TritonAttnBackend(AttentionBackend): self.cuda_graph_start_loc.zero_() self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) + def get_cuda_graph_seq_len_fill_value(self): + return 1 + def forward_extend( self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch ): diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index d83b337f1..22ed6cc2b 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -134,7 +134,11 @@ class CudaGraphRunner: self.max_bs = max(self.capture_bs) self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs) - self.seq_len_fill_value = 1 + self.seq_len_fill_value = ( + self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() + ) + + # FIXME(lsyin): leave it here for now, I don't know whether it is necessary self.encoder_len_fill_value = 0 if self.use_torch_compile: @@ -287,7 +291,7 @@ class CudaGraphRunner: index = bisect.bisect_left(self.capture_bs, raw_bs) bs = self.capture_bs[index] if bs != raw_bs: - self.seq_lens.fill_(self.seq_len_fill_value) + self.seq_lens.fill_(1) self.out_cache_loc.zero_() # Common inputs