### What this PR does / why we need it?
This PR integrate suffix decoding (https://arxiv.org/abs/2411.04975)
from vllm (https://github.com/vllm-project/vllm/pull/25784)
#
Suffix Decoding is a dynamic n-gram matching method that:
1. Uses suffix trees to generate speculative tokens quickly using branch
frequency counts.
2. Can keep a history of prior model responses, which tends to work very
well with repetitive agentic use cases.
3. Can be dynamically updated with newly generated tokens, and FIFO
eviction of older requests.
#
### Does this PR introduce _any_ user-facing change?
This feature should be implemented as opt-in and remain seamless for
users who do not require suffix speculative decoding.
For users who wish to enable it, they must first install
arctic-inference:
`pip install arctic-inference
`
After installation, the suffix speculative decoding feature can be
enabled using the following speculative config:
`--speculative_config '{"method": "suffix", "num_speculative_tokens":
5}'
`
### How was this patch tested?
This PR is currently being tested on vLLM
main:83f478bb19
with PR https://github.com/vllm-project/vllm/pull/25784
In our previous testing, suffix decoding achieved a 13%-30% throughput
improvement over n-gram on the sonnet dataset, tested on vllm-ascend
v0.9.1 with concurrency ranging from 2 to 40.
- vLLM version: v0.11.2
---------
Signed-off-by: fluctlux <38945811+fluctlux@users.noreply.github.com>
235 lines
11 KiB
Python
235 lines
11 KiB
Python
import ast
|
|
|
|
from vllm.config.speculative import SpeculativeConfig
|
|
from vllm.logger import logger
|
|
|
|
|
|
def __post_init__(self):
|
|
|
|
# Note: "method" is a new parameter that helps to extend the
|
|
# configuration of non-model-based proposers, and the "model" parameter
|
|
# will be used to set the draft model, eagle head, or additional weight
|
|
# when needed. If users do not specify "method", the speculative method
|
|
# will be detected automatically if possible. If the speculative method
|
|
# can not be detected, it will be considered as the "draft_model" by
|
|
# default.
|
|
|
|
if self.model is None and self.num_speculative_tokens is not None:
|
|
# TODO(Shangming): Refactor mtp configuration logic when supporting
|
|
if (self.target_model_config
|
|
and self.target_model_config.hf_text_config.model_type
|
|
in ("deepseek_v3", "deepseek_v32", "mimo", "ernie4_5_moe",
|
|
"qwen3_next")):
|
|
# use the draft model from the same model:
|
|
self.model = self.target_model_config.model
|
|
# Align the quantization of draft model for cases such as
|
|
# --quantization fp8 with a bf16 checkpoint.
|
|
if not self.quantization:
|
|
self.quantization = self.target_model_config.quantization
|
|
elif self.method in ("ngram", "[ngram]"):
|
|
self.model = "ngram"
|
|
elif self.method == "suffix":
|
|
self.model = "suffix"
|
|
else:
|
|
raise ValueError("num_speculative_tokens was provided but without "
|
|
"speculative model.")
|
|
|
|
# Automatically configure the method for ngram when "model" is used
|
|
# instead of "method"
|
|
if self.method is None and (self.model is not None
|
|
and self.model in ("ngram", "[ngram]")):
|
|
self.method = "ngram"
|
|
|
|
if self.method in ("ngram", "[ngram]"):
|
|
# Unified to "ngram" internally
|
|
self.method = "ngram"
|
|
# Set default values if not provided
|
|
if (self.prompt_lookup_min is None and self.prompt_lookup_max is None):
|
|
# TODO(woosuk): Tune these values. They are arbitrarily chosen.
|
|
self.prompt_lookup_min = 5
|
|
self.prompt_lookup_max = 5
|
|
elif self.prompt_lookup_min is None:
|
|
assert self.prompt_lookup_max is not None
|
|
self.prompt_lookup_min = self.prompt_lookup_max
|
|
elif self.prompt_lookup_max is None:
|
|
assert self.prompt_lookup_min is not None
|
|
self.prompt_lookup_max = self.prompt_lookup_min
|
|
|
|
# Validate values
|
|
if self.prompt_lookup_min < 1:
|
|
raise ValueError(
|
|
f"prompt_lookup_min={self.prompt_lookup_min} must be > 0")
|
|
if self.prompt_lookup_max < 1:
|
|
raise ValueError(
|
|
f"prompt_lookup_max={self.prompt_lookup_max} must be > 0")
|
|
if self.prompt_lookup_min > self.prompt_lookup_max:
|
|
raise ValueError(
|
|
f"prompt_lookup_min={self.prompt_lookup_min} must "
|
|
f"be <= prompt_lookup_max={self.prompt_lookup_max}")
|
|
|
|
# TODO: current we still need extract vocab_size from target model
|
|
# config, in future, we may try refactor it out, and set
|
|
# draft related config as None here.
|
|
self.draft_model_config = self.target_model_config
|
|
self.draft_parallel_config = self.target_parallel_config
|
|
elif self.method == "suffix":
|
|
self.draft_model_config = self.target_model_config
|
|
self.draft_parallel_config = self.target_parallel_config
|
|
self._validate_suffix_decoding()
|
|
else:
|
|
self.prompt_lookup_max = 0
|
|
self.prompt_lookup_min = 0
|
|
|
|
if self.model is not None:
|
|
# TODO: Move this import to the top once `ModelConfig`
|
|
# lives in `vllm.config.model`.
|
|
from vllm.config import ModelConfig
|
|
self.draft_model_config = ModelConfig(
|
|
model=self.model,
|
|
runner="draft",
|
|
tokenizer=self.target_model_config.tokenizer,
|
|
tokenizer_mode=self.target_model_config.tokenizer_mode,
|
|
trust_remote_code=self.target_model_config.trust_remote_code,
|
|
allowed_local_media_path=self.target_model_config.
|
|
allowed_local_media_path,
|
|
allowed_media_domains=self.target_model_config.
|
|
allowed_media_domains,
|
|
dtype=self.target_model_config.dtype,
|
|
seed=self.target_model_config.seed,
|
|
revision=self.revision,
|
|
code_revision=self.code_revision,
|
|
tokenizer_revision=self.target_model_config.tokenizer_revision,
|
|
spec_target_max_model_len=self.target_model_config.
|
|
max_model_len,
|
|
quantization=self.quantization,
|
|
enforce_eager=self.target_model_config.enforce_eager,
|
|
max_logprobs=self.target_model_config.max_logprobs,
|
|
hf_overrides=SpeculativeConfig.hf_config_override,
|
|
)
|
|
|
|
# Automatically detect the method
|
|
if self.method in ('eagle', 'eagle3'):
|
|
pass
|
|
# examples:
|
|
# yuhuili/EAGLE-LLaMA3-Instruct-8B
|
|
# yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
|
|
# AngelSlim/Qwen3-8B_eagle3
|
|
elif "eagle-" in self.draft_model_config.model.lower():
|
|
self.method = "eagle"
|
|
elif "eagle3" in self.draft_model_config.model.lower():
|
|
self.method = "eagle3"
|
|
elif self.draft_model_config.hf_config.model_type == "medusa":
|
|
self.method = "medusa"
|
|
elif (self.draft_model_config.hf_config.model_type ==
|
|
"mlp_speculator"):
|
|
self.method = "mlp_speculator"
|
|
elif (self.draft_model_config.hf_config.model_type
|
|
in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")):
|
|
self.method = "deepseek_mtp"
|
|
if self.num_speculative_tokens > 1:
|
|
logger.warning(
|
|
"All Deepseek MTP models only have " \
|
|
"one layer. Might need some code changes " \
|
|
"to support multiple layers."
|
|
)
|
|
elif (self.draft_model_config.hf_config.model_type == "ernie_mtp"):
|
|
self.method = "ernie_mtp"
|
|
if self.num_speculative_tokens > 1:
|
|
logger.warning(
|
|
"All Ernie MTP models only have " \
|
|
"one layer. Might need some code changes " \
|
|
"to support multiple layers."
|
|
)
|
|
elif (self.draft_model_config.hf_config.model_type ==
|
|
"qwen3_next_mtp"):
|
|
self.method = "qwen3_next_mtp"
|
|
if self.num_speculative_tokens > 1:
|
|
logger.warning(
|
|
"All Qwen3Next MTP models only have " \
|
|
"one layer. Might need some code changes " \
|
|
"to support multiple layers."
|
|
)
|
|
elif (self.draft_model_config.hf_config.model_type
|
|
in ("longcat_flash_mtp")):
|
|
self.method = "longcat_flash_mtp"
|
|
if self.num_speculative_tokens > 1:
|
|
logger.warning(
|
|
"LongCat MTP models only have " \
|
|
"one layer. Might need some code changes " \
|
|
"to support multiple layers."
|
|
)
|
|
else:
|
|
self.method = "draft_model"
|
|
raise NotImplementedError(
|
|
"Speculative decoding with draft model is not "
|
|
"supported yet. Please consider using other "
|
|
"speculative decoding methods such as ngram, medusa, "
|
|
"eagle, or deepseek_mtp.")
|
|
|
|
# Replace hf_config for EAGLE draft_model
|
|
if self.method in ("eagle", "eagle3"):
|
|
from vllm.transformers_utils.configs import SpeculatorsConfig
|
|
from vllm.transformers_utils.configs.eagle import EAGLEConfig
|
|
|
|
if isinstance(self.draft_model_config.hf_config,
|
|
(EAGLEConfig, SpeculatorsConfig)):
|
|
pass
|
|
else:
|
|
eagle_config = EAGLEConfig(
|
|
self.draft_model_config.hf_config,
|
|
method=self.method,
|
|
model_type="eagle")
|
|
self.draft_model_config.hf_config = eagle_config
|
|
|
|
if (self.num_speculative_tokens is not None
|
|
and hasattr(self.draft_model_config.hf_config,
|
|
"num_lookahead_tokens")):
|
|
self.draft_model_config.hf_config.num_lookahead_tokens = \
|
|
self.num_speculative_tokens
|
|
|
|
n_predict = getattr(self.draft_model_config.hf_config, "n_predict",
|
|
None)
|
|
if n_predict is not None:
|
|
if self.num_speculative_tokens is None:
|
|
# Default to max value defined in draft model config.
|
|
self.num_speculative_tokens = n_predict
|
|
elif self.num_speculative_tokens > n_predict and \
|
|
self.num_speculative_tokens % n_predict != 0:
|
|
# Ensure divisibility for MTP module reuse.
|
|
raise ValueError(
|
|
f"num_speculative_tokens:{self.num_speculative_tokens}"
|
|
f" must be divisible by {n_predict=}")
|
|
|
|
if self.speculative_token_tree is None:
|
|
# Generate chain of tokens.
|
|
self.speculative_token_tree = str([
|
|
(i + 1) * (0, ) for i in range(self.num_speculative_tokens)
|
|
])
|
|
else:
|
|
# Sort the token tree breadth-first.
|
|
tree_choices = ast.literal_eval(self.speculative_token_tree)
|
|
self.speculative_token_tree = str(
|
|
sorted(tree_choices, key=lambda t: (len(t), t)))
|
|
|
|
self.draft_tensor_parallel_size = \
|
|
SpeculativeConfig._verify_and_get_draft_tp(
|
|
self.target_parallel_config,
|
|
self.draft_tensor_parallel_size,
|
|
self.draft_model_config.hf_config
|
|
)
|
|
|
|
self.draft_model_config.max_model_len = (
|
|
SpeculativeConfig._maybe_override_draft_max_model_len(
|
|
self.max_model_len,
|
|
self.draft_model_config.max_model_len,
|
|
self.target_model_config.max_model_len,
|
|
))
|
|
|
|
self.draft_parallel_config = (
|
|
SpeculativeConfig.create_draft_parallel_config(
|
|
self.target_parallel_config,
|
|
self.draft_tensor_parallel_size))
|
|
|
|
|
|
SpeculativeConfig.__post_init__ = __post_init__
|