Minor cleanup of fa3 backend (#6999)
This commit is contained in:
@@ -1469,7 +1469,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
"cache_seqlens"
|
"cache_seqlens"
|
||||||
][:bs]
|
][:bs]
|
||||||
metadata.cache_seqlens_int32.copy_(
|
metadata.cache_seqlens_int32.copy_(
|
||||||
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
|
(seq_lens + self.speculative_num_draft_tokens)
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
||||||
@@ -1536,7 +1536,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
metadata.cache_seqlens_int32 = self.draft_extend_metadata["cache_seqlens"][
|
metadata.cache_seqlens_int32 = self.draft_extend_metadata["cache_seqlens"][
|
||||||
:bs
|
:bs
|
||||||
]
|
]
|
||||||
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
|
metadata.cache_seqlens_int32.copy_(seq_lens)
|
||||||
|
|
||||||
num_tokens_per_bs = num_tokens // bs
|
num_tokens_per_bs = num_tokens // bs
|
||||||
metadata.max_seq_len_q = num_tokens_per_bs
|
metadata.max_seq_len_q = num_tokens_per_bs
|
||||||
@@ -1600,38 +1600,32 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
if spec_info is not None:
|
if spec_info is not None:
|
||||||
# Draft Decode
|
# Draft Decode
|
||||||
if self.topk <= 1:
|
if self.topk <= 1:
|
||||||
metadata = self.decode_cuda_graph_metadata[bs]
|
|
||||||
# When topk = 1, we use the normal decode metadata
|
# When topk = 1, we use the normal decode metadata
|
||||||
metadata.cache_seqlens_int32.copy_(
|
metadata = self.decode_cuda_graph_metadata[bs]
|
||||||
(seq_lens + (self.speculative_step_id + 1)).to(torch.int32)
|
max_len = seq_lens_cpu.max().item()
|
||||||
)
|
metadata.max_seq_len_k = max_len + self.speculative_step_id + 1
|
||||||
|
|
||||||
metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
|
|
||||||
self.speculative_step_id + 1
|
|
||||||
)
|
|
||||||
metadata.cu_seqlens_k[1:].copy_(
|
|
||||||
torch.cumsum(
|
|
||||||
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
max_seq_pages = (
|
max_seq_pages = (
|
||||||
metadata.max_seq_len_k + self.page_size - 1
|
metadata.max_seq_len_k + self.page_size - 1
|
||||||
) // self.page_size
|
) // self.page_size
|
||||||
page_indices = self.req_to_token[
|
|
||||||
req_pool_indices[:, None],
|
|
||||||
self.decode_cuda_graph_metadata["strided_indices"][
|
|
||||||
:max_seq_pages
|
|
||||||
],
|
|
||||||
]
|
|
||||||
|
|
||||||
page_indices //= self.page_size
|
normal_decode_set_medadata(
|
||||||
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
metadata.cache_seqlens_int32,
|
||||||
|
metadata.cu_seqlens_k,
|
||||||
|
metadata.page_table,
|
||||||
|
self.req_to_token,
|
||||||
|
req_pool_indices,
|
||||||
|
self.decode_cuda_graph_metadata["strided_indices"],
|
||||||
|
max_seq_pages,
|
||||||
|
seq_lens,
|
||||||
|
self.speculative_step_id + 1,
|
||||||
|
self.page_size,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# When top k > 1, we need two specific draft decode metadata, and then merge states
|
# When top k > 1, we need two specific draft decode metadata, and then merge states
|
||||||
# 1. The first half of metadata for prefix tokens
|
# 1. The first half of metadata for prefix tokens
|
||||||
metadata = self.draft_decode_metadata_topk_normal[bs]
|
metadata = self.draft_decode_metadata_topk_normal[bs]
|
||||||
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
|
metadata.cache_seqlens_int32.copy_(seq_lens)
|
||||||
# metadata.max_seq_len_q = self.topk, already set in capture
|
# metadata.max_seq_len_q = self.topk, already set in capture
|
||||||
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
||||||
# metadata.cu_seqlens_q already set in capture
|
# metadata.cu_seqlens_q already set in capture
|
||||||
@@ -1654,7 +1648,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
self.speculative_num_steps, -1
|
self.speculative_num_steps, -1
|
||||||
).T.contiguous()
|
).T.contiguous()
|
||||||
metadata_expand.page_table[: cache_loc.shape[0]].copy_(
|
metadata_expand.page_table[: cache_loc.shape[0]].copy_(
|
||||||
cache_loc[:, :decode_length].contiguous().to(torch.int32)
|
cache_loc[:, :decode_length]
|
||||||
)
|
)
|
||||||
# TODO: Handle local attention metadata for draft decode when llama4 eagle is supported
|
# TODO: Handle local attention metadata for draft decode when llama4 eagle is supported
|
||||||
else:
|
else:
|
||||||
@@ -1665,12 +1659,15 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
metadata.max_seq_len_k = max_len
|
metadata.max_seq_len_k = max_len
|
||||||
|
|
||||||
normal_decode_set_medadata(
|
normal_decode_set_medadata(
|
||||||
metadata,
|
metadata.cache_seqlens_int32,
|
||||||
|
metadata.cu_seqlens_k,
|
||||||
|
metadata.page_table,
|
||||||
self.req_to_token,
|
self.req_to_token,
|
||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
self.decode_cuda_graph_metadata["strided_indices"],
|
self.decode_cuda_graph_metadata["strided_indices"],
|
||||||
max_seq_pages,
|
max_seq_pages,
|
||||||
seq_lens,
|
seq_lens,
|
||||||
|
0,
|
||||||
self.page_size,
|
self.page_size,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1679,7 +1676,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
if self.topk <= 1:
|
if self.topk <= 1:
|
||||||
metadata = self.target_verify_metadata[bs]
|
metadata = self.target_verify_metadata[bs]
|
||||||
metadata.cache_seqlens_int32.copy_(
|
metadata.cache_seqlens_int32.copy_(
|
||||||
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
|
(seq_lens + self.speculative_num_draft_tokens)
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata.max_seq_len_k = (
|
metadata.max_seq_len_k = (
|
||||||
@@ -1701,7 +1698,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
# When topk > 1, we need two specific target verify metadata, and then merge states
|
# When topk > 1, we need two specific target verify metadata, and then merge states
|
||||||
# 1. The first half of metadata for prefix tokens
|
# 1. The first half of metadata for prefix tokens
|
||||||
metadata = self.target_verify_metadata_topk_normal[bs]
|
metadata = self.target_verify_metadata_topk_normal[bs]
|
||||||
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
|
metadata.cache_seqlens_int32.copy_(seq_lens)
|
||||||
# metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
|
# metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
|
||||||
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
||||||
# metadata.cu_seqlens_q already set in capture
|
# metadata.cu_seqlens_q already set in capture
|
||||||
@@ -1761,9 +1758,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
metadata_expand.page_table.copy_(
|
metadata_expand.page_table.copy_(
|
||||||
non_masked_page_table.gather(1, sort_order)
|
non_masked_page_table.gather(1, sort_order)
|
||||||
)
|
)
|
||||||
metadata_expand.cache_seqlens_int32.copy_(
|
metadata_expand.cache_seqlens_int32.copy_(mask.sum(dim=1))
|
||||||
mask.sum(dim=1).to(torch.int32)
|
|
||||||
)
|
|
||||||
metadata_expand.cu_seqlens_k[1:].copy_(
|
metadata_expand.cu_seqlens_k[1:].copy_(
|
||||||
torch.cumsum(
|
torch.cumsum(
|
||||||
metadata_expand.cache_seqlens_int32,
|
metadata_expand.cache_seqlens_int32,
|
||||||
@@ -1776,14 +1771,14 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
elif forward_mode.is_draft_extend():
|
elif forward_mode.is_draft_extend():
|
||||||
metadata = self.draft_extend_metadata[bs]
|
metadata = self.draft_extend_metadata[bs]
|
||||||
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
|
metadata.cache_seqlens_int32.copy_(seq_lens)
|
||||||
|
|
||||||
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
||||||
metadata.cu_seqlens_k[1:].copy_(
|
metadata.cu_seqlens_k[1:].copy_(
|
||||||
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
|
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
|
||||||
)
|
)
|
||||||
accept_length = spec_info.accept_length[:bs]
|
accept_length = spec_info.accept_length[:bs]
|
||||||
metadata.max_seq_len_q = accept_length.max().item()
|
metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1
|
||||||
metadata.cu_seqlens_q[1:].copy_(
|
metadata.cu_seqlens_q[1:].copy_(
|
||||||
torch.cumsum(accept_length, dim=0, dtype=torch.int32)
|
torch.cumsum(accept_length, dim=0, dtype=torch.int32)
|
||||||
)
|
)
|
||||||
@@ -1795,8 +1790,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
req_pool_indices[:, None],
|
req_pool_indices[:, None],
|
||||||
self.draft_extend_metadata["strided_indices"][:max_seq_pages],
|
self.draft_extend_metadata["strided_indices"][:max_seq_pages],
|
||||||
]
|
]
|
||||||
page_indices //= self.page_size
|
metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)
|
||||||
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
|
||||||
|
|
||||||
if encoder_lens is not None:
|
if encoder_lens is not None:
|
||||||
# Only support encoder size 1 for now
|
# Only support encoder size 1 for now
|
||||||
@@ -2045,6 +2039,8 @@ class FlashAttentionMultiStepBackend:
|
|||||||
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
||||||
|
|
||||||
for i in range(self.speculative_num_steps - 1):
|
for i in range(self.speculative_num_steps - 1):
|
||||||
|
# TODO: incrementally update the metadata for the later steps,
|
||||||
|
# so that they do not need to recompute everything from scratch.
|
||||||
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
||||||
bs,
|
bs,
|
||||||
forward_batch.req_pool_indices,
|
forward_batch.req_pool_indices,
|
||||||
@@ -2058,21 +2054,25 @@ class FlashAttentionMultiStepBackend:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
# @torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||||
|
# TODO: fuse these kernels
|
||||||
|
# NOTE: torch.compile makes it slower in speculative decoding
|
||||||
def normal_decode_set_medadata(
|
def normal_decode_set_medadata(
|
||||||
metadata,
|
cache_seqlens_int32: torch.Tensor,
|
||||||
req_to_token,
|
cu_seqlens_k: torch.Tensor,
|
||||||
req_pool_indices,
|
page_table: torch.Tensor,
|
||||||
strided_indices,
|
req_to_token: torch.Tensor,
|
||||||
max_seq_pages,
|
req_pool_indices: torch.Tensor,
|
||||||
seq_lens,
|
strided_indices: torch.Tensor,
|
||||||
page_size,
|
max_seq_pages: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
seq_len_delta: int,
|
||||||
|
page_size: int,
|
||||||
):
|
):
|
||||||
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
|
cache_seqlens_int32.copy_(seq_lens + seq_len_delta)
|
||||||
metadata.cu_seqlens_k[1:].copy_(torch.cumsum(seq_lens, dim=0, dtype=torch.int32))
|
cu_seqlens_k[1:].copy_(torch.cumsum(cache_seqlens_int32, dim=0, dtype=torch.int32))
|
||||||
page_indices = req_to_token[
|
page_indices = req_to_token[
|
||||||
req_pool_indices[:, None],
|
req_pool_indices[:, None],
|
||||||
strided_indices[:max_seq_pages][None, :],
|
strided_indices[:max_seq_pages][None, :],
|
||||||
]
|
]
|
||||||
metadata.page_table[:, :max_seq_pages].copy_(page_indices // page_size)
|
page_table[:, :max_seq_pages].copy_(page_indices // page_size)
|
||||||
metadata.page_table[:, max_seq_pages:].fill_(0)
|
|
||||||
|
|||||||
@@ -920,19 +920,18 @@ def fast_mla_decode_plan(
|
|||||||
self._page_size = page_size
|
self._page_size = page_size
|
||||||
self._sm_scale = sm_scale
|
self._sm_scale = sm_scale
|
||||||
|
|
||||||
with self.device as device:
|
try:
|
||||||
try:
|
# Standard version with just the required arguments (no use_profiler)
|
||||||
# Standard version with just the required arguments (no use_profiler)
|
self._cached_module.plan.default(
|
||||||
self._cached_module.plan.default(
|
self._float_workspace_buffer,
|
||||||
self._float_workspace_buffer,
|
self._int_workspace_buffer,
|
||||||
self._int_workspace_buffer,
|
self._pin_memory_int_workspace_buffer,
|
||||||
self._pin_memory_int_workspace_buffer,
|
qo_indptr_cpu,
|
||||||
qo_indptr_cpu,
|
kv_indptr_cpu,
|
||||||
kv_indptr_cpu,
|
kv_len_arr_cpu,
|
||||||
kv_len_arr_cpu,
|
num_heads,
|
||||||
num_heads,
|
head_dim_ckv,
|
||||||
head_dim_ckv,
|
causal,
|
||||||
causal,
|
)
|
||||||
)
|
except Exception as e:
|
||||||
except Exception as e:
|
raise RuntimeError(f"Error in alternate MLA plan: {e}")
|
||||||
raise RuntimeError(f"Error in alternate MLA plan: {e}")
|
|
||||||
|
|||||||
Reference in New Issue
Block a user