Update Triton decode backend interface (#3292)
This commit is contained in:
@@ -194,10 +194,12 @@ class TestTritonAttention(unittest.TestCase):
|
||||
# o will have the same shape as q
|
||||
o = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda")
|
||||
|
||||
req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len)
|
||||
b_req_idx = torch.arange(B, device="cuda")
|
||||
b_seq_len = torch.full((B,), seq_len, device="cuda")
|
||||
|
||||
kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
|
||||
kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len[:B], dim=0)
|
||||
kv_indices = torch.arange(total_tokens, device="cuda")
|
||||
|
||||
attn_logits = torch.empty(
|
||||
(B, H_Q, num_kv_splits, D + 1),
|
||||
dtype=torch.float32,
|
||||
@@ -209,9 +211,8 @@ class TestTritonAttention(unittest.TestCase):
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_seq_len,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
@@ -250,10 +251,12 @@ class TestTritonAttention(unittest.TestCase):
|
||||
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
|
||||
o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda")
|
||||
|
||||
req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len)
|
||||
b_req_idx = torch.arange(B, device="cuda")
|
||||
b_seq_len = torch.full((B,), seq_len, device="cuda")
|
||||
|
||||
kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda")
|
||||
kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len[:B], dim=0)
|
||||
kv_indices = torch.arange(total_tokens, device="cuda")
|
||||
|
||||
attn_logits = torch.empty(
|
||||
(B, H_Q, num_kv_splits, D_V + 1),
|
||||
dtype=torch.float32,
|
||||
@@ -265,9 +268,8 @@ class TestTritonAttention(unittest.TestCase):
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_seq_len,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
attn_logits,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
@@ -284,9 +286,8 @@ class TestTritonAttention(unittest.TestCase):
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
o_grouped,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_seq_len,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
attn_logits1,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
|
||||
Reference in New Issue
Block a user