Fix fa3 cuda graph page_size > 1 precision and page_size=1 speed (#4855)
This commit is contained in:
@@ -322,10 +322,13 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
torch.cumsum(seq_lens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
torch.cumsum(seq_lens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata.page_table = self.req_to_token[
|
max_seq_pages = (metadata.max_seq_len_k + self.page_size - 1) // self.page_size
|
||||||
:, self.decode_cuda_graph_metadata["strided_indices"]
|
page_indices = self.req_to_token[
|
||||||
|
:, self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages]
|
||||||
]
|
]
|
||||||
metadata.page_table = metadata.page_table[req_pool_indices[:bs]]
|
page_indices = page_indices[req_pool_indices[:bs]] // self.page_size
|
||||||
|
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
||||||
|
metadata.page_table[:, max_seq_pages:].fill_(0)
|
||||||
self.forward_metadata = metadata
|
self.forward_metadata = metadata
|
||||||
|
|
||||||
def get_cuda_graph_seq_len_fill_value(self):
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user