Add Speculative Decoding Eagle3 topk > 1 (#5318)
Co-authored-by: Stefan He <hebiaobuaa@gmail.com> Co-authored-by: Yubo Wang <yubowang2019@gmail.com>
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -221,7 +221,16 @@ class ModelRunner:
|
||||
server_args = self.server_args
|
||||
|
||||
if server_args.attention_backend is None:
|
||||
# By default, use flashinfer for non-mla attention and triton for mla attention
|
||||
"""
|
||||
We auto select the fastest attention backend according to the current offering
|
||||
1. Models with MHA Architecture (e.g: Llama, QWen)
|
||||
1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
|
||||
1.2 In other cases, we will use flashinfer if available, otherwise use triton.
|
||||
2. Models with MLA Architecture and using FA3
|
||||
2.1 We will use FA3 backend on hopper.
|
||||
2.2 Otherwise, we will use triton backend.
|
||||
"""
|
||||
|
||||
if not self.use_mla_backend:
|
||||
if (
|
||||
is_hopper_with_cuda_12_3()
|
||||
@@ -234,9 +243,7 @@ class ModelRunner:
|
||||
"flashinfer" if is_flashinfer_available() else "triton"
|
||||
)
|
||||
else:
|
||||
if is_hopper_with_cuda_12_3() and is_no_spec_infer_or_topk_one(
|
||||
server_args
|
||||
):
|
||||
if is_hopper_with_cuda_12_3():
|
||||
server_args.attention_backend = "fa3"
|
||||
else:
|
||||
server_args.attention_backend = "triton"
|
||||
|
||||
@@ -359,7 +359,18 @@ class ServerArgs:
|
||||
|
||||
if self.page_size > 1 and self.speculative_eagle_topk > 1:
|
||||
self.speculative_eagle_topk = 1
|
||||
logger.info("speculative_eagle_topk is changed to 1 when page_size > 1")
|
||||
logger.info(
|
||||
"speculative_eagle_topk is adjusted to 1 when page_size > 1"
|
||||
)
|
||||
|
||||
if (
|
||||
self.speculative_eagle_topk == 1
|
||||
and self.speculative_num_draft_tokens != self.speculative_num_steps + 1
|
||||
):
|
||||
logger.info(
|
||||
"speculative_num_draft_tokens is adjusted to speculative_num_steps + 1 when speculative_eagle_topk == 1"
|
||||
)
|
||||
self.speculative_num_draft_tokens = self.speculative_num_steps + 1
|
||||
|
||||
# The token generated from the verify step is counted.
|
||||
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
|
||||
|
||||
@@ -1909,6 +1909,8 @@ def is_page_size_one(server_args):
|
||||
return server_args.page_size == 1
|
||||
|
||||
|
||||
# TODO(hebiao064): Accelerate FA3 Spec Decode with topk > 1.
|
||||
# TODO(hebiao064): Improve the acc rate for FA3 Spec Decode with topk == 1 and page_size > 1.
|
||||
def is_no_spec_infer_or_topk_one(server_args):
|
||||
return server_args.speculative_eagle_topk is None or (
|
||||
server_args.speculative_eagle_topk is not None
|
||||
|
||||
Reference in New Issue
Block a user