perf: optimize local_block_table memory allocation (#6273)
This commit is contained in:
@@ -1165,7 +1165,6 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
max_virtual_batches = max_bs * (
|
||||
(max_seq_len + attn_chunk_size - 1) // attn_chunk_size
|
||||
)
|
||||
max_blocks_per_seq = (max_seq_len + attn_chunk_size - 1) // attn_chunk_size
|
||||
max_pages_per_block = (attn_chunk_size + page_size - 1) // page_size
|
||||
|
||||
self.decode_cuda_graph_local_attn_metadata = {
|
||||
@@ -1177,7 +1176,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
),
|
||||
"local_block_table": torch.zeros(
|
||||
max_virtual_batches,
|
||||
max_blocks_per_seq * max_pages_per_block,
|
||||
max_pages_per_block,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user