Add Speculative Decoding Eagle3 topk > 1 (#5318)
Co-authored-by: Stefan He <hebiaobuaa@gmail.com> Co-authored-by: Yubo Wang <yubowang2019@gmail.com>
This commit is contained in:
@@ -16,6 +16,7 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
|
||||
from sgl_kernel import merge_state_v2
|
||||
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
||||
|
||||
|
||||
@@ -30,7 +31,7 @@ class FlashAttentionMetadata:
|
||||
# Sequence lengths for the forward batch
|
||||
cache_seqlens_int32: torch.Tensor = None
|
||||
# Maximum sequence length for query
|
||||
max_seq_len_q: int = 0
|
||||
max_seq_len_q: int = 1
|
||||
# Maximum sequence length for key
|
||||
max_seq_len_k: int = 0
|
||||
# Cumulative sequence lengths for query
|
||||
@@ -267,6 +268,12 @@ def cdiv(a: int, b: int) -> int:
|
||||
return -(a // -b)
|
||||
|
||||
|
||||
# TODO(hebiao064): remove this once we have a better way to handle the merge_state_v2 torch.compile issue
|
||||
@torch._dynamo.disable()
|
||||
def merge_state_v2_wrapper(o, s_a, o_exp, s_b):
|
||||
return merge_state_v2(o, s_a, o_exp, s_b)
|
||||
|
||||
|
||||
class FlashAttentionBackend(AttentionBackend):
|
||||
"""FlashAttention backend implementation.
|
||||
|
||||
@@ -301,6 +308,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
), "Sliding window and cross attention are not supported together"
|
||||
|
||||
self.forward_metadata: FlashAttentionMetadata = None
|
||||
# extra metdata for handling speculative decoding topk > 1, extended draft decode and verify
|
||||
self.forward_metadata_spec_decode_expand: FlashAttentionMetadata = None
|
||||
self.max_context_len = model_runner.model_config.context_len
|
||||
self.device = model_runner.device
|
||||
self.decode_cuda_graph_metadata = {}
|
||||
@@ -311,8 +320,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
self.page_size = model_runner.page_size
|
||||
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
|
||||
self.skip_prefill = skip_prefill
|
||||
|
||||
self.topk = topk
|
||||
self.topk = model_runner.server_args.speculative_eagle_topk or 0
|
||||
self.speculative_num_steps = speculative_num_steps
|
||||
self.speculative_num_draft_tokens = (
|
||||
model_runner.server_args.speculative_num_draft_tokens
|
||||
@@ -336,6 +344,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
# Draft Decode
|
||||
if forward_batch.spec_info is not None:
|
||||
if self.topk <= 1:
|
||||
metadata.cache_seqlens_int32 = (
|
||||
seqlens_in_batch + (self.speculative_step_id + 1)
|
||||
).to(torch.int32)
|
||||
@@ -354,8 +363,57 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||
]
|
||||
else:
|
||||
metadata.cache_seqlens_int32 = (seqlens_in_batch).to(torch.int32)
|
||||
metadata.max_seq_len_q = self.topk
|
||||
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
||||
metadata.cu_seqlens_q = torch.arange(
|
||||
0,
|
||||
batch_size * self.topk + 1,
|
||||
step=self.topk,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||
torch.cumsum(
|
||||
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
||||
),
|
||||
(1, 0),
|
||||
)
|
||||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||
]
|
||||
|
||||
self._init_local_attn_metadata(metadata, device)
|
||||
metadata_expand = FlashAttentionMetadata()
|
||||
decode_length = self.speculative_step_id + 1
|
||||
metadata_expand.cache_seqlens_int32 = torch.full(
|
||||
(seqlens_in_batch.numel() * self.topk,),
|
||||
decode_length,
|
||||
device=device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
metadata_expand.max_seq_len_q = 1
|
||||
metadata_expand.max_seq_len_k = self.speculative_step_id + 1
|
||||
metadata_expand.cu_seqlens_q = torch.arange(
|
||||
0,
|
||||
metadata_expand.cache_seqlens_int32.numel() + 1,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
metadata_expand.cu_seqlens_k = torch.arange(
|
||||
0,
|
||||
metadata_expand.cache_seqlens_int32.numel() * decode_length + 1,
|
||||
step=decode_length,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
cache_loc = forward_batch.out_cache_loc.view(
|
||||
self.speculative_num_steps, -1
|
||||
).T.contiguous()
|
||||
metadata_expand.page_table = (
|
||||
cache_loc[:, :decode_length].contiguous().to(torch.int32)
|
||||
)
|
||||
self.forward_metadata_spec_decode_expand = metadata_expand
|
||||
else:
|
||||
# Normal Decode
|
||||
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
||||
@@ -369,9 +427,10 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||
]
|
||||
|
||||
# TODO: we need to test this part for llama 4 eagle case
|
||||
self._init_local_attn_metadata(metadata, device)
|
||||
elif forward_batch.forward_mode.is_target_verify():
|
||||
if self.topk <= 1:
|
||||
metadata.cache_seqlens_int32 = (
|
||||
forward_batch.seq_lens + self.speculative_num_draft_tokens
|
||||
).to(torch.int32)
|
||||
@@ -388,13 +447,112 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
device=device,
|
||||
)
|
||||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||
torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32),
|
||||
torch.cumsum(
|
||||
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
||||
),
|
||||
(1, 0),
|
||||
)
|
||||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||
]
|
||||
|
||||
self._init_local_attn_metadata(metadata, device)
|
||||
else:
|
||||
metadata.cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
|
||||
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
||||
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
||||
metadata.cu_seqlens_q = torch.arange(
|
||||
0,
|
||||
batch_size * self.speculative_num_draft_tokens + 1,
|
||||
step=self.speculative_num_draft_tokens,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||
torch.cumsum(
|
||||
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
||||
),
|
||||
(1, 0),
|
||||
)
|
||||
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
||||
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
||||
]
|
||||
|
||||
metadata_expand = FlashAttentionMetadata()
|
||||
|
||||
metadata_expand.max_seq_len_q = 1
|
||||
metadata_expand.cu_seqlens_q = torch.arange(
|
||||
0,
|
||||
forward_batch.seq_lens.numel() * self.speculative_num_draft_tokens
|
||||
+ 1,
|
||||
dtype=torch.int32,
|
||||
device=device,
|
||||
)
|
||||
|
||||
# create expand page table
|
||||
offsets = torch.arange(
|
||||
self.speculative_num_draft_tokens, device=device
|
||||
).unsqueeze(
|
||||
0
|
||||
) # shape: (1, self.speculative_num_draft_tokens)
|
||||
cols = offsets.expand(
|
||||
forward_batch.seq_lens.numel(), -1
|
||||
) + forward_batch.seq_lens.unsqueeze(1)
|
||||
cum_len = torch.nn.functional.pad(
|
||||
torch.cumsum(
|
||||
(
|
||||
forward_batch.seq_lens + self.speculative_num_draft_tokens
|
||||
).repeat_interleave(self.speculative_num_draft_tokens),
|
||||
dim=0,
|
||||
),
|
||||
(1, 0),
|
||||
)[:-1]
|
||||
mask_extraction_indices = (
|
||||
cols.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
|
||||
+ cum_len[:, None]
|
||||
).view(1, -1)
|
||||
mask = forward_batch.spec_info.custom_mask[
|
||||
mask_extraction_indices
|
||||
].view(
|
||||
-1, self.speculative_num_draft_tokens
|
||||
) # (bsz * draft_num, draft_num)
|
||||
|
||||
# shift table indices to avoid padding
|
||||
# non_masked_page_table [[8, 9, 10], mask (display with int format) [[1, 0, 0],
|
||||
# [8, 9, 10], [1, 1, 0],
|
||||
# [8, 9, 10]] [1, 0, 1]]
|
||||
# if masked with padding [[8, 0, 0], our mask without padding [[8, 9, 10],
|
||||
# [8, 9, 0], [8, 9, 10],
|
||||
# [8, 0, 10]] [8, 10, 9]]
|
||||
# note here cache_seqlens_int32 is [1, 2, 2] so extra page indices will be ignored in each row
|
||||
col_indices = offsets.expand(
|
||||
mask.shape[0], self.speculative_num_draft_tokens
|
||||
)
|
||||
# Build keys: if an entry is valid (mask==True), keep its original index;
|
||||
# if not, add self.speculative_num_draft_tokens so that it sorts after all valid entries.
|
||||
keys = torch.where(
|
||||
mask, col_indices, col_indices + self.speculative_num_draft_tokens
|
||||
)
|
||||
_, sort_order = torch.sort(keys, dim=1)
|
||||
non_masked_page_table = (
|
||||
forward_batch.req_to_token_pool.req_to_token[
|
||||
forward_batch.req_pool_indices, :
|
||||
]
|
||||
.gather(1, cols)
|
||||
.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
|
||||
) # (bsz, draft_num)
|
||||
metadata_expand.page_table = non_masked_page_table.gather(1, sort_order)
|
||||
metadata_expand.cache_seqlens_int32 = mask.sum(dim=1).to(torch.int32)
|
||||
metadata_expand.cu_seqlens_k = torch.nn.functional.pad(
|
||||
torch.cumsum(
|
||||
metadata_expand.cache_seqlens_int32, dim=0, dtype=torch.int32
|
||||
),
|
||||
(1, 0),
|
||||
)
|
||||
metadata_expand.max_seq_len_k = (
|
||||
metadata_expand.cache_seqlens_int32.max().item()
|
||||
)
|
||||
self.forward_metadata_spec_decode_expand = metadata_expand
|
||||
elif forward_batch.forward_mode.is_extend_or_draft_extend_or_mixed():
|
||||
metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
|
||||
metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
|
||||
@@ -514,6 +672,11 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
and (hasattr(layer, "use_irope") and layer.use_irope)
|
||||
)
|
||||
|
||||
# We do cascade attention for Target Verify with topk > 1
|
||||
use_cascade_attn = (
|
||||
forward_batch.forward_mode.is_target_verify() and self.topk > 1
|
||||
)
|
||||
|
||||
# Get the appropriate page table based on whether we're using local attention
|
||||
if use_local_attn:
|
||||
local_metadata = metadata.local_attn_metadata
|
||||
@@ -548,7 +711,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
cu_seqlens_k = metadata.encoder_cu_seqlens_k
|
||||
window_size = (-1, -1)
|
||||
|
||||
o = flash_attn_with_kvcache(
|
||||
result = flash_attn_with_kvcache(
|
||||
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
@@ -558,13 +721,41 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
softmax_scale=layer.scaling,
|
||||
causal=causal,
|
||||
causal=False if use_cascade_attn else causal,
|
||||
window_size=window_size,
|
||||
softcap=layer.logit_cap,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
return_softmax_lse=use_cascade_attn,
|
||||
)
|
||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||
|
||||
if use_cascade_attn:
|
||||
o, softmax_lse, *rest = result
|
||||
o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(
|
||||
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
page_table=self.forward_metadata_spec_decode_expand.page_table,
|
||||
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
|
||||
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
|
||||
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
|
||||
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
|
||||
softmax_scale=layer.scaling,
|
||||
causal=False,
|
||||
window_size=window_size,
|
||||
softcap=layer.logit_cap,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
return_softmax_lse=True,
|
||||
)
|
||||
o, _ = merge_state_v2_wrapper(
|
||||
o,
|
||||
softmax_lse.T.contiguous(),
|
||||
o_expand,
|
||||
softmax_lse_expand.T.contiguous(),
|
||||
)
|
||||
else:
|
||||
o = result
|
||||
else:
|
||||
if (
|
||||
not global_server_args_dict["disable_chunked_prefix_cache"]
|
||||
@@ -627,7 +818,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
q_nope = q_all[:, :, : layer.v_head_dim]
|
||||
q_rope = q_all[:, :, layer.v_head_dim :]
|
||||
o = flash_attn_with_kvcache(
|
||||
|
||||
result = flash_attn_with_kvcache(
|
||||
q=q_rope,
|
||||
k_cache=k_rope_cache,
|
||||
v_cache=c_kv_cache,
|
||||
@@ -638,11 +830,42 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
cu_seqlens_k_new=cu_seqlens_k if not use_local_attn else None,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
softmax_scale=layer.scaling,
|
||||
causal=True,
|
||||
causal=False if use_cascade_attn else causal,
|
||||
softcap=layer.logit_cap,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
return_softmax_lse=use_cascade_attn,
|
||||
)
|
||||
if use_cascade_attn:
|
||||
o, softmax_lse, *rest = result
|
||||
o_expand, softmax_lse_expand, *rest_expand = (
|
||||
flash_attn_with_kvcache(
|
||||
q=q_rope,
|
||||
k_cache=k_rope_cache,
|
||||
v_cache=c_kv_cache,
|
||||
qv=q_nope,
|
||||
page_table=self.forward_metadata_spec_decode_expand.page_table,
|
||||
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
|
||||
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
|
||||
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
|
||||
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
|
||||
softmax_scale=layer.scaling,
|
||||
causal=False,
|
||||
window_size=window_size,
|
||||
softcap=layer.logit_cap,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
return_softmax_lse=True,
|
||||
)
|
||||
)
|
||||
o, _ = merge_state_v2_wrapper(
|
||||
o,
|
||||
softmax_lse.T.contiguous(),
|
||||
o_expand,
|
||||
softmax_lse_expand.T.contiguous(),
|
||||
)
|
||||
else:
|
||||
o = result
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||
|
||||
@@ -681,6 +904,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
use_local_attention = (
|
||||
self.attention_chunk_size is not None and local_attn_metadata is not None
|
||||
)
|
||||
# We do cascade attention for Draft Decode with topk > 1
|
||||
use_cascade_attn = self.topk > 1
|
||||
|
||||
# Calculate window size (can be moved to metadata if layer properties don't change)
|
||||
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
||||
@@ -752,23 +977,61 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
v_descale=v_descale,
|
||||
)
|
||||
else:
|
||||
page_table = metadata.page_table
|
||||
cache_seqlens = metadata.cache_seqlens_int32
|
||||
cu_seqlens_k = metadata.cu_seqlens_k
|
||||
max_seqlen_q = metadata.max_seq_len_q
|
||||
q_reshaped = q.contiguous().view(
|
||||
-1, layer.tp_q_head_num, layer.head_dim
|
||||
)
|
||||
|
||||
# Default: single-token self-attention
|
||||
o = flash_attn_with_kvcache(
|
||||
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
||||
result = flash_attn_with_kvcache(
|
||||
q=q_reshaped,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
page_table=metadata.page_table,
|
||||
cache_seqlens=metadata.cache_seqlens_int32,
|
||||
page_table=page_table,
|
||||
cache_seqlens=cache_seqlens,
|
||||
cu_seqlens_q=metadata.cu_seqlens_q,
|
||||
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
||||
max_seqlen_q=1,
|
||||
cu_seqlens_k_new=cu_seqlens_k,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
softmax_scale=layer.scaling,
|
||||
causal=True,
|
||||
causal=False if use_cascade_attn else causal,
|
||||
window_size=window_size,
|
||||
softcap=layer.logit_cap,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
return_softmax_lse=use_cascade_attn,
|
||||
)
|
||||
if use_cascade_attn:
|
||||
o, softmax_lse, *rest = result
|
||||
o_expand, softmax_lse_expand, *rest_expand = (
|
||||
flash_attn_with_kvcache(
|
||||
q=q_reshaped,
|
||||
k_cache=key_cache,
|
||||
v_cache=value_cache,
|
||||
page_table=self.forward_metadata_spec_decode_expand.page_table,
|
||||
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
|
||||
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
|
||||
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
|
||||
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
|
||||
softmax_scale=layer.scaling,
|
||||
causal=False,
|
||||
window_size=window_size,
|
||||
softcap=layer.logit_cap,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
return_softmax_lse=True,
|
||||
)
|
||||
)
|
||||
o, _ = merge_state_v2(
|
||||
o,
|
||||
softmax_lse.T.contiguous(),
|
||||
o_expand,
|
||||
softmax_lse_expand.T.contiguous(),
|
||||
)
|
||||
else:
|
||||
o = result
|
||||
else:
|
||||
# Do absorbed multi-latent attention
|
||||
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
||||
@@ -787,8 +1050,9 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
||||
q_nope = q_all[:, :, : layer.v_head_dim]
|
||||
q_rope = q_all[:, :, layer.v_head_dim :]
|
||||
max_seqlen_q = metadata.max_seq_len_q
|
||||
|
||||
o = flash_attn_with_kvcache(
|
||||
result = flash_attn_with_kvcache(
|
||||
q=q_rope,
|
||||
k_cache=k_rope_cache,
|
||||
v_cache=c_kv_cache,
|
||||
@@ -797,13 +1061,43 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
cache_seqlens=metadata.cache_seqlens_int32,
|
||||
cu_seqlens_q=metadata.cu_seqlens_q,
|
||||
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
||||
max_seqlen_q=1,
|
||||
max_seqlen_q=max_seqlen_q,
|
||||
softmax_scale=layer.scaling,
|
||||
causal=True,
|
||||
causal=False if use_cascade_attn else causal,
|
||||
softcap=layer.logit_cap,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
return_softmax_lse=use_cascade_attn, # softmax_lse is needed for merge states
|
||||
)
|
||||
if use_cascade_attn:
|
||||
o, softmax_lse, *rest = result
|
||||
o_expand, softmax_lse_expand, *rest_expand = flash_attn_with_kvcache(
|
||||
q=q_rope,
|
||||
k_cache=k_rope_cache,
|
||||
v_cache=c_kv_cache,
|
||||
qv=q_nope,
|
||||
page_table=self.forward_metadata_spec_decode_expand.page_table,
|
||||
cache_seqlens=self.forward_metadata_spec_decode_expand.cache_seqlens_int32,
|
||||
cu_seqlens_q=self.forward_metadata_spec_decode_expand.cu_seqlens_q,
|
||||
cu_seqlens_k_new=self.forward_metadata_spec_decode_expand.cu_seqlens_k,
|
||||
max_seqlen_q=self.forward_metadata_spec_decode_expand.max_seq_len_q,
|
||||
softmax_scale=layer.scaling,
|
||||
causal=False,
|
||||
window_size=window_size,
|
||||
softcap=layer.logit_cap,
|
||||
k_descale=k_descale,
|
||||
v_descale=v_descale,
|
||||
return_softmax_lse=True,
|
||||
)
|
||||
o, _ = merge_state_v2(
|
||||
o,
|
||||
softmax_lse.T.contiguous(),
|
||||
o_expand,
|
||||
softmax_lse_expand.T.contiguous(),
|
||||
)
|
||||
else:
|
||||
o = result
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int):
|
||||
@@ -815,6 +1109,8 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
This creates fixed-size tensors that will be reused during CUDA graph replay
|
||||
to avoid memory allocations.
|
||||
"""
|
||||
|
||||
# This is being used by normal decode and draft decode when topk == 1
|
||||
self.decode_cuda_graph_metadata = {
|
||||
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
||||
"cu_seqlens_q": torch.arange(
|
||||
@@ -840,11 +1136,75 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
),
|
||||
}
|
||||
|
||||
self.target_verify_metadata = {
|
||||
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
||||
"cu_seqlens_q": torch.zeros(
|
||||
# This is used by draft decode's first half of metadata when topk > 1
|
||||
if self.topk > 1:
|
||||
self.draft_decode_metadata_topk_normal = {
|
||||
"cache_seqlens": torch.zeros(
|
||||
max_bs, dtype=torch.int32, device=self.device
|
||||
),
|
||||
"cu_seqlens_q": torch.arange(
|
||||
0,
|
||||
max_bs * self.topk + 1,
|
||||
step=self.topk,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
"cu_seqlens_k": torch.zeros(
|
||||
max_bs + 1, dtype=torch.int32, device=self.device
|
||||
),
|
||||
"page_table": torch.zeros(
|
||||
max_bs,
|
||||
self.max_context_len,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
}
|
||||
|
||||
# This is used by draft decode's second half of metadata when topk > 1
|
||||
decode_length = self.speculative_step_id + 1
|
||||
self.draft_decode_metadata_topk_expand = {
|
||||
"cache_seqlens": torch.full(
|
||||
(max_bs * self.topk,),
|
||||
decode_length,
|
||||
device=self.device,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
"cu_seqlens_q": torch.arange(
|
||||
0,
|
||||
max_bs * self.topk + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
"cu_seqlens_k": torch.arange(
|
||||
0,
|
||||
max_bs * self.topk * decode_length + 1,
|
||||
step=decode_length,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
"page_table": torch.zeros(
|
||||
max_bs * self.topk,
|
||||
decode_length,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
}
|
||||
|
||||
if (
|
||||
self.speculative_num_draft_tokens is not None
|
||||
and self.speculative_num_draft_tokens > 0
|
||||
):
|
||||
self.target_verify_metadata = {
|
||||
"cache_seqlens": torch.zeros(
|
||||
max_bs, dtype=torch.int32, device=self.device
|
||||
),
|
||||
"cu_seqlens_q": torch.arange(
|
||||
0,
|
||||
max_bs * self.speculative_num_draft_tokens + 1,
|
||||
step=self.speculative_num_draft_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
"cu_seqlens_k": torch.zeros(
|
||||
max_bs + 1, dtype=torch.int32, device=self.device
|
||||
),
|
||||
@@ -859,6 +1219,54 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
),
|
||||
}
|
||||
|
||||
if self.topk > 1:
|
||||
self.target_verify_metadata_topk_normal = {
|
||||
"cache_seqlens": torch.zeros(
|
||||
max_bs, dtype=torch.int32, device=self.device
|
||||
),
|
||||
"cu_seqlens_q": torch.arange(
|
||||
0,
|
||||
max_bs * self.speculative_num_draft_tokens + 1,
|
||||
step=self.speculative_num_draft_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
"cu_seqlens_k": torch.zeros(
|
||||
max_bs + 1, dtype=torch.int32, device=self.device
|
||||
),
|
||||
"page_table": torch.zeros(
|
||||
max_bs,
|
||||
self.max_context_len,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
}
|
||||
|
||||
self.target_verify_metadata_topk_expand = {
|
||||
"cache_seqlens": torch.zeros(
|
||||
max_bs * self.speculative_num_draft_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
"cu_seqlens_k": torch.zeros(
|
||||
max_bs * self.speculative_num_draft_tokens + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
"cu_seqlens_q": torch.arange(
|
||||
0,
|
||||
max_bs * self.speculative_num_draft_tokens + 1,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
"page_table": torch.zeros(
|
||||
max_bs * self.speculative_num_draft_tokens,
|
||||
self.speculative_num_draft_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
),
|
||||
}
|
||||
|
||||
self.encoder_metadata = {
|
||||
"encoder_page_table": torch.zeros(
|
||||
max_bs,
|
||||
@@ -886,19 +1294,25 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
):
|
||||
"""Initialize forward metadata for capturing CUDA graph."""
|
||||
metadata = FlashAttentionMetadata()
|
||||
|
||||
# metadata_expand is needed for Spec Decoding when top k > 1
|
||||
metadata_expand = FlashAttentionMetadata()
|
||||
|
||||
device = seq_lens.device
|
||||
if forward_mode.is_decode_or_idle():
|
||||
if spec_info is not None:
|
||||
# Draft Decode
|
||||
if self.topk <= 1:
|
||||
# When topk = 1, we use the normal decode metadata
|
||||
metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
|
||||
"cache_seqlens"
|
||||
][:bs]
|
||||
metadata.max_seq_len_k = seq_lens.max().item() + (
|
||||
self.speculative_step_id + 1
|
||||
)
|
||||
metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][
|
||||
: bs + 1
|
||||
]
|
||||
metadata.cu_seqlens_q = self.decode_cuda_graph_metadata[
|
||||
"cu_seqlens_q"
|
||||
][: bs + 1]
|
||||
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
||||
torch.cumsum(
|
||||
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
||||
@@ -908,6 +1322,50 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
metadata.page_table = self.decode_cuda_graph_metadata[
|
||||
"page_table_draft_decode"
|
||||
][req_pool_indices, :]
|
||||
self.decode_cuda_graph_metadata[bs] = metadata
|
||||
else:
|
||||
# When top k > 1, we need two specific draft decode metadata, and then merge states
|
||||
# 1. The first half of metadata for prefix tokens
|
||||
metadata.cache_seqlens_int32 = (
|
||||
self.draft_decode_metadata_topk_normal["cache_seqlens"][:bs]
|
||||
)
|
||||
metadata.max_seq_len_q = self.topk
|
||||
metadata.max_seq_len_k = seq_lens.max().item()
|
||||
metadata.cu_seqlens_q = self.draft_decode_metadata_topk_normal[
|
||||
"cu_seqlens_q"
|
||||
][: bs + 1]
|
||||
metadata.cu_seqlens_k = self.draft_decode_metadata_topk_normal[
|
||||
"cu_seqlens_k"
|
||||
][: bs + 1]
|
||||
metadata.page_table = self.draft_decode_metadata_topk_normal[
|
||||
"page_table"
|
||||
][req_pool_indices, :]
|
||||
|
||||
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
||||
metadata_expand.cache_seqlens_int32 = (
|
||||
self.draft_decode_metadata_topk_expand["cache_seqlens"][
|
||||
: bs * self.topk
|
||||
]
|
||||
)
|
||||
metadata_expand.max_seq_len_q = 1
|
||||
metadata_expand.max_seq_len_k = (
|
||||
self.speculative_step_id + 1
|
||||
) # , do this in replay
|
||||
metadata_expand.cu_seqlens_q = (
|
||||
self.draft_decode_metadata_topk_expand["cu_seqlens_q"][
|
||||
: bs * self.topk + 1
|
||||
]
|
||||
)
|
||||
metadata_expand.cu_seqlens_k = (
|
||||
self.draft_decode_metadata_topk_expand["cu_seqlens_k"][
|
||||
: bs * self.topk + 1
|
||||
]
|
||||
)
|
||||
metadata_expand.page_table = self.draft_decode_metadata_topk_expand[
|
||||
"page_table"
|
||||
][: bs * self.topk]
|
||||
self.draft_decode_metadata_topk_normal[bs] = metadata
|
||||
self.draft_decode_metadata_topk_expand[bs] = metadata_expand
|
||||
else:
|
||||
# Normal Decode
|
||||
# Get sequence information
|
||||
@@ -928,10 +1386,12 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
0, batch_size + 1, dtype=torch.int32, device=device
|
||||
)
|
||||
self.decode_cuda_graph_metadata[bs] = metadata
|
||||
|
||||
elif forward_mode.is_target_verify():
|
||||
metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][
|
||||
:bs
|
||||
]
|
||||
if self.topk <= 1:
|
||||
metadata.cache_seqlens_int32 = self.target_verify_metadata[
|
||||
"cache_seqlens"
|
||||
][:bs]
|
||||
metadata.cache_seqlens_int32.copy_(
|
||||
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
|
||||
)
|
||||
@@ -958,6 +1418,44 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
]
|
||||
|
||||
self.target_verify_metadata[bs] = metadata
|
||||
else:
|
||||
# When topk > 1, we need two specific target verify metadata, and then merge states
|
||||
# 1. The first half of metadata for prefix tokens
|
||||
metadata.cache_seqlens_int32 = self.target_verify_metadata_topk_normal[
|
||||
"cache_seqlens"
|
||||
][:bs]
|
||||
metadata.max_seq_len_q = self.speculative_num_draft_tokens
|
||||
# metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item(), do this in replay
|
||||
metadata.cu_seqlens_q = self.target_verify_metadata_topk_normal[
|
||||
"cu_seqlens_q"
|
||||
][: bs + 1]
|
||||
metadata.cu_seqlens_k = self.target_verify_metadata_topk_normal[
|
||||
"cu_seqlens_k"
|
||||
][: bs + 1]
|
||||
metadata.page_table = self.target_verify_metadata_topk_normal[
|
||||
"page_table"
|
||||
][req_pool_indices, :]
|
||||
|
||||
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
||||
metadata_expand.cache_seqlens_int32 = (
|
||||
self.target_verify_metadata_topk_expand["cache_seqlens"][
|
||||
: bs * self.speculative_num_draft_tokens
|
||||
]
|
||||
)
|
||||
metadata_expand.max_seq_len_q = 1
|
||||
metadata_expand.cu_seqlens_q = self.target_verify_metadata_topk_expand[
|
||||
"cu_seqlens_q"
|
||||
][: bs * self.speculative_num_draft_tokens + 1]
|
||||
metadata_expand.cu_seqlens_k = self.target_verify_metadata_topk_expand[
|
||||
"cu_seqlens_k"
|
||||
][: bs * self.speculative_num_draft_tokens + 1]
|
||||
|
||||
metadata_expand.page_table = self.target_verify_metadata_topk_expand[
|
||||
"page_table"
|
||||
][: bs * self.speculative_num_draft_tokens]
|
||||
|
||||
self.target_verify_metadata_topk_normal[bs] = metadata
|
||||
self.target_verify_metadata_topk_expand[bs] = metadata_expand
|
||||
|
||||
if encoder_lens is not None:
|
||||
encoder_bs = encoder_lens.numel()
|
||||
@@ -973,6 +1471,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
]
|
||||
|
||||
self.forward_metadata = metadata
|
||||
self.forward_metadata_spec_decode_expand = metadata_expand
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
self,
|
||||
@@ -986,17 +1485,21 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
out_cache_loc: torch.Tensor = None,
|
||||
):
|
||||
# """Initialize forward metadata for replaying CUDA graph."""
|
||||
"""Initialize forward metadata for replaying CUDA graph."""
|
||||
seq_lens = seq_lens[:bs]
|
||||
seq_lens_cpu = seq_lens_cpu[:bs]
|
||||
req_pool_indices = req_pool_indices[:bs]
|
||||
device = seq_lens.device
|
||||
metadata = None
|
||||
metadata_expand = None
|
||||
|
||||
if forward_mode.is_decode_or_idle():
|
||||
metadata = self.decode_cuda_graph_metadata[bs]
|
||||
|
||||
if spec_info is not None:
|
||||
# Draft Decode
|
||||
if self.topk <= 1:
|
||||
metadata = self.decode_cuda_graph_metadata[bs]
|
||||
# When topk = 1, we use the normal decode metadata
|
||||
metadata.cache_seqlens_int32.copy_(
|
||||
(seq_lens + (self.speculative_step_id + 1)).to(torch.int32)
|
||||
)
|
||||
@@ -1013,14 +1516,54 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
)
|
||||
)
|
||||
|
||||
max_seq_pages = (
|
||||
metadata.max_seq_len_k + self.page_size - 1
|
||||
) // self.page_size
|
||||
page_indices = self.req_to_token[
|
||||
req_pool_indices[:, None],
|
||||
self.decode_cuda_graph_metadata["strided_indices"][
|
||||
:max_seq_pages
|
||||
],
|
||||
]
|
||||
|
||||
page_indices //= self.page_size
|
||||
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
||||
else:
|
||||
# When top k > 1, we need two specific draft decode metadata, and then merge states
|
||||
# 1. The first half of metadata for prefix tokens
|
||||
metadata = self.draft_decode_metadata_topk_normal[bs]
|
||||
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
|
||||
# metadata.max_seq_len_q = self.topk, already set in capture
|
||||
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
||||
# metadata.cu_seqlens_q already set in capture
|
||||
metadata.cu_seqlens_k.copy_(
|
||||
torch.nn.functional.pad(
|
||||
torch.cumsum(
|
||||
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
||||
),
|
||||
(1, 0),
|
||||
)
|
||||
)
|
||||
|
||||
page_table = self.req_to_token[
|
||||
req_pool_indices, : metadata.max_seq_len_k
|
||||
]
|
||||
|
||||
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
||||
|
||||
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
||||
metadata_expand = self.draft_decode_metadata_topk_expand[bs]
|
||||
decode_length = self.speculative_step_id + 1
|
||||
cache_loc = out_cache_loc.view(
|
||||
self.speculative_num_steps, -1
|
||||
).T.contiguous()
|
||||
metadata_expand.page_table[: cache_loc.shape[0]].copy_(
|
||||
cache_loc[:, :decode_length].contiguous().to(torch.int32)
|
||||
)
|
||||
# TODO: we need to test this part for llama 4 eagle case
|
||||
self._init_local_attn_metadata(metadata, device)
|
||||
else:
|
||||
metadata = self.decode_cuda_graph_metadata[bs]
|
||||
# Normal Decode
|
||||
max_len = seq_lens_cpu.max().item()
|
||||
metadata.max_seq_len_k = max_len
|
||||
@@ -1045,6 +1588,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
self._init_local_attn_metadata(metadata, device)
|
||||
elif forward_mode.is_target_verify():
|
||||
if self.topk <= 1:
|
||||
metadata = self.target_verify_metadata[bs]
|
||||
metadata.cache_seqlens_int32.copy_(
|
||||
(seq_lens + self.speculative_num_draft_tokens).to(torch.int32)
|
||||
@@ -1061,9 +1605,101 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
(1, 0),
|
||||
)
|
||||
)
|
||||
page_table = self.req_to_token[req_pool_indices, : metadata.max_seq_len_k]
|
||||
max_seq_pages = (
|
||||
metadata.max_seq_len_k + self.page_size - 1
|
||||
) // self.page_size
|
||||
page_indices = self.req_to_token[
|
||||
req_pool_indices[:, None],
|
||||
self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages],
|
||||
]
|
||||
page_indices //= self.page_size
|
||||
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
||||
else:
|
||||
# When topk > 1, we need two specific target verify metadata, and then merge states
|
||||
# 1. The first half of metadata for prefix tokens
|
||||
metadata = self.target_verify_metadata_topk_normal[bs]
|
||||
metadata.cache_seqlens_int32.copy_(seq_lens.to(torch.int32))
|
||||
# metadata.max_seq_len_q = self.speculative_num_draft_tokens, already set in capture
|
||||
metadata.max_seq_len_k = seq_lens_cpu.max().item()
|
||||
# metadata.cu_seqlens_q already set in capture
|
||||
metadata.cu_seqlens_k.copy_(
|
||||
torch.nn.functional.pad(
|
||||
torch.cumsum(
|
||||
metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
|
||||
),
|
||||
(1, 0),
|
||||
)
|
||||
)
|
||||
page_table = self.req_to_token[
|
||||
req_pool_indices, : metadata.max_seq_len_k
|
||||
]
|
||||
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
||||
|
||||
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
||||
metadata_expand = self.target_verify_metadata_topk_expand[bs]
|
||||
# metadata_expand.max_seq_len_q = 1, already set in capture
|
||||
# metadata_expand.cu_seqlens_q already set in capture
|
||||
|
||||
offsets = torch.arange(
|
||||
self.speculative_num_draft_tokens, device=device
|
||||
).unsqueeze(
|
||||
0
|
||||
) # shape: (1, self.speculative_num_draft_tokens)
|
||||
cols = offsets.expand(seq_lens.numel(), -1) + seq_lens.unsqueeze(1)
|
||||
cum_len = torch.nn.functional.pad(
|
||||
torch.cumsum(
|
||||
(
|
||||
seq_lens + self.speculative_num_draft_tokens
|
||||
).repeat_interleave(self.speculative_num_draft_tokens),
|
||||
dim=0,
|
||||
),
|
||||
(1, 0),
|
||||
)[:-1]
|
||||
mask_extraction_indices = (
|
||||
cols.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
|
||||
+ cum_len[:, None]
|
||||
).view(1, -1)
|
||||
# avoid extracting padded seq indices which will be out of boundary
|
||||
mask_extraction_indices[
|
||||
:, spec_info.positions.numel() * self.speculative_num_draft_tokens :
|
||||
].fill_(0)
|
||||
|
||||
mask = spec_info.custom_mask[mask_extraction_indices].view(
|
||||
-1, self.speculative_num_draft_tokens
|
||||
) # (bsz * draft_num, draft_num)
|
||||
col_indices = offsets.expand(
|
||||
mask.shape[0], self.speculative_num_draft_tokens
|
||||
)
|
||||
keys = torch.where(
|
||||
mask, col_indices, col_indices + self.speculative_num_draft_tokens
|
||||
)
|
||||
_, sort_order = torch.sort(keys, dim=1)
|
||||
|
||||
non_masked_page_table = (
|
||||
self.req_to_token[req_pool_indices, :]
|
||||
.gather(1, cols)
|
||||
.repeat_interleave(self.speculative_num_draft_tokens, dim=0)
|
||||
) # (bsz, draft_num)
|
||||
metadata_expand.page_table.copy_(
|
||||
non_masked_page_table.gather(1, sort_order)
|
||||
)
|
||||
metadata_expand.cache_seqlens_int32.copy_(
|
||||
mask.sum(dim=1).to(torch.int32)
|
||||
)
|
||||
metadata_expand.cu_seqlens_k.copy_(
|
||||
torch.nn.functional.pad(
|
||||
torch.cumsum(
|
||||
metadata_expand.cache_seqlens_int32,
|
||||
dim=0,
|
||||
dtype=torch.int32,
|
||||
),
|
||||
(1, 0),
|
||||
)
|
||||
)
|
||||
metadata_expand.max_seq_len_k = (
|
||||
metadata_expand.cache_seqlens_int32.max().item()
|
||||
)
|
||||
|
||||
if encoder_lens is not None:
|
||||
# Only support encoder size 1 for now
|
||||
metadata.encoder_max_seq_len_k = encoder_lens[0]
|
||||
@@ -1089,6 +1725,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
metadata.page_table[:, : metadata.max_seq_len_k].copy_(page_table)
|
||||
|
||||
self.forward_metadata = metadata
|
||||
self.forward_metadata_spec_decode_expand = metadata_expand
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
"""Get the fill value for sequence length in CUDA graph."""
|
||||
@@ -1139,12 +1776,6 @@ class FlashAttentionMultiStepBackend:
|
||||
self.model_runner = model_runner
|
||||
self.topk = topk
|
||||
self.speculative_num_steps = speculative_num_steps
|
||||
|
||||
# TODO: Support Topk > 1 for FlashAttentionBackend Spec Decoding
|
||||
assert (
|
||||
self.topk == 1
|
||||
), "speculative_eagle_topk must be 1 for FlashAttentionMultiStepBackend"
|
||||
|
||||
self.attn_backends = []
|
||||
for i in range(self.speculative_num_steps):
|
||||
self.attn_backends.append(
|
||||
|
||||
@@ -221,7 +221,16 @@ class ModelRunner:
|
||||
server_args = self.server_args
|
||||
|
||||
if server_args.attention_backend is None:
|
||||
# By default, use flashinfer for non-mla attention and triton for mla attention
|
||||
"""
|
||||
We auto select the fastest attention backend according to the current offering
|
||||
1. Models with MHA Architecture (e.g: Llama, QWen)
|
||||
1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
|
||||
1.2 In other cases, we will use flashinfer if available, otherwise use triton.
|
||||
2. Models with MLA Architecture and using FA3
|
||||
2.1 We will use FA3 backend on hopper.
|
||||
2.2 Otherwise, we will use triton backend.
|
||||
"""
|
||||
|
||||
if not self.use_mla_backend:
|
||||
if (
|
||||
is_hopper_with_cuda_12_3()
|
||||
@@ -234,9 +243,7 @@ class ModelRunner:
|
||||
"flashinfer" if is_flashinfer_available() else "triton"
|
||||
)
|
||||
else:
|
||||
if is_hopper_with_cuda_12_3() and is_no_spec_infer_or_topk_one(
|
||||
server_args
|
||||
):
|
||||
if is_hopper_with_cuda_12_3():
|
||||
server_args.attention_backend = "fa3"
|
||||
else:
|
||||
server_args.attention_backend = "triton"
|
||||
|
||||
@@ -359,7 +359,18 @@ class ServerArgs:
|
||||
|
||||
if self.page_size > 1 and self.speculative_eagle_topk > 1:
|
||||
self.speculative_eagle_topk = 1
|
||||
logger.info("speculative_eagle_topk is changed to 1 when page_size > 1")
|
||||
logger.info(
|
||||
"speculative_eagle_topk is adjusted to 1 when page_size > 1"
|
||||
)
|
||||
|
||||
if (
|
||||
self.speculative_eagle_topk == 1
|
||||
and self.speculative_num_draft_tokens != self.speculative_num_steps + 1
|
||||
):
|
||||
logger.info(
|
||||
"speculative_num_draft_tokens is adjusted to speculative_num_steps + 1 when speculative_eagle_topk == 1"
|
||||
)
|
||||
self.speculative_num_draft_tokens = self.speculative_num_steps + 1
|
||||
|
||||
# The token generated from the verify step is counted.
|
||||
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
|
||||
|
||||
@@ -1909,6 +1909,8 @@ def is_page_size_one(server_args):
|
||||
return server_args.page_size == 1
|
||||
|
||||
|
||||
# TODO(hebiao064): Accelerate FA3 Spec Decode with topk > 1.
|
||||
# TODO(hebiao064): Improve the acc rate for FA3 Spec Decode with topk == 1 and page_size > 1.
|
||||
def is_no_spec_infer_or_topk_one(server_args):
|
||||
return server_args.speculative_eagle_topk is None or (
|
||||
server_args.speculative_eagle_topk is not None
|
||||
|
||||
@@ -29,7 +29,7 @@ suites = {
|
||||
TestFile("test_chunked_prefill.py", 336),
|
||||
TestFile("test_eagle_infer.py", 500),
|
||||
TestFile("test_ebnf_constrained.py"),
|
||||
TestFile("test_fa3.py", 5),
|
||||
TestFile("test_fa3.py", 200),
|
||||
TestFile("test_fp8_kernel.py", 8),
|
||||
TestFile("test_embedding_openai_server.py", 36),
|
||||
TestFile("test_hidden_states.py", 55),
|
||||
|
||||
@@ -173,6 +173,60 @@ class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest):
|
||||
self.assertGreater(avg_spec_accept_length, 1.5)
|
||||
|
||||
|
||||
class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest):
|
||||
"""Test FlashAttention3 with speculative decode enabled, topk > 1"""
|
||||
|
||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
@classmethod
|
||||
def get_server_args(cls):
|
||||
args = super().get_server_args()
|
||||
args.extend(
|
||||
[
|
||||
"--cuda-graph-max-bs",
|
||||
"2",
|
||||
"--speculative-algorithm",
|
||||
"EAGLE3",
|
||||
"--speculative-draft",
|
||||
"jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B",
|
||||
"--speculative-num-steps",
|
||||
"5",
|
||||
"--speculative-eagle-topk",
|
||||
"4",
|
||||
"--speculative-num-draft-tokens",
|
||||
"8",
|
||||
"--dtype",
|
||||
"float16",
|
||||
]
|
||||
)
|
||||
return args
|
||||
|
||||
def test_gsm8k(self):
|
||||
"""
|
||||
Override the test_gsm8k to further test for average speculative accept length.
|
||||
"""
|
||||
requests.get(self.base_url + "/flush_cache")
|
||||
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=DATA_PATH,
|
||||
num_questions=200,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.base_url.split(":")[-1]),
|
||||
)
|
||||
metrics = run_eval_few_shot_gsm8k(args)
|
||||
print(metrics)
|
||||
|
||||
self.assertGreater(metrics["accuracy"], 0.60)
|
||||
|
||||
server_info = requests.get(self.base_url + "/get_server_info")
|
||||
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
|
||||
print(f"{avg_spec_accept_length=}")
|
||||
self.assertGreater(avg_spec_accept_length, 1.8)
|
||||
|
||||
|
||||
class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest):
|
||||
"""Test FlashAttention3 with speculative decode enabled."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user