Support Eagle2 for Triton backend (#3466)

This commit is contained in:
Ke Bao
2025-02-10 20:00:42 +08:00
committed by GitHub
parent cddb1cdf8f
commit 2d61132374
5 changed files with 285 additions and 41 deletions

View File

@@ -102,7 +102,7 @@ class TestTritonAttention(unittest.TestCase):
qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0)
custom_mask = None
mask_offsets = None
mask_indptr = None
extend_attention_fwd(
q_extend,
@@ -115,7 +115,7 @@ class TestTritonAttention(unittest.TestCase):
kv_indptr,
kv_indices,
custom_mask,
mask_offsets,
mask_indptr,
max_len_extend,
)
@@ -123,8 +123,8 @@ class TestTritonAttention(unittest.TestCase):
custom_mask = torch.ones(
(b_seq_mask_len.sum().item(),), dtype=torch.bool, device="cuda"
)
mask_offsets = torch.zeros((B + 1,), dtype=torch.int64, device="cuda")
mask_offsets[1 : B + 1] = torch.cumsum(b_seq_mask_len[:B], dim=0)
mask_indptr = torch.zeros((B + 1,), dtype=torch.int64, device="cuda")
mask_indptr[1 : B + 1] = torch.cumsum(b_seq_mask_len[:B], dim=0)
for i in range(B):
causal_mask = (
torch.tril(
@@ -136,7 +136,7 @@ class TestTritonAttention(unittest.TestCase):
b_seq_len_extend[i], b_seq_len_prefix[i], dtype=torch.bool
)
mask_flatten = torch.cat([prefix_mask, causal_mask], dim=1).flatten()
custom_mask[mask_offsets[i] : mask_offsets[i + 1]] = mask_flatten
custom_mask[mask_indptr[i] : mask_indptr[i + 1]] = mask_flatten
extend_attention_fwd(
q_extend,
@@ -149,7 +149,7 @@ class TestTritonAttention(unittest.TestCase):
kv_indptr,
kv_indices,
custom_mask,
mask_offsets,
mask_indptr,
max_len_extend,
)