Fix swa eagle verify accuracy for Triton backend (#9279)
This commit is contained in:
@@ -35,6 +35,7 @@ class ForwardMetadata:
|
||||
window_kv_indptr: torch.Tensor
|
||||
window_kv_indices: torch.Tensor
|
||||
window_num_kv_splits: torch.Tensor
|
||||
window_kv_offsets: torch.Tensor
|
||||
|
||||
|
||||
class TritonAttnBackend(AttentionBackend):
|
||||
@@ -163,6 +164,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
window_kv_indptr = self.window_kv_indptr
|
||||
window_kv_indices = None
|
||||
window_num_kv_splits = None
|
||||
window_kv_offsets = None
|
||||
spec_info = forward_batch.spec_info
|
||||
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
@@ -186,7 +188,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
self.sliding_window_size is not None
|
||||
and self.sliding_window_size > 0
|
||||
):
|
||||
window_kv_indptr, window_kv_indices, window_kv_lens = (
|
||||
window_kv_indptr, window_kv_indices, window_kv_lens, _ = (
|
||||
update_sliding_window_buffer(
|
||||
self.window_kv_indptr,
|
||||
self.req_to_token,
|
||||
@@ -249,17 +251,21 @@ class TritonAttnBackend(AttentionBackend):
|
||||
)
|
||||
|
||||
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
||||
window_kv_indptr, window_kv_indices, window_kv_lens = (
|
||||
update_sliding_window_buffer(
|
||||
self.window_kv_indptr,
|
||||
self.req_to_token,
|
||||
self.sliding_window_size,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.req_pool_indices,
|
||||
bs,
|
||||
self.device,
|
||||
self.token_to_kv_pool_allocator,
|
||||
)
|
||||
# window_kv_offsets is used to calculate the start position in custom mask
|
||||
(
|
||||
window_kv_indptr,
|
||||
window_kv_indices,
|
||||
window_kv_lens,
|
||||
window_kv_offsets,
|
||||
) = update_sliding_window_buffer(
|
||||
self.window_kv_indptr,
|
||||
self.req_to_token,
|
||||
self.sliding_window_size,
|
||||
forward_batch.seq_lens,
|
||||
forward_batch.req_pool_indices,
|
||||
bs,
|
||||
self.device,
|
||||
self.token_to_kv_pool_allocator,
|
||||
)
|
||||
|
||||
custom_mask = spec_info.custom_mask
|
||||
@@ -312,15 +318,17 @@ class TritonAttnBackend(AttentionBackend):
|
||||
)
|
||||
# Sliding window
|
||||
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
||||
window_kv_indptr, window_kv_indices, _ = update_sliding_window_buffer(
|
||||
self.window_kv_indptr,
|
||||
self.req_to_token,
|
||||
self.sliding_window_size,
|
||||
forward_batch.extend_prefix_lens,
|
||||
forward_batch.req_pool_indices,
|
||||
bs,
|
||||
self.device,
|
||||
self.token_to_kv_pool_allocator,
|
||||
window_kv_indptr, window_kv_indices, _, _ = (
|
||||
update_sliding_window_buffer(
|
||||
self.window_kv_indptr,
|
||||
self.req_to_token,
|
||||
self.sliding_window_size,
|
||||
forward_batch.extend_prefix_lens,
|
||||
forward_batch.req_pool_indices,
|
||||
bs,
|
||||
self.device,
|
||||
self.token_to_kv_pool_allocator,
|
||||
)
|
||||
)
|
||||
|
||||
qo_indptr = self.qo_indptr
|
||||
@@ -346,6 +354,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
window_kv_indptr,
|
||||
window_kv_indices,
|
||||
window_num_kv_splits,
|
||||
window_kv_offsets,
|
||||
)
|
||||
|
||||
def init_cuda_graph_state(
|
||||
@@ -400,6 +409,12 @@ class TritonAttnBackend(AttentionBackend):
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.cuda_graph_window_kv_offsets = torch.zeros(
|
||||
(max_bs,),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
bs: int,
|
||||
@@ -414,6 +429,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
window_kv_indptr = self.window_kv_indptr
|
||||
window_kv_indices = None
|
||||
window_num_kv_splits = None
|
||||
window_kv_offsets = None
|
||||
|
||||
if forward_mode.is_decode_or_idle():
|
||||
if spec_info is None:
|
||||
@@ -436,7 +452,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
):
|
||||
window_kv_indices = self.cuda_graph_window_kv_indices
|
||||
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
||||
window_kv_indptr, window_kv_indices, _ = (
|
||||
window_kv_indptr, window_kv_indices, _, _ = (
|
||||
update_sliding_window_buffer_cuda_graph(
|
||||
self.window_kv_indptr,
|
||||
window_kv_indices,
|
||||
@@ -483,13 +499,14 @@ class TritonAttnBackend(AttentionBackend):
|
||||
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
||||
window_kv_indices = self.cuda_graph_window_kv_indices
|
||||
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
||||
window_kv_indptr, window_kv_indices, _ = (
|
||||
window_kv_offsets = self.cuda_graph_window_kv_offsets
|
||||
window_kv_indptr, window_kv_indices, _, window_kv_offsets[:bs] = (
|
||||
update_sliding_window_buffer_cuda_graph(
|
||||
self.window_kv_indptr,
|
||||
window_kv_indices,
|
||||
self.req_to_token,
|
||||
self.sliding_window_size,
|
||||
seq_lens,
|
||||
seq_lens[:bs],
|
||||
req_pool_indices,
|
||||
bs,
|
||||
self.token_to_kv_pool_allocator,
|
||||
@@ -551,6 +568,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
window_kv_indptr,
|
||||
window_kv_indices,
|
||||
window_num_kv_splits,
|
||||
window_kv_offsets,
|
||||
)
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
@@ -589,7 +607,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
):
|
||||
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
||||
window_kv_indices = self.cuda_graph_window_kv_indices
|
||||
_, _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
|
||||
_, _, window_kv_lens, _ = update_sliding_window_buffer_cuda_graph(
|
||||
self.window_kv_indptr,
|
||||
window_kv_indices,
|
||||
self.req_to_token,
|
||||
@@ -635,15 +653,18 @@ class TritonAttnBackend(AttentionBackend):
|
||||
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
||||
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
||||
window_kv_indices = self.cuda_graph_window_kv_indices
|
||||
_, _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
|
||||
self.window_kv_indptr,
|
||||
window_kv_indices,
|
||||
self.req_to_token,
|
||||
self.sliding_window_size,
|
||||
seq_lens,
|
||||
req_pool_indices,
|
||||
bs,
|
||||
self.token_to_kv_pool_allocator,
|
||||
window_kv_offsets = self.cuda_graph_window_kv_offsets
|
||||
_, _, window_kv_lens, window_kv_offsets[:bs] = (
|
||||
update_sliding_window_buffer_cuda_graph(
|
||||
self.window_kv_indptr,
|
||||
window_kv_indices,
|
||||
self.req_to_token,
|
||||
self.sliding_window_size,
|
||||
seq_lens[:bs],
|
||||
req_pool_indices,
|
||||
bs,
|
||||
self.token_to_kv_pool_allocator,
|
||||
)
|
||||
)
|
||||
custom_mask = self.cuda_graph_custom_mask
|
||||
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
|
||||
@@ -706,10 +727,12 @@ class TritonAttnBackend(AttentionBackend):
|
||||
) # Needed for sliding window mask
|
||||
kv_indptr = self.forward_metadata.window_kv_indptr
|
||||
kv_indices = self.forward_metadata.window_kv_indices
|
||||
window_kv_offsets = self.forward_metadata.window_kv_offsets
|
||||
else:
|
||||
sliding_window_size = -1
|
||||
kv_indptr = self.forward_metadata.kv_indptr
|
||||
kv_indices = self.forward_metadata.kv_indices
|
||||
window_kv_offsets = None
|
||||
|
||||
self.extend_attention_fwd(
|
||||
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||
@@ -729,6 +752,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
layer.logit_cap,
|
||||
sliding_window_size=sliding_window_size,
|
||||
sinks=sinks,
|
||||
window_kv_offsets=window_kv_offsets,
|
||||
)
|
||||
return o
|
||||
|
||||
@@ -1011,7 +1035,7 @@ def update_sliding_window_buffer(
|
||||
window_kv_indices[:kv_last_index]
|
||||
)
|
||||
)
|
||||
return window_kv_indptr, window_kv_indices, window_kv_lens
|
||||
return window_kv_indptr, window_kv_indices, window_kv_lens, window_kv_start_idx
|
||||
|
||||
|
||||
def update_sliding_window_buffer_cuda_graph(
|
||||
@@ -1048,4 +1072,4 @@ def update_sliding_window_buffer_cuda_graph(
|
||||
window_kv_indices[:kv_last_index]
|
||||
)
|
||||
)
|
||||
return window_kv_indptr, window_kv_indices, window_kv_lens
|
||||
return window_kv_indptr, window_kv_indices, window_kv_lens, window_kv_start_idx
|
||||
|
||||
@@ -190,7 +190,7 @@ def _decode_att_m_fwd(
|
||||
Lk = k_buffer.shape[-1]
|
||||
Lv = v_buffer.shape[-1]
|
||||
|
||||
batch, head_num = kv_indptr.shape[0] - 1, q.shape[1]
|
||||
batch, head_num = q.shape[0], q.shape[1]
|
||||
|
||||
grid = (batch, head_num, MAX_KV_SPLITS)
|
||||
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
||||
@@ -433,7 +433,7 @@ def _decode_grouped_att_m_fwd(
|
||||
BLOCK_DPE = 0
|
||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||
|
||||
batch, head_num = kv_indptr.shape[0] - 1, q.shape[1]
|
||||
batch, head_num = q.shape[0], q.shape[1]
|
||||
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
||||
|
||||
BLOCK_H = 16
|
||||
|
||||
@@ -52,6 +52,7 @@ def _fwd_kernel(
|
||||
mask_ptr,
|
||||
mask_indptr,
|
||||
sink_ptr,
|
||||
window_kv_offset_ptr,
|
||||
sm_scale,
|
||||
kv_group_num,
|
||||
stride_qbs,
|
||||
@@ -95,6 +96,11 @@ def _fwd_kernel(
|
||||
if USE_CUSTOM_MASK:
|
||||
cur_seq_mask_start_idx = tl.load(mask_indptr + cur_seq)
|
||||
|
||||
# For SWA, we should only load the mask in the sliding window
|
||||
window_kv_offset = 0
|
||||
if USE_CUSTOM_MASK and SLIDING_WINDOW_SIZE > 0:
|
||||
window_kv_offset = tl.load(window_kv_offset_ptr + cur_seq)
|
||||
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
offs_dv = tl.arange(0, BLOCK_DV)
|
||||
offs_m = tl.arange(0, BLOCK_M)
|
||||
@@ -139,7 +145,9 @@ def _fwd_kernel(
|
||||
custom_mask = tl.load(
|
||||
mask_ptr
|
||||
+ cur_seq_mask_start_idx
|
||||
+ (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_len
|
||||
+ (cur_block_m * BLOCK_M + offs_m[:, None])
|
||||
* (cur_seq_len + window_kv_offset)
|
||||
+ window_kv_offset
|
||||
+ start_n
|
||||
+ offs_n[None, :],
|
||||
mask=(mask_m[:, None] & mask_n[None, :]),
|
||||
@@ -236,7 +244,9 @@ def _fwd_kernel(
|
||||
custom_mask = tl.load(
|
||||
mask_ptr
|
||||
+ cur_seq_mask_start_idx
|
||||
+ (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_len
|
||||
+ (cur_block_m * BLOCK_M + offs_m[:, None])
|
||||
* (cur_seq_len + window_kv_offset)
|
||||
+ window_kv_offset
|
||||
+ cur_seq_len_prefix
|
||||
+ start_n
|
||||
+ offs_n[None, :],
|
||||
@@ -362,6 +372,7 @@ def extend_attention_fwd(
|
||||
skip_prefix_custom_mask=True,
|
||||
sliding_window_size=-1,
|
||||
sinks=None,
|
||||
window_kv_offsets=None,
|
||||
):
|
||||
"""
|
||||
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
||||
@@ -449,6 +460,7 @@ def extend_attention_fwd(
|
||||
custom_mask,
|
||||
mask_indptr,
|
||||
sinks,
|
||||
window_kv_offsets,
|
||||
sm_scale,
|
||||
kv_group_num,
|
||||
q_extend.stride(0),
|
||||
|
||||
Reference in New Issue
Block a user