Fix triton backend eagle illegal memory access (#9344)
This commit is contained in:
@@ -172,7 +172,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = torch.empty(
|
||||
forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
|
||||
forward_batch.seq_lens_sum, dtype=torch.int64, device=self.device
|
||||
)
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
@@ -238,7 +238,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = torch.empty(
|
||||
kv_indptr[-1], dtype=torch.int32, device=self.device
|
||||
kv_indptr[-1], dtype=torch.int64, device=self.device
|
||||
)
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
@@ -289,6 +289,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
self.req_to_token,
|
||||
)
|
||||
)
|
||||
kv_indices = kv_indices.to(torch.int64)
|
||||
mask_indptr = None
|
||||
# TODO(FIXME): This will trigger an invalid Eagle tree when using
|
||||
# `max(spec_info.accept_length_cpu)`.
|
||||
@@ -304,7 +305,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = torch.empty(
|
||||
forward_batch.extend_prefix_lens.sum().item(),
|
||||
dtype=torch.int32,
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
)
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
@@ -379,7 +380,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
if kv_indices_buf is None:
|
||||
self.cuda_graph_kv_indices = torch.zeros(
|
||||
(max_num_tokens * self.max_context_len),
|
||||
dtype=torch.int32,
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
@@ -396,7 +397,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
if kv_indices_buf is None:
|
||||
self.cuda_graph_window_kv_indices = torch.zeros(
|
||||
(max_num_tokens * self.sliding_window_size),
|
||||
dtype=torch.int32,
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
@@ -888,7 +889,7 @@ class TritonMultiStepDraftBackend:
|
||||
self.speculative_num_steps,
|
||||
forward_batch.batch_size * self.topk * self.max_context_len,
|
||||
),
|
||||
dtype=torch.int32,
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
@@ -906,7 +907,7 @@ class TritonMultiStepDraftBackend:
|
||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||
self.cuda_graph_kv_indices = torch.zeros(
|
||||
(self.speculative_num_steps, max_num_tokens * self.max_context_len),
|
||||
dtype=torch.int32,
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
)
|
||||
for i in range(self.speculative_num_steps):
|
||||
@@ -1015,7 +1016,7 @@ def update_sliding_window_buffer(
|
||||
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
|
||||
window_kv_indptr = window_kv_indptr[: bs + 1]
|
||||
window_kv_indices = torch.empty(
|
||||
window_kv_indptr[-1], dtype=torch.int32, device=device
|
||||
window_kv_indptr[-1], dtype=torch.int64, device=device
|
||||
)
|
||||
window_kv_start_idx = seq_lens - window_kv_lens
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
|
||||
Reference in New Issue
Block a user