Refine naming (#8868)
This commit is contained in:
@@ -686,7 +686,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
sk=None,
|
||||
sinks=None,
|
||||
):
|
||||
# TODO: reuse the buffer across layers
|
||||
if layer.qk_head_dim != layer.v_head_dim:
|
||||
@@ -731,7 +731,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
layer.scaling,
|
||||
layer.logit_cap,
|
||||
sliding_window_size=sliding_window_size,
|
||||
sk=sk,
|
||||
sinks=sinks,
|
||||
)
|
||||
return o
|
||||
|
||||
@@ -743,7 +743,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
layer: RadixAttention,
|
||||
forward_batch: ForwardBatch,
|
||||
save_kv_cache=True,
|
||||
sk=None,
|
||||
sinks=None,
|
||||
):
|
||||
# During torch.compile, there is a bug in rotary_emb that causes the
|
||||
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
||||
@@ -780,7 +780,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
self.max_kv_splits,
|
||||
layer.scaling,
|
||||
layer.logit_cap,
|
||||
sk=sk,
|
||||
sinks=sinks,
|
||||
)
|
||||
return o
|
||||
|
||||
|
||||
@@ -495,7 +495,7 @@ def _fwd_kernel_stage2(
|
||||
O,
|
||||
kv_indptr,
|
||||
num_kv_splits,
|
||||
sk_ptr,
|
||||
sink_ptr,
|
||||
stride_mid_ob,
|
||||
stride_mid_oh,
|
||||
stride_mid_os,
|
||||
@@ -505,7 +505,7 @@ def _fwd_kernel_stage2(
|
||||
MIN_BLOCK_KV: tl.constexpr,
|
||||
BLOCK_DV: tl.constexpr,
|
||||
Lv: tl.constexpr,
|
||||
HAS_SK: tl.constexpr,
|
||||
HAS_SINK: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
@@ -547,9 +547,9 @@ def _fwd_kernel_stage2(
|
||||
e_sum = e_sum * old_scale + exp_logic
|
||||
e_max = n_e_max
|
||||
|
||||
if HAS_SK:
|
||||
cur_sk = tl.load(sk_ptr + cur_head)
|
||||
e_sum += tl.exp(cur_sk - e_max)
|
||||
if HAS_SINK:
|
||||
cur_sink = tl.load(sink_ptr + cur_head)
|
||||
e_sum += tl.exp(cur_sink - e_max)
|
||||
|
||||
tl.store(
|
||||
O + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
|
||||
@@ -567,14 +567,14 @@ def _decode_softmax_reducev_fwd(
|
||||
kv_indptr,
|
||||
num_kv_splits,
|
||||
max_kv_splits,
|
||||
sk=None,
|
||||
sinks=None,
|
||||
):
|
||||
batch, head_num = q.shape[0], q.shape[1]
|
||||
Lv = v_buffer.shape[-1]
|
||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||
|
||||
MAX_KV_SPLITS = max_kv_splits
|
||||
HAS_SK = sk is not None
|
||||
HAS_SINK = sinks is not None
|
||||
|
||||
extra_kargs = {}
|
||||
if _is_hip:
|
||||
@@ -589,7 +589,7 @@ def _decode_softmax_reducev_fwd(
|
||||
o,
|
||||
kv_indptr,
|
||||
num_kv_splits,
|
||||
sk,
|
||||
sinks,
|
||||
logits.stride(0),
|
||||
logits.stride(1),
|
||||
logits.stride(2),
|
||||
@@ -599,7 +599,7 @@ def _decode_softmax_reducev_fwd(
|
||||
MIN_BLOCK_KV=_MIN_BLOCK_KV,
|
||||
BLOCK_DV=BLOCK_DV,
|
||||
Lv=Lv,
|
||||
HAS_SK=HAS_SK,
|
||||
HAS_SINK=HAS_SINK,
|
||||
num_warps=4,
|
||||
num_stages=2,
|
||||
**extra_kargs,
|
||||
@@ -619,7 +619,7 @@ def decode_attention_fwd_normal(
|
||||
max_kv_splits,
|
||||
sm_scale,
|
||||
logit_cap=0.0,
|
||||
sk=None,
|
||||
sinks=None,
|
||||
):
|
||||
_decode_att_m_fwd(
|
||||
q,
|
||||
@@ -643,7 +643,7 @@ def decode_attention_fwd_normal(
|
||||
kv_indptr,
|
||||
num_kv_splits,
|
||||
max_kv_splits,
|
||||
sk,
|
||||
sinks,
|
||||
)
|
||||
|
||||
|
||||
@@ -660,7 +660,7 @@ def decode_attention_fwd_grouped(
|
||||
max_kv_splits,
|
||||
sm_scale,
|
||||
logit_cap=0.0,
|
||||
sk=None,
|
||||
sinks=None,
|
||||
):
|
||||
_decode_grouped_att_m_fwd(
|
||||
q,
|
||||
@@ -684,7 +684,7 @@ def decode_attention_fwd_grouped(
|
||||
kv_indptr,
|
||||
num_kv_splits,
|
||||
max_kv_splits,
|
||||
sk,
|
||||
sinks,
|
||||
)
|
||||
|
||||
|
||||
@@ -701,7 +701,7 @@ def decode_attention_fwd(
|
||||
max_kv_splits,
|
||||
sm_scale,
|
||||
logit_cap=0.0,
|
||||
sk=None,
|
||||
sinks=None,
|
||||
):
|
||||
assert max_kv_splits == attn_logits.shape[2]
|
||||
assert q.shape[0] <= kv_indptr.shape[0] - 1
|
||||
@@ -724,7 +724,7 @@ def decode_attention_fwd(
|
||||
max_kv_splits,
|
||||
sm_scale,
|
||||
logit_cap=logit_cap,
|
||||
sk=sk,
|
||||
sinks=sinks,
|
||||
)
|
||||
else:
|
||||
# GQA/MQA/MLA
|
||||
@@ -741,5 +741,5 @@ def decode_attention_fwd(
|
||||
max_kv_splits,
|
||||
sm_scale,
|
||||
logit_cap=logit_cap,
|
||||
sk=sk,
|
||||
sinks=sinks,
|
||||
)
|
||||
|
||||
@@ -51,7 +51,7 @@ def _fwd_kernel(
|
||||
kv_indices,
|
||||
mask_ptr,
|
||||
mask_indptr,
|
||||
sk_ptr,
|
||||
sink_ptr,
|
||||
sm_scale,
|
||||
kv_group_num,
|
||||
stride_qbs,
|
||||
@@ -79,7 +79,7 @@ def _fwd_kernel(
|
||||
IS_CAUSAL: tl.constexpr,
|
||||
SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,
|
||||
STORE_TRANSPOSE: tl.constexpr,
|
||||
HAS_SK: tl.constexpr,
|
||||
HAS_SINK: tl.constexpr,
|
||||
):
|
||||
cur_seq = tl.program_id(0)
|
||||
cur_head = tl.program_id(1)
|
||||
@@ -302,9 +302,9 @@ def _fwd_kernel(
|
||||
|
||||
e_max = n_e_max
|
||||
|
||||
if HAS_SK:
|
||||
cur_sk = tl.load(sk_ptr + cur_head)
|
||||
deno += tl.exp(cur_sk - e_max)
|
||||
if HAS_SINK:
|
||||
cur_sink = tl.load(sink_ptr + cur_head)
|
||||
deno += tl.exp(cur_sink - e_max)
|
||||
|
||||
offs_o = (
|
||||
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
|
||||
@@ -344,7 +344,7 @@ def extend_attention_fwd(
|
||||
logit_cap=0.0,
|
||||
skip_prefix_custom_mask=True,
|
||||
sliding_window_size=-1,
|
||||
sk=None,
|
||||
sinks=None,
|
||||
):
|
||||
"""
|
||||
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
||||
@@ -410,7 +410,7 @@ def extend_attention_fwd(
|
||||
# Skip custom mask for prefix part
|
||||
SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask
|
||||
|
||||
HAS_SK = sk is not None
|
||||
HAS_SINK = sinks is not None
|
||||
|
||||
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
|
||||
num_stages = 1
|
||||
@@ -431,7 +431,7 @@ def extend_attention_fwd(
|
||||
kv_indices,
|
||||
custom_mask,
|
||||
mask_indptr,
|
||||
sk,
|
||||
sinks,
|
||||
sm_scale,
|
||||
kv_group_num,
|
||||
q_extend.stride(0),
|
||||
@@ -458,7 +458,7 @@ def extend_attention_fwd(
|
||||
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
|
||||
IS_CAUSAL=is_causal,
|
||||
SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
|
||||
HAS_SK=HAS_SK,
|
||||
HAS_SINK=HAS_SINK,
|
||||
STORE_TRANSPOSE=_is_hip,
|
||||
num_warps=num_warps,
|
||||
num_stages=num_stages,
|
||||
|
||||
@@ -301,7 +301,7 @@ class GptOssAttention(nn.Module):
|
||||
hidden_states, forward_batch, inner_state = intermediate_state
|
||||
if inner_state is None:
|
||||
return hidden_states
|
||||
attn_output = self.attn(*inner_state, sk=self.sinks)
|
||||
attn_output = self.attn(*inner_state, sinks=self.sinks)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
Reference in New Issue
Block a user