Sync from v0.13
This commit is contained in:
644
vllm/config/speculative.py
Normal file
644
vllm/config/speculative.py
Normal file
@@ -0,0 +1,644 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ast
|
||||
from typing import TYPE_CHECKING, Any, Literal, get_args
|
||||
|
||||
from pydantic import Field, SkipValidation, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
from typing_extensions import Self
|
||||
|
||||
from vllm.config.model import ModelConfig
|
||||
from vllm.config.parallel import ParallelConfig
|
||||
from vllm.config.utils import config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils.hashing import safe_hash
|
||||
from vllm.utils.import_utils import LazyLoader, has_arctic_inference
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
import vllm.model_executor.layers.quantization as me_quant
|
||||
else:
|
||||
PretrainedConfig = Any
|
||||
|
||||
me_quant = LazyLoader(
|
||||
"model_executor", globals(), "vllm.model_executor.layers.quantization"
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
MTPModelTypes = Literal[
|
||||
"deepseek_mtp",
|
||||
"mimo_mtp",
|
||||
"glm4_moe_mtp",
|
||||
"ernie_mtp",
|
||||
"qwen3_next_mtp",
|
||||
"longcat_flash_mtp",
|
||||
"mtp",
|
||||
"pangu_ultra_moe_mtp",
|
||||
]
|
||||
EagleModelTypes = Literal["eagle", "eagle3", MTPModelTypes]
|
||||
SpeculativeMethod = Literal[
|
||||
"ngram",
|
||||
"medusa",
|
||||
"mlp_speculator",
|
||||
"draft_model",
|
||||
"suffix",
|
||||
EagleModelTypes,
|
||||
]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class SpeculativeConfig:
|
||||
"""Configuration for speculative decoding."""
|
||||
|
||||
enforce_eager: bool | None = None
|
||||
"""Override the default enforce_eager from model_config"""
|
||||
# General speculative decoding control
|
||||
num_speculative_tokens: int = Field(default=None, gt=0)
|
||||
"""The number of speculative tokens, if provided. It will default to the
|
||||
number in the draft model config if present, otherwise, it is required."""
|
||||
model: str | None = None
|
||||
"""The name of the draft model, eagle head, or additional weights, if
|
||||
provided."""
|
||||
method: SpeculativeMethod | None = None
|
||||
"""The name of the speculative method to use. If users provide and set the
|
||||
`model` param, the speculative method type will be detected automatically
|
||||
if possible, if `model` param is not provided, the method name must be
|
||||
provided.
|
||||
|
||||
If using `ngram` method, the related configuration `prompt_lookup_max` and
|
||||
`prompt_lookup_min` should be considered."""
|
||||
draft_tensor_parallel_size: int | None = Field(default=None, ge=1)
|
||||
"""The degree of the tensor parallelism for the draft model. Can only be 1
|
||||
or the same as the target model's tensor parallel size."""
|
||||
|
||||
# Draft model configuration
|
||||
quantization: me_quant.QuantizationMethods | None = None
|
||||
"""Quantization method that was used to quantize the draft model weights.
|
||||
If `None`, we assume the model weights are not quantized. Note that it only
|
||||
takes effect when using the draft model-based speculative method."""
|
||||
max_model_len: int | None = Field(default=None, ge=1)
|
||||
"""The maximum model length of the draft model. Used when testing the
|
||||
ability to skip speculation for some sequences."""
|
||||
revision: str | None = None
|
||||
"""The specific model version to use for the draft model. It can be a
|
||||
branch name, a tag name, or a commit id. If unspecified, will use the
|
||||
default version."""
|
||||
code_revision: str | None = None
|
||||
"""The specific revision to use for the draft model code on Hugging Face
|
||||
Hub. It can be a branch name, a tag name, or a commit id. If unspecified,
|
||||
will use the default version."""
|
||||
|
||||
# Advanced control
|
||||
disable_by_batch_size: int | None = Field(default=None, ge=2)
|
||||
"""Disable speculative decoding for new incoming requests when the number
|
||||
of enqueued requests is larger than this value, if provided."""
|
||||
disable_padded_drafter_batch: bool = False
|
||||
"""Disable input padding for speculative decoding. If set to True,
|
||||
speculative input batches can contain sequences of different lengths,
|
||||
which may only be supported by certain attention backends. This currently
|
||||
only affects the EAGLE method of speculation."""
|
||||
|
||||
# Ngram proposer configuration
|
||||
prompt_lookup_max: int | None = Field(default=None, ge=1)
|
||||
"""Maximum size of ngram token window when using Ngram proposer, required
|
||||
when method is set to ngram."""
|
||||
prompt_lookup_min: int | None = Field(default=None, ge=1)
|
||||
"""Minimum size of ngram token window when using Ngram proposer, if
|
||||
provided. Defaults to 1."""
|
||||
|
||||
speculative_token_tree: str | None = None
|
||||
"""Specifies the tree structure for speculative token generation.
|
||||
"""
|
||||
# required configuration params passed from engine
|
||||
target_model_config: SkipValidation[ModelConfig] = None # type: ignore
|
||||
"""The configuration of the target model."""
|
||||
target_parallel_config: SkipValidation[ParallelConfig] = None # type: ignore
|
||||
"""The parallel configuration for the target model."""
|
||||
|
||||
# params generated in the post-init stage
|
||||
draft_model_config: SkipValidation[ModelConfig] = None # type: ignore
|
||||
"""The configuration of the draft model initialized internal."""
|
||||
draft_parallel_config: SkipValidation[ParallelConfig] = None # type: ignore
|
||||
"""The parallel configuration for the draft model initialized internal."""
|
||||
|
||||
# Suffix decoding configuration
|
||||
suffix_decoding_max_tree_depth: int = 24
|
||||
"""The maximum depth of the suffix decoding global and prompt trees. The
|
||||
tree depth limits the sum of the prefix match and speculation lengths."""
|
||||
|
||||
suffix_decoding_max_cached_requests: int = 10000
|
||||
"""The maximum number of requests to cache in the global suffix tree. If
|
||||
exceeded, will trigger eviction in FIFO order. If set to 0, the global
|
||||
suffix tree is disabled and past responses are not cached (prompt trees
|
||||
are still used)."""
|
||||
|
||||
suffix_decoding_max_spec_factor: float = 1.0
|
||||
"""The maximum spec factor for suffix decoding. The spec factor controls
|
||||
speculation lengths based on the prefix match length: max_spec_tokens =
|
||||
max_spec_factor * prefix_match_length."""
|
||||
|
||||
suffix_decoding_min_token_prob: float = 0.1
|
||||
"""The minimum token probability for suffix decoding. Will only speculate
|
||||
tokens with estimated probability (based on frequency counts) greater than
|
||||
or equal to this value."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
factors: list[Any] = []
|
||||
# Eagle3 affects the computation graph because it returns intermediate
|
||||
# hidden states in addition to the final hidden state.
|
||||
factors.append(self.method == "eagle3")
|
||||
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
@staticmethod
|
||||
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
|
||||
initial_architecture = hf_config.architectures[0]
|
||||
if hf_config.model_type in ("deepseek_v3", "deepseek_v32"):
|
||||
hf_config.model_type = "deepseek_mtp"
|
||||
if hf_config.model_type == "deepseek_mtp":
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update(
|
||||
{"n_predict": n_predict, "architectures": ["DeepSeekMTPModel"]}
|
||||
)
|
||||
if hf_config.model_type in ("pangu_ultra_moe"):
|
||||
hf_config.model_type = "pangu_ultra_moe_mtp"
|
||||
if hf_config.model_type == "pangu_ultra_moe_mtp":
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update(
|
||||
{"n_predict": n_predict, "architectures": ["OpenPanguMTPModel"]}
|
||||
)
|
||||
|
||||
if hf_config.architectures[0] == "MiMoForCausalLM":
|
||||
hf_config.model_type = "mimo_mtp"
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update(
|
||||
{
|
||||
"num_hidden_layers": 0,
|
||||
"n_predict": n_predict,
|
||||
"architectures": ["MiMoMTPModel"],
|
||||
}
|
||||
)
|
||||
|
||||
if hf_config.architectures[0] == "Glm4MoeForCausalLM":
|
||||
hf_config.model_type = "glm4_moe_mtp"
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update(
|
||||
{
|
||||
"num_hidden_layers": 0,
|
||||
"n_predict": n_predict,
|
||||
"architectures": ["Glm4MoeMTPModel"],
|
||||
}
|
||||
)
|
||||
|
||||
if hf_config.model_type == "ernie4_5_moe":
|
||||
hf_config.model_type = "ernie_mtp"
|
||||
if hf_config.model_type == "ernie_mtp":
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update(
|
||||
{"n_predict": n_predict, "architectures": ["ErnieMTPModel"]}
|
||||
)
|
||||
|
||||
if hf_config.model_type == "qwen3_next":
|
||||
hf_config.model_type = "qwen3_next_mtp"
|
||||
if hf_config.model_type == "qwen3_next_mtp":
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update(
|
||||
{"n_predict": n_predict, "architectures": ["Qwen3NextMTP"]}
|
||||
)
|
||||
if hf_config.model_type == "longcat_flash":
|
||||
hf_config.model_type = "longcat_flash_mtp"
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
|
||||
hf_config.update(
|
||||
{"n_predict": n_predict, "architectures": ["LongCatFlashMTPModel"]}
|
||||
)
|
||||
|
||||
if initial_architecture == "MistralLarge3ForCausalLM":
|
||||
hf_config.update({"architectures": ["EagleMistralLarge3ForCausalLM"]})
|
||||
|
||||
return hf_config
|
||||
|
||||
def __post_init__(self):
|
||||
# Note: "method" is a new parameter that helps to extend the
|
||||
# configuration of non-model-based proposers, and the "model" parameter
|
||||
# will be used to set the draft model, eagle head, or additional weight
|
||||
# when needed. If users do not specify "method", the speculative method
|
||||
# will be detected automatically if possible. If the speculative method
|
||||
# can not be detected, it will be considered as the "draft_model" by
|
||||
# default.
|
||||
|
||||
if self.method in get_args(MTPModelTypes) and self.method != "mtp":
|
||||
logger.warning(
|
||||
"method `%s` is deprecated and replaced with mtp.", self.method
|
||||
)
|
||||
self.method = "mtp"
|
||||
|
||||
if self.model is None and self.num_speculative_tokens is not None:
|
||||
if self.method == "mtp":
|
||||
if self.target_model_config is None:
|
||||
raise ValueError("target_model_config must be present for mtp")
|
||||
if self.target_model_config.hf_text_config.model_type == "deepseek_v32":
|
||||
# FIXME(luccafong): cudgraph with v32 MTP is not supported,
|
||||
# remove this when the issue is fixed.
|
||||
self.enforce_eager = True
|
||||
# use the draft model from the same model:
|
||||
self.model = self.target_model_config.model
|
||||
# Align the quantization of draft model for cases such as
|
||||
# --quantization fp8 with a bf16 checkpoint.
|
||||
if not self.quantization:
|
||||
self.quantization = self.target_model_config.quantization
|
||||
elif self.method in ("ngram", "[ngram]"):
|
||||
self.model = "ngram"
|
||||
elif self.method == "suffix":
|
||||
self.model = "suffix"
|
||||
else:
|
||||
raise ValueError(
|
||||
"num_speculative_tokens was provided but without speculative model."
|
||||
)
|
||||
|
||||
# Automatically configure the method for ngram when "model" is used
|
||||
# instead of "method"
|
||||
if self.method is None and (
|
||||
self.model is not None and self.model in ("ngram", "[ngram]")
|
||||
):
|
||||
self.method = "ngram"
|
||||
|
||||
if self.method in ("ngram", "[ngram]"):
|
||||
# Unified to "ngram" internally
|
||||
self.method = "ngram"
|
||||
# Set default values if not provided
|
||||
if self.prompt_lookup_min is None and self.prompt_lookup_max is None:
|
||||
# TODO(woosuk): Tune these values. They are arbitrarily chosen.
|
||||
self.prompt_lookup_min = 5
|
||||
self.prompt_lookup_max = 5
|
||||
elif self.prompt_lookup_min is None:
|
||||
if self.prompt_lookup_max is None:
|
||||
raise ValueError(
|
||||
"Either prompt_lookup_max or prompt_lookup_min must be "
|
||||
"provided when using the ngram method."
|
||||
)
|
||||
self.prompt_lookup_min = self.prompt_lookup_max
|
||||
elif self.prompt_lookup_max is None:
|
||||
if self.prompt_lookup_min is None:
|
||||
raise ValueError(
|
||||
"Either prompt_lookup_max or prompt_lookup_min must be "
|
||||
"provided when using the ngram method."
|
||||
)
|
||||
self.prompt_lookup_max = self.prompt_lookup_min
|
||||
|
||||
# Validate values
|
||||
if self.prompt_lookup_min > self.prompt_lookup_max:
|
||||
raise ValueError(
|
||||
f"prompt_lookup_min={self.prompt_lookup_min} must "
|
||||
f"be <= prompt_lookup_max={self.prompt_lookup_max}"
|
||||
)
|
||||
|
||||
# TODO: current we still need extract vocab_size from target model
|
||||
# config, in future, we may try refactor it out, and set
|
||||
# draft related config as None here.
|
||||
self.draft_model_config = self.target_model_config
|
||||
self.draft_parallel_config = self.target_parallel_config
|
||||
elif self.method == "suffix":
|
||||
self._validate_suffix_decoding()
|
||||
else:
|
||||
self.prompt_lookup_max = 0
|
||||
self.prompt_lookup_min = 0
|
||||
|
||||
if self.model is not None:
|
||||
self.draft_model_config = ModelConfig(
|
||||
model=self.model,
|
||||
runner="draft",
|
||||
tokenizer=self.target_model_config.tokenizer,
|
||||
tokenizer_mode=self.target_model_config.tokenizer_mode,
|
||||
trust_remote_code=self.target_model_config.trust_remote_code,
|
||||
allowed_local_media_path=self.target_model_config.allowed_local_media_path,
|
||||
allowed_media_domains=self.target_model_config.allowed_media_domains,
|
||||
dtype=self.target_model_config.dtype,
|
||||
seed=self.target_model_config.seed,
|
||||
revision=self.revision,
|
||||
code_revision=self.code_revision,
|
||||
tokenizer_revision=self.target_model_config.tokenizer_revision,
|
||||
spec_target_max_model_len=self.target_model_config.max_model_len,
|
||||
quantization=self.quantization,
|
||||
enforce_eager=self.target_model_config.enforce_eager,
|
||||
max_logprobs=self.target_model_config.max_logprobs,
|
||||
hf_overrides=SpeculativeConfig.hf_config_override,
|
||||
config_format=self.target_model_config.config_format,
|
||||
)
|
||||
|
||||
# Automatically detect the method
|
||||
if self.method in ("eagle", "eagle3"):
|
||||
pass
|
||||
# examples:
|
||||
# yuhuili/EAGLE-LLaMA3-Instruct-8B
|
||||
# yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
|
||||
# AngelSlim/Qwen3-8B_eagle3
|
||||
elif "eagle-" in self.draft_model_config.model.lower():
|
||||
self.method = "eagle"
|
||||
elif "eagle3" in self.draft_model_config.model.lower():
|
||||
self.method = "eagle3"
|
||||
elif self.draft_model_config.hf_config.model_type == "medusa":
|
||||
self.method = "medusa"
|
||||
elif self.draft_model_config.hf_config.model_type == "mlp_speculator":
|
||||
self.method = "mlp_speculator"
|
||||
elif self.draft_model_config.hf_config.model_type in get_args(
|
||||
MTPModelTypes
|
||||
):
|
||||
self.method = "mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
"Enabling num_speculative_tokens > 1 will run"
|
||||
"multiple times of forward on same MTP layer"
|
||||
",which may result in lower acceptance rate"
|
||||
)
|
||||
elif self.draft_model_config.hf_config.model_type in (
|
||||
"longcat_flash_mtp"
|
||||
):
|
||||
self.method = "longcat_flash_mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
"LongCat MTP models only have "
|
||||
"one layer. Might need some code changes "
|
||||
"to support multiple layers."
|
||||
)
|
||||
else:
|
||||
self.method = "draft_model"
|
||||
raise NotImplementedError(
|
||||
"Speculative decoding with draft model is not "
|
||||
"supported yet. Please consider using other "
|
||||
"speculative decoding methods such as ngram, medusa, "
|
||||
"eagle, or mtp."
|
||||
)
|
||||
|
||||
# Replace hf_config for EAGLE draft_model
|
||||
if self.method in ("eagle", "eagle3"):
|
||||
from vllm.transformers_utils.configs import SpeculatorsConfig
|
||||
from vllm.transformers_utils.configs.eagle import EAGLEConfig
|
||||
|
||||
if isinstance(
|
||||
self.draft_model_config.hf_config,
|
||||
(EAGLEConfig, SpeculatorsConfig),
|
||||
):
|
||||
pass
|
||||
else:
|
||||
eagle_config = EAGLEConfig(
|
||||
self.draft_model_config.hf_config,
|
||||
method=self.method,
|
||||
model_type="eagle",
|
||||
)
|
||||
self.draft_model_config.hf_config = eagle_config
|
||||
|
||||
if self.num_speculative_tokens is not None and hasattr(
|
||||
self.draft_model_config.hf_config, "num_lookahead_tokens"
|
||||
):
|
||||
self.draft_model_config.hf_config.num_lookahead_tokens = (
|
||||
self.num_speculative_tokens
|
||||
)
|
||||
|
||||
n_predict = getattr(
|
||||
self.draft_model_config.hf_config, "n_predict", None
|
||||
)
|
||||
if n_predict is not None:
|
||||
if self.num_speculative_tokens is None:
|
||||
# Default to max value defined in draft model config.
|
||||
self.num_speculative_tokens = n_predict
|
||||
elif (
|
||||
self.num_speculative_tokens > n_predict
|
||||
and self.num_speculative_tokens % n_predict != 0
|
||||
):
|
||||
# Ensure divisibility for MTP module reuse.
|
||||
raise ValueError(
|
||||
f"num_speculative_tokens:{self.num_speculative_tokens}"
|
||||
f" must be divisible by {n_predict=}"
|
||||
)
|
||||
|
||||
if self.speculative_token_tree is None:
|
||||
# Generate chain of tokens.
|
||||
self.speculative_token_tree = str(
|
||||
[(i + 1) * (0,) for i in range(self.num_speculative_tokens)]
|
||||
)
|
||||
else:
|
||||
# Sort the token tree breadth-first.
|
||||
tree_choices = ast.literal_eval(self.speculative_token_tree)
|
||||
self.speculative_token_tree = str(
|
||||
sorted(tree_choices, key=lambda t: (len(t), t))
|
||||
)
|
||||
|
||||
self.draft_tensor_parallel_size = (
|
||||
SpeculativeConfig._verify_and_get_draft_tp(
|
||||
self.target_parallel_config,
|
||||
self.draft_tensor_parallel_size,
|
||||
self.draft_model_config.hf_config,
|
||||
)
|
||||
)
|
||||
|
||||
self.draft_model_config.max_model_len = (
|
||||
SpeculativeConfig._maybe_override_draft_max_model_len(
|
||||
self.max_model_len,
|
||||
self.draft_model_config.max_model_len,
|
||||
self.target_model_config.max_model_len,
|
||||
)
|
||||
)
|
||||
|
||||
self.draft_parallel_config = (
|
||||
SpeculativeConfig.create_draft_parallel_config(
|
||||
self.target_parallel_config, self.draft_tensor_parallel_size
|
||||
)
|
||||
)
|
||||
return self
|
||||
|
||||
def _validate_suffix_decoding(self):
|
||||
if not has_arctic_inference():
|
||||
raise ImportError(
|
||||
"Arctic Inference is required for suffix decoding. "
|
||||
"Install via `pip install arctic-inference==0.1.1`."
|
||||
)
|
||||
if self.num_speculative_tokens is None:
|
||||
# Suffix decoding decides the actual number of speculative tokens
|
||||
# dynamically and treats num_speculative_tokens as a maximum limit.
|
||||
self.num_speculative_tokens = self.suffix_decoding_max_tree_depth
|
||||
logger.warning(
|
||||
"Defaulted num_speculative_tokens to %s for suffix decoding.",
|
||||
self.num_speculative_tokens,
|
||||
)
|
||||
# Validate values
|
||||
if self.suffix_decoding_max_tree_depth < 1:
|
||||
raise ValueError(
|
||||
f"suffix_decoding_max_tree_depth="
|
||||
f"{self.suffix_decoding_max_tree_depth} must be >= 1"
|
||||
)
|
||||
if self.suffix_decoding_max_cached_requests < 0:
|
||||
raise ValueError(
|
||||
f"suffix_decoding_max_cached_requests="
|
||||
f"{self.suffix_decoding_max_cached_requests} must be >= 0"
|
||||
)
|
||||
if self.suffix_decoding_max_spec_factor < 0:
|
||||
raise ValueError(
|
||||
f"suffix_decoding_max_spec_factor="
|
||||
f"{self.suffix_decoding_max_spec_factor} must be >= 0"
|
||||
)
|
||||
if not 0 <= self.suffix_decoding_min_token_prob <= 1:
|
||||
raise ValueError(
|
||||
f"suffix_decoding_min_token_prob="
|
||||
f"{self.suffix_decoding_min_token_prob} must be in [0, 1]"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _maybe_override_draft_max_model_len(
|
||||
speculative_max_model_len: int | None,
|
||||
draft_max_model_len: int,
|
||||
target_max_model_len: int,
|
||||
) -> int:
|
||||
"""Determine the max sequence len for the draft model. This is usually
|
||||
the draft_max_model_len, but may be the target_max_model_len if it is
|
||||
less than the draft_max_model_len, or may be speculative_max_model_len
|
||||
if it is specified.
|
||||
|
||||
This is necessary so that sequences do not exceed the capacity of the
|
||||
draft model or the target model.
|
||||
|
||||
speculative_max_model_len is mainly used for testing that sequences can
|
||||
skip speculation.
|
||||
"""
|
||||
|
||||
if speculative_max_model_len is not None:
|
||||
if speculative_max_model_len > draft_max_model_len:
|
||||
raise ValueError(
|
||||
f"{speculative_max_model_len=} cannot be "
|
||||
f"larger than {draft_max_model_len=}"
|
||||
)
|
||||
|
||||
if speculative_max_model_len > target_max_model_len:
|
||||
raise ValueError(
|
||||
f"{speculative_max_model_len=} cannot be "
|
||||
f"larger than {target_max_model_len=}"
|
||||
)
|
||||
|
||||
return speculative_max_model_len
|
||||
|
||||
return min(
|
||||
draft_max_model_len,
|
||||
target_max_model_len,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _verify_and_get_draft_tp(
|
||||
target_parallel_config: ParallelConfig,
|
||||
speculative_draft_tensor_parallel_size: int | None,
|
||||
draft_hf_config: PretrainedConfig,
|
||||
) -> int:
|
||||
"""
|
||||
Verifies and adjusts the tensor parallel size for a draft model
|
||||
specified using speculative_draft_tensor_parallel_size.
|
||||
"""
|
||||
# If speculative_draft_tensor_parallel_size is unset then set it
|
||||
# appropriately else verify that it is set correctly.
|
||||
if speculative_draft_tensor_parallel_size is None:
|
||||
if draft_hf_config.model_type == "mlp_speculator":
|
||||
speculative_draft_tensor_parallel_size = 1
|
||||
if target_parallel_config.tensor_parallel_size > 1:
|
||||
logger.warning(
|
||||
"%s cannot currently be run with tp>1; "
|
||||
"setting speculative_draft_tensor_parallel_size=1",
|
||||
draft_hf_config.model_type,
|
||||
)
|
||||
else:
|
||||
speculative_draft_tensor_parallel_size = (
|
||||
target_parallel_config.tensor_parallel_size
|
||||
)
|
||||
elif speculative_draft_tensor_parallel_size not in (
|
||||
1,
|
||||
target_parallel_config.tensor_parallel_size,
|
||||
):
|
||||
raise ValueError(
|
||||
f"{speculative_draft_tensor_parallel_size=} cannot be "
|
||||
f"other value than 1 or target model tensor_parallel_size"
|
||||
)
|
||||
return speculative_draft_tensor_parallel_size
|
||||
|
||||
@staticmethod
|
||||
def create_draft_parallel_config(
|
||||
target_parallel_config: ParallelConfig,
|
||||
speculative_draft_tensor_parallel_size: int,
|
||||
) -> ParallelConfig:
|
||||
"""Create a parallel config for use by the draft worker.
|
||||
|
||||
This is mostly a copy of the target parallel config, except the tp_size.
|
||||
"""
|
||||
draft_parallel_config = ParallelConfig(
|
||||
pipeline_parallel_size=target_parallel_config.pipeline_parallel_size,
|
||||
tensor_parallel_size=speculative_draft_tensor_parallel_size,
|
||||
distributed_executor_backend=target_parallel_config.distributed_executor_backend,
|
||||
max_parallel_loading_workers=target_parallel_config.max_parallel_loading_workers,
|
||||
disable_custom_all_reduce=target_parallel_config.disable_custom_all_reduce,
|
||||
ray_workers_use_nsight=target_parallel_config.ray_workers_use_nsight,
|
||||
placement_group=target_parallel_config.placement_group,
|
||||
)
|
||||
|
||||
return draft_parallel_config
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _verify_args(self) -> Self:
|
||||
if self.num_speculative_tokens is None:
|
||||
raise ValueError(
|
||||
"num_speculative_tokens must be provided with "
|
||||
"speculative model unless the draft model config contains an "
|
||||
"n_predict parameter."
|
||||
)
|
||||
|
||||
if self.num_speculative_tokens <= 0:
|
||||
raise ValueError(
|
||||
"Expected num_speculative_tokens to be greater "
|
||||
f"than zero ({self.num_speculative_tokens})."
|
||||
)
|
||||
|
||||
if self.draft_model_config:
|
||||
self.draft_model_config.verify_with_parallel_config(
|
||||
self.draft_parallel_config
|
||||
)
|
||||
|
||||
if self.disable_by_batch_size is not None and self.disable_by_batch_size < 2:
|
||||
raise ValueError(
|
||||
"Expect the batch size threshold of disabling "
|
||||
"speculative decoding is > 1, but got "
|
||||
f"{self.disable_by_batch_size=}"
|
||||
)
|
||||
|
||||
eagle3_target_supported = ["llama", "qwen", "minicpm", "gpt_oss"]
|
||||
if (
|
||||
self.method == "eagle3"
|
||||
and self.target_model_config
|
||||
and not any(
|
||||
supported_model in self.target_model_config.hf_text_config.model_type
|
||||
for supported_model in eagle3_target_supported
|
||||
)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Eagle3 is only supported for {eagle3_target_supported} models. " # noqa: E501
|
||||
f"Got {self.target_model_config.hf_text_config.model_type=}"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def use_eagle(self) -> bool:
|
||||
return self.method in ("eagle", "eagle3", "mtp")
|
||||
|
||||
def __repr__(self) -> str:
|
||||
method = self.method
|
||||
model = None if method in ("ngram", "suffix") else self.draft_model_config.model
|
||||
num_spec_tokens = self.num_speculative_tokens
|
||||
return f"SpeculativeConfig({method=}, {model=}, {num_spec_tokens=})"
|
||||
Reference in New Issue
Block a user