diff --git a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py index 608f9bab0..079c8cfd9 100644 --- a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py @@ -74,6 +74,7 @@ def _fwd_kernel( BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, USE_CUSTOM_MASK: tl.constexpr, + SKIP_PREFIX_CUSTOM_MASK: tl.constexpr, STORE_TRANSPOSE: tl.constexpr, ): cur_seq = tl.program_id(0) @@ -160,7 +161,7 @@ def _fwd_kernel( if logit_cap > 0: qk = logit_cap * tanh(qk / logit_cap) - if USE_CUSTOM_MASK: + if USE_CUSTOM_MASK and not SKIP_PREFIX_CUSTOM_MASK: custom_mask = tl.load( mask_ptr + cur_seq_mask_start_idx @@ -302,6 +303,7 @@ def extend_attention_fwd( max_len_extend, sm_scale=None, logit_cap=0.0, + skip_prefix_custom_mask=True, ): """ q_extend, k_extend, v_extend, o_extend: contiguous tensors @@ -355,6 +357,8 @@ def extend_attention_fwd( kv_group_num = q_extend.shape[1] // k_extend.shape[1] USE_CUSTOM_MASK = custom_mask is not None + # Skip custom mask for prefix part + SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M)) num_stages = 1 @@ -398,6 +402,7 @@ def extend_attention_fwd( Lq=Lq, Lv=Lv, USE_CUSTOM_MASK=USE_CUSTOM_MASK, + SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK, STORE_TRANSPOSE=is_hip_, num_warps=num_warps, num_stages=num_stages,