Fix Triton decode kernel & ut (#1819)

This commit is contained in:
Ke Bao
2024-10-28 01:54:38 +08:00
committed by GitHub
parent 51c81e339b
commit c77762d57f
5 changed files with 218 additions and 42 deletions

View File

@@ -296,12 +296,18 @@ def _fwd_grouped_kernel_stage1(
Lk: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_kv_head = tl.program_id(1)
cur_head_id = tl.program_id(1)
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
start_n = tl.program_id(2)
reduce_dtype = Att_Out.dtype.element_ty
cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H)
mask_h = cur_head < (cur_kv_head + 1) * kv_group_num
if BLOCK_H < kv_group_num:
VALID_BLOCK_H: tl.constexpr = BLOCK_H
else:
VALID_BLOCK_H: tl.constexpr = kv_group_num
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
mask_h = mask_h & (cur_head < q_head_num)
offs_d = tl.arange(0, BLOCK_DMODEL)
@@ -400,10 +406,15 @@ def _fwd_grouped_kernel_stage2(
Lv: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_kv_head = tl.program_id(1)
cur_head_id = tl.program_id(1)
cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H)
cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H)
mask_h = cur_head < (cur_kv_head + 1) * kv_group_num
if BLOCK_H < kv_group_num:
VALID_BLOCK_H: tl.constexpr = BLOCK_H
else:
VALID_BLOCK_H: tl.constexpr = kv_group_num
cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H)
mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H
mask_h = mask_h & (cur_head < q_head_num)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
@@ -485,7 +496,7 @@ def _decode_grouped_att_m_fwd(
batch, head_num = B_req_idx.shape[0], q.shape[1]
kv_group_num = q.shape[1] // k_buffer.shape[1]
BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num)))
grid = (
batch,
triton.cdiv(head_num, min(BLOCK_H, kv_group_num)),
@@ -534,7 +545,7 @@ def _decode_grouped_softmax_reducev_fwd(
BLOCK = 128
batch, head_num = b_seq_len.shape[0], logits.shape[0]
kv_group_num = logits.shape[0] // v_buffer.shape[1]
BLOCK_H = max(16, triton.next_power_of_2(kv_group_num))
BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num)))
grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1)
num_warps = 8
@@ -567,6 +578,80 @@ def _decode_grouped_softmax_reducev_fwd(
)
def decode_attention_fwd_normal(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
attn_logits,
max_len_in_batch,
sm_scale,
logit_cap=0.0,
):
_decode_att_m_fwd(
q,
k_buffer,
attn_logits,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
max_len_in_batch,
sm_scale,
logit_cap,
)
_decode_softmax_reducev_fwd(
attn_logits,
v_buffer,
o,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
)
def decode_attention_fwd_grouped(
q,
k_buffer,
v_buffer,
o,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
attn_logits,
max_len_in_batch,
sm_scale,
logit_cap=0.0,
):
_decode_grouped_att_m_fwd(
q,
k_buffer,
attn_logits,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
max_len_in_batch,
sm_scale,
logit_cap,
)
_decode_grouped_softmax_reducev_fwd(
attn_logits,
v_buffer,
o,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
)
def decode_attention_fwd(
q,
k_buffer,
@@ -585,47 +670,33 @@ def decode_attention_fwd(
if kv_group_num == 1:
# MHA
_decode_att_m_fwd(
decode_attention_fwd_normal(
q,
k_buffer,
attn_logits,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
max_len_in_batch,
sm_scale,
logit_cap,
)
_decode_softmax_reducev_fwd(
attn_logits,
v_buffer,
o,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
attn_logits,
max_len_in_batch,
sm_scale,
logit_cap,
)
else:
# GQA/MQA/MLA
_decode_grouped_att_m_fwd(
decode_attention_fwd_grouped(
q,
k_buffer,
attn_logits,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
max_len_in_batch,
sm_scale,
logit_cap,
)
_decode_grouped_softmax_reducev_fwd(
attn_logits,
v_buffer,
o,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
attn_logits,
max_len_in_batch,
sm_scale,
logit_cap,
)

View File

@@ -168,7 +168,7 @@ def _fwd_kernel(
def context_attention_fwd(
q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True
):
if is_cuda_available and CUDA_CAPABILITY[0] >= 8:
if is_cuda_available and CUDA_CAPABILITY[0] > 8:
BLOCK = 128
else:
BLOCK = 64