Refine naming (#8868)
This commit is contained in:
@@ -686,7 +686,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
layer: RadixAttention,
|
layer: RadixAttention,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
save_kv_cache=True,
|
save_kv_cache=True,
|
||||||
sk=None,
|
sinks=None,
|
||||||
):
|
):
|
||||||
# TODO: reuse the buffer across layers
|
# TODO: reuse the buffer across layers
|
||||||
if layer.qk_head_dim != layer.v_head_dim:
|
if layer.qk_head_dim != layer.v_head_dim:
|
||||||
@@ -731,7 +731,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
layer.scaling,
|
layer.scaling,
|
||||||
layer.logit_cap,
|
layer.logit_cap,
|
||||||
sliding_window_size=sliding_window_size,
|
sliding_window_size=sliding_window_size,
|
||||||
sk=sk,
|
sinks=sinks,
|
||||||
)
|
)
|
||||||
return o
|
return o
|
||||||
|
|
||||||
@@ -743,7 +743,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
layer: RadixAttention,
|
layer: RadixAttention,
|
||||||
forward_batch: ForwardBatch,
|
forward_batch: ForwardBatch,
|
||||||
save_kv_cache=True,
|
save_kv_cache=True,
|
||||||
sk=None,
|
sinks=None,
|
||||||
):
|
):
|
||||||
# During torch.compile, there is a bug in rotary_emb that causes the
|
# During torch.compile, there is a bug in rotary_emb that causes the
|
||||||
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
||||||
@@ -780,7 +780,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
self.max_kv_splits,
|
self.max_kv_splits,
|
||||||
layer.scaling,
|
layer.scaling,
|
||||||
layer.logit_cap,
|
layer.logit_cap,
|
||||||
sk=sk,
|
sinks=sinks,
|
||||||
)
|
)
|
||||||
return o
|
return o
|
||||||
|
|
||||||
|
|||||||
@@ -495,7 +495,7 @@ def _fwd_kernel_stage2(
|
|||||||
O,
|
O,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
sk_ptr,
|
sink_ptr,
|
||||||
stride_mid_ob,
|
stride_mid_ob,
|
||||||
stride_mid_oh,
|
stride_mid_oh,
|
||||||
stride_mid_os,
|
stride_mid_os,
|
||||||
@@ -505,7 +505,7 @@ def _fwd_kernel_stage2(
|
|||||||
MIN_BLOCK_KV: tl.constexpr,
|
MIN_BLOCK_KV: tl.constexpr,
|
||||||
BLOCK_DV: tl.constexpr,
|
BLOCK_DV: tl.constexpr,
|
||||||
Lv: tl.constexpr,
|
Lv: tl.constexpr,
|
||||||
HAS_SK: tl.constexpr,
|
HAS_SINK: tl.constexpr,
|
||||||
):
|
):
|
||||||
cur_batch = tl.program_id(0)
|
cur_batch = tl.program_id(0)
|
||||||
cur_head = tl.program_id(1)
|
cur_head = tl.program_id(1)
|
||||||
@@ -547,9 +547,9 @@ def _fwd_kernel_stage2(
|
|||||||
e_sum = e_sum * old_scale + exp_logic
|
e_sum = e_sum * old_scale + exp_logic
|
||||||
e_max = n_e_max
|
e_max = n_e_max
|
||||||
|
|
||||||
if HAS_SK:
|
if HAS_SINK:
|
||||||
cur_sk = tl.load(sk_ptr + cur_head)
|
cur_sink = tl.load(sink_ptr + cur_head)
|
||||||
e_sum += tl.exp(cur_sk - e_max)
|
e_sum += tl.exp(cur_sink - e_max)
|
||||||
|
|
||||||
tl.store(
|
tl.store(
|
||||||
O + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
|
O + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
|
||||||
@@ -567,14 +567,14 @@ def _decode_softmax_reducev_fwd(
|
|||||||
kv_indptr,
|
kv_indptr,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
max_kv_splits,
|
max_kv_splits,
|
||||||
sk=None,
|
sinks=None,
|
||||||
):
|
):
|
||||||
batch, head_num = q.shape[0], q.shape[1]
|
batch, head_num = q.shape[0], q.shape[1]
|
||||||
Lv = v_buffer.shape[-1]
|
Lv = v_buffer.shape[-1]
|
||||||
BLOCK_DV = triton.next_power_of_2(Lv)
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||||
|
|
||||||
MAX_KV_SPLITS = max_kv_splits
|
MAX_KV_SPLITS = max_kv_splits
|
||||||
HAS_SK = sk is not None
|
HAS_SINK = sinks is not None
|
||||||
|
|
||||||
extra_kargs = {}
|
extra_kargs = {}
|
||||||
if _is_hip:
|
if _is_hip:
|
||||||
@@ -589,7 +589,7 @@ def _decode_softmax_reducev_fwd(
|
|||||||
o,
|
o,
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
sk,
|
sinks,
|
||||||
logits.stride(0),
|
logits.stride(0),
|
||||||
logits.stride(1),
|
logits.stride(1),
|
||||||
logits.stride(2),
|
logits.stride(2),
|
||||||
@@ -599,7 +599,7 @@ def _decode_softmax_reducev_fwd(
|
|||||||
MIN_BLOCK_KV=_MIN_BLOCK_KV,
|
MIN_BLOCK_KV=_MIN_BLOCK_KV,
|
||||||
BLOCK_DV=BLOCK_DV,
|
BLOCK_DV=BLOCK_DV,
|
||||||
Lv=Lv,
|
Lv=Lv,
|
||||||
HAS_SK=HAS_SK,
|
HAS_SINK=HAS_SINK,
|
||||||
num_warps=4,
|
num_warps=4,
|
||||||
num_stages=2,
|
num_stages=2,
|
||||||
**extra_kargs,
|
**extra_kargs,
|
||||||
@@ -619,7 +619,7 @@ def decode_attention_fwd_normal(
|
|||||||
max_kv_splits,
|
max_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap=0.0,
|
logit_cap=0.0,
|
||||||
sk=None,
|
sinks=None,
|
||||||
):
|
):
|
||||||
_decode_att_m_fwd(
|
_decode_att_m_fwd(
|
||||||
q,
|
q,
|
||||||
@@ -643,7 +643,7 @@ def decode_attention_fwd_normal(
|
|||||||
kv_indptr,
|
kv_indptr,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
max_kv_splits,
|
max_kv_splits,
|
||||||
sk,
|
sinks,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -660,7 +660,7 @@ def decode_attention_fwd_grouped(
|
|||||||
max_kv_splits,
|
max_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap=0.0,
|
logit_cap=0.0,
|
||||||
sk=None,
|
sinks=None,
|
||||||
):
|
):
|
||||||
_decode_grouped_att_m_fwd(
|
_decode_grouped_att_m_fwd(
|
||||||
q,
|
q,
|
||||||
@@ -684,7 +684,7 @@ def decode_attention_fwd_grouped(
|
|||||||
kv_indptr,
|
kv_indptr,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
max_kv_splits,
|
max_kv_splits,
|
||||||
sk,
|
sinks,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -701,7 +701,7 @@ def decode_attention_fwd(
|
|||||||
max_kv_splits,
|
max_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap=0.0,
|
logit_cap=0.0,
|
||||||
sk=None,
|
sinks=None,
|
||||||
):
|
):
|
||||||
assert max_kv_splits == attn_logits.shape[2]
|
assert max_kv_splits == attn_logits.shape[2]
|
||||||
assert q.shape[0] <= kv_indptr.shape[0] - 1
|
assert q.shape[0] <= kv_indptr.shape[0] - 1
|
||||||
@@ -724,7 +724,7 @@ def decode_attention_fwd(
|
|||||||
max_kv_splits,
|
max_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap=logit_cap,
|
logit_cap=logit_cap,
|
||||||
sk=sk,
|
sinks=sinks,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# GQA/MQA/MLA
|
# GQA/MQA/MLA
|
||||||
@@ -741,5 +741,5 @@ def decode_attention_fwd(
|
|||||||
max_kv_splits,
|
max_kv_splits,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap=logit_cap,
|
logit_cap=logit_cap,
|
||||||
sk=sk,
|
sinks=sinks,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ def _fwd_kernel(
|
|||||||
kv_indices,
|
kv_indices,
|
||||||
mask_ptr,
|
mask_ptr,
|
||||||
mask_indptr,
|
mask_indptr,
|
||||||
sk_ptr,
|
sink_ptr,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
kv_group_num,
|
kv_group_num,
|
||||||
stride_qbs,
|
stride_qbs,
|
||||||
@@ -79,7 +79,7 @@ def _fwd_kernel(
|
|||||||
IS_CAUSAL: tl.constexpr,
|
IS_CAUSAL: tl.constexpr,
|
||||||
SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,
|
SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,
|
||||||
STORE_TRANSPOSE: tl.constexpr,
|
STORE_TRANSPOSE: tl.constexpr,
|
||||||
HAS_SK: tl.constexpr,
|
HAS_SINK: tl.constexpr,
|
||||||
):
|
):
|
||||||
cur_seq = tl.program_id(0)
|
cur_seq = tl.program_id(0)
|
||||||
cur_head = tl.program_id(1)
|
cur_head = tl.program_id(1)
|
||||||
@@ -302,9 +302,9 @@ def _fwd_kernel(
|
|||||||
|
|
||||||
e_max = n_e_max
|
e_max = n_e_max
|
||||||
|
|
||||||
if HAS_SK:
|
if HAS_SINK:
|
||||||
cur_sk = tl.load(sk_ptr + cur_head)
|
cur_sink = tl.load(sink_ptr + cur_head)
|
||||||
deno += tl.exp(cur_sk - e_max)
|
deno += tl.exp(cur_sink - e_max)
|
||||||
|
|
||||||
offs_o = (
|
offs_o = (
|
||||||
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
|
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
|
||||||
@@ -344,7 +344,7 @@ def extend_attention_fwd(
|
|||||||
logit_cap=0.0,
|
logit_cap=0.0,
|
||||||
skip_prefix_custom_mask=True,
|
skip_prefix_custom_mask=True,
|
||||||
sliding_window_size=-1,
|
sliding_window_size=-1,
|
||||||
sk=None,
|
sinks=None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
||||||
@@ -410,7 +410,7 @@ def extend_attention_fwd(
|
|||||||
# Skip custom mask for prefix part
|
# Skip custom mask for prefix part
|
||||||
SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask
|
SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask
|
||||||
|
|
||||||
HAS_SK = sk is not None
|
HAS_SINK = sinks is not None
|
||||||
|
|
||||||
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
|
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
|
||||||
num_stages = 1
|
num_stages = 1
|
||||||
@@ -431,7 +431,7 @@ def extend_attention_fwd(
|
|||||||
kv_indices,
|
kv_indices,
|
||||||
custom_mask,
|
custom_mask,
|
||||||
mask_indptr,
|
mask_indptr,
|
||||||
sk,
|
sinks,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
kv_group_num,
|
kv_group_num,
|
||||||
q_extend.stride(0),
|
q_extend.stride(0),
|
||||||
@@ -458,7 +458,7 @@ def extend_attention_fwd(
|
|||||||
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
|
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
|
||||||
IS_CAUSAL=is_causal,
|
IS_CAUSAL=is_causal,
|
||||||
SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
|
SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
|
||||||
HAS_SK=HAS_SK,
|
HAS_SINK=HAS_SINK,
|
||||||
STORE_TRANSPOSE=_is_hip,
|
STORE_TRANSPOSE=_is_hip,
|
||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
num_stages=num_stages,
|
num_stages=num_stages,
|
||||||
|
|||||||
@@ -301,7 +301,7 @@ class GptOssAttention(nn.Module):
|
|||||||
hidden_states, forward_batch, inner_state = intermediate_state
|
hidden_states, forward_batch, inner_state = intermediate_state
|
||||||
if inner_state is None:
|
if inner_state is None:
|
||||||
return hidden_states
|
return hidden_states
|
||||||
attn_output = self.attn(*inner_state, sk=self.sinks)
|
attn_output = self.attn(*inner_state, sinks=self.sinks)
|
||||||
output, _ = self.o_proj(attn_output)
|
output, _ = self.o_proj(attn_output)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user