optimize pad operations in fa3 to accelarate 100+us (#6077)
This commit is contained in:
@@ -1525,12 +1525,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
|
metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
|
||||||
self.speculative_step_id + 1
|
self.speculative_step_id + 1
|
||||||
)
|
)
|
||||||
metadata.cu_seqlens_k.copy_(
|
metadata.cu_seqlens_k[1:].copy_(
|
||||||
torch.nn.functional.pad(
|
torch.cumsum(
|
||||||
torch.cumsum(
|
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
||||||
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
|
||||||
),
|
|
||||||
(1, 0),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1554,12 +1551,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
# metadata.max_seq_len_q = self.topk, already set in capture
|
# metadata.max_seq_len_q = self.topk, already set in capture
|
||||||
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
||||||
# metadata.cu_seqlens_q already set in capture
|
# metadata.cu_seqlens_q already set in capture
|
||||||
metadata.cu_seqlens_k.copy_(
|
metadata.cu_seqlens_k[1:].copy_(
|
||||||
torch.nn.functional.pad(
|
torch.cumsum(
|
||||||
torch.cumsum(
|
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
||||||
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
|
||||||
),
|
|
||||||
(1, 0),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1616,13 +1610,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
metadata.max_seq_len_k = (
|
metadata.max_seq_len_k = (
|
||||||
seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
|
seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
|
||||||
)
|
)
|
||||||
metadata.cu_seqlens_k.copy_(
|
metadata.cu_seqlens_k[1:].copy_(
|
||||||
torch.nn.functional.pad(
|
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
|
||||||
torch.cumsum(
|
|
||||||
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
|
||||||
),
|
|
||||||
(1, 0),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
max_seq_pages = (
|
max_seq_pages = (
|
||||||
metadata.max_seq_len_k + self.page_size - 1
|
metadata.max_seq_len_k + self.page_size - 1
|
||||||
@@ -1641,13 +1630,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
# metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
|
# metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
|
||||||
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
||||||
# metadata.cu_seqlens_q already set in capture
|
# metadata.cu_seqlens_q already set in capture
|
||||||
metadata.cu_seqlens_k.copy_(
|
metadata.cu_seqlens_k[1:].copy_(
|
||||||
torch.nn.functional.pad(
|
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
|
||||||
torch.cumsum(
|
|
||||||
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
|
||||||
),
|
|
||||||
(1, 0),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
page_table = self.req_to_token[
|
page_table = self.req_to_token[
|
||||||
req_pool_indices, : metadata.max_seq_len_k
|
req_pool_indices, : metadata.max_seq_len_k
|
||||||
@@ -1705,14 +1689,11 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
metadata_expand.cache_seqlens_int32.copy_(
|
metadata_expand.cache_seqlens_int32.copy_(
|
||||||
mask.sum(dim=1).to(torch.int32)
|
mask.sum(dim=1).to(torch.int32)
|
||||||
)
|
)
|
||||||
metadata_expand.cu_seqlens_k.copy_(
|
metadata_expand.cu_seqlens_k[1:].copy_(
|
||||||
torch.nn.functional.pad(
|
torch.cumsum(
|
||||||
torch.cumsum(
|
metadata_expand.cache_seqlens_int32,
|
||||||
metadata_expand.cache_seqlens_int32,
|
dim=0,
|
||||||
dim=0,
|
dtype=torch.int32,
|
||||||
dtype=torch.int32,
|
|
||||||
),
|
|
||||||
(1, 0),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
metadata_expand.max_seq_len_k = (
|
metadata_expand.max_seq_len_k = (
|
||||||
@@ -1723,11 +1704,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
# Only support encoder size 1 for now
|
# Only support encoder size 1 for now
|
||||||
metadata.encoder_max_seq_len_k = encoder_lens[0]
|
metadata.encoder_max_seq_len_k = encoder_lens[0]
|
||||||
metadata.encoder_lens_int32.copy_(encoder_lens[:1])
|
metadata.encoder_lens_int32.copy_(encoder_lens[:1])
|
||||||
metadata.encoder_cu_seqlens_k.copy_(
|
metadata.encoder_cu_seqlens_k[1:].copy_(
|
||||||
torch.nn.functional.pad(
|
torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32)
|
||||||
torch.cumsum(metadata.encoder_lens_int32, dim=0, dtype=torch.int32),
|
|
||||||
(1, 0),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata.encoder_page_table[:, : metadata.encoder_max_seq_len_k].copy_(
|
metadata.encoder_page_table[:, : metadata.encoder_max_seq_len_k].copy_(
|
||||||
|
|||||||
Reference in New Issue
Block a user