Sync from v0.13
This commit is contained in:
114
vllm/transformers_utils/configs/speculators/base.py
Normal file
114
vllm/transformers_utils/configs/speculators/base.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.transformers_utils.configs.speculators.algos import (
|
||||
SUPPORTED_SPECULATORS_TYPES,
|
||||
)
|
||||
|
||||
__all__ = ["SpeculatorsConfig"]
|
||||
|
||||
|
||||
class SpeculatorsConfig(PretrainedConfig):
|
||||
model_type = "speculators"
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: str | os.PathLike,
|
||||
**kwargs,
|
||||
) -> "SpeculatorsConfig":
|
||||
"""Load speculators Eagle config and convert to vLLM format."""
|
||||
config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
vllm_config = cls.extract_vllm_speculative_config(config_dict)
|
||||
return cls(**vllm_config)
|
||||
|
||||
@classmethod
|
||||
def extract_vllm_speculative_config(
|
||||
cls, config_dict: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
speculators_model_type = config_dict.get("speculators_model_type")
|
||||
if speculators_model_type not in SUPPORTED_SPECULATORS_TYPES:
|
||||
raise ValueError(
|
||||
f"Expected one of: {SUPPORTED_SPECULATORS_TYPES}. "
|
||||
"Please ensure you're loading a speculators-format model."
|
||||
)
|
||||
|
||||
# validate fields
|
||||
# TODO: @dsikka - use speculators pydantic model to validate
|
||||
cls.validate_speculators_config(config_dict=config_dict)
|
||||
# Convert from speculators config -> format that can be ingested by vLLM
|
||||
vllm_config = cls.build_vllm_speculative_config(config_dict=config_dict)
|
||||
# Apply anything specific to the supported algorithm
|
||||
algo_updater = SUPPORTED_SPECULATORS_TYPES[speculators_model_type]
|
||||
algo_updater(config_dict=config_dict, vllm_config=vllm_config)
|
||||
return vllm_config
|
||||
|
||||
@classmethod
|
||||
def validate_speculators_config(cls, config_dict: dict[str, Any]) -> None:
|
||||
try:
|
||||
spec_config = config_dict["speculators_config"]
|
||||
methods = spec_config["proposal_methods"]
|
||||
first_method = methods[0]
|
||||
_ = first_method["speculative_tokens"]
|
||||
_ = spec_config["verifier"]["name_or_path"]
|
||||
_ = config_dict["speculators_model_type"]
|
||||
except (KeyError, IndexError, TypeError) as e:
|
||||
raise ValueError("Invalid speculators config structure") from e
|
||||
|
||||
if "transformer_layer_config" not in config_dict:
|
||||
raise ValueError("Must provide transformer_layer_config")
|
||||
|
||||
if not isinstance(config_dict["transformer_layer_config"], dict):
|
||||
raise TypeError(
|
||||
"'transformer_layer_config' must be a dictionary if provided"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build_vllm_speculative_config(
|
||||
cls, config_dict: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Build vLLM-compatible speculative configuration from speculators format.
|
||||
|
||||
This method extracts and transforms speculative configuration from the
|
||||
speculators format into the structure expected by vLLM.
|
||||
|
||||
Args:
|
||||
config_dict: Configuration dictionary in speculators format
|
||||
|
||||
Returns:
|
||||
Dictionary with vLLM-compatible speculative configuration
|
||||
"""
|
||||
# Extract speculators configuration
|
||||
spec_config = config_dict["speculators_config"]
|
||||
|
||||
# Currently we only support one proposal method
|
||||
proposal_methods = spec_config.get("proposal_methods")
|
||||
if not proposal_methods:
|
||||
raise ValueError("No proposal methods found in speculators config")
|
||||
|
||||
first_method = proposal_methods[0]
|
||||
num_speculative_tokens = first_method.get("speculative_tokens")
|
||||
|
||||
if num_speculative_tokens is None:
|
||||
raise ValueError(
|
||||
f"Missing 'speculative_tokens' in proposal method. Got: {first_method}"
|
||||
)
|
||||
|
||||
# Build base vLLM speculative configuration
|
||||
vllm_config = {
|
||||
"method": config_dict.get("speculators_model_type"),
|
||||
"num_speculative_tokens": num_speculative_tokens,
|
||||
"target_model": spec_config.get("verifier")["name_or_path"],
|
||||
}
|
||||
|
||||
# Merge transformer layer configuration if present
|
||||
transformer_config = config_dict.get("transformer_layer_config", {})
|
||||
vllm_config.update(transformer_config)
|
||||
|
||||
return vllm_config
|
||||
Reference in New Issue
Block a user