[triton] Support head_dim not 2^n in triton extend and decode attention (#1281)
This commit is contained in:
@@ -60,6 +60,7 @@ def _fwd_kernel_stage1(
|
|||||||
BLOCK_DMODEL: tl.constexpr,
|
BLOCK_DMODEL: tl.constexpr,
|
||||||
BLOCK_N: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
logit_cap: tl.constexpr,
|
logit_cap: tl.constexpr,
|
||||||
|
Lk: tl.constexpr,
|
||||||
):
|
):
|
||||||
cur_batch = tl.program_id(0)
|
cur_batch = tl.program_id(0)
|
||||||
cur_head = tl.program_id(1)
|
cur_head = tl.program_id(1)
|
||||||
@@ -97,7 +98,7 @@ def _fwd_kernel_stage1(
|
|||||||
)
|
)
|
||||||
k = tl.load(
|
k = tl.load(
|
||||||
K_Buffer + offs_buf_k,
|
K_Buffer + offs_buf_k,
|
||||||
mask=offs_n_new[:, None] < cur_batch_end_index,
|
mask=(offs_n_new[:, None] < cur_batch_end_index) & (offs_d[None, :] < Lk),
|
||||||
other=0.0,
|
other=0.0,
|
||||||
).to(REDUCE_TRITON_TYPE)
|
).to(REDUCE_TRITON_TYPE)
|
||||||
att_value = tl.sum(q[None, :] * k, 1)
|
att_value = tl.sum(q[None, :] * k, 1)
|
||||||
@@ -128,6 +129,7 @@ def _fwd_kernel_stage2(
|
|||||||
kv_group_num: tl.constexpr,
|
kv_group_num: tl.constexpr,
|
||||||
BLOCK_DMODEL: tl.constexpr,
|
BLOCK_DMODEL: tl.constexpr,
|
||||||
BLOCK_N: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
|
Lv: tl.constexpr,
|
||||||
):
|
):
|
||||||
cur_batch = tl.program_id(0)
|
cur_batch = tl.program_id(0)
|
||||||
cur_head = tl.program_id(1)
|
cur_head = tl.program_id(1)
|
||||||
@@ -170,14 +172,16 @@ def _fwd_kernel_stage2(
|
|||||||
old_scale = tl.exp(e_max - n_e_max)
|
old_scale = tl.exp(e_max - n_e_max)
|
||||||
p = tl.exp(qk - n_e_max)
|
p = tl.exp(qk - n_e_max)
|
||||||
e_sum = e_sum * old_scale + tl.sum(p, 0)
|
e_sum = e_sum * old_scale + tl.sum(p, 0)
|
||||||
v = tl.load(v_ptrs + v_index[:, None] * stride_buf_vbs)
|
v = tl.load(
|
||||||
|
v_ptrs + v_index[:, None] * stride_buf_vbs, mask=(offs_d[None, :] < Lv)
|
||||||
|
)
|
||||||
acc = acc * old_scale + tl.sum(p[:, None] * v, 0)
|
acc = acc * old_scale + tl.sum(p[:, None] * v, 0)
|
||||||
e_max = n_e_max
|
e_max = n_e_max
|
||||||
|
|
||||||
acc = acc / e_sum
|
acc = acc / e_sum
|
||||||
off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d
|
off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d
|
||||||
out_ptrs = Out + off_o
|
out_ptrs = Out + off_o
|
||||||
tl.store(out_ptrs, acc)
|
tl.store(out_ptrs, acc, mask=(offs_d < Lv))
|
||||||
|
|
||||||
|
|
||||||
def _decode_att_m_fwd(
|
def _decode_att_m_fwd(
|
||||||
@@ -196,7 +200,7 @@ def _decode_att_m_fwd(
|
|||||||
# shape constraints
|
# shape constraints
|
||||||
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
|
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
|
||||||
assert Lq == Lk
|
assert Lq == Lk
|
||||||
assert Lk in {16, 32, 64, 128, 256}
|
assert Lk in {16, 32, 64, 96, 128, 256}
|
||||||
|
|
||||||
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
||||||
|
|
||||||
@@ -208,6 +212,8 @@ def _decode_att_m_fwd(
|
|||||||
else:
|
else:
|
||||||
num_warps = 2
|
num_warps = 2
|
||||||
|
|
||||||
|
BLOCK_DMODEL = triton.next_power_of_2(Lk)
|
||||||
|
|
||||||
_fwd_kernel_stage1[grid](
|
_fwd_kernel_stage1[grid](
|
||||||
q,
|
q,
|
||||||
k_buffer,
|
k_buffer,
|
||||||
@@ -224,11 +230,12 @@ def _decode_att_m_fwd(
|
|||||||
k_buffer.stride(1),
|
k_buffer.stride(1),
|
||||||
att_out.stride(0),
|
att_out.stride(0),
|
||||||
kv_group_num=kv_group_num,
|
kv_group_num=kv_group_num,
|
||||||
BLOCK_DMODEL=Lk,
|
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||||
BLOCK_N=BLOCK,
|
BLOCK_N=BLOCK,
|
||||||
logit_cap=logit_cap,
|
logit_cap=logit_cap,
|
||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
num_stages=1,
|
num_stages=1,
|
||||||
|
Lk=Lk,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -248,6 +255,9 @@ def _decode_softmax_reducev_fwd(
|
|||||||
|
|
||||||
num_warps = 1
|
num_warps = 1
|
||||||
|
|
||||||
|
Lv = v_buffer.shape[-1]
|
||||||
|
BLOCK_DMODEL = triton.next_power_of_2(Lv)
|
||||||
|
|
||||||
_fwd_kernel_stage2[grid](
|
_fwd_kernel_stage2[grid](
|
||||||
logics,
|
logics,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
@@ -263,10 +273,11 @@ def _decode_softmax_reducev_fwd(
|
|||||||
o.stride(1),
|
o.stride(1),
|
||||||
req_to_tokens.stride(0),
|
req_to_tokens.stride(0),
|
||||||
kv_group_num=kv_group_num,
|
kv_group_num=kv_group_num,
|
||||||
BLOCK_DMODEL=v_buffer.shape[-1],
|
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||||
BLOCK_N=BLOCK,
|
BLOCK_N=BLOCK,
|
||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
num_stages=3,
|
num_stages=3,
|
||||||
|
Lv=Lv,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -293,6 +304,7 @@ def _fwd_grouped_kernel_stage1(
|
|||||||
BLOCK_N: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
BLOCK_H: tl.constexpr,
|
BLOCK_H: tl.constexpr,
|
||||||
logit_cap: tl.constexpr,
|
logit_cap: tl.constexpr,
|
||||||
|
Lk: tl.constexpr,
|
||||||
):
|
):
|
||||||
cur_batch = tl.program_id(0)
|
cur_batch = tl.program_id(0)
|
||||||
cur_kv_head = tl.program_id(1)
|
cur_kv_head = tl.program_id(1)
|
||||||
@@ -324,9 +336,9 @@ def _fwd_grouped_kernel_stage1(
|
|||||||
block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)
|
block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)
|
||||||
|
|
||||||
for start_mark in range(0, block_mask, 1):
|
for start_mark in range(0, block_mask, 1):
|
||||||
q = tl.load(Q + offs_q + start_mark, mask=mask_h[:, None]).to(
|
q = tl.load(
|
||||||
REDUCE_TRITON_TYPE
|
Q + offs_q + start_mark, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk)
|
||||||
)
|
).to(REDUCE_TRITON_TYPE)
|
||||||
offs_n_new = cur_batch_start_index + offs_n
|
offs_n_new = cur_batch_start_index + offs_n
|
||||||
k_loc = tl.load(
|
k_loc = tl.load(
|
||||||
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
|
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
|
||||||
@@ -340,7 +352,7 @@ def _fwd_grouped_kernel_stage1(
|
|||||||
)
|
)
|
||||||
k = tl.load(
|
k = tl.load(
|
||||||
K_Buffer + offs_buf_k,
|
K_Buffer + offs_buf_k,
|
||||||
mask=offs_n_new[None, :] < cur_batch_end_index,
|
mask=(offs_n_new[None, :] < cur_batch_end_index) & (offs_d[:, None] < Lk),
|
||||||
other=0.0,
|
other=0.0,
|
||||||
).to(REDUCE_TRITON_TYPE)
|
).to(REDUCE_TRITON_TYPE)
|
||||||
qk = tl.dot(q, k)
|
qk = tl.dot(q, k)
|
||||||
@@ -395,6 +407,7 @@ def _fwd_grouped_kernel_stage2(
|
|||||||
BLOCK_DMODEL: tl.constexpr,
|
BLOCK_DMODEL: tl.constexpr,
|
||||||
BLOCK_N: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
BLOCK_H: tl.constexpr,
|
BLOCK_H: tl.constexpr,
|
||||||
|
Lv: tl.constexpr,
|
||||||
):
|
):
|
||||||
cur_batch = tl.program_id(0)
|
cur_batch = tl.program_id(0)
|
||||||
cur_kv_head = tl.program_id(1)
|
cur_kv_head = tl.program_id(1)
|
||||||
@@ -441,7 +454,9 @@ def _fwd_grouped_kernel_stage2(
|
|||||||
old_scale = tl.exp(e_max - n_e_max)
|
old_scale = tl.exp(e_max - n_e_max)
|
||||||
p = tl.exp(qk - n_e_max[:, None])
|
p = tl.exp(qk - n_e_max[:, None])
|
||||||
e_sum = e_sum * old_scale + tl.sum(p, 1)
|
e_sum = e_sum * old_scale + tl.sum(p, 1)
|
||||||
v = tl.load(v_ptrs + v_index[:, None] * stride_buf_vbs)
|
v = tl.load(
|
||||||
|
v_ptrs + v_index[:, None] * stride_buf_vbs, mask=(offs_d[None, :] < Lv)
|
||||||
|
)
|
||||||
p = p.to(v.dtype)
|
p = p.to(v.dtype)
|
||||||
acc = acc * old_scale[:, None] + tl.dot(p, v)
|
acc = acc * old_scale[:, None] + tl.dot(p, v)
|
||||||
e_max = n_e_max
|
e_max = n_e_max
|
||||||
@@ -449,7 +464,7 @@ def _fwd_grouped_kernel_stage2(
|
|||||||
acc = acc / e_sum[:, None]
|
acc = acc / e_sum[:, None]
|
||||||
off_o = cur_batch * stride_obs + cur_head[:, None] * stride_oh + offs_d[None, :]
|
off_o = cur_batch * stride_obs + cur_head[:, None] * stride_oh + offs_d[None, :]
|
||||||
out_ptrs = Out + off_o
|
out_ptrs = Out + off_o
|
||||||
tl.store(out_ptrs, acc, mask=mask_h[:, None])
|
tl.store(out_ptrs, acc, mask=(mask_h[:, None]) & (offs_d[None, :] < Lv))
|
||||||
|
|
||||||
|
|
||||||
def _decode_grouped_att_m_fwd(
|
def _decode_grouped_att_m_fwd(
|
||||||
@@ -468,13 +483,13 @@ def _decode_grouped_att_m_fwd(
|
|||||||
# shape constraints
|
# shape constraints
|
||||||
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
|
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
|
||||||
assert Lq == Lk
|
assert Lq == Lk
|
||||||
assert Lk in {16, 32, 64, 128, 256, 576}
|
assert Lk in {16, 32, 64, 96, 128, 256, 576}
|
||||||
|
|
||||||
if Lk == 576:
|
if Lk == 576:
|
||||||
BLOCK_DMODEL = 512
|
BLOCK_DMODEL = 512
|
||||||
BLOCK_DPE = 64
|
BLOCK_DPE = 64
|
||||||
else:
|
else:
|
||||||
BLOCK_DMODEL = Lk
|
BLOCK_DMODEL = triton.next_power_of_2(Lk)
|
||||||
BLOCK_DPE = 0
|
BLOCK_DPE = 0
|
||||||
|
|
||||||
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
||||||
@@ -513,6 +528,7 @@ def _decode_grouped_att_m_fwd(
|
|||||||
logit_cap=logit_cap,
|
logit_cap=logit_cap,
|
||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
num_stages=1,
|
num_stages=1,
|
||||||
|
Lk=Lk,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -533,6 +549,9 @@ def _decode_grouped_softmax_reducev_fwd(
|
|||||||
|
|
||||||
num_warps = 8
|
num_warps = 8
|
||||||
|
|
||||||
|
Lv = v_buffer.shape[-1]
|
||||||
|
BLOCK_DMODEL = triton.next_power_of_2(Lv)
|
||||||
|
|
||||||
_fwd_grouped_kernel_stage2[grid](
|
_fwd_grouped_kernel_stage2[grid](
|
||||||
logics,
|
logics,
|
||||||
v_buffer,
|
v_buffer,
|
||||||
@@ -549,11 +568,12 @@ def _decode_grouped_softmax_reducev_fwd(
|
|||||||
req_to_tokens.stride(0),
|
req_to_tokens.stride(0),
|
||||||
kv_group_num=kv_group_num,
|
kv_group_num=kv_group_num,
|
||||||
q_head_num=head_num,
|
q_head_num=head_num,
|
||||||
BLOCK_DMODEL=v_buffer.shape[-1],
|
BLOCK_DMODEL=BLOCK_DMODEL,
|
||||||
BLOCK_N=BLOCK,
|
BLOCK_N=BLOCK,
|
||||||
BLOCK_H=BLOCK_H,
|
BLOCK_H=BLOCK_H,
|
||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
num_stages=1,
|
num_stages=1,
|
||||||
|
Lv=Lv,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ limitations under the License.
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
Memory-efficient attention for prefill.
|
Memory-efficient attention for prefill.
|
||||||
It supporst page size = 1 and prefill with KV cache (i.e. extend).
|
It supports page size = 1 and prefill with KV cache (i.e. extend).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -67,6 +67,8 @@ def _fwd_kernel(
|
|||||||
BLOCK_M: tl.constexpr,
|
BLOCK_M: tl.constexpr,
|
||||||
BLOCK_N: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
logit_cap: tl.constexpr,
|
logit_cap: tl.constexpr,
|
||||||
|
Lq: tl.constexpr,
|
||||||
|
Lv: tl.constexpr,
|
||||||
):
|
):
|
||||||
cur_seq = tl.program_id(0)
|
cur_seq = tl.program_id(0)
|
||||||
cur_head = tl.program_id(1)
|
cur_head = tl.program_id(1)
|
||||||
@@ -86,13 +88,18 @@ def _fwd_kernel(
|
|||||||
offs_m = tl.arange(0, BLOCK_M)
|
offs_m = tl.arange(0, BLOCK_M)
|
||||||
mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend
|
mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend
|
||||||
|
|
||||||
|
mask_d = offs_d < Lq
|
||||||
|
mask_dv = offs_dv < Lv
|
||||||
|
|
||||||
offs_q = (
|
offs_q = (
|
||||||
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
(cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None])
|
||||||
* stride_qbs
|
* stride_qbs
|
||||||
+ cur_head * stride_qh
|
+ cur_head * stride_qh
|
||||||
+ offs_d[None, :]
|
+ offs_d[None, :]
|
||||||
)
|
)
|
||||||
q = tl.load(Q_Extend + offs_q, mask=mask_m[:, None], other=0.0)
|
q = tl.load(
|
||||||
|
Q_Extend + offs_q, mask=(mask_m[:, None]) & (mask_d[None, :]), other=0.0
|
||||||
|
)
|
||||||
|
|
||||||
if BLOCK_DPE > 0:
|
if BLOCK_DPE > 0:
|
||||||
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE)
|
||||||
@@ -125,7 +132,9 @@ def _fwd_kernel(
|
|||||||
+ cur_kv_head * stride_buf_kh
|
+ cur_kv_head * stride_buf_kh
|
||||||
+ offs_d[:, None]
|
+ offs_d[:, None]
|
||||||
)
|
)
|
||||||
k = tl.load(K_Buffer + offs_buf_k, mask=mask_n[None, :], other=0.0)
|
k = tl.load(
|
||||||
|
K_Buffer + offs_buf_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
|
||||||
|
)
|
||||||
|
|
||||||
qk = tl.dot(q.to(k.dtype), k)
|
qk = tl.dot(q.to(k.dtype), k)
|
||||||
if BLOCK_DPE > 0:
|
if BLOCK_DPE > 0:
|
||||||
@@ -157,7 +166,9 @@ def _fwd_kernel(
|
|||||||
+ cur_kv_head * stride_buf_vh
|
+ cur_kv_head * stride_buf_vh
|
||||||
+ offs_dv[None, :]
|
+ offs_dv[None, :]
|
||||||
)
|
)
|
||||||
v = tl.load(V_Buffer + offs_buf_v, mask=mask_n[:, None], other=0.0)
|
v = tl.load(
|
||||||
|
V_Buffer + offs_buf_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
|
||||||
|
)
|
||||||
p = p.to(v.dtype)
|
p = p.to(v.dtype)
|
||||||
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
||||||
|
|
||||||
@@ -176,7 +187,9 @@ def _fwd_kernel(
|
|||||||
+ cur_kv_head * stride_kh
|
+ cur_kv_head * stride_kh
|
||||||
+ offs_d[:, None]
|
+ offs_d[:, None]
|
||||||
)
|
)
|
||||||
k = tl.load(K_Extend + offs_k, mask=mask_n[None, :], other=0.0)
|
k = tl.load(
|
||||||
|
K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0
|
||||||
|
)
|
||||||
|
|
||||||
qk = tl.dot(q, k, out_dtype=tl.float32)
|
qk = tl.dot(q, k, out_dtype=tl.float32)
|
||||||
if BLOCK_DPE > 0:
|
if BLOCK_DPE > 0:
|
||||||
@@ -214,7 +227,9 @@ def _fwd_kernel(
|
|||||||
+ cur_kv_head * stride_vh
|
+ cur_kv_head * stride_vh
|
||||||
+ offs_dv[None, :]
|
+ offs_dv[None, :]
|
||||||
)
|
)
|
||||||
v = tl.load(V_Extend + offs_v, mask=mask_n[:, None], other=0.0)
|
v = tl.load(
|
||||||
|
V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0
|
||||||
|
)
|
||||||
p = p.to(v.dtype)
|
p = p.to(v.dtype)
|
||||||
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
acc = acc * re_scale[:, None] + tl.dot(p, v)
|
||||||
|
|
||||||
@@ -226,7 +241,9 @@ def _fwd_kernel(
|
|||||||
+ cur_head * stride_oh
|
+ cur_head * stride_oh
|
||||||
+ offs_dv[None, :]
|
+ offs_dv[None, :]
|
||||||
)
|
)
|
||||||
tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None])
|
tl.store(
|
||||||
|
O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None] & mask_dv[None, :]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def extend_attention_fwd(
|
def extend_attention_fwd(
|
||||||
@@ -261,16 +278,18 @@ def extend_attention_fwd(
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert Lq == Lk and Lv == Lo
|
assert Lq == Lk and Lv == Lo
|
||||||
assert Lq in {16, 32, 64, 128, 256, 576}
|
|
||||||
assert Lv in {16, 32, 64, 128, 256, 512}
|
# TODO: is the assertion necessary?
|
||||||
|
assert Lq in {16, 32, 64, 96, 128, 256, 576}
|
||||||
|
assert Lv in {16, 32, 64, 96, 128, 256, 512}
|
||||||
|
|
||||||
if Lq == 576:
|
if Lq == 576:
|
||||||
BLOCK_DMODEL = 512
|
BLOCK_DMODEL = 512
|
||||||
BLOCK_DPE = 64
|
BLOCK_DPE = 64
|
||||||
else:
|
else:
|
||||||
BLOCK_DMODEL = Lq
|
BLOCK_DMODEL = triton.next_power_of_2(Lq)
|
||||||
BLOCK_DPE = 0
|
BLOCK_DPE = 0
|
||||||
BLOCK_DV = Lv
|
BLOCK_DV = triton.next_power_of_2(Lv)
|
||||||
|
|
||||||
if CUDA_CAPABILITY[0] >= 9:
|
if CUDA_CAPABILITY[0] >= 9:
|
||||||
if Lq <= 256:
|
if Lq <= 256:
|
||||||
@@ -330,6 +349,8 @@ def extend_attention_fwd(
|
|||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
num_stages=num_stages,
|
num_stages=num_stages,
|
||||||
logit_cap=logit_cap,
|
logit_cap=logit_cap,
|
||||||
|
Lq=Lq,
|
||||||
|
Lv=Lv,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -373,10 +394,7 @@ def redundant_attention(
|
|||||||
pt += cur_seq_len_extend
|
pt += cur_seq_len_extend
|
||||||
|
|
||||||
|
|
||||||
def test():
|
def test_once(B, N_CTX, H_Q, H_KV, D):
|
||||||
torch.manual_seed(0)
|
|
||||||
|
|
||||||
B, N_CTX, H_Q, H_KV, D = 19, 12331, 12, 4, 128
|
|
||||||
dtype = torch.float16
|
dtype = torch.float16
|
||||||
|
|
||||||
b_seq_len_prefix = torch.randint(
|
b_seq_len_prefix = torch.randint(
|
||||||
@@ -473,4 +491,5 @@ def test():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test()
|
test_once(19, 12331, 12, 4, 128)
|
||||||
|
test_once(19, 12331, 12, 4, 96)
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ def _fwd_kernel(
|
|||||||
BLOCK_M: tl.constexpr,
|
BLOCK_M: tl.constexpr,
|
||||||
BLOCK_DMODEL: tl.constexpr,
|
BLOCK_DMODEL: tl.constexpr,
|
||||||
BLOCK_N: tl.constexpr,
|
BLOCK_N: tl.constexpr,
|
||||||
|
Lk: tl.constexpr,
|
||||||
):
|
):
|
||||||
cur_batch = tl.program_id(0)
|
cur_batch = tl.program_id(0)
|
||||||
cur_head = tl.program_id(1)
|
cur_head = tl.program_id(1)
|
||||||
@@ -72,7 +73,11 @@ def _fwd_kernel(
|
|||||||
off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None]
|
off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None]
|
||||||
off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :]
|
off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :]
|
||||||
|
|
||||||
q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)
|
mask_d = offs_d < Lk
|
||||||
|
|
||||||
|
q = tl.load(
|
||||||
|
Q + off_q, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d), other=0.0
|
||||||
|
)
|
||||||
|
|
||||||
k_ptrs = K + off_k
|
k_ptrs = K + off_k
|
||||||
v_ptrs = V + off_v
|
v_ptrs = V + off_v
|
||||||
@@ -89,7 +94,7 @@ def _fwd_kernel(
|
|||||||
# -- compute qk ----
|
# -- compute qk ----
|
||||||
k = tl.load(
|
k = tl.load(
|
||||||
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
|
k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,
|
||||||
mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,
|
mask=((start_n + offs_n[None, :]) < cur_batch_seq_len) & (mask_d[:, None]),
|
||||||
other=0.0,
|
other=0.0,
|
||||||
)
|
)
|
||||||
# mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0)
|
# mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0)
|
||||||
@@ -118,7 +123,7 @@ def _fwd_kernel(
|
|||||||
# update acc
|
# update acc
|
||||||
v = tl.load(
|
v = tl.load(
|
||||||
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
|
v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,
|
||||||
mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,
|
mask=((start_n + offs_n[:, None]) < cur_batch_seq_len) & (mask_d[None, :]),
|
||||||
other=0.0,
|
other=0.0,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -134,7 +139,9 @@ def _fwd_kernel(
|
|||||||
+ offs_d[None, :]
|
+ offs_d[None, :]
|
||||||
)
|
)
|
||||||
out_ptrs = Out + off_o
|
out_ptrs = Out + off_o
|
||||||
tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)
|
tl.store(
|
||||||
|
out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :])
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
||||||
@@ -145,7 +152,7 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
|||||||
|
|
||||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||||
assert Lq == Lk and Lk == Lv
|
assert Lq == Lk and Lk == Lv
|
||||||
assert Lk in {16, 32, 64, 128, 256}
|
assert Lk in {16, 32, 64, 96, 128, 256}
|
||||||
|
|
||||||
sm_scale = 1.0 / (Lq**0.5)
|
sm_scale = 1.0 / (Lq**0.5)
|
||||||
batch, head = b_seq_len.shape[0], q.shape[1]
|
batch, head = b_seq_len.shape[0], q.shape[1]
|
||||||
@@ -172,8 +179,9 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
|
|||||||
o.stride(1),
|
o.stride(1),
|
||||||
kv_group_num=kv_group_num,
|
kv_group_num=kv_group_num,
|
||||||
BLOCK_M=BLOCK,
|
BLOCK_M=BLOCK,
|
||||||
BLOCK_DMODEL=Lk,
|
BLOCK_DMODEL=triton.next_power_of_2(Lk),
|
||||||
BLOCK_N=BLOCK,
|
BLOCK_N=BLOCK,
|
||||||
num_warps=num_warps,
|
num_warps=num_warps,
|
||||||
num_stages=1,
|
num_stages=1,
|
||||||
|
Lk=Lk,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user