perf: optimize local_block_table memory allocation (#6273)

This commit is contained in:
Chang Su
2025-05-13 17:18:38 -07:00
committed by GitHub
parent 0f75b907c6
commit 912788c095

View File

@@ -1165,7 +1165,6 @@ class FlashAttentionBackend(AttentionBackend):
max_virtual_batches = max_bs * (
(max_seq_len + attn_chunk_size - 1) // attn_chunk_size
)
max_blocks_per_seq = (max_seq_len + attn_chunk_size - 1) // attn_chunk_size
max_pages_per_block = (attn_chunk_size + page_size - 1) // page_size
self.decode_cuda_graph_local_attn_metadata = {
@@ -1177,7 +1176,7 @@ class FlashAttentionBackend(AttentionBackend):
),
"local_block_table": torch.zeros(
max_virtual_batches,
max_blocks_per_seq * max_pages_per_block,
max_pages_per_block,
dtype=torch.int32,
device=self.device,
),