[FA3 Attn Backend] Remove Unnecessary Device Sync for FA3 (#4745)
Co-authored-by: Yubo Wang <yubowang2019@gmail.com>
This commit is contained in:
@@ -29,11 +29,11 @@ class FlashAttentionMetadata:
|
||||
|
||||
cu_seqlens_q: torch.Tensor = None
|
||||
cu_seqlens_k: torch.Tensor = None
|
||||
max_seq_len_q: int = 0
|
||||
max_seq_len_k: int = 0
|
||||
window_size: tuple = (-1, -1)
|
||||
page_table: torch.Tensor = None
|
||||
cache_seqlens_int32: torch.Tensor = None
|
||||
max_seq_len_q: int = 0
|
||||
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
@@ -63,7 +63,6 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
# Create metadata based on forward mode
|
||||
metadata = FlashAttentionMetadata()
|
||||
|
||||
extend_seq_lens = forward_batch.extend_seq_lens
|
||||
# Get sequence information
|
||||
seqlens_in_batch = forward_batch.seq_lens
|
||||
# Precompute int32 version of sequence lengths
|
||||
@@ -85,15 +84,16 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
0, batch_size + 1, dtype=torch.int32, device=device
|
||||
)
|
||||
else:
|
||||
extend_no_prefix = not any(forward_batch.extend_prefix_lens)
|
||||
# Precompute cumulative sequence lengths
|
||||
if not extend_no_prefix:
|
||||
if any(forward_batch.extend_prefix_lens_cpu):
|
||||
extend_seq_lens = forward_batch.extend_seq_lens
|
||||
metadata.cu_seqlens_q = torch.nn.functional.pad(
|
||||
torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
||||
)
|
||||
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
|
||||
else:
|
||||
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
||||
metadata.max_seq_len_q = seqlens_in_batch.max().item()
|
||||
metadata.max_seq_len_q = metadata.max_seq_len_k
|
||||
self.forward_metadata = metadata
|
||||
|
||||
def forward_extend(
|
||||
@@ -274,20 +274,26 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
# """Initialize forward metadata for replaying CUDA graph."""
|
||||
seqlens_in_batch = seq_lens[:bs]
|
||||
metadata = self.decode_cuda_graph_metadata[bs]
|
||||
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
||||
|
||||
# For CPU operations
|
||||
max_len = seq_lens_cpu[:bs].max().item()
|
||||
metadata.max_seq_len_k = max_len
|
||||
|
||||
# For GPU operations
|
||||
seq_lens_in_batch = seq_lens[:bs]
|
||||
metadata.cache_seqlens_int32 = seq_lens_in_batch.to(torch.int32)
|
||||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
||||
torch.cumsum(seq_lens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
||||
)
|
||||
# Precompute maximum sequence length
|
||||
metadata.max_seq_len_k = seqlens_in_batch.max().item()
|
||||
|
||||
# Only zero out the part out of max_len_k
|
||||
metadata.page_table[:, metadata.max_seq_len_k :].fill_(0)
|
||||
# Then do the copy
|
||||
metadata.page_table[:, : metadata.max_seq_len_k].copy_(
|
||||
self.req_to_token[req_pool_indices[:bs], : metadata.max_seq_len_k]
|
||||
)
|
||||
|
||||
self.forward_decode_metadata = metadata
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
|
||||
@@ -1376,6 +1376,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
if (
|
||||
global_server_args_dict["enable_flashinfer_mla"]
|
||||
or global_server_args_dict["enable_flashmla"]
|
||||
or global_server_args_dict["attention_backend"] == "fa3"
|
||||
):
|
||||
decode_seq_lens = self.seq_lens.cpu()
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user