diff --git a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py index 0d43a2f8f..58831a5a6 100644 --- a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py @@ -61,18 +61,15 @@ class MambaAttnBackend(AttentionBackend): self.forward_metadata: ForwardMetadata = None self.state_indices_list = [] self.query_start_loc_list = [] - - @classmethod - @lru_cache(maxsize=128) - def _get_cached_arange(cls, bs: int, device_str: str) -> torch.Tensor: - """Cache torch.arange tensors for common batch sizes to avoid repeated allocation.""" - device = torch.device(device_str) - return torch.arange(0, bs + 1, dtype=torch.int32, device=device) + self.cached_cuda_graph_decode_query_start_loc: torch.Tensor = None + self.cached_cuda_graph_verify_query_start_loc: torch.Tensor = None def init_forward_metadata(self, forward_batch: ForwardBatch): bs = forward_batch.batch_size if forward_batch.forward_mode.is_decode_or_idle(): - query_start_loc = self._get_cached_arange(bs, str(self.device)) + query_start_loc = torch.arange( + 0, bs + 1, dtype=torch.int32, device=self.device + ) elif forward_batch.forward_mode.is_extend(): if forward_batch.forward_mode.is_target_verify(): query_start_loc = torch.arange( @@ -102,6 +99,10 @@ class MambaAttnBackend(AttentionBackend): ) def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): + assert ( + max_num_tokens % max_bs == 0 + ), f"max_num_tokens={max_num_tokens} must be divisible by max_bs={max_bs}" + verify_step = max_num_tokens / max_bs for i in range(max_bs): self.state_indices_list.append( torch.full( @@ -111,6 +112,16 @@ class MambaAttnBackend(AttentionBackend): self.query_start_loc_list.append( torch.empty((i + 2,), dtype=torch.int32, device=self.device) ) + self.cached_cuda_graph_decode_query_start_loc = torch.arange( + 0, max_bs + 1, dtype=torch.int32, device=self.device + ) + self.cached_cuda_graph_verify_query_start_loc = torch.arange( + 0, + max_bs * verify_step + 1, + step=verify_step, + dtype=torch.int32, + device=self.device, + ) def init_forward_metadata_capture_cuda_graph( self, @@ -123,16 +134,12 @@ class MambaAttnBackend(AttentionBackend): spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], ): if forward_mode.is_decode_or_idle(): - self.query_start_loc_list[bs - 1].copy_(self._get_cached_arange(bs, "cuda")) + self.query_start_loc_list[bs - 1].copy_( + self.cached_cuda_graph_decode_query_start_loc[: bs + 1] + ) elif forward_mode.is_target_verify(): self.query_start_loc_list[bs - 1].copy_( - torch.arange( - 0, - bs * spec_info.draft_token_num + 1, - step=spec_info.draft_token_num, - dtype=torch.int32, - device=self.device, - ) + self.cached_cuda_graph_verify_query_start_loc[: bs + 1] ) else: raise ValueError(f"Invalid forward mode: {forward_mode=}") @@ -163,23 +170,29 @@ class MambaAttnBackend(AttentionBackend): mamba_indices[bs - num_padding :] = -1 self.state_indices_list[bs - 1][: len(mamba_indices)].copy_(mamba_indices) if forward_mode.is_decode_or_idle(): - self.query_start_loc_list[bs - 1].copy_(self._get_cached_arange(bs, "cuda")) - if num_padding > 0: - self.query_start_loc_list[bs - 1][bs - num_padding :] = bs - num_padding - elif forward_mode.is_target_verify(): - self.query_start_loc_list[bs - 1].copy_( - torch.arange( - 0, - bs * spec_info.draft_token_num + 1, - step=spec_info.draft_token_num, - dtype=torch.int32, - device=self.device, + if num_padding == 0: + self.query_start_loc_list[bs - 1].copy_( + self.cached_cuda_graph_decode_query_start_loc[: bs + 1] ) - ) - if num_padding > 0: - self.query_start_loc_list[bs - 1][bs - num_padding :] = ( + else: + self.query_start_loc_list[bs - 1][: bs - num_padding].copy_( + self.cached_cuda_graph_decode_query_start_loc[: bs - num_padding] + ) + self.query_start_loc_list[bs - 1][bs - num_padding :].copy_( bs - num_padding - ) * spec_info.draft_token_num + ) + elif forward_mode.is_target_verify(): + if num_padding == 0: + self.query_start_loc_list[bs - 1].copy_( + self.cached_cuda_graph_verify_query_start_loc[: bs + 1] + ) + else: + self.query_start_loc_list[bs - 1][: bs - num_padding].copy_( + self.cached_cuda_graph_verify_query_start_loc[: bs - num_padding] + ) + self.query_start_loc_list[bs - 1][bs - num_padding :].copy_( + (bs - num_padding) * spec_info.draft_token_num + ) else: raise ValueError(f"Invalid forward mode: {forward_mode=}")