Support Eagle2 for Triton backend (#3466)

This commit is contained in:
Ke Bao
2025-02-10 20:00:42 +08:00
committed by GitHub
parent cddb1cdf8f
commit 2d61132374
5 changed files with 285 additions and 41 deletions

View File

@@ -193,5 +193,34 @@ class TestEAGLEServer(unittest.TestCase):
self.assertGreater(metrics["accuracy"], 0.20)
class TestEAGLEServerTriton(TestEAGLEServer):
@classmethod
def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--speculative-algorithm",
"EAGLE",
"--speculative-draft-model-path",
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
"--speculative-num-steps",
"5",
"--speculative-eagle-topk",
"8",
"--speculative-num-draft-tokens",
"64",
"--mem-fraction-static",
"0.7",
"--attention-backend",
"triton",
# TODO: Support cuda graph
"--disable-cuda-graph",
],
)
if __name__ == "__main__":
unittest.main()

View File

@@ -102,7 +102,7 @@ class TestTritonAttention(unittest.TestCase):
qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0)
custom_mask = None
mask_offsets = None
mask_indptr = None
extend_attention_fwd(
q_extend,
@@ -115,7 +115,7 @@ class TestTritonAttention(unittest.TestCase):
kv_indptr,
kv_indices,
custom_mask,
mask_offsets,
mask_indptr,
max_len_extend,
)
@@ -123,8 +123,8 @@ class TestTritonAttention(unittest.TestCase):
custom_mask = torch.ones(
(b_seq_mask_len.sum().item(),), dtype=torch.bool, device="cuda"
)
mask_offsets = torch.zeros((B + 1,), dtype=torch.int64, device="cuda")
mask_offsets[1 : B + 1] = torch.cumsum(b_seq_mask_len[:B], dim=0)
mask_indptr = torch.zeros((B + 1,), dtype=torch.int64, device="cuda")
mask_indptr[1 : B + 1] = torch.cumsum(b_seq_mask_len[:B], dim=0)
for i in range(B):
causal_mask = (
torch.tril(
@@ -136,7 +136,7 @@ class TestTritonAttention(unittest.TestCase):
b_seq_len_extend[i], b_seq_len_prefix[i], dtype=torch.bool
)
mask_flatten = torch.cat([prefix_mask, causal_mask], dim=1).flatten()
custom_mask[mask_offsets[i] : mask_offsets[i + 1]] = mask_flatten
custom_mask[mask_indptr[i] : mask_indptr[i + 1]] = mask_flatten
extend_attention_fwd(
q_extend,
@@ -149,7 +149,7 @@ class TestTritonAttention(unittest.TestCase):
kv_indptr,
kv_indices,
custom_mask,
mask_offsets,
mask_indptr,
max_len_extend,
)