Add Speculative Decoding Eagle3 topk > 1 (#5318)
Co-authored-by: Stefan He <hebiaobuaa@gmail.com> Co-authored-by: Yubo Wang <yubowang2019@gmail.com>
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -221,7 +221,16 @@ class ModelRunner:
|
|||||||
server_args = self.server_args
|
server_args = self.server_args
|
||||||
|
|
||||||
if server_args.attention_backend is None:
|
if server_args.attention_backend is None:
|
||||||
# By default, use flashinfer for non-mla attention and triton for mla attention
|
"""
|
||||||
|
We auto select the fastest attention backend according to the current offering
|
||||||
|
1. Models with MHA Architecture (e.g: Llama, QWen)
|
||||||
|
1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
|
||||||
|
1.2 In other cases, we will use flashinfer if available, otherwise use triton.
|
||||||
|
2. Models with MLA Architecture and using FA3
|
||||||
|
2.1 We will use FA3 backend on hopper.
|
||||||
|
2.2 Otherwise, we will use triton backend.
|
||||||
|
"""
|
||||||
|
|
||||||
if not self.use_mla_backend:
|
if not self.use_mla_backend:
|
||||||
if (
|
if (
|
||||||
is_hopper_with_cuda_12_3()
|
is_hopper_with_cuda_12_3()
|
||||||
@@ -234,9 +243,7 @@ class ModelRunner:
|
|||||||
"flashinfer" if is_flashinfer_available() else "triton"
|
"flashinfer" if is_flashinfer_available() else "triton"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if is_hopper_with_cuda_12_3() and is_no_spec_infer_or_topk_one(
|
if is_hopper_with_cuda_12_3():
|
||||||
server_args
|
|
||||||
):
|
|
||||||
server_args.attention_backend = "fa3"
|
server_args.attention_backend = "fa3"
|
||||||
else:
|
else:
|
||||||
server_args.attention_backend = "triton"
|
server_args.attention_backend = "triton"
|
||||||
|
|||||||
@@ -359,7 +359,18 @@ class ServerArgs:
|
|||||||
|
|
||||||
if self.page_size > 1 and self.speculative_eagle_topk > 1:
|
if self.page_size > 1 and self.speculative_eagle_topk > 1:
|
||||||
self.speculative_eagle_topk = 1
|
self.speculative_eagle_topk = 1
|
||||||
logger.info("speculative_eagle_topk is changed to 1 when page_size > 1")
|
logger.info(
|
||||||
|
"speculative_eagle_topk is adjusted to 1 when page_size > 1"
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
self.speculative_eagle_topk == 1
|
||||||
|
and self.speculative_num_draft_tokens != self.speculative_num_steps + 1
|
||||||
|
):
|
||||||
|
logger.info(
|
||||||
|
"speculative_num_draft_tokens is adjusted to speculative_num_steps + 1 when speculative_eagle_topk == 1"
|
||||||
|
)
|
||||||
|
self.speculative_num_draft_tokens = self.speculative_num_steps + 1
|
||||||
|
|
||||||
# The token generated from the verify step is counted.
|
# The token generated from the verify step is counted.
|
||||||
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
|
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
|
||||||
|
|||||||
@@ -1909,6 +1909,8 @@ def is_page_size_one(server_args):
|
|||||||
return server_args.page_size == 1
|
return server_args.page_size == 1
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(hebiao064): Accelerate FA3 Spec Decode with topk > 1.
|
||||||
|
# TODO(hebiao064): Improve the acc rate for FA3 Spec Decode with topk == 1 and page_size > 1.
|
||||||
def is_no_spec_infer_or_topk_one(server_args):
|
def is_no_spec_infer_or_topk_one(server_args):
|
||||||
return server_args.speculative_eagle_topk is None or (
|
return server_args.speculative_eagle_topk is None or (
|
||||||
server_args.speculative_eagle_topk is not None
|
server_args.speculative_eagle_topk is not None
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ suites = {
|
|||||||
TestFile("test_chunked_prefill.py", 336),
|
TestFile("test_chunked_prefill.py", 336),
|
||||||
TestFile("test_eagle_infer.py", 500),
|
TestFile("test_eagle_infer.py", 500),
|
||||||
TestFile("test_ebnf_constrained.py"),
|
TestFile("test_ebnf_constrained.py"),
|
||||||
TestFile("test_fa3.py", 5),
|
TestFile("test_fa3.py", 200),
|
||||||
TestFile("test_fp8_kernel.py", 8),
|
TestFile("test_fp8_kernel.py", 8),
|
||||||
TestFile("test_embedding_openai_server.py", 36),
|
TestFile("test_embedding_openai_server.py", 36),
|
||||||
TestFile("test_hidden_states.py", 55),
|
TestFile("test_hidden_states.py", 55),
|
||||||
|
|||||||
@@ -173,6 +173,60 @@ class TestFlashAttention3SpeculativeDecode(BaseFlashAttentionTest):
|
|||||||
self.assertGreater(avg_spec_accept_length, 1.5)
|
self.assertGreater(avg_spec_accept_length, 1.5)
|
||||||
|
|
||||||
|
|
||||||
|
class TestFlashAttention3SpeculativeDecodeTopk(BaseFlashAttentionTest):
|
||||||
|
"""Test FlashAttention3 with speculative decode enabled, topk > 1"""
|
||||||
|
|
||||||
|
model = "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_server_args(cls):
|
||||||
|
args = super().get_server_args()
|
||||||
|
args.extend(
|
||||||
|
[
|
||||||
|
"--cuda-graph-max-bs",
|
||||||
|
"2",
|
||||||
|
"--speculative-algorithm",
|
||||||
|
"EAGLE3",
|
||||||
|
"--speculative-draft",
|
||||||
|
"jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B",
|
||||||
|
"--speculative-num-steps",
|
||||||
|
"5",
|
||||||
|
"--speculative-eagle-topk",
|
||||||
|
"4",
|
||||||
|
"--speculative-num-draft-tokens",
|
||||||
|
"8",
|
||||||
|
"--dtype",
|
||||||
|
"float16",
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return args
|
||||||
|
|
||||||
|
def test_gsm8k(self):
|
||||||
|
"""
|
||||||
|
Override the test_gsm8k to further test for average speculative accept length.
|
||||||
|
"""
|
||||||
|
requests.get(self.base_url + "/flush_cache")
|
||||||
|
|
||||||
|
args = SimpleNamespace(
|
||||||
|
num_shots=5,
|
||||||
|
data_path=DATA_PATH,
|
||||||
|
num_questions=200,
|
||||||
|
max_new_tokens=512,
|
||||||
|
parallel=128,
|
||||||
|
host="http://127.0.0.1",
|
||||||
|
port=int(self.base_url.split(":")[-1]),
|
||||||
|
)
|
||||||
|
metrics = run_eval_few_shot_gsm8k(args)
|
||||||
|
print(metrics)
|
||||||
|
|
||||||
|
self.assertGreater(metrics["accuracy"], 0.60)
|
||||||
|
|
||||||
|
server_info = requests.get(self.base_url + "/get_server_info")
|
||||||
|
avg_spec_accept_length = server_info.json()["avg_spec_accept_length"]
|
||||||
|
print(f"{avg_spec_accept_length=}")
|
||||||
|
self.assertGreater(avg_spec_accept_length, 1.8)
|
||||||
|
|
||||||
|
|
||||||
class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest):
|
class TestFlashAttention3MLASpeculativeDecode(BaseFlashAttentionTest):
|
||||||
"""Test FlashAttention3 with speculative decode enabled."""
|
"""Test FlashAttention3 with speculative decode enabled."""
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user