Add Speculative Decoding Eagle3 topk > 1 (#5318)
Co-authored-by: Stefan He <hebiaobuaa@gmail.com> Co-authored-by: Yubo Wang <yubowang2019@gmail.com>
This commit is contained in:
@@ -16,6 +16,7 @@ if TYPE_CHECKING:
|
|||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
|
|
||||||
|
from sgl_kernel import merge_state_v2
|
||||||
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||||
|
|
||||||
|
|
||||||
@@ -30,7 +31,7 @@ class FlashAttentionMetadata:
|
|||||||
# Sequence lengths for the forward batch
|
# Sequence lengths for the forward batch
|
||||||
cache_seqlens_int32: torch.Tensor = None
|
cache_seqlens_int32: torch.Tensor = None
|
||||||
# Maximum sequence length for query
|
# Maximum sequence length for query
|
||||||
max_seq_len_q: int = 0
|
max_seq_len_q: int = 1
|
||||||
# Maximum sequence length for key
|
# Maximum sequence length for key
|
||||||
max_seq_len_k: int = 0
|
max_seq_len_k: int = 0
|
||||||
# Cumulative sequence lengths for query
|
# Cumulative sequence lengths for query
|
||||||
@@ -267,6 +268,12 @@ def cdiv(a: int, b: int) -> int:
|
|||||||
return -(a // -b)
|
return -(a // -b)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(hebiao064): remove this once we have a better way to handle the merge_state_v2 torch.compile issue
|
||||||
|
@torch._dynamo.disable()
|
||||||
|
def merge_state_v2_wrapper(o, s_a, o_exp, s_b):
|
||||||
|
return merge_state_v2(o, s_a, o_exp, s_b)
|
||||||
|
|
||||||
|
|
||||||
class FlashAttentionBackend(AttentionBackend):
|
class FlashAttentionBackend(AttentionBackend):
|
||||||
"""FlashAttention backend implementation.
|
"""FlashAttention backend implementation.
|
||||||
|
|
||||||
@@ -301,6 +308,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
), "Sliding window and cross attention are not supported together"
|
), "Sliding window and cross attention are not supported together"
|
||||||
|
|
||||||
self.forward_metadata: FlashAttentionMetadata = None
|
self.forward_metadata: FlashAttentionMetadata = None
|
||||||
|
# extra metdata for handling speculative decoding topk > 1, extended draft decode and verify
|
||||||
|
self.forward_metadata_spec_decode_expand: FlashAttentionMetadata = None
|
||||||
self.max_context_len = model_runner.model_config.context_len
|
self.max_context_len = model_runner.model_config.context_len
|
||||||
self.device = model_runner.device
|
self.device = model_runner.device
|
||||||
self.decode_cuda_graph_metadata = {}
|
self.decode_cuda_graph_metadata = {}
|
||||||
@@ -311,8 +320,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
self.page_size = model_runner.page_size
|
self.page_size = model_runner.page_size
|
||||||
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
|
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
|
||||||
self.skip_prefill = skip_prefill
|
self.skip_prefill = skip_prefill
|
||||||
|
self.topk = model_runner.server_args.speculative_eagle_topk or 0
|
||||||
self.topk = topk
|
|
||||||
self.speculative_num_steps = speculative_num_steps
|
self.speculative_num_steps = speculative_num_steps
|
||||||
self.speculative_num_draft_tokens = (
|
self.speculative_num_draft_tokens = (
|
||||||
model_runner.server_args.speculative_num_draft_tokens
|
model_runner.server_args.speculative_num_draft_tokens
|
||||||
@@ -336,6 +344,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
if forward_batch.forward_mode.is_decode_or_idle():
|
if forward_batch.forward_mode.is_decode_or_idle():
|
||||||
# Draft Decode
|
# Draft Decode
|
||||||
if forward_batch.spec_info is not None:
|
if forward_batch.spec_info is not None:
|
||||||
|
if self.topk <= 1:
|
||||||
metadata.cache_seqlens_int32 = (
|
metadata.cache_seqlens_int32 = (
|
||||||
seqlens_in_batch + (self.speculative_step_id + 1)
|
seqlens_in_batch + (self.speculative_step_id + 1)
|
||||||
).to(torch.int32)
|
).to(torch.int32)
|
||||||
@@ -354,8 +363,57 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||||
]
|
]
|
||||||
|
else:
|
||||||
|
metadata.cache_seqlens_int32 = (seqlens_in_batch).to(torch.int32)
|
||||||
|
metadata.max_seq_len_q = self.topk
|
||||||
|
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
||||||
|
metadata.cu_seqlens_q = torch.arange(
|
||||||
|
0,
|
||||||
|
batch_size * self.topk + 1,
|
||||||
|
step=self.topk,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||||
|
torch.cumsum(
|
||||||
|
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
||||||
|
),
|
||||||
|
(1, 0),
|
||||||
|
)
|
||||||
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||||
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||||
|
]
|
||||||
|
|
||||||
self._init_local_attn_metadata(metadata, device)
|
metadata_expand = FlashAttentionMetadata()
|
||||||
|
decode_length = self.speculative_step_id + 1
|
||||||
|
metadata_expand.cache_seqlens_int32 = torch.full(
|
||||||
|
(seqlens_in_batch.numel() * self.topk,),
|
||||||
|
decode_length,
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32,
|
||||||
|
)
|
||||||
|
metadata_expand.max_seq_len_q = 1
|
||||||
|
metadata_expand.max_seq_len_k = self.speculative_step_id + 1
|
||||||
|
metadata_expand.cu_seqlens_q = torch.arange(
|
||||||
|
0,
|
||||||
|
metadata_expand.cache_seqlens_int32.numel() + 1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
metadata_expand.cu_seqlens_k = torch.arange(
|
||||||
|
0,
|
||||||
|
metadata_expand.cache_seqlens_int32.numel() * decode_length + 1,
|
||||||
|
step=decode_length,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
cache_loc = forward_batch.out_cache_loc.view(
|
||||||
|
self.speculative_num_steps, -1
|
||||||
|
).T.contiguous()
|
||||||
|
metadata_expand.page_table = (
|
||||||
|
cache_loc[:, :decode_length].contiguous().to(torch.int32)
|
||||||
|
)
|
||||||
|
self.forward_metadata_spec_decode_expand = metadata_expand
|
||||||
else:
|
else:
|
||||||
# Normal Decode
|
# Normal Decode
|
||||||
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
||||||
@@ -369,9 +427,10 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||||
]
|
]
|
||||||
|
# TODO: we need to test this part for llama 4 eagle case
|
||||||
self._init_local_attn_metadata(metadata, device)
|
self._init_local_attn_metadata(metadata, device)
|
||||||
elif forward_batch.forward_mode.is_target_verify():
|
elif forward_batch.forward_mode.is_target_verify():
|
||||||
|
if self.topk <= 1:
|
||||||
metadata.cache_seqlens_int32 = (
|
metadata.cache_seqlens_int32 = (
|
||||||
forward_batch.seq_lens + self.speculative_num_draft_tokens
|
forward_batch.seq_lens + self.speculative_num_draft_tokens
|
||||||
).to(torch.int32)
|
).to(torch.int32)
|
||||||
@@ -388,13 +447,112 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||||
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32),
|
torch.cumsum(
|
||||||
|
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
||||||
|
),
|
||||||
(1, 0),
|
(1, 0),
|
||||||
)
|
)
|
||||||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||||
]
|
]
|
||||||
|
|
||||||
|
self._init_local_attn_metadata(metadata, device)
|
||||||
|
else:
|
||||||
|
metadata.cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
|
||||||
|
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
||||||
|
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
||||||
|
metadata.cu_seqlens_q = torch.arange(
|
||||||
|
0,
|
||||||
|
batch_size * self.speculative_num_draft_tokens + 1,
|
||||||
|
step=self.speculative_num_draft_tokens,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||||
|
torch.cumsum(
|
||||||
|
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
||||||
|
),
|
||||||
|
(1, 0),
|
||||||
|
)
|
||||||
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||||
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||||
|
]
|
||||||
|
|
||||||
|
metadata_expand = FlashAttentionMetadata()
|
||||||
|
|
||||||
|
metadata_expand.max_seq_len_q = 1
|
||||||
|
metadata_expand.cu_seqlens_q = torch.arange(
|
||||||
|
0,
|
||||||
|
forward_batch.seq_lens.numel() * self.speculative_num_draft_tokens
|
||||||
|
+ 1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# create expand page table
|
||||||
|
offsets = torch.arange(
|
||||||
|
self.speculative_num_draft_tokens, device=device
|
||||||
|
).unsqueeze(
|
||||||
|
0
|
||||||
|
) # shape: (1, self.speculative_num_draft_tokens)
|
||||||
|
cols = offsets.expand(
|
||||||
|
forward_batch.seq_lens.numel(), -1
|
||||||
|
) + forward_batch.seq_lens.unsqueeze(1)
|
||||||
|
cum_len = torch.nn.functional.pad(
|
||||||
|
torch.cumsum(
|
||||||
|
(
|
||||||
|
forward_batch.seq_lens + self.speculative_num_draft_tokens
|
||||||
|
).repeat_interleave(self.speculative_num_draft_tokens),
|
||||||
|
dim=0,
|
||||||
|
),
|
||||||
|
(1, 0),
|
||||||
|
)[:-1]
|
||||||
|
mask_extraction_indices = (
|
||||||
|
cols.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
|
||||||
|
+ cum_len[:, None]
|
||||||
|
).view(1, -1)
|
||||||
|
mask = forward_batch.spec_info.custom_mask[
|
||||||
|
mask_extraction_indices
|
||||||
|
].view(
|
||||||
|
-1, self.speculative_num_draft_tokens
|
||||||
|
) # (bsz * draft_num, draft_num)
|
||||||
|
|
||||||
|
# shift table indices to avoid padding
|
||||||
|
# non_masked_page_table [[8, 9, 10], mask (display with int format) [[1, 0, 0],
|
||||||
|
# [8, 9, 10], [1, 1, 0],
|
||||||
|
# [8, 9, 10]] [1, 0, 1]]
|
||||||
|
# if masked with padding [[8, 0, 0], our mask without padding [[8, 9, 10],
|
||||||
|
# [8, 9, 0], [8, 9, 10],
|
||||||
|
# [8, 0, 10]] [8, 10, 9]]
|
||||||
|
# note here cache_seqlens_int32 is [1, 2, 2] so extra page indices will be ignored in each row
|
||||||
|
col_indices = offsets.expand(
|
||||||
|
mask.shape[0], self.speculative_num_draft_tokens
|
||||||
|
)
|
||||||
|
# Build keys: if an entry is valid (mask==True), keep its original index;
|
||||||
|
# if not, add self.speculative_num_draft_tokens so that it sorts after all valid entries.
|
||||||
|
keys = torch.where(
|
||||||
|
mask, col_indices, col_indices + self.speculative_num_draft_tokens
|
||||||
|
)
|
||||||
|
_, sort_order = torch.sort(keys, dim=1)
|
||||||
|
non_masked_page_table = (
|
||||||
|
forward_batch.req_to_token_pool.req_to_token[
|
||||||
|
forward_batch.req_pool_indices, :
|
||||||
|
]
|
||||||
|
.gather(1, cols)
|
||||||
|
.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
|
||||||
|
) # (bsz, draft_num)
|
||||||
|
metadata_expand.page_table = non_masked_page_table.gather(1, sort_order)
|
||||||
|
metadata_expand.cache_seqlens_int32 = mask.sum(dim=1).to(torch.int32)
|
||||||
|
metadata_expand.cu_seqlens_k = torch.nn.functional.pad(
|
||||||
|
torch.cumsum(
|
||||||
|
metadata_expand.cache_seqlens_int32, dim=0, dtype=torch.int32
|
||||||
|
),
|
||||||
|
(1, 0),
|
||||||
|
)
|
||||||
|
metadata_expand.max_seq_len_k = (
|
||||||
|
metadata_expand.cache_seqlens_int32.max().item()
|
||||||
|
)
|
||||||
|
self.forward_metadata_spec_decode_expand = metadata_expand
|
||||||
elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
|
elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
|
||||||
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
||||||
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
||||||
@@ -514,6 +672,11 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
and (hasattr(layer, "use_irope") and layer.use_irope)
|
and (hasattr(layer, "use_irope") and layer.use_irope)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# We do cascade attention for Target Verify with topk > 1
|
||||||
|
use_cascade_attn = (
|
||||||
|
forward_batch.forward_mode.is_target_verify() and self.topk > 1
|
||||||
|
)
|
||||||
|
|
||||||
# Get the appropriate page table based on whether we're using local attention
|
# Get the appropriate page table based on whether we're using local attention
|
||||||
if use_local_attn:
|
if use_local_attn:
|
||||||
local_metadata = metadata.local_attn_metadata
|
local_metadata = metadata.local_attn_metadata
|
||||||
@@ -548,7 +711,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
cu_seqlens_k = metadata.encoder_cu_seqlens_k
|
cu_seqlens_k = metadata.encoder_cu_seqlens_k
|
||||||
window_size = (-1, -1)
|
window_size = (-1, -1)
|
||||||
|
|
||||||
o = flash_attn_with_kvcache(
|
result = flash_attn_with_kvcache(
|
||||||
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||||
k_cache=key_cache,
|
k_cache=key_cache,
|
||||||
v_cache=value_cache,
|
v_cache=value_cache,
|
||||||
@@ -558,13 +721,41 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
||||||
max_seqlen_q=max_seqlen_q,
|
max_seqlen_q=max_seqlen_q,
|
||||||
softmax_scale=layer.scaling,
|
softmax_scale=layer.scaling,
|
||||||
causal=causal,
|
causal=False if use_cascade_attn else causal,
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
softcap=layer.logit_cap,
|
softcap=layer.logit_cap,
|
||||||
k_descale=k_descale,
|
k_descale=k_descale,
|
||||||
v_descale=v_descale,
|
v_descale=v_descale,
|
||||||
|
return_softmax_lse=use_cascade_attn,
|
||||||
)
|
)
|
||||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
|
||||||
|
if use_cascade_attn:
|
||||||
|
o, softmax_lse, *rest = result
|
||||||
|
o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(
|
||||||
|
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||||
|
k_cache=key_cache,
|
||||||
|
v_cache=value_cache,
|
||||||
|
page_table=self.forward_metadata_spec_decode_expand.page_table,
|
||||||
|
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
|
||||||
|
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
|
||||||
|
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
|
||||||
|
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
|
||||||
|
softmax_scale=layer.scaling,
|
||||||
|
causal=False,
|
||||||
|
window_size=window_size,
|
||||||
|
softcap=layer.logit_cap,
|
||||||
|
k_descale=k_descale,
|
||||||
|
v_descale=v_descale,
|
||||||
|
return_softmax_lse=True,
|
||||||
|
)
|
||||||
|
o, _ = merge_state_v2_wrapper(
|
||||||
|
o,
|
||||||
|
softmax_lse.T.contiguous(),
|
||||||
|
o_expand,
|
||||||
|
softmax_lse_expand.T.contiguous(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
o = result
|
||||||
else:
|
else:
|
||||||
if (
|
if (
|
||||||
not global_server_args_dict["disable_chunked_prefix_cache"]
|
not global_server_args_dict["disable_chunked_prefix_cache"]
|
||||||
@@ -627,7 +818,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
q_nope = q_all[:, :, : layer.v_head_dim]
|
q_nope = q_all[:, :, : layer.v_head_dim]
|
||||||
q_rope = q_all[:, :, layer.v_head_dim :]
|
q_rope = q_all[:, :, layer.v_head_dim :]
|
||||||
o = flash_attn_with_kvcache(
|
|
||||||
|
result = flash_attn_with_kvcache(
|
||||||
q=q_rope,
|
q=q_rope,
|
||||||
k_cache=k_rope_cache,
|
k_cache=k_rope_cache,
|
||||||
v_cache=c_kv_cache,
|
v_cache=c_kv_cache,
|
||||||
@@ -638,11 +830,42 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
||||||
max_seqlen_q=max_seqlen_q,
|
max_seqlen_q=max_seqlen_q,
|
||||||
softmax_scale=layer.scaling,
|
softmax_scale=layer.scaling,
|
||||||
causal=True,
|
causal=False if use_cascade_attn else causal,
|
||||||
softcap=layer.logit_cap,
|
softcap=layer.logit_cap,
|
||||||
k_descale=k_descale,
|
k_descale=k_descale,
|
||||||
v_descale=v_descale,
|
v_descale=v_descale,
|
||||||
|
return_softmax_lse=use_cascade_attn,
|
||||||
)
|
)
|
||||||
|
if use_cascade_attn:
|
||||||
|
o, softmax_lse, *rest = result
|
||||||
|
o_expand, softmax_lse_expand, *rest_expand = (
|
||||||
|
flash_attn_with_kvcache(
|
||||||
|
q=q_rope,
|
||||||
|
k_cache=k_rope_cache,
|
||||||
|
v_cache=c_kv_cache,
|
||||||
|
qv=q_nope,
|
||||||
|
page_table=self.forward_metadata_spec_decode_expand.page_table,
|
||||||
|
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
|
||||||
|
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
|
||||||
|
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
|
||||||
|
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
|
||||||
|
softmax_scale=layer.scaling,
|
||||||
|
causal=False,
|
||||||
|
window_size=window_size,
|
||||||
|
softcap=layer.logit_cap,
|
||||||
|
k_descale=k_descale,
|
||||||
|
v_descale=v_descale,
|
||||||
|
return_softmax_lse=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
o, _ = merge_state_v2_wrapper(
|
||||||
|
o,
|
||||||
|
softmax_lse.T.contiguous(),
|
||||||
|
o_expand,
|
||||||
|
softmax_lse_expand.T.contiguous(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
o = result
|
||||||
|
|
||||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||||
|
|
||||||
@@ -681,6 +904,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
use_local_attention = (
|
use_local_attention = (
|
||||||
self.attention_chunk_size is not None and local_attn_metadata is not None
|
self.attention_chunk_size is not None and local_attn_metadata is not None
|
||||||
)
|
)
|
||||||
|
# We do cascade attention for Draft Decode with topk > 1
|
||||||
|
use_cascade_attn = self.topk > 1
|
||||||
|
|
||||||
# Calculate window size (can be moved to metadata if layer properties don't change)
|
# Calculate window size (can be moved to metadata if layer properties don't change)
|
||||||
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
||||||
@@ -752,23 +977,61 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
v_descale=v_descale,
|
v_descale=v_descale,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
page_table = metadata.page_table
|
||||||
|
cache_seqlens = metadata.cache_seqlens_int32
|
||||||
|
cu_seqlens_k = metadata.cu_seqlens_k
|
||||||
|
max_seqlen_q = metadata.max_seq_len_q
|
||||||
|
q_reshaped = q.contiguous().view(
|
||||||
|
-1, layer.tp_q_head_num, layer.head_dim
|
||||||
|
)
|
||||||
|
|
||||||
# Default: single-token self-attention
|
# Default: single-token self-attention
|
||||||
o = flash_attn_with_kvcache(
|
result = flash_attn_with_kvcache(
|
||||||
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
q=q_reshaped,
|
||||||
k_cache=key_cache,
|
k_cache=key_cache,
|
||||||
v_cache=value_cache,
|
v_cache=value_cache,
|
||||||
page_table=metadata.page_table,
|
page_table=page_table,
|
||||||
cache_seqlens=metadata.cache_seqlens_int32,
|
cache_seqlens=cache_seqlens,
|
||||||
cu_seqlens_q=metadata.cu_seqlens_q,
|
cu_seqlens_q=metadata.cu_seqlens_q,
|
||||||
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
cu_seqlens_k_new=cu_seqlens_k,
|
||||||
max_seqlen_q=1,
|
max_seqlen_q=max_seqlen_q,
|
||||||
softmax_scale=layer.scaling,
|
softmax_scale=layer.scaling,
|
||||||
causal=True,
|
causal=False if use_cascade_attn else causal,
|
||||||
window_size=window_size,
|
window_size=window_size,
|
||||||
softcap=layer.logit_cap,
|
softcap=layer.logit_cap,
|
||||||
k_descale=k_descale,
|
k_descale=k_descale,
|
||||||
v_descale=v_descale,
|
v_descale=v_descale,
|
||||||
|
return_softmax_lse=use_cascade_attn,
|
||||||
)
|
)
|
||||||
|
if use_cascade_attn:
|
||||||
|
o, softmax_lse, *rest = result
|
||||||
|
o_expand, softmax_lse_expand, *rest_expand = (
|
||||||
|
flash_attn_with_kvcache(
|
||||||
|
q=q_reshaped,
|
||||||
|
k_cache=key_cache,
|
||||||
|
v_cache=value_cache,
|
||||||
|
page_table=self.forward_metadata_spec_decode_expand.page_table,
|
||||||
|
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
|
||||||
|
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
|
||||||
|
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
|
||||||
|
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
|
||||||
|
softmax_scale=layer.scaling,
|
||||||
|
causal=False,
|
||||||
|
window_size=window_size,
|
||||||
|
softcap=layer.logit_cap,
|
||||||
|
k_descale=k_descale,
|
||||||
|
v_descale=v_descale,
|
||||||
|
return_softmax_lse=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
o, _ = merge_state_v2(
|
||||||
|
o,
|
||||||
|
softmax_lse.T.contiguous(),
|
||||||
|
o_expand,
|
||||||
|
softmax_lse_expand.T.contiguous(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
o = result
|
||||||
else:
|
else:
|
||||||
# Do absorbed multi-latent attention
|
# Do absorbed multi-latent attention
|
||||||
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||||
@@ -787,8 +1050,9 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||||
q_nope = q_all[:, :, : layer.v_head_dim]
|
q_nope = q_all[:, :, : layer.v_head_dim]
|
||||||
q_rope = q_all[:, :, layer.v_head_dim :]
|
q_rope = q_all[:, :, layer.v_head_dim :]
|
||||||
|
max_seqlen_q = metadata.max_seq_len_q
|
||||||
|
|
||||||
o = flash_attn_with_kvcache(
|
result = flash_attn_with_kvcache(
|
||||||
q=q_rope,
|
q=q_rope,
|
||||||
k_cache=k_rope_cache,
|
k_cache=k_rope_cache,
|
||||||
v_cache=c_kv_cache,
|
v_cache=c_kv_cache,
|
||||||
@@ -797,13 +1061,43 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
cache_seqlens=metadata.cache_seqlens_int32,
|
cache_seqlens=metadata.cache_seqlens_int32,
|
||||||
cu_seqlens_q=metadata.cu_seqlens_q,
|
cu_seqlens_q=metadata.cu_seqlens_q,
|
||||||
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
||||||
max_seqlen_q=1,
|
max_seqlen_q=max_seqlen_q,
|
||||||
softmax_scale=layer.scaling,
|
softmax_scale=layer.scaling,
|
||||||
causal=True,
|
causal=False if use_cascade_attn else causal,
|
||||||
softcap=layer.logit_cap,
|
softcap=layer.logit_cap,
|
||||||
k_descale=k_descale,
|
k_descale=k_descale,
|
||||||
v_descale=v_descale,
|
v_descale=v_descale,
|
||||||
|
return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states
|
||||||
)
|
)
|
||||||
|
if use_cascade_attn:
|
||||||
|
o, softmax_lse, *rest = result
|
||||||
|
o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(
|
||||||
|
q=q_rope,
|
||||||
|
k_cache=k_rope_cache,
|
||||||
|
v_cache=c_kv_cache,
|
||||||
|
qv=q_nope,
|
||||||
|
page_table=self.forward_metadata_spec_decode_expand.page_table,
|
||||||
|
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
|
||||||
|
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
|
||||||
|
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
|
||||||
|
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
|
||||||
|
softmax_scale=layer.scaling,
|
||||||
|
causal=False,
|
||||||
|
window_size=window_size,
|
||||||
|
softcap=layer.logit_cap,
|
||||||
|
k_descale=k_descale,
|
||||||
|
v_descale=v_descale,
|
||||||
|
return_softmax_lse=True,
|
||||||
|
)
|
||||||
|
o, _ = merge_state_v2(
|
||||||
|
o,
|
||||||
|
softmax_lse.T.contiguous(),
|
||||||
|
o_expand,
|
||||||
|
softmax_lse_expand.T.contiguous(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
o = result
|
||||||
|
|
||||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||||
|
|
||||||
def init_cuda_graph_state(self, max_bs: int):
|
def init_cuda_graph_state(self, max_bs: int):
|
||||||
@@ -815,6 +1109,8 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
This creates fixed-size tensors that will be reused during CUDA graph replay
|
This creates fixed-size tensors that will be reused during CUDA graph replay
|
||||||
to avoid memory allocations.
|
to avoid memory allocations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# This is being used by normal decode and draft decode when topk == 1
|
||||||
self.decode_cuda_graph_metadata = {
|
self.decode_cuda_graph_metadata = {
|
||||||
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
||||||
"cu_seqlens_q": torch.arange(
|
"cu_seqlens_q": torch.arange(
|
||||||
@@ -840,11 +1136,75 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
self.target_verify_metadata = {
|
# This is used by draft decode's first half of metadata when topk > 1
|
||||||
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
if self.topk > 1:
|
||||||
"cu_seqlens_q": torch.zeros(
|
self.draft_decode_metadata_topk_normal = {
|
||||||
|
"cache_seqlens": torch.zeros(
|
||||||
|
max_bs, dtype=torch.int32, device=self.device
|
||||||
|
),
|
||||||
|
"cu_seqlens_q": torch.arange(
|
||||||
|
0,
|
||||||
|
max_bs * self.topk + 1,
|
||||||
|
step=self.topk,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
"cu_seqlens_k": torch.zeros(
|
||||||
max_bs + 1, dtype=torch.int32, device=self.device
|
max_bs + 1, dtype=torch.int32, device=self.device
|
||||||
),
|
),
|
||||||
|
"page_table": torch.zeros(
|
||||||
|
max_bs,
|
||||||
|
self.max_context_len,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
# This is used by draft decode's second half of metadata when topk > 1
|
||||||
|
decode_length = self.speculative_step_id + 1
|
||||||
|
self.draft_decode_metadata_topk_expand = {
|
||||||
|
"cache_seqlens": torch.full(
|
||||||
|
(max_bs * self.topk,),
|
||||||
|
decode_length,
|
||||||
|
device=self.device,
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
"cu_seqlens_q": torch.arange(
|
||||||
|
0,
|
||||||
|
max_bs * self.topk + 1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
"cu_seqlens_k": torch.arange(
|
||||||
|
0,
|
||||||
|
max_bs * self.topk * decode_length + 1,
|
||||||
|
step=decode_length,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
"page_table": torch.zeros(
|
||||||
|
max_bs * self.topk,
|
||||||
|
decode_length,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.speculative_num_draft_tokens is not None
|
||||||
|
and self.speculative_num_draft_tokens > 0
|
||||||
|
):
|
||||||
|
self.target_verify_metadata = {
|
||||||
|
"cache_seqlens": torch.zeros(
|
||||||
|
max_bs, dtype=torch.int32, device=self.device
|
||||||
|
),
|
||||||
|
"cu_seqlens_q": torch.arange(
|
||||||
|
0,
|
||||||
|
max_bs * self.speculative_num_draft_tokens + 1,
|
||||||
|
step=self.speculative_num_draft_tokens,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
"cu_seqlens_k": torch.zeros(
|
"cu_seqlens_k": torch.zeros(
|
||||||
max_bs + 1, dtype=torch.int32, device=self.device
|
max_bs + 1, dtype=torch.int32, device=self.device
|
||||||
),
|
),
|
||||||
@@ -859,6 +1219,54 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.topk > 1:
|
||||||
|
self.target_verify_metadata_topk_normal = {
|
||||||
|
"cache_seqlens": torch.zeros(
|
||||||
|
max_bs, dtype=torch.int32, device=self.device
|
||||||
|
),
|
||||||
|
"cu_seqlens_q": torch.arange(
|
||||||
|
0,
|
||||||
|
max_bs * self.speculative_num_draft_tokens + 1,
|
||||||
|
step=self.speculative_num_draft_tokens,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
"cu_seqlens_k": torch.zeros(
|
||||||
|
max_bs + 1, dtype=torch.int32, device=self.device
|
||||||
|
),
|
||||||
|
"page_table": torch.zeros(
|
||||||
|
max_bs,
|
||||||
|
self.max_context_len,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
self.target_verify_metadata_topk_expand = {
|
||||||
|
"cache_seqlens": torch.zeros(
|
||||||
|
max_bs * self.speculative_num_draft_tokens,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
"cu_seqlens_k": torch.zeros(
|
||||||
|
max_bs * self.speculative_num_draft_tokens + 1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
"cu_seqlens_q": torch.arange(
|
||||||
|
0,
|
||||||
|
max_bs * self.speculative_num_draft_tokens + 1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
"page_table": torch.zeros(
|
||||||
|
max_bs * self.speculative_num_draft_tokens,
|
||||||
|
self.speculative_num_draft_tokens,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
self.encoder_metadata = {
|
self.encoder_metadata = {
|
||||||
"encoder_page_table": torch.zeros(
|
"encoder_page_table": torch.zeros(
|
||||||
max_bs,
|
max_bs,
|
||||||
@@ -886,19 +1294,25 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
):
|
):
|
||||||
"""Initialize forward metadata for capturing CUDA graph."""
|
"""Initialize forward metadata for capturing CUDA graph."""
|
||||||
metadata = FlashAttentionMetadata()
|
metadata = FlashAttentionMetadata()
|
||||||
|
|
||||||
|
# metadata_expand is needed for Spec Decoding when top k > 1
|
||||||
|
metadata_expand = FlashAttentionMetadata()
|
||||||
|
|
||||||
device = seq_lens.device
|
device = seq_lens.device
|
||||||
if forward_mode.is_decode_or_idle():
|
if forward_mode.is_decode_or_idle():
|
||||||
if spec_info is not None:
|
if spec_info is not None:
|
||||||
# Draft Decode
|
# Draft Decode
|
||||||
|
if self.topk <= 1:
|
||||||
|
# When topk = 1, we use the normal decode metadata
|
||||||
metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
|
metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
|
||||||
"cache_seqlens"
|
"cache_seqlens"
|
||||||
][:bs]
|
][:bs]
|
||||||
metadata.max_seq_len_k = seq_lens.max().item() + (
|
metadata.max_seq_len_k = seq_lens.max().item() + (
|
||||||
self.speculative_step_id + 1
|
self.speculative_step_id + 1
|
||||||
)
|
)
|
||||||
metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][
|
metadata.cu_seqlens_q = self.decode_cuda_graph_metadata[
|
||||||
: bs + 1
|
"cu_seqlens_q"
|
||||||
]
|
][: bs + 1]
|
||||||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||||
torch.cumsum(
|
torch.cumsum(
|
||||||
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
||||||
@@ -908,6 +1322,50 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
metadata.page_table = self.decode_cuda_graph_metadata[
|
metadata.page_table = self.decode_cuda_graph_metadata[
|
||||||
"page_table_draft_decode"
|
"page_table_draft_decode"
|
||||||
][req_pool_indices, :]
|
][req_pool_indices, :]
|
||||||
|
self.decode_cuda_graph_metadata[bs] = metadata
|
||||||
|
else:
|
||||||
|
# When top k > 1, we need two specific draft decode metadata, and then merge states
|
||||||
|
# 1. The first half of metadata for prefix tokens
|
||||||
|
metadata.cache_seqlens_int32 = (
|
||||||
|
self.draft_decode_metadata_topk_normal["cache_seqlens"][:bs]
|
||||||
|
)
|
||||||
|
metadata.max_seq_len_q = self.topk
|
||||||
|
metadata.max_seq_len_k = seq_lens.max().item()
|
||||||
|
metadata.cu_seqlens_q = self.draft_decode_metadata_topk_normal[
|
||||||
|
"cu_seqlens_q"
|
||||||
|
][: bs + 1]
|
||||||
|
metadata.cu_seqlens_k = self.draft_decode_metadata_topk_normal[
|
||||||
|
"cu_seqlens_k"
|
||||||
|
][: bs + 1]
|
||||||
|
metadata.page_table = self.draft_decode_metadata_topk_normal[
|
||||||
|
"page_table"
|
||||||
|
][req_pool_indices, :]
|
||||||
|
|
||||||
|
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
||||||
|
metadata_expand.cache_seqlens_int32 = (
|
||||||
|
self.draft_decode_metadata_topk_expand["cache_seqlens"][
|
||||||
|
: bs * self.topk
|
||||||
|
]
|
||||||
|
)
|
||||||
|
metadata_expand.max_seq_len_q = 1
|
||||||
|
metadata_expand.max_seq_len_k = (
|
||||||
|
self.speculative_step_id + 1
|
||||||
|
) # , do this in replay
|
||||||
|
metadata_expand.cu_seqlens_q = (
|
||||||
|
self.draft_decode_metadata_topk_expand["cu_seqlens_q"][
|
||||||
|
: bs * self.topk + 1
|
||||||
|
]
|
||||||
|
)
|
||||||
|
metadata_expand.cu_seqlens_k = (
|
||||||
|
self.draft_decode_metadata_topk_expand["cu_seqlens_k"][
|
||||||
|
: bs * self.topk + 1
|
||||||
|
]
|
||||||
|
)
|
||||||
|
metadata_expand.page_table = self.draft_decode_metadata_topk_expand[
|
||||||
|
"page_table"
|
||||||
|
][: bs * self.topk]
|
||||||
|
self.draft_decode_metadata_topk_normal[bs] = metadata
|
||||||
|
self.draft_decode_metadata_topk_expand[bs] = metadata_expand
|
||||||
else:
|
else:
|
||||||
# Normal Decode
|
# Normal Decode
|
||||||
# Get sequence information
|
# Get sequence information
|
||||||
@@ -928,10 +1386,12 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
0, batch_size + 1, dtype=torch.int32, device=device
|
0, batch_size + 1, dtype=torch.int32, device=device
|
||||||
)
|
)
|
||||||
self.decode_cuda_graph_metadata[bs] = metadata
|
self.decode_cuda_graph_metadata[bs] = metadata
|
||||||
|
|
||||||
elif forward_mode.is_target_verify():
|
elif forward_mode.is_target_verify():
|
||||||
metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][
|
if self.topk <= 1:
|
||||||
:bs
|
metadata.cache_seqlens_int32 = self.target_verify_metadata[
|
||||||
]
|
"cache_seqlens"
|
||||||
|
][:bs]
|
||||||
metadata.cache_seqlens_int32.copy_(
|
metadata.cache_seqlens_int32.copy_(
|
||||||
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
|
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
|
||||||
)
|
)
|
||||||
@@ -958,6 +1418,44 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
]
|
]
|
||||||
|
|
||||||
self.target_verify_metadata[bs] = metadata
|
self.target_verify_metadata[bs] = metadata
|
||||||
|
else:
|
||||||
|
# When topk > 1, we need two specific target verify metadata, and then merge states
|
||||||
|
# 1. The first half of metadata for prefix tokens
|
||||||
|
metadata.cache_seqlens_int32 = self.target_verify_metadata_topk_normal[
|
||||||
|
"cache_seqlens"
|
||||||
|
][:bs]
|
||||||
|
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
||||||
|
# metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item(), do this in replay
|
||||||
|
metadata.cu_seqlens_q = self.target_verify_metadata_topk_normal[
|
||||||
|
"cu_seqlens_q"
|
||||||
|
][: bs + 1]
|
||||||
|
metadata.cu_seqlens_k = self.target_verify_metadata_topk_normal[
|
||||||
|
"cu_seqlens_k"
|
||||||
|
][: bs + 1]
|
||||||
|
metadata.page_table = self.target_verify_metadata_topk_normal[
|
||||||
|
"page_table"
|
||||||
|
][req_pool_indices, :]
|
||||||
|
|
||||||
|
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
||||||
|
metadata_expand.cache_seqlens_int32 = (
|
||||||
|
self.target_verify_metadata_topk_expand["cache_seqlens"][
|
||||||
|
: bs * self.speculative_num_draft_tokens
|
||||||
|
]
|
||||||
|
)
|
||||||
|
metadata_expand.max_seq_len_q = 1
|
||||||
|
metadata_expand.cu_seqlens_q = self.target_verify_metadata_topk_expand[
|
||||||
|
"cu_seqlens_q"
|
||||||
|
][: bs * self.speculative_num_draft_tokens + 1]
|
||||||
|
metadata_expand.cu_seqlens_k = self.target_verify_metadata_topk_expand[
|
||||||
|
"cu_seqlens_k"
|
||||||
|
][: bs * self.speculative_num_draft_tokens + 1]
|
||||||
|
|
||||||
|
metadata_expand.page_table = self.target_verify_metadata_topk_expand[
|
||||||
|
"page_table"
|
||||||
|
][: bs * self.speculative_num_draft_tokens]
|
||||||
|
|
||||||
|
self.target_verify_metadata_topk_normal[bs] = metadata
|
||||||
|
self.target_verify_metadata_topk_expand[bs] = metadata_expand
|
||||||
|
|
||||||
if encoder_lens is not None:
|
if encoder_lens is not None:
|
||||||
encoder_bs = encoder_lens.numel()
|
encoder_bs = encoder_lens.numel()
|
||||||
@@ -973,6 +1471,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
]
|
]
|
||||||
|
|
||||||
self.forward_metadata = metadata
|
self.forward_metadata = metadata
|
||||||
|
self.forward_metadata_spec_decode_expand = metadata_expand
|
||||||
|
|
||||||
def init_forward_metadata_replay_cuda_graph(
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
self,
|
self,
|
||||||
@@ -986,17 +1485,21 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
seq_lens_cpu: Optional[torch.Tensor],
|
seq_lens_cpu: Optional[torch.Tensor],
|
||||||
out_cache_loc: torch.Tensor = None,
|
out_cache_loc: torch.Tensor = None,
|
||||||
):
|
):
|
||||||
# """Initialize forward metadata for replaying CUDA graph."""
|
"""Initialize forward metadata for replaying CUDA graph."""
|
||||||
seq_lens = seq_lens[:bs]
|
seq_lens = seq_lens[:bs]
|
||||||
seq_lens_cpu = seq_lens_cpu[:bs]
|
seq_lens_cpu = seq_lens_cpu[:bs]
|
||||||
req_pool_indices = req_pool_indices[:bs]
|
req_pool_indices = req_pool_indices[:bs]
|
||||||
device = seq_lens.device
|
device = seq_lens.device
|
||||||
|
metadata = None
|
||||||
|
metadata_expand = None
|
||||||
|
|
||||||
if forward_mode.is_decode_or_idle():
|
if forward_mode.is_decode_or_idle():
|
||||||
metadata = self.decode_cuda_graph_metadata[bs]
|
|
||||||
|
|
||||||
if spec_info is not None:
|
if spec_info is not None:
|
||||||
# Draft Decode
|
# Draft Decode
|
||||||
|
if self.topk <= 1:
|
||||||
|
metadata = self.decode_cuda_graph_metadata[bs]
|
||||||
|
# When topk = 1, we use the normal decode metadata
|
||||||
metadata.cache_seqlens_int32.copy_(
|
metadata.cache_seqlens_int32.copy_(
|
||||||
(seq_lens + (self.speculative_step_id + 1)).to(torch.int32)
|
(seq_lens + (self.speculative_step_id + 1)).to(torch.int32)
|
||||||
)
|
)
|
||||||
@@ -1013,14 +1516,54 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
max_seq_pages = (
|
||||||
|
metadata.max_seq_len_k + self.page_size - 1
|
||||||
|
) // self.page_size
|
||||||
|
page_indices = self.req_to_token[
|
||||||
|
req_pool_indices[:, None],
|
||||||
|
self.decode_cuda_graph_metadata["strided_indices"][
|
||||||
|
:max_seq_pages
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
|
page_indices //= self.page_size
|
||||||
|
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
||||||
|
else:
|
||||||
|
# When top k > 1, we need two specific draft decode metadata, and then merge states
|
||||||
|
# 1. The first half of metadata for prefix tokens
|
||||||
|
metadata = self.draft_decode_metadata_topk_normal[bs]
|
||||||
|
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
|
||||||
|
# metadata.max_seq_len_q = self.topk, already set in capture
|
||||||
|
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
||||||
|
# metadata.cu_seqlens_q already set in capture
|
||||||
|
metadata.cu_seqlens_k.copy_(
|
||||||
|
torch.nn.functional.pad(
|
||||||
|
torch.cumsum(
|
||||||
|
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
||||||
|
),
|
||||||
|
(1, 0),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
page_table = self.req_to_token[
|
page_table = self.req_to_token[
|
||||||
req_pool_indices, : metadata.max_seq_len_k
|
req_pool_indices, : metadata.max_seq_len_k
|
||||||
]
|
]
|
||||||
|
|
||||||
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
||||||
|
|
||||||
|
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
||||||
|
metadata_expand = self.draft_decode_metadata_topk_expand[bs]
|
||||||
|
decode_length = self.speculative_step_id + 1
|
||||||
|
cache_loc = out_cache_loc.view(
|
||||||
|
self.speculative_num_steps, -1
|
||||||
|
).T.contiguous()
|
||||||
|
metadata_expand.page_table[: cache_loc.shape[0]].copy_(
|
||||||
|
cache_loc[:, :decode_length].contiguous().to(torch.int32)
|
||||||
|
)
|
||||||
|
# TODO: we need to test this part for llama 4 eagle case
|
||||||
self._init_local_attn_metadata(metadata, device)
|
self._init_local_attn_metadata(metadata, device)
|
||||||
else:
|
else:
|
||||||
|
metadata = self.decode_cuda_graph_metadata[bs]
|
||||||
# Normal Decode
|
# Normal Decode
|
||||||
max_len = seq_lens_cpu.max().item()
|
max_len = seq_lens_cpu.max().item()
|
||||||
metadata.max_seq_len_k = max_len
|
metadata.max_seq_len_k = max_len
|
||||||
@@ -1045,6 +1588,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
|
|
||||||
self._init_local_attn_metadata(metadata, device)
|
self._init_local_attn_metadata(metadata, device)
|
||||||
elif forward_mode.is_target_verify():
|
elif forward_mode.is_target_verify():
|
||||||
|
if self.topk <= 1:
|
||||||
metadata = self.target_verify_metadata[bs]
|
metadata = self.target_verify_metadata[bs]
|
||||||
metadata.cache_seqlens_int32.copy_(
|
metadata.cache_seqlens_int32.copy_(
|
||||||
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
|
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
|
||||||
@@ -1061,9 +1605,101 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
(1, 0),
|
(1, 0),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
page_table = self.req_to_token[req_pool_indices, : metadata.max_seq_len_k]
|
max_seq_pages = (
|
||||||
|
metadata.max_seq_len_k + self.page_size - 1
|
||||||
|
) // self.page_size
|
||||||
|
page_indices = self.req_to_token[
|
||||||
|
req_pool_indices[:, None],
|
||||||
|
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages],
|
||||||
|
]
|
||||||
|
page_indices //= self.page_size
|
||||||
|
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
||||||
|
else:
|
||||||
|
# When topk > 1, we need two specific target verify metadata, and then merge states
|
||||||
|
# 1. The first half of metadata for prefix tokens
|
||||||
|
metadata = self.target_verify_metadata_topk_normal[bs]
|
||||||
|
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
|
||||||
|
# metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
|
||||||
|
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
||||||
|
# metadata.cu_seqlens_q already set in capture
|
||||||
|
metadata.cu_seqlens_k.copy_(
|
||||||
|
torch.nn.functional.pad(
|
||||||
|
torch.cumsum(
|
||||||
|
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
||||||
|
),
|
||||||
|
(1, 0),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
page_table = self.req_to_token[
|
||||||
|
req_pool_indices, : metadata.max_seq_len_k
|
||||||
|
]
|
||||||
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
||||||
|
|
||||||
|
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
||||||
|
metadata_expand = self.target_verify_metadata_topk_expand[bs]
|
||||||
|
# metadata_expand.max_seq_len_q = 1, already set in capture
|
||||||
|
# metadata_expand.cu_seqlens_q already set in capture
|
||||||
|
|
||||||
|
offsets = torch.arange(
|
||||||
|
self.speculative_num_draft_tokens, device=device
|
||||||
|
).unsqueeze(
|
||||||
|
0
|
||||||
|
) # shape: (1, self.speculative_num_draft_tokens)
|
||||||
|
cols = offsets.expand(seq_lens.numel(), -1) + seq_lens.unsqueeze(1)
|
||||||
|
cum_len = torch.nn.functional.pad(
|
||||||
|
torch.cumsum(
|
||||||
|
(
|
||||||
|
seq_lens + self.speculative_num_draft_tokens
|
||||||
|
).repeat_interleave(self.speculative_num_draft_tokens),
|
||||||
|
dim=0,
|
||||||
|
),
|
||||||
|
(1, 0),
|
||||||
|
)[:-1]
|
||||||
|
mask_extraction_indices = (
|
||||||
|
cols.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
|
||||||
|
+ cum_len[:, None]
|
||||||
|
).view(1, -1)
|
||||||
|
# avoid extracting padded seq indices which will be out of boundary
|
||||||
|
mask_extraction_indices[
|
||||||
|
:, spec_info.positions.numel() * self.speculative_num_draft_tokens :
|
||||||
|
].fill_(0)
|
||||||
|
|
||||||
|
mask = spec_info.custom_mask[mask_extraction_indices].view(
|
||||||
|
-1, self.speculative_num_draft_tokens
|
||||||
|
) # (bsz * draft_num, draft_num)
|
||||||
|
col_indices = offsets.expand(
|
||||||
|
mask.shape[0], self.speculative_num_draft_tokens
|
||||||
|
)
|
||||||
|
keys = torch.where(
|
||||||
|
mask, col_indices, col_indices + self.speculative_num_draft_tokens
|
||||||
|
)
|
||||||
|
_, sort_order = torch.sort(keys, dim=1)
|
||||||
|
|
||||||
|
non_masked_page_table = (
|
||||||
|
self.req_to_token[req_pool_indices, :]
|
||||||
|
.gather(1, cols)
|
||||||
|
.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
|
||||||
|
) # (bsz, draft_num)
|
||||||
|
metadata_expand.page_table.copy_(
|
||||||
|
non_masked_page_table.gather(1, sort_order)
|
||||||
|
)
|
||||||
|
metadata_expand.cache_seqlens_int32.copy_(
|
||||||
|
mask.sum(dim=1).to(torch.int32)
|
||||||
|
)
|
||||||
|
metadata_expand.cu_seqlens_k.copy_(
|
||||||
|
torch.nn.functional.pad(
|
||||||
|
torch.cumsum(
|
||||||
|
metadata_expand.cache_seqlens_int32,
|
||||||
|
dim=0,
|
||||||
|
dtype=torch.int32,
|
||||||
|
),
|
||||||
|
(1, 0),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
metadata_expand.max_seq_len_k = (
|
||||||
|
metadata_expand.cache_seqlens_int32.max().item()
|
||||||
|
)
|
||||||
|
|
||||||
if encoder_lens is not None:
|
if encoder_lens is not None:
|
||||||
# Only support encoder size 1 for now
|
# Only support encoder size 1 for now
|
||||||
metadata.encoder_max_seq_len_k = encoder_lens[0]
|
metadata.encoder_max_seq_len_k = encoder_lens[0]
|
||||||
@@ -1089,6 +1725,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
||||||
|
|
||||||
self.forward_metadata = metadata
|
self.forward_metadata = metadata
|
||||||
|
self.forward_metadata_spec_decode_expand = metadata_expand
|
||||||
|
|
||||||
def get_cuda_graph_seq_len_fill_value(self):
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
"""Get the fill value for sequence length in CUDA graph."""
|
"""Get the fill value for sequence length in CUDA graph."""
|
||||||
@@ -1139,12 +1776,6 @@ class FlashAttentionMultiStepBackend:
|
|||||||
self.model_runner = model_runner
|
self.model_runner = model_runner
|
||||||
self.topk = topk
|
self.topk = topk
|
||||||
self.speculative_num_steps = speculative_num_steps
|
self.speculative_num_steps = speculative_num_steps
|
||||||
|
|
||||||
# TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding
|
|
||||||
assert (
|
|
||||||
self.topk == 1
|
|
||||||
), "speculative_eagle_topk must be 1 for FlashAttentionMultiStepBackend"
|
|
||||||
|
|
||||||
self.attn_backends = []
|
self.attn_backends = []
|
||||||
for i in range(self.speculative_num_steps):
|
for i in range(self.speculative_num_steps):
|
||||||
self.attn_backends.append(
|
self.attn_backends.append(
|
||||||
|
|||||||
@@ -221,7 +221,16 @@ class ModelRunner:
|
|||||||
server_args = self.server_args
|
server_args = self.server_args
|
||||||
|
|
||||||
if server_args.attention_backend is None:
|
if server_args.attention_backend is None:
|
||||||
# By default, use flashinfer for non-mla attention and triton for mla attention
|
"""
|
||||||
|
We auto select the fastest attention backend according to the current offering
|
||||||
|
1. Models with MHA Architecture (e.g: Llama, QWen)
|
||||||
|
1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
|
||||||
|
1.2 In other cases, we will use flashinfer if available, otherwise use triton.
|
||||||
|
2. Models with MLA Architecture and using FA3
|
||||||
|
2.1 We will use FA3 backend on hopper.
|
||||||
|
2.2 Otherwise, we will use triton backend.
|
||||||
|
"""
|
||||||
|
|
||||||
if not self.use_mla_backend:
|
if not self.use_mla_backend:
|
||||||
if (
|
if (
|
||||||
is_hopper_with_cuda_12_3()
|
is_hopper_with_cuda_12_3()
|
||||||
@@ -234,9 +243,7 @@ class ModelRunner:
|
|||||||
"flashinfer" if is_flashinfer_available() else "triton"
|
"flashinfer" if is_flashinfer_available() else "triton"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if is_hopper_with_cuda_12_3() and is_no_spec_infer_or_topk_one(
|
if is_hopper_with_cuda_12_3():
|
||||||
server_args
|
|
||||||
):
|
|
||||||
server_args.attention_backend = "fa3"
|
server_args.attention_backend = "fa3"
|
||||||
else:
|
else:
|
||||||
server_args.attention_backend = "triton"
|
server_args.attention_backend = "triton"
|
||||||
|
|||||||
@@ -359,7 +359,18 @@ class ServerArgs:
|
|||||||
|
|
||||||
if self.page_size > 1 and self.speculative_eagle_topk > 1:
|
if self.page_size > 1 and self.speculative_eagle_topk > 1:
|
||||||
self.speculative_eagle_topk = 1
|
self.speculative_eagle_topk = 1
|
||||||
logger.info("speculative_eagle_topk is changed to 1 when page_size > 1")
|
logger.info(
|
||||||
|
"speculative_eagle_topk is adjusted to 1 when page_size > 1"
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.speculative_eagle_topk == 1
|
||||||
|
and self.speculative_num_draft_tokens != self.speculative_num_steps + 1
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
"speculative_num_draft_tokens is adjusted to speculative_num_steps + 1 when speculative_eagle_topk == 1"
|
||||||
|
)
|
||||||
|
self.speculative_num_draft_tokens = self.speculative_num_steps + 1
|
||||||
|
|
||||||
# The token generated from the verify step is counted.
|
# The token generated from the verify step is counted.
|
||||||
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
|
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
|
||||||
|
|||||||
@@ -1909,6 +1909,8 @@ def is_page_size_one(server_args):
|
|||||||
return server_args.page_size == 1
|
return server_args.page_size == 1
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(hebiao064): Accelerate FA3 Spec Decode with topk > 1.
|
||||||
|
# TODO(hebiao064): Improve the acc rate for FA3 Spec Decode with topk == 1 and page_size > 1.
|
||||||
def is_no_spec_infer_or_topk_one(server_args):
|
def is_no_spec_infer_or_topk_one(server_args):
|
||||||
return server_args.speculative_eagle_topk is None or (
|
return server_args.speculative_eagle_topk is None or (
|
||||||
server_args.speculative_eagle_topk is not None
|
server_args.speculative_eagle_topk is not None
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ suites = {
|
|||||||
TestFile("test_chunked_prefill.py", 336),
|
TestFile("test_chunked_prefill.py", 336),
|
||||||
TestFile("test_eagle_infer.py", 500),
|
TestFile("test_eagle_infer.py", 500),
|
||||||
TestFile("test_ebnf_constrained.py"),
|
TestFile("test_ebnf_constrained.py"),
|
||||||
TestFile("test_fa3.py", 5),
|
TestFile("test_fa3.py", 200),
|
||||||
TestFile("test_fp8_kernel.py", 8),
|
TestFile("test_fp8_kernel.py", 8),
|
||||||
TestFile("test_embedding_openai_server.py", 36),
|
TestFile("test_embedding_openai_server.py", 36),
|
||||||
TestFile("test_hidden_states.py", 55),
|
TestFile("test_hidden_states.py", 55),
|
||||||
|
|||||||
@@ -173,6 +173,60 @@ class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest):
|
|||||||
self.assertGreater(avg_spec_accept_length, 1.5)
|
self.assertGreater(avg_spec_accept_length, 1.5)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest):
|
||||||
|
"""Test FlashAttention3 with speculative decode enabled, topk > 1"""
|
||||||
|
|
||||||
|
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_server_args(cls):
|
||||||
|
args = super().get_server_args()
|
||||||
|
args.extend(
|
||||||
|
[
|
||||||
|
"--cuda-graph-max-bs",
|
||||||
|
"2",
|
||||||
|
"--speculative-algorithm",
|
||||||
|
"EAGLE3",
|
||||||
|
"--speculative-draft",
|
||||||
|
"jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B",
|
||||||
|
"--speculative-num-steps",
|
||||||
|
"5",
|
||||||
|
"--speculative-eagle-topk",
|
||||||
|
"4",
|
||||||
|
"--speculative-num-draft-tokens",
|
||||||
|
"8",
|
||||||
|
"--dtype",
|
||||||
|
"float16",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return args
|
||||||
|
|
||||||
|
def test_gsm8k(self):
|
||||||
|
"""
|
||||||
|
Override the test_gsm8k to further test for average speculative accept length.
|
||||||
|
"""
|
||||||
|
requests.get(self.base_url + "/flush_cache")
|
||||||
|
|
||||||
|
args = SimpleNamespace(
|
||||||
|
num_shots=5,
|
||||||
|
data_path=DATA_PATH,
|
||||||
|
num_questions=200,
|
||||||
|
max_new_tokens=512,
|
||||||
|
parallel=128,
|
||||||
|
host="http://127.0.0.1",
|
||||||
|
port=int(self.base_url.split(":")[-1]),
|
||||||
|
)
|
||||||
|
metrics = run_eval_few_shot_gsm8k(args)
|
||||||
|
print(metrics)
|
||||||
|
|
||||||
|
self.assertGreater(metrics["accuracy"], 0.60)
|
||||||
|
|
||||||
|
server_info = requests.get(self.base_url + "/get_server_info")
|
||||||
|
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
|
||||||
|
print(f"{avg_spec_accept_length=}")
|
||||||
|
self.assertGreater(avg_spec_accept_length, 1.8)
|
||||||
|
|
||||||
|
|
||||||
class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest):
|
class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest):
|
||||||
"""Test FlashAttention3 with speculative decode enabled."""
|
"""Test FlashAttention3 with speculative decode enabled."""
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user