Perormance: Enable cuda graph for dp idle batch (#7269)
Co-authored-by: austindeng <austindeng@tencent.com> Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Co-authored-by: ch-wan <cwan39@gatech.edu>
This commit is contained in:
@@ -1704,14 +1704,15 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
||||
metadata_expand = self.target_verify_metadata_topk_expand[bs]
|
||||
|
||||
# metadata_expand.max_seq_len_q = 1, already set in capture
|
||||
# metadata_expand.cu_seqlens_q already set in capture
|
||||
|
||||
offsets = torch.arange(
|
||||
self.speculative_num_draft_tokens, device=device
|
||||
).unsqueeze(
|
||||
0
|
||||
) # shape: (1, self.speculative_num_draft_tokens)
|
||||
|
||||
cols = offsets.expand(seq_lens.numel(), -1) + seq_lens.unsqueeze(1)
|
||||
cum_len = torch.nn.functional.pad(
|
||||
torch.cumsum(
|
||||
@@ -1728,17 +1729,20 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
).view(1, -1)
|
||||
# avoid extracting padded seq indices which will be out of boundary
|
||||
mask_extraction_indices[
|
||||
:, spec_info.positions.numel() * self.speculative_num_draft_tokens :
|
||||
:,
|
||||
spec_info.positions.numel() * self.speculative_num_draft_tokens :,
|
||||
].fill_(0)
|
||||
|
||||
mask = spec_info.custom_mask[mask_extraction_indices].view(
|
||||
-1, self.speculative_num_draft_tokens
|
||||
) # (bsz * draft_num, draft_num)
|
||||
|
||||
col_indices = offsets.expand(
|
||||
mask.shape[0], self.speculative_num_draft_tokens
|
||||
)
|
||||
keys = torch.where(
|
||||
mask, col_indices, col_indices + self.speculative_num_draft_tokens
|
||||
mask,
|
||||
col_indices,
|
||||
col_indices + self.speculative_num_draft_tokens,
|
||||
)
|
||||
_, sort_order = torch.sort(keys, dim=1)
|
||||
|
||||
@@ -1747,6 +1751,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
.gather(1, cols)
|
||||
.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
|
||||
) # (bsz, draft_num)
|
||||
|
||||
metadata_expand.page_table.copy_(
|
||||
non_masked_page_table.gather(1, sort_order)
|
||||
)
|
||||
@@ -1758,6 +1763,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
dtype=torch.int32,
|
||||
)
|
||||
)
|
||||
|
||||
elif forward_mode.is_draft_extend():
|
||||
metadata = self.draft_extend_metadata[bs]
|
||||
metadata.cache_seqlens_int32.copy_(seq_lens)
|
||||
@@ -1767,7 +1773,11 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
|
||||
)
|
||||
accept_length = spec_info.accept_length[:bs]
|
||||
metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1
|
||||
if spec_info.accept_length_cpu:
|
||||
metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1
|
||||
else:
|
||||
metadata.max_seq_len_q = 1
|
||||
|
||||
metadata.cu_seqlens_q[1:].copy_(
|
||||
torch.cumsum(accept_length, dim=0, dtype=torch.int32)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user