Minor cleanup of fa3 backend (#6999)
This commit is contained in:
@@ -1469,7 +1469,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
"cache_seqlens"
|
||||
][:bs]
|
||||
metadata.cache_seqlens_int32.copy_(
|
||||
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
|
||||
(seq_lens + self.speculative_num_draft_tokens)
|
||||
)
|
||||
|
||||
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
||||
@@ -1536,7 +1536,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
metadata.cache_seqlens_int32 = self.draft_extend_metadata["cache_seqlens"][
|
||||
:bs
|
||||
]
|
||||
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
|
||||
metadata.cache_seqlens_int32.copy_(seq_lens)
|
||||
|
||||
num_tokens_per_bs = num_tokens // bs
|
||||
metadata.max_seq_len_q = num_tokens_per_bs
|
||||
@@ -1600,38 +1600,32 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
if spec_info is not None:
|
||||
# Draft Decode
|
||||
if self.topk <= 1:
|
||||
metadata = self.decode_cuda_graph_metadata[bs]
|
||||
# When topk = 1, we use the normal decode metadata
|
||||
metadata.cache_seqlens_int32.copy_(
|
||||
(seq_lens + (self.speculative_step_id + 1)).to(torch.int32)
|
||||
)
|
||||
|
||||
metadata.max_seq_len_k = seq_lens_cpu.max().item() + (
|
||||
self.speculative_step_id + 1
|
||||
)
|
||||
metadata.cu_seqlens_k[1:].copy_(
|
||||
torch.cumsum(
|
||||
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
||||
)
|
||||
)
|
||||
|
||||
metadata = self.decode_cuda_graph_metadata[bs]
|
||||
max_len = seq_lens_cpu.max().item()
|
||||
metadata.max_seq_len_k = max_len + self.speculative_step_id + 1
|
||||
max_seq_pages = (
|
||||
metadata.max_seq_len_k + self.page_size - 1
|
||||
) // self.page_size
|
||||
page_indices = self.req_to_token[
|
||||
req_pool_indices[:, None],
|
||||
self.decode_cuda_graph_metadata["strided_indices"][
|
||||
:max_seq_pages
|
||||
],
|
||||
]
|
||||
|
||||
page_indices //= self.page_size
|
||||
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
||||
normal_decode_set_medadata(
|
||||
metadata.cache_seqlens_int32,
|
||||
metadata.cu_seqlens_k,
|
||||
metadata.page_table,
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
self.decode_cuda_graph_metadata["strided_indices"],
|
||||
max_seq_pages,
|
||||
seq_lens,
|
||||
self.speculative_step_id + 1,
|
||||
self.page_size,
|
||||
)
|
||||
|
||||
else:
|
||||
# When top k > 1, we need two specific draft decode metadata, and then merge states
|
||||
# 1. The first half of metadata for prefix tokens
|
||||
metadata = self.draft_decode_metadata_topk_normal[bs]
|
||||
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
|
||||
metadata.cache_seqlens_int32.copy_(seq_lens)
|
||||
# metadata.max_seq_len_q = self.topk, already set in capture
|
||||
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
||||
# metadata.cu_seqlens_q already set in capture
|
||||
@@ -1654,7 +1648,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
self.speculative_num_steps, -1
|
||||
).T.contiguous()
|
||||
metadata_expand.page_table[: cache_loc.shape[0]].copy_(
|
||||
cache_loc[:, :decode_length].contiguous().to(torch.int32)
|
||||
cache_loc[:, :decode_length]
|
||||
)
|
||||
# TODO: Handle local attention metadata for draft decode when llama4 eagle is supported
|
||||
else:
|
||||
@@ -1665,12 +1659,15 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
metadata.max_seq_len_k = max_len
|
||||
|
||||
normal_decode_set_medadata(
|
||||
metadata,
|
||||
metadata.cache_seqlens_int32,
|
||||
metadata.cu_seqlens_k,
|
||||
metadata.page_table,
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
self.decode_cuda_graph_metadata["strided_indices"],
|
||||
max_seq_pages,
|
||||
seq_lens,
|
||||
0,
|
||||
self.page_size,
|
||||
)
|
||||
|
||||
@@ -1679,7 +1676,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
if self.topk <= 1:
|
||||
metadata = self.target_verify_metadata[bs]
|
||||
metadata.cache_seqlens_int32.copy_(
|
||||
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
|
||||
(seq_lens + self.speculative_num_draft_tokens)
|
||||
)
|
||||
|
||||
metadata.max_seq_len_k = (
|
||||
@@ -1701,7 +1698,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
# When topk > 1, we need two specific target verify metadata, and then merge states
|
||||
# 1. The first half of metadata for prefix tokens
|
||||
metadata = self.target_verify_metadata_topk_normal[bs]
|
||||
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
|
||||
metadata.cache_seqlens_int32.copy_(seq_lens)
|
||||
# metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
|
||||
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
||||
# metadata.cu_seqlens_q already set in capture
|
||||
@@ -1761,9 +1758,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
metadata_expand.page_table.copy_(
|
||||
non_masked_page_table.gather(1, sort_order)
|
||||
)
|
||||
metadata_expand.cache_seqlens_int32.copy_(
|
||||
mask.sum(dim=1).to(torch.int32)
|
||||
)
|
||||
metadata_expand.cache_seqlens_int32.copy_(mask.sum(dim=1))
|
||||
metadata_expand.cu_seqlens_k[1:].copy_(
|
||||
torch.cumsum(
|
||||
metadata_expand.cache_seqlens_int32,
|
||||
@@ -1776,14 +1771,14 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
)
|
||||
elif forward_mode.is_draft_extend():
|
||||
metadata = self.draft_extend_metadata[bs]
|
||||
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
|
||||
metadata.cache_seqlens_int32.copy_(seq_lens)
|
||||
|
||||
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
||||
metadata.cu_seqlens_k[1:].copy_(
|
||||
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
|
||||
)
|
||||
accept_length = spec_info.accept_length[:bs]
|
||||
metadata.max_seq_len_q = accept_length.max().item()
|
||||
metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1
|
||||
metadata.cu_seqlens_q[1:].copy_(
|
||||
torch.cumsum(accept_length, dim=0, dtype=torch.int32)
|
||||
)
|
||||
@@ -1795,8 +1790,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
req_pool_indices[:, None],
|
||||
self.draft_extend_metadata["strided_indices"][:max_seq_pages],
|
||||
]
|
||||
page_indices //= self.page_size
|
||||
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
||||
metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)
|
||||
|
||||
if encoder_lens is not None:
|
||||
# Only support encoder size 1 for now
|
||||
@@ -2045,6 +2039,8 @@ class FlashAttentionMultiStepBackend:
|
||||
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
||||
|
||||
for i in range(self.speculative_num_steps - 1):
|
||||
# TODO: incrementally update the metadata for the later steps,
|
||||
# so that they do not need to recompute everything from scratch.
|
||||
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
|
||||
bs,
|
||||
forward_batch.req_pool_indices,
|
||||
@@ -2058,21 +2054,25 @@ class FlashAttentionMultiStepBackend:
|
||||
)
|
||||
|
||||
|
||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||
# @torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||
# TODO: fuse these kernels
|
||||
# NOTE: torch.compile makes it slower in speculative decoding
|
||||
def normal_decode_set_medadata(
|
||||
metadata,
|
||||
req_to_token,
|
||||
req_pool_indices,
|
||||
strided_indices,
|
||||
max_seq_pages,
|
||||
seq_lens,
|
||||
page_size,
|
||||
cache_seqlens_int32: torch.Tensor,
|
||||
cu_seqlens_k: torch.Tensor,
|
||||
page_table: torch.Tensor,
|
||||
req_to_token: torch.Tensor,
|
||||
req_pool_indices: torch.Tensor,
|
||||
strided_indices: torch.Tensor,
|
||||
max_seq_pages: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_len_delta: int,
|
||||
page_size: int,
|
||||
):
|
||||
metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
|
||||
metadata.cu_seqlens_k[1:].copy_(torch.cumsum(seq_lens, dim=0, dtype=torch.int32))
|
||||
cache_seqlens_int32.copy_(seq_lens + seq_len_delta)
|
||||
cu_seqlens_k[1:].copy_(torch.cumsum(cache_seqlens_int32, dim=0, dtype=torch.int32))
|
||||
page_indices = req_to_token[
|
||||
req_pool_indices[:, None],
|
||||
strided_indices[:max_seq_pages][None, :],
|
||||
]
|
||||
metadata.page_table[:, :max_seq_pages].copy_(page_indices // page_size)
|
||||
metadata.page_table[:, max_seq_pages:].fill_(0)
|
||||
page_table[:, :max_seq_pages].copy_(page_indices // page_size)
|
||||
|
||||
@@ -920,19 +920,18 @@ def fast_mla_decode_plan(
|
||||
self._page_size = page_size
|
||||
self._sm_scale = sm_scale
|
||||
|
||||
with self.device as device:
|
||||
try:
|
||||
# Standard version with just the required arguments (no use_profiler)
|
||||
self._cached_module.plan.default(
|
||||
self._float_workspace_buffer,
|
||||
self._int_workspace_buffer,
|
||||
self._pin_memory_int_workspace_buffer,
|
||||
qo_indptr_cpu,
|
||||
kv_indptr_cpu,
|
||||
kv_len_arr_cpu,
|
||||
num_heads,
|
||||
head_dim_ckv,
|
||||
causal,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in alternate MLA plan: {e}")
|
||||
try:
|
||||
# Standard version with just the required arguments (no use_profiler)
|
||||
self._cached_module.plan.default(
|
||||
self._float_workspace_buffer,
|
||||
self._int_workspace_buffer,
|
||||
self._pin_memory_int_workspace_buffer,
|
||||
qo_indptr_cpu,
|
||||
kv_indptr_cpu,
|
||||
kv_len_arr_cpu,
|
||||
num_heads,
|
||||
head_dim_ckv,
|
||||
causal,
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error in alternate MLA plan: {e}")
|
||||
|
||||
Reference in New Issue
Block a user