Support page size > 1 + eagle (#4908)
This commit is contained in:
@@ -15,6 +15,7 @@
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
@@ -132,9 +133,9 @@ class ServerArgs:
|
||||
# Speculative decoding
|
||||
speculative_algorithm: Optional[str] = None
|
||||
speculative_draft_model_path: Optional[str] = None
|
||||
speculative_num_steps: int = 5
|
||||
speculative_eagle_topk: int = 4
|
||||
speculative_num_draft_tokens: int = 8
|
||||
speculative_num_steps: Optional[int] = None
|
||||
speculative_eagle_topk: Optional[int] = None
|
||||
speculative_num_draft_tokens: Optional[int] = None
|
||||
speculative_accept_threshold_single: float = 1.0
|
||||
speculative_accept_threshold_acc: float = 1.0
|
||||
speculative_token_map: Optional[str] = None
|
||||
@@ -313,12 +314,29 @@ class ServerArgs:
|
||||
or self.speculative_algorithm == "EAGLE3"
|
||||
):
|
||||
if self.max_running_requests is None:
|
||||
self.max_running_requests = 32
|
||||
self.max_running_requests = 48
|
||||
self.disable_overlap_schedule = True
|
||||
logger.info(
|
||||
"Overlap scheduler is disabled because of using "
|
||||
"eagle speculative decoding."
|
||||
)
|
||||
|
||||
# Auto choose parameters
|
||||
if self.speculative_num_steps is None:
|
||||
assert (
|
||||
self.speculative_eagle_topk is None
|
||||
and self.speculative_num_draft_tokens is None
|
||||
)
|
||||
(
|
||||
self.speculative_num_steps,
|
||||
self.speculative_eagle_topk,
|
||||
self.speculative_num_draft_tokens,
|
||||
) = auto_choose_speculative_params(self)
|
||||
|
||||
if self.page_size > 1 and self.speculative_eagle_topk > 1:
|
||||
self.speculative_eagle_topk = 1
|
||||
logger.info("speculative_eagle_topk is changed to 1 when page_size > 1")
|
||||
|
||||
# The token generated from the verify step is counted.
|
||||
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
|
||||
# assert self.speculative_num_steps < self.speculative_num_draft_tokens
|
||||
@@ -1253,3 +1271,33 @@ class DeprecatedAction(argparse.Action):
|
||||
|
||||
def __call__(self, parser, namespace, values, option_string=None):
|
||||
raise ValueError(self.help)
|
||||
|
||||
|
||||
def auto_choose_speculative_params(self: ServerArgs):
|
||||
"""
|
||||
Automatically choose the parameters for speculative decoding.
|
||||
|
||||
You can tune them on your own models and prompts with scripts/playground/bench_speculative.py
|
||||
"""
|
||||
if self.decrypted_config_file:
|
||||
config_path = self.decrypted_config_file
|
||||
else:
|
||||
config_path = os.path.join(self.model_path, "config.json")
|
||||
if not os.path.exists(config_path):
|
||||
raise ValueError(f"{config_path} is not found.")
|
||||
|
||||
config = json.load(open(config_path))
|
||||
|
||||
arch = config.get("architectures", ["Unknown"])[0]
|
||||
|
||||
if arch in ["LlamaForCausalLM"]:
|
||||
# The default value for llama
|
||||
return (5, 4, 8)
|
||||
elif arch in ["DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"]:
|
||||
# The default value for deepseek
|
||||
return (5, 4, 8)
|
||||
elif arch in ["Grok1ForCausalLM", "Grok1VForCausalLM"]:
|
||||
return (5, 4, 8)
|
||||
else:
|
||||
# The default value for all other models
|
||||
return (5, 4, 8)
|
||||
|
||||
Reference in New Issue
Block a user