Update Triton extend backend interface (#3309)

This commit is contained in:
Ke Bao
2025-02-05 18:12:22 +08:00
committed by GitHub
parent 7aad8d1854
commit de5533341e
5 changed files with 427 additions and 69 deletions

View File

@@ -45,16 +45,20 @@ class TestTritonAttention(unittest.TestCase):
max_len_in_batch = torch.max(b_seq_len, 0)[0].item()
b_req_idx = torch.arange(B, dtype=torch.int32, device="cuda")
req_to_tokens = torch.empty(
(B, max_len_in_batch), dtype=torch.int32, device="cuda"
)
b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda")
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda")
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0)
kv_indices = torch.zeros(
(b_seq_len_prefix.sum().item(),), dtype=torch.int32, device="cuda"
)
for i in range(B):
req_to_tokens[i, : b_seq_len[i]] = torch.arange(
b_start_loc[i], b_start_loc[i] + b_seq_len[i]
kv_indices[kv_indptr[i] : kv_indptr[i + 1]] = torch.arange(
b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i]
)
total_token_num = torch.sum(b_seq_len).item()
@@ -90,9 +94,10 @@ class TestTritonAttention(unittest.TestCase):
)
b_seq_len_extend = b_seq_len - b_seq_len_prefix
b_start_loc_extend = torch.zeros_like(b_seq_len)
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0)
extend_attention_fwd(
q_extend,
k_extend,
@@ -100,11 +105,9 @@ class TestTritonAttention(unittest.TestCase):
o_extend,
k_buffer,
v_buffer,
req_to_tokens,
b_req_idx,
b_seq_len,
b_seq_len_extend,
b_start_loc_extend,
qo_indptr,
kv_indptr,
kv_indices,
max_len_extend,
)