Remove unused vars in the triton backend (#2401)

This commit is contained in:
Ke Bao
2024-12-08 19:37:03 +08:00
committed by GitHub
parent 96db0f666d
commit 61dec545b0
3 changed files with 14 additions and 33 deletions

View File

@@ -196,7 +196,6 @@ class TestTritonAttention(unittest.TestCase):
req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len)
b_req_idx = torch.arange(B, device="cuda")
b_start_loc = torch.arange(0, total_tokens, seq_len, device="cuda")
b_seq_len = torch.full((B,), seq_len, device="cuda")
attn_logits = torch.empty(
@@ -212,10 +211,8 @@ class TestTritonAttention(unittest.TestCase):
o,
req_to_token,
b_req_idx,
b_start_loc,
b_seq_len,
attn_logits,
seq_len,
num_kv_splits,
sm_scale,
)
@@ -255,7 +252,6 @@ class TestTritonAttention(unittest.TestCase):
req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len)
b_req_idx = torch.arange(B, device="cuda")
b_start_loc = torch.arange(0, total_tokens, seq_len, device="cuda")
b_seq_len = torch.full((B,), seq_len, device="cuda")
attn_logits = torch.empty(
@@ -273,7 +269,6 @@ class TestTritonAttention(unittest.TestCase):
b_req_idx,
b_seq_len,
attn_logits,
seq_len,
num_kv_splits,
sm_scale,
)
@@ -293,7 +288,6 @@ class TestTritonAttention(unittest.TestCase):
b_req_idx,
b_seq_len,
attn_logits1,
seq_len,
num_kv_splits,
sm_scale,
)