Perormance: Enable cuda graph for dp idle batch (#7269)

Co-authored-by: austindeng <austindeng@tencent.com>
Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com>
Co-authored-by: ch-wan <cwan39@gatech.edu>
This commit is contained in:
u4lr451
2025-06-24 08:34:13 +08:00
committed by GitHub
parent fa42e41962
commit ed0a0b692c
5 changed files with 51 additions and 50 deletions

View File

@@ -1704,14 +1704,15 @@ class FlashAttentionBackend(AttentionBackend):
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
metadata_expand = self.target_verify_metadata_topk_expand[bs]
# metadata_expand.max_seq_len_q = 1, already set in capture
# metadata_expand.cu_seqlens_q already set in capture
offsets = torch.arange(
self.speculative_num_draft_tokens, device=device
).unsqueeze(
0
) # shape: (1, self.speculative_num_draft_tokens)
cols = offsets.expand(seq_lens.numel(), -1) + seq_lens.unsqueeze(1)
cum_len = torch.nn.functional.pad(
torch.cumsum(
@@ -1728,17 +1729,20 @@ class FlashAttentionBackend(AttentionBackend):
).view(1, -1)
# avoid extracting padded seq indices which will be out of boundary
mask_extraction_indices[
:, spec_info.positions.numel() * self.speculative_num_draft_tokens :
:,
spec_info.positions.numel() * self.speculative_num_draft_tokens :,
].fill_(0)
mask = spec_info.custom_mask[mask_extraction_indices].view(
-1, self.speculative_num_draft_tokens
) # (bsz * draft_num, draft_num)
col_indices = offsets.expand(
mask.shape[0], self.speculative_num_draft_tokens
)
keys = torch.where(
mask, col_indices, col_indices + self.speculative_num_draft_tokens
mask,
col_indices,
col_indices + self.speculative_num_draft_tokens,
)
_, sort_order = torch.sort(keys, dim=1)
@@ -1747,6 +1751,7 @@ class FlashAttentionBackend(AttentionBackend):
.gather(1, cols)
.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
) # (bsz, draft_num)
metadata_expand.page_table.copy_(
non_masked_page_table.gather(1, sort_order)
)
@@ -1758,6 +1763,7 @@ class FlashAttentionBackend(AttentionBackend):
dtype=torch.int32,
)
)
elif forward_mode.is_draft_extend():
metadata = self.draft_extend_metadata[bs]
metadata.cache_seqlens_int32.copy_(seq_lens)
@@ -1767,7 +1773,11 @@ class FlashAttentionBackend(AttentionBackend):
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
)
accept_length = spec_info.accept_length[:bs]
metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1
if spec_info.accept_length_cpu:
metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1
else:
metadata.max_seq_len_q = 1
metadata.cu_seqlens_q[1:].copy_(
torch.cumsum(accept_length, dim=0, dtype=torch.int32)
)