perf: optimize local_block_table memory allocation (#6273)
This commit is contained in:
@@ -1165,7 +1165,6 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
max_virtual_batches = max_bs * (
|
max_virtual_batches = max_bs * (
|
||||||
(max_seq_len + attn_chunk_size - 1) // attn_chunk_size
|
(max_seq_len + attn_chunk_size - 1) // attn_chunk_size
|
||||||
)
|
)
|
||||||
max_blocks_per_seq = (max_seq_len + attn_chunk_size - 1) // attn_chunk_size
|
|
||||||
max_pages_per_block = (attn_chunk_size + page_size - 1) // page_size
|
max_pages_per_block = (attn_chunk_size + page_size - 1) // page_size
|
||||||
|
|
||||||
self.decode_cuda_graph_local_attn_metadata = {
|
self.decode_cuda_graph_local_attn_metadata = {
|
||||||
@@ -1177,7 +1176,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
),
|
),
|
||||||
"local_block_table": torch.zeros(
|
"local_block_table": torch.zeros(
|
||||||
max_virtual_batches,
|
max_virtual_batches,
|
||||||
max_blocks_per_seq * max_pages_per_block,
|
max_pages_per_block,
|
||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
),
|
),
|
||||||
|
|||||||
Reference in New Issue
Block a user