Update Triton decode backend interface (#3292)

This commit is contained in:
Ke Bao
2025-02-04 23:26:04 +08:00
committed by GitHub
parent 2c1a695ff1
commit a07364ccc5
3 changed files with 129 additions and 77 deletions

View File

@@ -194,10 +194,12 @@ class TestTritonAttention(unittest.TestCase):
# o will have the same shape as q
o = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda")
req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len)
b_req_idx = torch.arange(B, device="cuda")
b_seq_len = torch.full((B,), seq_len, device="cuda")
kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len[:B], dim=0)
kv_indices = torch.arange(total_tokens, device="cuda")
attn_logits = torch.empty(
(B, H_Q, num_kv_splits, D + 1),
dtype=torch.float32,
@@ -209,9 +211,8 @@ class TestTritonAttention(unittest.TestCase):
k_buffer,
v_buffer,
o,
req_to_token,
b_req_idx,
b_seq_len,
kv_indptr,
kv_indices,
attn_logits,
num_kv_splits,
sm_scale,
@@ -250,10 +251,12 @@ class TestTritonAttention(unittest.TestCase):
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len)
b_req_idx = torch.arange(B, device="cuda")
b_seq_len = torch.full((B,), seq_len, device="cuda")
kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len[:B], dim=0)
kv_indices = torch.arange(total_tokens, device="cuda")
attn_logits = torch.empty(
(B, H_Q, num_kv_splits, D_V + 1),
dtype=torch.float32,
@@ -265,9 +268,8 @@ class TestTritonAttention(unittest.TestCase):
k_buffer,
v_buffer,
o,
req_to_token,
b_req_idx,
b_seq_len,
kv_indptr,
kv_indices,
attn_logits,
num_kv_splits,
sm_scale,
@@ -284,9 +286,8 @@ class TestTritonAttention(unittest.TestCase):
k_buffer,
v_buffer,
o_grouped,
req_to_token,
b_req_idx,
b_seq_len,
kv_indptr,
kv_indices,
attn_logits1,
num_kv_splits,
sm_scale,