Optimize triton attention custom mask (#3731)
This commit is contained in:
@@ -74,6 +74,7 @@ def _fwd_kernel(
|
|||||||
BLOCK_M: tl.constexpr,
|
BLOCK_M: tl.constexpr,
|
||||||
BLOCK_N: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
USE_CUSTOM_MASK: tl.constexpr,
|
USE_CUSTOM_MASK: tl.constexpr,
|
||||||
|
SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,
|
||||||
STORE_TRANSPOSE: tl.constexpr,
|
STORE_TRANSPOSE: tl.constexpr,
|
||||||
):
|
):
|
||||||
cur_seq = tl.program_id(0)
|
cur_seq = tl.program_id(0)
|
||||||
@@ -160,7 +161,7 @@ def _fwd_kernel(
|
|||||||
if logit_cap > 0:
|
if logit_cap > 0:
|
||||||
qk = logit_cap * tanh(qk / logit_cap)
|
qk = logit_cap * tanh(qk / logit_cap)
|
||||||
|
|
||||||
if USE_CUSTOM_MASK:
|
if USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK:
|
||||||
custom_mask = tl.load(
|
custom_mask = tl.load(
|
||||||
mask_ptr
|
mask_ptr
|
||||||
+ cur_seq_mask_start_idx
|
+ cur_seq_mask_start_idx
|
||||||
@@ -302,6 +303,7 @@ def extend_attention_fwd(
|
|||||||
max_len_extend,
|
max_len_extend,
|
||||||
sm_scale=None,
|
sm_scale=None,
|
||||||
logit_cap=0.0,
|
logit_cap=0.0,
|
||||||
|
skip_prefix_custom_mask=True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
||||||
@@ -355,6 +357,8 @@ def extend_attention_fwd(
|
|||||||
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
|
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
|
||||||
|
|
||||||
USE_CUSTOM_MASK = custom_mask is not None
|
USE_CUSTOM_MASK = custom_mask is not None
|
||||||
|
# Skip custom mask for prefix part
|
||||||
|
SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask
|
||||||
|
|
||||||
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
|
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
|
||||||
num_stages = 1
|
num_stages = 1
|
||||||
@@ -398,6 +402,7 @@ def extend_attention_fwd(
|
|||||||
Lq=Lq,
|
Lq=Lq,
|
||||||
Lv=Lv,
|
Lv=Lv,
|
||||||
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
|
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
|
||||||
|
SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
|
||||||
STORE_TRANSPOSE=is_hip_,
|
STORE_TRANSPOSE=is_hip_,
|
||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
num_stages=num_stages,
|
num_stages=num_stages,
|
||||||
|
|||||||
Reference in New Issue
Block a user