Add Speculative Decoding Eagle3 topk > 1 (#5318)
Co-authored-by: Stefan He <hebiaobuaa@gmail.com> Co-authored-by: Yubo Wang <yubowang2019@gmail.com>
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -221,7 +221,16 @@ class ModelRunner:
|
||||
server_args = self.server_args
|
||||
|
||||
if server_args.attention_backend is None:
|
||||
# By default, use flashinfer for non-mla attention and triton for mla attention
|
||||
"""
|
||||
We auto select the fastest attention backend according to the current offering
|
||||
1. Models with MHA Architecture (e.g: Llama, QWen)
|
||||
1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
|
||||
1.2 In other cases, we will use flashinfer if available, otherwise use triton.
|
||||
2. Models with MLA Architecture and using FA3
|
||||
2.1 We will use FA3 backend on hopper.
|
||||
2.2 Otherwise, we will use triton backend.
|
||||
"""
|
||||
|
||||
if not self.use_mla_backend:
|
||||
if (
|
||||
is_hopper_with_cuda_12_3()
|
||||
@@ -234,9 +243,7 @@ class ModelRunner:
|
||||
"flashinfer" if is_flashinfer_available() else "triton"
|
||||
)
|
||||
else:
|
||||
if is_hopper_with_cuda_12_3() and is_no_spec_infer_or_topk_one(
|
||||
server_args
|
||||
):
|
||||
if is_hopper_with_cuda_12_3():
|
||||
server_args.attention_backend = "fa3"
|
||||
else:
|
||||
server_args.attention_backend = "triton"
|
||||
|
||||
@@ -359,7 +359,18 @@ class ServerArgs:
|
||||
|
||||
if self.page_size > 1 and self.speculative_eagle_topk > 1:
|
||||
self.speculative_eagle_topk = 1
|
||||
logger.info("speculative_eagle_topk is changed to 1 when page_size > 1")
|
||||
logger.info(
|
||||
"speculative_eagle_topk is adjusted to 1 when page_size > 1"
|
||||
)
|
||||
|
||||
if (
|
||||
self.speculative_eagle_topk == 1
|
||||
and self.speculative_num_draft_tokens != self.speculative_num_steps + 1
|
||||
):
|
||||
logger.info(
|
||||
"speculative_num_draft_tokens is adjusted to speculative_num_steps + 1 when speculative_eagle_topk == 1"
|
||||
)
|
||||
self.speculative_num_draft_tokens = self.speculative_num_steps + 1
|
||||
|
||||
# The token generated from the verify step is counted.
|
||||
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
|
||||
|
||||
@@ -1909,6 +1909,8 @@ def is_page_size_one(server_args):
|
||||
return server_args.page_size == 1
|
||||
|
||||
|
||||
# TODO(hebiao064): Accelerate FA3 Spec Decode with topk > 1.
|
||||
# TODO(hebiao064): Improve the acc rate for FA3 Spec Decode with topk == 1 and page_size > 1.
|
||||
def is_no_spec_infer_or_topk_one(server_args):
|
||||
return server_args.speculative_eagle_topk is None or (
|
||||
server_args.speculative_eagle_topk is not None
|
||||
|
||||
@@ -29,7 +29,7 @@ suites = {
|
||||
TestFile("test_chunked_prefill.py", 336),
|
||||
TestFile("test_eagle_infer.py", 500),
|
||||
TestFile("test_ebnf_constrained.py"),
|
||||
TestFile("test_fa3.py", 5),
|
||||
TestFile("test_fa3.py", 200),
|
||||
TestFile("test_fp8_kernel.py", 8),
|
||||
TestFile("test_embedding_openai_server.py", 36),
|
||||
TestFile("test_hidden_states.py", 55),
|
||||
|
||||
@@ -173,6 +173,60 @@ class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest):
|
||||
self.assertGreater(avg_spec_accept_length, 1.5)
|
||||
|
||||
|
||||
class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest):
|
||||
"""Test FlashAttention3 with speculative decode enabled, topk > 1"""
|
||||
|
||||
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
|
||||
@classmethod
|
||||
def get_server_args(cls):
|
||||
args = super().get_server_args()
|
||||
args.extend(
|
||||
[
|
||||
"--cuda-graph-max-bs",
|
||||
"2",
|
||||
"--speculative-algorithm",
|
||||
"EAGLE3",
|
||||
"--speculative-draft",
|
||||
"jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B",
|
||||
"--speculative-num-steps",
|
||||
"5",
|
||||
"--speculative-eagle-topk",
|
||||
"4",
|
||||
"--speculative-num-draft-tokens",
|
||||
"8",
|
||||
"--dtype",
|
||||
"float16",
|
||||
]
|
||||
)
|
||||
return args
|
||||
|
||||
def test_gsm8k(self):
|
||||
"""
|
||||
Override the test_gsm8k to further test for average speculative accept length.
|
||||
"""
|
||||
requests.get(self.base_url + "/flush_cache")
|
||||
|
||||
args = SimpleNamespace(
|
||||
num_shots=5,
|
||||
data_path=DATA_PATH,
|
||||
num_questions=200,
|
||||
max_new_tokens=512,
|
||||
parallel=128,
|
||||
host="http://127.0.0.1",
|
||||
port=int(self.base_url.split(":")[-1]),
|
||||
)
|
||||
metrics = run_eval_few_shot_gsm8k(args)
|
||||
print(metrics)
|
||||
|
||||
self.assertGreater(metrics["accuracy"], 0.60)
|
||||
|
||||
server_info = requests.get(self.base_url + "/get_server_info")
|
||||
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
|
||||
print(f"{avg_spec_accept_length=}")
|
||||
self.assertGreater(avg_spec_accept_length, 1.8)
|
||||
|
||||
|
||||
class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest):
|
||||
"""Test FlashAttention3 with speculative decode enabled."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user