Update Triton extend backend interface (#3309)
This commit is contained in:
@@ -45,16 +45,20 @@ class TestTritonAttention(unittest.TestCase):
|
||||
max_len_in_batch = torch.max(b_seq_len, 0)[0].item()
|
||||
|
||||
b_req_idx = torch.arange(B, dtype=torch.int32, device="cuda")
|
||||
req_to_tokens = torch.empty(
|
||||
(B, max_len_in_batch), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda")
|
||||
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
|
||||
b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda")
|
||||
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
|
||||
|
||||
kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
|
||||
kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0)
|
||||
kv_indices = torch.zeros(
|
||||
(b_seq_len_prefix.sum().item(),), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
for i in range(B):
|
||||
req_to_tokens[i, : b_seq_len[i]] = torch.arange(
|
||||
b_start_loc[i], b_start_loc[i] + b_seq_len[i]
|
||||
kv_indices[kv_indptr[i] : kv_indptr[i + 1]] = torch.arange(
|
||||
b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i]
|
||||
)
|
||||
|
||||
total_token_num = torch.sum(b_seq_len).item()
|
||||
@@ -90,9 +94,10 @@ class TestTritonAttention(unittest.TestCase):
|
||||
)
|
||||
|
||||
b_seq_len_extend = b_seq_len - b_seq_len_prefix
|
||||
b_start_loc_extend = torch.zeros_like(b_seq_len)
|
||||
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
|
||||
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
|
||||
qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
|
||||
qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0)
|
||||
|
||||
extend_attention_fwd(
|
||||
q_extend,
|
||||
k_extend,
|
||||
@@ -100,11 +105,9 @@ class TestTritonAttention(unittest.TestCase):
|
||||
o_extend,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_tokens,
|
||||
b_req_idx,
|
||||
b_seq_len,
|
||||
b_seq_len_extend,
|
||||
b_start_loc_extend,
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
max_len_extend,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user