perf: Optimize local attention memory allocation in FlashAttentionBackend (#6356)
This commit is contained in:
@@ -1434,19 +1434,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
self.decode_cuda_graph_metadata[bs] = metadata
|
||||
|
||||
if self.attention_chunk_size is not None:
|
||||
metadata.local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
||||
local_query_start_loc=self.decode_cuda_graph_local_attn_metadata[
|
||||
"local_query_start_loc"
|
||||
],
|
||||
local_seqused_k=self.decode_cuda_graph_local_attn_metadata[
|
||||
"local_seqused_k"
|
||||
],
|
||||
local_block_table=self.decode_cuda_graph_local_attn_metadata[
|
||||
"local_block_table"
|
||||
],
|
||||
local_max_query_len=1,
|
||||
local_max_seq_len=1,
|
||||
)
|
||||
self._update_local_attn_metadata_for_capture(metadata, batch_size)
|
||||
|
||||
elif forward_mode.is_target_verify():
|
||||
if self.topk <= 1:
|
||||
@@ -1807,6 +1795,62 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
)
|
||||
metadata.local_attn_metadata = local_metadata
|
||||
|
||||
def _update_local_attn_metadata_for_capture(
|
||||
self, metadata: FlashAttentionMetadata, bs: int
|
||||
):
|
||||
"""Update local attention metadata during CUDA graph capture phase.
|
||||
|
||||
This method calculates the exact buffer sizes needed for local attention metadata
|
||||
during the CUDA graph capture phase, optimizing memory usage by creating views of
|
||||
pre-allocated buffers with exactly the sizes needed.
|
||||
"""
|
||||
seq_lens_capture = metadata.cache_seqlens_int32
|
||||
max_seq_len = int(seq_lens_capture.max().item())
|
||||
page_table_capture = metadata.page_table
|
||||
|
||||
cu_seqlens_q_np = metadata.cu_seqlens_q.cpu().numpy()
|
||||
seqlens_np = seq_lens_capture.cpu().numpy()
|
||||
(
|
||||
seqlens_q_local_np,
|
||||
cu_seqlens_q_local_np,
|
||||
seqlens_k_local_np,
|
||||
block_table_local_np,
|
||||
) = make_local_attention_virtual_batches(
|
||||
self.attention_chunk_size,
|
||||
cu_seqlens_q_np,
|
||||
seqlens_np,
|
||||
page_table_capture,
|
||||
self.page_size,
|
||||
)
|
||||
|
||||
# Get exact dimensions from the calculation
|
||||
q_len = len(cu_seqlens_q_local_np)
|
||||
k_len = len(seqlens_k_local_np)
|
||||
b0 = block_table_local_np.shape[0] if block_table_local_np.shape[0] > 0 else bs
|
||||
b1 = block_table_local_np.shape[1] if block_table_local_np.shape[1] > 0 else 1
|
||||
|
||||
# Create views of the pre-allocated buffers with exactly these sizes
|
||||
# This is the key optimization - we only use the memory we actually need
|
||||
local_query_start_loc = self.decode_cuda_graph_local_attn_metadata[
|
||||
"local_query_start_loc"
|
||||
][:q_len]
|
||||
|
||||
local_seqused_k = self.decode_cuda_graph_local_attn_metadata["local_seqused_k"][
|
||||
:k_len
|
||||
]
|
||||
|
||||
local_block_table = self.decode_cuda_graph_local_attn_metadata[
|
||||
"local_block_table"
|
||||
][:b0, :b1]
|
||||
|
||||
metadata.local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
||||
local_query_start_loc=local_query_start_loc,
|
||||
local_seqused_k=local_seqused_k,
|
||||
local_block_table=local_block_table,
|
||||
local_max_query_len=1,
|
||||
local_max_seq_len=max_seq_len,
|
||||
)
|
||||
|
||||
def _update_local_attn_metadata_for_replay(
|
||||
self, metadata: FlashAttentionMetadata, bs: int
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user