Optimize a pad operation to accelerate 25us (#5945)
This commit is contained in:
@@ -1587,8 +1587,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
metadata.max_seq_len_k = max_len
|
metadata.max_seq_len_k = max_len
|
||||||
|
|
||||||
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
|
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
|
||||||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
# Optimize cumulative sequence length calculation
|
||||||
torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
metadata.cu_seqlens_k[1:].copy_(
|
||||||
|
torch.cumsum(seq_lens, dim=0, dtype=torch.int32)
|
||||||
)
|
)
|
||||||
|
|
||||||
max_seq_pages = (
|
max_seq_pages = (
|
||||||
|
|||||||
Reference in New Issue
Block a user