From e0ce171d7981e324ea3bb9def6079274e039c118 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Wed, 20 Aug 2025 11:16:26 +0800 Subject: [PATCH] Fix triton backend eagle illegal memory access (#9344) --- .../srt/layers/attention/triton_backend.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 302907b67..2d9b42c8b 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -172,7 +172,7 @@ class TritonAttnBackend(AttentionBackend): kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) kv_indptr = kv_indptr[: bs + 1] kv_indices = torch.empty( - forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device + forward_batch.seq_lens_sum, dtype=torch.int64, device=self.device ) create_flashinfer_kv_indices_triton[(bs,)]( self.req_to_token, @@ -238,7 +238,7 @@ class TritonAttnBackend(AttentionBackend): kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) kv_indptr = kv_indptr[: bs + 1] kv_indices = torch.empty( - kv_indptr[-1], dtype=torch.int32, device=self.device + kv_indptr[-1], dtype=torch.int64, device=self.device ) create_flashinfer_kv_indices_triton[(bs,)]( self.req_to_token, @@ -289,6 +289,7 @@ class TritonAttnBackend(AttentionBackend): self.req_to_token, ) ) + kv_indices = kv_indices.to(torch.int64) mask_indptr = None # TODO(FIXME): This will trigger an invalid Eagle tree when using # `max(spec_info.accept_length_cpu)`. @@ -304,7 +305,7 @@ class TritonAttnBackend(AttentionBackend): kv_indptr = kv_indptr[: bs + 1] kv_indices = torch.empty( forward_batch.extend_prefix_lens.sum().item(), - dtype=torch.int32, + dtype=torch.int64, device=self.device, ) create_flashinfer_kv_indices_triton[(bs,)]( @@ -379,7 +380,7 @@ class TritonAttnBackend(AttentionBackend): if kv_indices_buf is None: self.cuda_graph_kv_indices = torch.zeros( (max_num_tokens * self.max_context_len), - dtype=torch.int32, + dtype=torch.int64, device=self.device, ) else: @@ -396,7 +397,7 @@ class TritonAttnBackend(AttentionBackend): if kv_indices_buf is None: self.cuda_graph_window_kv_indices = torch.zeros( (max_num_tokens * self.sliding_window_size), - dtype=torch.int32, + dtype=torch.int64, device=self.device, ) else: @@ -888,7 +889,7 @@ class TritonMultiStepDraftBackend: self.speculative_num_steps, forward_batch.batch_size * self.topk * self.max_context_len, ), - dtype=torch.int32, + dtype=torch.int64, device=self.device, ) @@ -906,7 +907,7 @@ class TritonMultiStepDraftBackend: def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): self.cuda_graph_kv_indices = torch.zeros( (self.speculative_num_steps, max_num_tokens * self.max_context_len), - dtype=torch.int32, + dtype=torch.int64, device=self.device, ) for i in range(self.speculative_num_steps): @@ -1015,7 +1016,7 @@ def update_sliding_window_buffer( window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0) window_kv_indptr = window_kv_indptr[: bs + 1] window_kv_indices = torch.empty( - window_kv_indptr[-1], dtype=torch.int32, device=device + window_kv_indptr[-1], dtype=torch.int64, device=device ) window_kv_start_idx = seq_lens - window_kv_lens create_flashinfer_kv_indices_triton[(bs,)](