Add Speculative Decoding Eagle3 topk > 1 (#5318)

Co-authored-by: Stefan He <hebiaobuaa@gmail.com>
Co-authored-by: Yubo Wang <yubowang2019@gmail.com>
This commit is contained in:
Qingquan Song
2025-04-20 22:58:28 -07:00
committed by GitHub
parent eef9433b46
commit 188f0955fa
6 changed files with 872 additions and 167 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -221,7 +221,16 @@ class ModelRunner:
server_args = self.server_args
if server_args.attention_backend is None:
# By default, use flashinfer for non-mla attention and triton for mla attention
"""
We auto select the fastest attention backend according to the current offering
1. Models with MHA Architecture (e.g: Llama, QWen)
1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
1.2 In other cases, we will use flashinfer if available, otherwise use triton.
2. Models with MLA Architecture and using FA3
2.1 We will use FA3 backend on hopper.
2.2 Otherwise, we will use triton backend.
"""
if not self.use_mla_backend:
if (
is_hopper_with_cuda_12_3()
@@ -234,9 +243,7 @@ class ModelRunner:
"flashinfer" if is_flashinfer_available() else "triton"
)
else:
if is_hopper_with_cuda_12_3() and is_no_spec_infer_or_topk_one(
server_args
):
if is_hopper_with_cuda_12_3():
server_args.attention_backend = "fa3"
else:
server_args.attention_backend = "triton"

View File

@@ -359,7 +359,18 @@ class ServerArgs:
if self.page_size > 1 and self.speculative_eagle_topk > 1:
self.speculative_eagle_topk = 1
logger.info("speculative_eagle_topk is changed to 1 when page_size > 1")
logger.info(
"speculative_eagle_topk is adjusted to 1 when page_size > 1"
)
if (
self.speculative_eagle_topk == 1
and self.speculative_num_draft_tokens != self.speculative_num_steps + 1
):
logger.info(
"speculative_num_draft_tokens is adjusted to speculative_num_steps + 1 when speculative_eagle_topk == 1"
)
self.speculative_num_draft_tokens = self.speculative_num_steps + 1
# The token generated from the verify step is counted.
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.

View File

@@ -1909,6 +1909,8 @@ def is_page_size_one(server_args):
return server_args.page_size == 1
# TODO(hebiao064): Accelerate FA3 Spec Decode with topk > 1.
# TODO(hebiao064): Improve the acc rate for FA3 Spec Decode with topk == 1 and page_size > 1.
def is_no_spec_infer_or_topk_one(server_args):
return server_args.speculative_eagle_topk is None or (
server_args.speculative_eagle_topk is not None