Remove unused vars in the triton backend (#2401)
This commit is contained in:
@@ -196,7 +196,6 @@ class TestTritonAttention(unittest.TestCase):
|
||||
|
||||
req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len)
|
||||
b_req_idx = torch.arange(B, device="cuda")
|
||||
b_start_loc = torch.arange(0, total_tokens, seq_len, device="cuda")
|
||||
b_seq_len = torch.full((B,), seq_len, device="cuda")
|
||||
|
||||
attn_logits = torch.empty(
|
||||
@@ -212,10 +211,8 @@ class TestTritonAttention(unittest.TestCase):
|
||||
o,
|
||||
req_to_token,
|
||||
b_req_idx,
|
||||
b_start_loc,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
seq_len,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
)
|
||||
@@ -255,7 +252,6 @@ class TestTritonAttention(unittest.TestCase):
|
||||
|
||||
req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len)
|
||||
b_req_idx = torch.arange(B, device="cuda")
|
||||
b_start_loc = torch.arange(0, total_tokens, seq_len, device="cuda")
|
||||
b_seq_len = torch.full((B,), seq_len, device="cuda")
|
||||
|
||||
attn_logits = torch.empty(
|
||||
@@ -273,7 +269,6 @@ class TestTritonAttention(unittest.TestCase):
|
||||
b_req_idx,
|
||||
b_seq_len,
|
||||
attn_logits,
|
||||
seq_len,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
)
|
||||
@@ -293,7 +288,6 @@ class TestTritonAttention(unittest.TestCase):
|
||||
b_req_idx,
|
||||
b_seq_len,
|
||||
attn_logits1,
|
||||
seq_len,
|
||||
num_kv_splits,
|
||||
sm_scale,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user