Fix swa eagle verify accuracy for Triton backend (#9279)
This commit is contained in:
@@ -35,6 +35,7 @@ class ForwardMetadata:
|
|||||||
window_kv_indptr: torch.Tensor
|
window_kv_indptr: torch.Tensor
|
||||||
window_kv_indices: torch.Tensor
|
window_kv_indices: torch.Tensor
|
||||||
window_num_kv_splits: torch.Tensor
|
window_num_kv_splits: torch.Tensor
|
||||||
|
window_kv_offsets: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
class TritonAttnBackend(AttentionBackend):
|
class TritonAttnBackend(AttentionBackend):
|
||||||
@@ -163,6 +164,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
window_kv_indptr = self.window_kv_indptr
|
window_kv_indptr = self.window_kv_indptr
|
||||||
window_kv_indices = None
|
window_kv_indices = None
|
||||||
window_num_kv_splits = None
|
window_num_kv_splits = None
|
||||||
|
window_kv_offsets = None
|
||||||
spec_info = forward_batch.spec_info
|
spec_info = forward_batch.spec_info
|
||||||
|
|
||||||
if forward_batch.forward_mode.is_decode_or_idle():
|
if forward_batch.forward_mode.is_decode_or_idle():
|
||||||
@@ -186,7 +188,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
self.sliding_window_size is not None
|
self.sliding_window_size is not None
|
||||||
and self.sliding_window_size > 0
|
and self.sliding_window_size > 0
|
||||||
):
|
):
|
||||||
window_kv_indptr, window_kv_indices, window_kv_lens = (
|
window_kv_indptr, window_kv_indices, window_kv_lens, _ = (
|
||||||
update_sliding_window_buffer(
|
update_sliding_window_buffer(
|
||||||
self.window_kv_indptr,
|
self.window_kv_indptr,
|
||||||
self.req_to_token,
|
self.req_to_token,
|
||||||
@@ -249,17 +251,21 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
||||||
window_kv_indptr, window_kv_indices, window_kv_lens = (
|
# window_kv_offsets is used to calculate the start position in custom mask
|
||||||
update_sliding_window_buffer(
|
(
|
||||||
self.window_kv_indptr,
|
window_kv_indptr,
|
||||||
self.req_to_token,
|
window_kv_indices,
|
||||||
self.sliding_window_size,
|
window_kv_lens,
|
||||||
forward_batch.seq_lens,
|
window_kv_offsets,
|
||||||
forward_batch.req_pool_indices,
|
) = update_sliding_window_buffer(
|
||||||
bs,
|
self.window_kv_indptr,
|
||||||
self.device,
|
self.req_to_token,
|
||||||
self.token_to_kv_pool_allocator,
|
self.sliding_window_size,
|
||||||
)
|
forward_batch.seq_lens,
|
||||||
|
forward_batch.req_pool_indices,
|
||||||
|
bs,
|
||||||
|
self.device,
|
||||||
|
self.token_to_kv_pool_allocator,
|
||||||
)
|
)
|
||||||
|
|
||||||
custom_mask = spec_info.custom_mask
|
custom_mask = spec_info.custom_mask
|
||||||
@@ -312,15 +318,17 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
# Sliding window
|
# Sliding window
|
||||||
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
||||||
window_kv_indptr, window_kv_indices, _ = update_sliding_window_buffer(
|
window_kv_indptr, window_kv_indices, _, _ = (
|
||||||
self.window_kv_indptr,
|
update_sliding_window_buffer(
|
||||||
self.req_to_token,
|
self.window_kv_indptr,
|
||||||
self.sliding_window_size,
|
self.req_to_token,
|
||||||
forward_batch.extend_prefix_lens,
|
self.sliding_window_size,
|
||||||
forward_batch.req_pool_indices,
|
forward_batch.extend_prefix_lens,
|
||||||
bs,
|
forward_batch.req_pool_indices,
|
||||||
self.device,
|
bs,
|
||||||
self.token_to_kv_pool_allocator,
|
self.device,
|
||||||
|
self.token_to_kv_pool_allocator,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
qo_indptr = self.qo_indptr
|
qo_indptr = self.qo_indptr
|
||||||
@@ -346,6 +354,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
window_kv_indptr,
|
window_kv_indptr,
|
||||||
window_kv_indices,
|
window_kv_indices,
|
||||||
window_num_kv_splits,
|
window_num_kv_splits,
|
||||||
|
window_kv_offsets,
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_cuda_graph_state(
|
def init_cuda_graph_state(
|
||||||
@@ -400,6 +409,12 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.cuda_graph_window_kv_offsets = torch.zeros(
|
||||||
|
(max_bs,),
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
|
||||||
def init_forward_metadata_capture_cuda_graph(
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
self,
|
self,
|
||||||
bs: int,
|
bs: int,
|
||||||
@@ -414,6 +429,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
window_kv_indptr = self.window_kv_indptr
|
window_kv_indptr = self.window_kv_indptr
|
||||||
window_kv_indices = None
|
window_kv_indices = None
|
||||||
window_num_kv_splits = None
|
window_num_kv_splits = None
|
||||||
|
window_kv_offsets = None
|
||||||
|
|
||||||
if forward_mode.is_decode_or_idle():
|
if forward_mode.is_decode_or_idle():
|
||||||
if spec_info is None:
|
if spec_info is None:
|
||||||
@@ -436,7 +452,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
):
|
):
|
||||||
window_kv_indices = self.cuda_graph_window_kv_indices
|
window_kv_indices = self.cuda_graph_window_kv_indices
|
||||||
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
||||||
window_kv_indptr, window_kv_indices, _ = (
|
window_kv_indptr, window_kv_indices, _, _ = (
|
||||||
update_sliding_window_buffer_cuda_graph(
|
update_sliding_window_buffer_cuda_graph(
|
||||||
self.window_kv_indptr,
|
self.window_kv_indptr,
|
||||||
window_kv_indices,
|
window_kv_indices,
|
||||||
@@ -483,13 +499,14 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
||||||
window_kv_indices = self.cuda_graph_window_kv_indices
|
window_kv_indices = self.cuda_graph_window_kv_indices
|
||||||
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
||||||
window_kv_indptr, window_kv_indices, _ = (
|
window_kv_offsets = self.cuda_graph_window_kv_offsets
|
||||||
|
window_kv_indptr, window_kv_indices, _, window_kv_offsets[:bs] = (
|
||||||
update_sliding_window_buffer_cuda_graph(
|
update_sliding_window_buffer_cuda_graph(
|
||||||
self.window_kv_indptr,
|
self.window_kv_indptr,
|
||||||
window_kv_indices,
|
window_kv_indices,
|
||||||
self.req_to_token,
|
self.req_to_token,
|
||||||
self.sliding_window_size,
|
self.sliding_window_size,
|
||||||
seq_lens,
|
seq_lens[:bs],
|
||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
bs,
|
bs,
|
||||||
self.token_to_kv_pool_allocator,
|
self.token_to_kv_pool_allocator,
|
||||||
@@ -551,6 +568,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
window_kv_indptr,
|
window_kv_indptr,
|
||||||
window_kv_indices,
|
window_kv_indices,
|
||||||
window_num_kv_splits,
|
window_num_kv_splits,
|
||||||
|
window_kv_offsets,
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_forward_metadata_replay_cuda_graph(
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
@@ -589,7 +607,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
):
|
):
|
||||||
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
||||||
window_kv_indices = self.cuda_graph_window_kv_indices
|
window_kv_indices = self.cuda_graph_window_kv_indices
|
||||||
_, _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
|
_, _, window_kv_lens, _ = update_sliding_window_buffer_cuda_graph(
|
||||||
self.window_kv_indptr,
|
self.window_kv_indptr,
|
||||||
window_kv_indices,
|
window_kv_indices,
|
||||||
self.req_to_token,
|
self.req_to_token,
|
||||||
@@ -635,15 +653,18 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
||||||
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
|
||||||
window_kv_indices = self.cuda_graph_window_kv_indices
|
window_kv_indices = self.cuda_graph_window_kv_indices
|
||||||
_, _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
|
window_kv_offsets = self.cuda_graph_window_kv_offsets
|
||||||
self.window_kv_indptr,
|
_, _, window_kv_lens, window_kv_offsets[:bs] = (
|
||||||
window_kv_indices,
|
update_sliding_window_buffer_cuda_graph(
|
||||||
self.req_to_token,
|
self.window_kv_indptr,
|
||||||
self.sliding_window_size,
|
window_kv_indices,
|
||||||
seq_lens,
|
self.req_to_token,
|
||||||
req_pool_indices,
|
self.sliding_window_size,
|
||||||
bs,
|
seq_lens[:bs],
|
||||||
self.token_to_kv_pool_allocator,
|
req_pool_indices,
|
||||||
|
bs,
|
||||||
|
self.token_to_kv_pool_allocator,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
custom_mask = self.cuda_graph_custom_mask
|
custom_mask = self.cuda_graph_custom_mask
|
||||||
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
|
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
|
||||||
@@ -706,10 +727,12 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
) # Needed for sliding window mask
|
) # Needed for sliding window mask
|
||||||
kv_indptr = self.forward_metadata.window_kv_indptr
|
kv_indptr = self.forward_metadata.window_kv_indptr
|
||||||
kv_indices = self.forward_metadata.window_kv_indices
|
kv_indices = self.forward_metadata.window_kv_indices
|
||||||
|
window_kv_offsets = self.forward_metadata.window_kv_offsets
|
||||||
else:
|
else:
|
||||||
sliding_window_size = -1
|
sliding_window_size = -1
|
||||||
kv_indptr = self.forward_metadata.kv_indptr
|
kv_indptr = self.forward_metadata.kv_indptr
|
||||||
kv_indices = self.forward_metadata.kv_indices
|
kv_indices = self.forward_metadata.kv_indices
|
||||||
|
window_kv_offsets = None
|
||||||
|
|
||||||
self.extend_attention_fwd(
|
self.extend_attention_fwd(
|
||||||
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||||
@@ -729,6 +752,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
layer.logit_cap,
|
layer.logit_cap,
|
||||||
sliding_window_size=sliding_window_size,
|
sliding_window_size=sliding_window_size,
|
||||||
sinks=sinks,
|
sinks=sinks,
|
||||||
|
window_kv_offsets=window_kv_offsets,
|
||||||
)
|
)
|
||||||
return o
|
return o
|
||||||
|
|
||||||
@@ -1011,7 +1035,7 @@ def update_sliding_window_buffer(
|
|||||||
window_kv_indices[:kv_last_index]
|
window_kv_indices[:kv_last_index]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return window_kv_indptr, window_kv_indices, window_kv_lens
|
return window_kv_indptr, window_kv_indices, window_kv_lens, window_kv_start_idx
|
||||||
|
|
||||||
|
|
||||||
def update_sliding_window_buffer_cuda_graph(
|
def update_sliding_window_buffer_cuda_graph(
|
||||||
@@ -1048,4 +1072,4 @@ def update_sliding_window_buffer_cuda_graph(
|
|||||||
window_kv_indices[:kv_last_index]
|
window_kv_indices[:kv_last_index]
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
return window_kv_indptr, window_kv_indices, window_kv_lens
|
return window_kv_indptr, window_kv_indices, window_kv_lens, window_kv_start_idx
|
||||||
|
|||||||
@@ -190,7 +190,7 @@ def _decode_att_m_fwd(
|
|||||||
Lk = k_buffer.shape[-1]
|
Lk = k_buffer.shape[-1]
|
||||||
Lv = v_buffer.shape[-1]
|
Lv = v_buffer.shape[-1]
|
||||||
|
|
||||||
batch, head_num = kv_indptr.shape[0] - 1, q.shape[1]
|
batch, head_num = q.shape[0], q.shape[1]
|
||||||
|
|
||||||
grid = (batch, head_num, MAX_KV_SPLITS)
|
grid = (batch, head_num, MAX_KV_SPLITS)
|
||||||
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
||||||
@@ -433,7 +433,7 @@ def _decode_grouped_att_m_fwd(
|
|||||||
BLOCK_DPE = 0
|
BLOCK_DPE = 0
|
||||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||||
|
|
||||||
batch, head_num = kv_indptr.shape[0] - 1, q.shape[1]
|
batch, head_num = q.shape[0], q.shape[1]
|
||||||
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
||||||
|
|
||||||
BLOCK_H = 16
|
BLOCK_H = 16
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ def _fwd_kernel(
|
|||||||
mask_ptr,
|
mask_ptr,
|
||||||
mask_indptr,
|
mask_indptr,
|
||||||
sink_ptr,
|
sink_ptr,
|
||||||
|
window_kv_offset_ptr,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
kv_group_num,
|
kv_group_num,
|
||||||
stride_qbs,
|
stride_qbs,
|
||||||
@@ -95,6 +96,11 @@ def _fwd_kernel(
|
|||||||
if USE_CUSTOM_MASK:
|
if USE_CUSTOM_MASK:
|
||||||
cur_seq_mask_start_idx = tl.load(mask_indptr + cur_seq)
|
cur_seq_mask_start_idx = tl.load(mask_indptr + cur_seq)
|
||||||
|
|
||||||
|
# For SWA, we should only load the mask in the sliding window
|
||||||
|
window_kv_offset = 0
|
||||||
|
if USE_CUSTOM_MASK and SLIDING_WINDOW_SIZE > 0:
|
||||||
|
window_kv_offset = tl.load(window_kv_offset_ptr + cur_seq)
|
||||||
|
|
||||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||||
offs_dv = tl.arange(0, BLOCK_DV)
|
offs_dv = tl.arange(0, BLOCK_DV)
|
||||||
offs_m = tl.arange(0, BLOCK_M)
|
offs_m = tl.arange(0, BLOCK_M)
|
||||||
@@ -139,7 +145,9 @@ def _fwd_kernel(
|
|||||||
custom_mask = tl.load(
|
custom_mask = tl.load(
|
||||||
mask_ptr
|
mask_ptr
|
||||||
+ cur_seq_mask_start_idx
|
+ cur_seq_mask_start_idx
|
||||||
+ (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_len
|
+ (cur_block_m * BLOCK_M + offs_m[:, None])
|
||||||
|
* (cur_seq_len + window_kv_offset)
|
||||||
|
+ window_kv_offset
|
||||||
+ start_n
|
+ start_n
|
||||||
+ offs_n[None, :],
|
+ offs_n[None, :],
|
||||||
mask=(mask_m[:, None] & mask_n[None, :]),
|
mask=(mask_m[:, None] & mask_n[None, :]),
|
||||||
@@ -236,7 +244,9 @@ def _fwd_kernel(
|
|||||||
custom_mask = tl.load(
|
custom_mask = tl.load(
|
||||||
mask_ptr
|
mask_ptr
|
||||||
+ cur_seq_mask_start_idx
|
+ cur_seq_mask_start_idx
|
||||||
+ (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_len
|
+ (cur_block_m * BLOCK_M + offs_m[:, None])
|
||||||
|
* (cur_seq_len + window_kv_offset)
|
||||||
|
+ window_kv_offset
|
||||||
+ cur_seq_len_prefix
|
+ cur_seq_len_prefix
|
||||||
+ start_n
|
+ start_n
|
||||||
+ offs_n[None, :],
|
+ offs_n[None, :],
|
||||||
@@ -362,6 +372,7 @@ def extend_attention_fwd(
|
|||||||
skip_prefix_custom_mask=True,
|
skip_prefix_custom_mask=True,
|
||||||
sliding_window_size=-1,
|
sliding_window_size=-1,
|
||||||
sinks=None,
|
sinks=None,
|
||||||
|
window_kv_offsets=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
||||||
@@ -449,6 +460,7 @@ def extend_attention_fwd(
|
|||||||
custom_mask,
|
custom_mask,
|
||||||
mask_indptr,
|
mask_indptr,
|
||||||
sinks,
|
sinks,
|
||||||
|
window_kv_offsets,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
kv_group_num,
|
kv_group_num,
|
||||||
q_extend.stride(0),
|
q_extend.stride(0),
|
||||||
|
|||||||
Reference in New Issue
Block a user