Support custom mask for Triton attention (#3317)
This commit is contained in:
@@ -91,6 +91,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
|
|
||||||
qo_indptr = None
|
qo_indptr = None
|
||||||
custom_mask = None
|
custom_mask = None
|
||||||
|
mask_offsets = None
|
||||||
else:
|
else:
|
||||||
kv_indptr[1 : bs + 1] = torch.cumsum(
|
kv_indptr[1 : bs + 1] = torch.cumsum(
|
||||||
forward_batch.extend_prefix_lens, dim=0
|
forward_batch.extend_prefix_lens, dim=0
|
||||||
@@ -115,6 +116,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
|
qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0)
|
||||||
qo_indptr = qo_indptr[: bs + 1]
|
qo_indptr = qo_indptr[: bs + 1]
|
||||||
custom_mask = None
|
custom_mask = None
|
||||||
|
mask_offsets = None
|
||||||
|
|
||||||
attn_logits = None
|
attn_logits = None
|
||||||
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
|
max_extend_len = torch.max(forward_batch.extend_seq_lens).item()
|
||||||
@@ -126,6 +128,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
kv_indices,
|
kv_indices,
|
||||||
qo_indptr,
|
qo_indptr,
|
||||||
custom_mask,
|
custom_mask,
|
||||||
|
mask_offsets,
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_cuda_graph_state(self, max_bs: int):
|
def init_cuda_graph_state(self, max_bs: int):
|
||||||
@@ -180,6 +183,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
kv_indices,
|
kv_indices,
|
||||||
None,
|
None,
|
||||||
None,
|
None,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_forward_metadata_replay_cuda_graph(
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
@@ -233,9 +237,15 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
layer, forward_batch.out_cache_loc, k, v
|
layer, forward_batch.out_cache_loc, k, v
|
||||||
)
|
)
|
||||||
|
|
||||||
_, max_extend_len, kv_indptr, kv_indices, qo_indptr, custom_mask = (
|
(
|
||||||
self.forward_metadata
|
_,
|
||||||
)
|
max_extend_len,
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
qo_indptr,
|
||||||
|
custom_mask,
|
||||||
|
mask_offsets,
|
||||||
|
) = self.forward_metadata
|
||||||
self.extend_attention_fwd(
|
self.extend_attention_fwd(
|
||||||
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
||||||
k.contiguous(),
|
k.contiguous(),
|
||||||
@@ -246,6 +256,8 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
qo_indptr,
|
qo_indptr,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
|
custom_mask,
|
||||||
|
mask_offsets,
|
||||||
max_extend_len,
|
max_extend_len,
|
||||||
layer.scaling,
|
layer.scaling,
|
||||||
layer.logit_cap,
|
layer.logit_cap,
|
||||||
@@ -271,7 +283,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
else:
|
else:
|
||||||
o = torch.empty_like(q)
|
o = torch.empty_like(q)
|
||||||
|
|
||||||
attn_logits, _, kv_indptr, kv_indices, _, _ = self.forward_metadata
|
attn_logits, _, kv_indptr, kv_indices, _, _, _ = self.forward_metadata
|
||||||
|
|
||||||
if save_kv_cache:
|
if save_kv_cache:
|
||||||
forward_batch.token_to_kv_pool.set_kv_buffer(
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
||||||
|
|||||||
@@ -49,6 +49,8 @@ def _fwd_kernel(
|
|||||||
qo_indptr,
|
qo_indptr,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
|
mask_ptr,
|
||||||
|
mask_offsets,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
kv_group_num,
|
kv_group_num,
|
||||||
stride_qbs,
|
stride_qbs,
|
||||||
@@ -71,6 +73,7 @@ def _fwd_kernel(
|
|||||||
BLOCK_DV: tl.constexpr,
|
BLOCK_DV: tl.constexpr,
|
||||||
BLOCK_M: tl.constexpr,
|
BLOCK_M: tl.constexpr,
|
||||||
BLOCK_N: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
|
USE_CUSTOM_MASK: tl.constexpr,
|
||||||
):
|
):
|
||||||
cur_seq = tl.program_id(0)
|
cur_seq = tl.program_id(0)
|
||||||
cur_head = tl.program_id(1)
|
cur_head = tl.program_id(1)
|
||||||
@@ -81,6 +84,10 @@ def _fwd_kernel(
|
|||||||
cur_seq_len_extend = tl.load(qo_indptr + cur_seq + 1) - cur_seq_extend_start_idx
|
cur_seq_len_extend = tl.load(qo_indptr + cur_seq + 1) - cur_seq_extend_start_idx
|
||||||
cur_seq_kv_start_idx = tl.load(kv_indptr + cur_seq)
|
cur_seq_kv_start_idx = tl.load(kv_indptr + cur_seq)
|
||||||
cur_seq_len_prefix = tl.load(kv_indptr + cur_seq + 1) - cur_seq_kv_start_idx
|
cur_seq_len_prefix = tl.load(kv_indptr + cur_seq + 1) - cur_seq_kv_start_idx
|
||||||
|
cur_seq_len = cur_seq_len_prefix + cur_seq_len_extend
|
||||||
|
|
||||||
|
if USE_CUSTOM_MASK:
|
||||||
|
cur_seq_mask_start_idx = tl.load(mask_offsets + cur_seq)
|
||||||
|
|
||||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||||
offs_dv = tl.arange(0, BLOCK_DV)
|
offs_dv = tl.arange(0, BLOCK_DV)
|
||||||
@@ -152,7 +159,20 @@ def _fwd_kernel(
|
|||||||
if logit_cap > 0:
|
if logit_cap > 0:
|
||||||
qk = logit_cap * tanh(qk / logit_cap)
|
qk = logit_cap * tanh(qk / logit_cap)
|
||||||
|
|
||||||
qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
|
if USE_CUSTOM_MASK:
|
||||||
|
custom_mask = tl.load(
|
||||||
|
mask_ptr
|
||||||
|
+ cur_seq_mask_start_idx
|
||||||
|
+ (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_len
|
||||||
|
+ start_n
|
||||||
|
+ offs_n[None, :],
|
||||||
|
mask=(mask_m[:, None] & mask_n[None, :]),
|
||||||
|
other=0,
|
||||||
|
)
|
||||||
|
custom_mask &= mask_m[:, None] & mask_n[None, :]
|
||||||
|
qk = tl.where(custom_mask, qk, float("-inf"))
|
||||||
|
else:
|
||||||
|
qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf"))
|
||||||
|
|
||||||
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
||||||
re_scale = tl.exp(e_max - n_e_max)
|
re_scale = tl.exp(e_max - n_e_max)
|
||||||
@@ -172,7 +192,7 @@ def _fwd_kernel(
|
|||||||
|
|
||||||
e_max = n_e_max
|
e_max = n_e_max
|
||||||
|
|
||||||
# stage 2: compute the trianlge part
|
# stage 2: compute the triangle part
|
||||||
|
|
||||||
cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
|
cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
|
||||||
for start_n in range(0, cur_block_m_end, BLOCK_N):
|
for start_n in range(0, cur_block_m_end, BLOCK_N):
|
||||||
@@ -208,11 +228,25 @@ def _fwd_kernel(
|
|||||||
if logit_cap > 0:
|
if logit_cap > 0:
|
||||||
qk = logit_cap * tanh(qk / logit_cap)
|
qk = logit_cap * tanh(qk / logit_cap)
|
||||||
|
|
||||||
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
|
if USE_CUSTOM_MASK:
|
||||||
start_n + offs_n[None, :]
|
custom_mask = tl.load(
|
||||||
)
|
mask_ptr
|
||||||
mask_causual &= mask_m[:, None] & mask_n[None, :]
|
+ cur_seq_mask_start_idx
|
||||||
qk = tl.where(mask_causual, qk, float("-inf"))
|
+ (cur_block_m * BLOCK_M + offs_m[:, None]) * cur_seq_len
|
||||||
|
+ cur_seq_len_prefix
|
||||||
|
+ start_n
|
||||||
|
+ offs_n[None, :],
|
||||||
|
mask=(mask_m[:, None] & mask_n[None, :]),
|
||||||
|
other=0,
|
||||||
|
)
|
||||||
|
custom_mask &= mask_m[:, None] & mask_n[None, :]
|
||||||
|
qk = tl.where(custom_mask, qk, float("-inf"))
|
||||||
|
else:
|
||||||
|
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
|
||||||
|
start_n + offs_n[None, :]
|
||||||
|
)
|
||||||
|
mask_causual &= mask_m[:, None] & mask_n[None, :]
|
||||||
|
qk = tl.where(mask_causual, qk, float("-inf"))
|
||||||
|
|
||||||
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
|
||||||
re_scale = tl.exp(e_max - n_e_max)
|
re_scale = tl.exp(e_max - n_e_max)
|
||||||
@@ -253,6 +287,8 @@ def extend_attention_fwd(
|
|||||||
qo_indptr,
|
qo_indptr,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
|
custom_mask,
|
||||||
|
mask_offsets,
|
||||||
max_len_extend,
|
max_len_extend,
|
||||||
sm_scale=None,
|
sm_scale=None,
|
||||||
logit_cap=0.0,
|
logit_cap=0.0,
|
||||||
@@ -308,6 +344,8 @@ def extend_attention_fwd(
|
|||||||
batch_size, head_num = qo_indptr.shape[0] - 1, q_extend.shape[1]
|
batch_size, head_num = qo_indptr.shape[0] - 1, q_extend.shape[1]
|
||||||
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
|
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
|
||||||
|
|
||||||
|
USE_CUSTOM_MASK = custom_mask is not None
|
||||||
|
|
||||||
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
|
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
|
||||||
num_stages = 1
|
num_stages = 1
|
||||||
|
|
||||||
@@ -325,6 +363,8 @@ def extend_attention_fwd(
|
|||||||
qo_indptr,
|
qo_indptr,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
|
custom_mask,
|
||||||
|
mask_offsets,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
kv_group_num,
|
kv_group_num,
|
||||||
q_extend.stride(0),
|
q_extend.stride(0),
|
||||||
@@ -347,6 +387,7 @@ def extend_attention_fwd(
|
|||||||
BLOCK_N=BLOCK_N,
|
BLOCK_N=BLOCK_N,
|
||||||
Lq=Lq,
|
Lq=Lq,
|
||||||
Lv=Lv,
|
Lv=Lv,
|
||||||
|
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
|
||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
num_stages=num_stages,
|
num_stages=num_stages,
|
||||||
**extra_kargs,
|
**extra_kargs,
|
||||||
|
|||||||
@@ -89,6 +89,9 @@ class TestTritonAttention(unittest.TestCase):
|
|||||||
).normal_(mean=0.1, std=0.2)
|
).normal_(mean=0.1, std=0.2)
|
||||||
|
|
||||||
o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
|
o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
|
||||||
|
o_extend_mask = torch.empty(
|
||||||
|
(extend_token_num, H_Q, D), dtype=dtype, device="cuda"
|
||||||
|
)
|
||||||
o_redundant = torch.empty(
|
o_redundant = torch.empty(
|
||||||
(extend_token_num, H_Q, D), dtype=dtype, device="cuda"
|
(extend_token_num, H_Q, D), dtype=dtype, device="cuda"
|
||||||
)
|
)
|
||||||
@@ -98,6 +101,9 @@ class TestTritonAttention(unittest.TestCase):
|
|||||||
qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
|
qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
|
||||||
qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0)
|
qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0)
|
||||||
|
|
||||||
|
custom_mask = None
|
||||||
|
mask_offsets = None
|
||||||
|
|
||||||
extend_attention_fwd(
|
extend_attention_fwd(
|
||||||
q_extend,
|
q_extend,
|
||||||
k_extend,
|
k_extend,
|
||||||
@@ -108,6 +114,42 @@ class TestTritonAttention(unittest.TestCase):
|
|||||||
qo_indptr,
|
qo_indptr,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
|
custom_mask,
|
||||||
|
mask_offsets,
|
||||||
|
max_len_extend,
|
||||||
|
)
|
||||||
|
|
||||||
|
b_seq_mask_len = b_seq_len_extend * b_seq_len
|
||||||
|
custom_mask = torch.ones(
|
||||||
|
(b_seq_mask_len.sum().item(),), dtype=torch.bool, device="cuda"
|
||||||
|
)
|
||||||
|
mask_offsets = torch.zeros((B + 1,), dtype=torch.int64, device="cuda")
|
||||||
|
mask_offsets[1 : B + 1] = torch.cumsum(b_seq_mask_len[:B], dim=0)
|
||||||
|
for i in range(B):
|
||||||
|
causal_mask = (
|
||||||
|
torch.tril(
|
||||||
|
torch.ones(b_seq_len_extend[i], b_seq_len_extend[i]), diagonal=0
|
||||||
|
)
|
||||||
|
== 1
|
||||||
|
)
|
||||||
|
prefix_mask = torch.ones(
|
||||||
|
b_seq_len_extend[i], b_seq_len_prefix[i], dtype=torch.bool
|
||||||
|
)
|
||||||
|
mask_flatten = torch.cat([prefix_mask, causal_mask], dim=1).flatten()
|
||||||
|
custom_mask[mask_offsets[i] : mask_offsets[i + 1]] = mask_flatten
|
||||||
|
|
||||||
|
extend_attention_fwd(
|
||||||
|
q_extend,
|
||||||
|
k_extend,
|
||||||
|
v_extend,
|
||||||
|
o_extend_mask,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
qo_indptr,
|
||||||
|
kv_indptr,
|
||||||
|
kv_indices,
|
||||||
|
custom_mask,
|
||||||
|
mask_offsets,
|
||||||
max_len_extend,
|
max_len_extend,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -124,6 +166,7 @@ class TestTritonAttention(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.assertTrue(torch.allclose(o_extend, o_redundant, rtol=1e-2))
|
self.assertTrue(torch.allclose(o_extend, o_redundant, rtol=1e-2))
|
||||||
|
self.assertTrue(torch.allclose(o_extend_mask, o_redundant, rtol=1e-2))
|
||||||
|
|
||||||
def test_extend_attention(self):
|
def test_extend_attention(self):
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user