perf: Optimize local attention memory allocation in FlashAttentionBackend (#6356)
This commit is contained in:
@@ -1434,19 +1434,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
self.decode_cuda_graph_metadata[bs] = metadata
|
self.decode_cuda_graph_metadata[bs] = metadata
|
||||||
|
|
||||||
if self.attention_chunk_size is not None:
|
if self.attention_chunk_size is not None:
|
||||||
metadata.local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
self._update_local_attn_metadata_for_capture(metadata, batch_size)
|
||||||
local_query_start_loc=self.decode_cuda_graph_local_attn_metadata[
|
|
||||||
"local_query_start_loc"
|
|
||||||
],
|
|
||||||
local_seqused_k=self.decode_cuda_graph_local_attn_metadata[
|
|
||||||
"local_seqused_k"
|
|
||||||
],
|
|
||||||
local_block_table=self.decode_cuda_graph_local_attn_metadata[
|
|
||||||
"local_block_table"
|
|
||||||
],
|
|
||||||
local_max_query_len=1,
|
|
||||||
local_max_seq_len=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
elif forward_mode.is_target_verify():
|
elif forward_mode.is_target_verify():
|
||||||
if self.topk <= 1:
|
if self.topk <= 1:
|
||||||
@@ -1807,6 +1795,62 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
metadata.local_attn_metadata = local_metadata
|
metadata.local_attn_metadata = local_metadata
|
||||||
|
|
||||||
|
def _update_local_attn_metadata_for_capture(
|
||||||
|
self, metadata: FlashAttentionMetadata, bs: int
|
||||||
|
):
|
||||||
|
"""Update local attention metadata during CUDA graph capture phase.
|
||||||
|
|
||||||
|
This method calculates the exact buffer sizes needed for local attention metadata
|
||||||
|
during the CUDA graph capture phase, optimizing memory usage by creating views of
|
||||||
|
pre-allocated buffers with exactly the sizes needed.
|
||||||
|
"""
|
||||||
|
seq_lens_capture = metadata.cache_seqlens_int32
|
||||||
|
max_seq_len = int(seq_lens_capture.max().item())
|
||||||
|
page_table_capture = metadata.page_table
|
||||||
|
|
||||||
|
cu_seqlens_q_np = metadata.cu_seqlens_q.cpu().numpy()
|
||||||
|
seqlens_np = seq_lens_capture.cpu().numpy()
|
||||||
|
(
|
||||||
|
seqlens_q_local_np,
|
||||||
|
cu_seqlens_q_local_np,
|
||||||
|
seqlens_k_local_np,
|
||||||
|
block_table_local_np,
|
||||||
|
) = make_local_attention_virtual_batches(
|
||||||
|
self.attention_chunk_size,
|
||||||
|
cu_seqlens_q_np,
|
||||||
|
seqlens_np,
|
||||||
|
page_table_capture,
|
||||||
|
self.page_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get exact dimensions from the calculation
|
||||||
|
q_len = len(cu_seqlens_q_local_np)
|
||||||
|
k_len = len(seqlens_k_local_np)
|
||||||
|
b0 = block_table_local_np.shape[0] if block_table_local_np.shape[0] > 0 else bs
|
||||||
|
b1 = block_table_local_np.shape[1] if block_table_local_np.shape[1] > 0 else 1
|
||||||
|
|
||||||
|
# Create views of the pre-allocated buffers with exactly these sizes
|
||||||
|
# This is the key optimization - we only use the memory we actually need
|
||||||
|
local_query_start_loc = self.decode_cuda_graph_local_attn_metadata[
|
||||||
|
"local_query_start_loc"
|
||||||
|
][:q_len]
|
||||||
|
|
||||||
|
local_seqused_k = self.decode_cuda_graph_local_attn_metadata["local_seqused_k"][
|
||||||
|
:k_len
|
||||||
|
]
|
||||||
|
|
||||||
|
local_block_table = self.decode_cuda_graph_local_attn_metadata[
|
||||||
|
"local_block_table"
|
||||||
|
][:b0, :b1]
|
||||||
|
|
||||||
|
metadata.local_attn_metadata = FlashAttentionMetadata.LocalAttentionMetadata(
|
||||||
|
local_query_start_loc=local_query_start_loc,
|
||||||
|
local_seqused_k=local_seqused_k,
|
||||||
|
local_block_table=local_block_table,
|
||||||
|
local_max_query_len=1,
|
||||||
|
local_max_seq_len=max_seq_len,
|
||||||
|
)
|
||||||
|
|
||||||
def _update_local_attn_metadata_for_replay(
|
def _update_local_attn_metadata_for_replay(
|
||||||
self, metadata: FlashAttentionMetadata, bs: int
|
self, metadata: FlashAttentionMetadata, bs: int
|
||||||
):
|
):
|
||||||
|
|||||||
Reference in New Issue
Block a user