From fc82f5a743f48d50c633a08e89eff3d6522fb4a3 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Thu, 24 Oct 2024 12:33:15 -0700 Subject: [PATCH] [Fix] Fix cuda graph padding for triton attention backend (#1782) --- python/sglang/srt/layers/attention/__init__.py | 4 ---- .../sglang/srt/layers/attention/double_sparsity_backend.py | 3 --- python/sglang/srt/layers/attention/flashinfer_backend.py | 3 --- python/sglang/srt/layers/attention/triton_backend.py | 3 --- python/sglang/srt/mem_cache/memory_pool.py | 2 +- python/sglang/srt/model_executor/cuda_graph_runner.py | 7 ++----- 6 files changed, 3 insertions(+), 19 deletions(-) diff --git a/python/sglang/srt/layers/attention/__init__.py b/python/sglang/srt/layers/attention/__init__.py index f5d573f5f..7759015aa 100644 --- a/python/sglang/srt/layers/attention/__init__.py +++ b/python/sglang/srt/layers/attention/__init__.py @@ -41,10 +41,6 @@ class AttentionBackend(ABC): """Init the metadata for a forward pass for replying a cuda graph.""" raise NotImplementedError() - def get_cuda_graph_seq_len_fill_value(self): - """Get the fill value for padded seq lens. Typically, it is 0 or 1.""" - raise NotImplementedError() - def forward( self, q: torch.Tensor, diff --git a/python/sglang/srt/layers/attention/double_sparsity_backend.py b/python/sglang/srt/layers/attention/double_sparsity_backend.py index 73c32df8f..af8c8a71c 100644 --- a/python/sglang/srt/layers/attention/double_sparsity_backend.py +++ b/python/sglang/srt/layers/attention/double_sparsity_backend.py @@ -161,9 +161,6 @@ class DoubleSparseAttnBackend(AttentionBackend): self.cuda_graph_start_loc.zero_() self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) - def get_cuda_graph_seq_len_fill_value(self): - return 1 - def forward_extend( self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch ): diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index c6b5393ee..3475d8b7e 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -210,9 +210,6 @@ class FlashInferAttnBackend(AttentionBackend): encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None, ) - def get_cuda_graph_seq_len_fill_value(self): - return 0 - def forward_extend( self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch ): diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 47b8d3cd5..9fc3c9751 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -108,9 +108,6 @@ class TritonAttnBackend(AttentionBackend): self.cuda_graph_start_loc.zero_() self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) - def get_cuda_graph_seq_len_fill_value(self): - return 1 - def forward_extend( self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch ): diff --git a/python/sglang/srt/mem_cache/memory_pool.py b/python/sglang/srt/mem_cache/memory_pool.py index 07f3d454e..b028309c7 100644 --- a/python/sglang/srt/mem_cache/memory_pool.py +++ b/python/sglang/srt/mem_cache/memory_pool.py @@ -38,7 +38,7 @@ class ReqToTokenPool: self.size = size self.max_context_len = max_context_len self.device = device - self.req_to_token = torch.empty( + self.req_to_token = torch.zeros( (size, max_context_len), dtype=torch.int32, device=device ) self.free_slots = list(range(size)) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 23090688d..d83b337f1 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -133,11 +133,8 @@ class CudaGraphRunner: # Attention backend self.max_bs = max(self.capture_bs) self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs) - self.seq_len_fill_value = ( - self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() - ) - # FIXME(lsyin): leave it here for now, I don't know whether it is necessary + self.seq_len_fill_value = 1 self.encoder_len_fill_value = 0 if self.use_torch_compile: @@ -290,7 +287,7 @@ class CudaGraphRunner: index = bisect.bisect_left(self.capture_bs, raw_bs) bs = self.capture_bs[index] if bs != raw_bs: - self.seq_lens.fill_(1) + self.seq_lens.fill_(self.seq_len_fill_value) self.out_cache_loc.zero_() # Common inputs