Support EAGLE draft extend CUDA graph (#6606)

Co-authored-by: Sehoon Kim <sehoonkim@berkeley.edu>
This commit is contained in:
Ke Bao
2025-05-27 17:35:17 +08:00
committed by GitHub
parent a3d7f4b673
commit 631950280a
5 changed files with 406 additions and 5 deletions

View File

@@ -1268,6 +1268,29 @@ class FlashAttentionBackend(AttentionBackend):
),
}
self.draft_extend_metadata = {
"cache_seqlens": torch.zeros(
max_bs, dtype=torch.int32, device=self.device
),
"cu_seqlens_q": torch.zeros(
max_bs + 1,
dtype=torch.int32,
device=self.device,
),
"cu_seqlens_k": torch.zeros(
max_bs + 1, dtype=torch.int32, device=self.device
),
"page_table": torch.zeros(
max_bs,
(self.max_context_len + self.page_size - 1) // self.page_size,
dtype=torch.int32,
device=self.device,
),
"strided_indices": torch.arange(
0, self.max_context_len, self.page_size, device=self.device
),
}
if self.topk > 1:
self.target_verify_metadata_topk_normal = {
"cache_seqlens": torch.zeros(
@@ -1508,6 +1531,32 @@ class FlashAttentionBackend(AttentionBackend):
self.target_verify_metadata_topk_normal[bs] = metadata
self.target_verify_metadata_topk_expand[bs] = metadata_expand
elif forward_mode.is_draft_extend():
metadata.cache_seqlens_int32 = self.draft_extend_metadata["cache_seqlens"][
:bs
]
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
num_tokens_per_bs = num_tokens // bs
metadata.max_seq_len_q = num_tokens_per_bs
metadata.max_seq_len_k = seq_lens.max().item()
metadata.cu_seqlens_q = torch.arange(
0,
bs * num_tokens_per_bs + 1,
num_tokens_per_bs,
dtype=torch.int32,
device=device,
)
metadata.cu_seqlens_k = self.draft_extend_metadata["cu_seqlens_k"][
: (bs + 1)
]
metadata.page_table = self.draft_extend_metadata["page_table"][
req_pool_indices, :
]
self.draft_extend_metadata[bs] = metadata
if encoder_lens is not None:
encoder_bs = encoder_lens.numel()
@@ -1732,6 +1781,29 @@ class FlashAttentionBackend(AttentionBackend):
metadata_expand.max_seq_len_k = (
metadata_expand.cache_seqlens_int32.max().item()
)
elif forward_mode.is_draft_extend():
metadata = self.draft_extend_metadata[bs]
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
metadata.max_seq_len_k = seq_lens_cpu.max().item()
metadata.cu_seqlens_k[1:].copy_(
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
)
accept_length = spec_info.accept_length[:bs]
metadata.max_seq_len_q = accept_length.max().item()
metadata.cu_seqlens_q[1:].copy_(
torch.cumsum(accept_length, dim=0, dtype=torch.int32)
)
max_seq_pages = (
metadata.max_seq_len_k + self.page_size - 1
) // self.page_size
page_indices = self.req_to_token[
req_pool_indices[:, None],
self.draft_extend_metadata["strided_indices"][:max_seq_pages],
]
page_indices //= self.page_size
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
if encoder_lens is not None:
# Only support encoder size 1 for now