[FA3 Attn Backend] Remove Unnecessary Device Sync for FA3 (#4745)
Co-authored-by: Yubo Wang <yubowang2019@gmail.com>
This commit is contained in:
@@ -29,11 +29,11 @@ class FlashAttentionMetadata:
|
|||||||
|
|
||||||
cu_seqlens_q: torch.Tensor = None
|
cu_seqlens_q: torch.Tensor = None
|
||||||
cu_seqlens_k: torch.Tensor = None
|
cu_seqlens_k: torch.Tensor = None
|
||||||
|
max_seq_len_q: int = 0
|
||||||
max_seq_len_k: int = 0
|
max_seq_len_k: int = 0
|
||||||
window_size: tuple = (-1, -1)
|
window_size: tuple = (-1, -1)
|
||||||
page_table: torch.Tensor = None
|
page_table: torch.Tensor = None
|
||||||
cache_seqlens_int32: torch.Tensor = None
|
cache_seqlens_int32: torch.Tensor = None
|
||||||
max_seq_len_q: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
class FlashAttentionBackend(AttentionBackend):
|
class FlashAttentionBackend(AttentionBackend):
|
||||||
@@ -63,7 +63,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
# Create metadata based on forward mode
|
# Create metadata based on forward mode
|
||||||
metadata = FlashAttentionMetadata()
|
metadata = FlashAttentionMetadata()
|
||||||
|
|
||||||
extend_seq_lens = forward_batch.extend_seq_lens
|
|
||||||
# Get sequence information
|
# Get sequence information
|
||||||
seqlens_in_batch = forward_batch.seq_lens
|
seqlens_in_batch = forward_batch.seq_lens
|
||||||
# Precompute int32 version of sequence lengths
|
# Precompute int32 version of sequence lengths
|
||||||
@@ -85,15 +84,16 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
0, batch_size + 1, dtype=torch.int32, device=device
|
0, batch_size + 1, dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
extend_no_prefix = not any(forward_batch.extend_prefix_lens)
|
|
||||||
# Precompute cumulative sequence lengths
|
# Precompute cumulative sequence lengths
|
||||||
if not extend_no_prefix:
|
if any(forward_batch.extend_prefix_lens_cpu):
|
||||||
|
extend_seq_lens = forward_batch.extend_seq_lens
|
||||||
metadata.cu_seqlens_q = torch.nn.functional.pad(
|
metadata.cu_seqlens_q = torch.nn.functional.pad(
|
||||||
torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
||||||
)
|
)
|
||||||
|
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
|
||||||
else:
|
else:
|
||||||
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
||||||
metadata.max_seq_len_q = seqlens_in_batch.max().item()
|
metadata.max_seq_len_q = metadata.max_seq_len_k
|
||||||
self.forward_metadata = metadata
|
self.forward_metadata = metadata
|
||||||
|
|
||||||
def forward_extend(
|
def forward_extend(
|
||||||
@@ -274,20 +274,26 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
seq_lens_cpu: Optional[torch.Tensor],
|
seq_lens_cpu: Optional[torch.Tensor],
|
||||||
):
|
):
|
||||||
# """Initialize forward metadata for replaying CUDA graph."""
|
# """Initialize forward metadata for replaying CUDA graph."""
|
||||||
seqlens_in_batch = seq_lens[:bs]
|
|
||||||
metadata = self.decode_cuda_graph_metadata[bs]
|
metadata = self.decode_cuda_graph_metadata[bs]
|
||||||
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
|
||||||
|
# For CPU operations
|
||||||
|
max_len = seq_lens_cpu[:bs].max().item()
|
||||||
|
metadata.max_seq_len_k = max_len
|
||||||
|
|
||||||
|
# For GPU operations
|
||||||
|
seq_lens_in_batch = seq_lens[:bs]
|
||||||
|
metadata.cache_seqlens_int32 = seq_lens_in_batch.to(torch.int32)
|
||||||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||||
torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
torch.cumsum(seq_lens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
||||||
)
|
)
|
||||||
# Precompute maximum sequence length
|
|
||||||
metadata.max_seq_len_k = seqlens_in_batch.max().item()
|
|
||||||
# Only zero out the part out of max_len_k
|
# Only zero out the part out of max_len_k
|
||||||
metadata.page_table[:, metadata.max_seq_len_k :].fill_(0)
|
metadata.page_table[:, metadata.max_seq_len_k :].fill_(0)
|
||||||
# Then do the copy
|
# Then do the copy
|
||||||
metadata.page_table[:, : metadata.max_seq_len_k].copy_(
|
metadata.page_table[:, : metadata.max_seq_len_k].copy_(
|
||||||
self.req_to_token[req_pool_indices[:bs], : metadata.max_seq_len_k]
|
self.req_to_token[req_pool_indices[:bs], : metadata.max_seq_len_k]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.forward_decode_metadata = metadata
|
self.forward_decode_metadata = metadata
|
||||||
|
|
||||||
def get_cuda_graph_seq_len_fill_value(self):
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
|
|||||||
@@ -1376,6 +1376,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
if (
|
if (
|
||||||
global_server_args_dict["enable_flashinfer_mla"]
|
global_server_args_dict["enable_flashinfer_mla"]
|
||||||
or global_server_args_dict["enable_flashmla"]
|
or global_server_args_dict["enable_flashmla"]
|
||||||
|
or global_server_args_dict["attention_backend"] == "fa3"
|
||||||
):
|
):
|
||||||
decode_seq_lens = self.seq_lens.cpu()
|
decode_seq_lens = self.seq_lens.cpu()
|
||||||
else:
|
else:
|
||||||
|
|||||||
Reference in New Issue
Block a user