Support EAGLE draft extend CUDA graph (#6606)
Co-authored-by: Sehoon Kim <sehoonkim@berkeley.edu>
This commit is contained in:
@@ -1268,6 +1268,29 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
),
|
||||
}
|
||||
|
||||
self.draft_extend_metadata = {
|
||||
"cache_seqlens": torch.zeros(
|
||||
max_bs, dtype=torch.int32, device=self.device
|
||||
),
|
||||
"cu_seqlens_q": torch.zeros(
|
||||
max_bs + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
"cu_seqlens_k": torch.zeros(
|
||||
max_bs + 1, dtype=torch.int32, device=self.device
|
||||
),
|
||||
"page_table": torch.zeros(
|
||||
max_bs,
|
||||
(self.max_context_len + self.page_size - 1) // self.page_size,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
"strided_indices": torch.arange(
|
||||
0, self.max_context_len, self.page_size, device=self.device
|
||||
),
|
||||
}
|
||||
|
||||
if self.topk > 1:
|
||||
self.target_verify_metadata_topk_normal = {
|
||||
"cache_seqlens": torch.zeros(
|
||||
@@ -1508,6 +1531,32 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
self.target_verify_metadata_topk_normal[bs] = metadata
|
||||
self.target_verify_metadata_topk_expand[bs] = metadata_expand
|
||||
elif forward_mode.is_draft_extend():
|
||||
metadata.cache_seqlens_int32 = self.draft_extend_metadata["cache_seqlens"][
|
||||
:bs
|
||||
]
|
||||
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
|
||||
|
||||
num_tokens_per_bs = num_tokens // bs
|
||||
metadata.max_seq_len_q = num_tokens_per_bs
|
||||
metadata.max_seq_len_k = seq_lens.max().item()
|
||||
|
||||
metadata.cu_seqlens_q = torch.arange(
|
||||
0,
|
||||
bs * num_tokens_per_bs + 1,
|
||||
num_tokens_per_bs,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
metadata.cu_seqlens_k = self.draft_extend_metadata["cu_seqlens_k"][
|
||||
: (bs + 1)
|
||||
]
|
||||
metadata.page_table = self.draft_extend_metadata["page_table"][
|
||||
req_pool_indices, :
|
||||
]
|
||||
|
||||
self.draft_extend_metadata[bs] = metadata
|
||||
|
||||
if encoder_lens is not None:
|
||||
encoder_bs = encoder_lens.numel()
|
||||
@@ -1732,6 +1781,29 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
metadata_expand.max_seq_len_k = (
|
||||
metadata_expand.cache_seqlens_int32.max().item()
|
||||
)
|
||||
elif forward_mode.is_draft_extend():
|
||||
metadata = self.draft_extend_metadata[bs]
|
||||
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
|
||||
|
||||
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
||||
metadata.cu_seqlens_k[1:].copy_(
|
||||
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
|
||||
)
|
||||
accept_length = spec_info.accept_length[:bs]
|
||||
metadata.max_seq_len_q = accept_length.max().item()
|
||||
metadata.cu_seqlens_q[1:].copy_(
|
||||
torch.cumsum(accept_length, dim=0, dtype=torch.int32)
|
||||
)
|
||||
|
||||
max_seq_pages = (
|
||||
metadata.max_seq_len_k + self.page_size - 1
|
||||
) // self.page_size
|
||||
page_indices = self.req_to_token[
|
||||
req_pool_indices[:, None],
|
||||
self.draft_extend_metadata["strided_indices"][:max_seq_pages],
|
||||
]
|
||||
page_indices //= self.page_size
|
||||
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
||||
|
||||
if encoder_lens is not None:
|
||||
# Only support encoder size 1 for now
|
||||
|
||||
Reference in New Issue
Block a user