Fix Triton decode kernel & ut (#1819)
This commit is contained in:
@@ -296,12 +296,18 @@ def _fwd_grouped_kernel_stage1(
|
||||
Lk: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_kv_head = tl.program_id(1)
|
||||
cur_head_id = tl.program_id(1)
|
||||
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
|
||||
start_n = tl.program_id(2)
|
||||
|
||||
reduce_dtype = Att_Out.dtype.element_ty
|
||||
cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H)
|
||||
mask_h = cur_head < (cur_kv_head + 1) * kv_group_num
|
||||
|
||||
if BLOCK_H < kv_group_num:
|
||||
VALID_BLOCK_H: tl.constexpr = BLOCK_H
|
||||
else:
|
||||
VALID_BLOCK_H: tl.constexpr = kv_group_num
|
||||
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
|
||||
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
|
||||
mask_h = mask_h & (cur_head < q_head_num)
|
||||
|
||||
offs_d = tl.arange(0, BLOCK_DMODEL)
|
||||
@@ -400,10 +406,15 @@ def _fwd_grouped_kernel_stage2(
|
||||
Lv: tl.constexpr,
|
||||
):
|
||||
cur_batch = tl.program_id(0)
|
||||
cur_kv_head = tl.program_id(1)
|
||||
cur_head_id = tl.program_id(1)
|
||||
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
|
||||
|
||||
cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H)
|
||||
mask_h = cur_head < (cur_kv_head + 1) * kv_group_num
|
||||
if BLOCK_H < kv_group_num:
|
||||
VALID_BLOCK_H: tl.constexpr = BLOCK_H
|
||||
else:
|
||||
VALID_BLOCK_H: tl.constexpr = kv_group_num
|
||||
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
|
||||
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
|
||||
mask_h = mask_h & (cur_head < q_head_num)
|
||||
|
||||
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
|
||||
@@ -485,7 +496,7 @@ def _decode_grouped_att_m_fwd(
|
||||
batch, head_num = B_req_idx.shape[0], q.shape[1]
|
||||
kv_group_num = q.shape[1] // k_buffer.shape[1]
|
||||
|
||||
BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
|
||||
BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num)))
|
||||
grid = (
|
||||
batch,
|
||||
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
|
||||
@@ -534,7 +545,7 @@ def _decode_grouped_softmax_reducev_fwd(
|
||||
BLOCK = 128
|
||||
batch, head_num = b_seq_len.shape[0], logits.shape[0]
|
||||
kv_group_num = logits.shape[0] // v_buffer.shape[1]
|
||||
BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
|
||||
BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num)))
|
||||
grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1)
|
||||
|
||||
num_warps = 8
|
||||
@@ -567,6 +578,80 @@ def _decode_grouped_softmax_reducev_fwd(
|
||||
)
|
||||
|
||||
|
||||
def decode_attention_fwd_normal(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
max_len_in_batch,
|
||||
sm_scale,
|
||||
logit_cap=0.0,
|
||||
):
|
||||
_decode_att_m_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
attn_logits,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
max_len_in_batch,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
)
|
||||
_decode_softmax_reducev_fwd(
|
||||
attn_logits,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
)
|
||||
|
||||
|
||||
def decode_attention_fwd_grouped(
|
||||
q,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
max_len_in_batch,
|
||||
sm_scale,
|
||||
logit_cap=0.0,
|
||||
):
|
||||
_decode_grouped_att_m_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
attn_logits,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
max_len_in_batch,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
)
|
||||
_decode_grouped_softmax_reducev_fwd(
|
||||
attn_logits,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
)
|
||||
|
||||
|
||||
def decode_attention_fwd(
|
||||
q,
|
||||
k_buffer,
|
||||
@@ -585,47 +670,33 @@ def decode_attention_fwd(
|
||||
|
||||
if kv_group_num == 1:
|
||||
# MHA
|
||||
_decode_att_m_fwd(
|
||||
decode_attention_fwd_normal(
|
||||
q,
|
||||
k_buffer,
|
||||
attn_logits,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
max_len_in_batch,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
)
|
||||
_decode_softmax_reducev_fwd(
|
||||
attn_logits,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
max_len_in_batch,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
)
|
||||
else:
|
||||
# GQA/MQA/MLA
|
||||
_decode_grouped_att_m_fwd(
|
||||
decode_attention_fwd_grouped(
|
||||
q,
|
||||
k_buffer,
|
||||
attn_logits,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
max_len_in_batch,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
)
|
||||
_decode_grouped_softmax_reducev_fwd(
|
||||
attn_logits,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
max_len_in_batch,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
)
|
||||
|
||||
@@ -168,7 +168,7 @@ def _fwd_kernel(
|
||||
def context_attention_fwd(
|
||||
q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True
|
||||
):
|
||||
if is_cuda_available and CUDA_CAPABILITY[0] >= 8:
|
||||
if is_cuda_available and CUDA_CAPABILITY[0] > 8:
|
||||
BLOCK = 128
|
||||
else:
|
||||
BLOCK = 64
|
||||
|
||||
Reference in New Issue
Block a user