remove assertion in triton attention and add an unit test (#1385)

This commit is contained in:
Byron Hsu
2024-09-11 03:22:07 -07:00
committed by GitHub
parent 144bc70fcc
commit 8c0efa514d
4 changed files with 213 additions and 113 deletions

View File

@@ -199,8 +199,6 @@ def _decode_att_m_fwd(
BLOCK = 32
# shape constraints
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
assert Lq == Lk
assert Lk in {16, 32, 64, 96, 128, 256}
batch, head_num = B_req_idx.shape[0], q.shape[1]
@@ -482,8 +480,6 @@ def _decode_grouped_att_m_fwd(
BLOCK = 32
# shape constraints
Lq, Lk = q.shape[-1], k_buffer.shape[-1]
assert Lq == Lk
assert Lk in {16, 32, 64, 96, 128, 256, 576, 288}
if Lk == 576:
BLOCK_DMODEL = 512

View File

@@ -277,12 +277,6 @@ def extend_attention_fwd(
o_extend.shape[-1],
)
assert Lq == Lk and Lv == Lo
# TODO: is the assertion necessary?
assert Lq in {16, 32, 64, 96, 128, 256, 576, 288}
assert Lv in {16, 32, 64, 96, 128, 256, 512}
if Lq == 576:
BLOCK_DMODEL = 512
BLOCK_DPE = 64
@@ -395,104 +389,3 @@ def redundant_attention(
pl, pr = b_start_loc[i] + b_seq_len_prefix[i], b_start_loc[i] + b_seq_len[i]
o_extend[pt : pt + cur_seq_len_extend] = o_buffer[pl:pr]
pt += cur_seq_len_extend
def test_once(B, N_CTX, H_Q, H_KV, D):
dtype = torch.float16
b_seq_len_prefix = torch.randint(
1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
)
b_seq_len_extend = torch.randint(
1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda"
)
b_seq_len = b_seq_len_prefix + b_seq_len_extend
max_len_in_batch = torch.max(b_seq_len, 0)[0].item()
b_req_idx = torch.arange(B, dtype=torch.int32, device="cuda")
req_to_tokens = torch.empty((B, max_len_in_batch), dtype=torch.int32, device="cuda")
b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda")
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda")
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
for i in range(B):
req_to_tokens[i, : b_seq_len[i]] = torch.arange(
b_start_loc[i], b_start_loc[i] + b_seq_len[i]
)
total_token_num = torch.sum(b_seq_len).item()
extend_token_num = torch.sum(b_seq_len_extend).item()
k_buffer = torch.empty(
(total_token_num, H_KV, D), dtype=dtype, device="cuda"
).normal_(mean=0.1, std=0.2)
v_buffer = torch.empty(
(total_token_num, H_KV, D), dtype=dtype, device="cuda"
).normal_(mean=0.1, std=0.2)
k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda")
q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
for i in range(B):
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
extend_start = b_start_loc_extend[i]
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
k_extend[extend_start:extend_end] = k_buffer[
extend_start_in_buffer:extend_end_in_buffer
]
v_extend[extend_start:extend_end] = v_buffer[
extend_start_in_buffer:extend_end_in_buffer
]
q_extend[extend_start:extend_end] = torch.empty(
(b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda"
).normal_(mean=0.1, std=0.2)
o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
o_redundant = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda")
b_seq_len_extend = b_seq_len - b_seq_len_prefix
b_start_loc_extend = torch.zeros_like(b_seq_len)
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
extend_attention_fwd(
q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
b_seq_len_prefix,
b_start_loc_extend,
b_seq_len_extend,
max_len_in_batch,
max_len_extend,
)
redundant_attention(
q_extend,
k_extend,
v_extend,
o_redundant,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_start_loc,
b_seq_len,
b_seq_len_prefix,
max_len_in_batch,
)
print("Mean: ", torch.mean(torch.abs(o_extend - o_redundant)))
print("Max: ", torch.max(torch.abs(o_extend - o_redundant)))
assert torch.allclose(o_extend, o_redundant, rtol=1e-2)
if __name__ == "__main__":
test_once(19, 12331, 12, 4, 128)
test_once(19, 12331, 12, 4, 96)

View File

@@ -151,8 +151,6 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
BLOCK = 64
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
assert Lq == Lk and Lk == Lv
assert Lk in {16, 32, 64, 96, 128, 256}
sm_scale = 1.0 / (Lq**0.5)
batch, head = b_seq_len.shape[0], q.shape[1]