Files

293 lines
11 KiB
Python
Raw Permalink Normal View History

2026-01-19 10:38:50 +08:00
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
2026-01-09 13:34:11 +08:00
"""Utilities for selecting and loading models."""
2026-01-19 10:38:50 +08:00
import inspect
import warnings
from contextlib import contextmanager
from dataclasses import dataclass, field
2026-01-09 13:34:11 +08:00
import torch
from torch import nn
2026-01-19 10:38:50 +08:00
from typing_extensions import assert_never
from vllm.attention.layer import Attention, MLAAttention
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.models.interfaces import SupportsQuant, supports_multimodal
from vllm.utils.platform_utils import is_pin_memory_available
logger = init_logger(__name__)
def initialize_model(
vllm_config: VllmConfig,
*,
prefix: str = "",
model_class: type[nn.Module] | None = None,
model_config: ModelConfig | None = None,
) -> nn.Module:
"""Initialize a model with the given configurations."""
if model_config is None:
model_config = vllm_config.model_config
if model_class is None:
model_class, _ = get_model_architecture(model_config)
if vllm_config.quant_config is not None:
configure_quant_config(vllm_config.quant_config, model_class)
signatures = inspect.signature(model_class.__init__)
all_params = [param.name for param in signatures.parameters.values()]
if "vllm_config" in all_params and "prefix" in all_params:
# new-style model class
with set_current_vllm_config(vllm_config, check_compile=True, prefix=prefix):
return model_class(vllm_config=vllm_config, prefix=prefix)
msg = (
"vLLM model class should accept `vllm_config` and `prefix` as "
"input arguments. Possibly you have an old-style model class"
" registered from out of tree and it is used for new vLLM version. "
"Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
"for the design and update the model class accordingly."
)
warnings.warn(msg, DeprecationWarning, stacklevel=2)
logger.warning(
"Trying to guess the arguments for old-style model class %s",
model_class,
)
# try to be compatible with old-style model class
kwargs = {}
if "prefix" in all_params:
kwargs["prefix"] = prefix
if "config" in all_params:
kwargs["config"] = model_config.hf_config
if "cache_config" in all_params:
kwargs["cache_config"] = vllm_config.cache_config
if "quant_config" in all_params:
kwargs["quant_config"] = vllm_config.quant_config
if "lora_config" in all_params:
kwargs["lora_config"] = vllm_config.lora_config
if "scheduler_config" in all_params:
kwargs["scheduler_config"] = vllm_config.scheduler_config
with set_current_vllm_config(vllm_config, check_compile=True, prefix=prefix):
return model_class(**kwargs)
def process_weights_after_loading(
model: nn.Module, model_config: ModelConfig, target_device: torch.device
) -> None:
if getattr(model, "process_weights_after_loading_already_called", False):
# In case `process_weights_after_loading` is called multiple times
# we'll skip it at later times
logger.debug_once(
"process_weights_after_loading already called for model %s", model
)
return
# to avoid circular dependency
from vllm.model_executor.model_loader.online_quantization import (
maybe_save_metadata_and_attributes_for_weight_reloading,
)
maybe_save_metadata_and_attributes_for_weight_reloading(model, model_config)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if isinstance(quant_method, QuantizeMethodBase):
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
# Initialize post-load attention weights for both Attention and MLA.
# NOTE: Happens after other modules so we can easily decompress weights.
for _, module in model.named_modules():
if isinstance(module, (Attention, MLAAttention)) and hasattr(
module, "process_weights_after_loading"
):
# TODO(lucas): see if there is a way to unify the signatures
# of process_weights_after_loading
module.process_weights_after_loading(model_config.dtype)
@contextmanager
def device_loading_context(module: torch.nn.Module, target_device: torch.device):
if target_device.type == "cpu":
# If target is CPU, no need to move anything
yield module
return
original_device_states: dict[str, torch.device] = {}
# Store original device states and move parameters to GPU if they're on CPU
for name, p in module.named_parameters():
if p.device.type == "cpu":
original_device_states[name] = p.device
p.data = p.data.to(target_device)
# Parameters already on target device are not touched
try:
yield module
finally:
# Restore parameters to their original devices, ignoring new parameters
pin_memory = is_pin_memory_available()
for name, p in module.named_parameters():
if name in original_device_states:
original_device: torch.device = original_device_states[name]
if original_device.type == "cpu":
# `torch.empty_like` does not support `pin_memory` argument
cpu_data = torch.empty_strided(
size=p.data.size(),
stride=p.data.stride(),
dtype=p.data.dtype,
layout=p.data.layout,
device="cpu",
pin_memory=pin_memory,
)
cpu_data.copy_(p.data)
p.data = cpu_data
else:
p.data = p.data.to(original_device)
# New parameters or parameters already on target device are untouched
2026-01-09 13:34:11 +08:00
2026-01-19 10:38:50 +08:00
_MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]()
"""Caches the outputs of `_get_model_architecture`."""
2026-01-09 13:34:11 +08:00
2026-01-19 10:38:50 +08:00
def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]:
from vllm.model_executor.models.adapters import (
as_embedding_model,
as_seq_cls_model,
try_create_mm_pooling_model_cls,
)
2026-01-09 13:34:11 +08:00
architectures = getattr(model_config.hf_config, "architectures", [])
2026-01-19 10:38:50 +08:00
model_cls, arch = model_config.registry.resolve_model_cls(
architectures,
model_config=model_config,
)
if arch == model_config._get_transformers_backend_cls():
assert model_config.model_impl != "vllm"
if model_config.model_impl == "auto":
logger.warning_once(
"%s has no vLLM implementation, falling back to Transformers "
"implementation. Some features may not be supported and "
"performance may not be optimal.",
arch,
)
convert_type = model_config.convert_type
if convert_type != "none" and supports_multimodal(model_cls):
logger.debug_once("Detected conversion of Multi Modal model.")
converted = try_create_mm_pooling_model_cls(model_cls)
if converted is not None:
logger.debug_once("Creating wrapper class to forward pooler.")
return converted, arch
else:
logger.debug_once("Attempting direct conversion.")
if convert_type == "none":
pass
elif convert_type == "embed":
logger.debug_once("Converting to embedding model.")
model_cls = as_embedding_model(model_cls)
elif convert_type == "classify":
logger.debug_once("Converting to sequence classification model.")
model_cls = as_seq_cls_model(model_cls)
else:
assert_never(convert_type)
return model_cls, arch
def get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]:
key = hash(
(
model_config.model,
model_config.convert_type,
model_config.runner_type,
model_config.trust_remote_code,
model_config.model_impl,
tuple(getattr(model_config.hf_config, "architectures", [])),
)
)
if key in _MODEL_ARCH_BY_HASH:
return _MODEL_ARCH_BY_HASH[key]
model_arch = _get_model_architecture(model_config)
_MODEL_ARCH_BY_HASH[key] = model_arch
return model_arch
def get_model_cls(model_config: ModelConfig) -> type[nn.Module]:
return get_model_architecture(model_config)[0]
2026-01-09 13:34:11 +08:00
def get_architecture_class_name(model_config: ModelConfig) -> str:
return get_model_architecture(model_config)[1]
2026-01-19 10:38:50 +08:00
@dataclass
class ParamMapping:
"""
A class to handle parameter mapping for model weight loading.
It creates a bidirectional mapping between packed parameters and their
constituent parts.
"""
packed_mapping: dict[str, list[str]]
inverse_packed_mapping: dict[str, tuple[str, int]] = field(default_factory=dict)
def __post_init__(self):
for packed_name, sub_params in self.packed_mapping.items():
# Skip self-contained cases (e.g., {"W_pack": ["W_pack"]})
if len(sub_params) == 1 and sub_params[0] == packed_name:
continue
for index, param_name in enumerate(sub_params):
self.inverse_packed_mapping[param_name] = (
packed_name,
index,
)
def get_sub_modules(self, module_name: str) -> tuple[str, list[str]] | None:
for key, value in self.packed_mapping.items():
if module_name.endswith(key):
return key, value
return None
def configure_quant_config(
quant_config: QuantizationConfig, model_class: type[nn.Module]
):
"""
Pass packed_modules_mapping by reference to quant_config so that
quant_config can properly match fused modules
Note that model attributes are passed by reference to quant_config,
enabling them to be updated by model_class.__new__ (ex. chatglm, qwen)
Once the `SupportsQuant` mixin has been added to all models, this
function can be removed
"""
if not issubclass(model_class, SupportsQuant):
hf_to_vllm_mapper = getattr(model_class, "hf_to_vllm_mapper", None)
packed_mapping = getattr(model_class, "packed_modules_mapping", None)
# pass mappings by reference to quant_config
if hf_to_vllm_mapper is not None:
quant_config.apply_vllm_mapper(hf_to_vllm_mapper)
if packed_mapping is not None:
quant_config.packed_modules_mapping = packed_mapping