Support page size > 1 + eagle (#4908)

This commit is contained in:
Lianmin Zheng
2025-03-30 00:46:23 -07:00
committed by GitHub
parent 5ec5eaf760
commit b26bc86b36
16 changed files with 374 additions and 71 deletions

View File

@@ -15,6 +15,7 @@
import argparse
import dataclasses
import json
import logging
import os
import random
@@ -132,9 +133,9 @@ class ServerArgs:
# Speculative decoding
speculative_algorithm: Optional[str] = None
speculative_draft_model_path: Optional[str] = None
speculative_num_steps: int = 5
speculative_eagle_topk: int = 4
speculative_num_draft_tokens: int = 8
speculative_num_steps: Optional[int] = None
speculative_eagle_topk: Optional[int] = None
speculative_num_draft_tokens: Optional[int] = None
speculative_accept_threshold_single: float = 1.0
speculative_accept_threshold_acc: float = 1.0
speculative_token_map: Optional[str] = None
@@ -313,12 +314,29 @@ class ServerArgs:
or self.speculative_algorithm == "EAGLE3"
):
if self.max_running_requests is None:
self.max_running_requests = 32
self.max_running_requests = 48
self.disable_overlap_schedule = True
logger.info(
"Overlap scheduler is disabled because of using "
"eagle speculative decoding."
)
# Auto choose parameters
if self.speculative_num_steps is None:
assert (
self.speculative_eagle_topk is None
and self.speculative_num_draft_tokens is None
)
(
self.speculative_num_steps,
self.speculative_eagle_topk,
self.speculative_num_draft_tokens,
) = auto_choose_speculative_params(self)
if self.page_size > 1 and self.speculative_eagle_topk > 1:
self.speculative_eagle_topk = 1
logger.info("speculative_eagle_topk is changed to 1 when page_size > 1")
# The token generated from the verify step is counted.
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
# assert self.speculative_num_steps < self.speculative_num_draft_tokens
@@ -1253,3 +1271,33 @@ class DeprecatedAction(argparse.Action):
def __call__(self, parser, namespace, values, option_string=None):
raise ValueError(self.help)
def auto_choose_speculative_params(self: ServerArgs):
"""
Automatically choose the parameters for speculative decoding.
You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
"""
if self.decrypted_config_file:
config_path = self.decrypted_config_file
else:
config_path = os.path.join(self.model_path, "config.json")
if not os.path.exists(config_path):
raise ValueError(f"{config_path} is not found.")
config = json.load(open(config_path))
arch = config.get("architectures", ["Unknown"])[0]
if arch in ["LlamaForCausalLM"]:
# The default value for llama
return (5, 4, 8)
elif arch in ["DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"]:
# The default value for deepseek
return (5, 4, 8)
elif arch in ["Grok1ForCausalLM", "Grok1VForCausalLM"]:
return (5, 4, 8)
else:
# The default value for all other models
return (5, 4, 8)