Optimize triton attention custom mask (#3731)
This commit is contained in:
@@ -74,6 +74,7 @@ def _fwd_kernel(
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
USE_CUSTOM_MASK: tl.constexpr,
|
||||
SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,
|
||||
STORE_TRANSPOSE: tl.constexpr,
|
||||
):
|
||||
cur_seq = tl.program_id(0)
|
||||
@@ -160,7 +161,7 @@ def _fwd_kernel(
|
||||
if logit_cap > 0:
|
||||
qk = logit_cap * tanh(qk / logit_cap)
|
||||
|
||||
if USE_CUSTOM_MASK:
|
||||
if USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK:
|
||||
custom_mask = tl.load(
|
||||
mask_ptr
|
||||
+ cur_seq_mask_start_idx
|
||||
@@ -302,6 +303,7 @@ def extend_attention_fwd(
|
||||
max_len_extend,
|
||||
sm_scale=None,
|
||||
logit_cap=0.0,
|
||||
skip_prefix_custom_mask=True,
|
||||
):
|
||||
"""
|
||||
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
||||
@@ -355,6 +357,8 @@ def extend_attention_fwd(
|
||||
kv_group_num = q_extend.shape[1] // k_extend.shape[1]
|
||||
|
||||
USE_CUSTOM_MASK = custom_mask is not None
|
||||
# Skip custom mask for prefix part
|
||||
SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask
|
||||
|
||||
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
|
||||
num_stages = 1
|
||||
@@ -398,6 +402,7 @@ def extend_attention_fwd(
|
||||
Lq=Lq,
|
||||
Lv=Lv,
|
||||
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
|
||||
SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
|
||||
STORE_TRANSPOSE=is_hip_,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
|
||||
Reference in New Issue
Block a user