# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """Utilities for selecting and loading models.""" import inspect import warnings from contextlib import contextmanager from dataclasses import dataclass, field import torch from torch import nn from typing_extensions import assert_never import vllm.envs as envs from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.attention import Attention, MLAAttention from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) from vllm.model_executor.model_loader.reload import ( record_metadata_for_reloading, set_torchao_reload_attrs, ) from vllm.model_executor.models.interfaces import SupportsQuant from vllm.tracing import instrument from vllm.utils.platform_utils import is_pin_memory_available from vllm.utils.torch_utils import get_accelerator_view_from_cpu_tensor logger = init_logger(__name__) @instrument(span_name="Initialize model") def initialize_model( vllm_config: VllmConfig, *, prefix: str = "", model_class: type[nn.Module] | None = None, model_config: ModelConfig | None = None, ) -> nn.Module: """Initialize a model with the given configurations.""" if model_config is None: model_config = vllm_config.model_config if model_class is None: model_class, _ = get_model_architecture(model_config) if vllm_config.quant_config is not None: configure_quant_config(vllm_config.quant_config, model_class) signatures = inspect.signature(model_class.__init__) all_params = [param.name for param in signatures.parameters.values()] if "vllm_config" in all_params and "prefix" in all_params: # new-style model class with set_current_vllm_config(vllm_config, check_compile=True, prefix=prefix): model = model_class(vllm_config=vllm_config, prefix=prefix) record_metadata_for_reloading(model) return model msg = ( "vLLM model class should accept `vllm_config` and `prefix` as " "input arguments. Possibly you have an old-style model class" " registered from out of tree and it is used for new vLLM version. " "Check https://docs.vllm.ai/en/latest/design/arch_overview.html " "for the design and update the model class accordingly." ) warnings.warn(msg, DeprecationWarning, stacklevel=2) logger.warning( "Trying to guess the arguments for old-style model class %s", model_class, ) # try to be compatible with old-style model class kwargs = {} if "prefix" in all_params: kwargs["prefix"] = prefix if "config" in all_params: kwargs["config"] = model_config.hf_config if "cache_config" in all_params: kwargs["cache_config"] = vllm_config.cache_config if "quant_config" in all_params: kwargs["quant_config"] = vllm_config.quant_config if "lora_config" in all_params: kwargs["lora_config"] = vllm_config.lora_config if "scheduler_config" in all_params: kwargs["scheduler_config"] = vllm_config.scheduler_config with set_current_vllm_config(vllm_config, check_compile=True, prefix=prefix): model = model_class(**kwargs) record_metadata_for_reloading(model) return model def process_weights_after_loading( model: nn.Module, model_config: ModelConfig, target_device: torch.device ) -> None: for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) if isinstance(quant_method, QuantizeMethodBase): # When quant methods need to process weights after loading # (for repacking, quantizing, etc), they expect parameters # to be on the global target device. This scope is for the # case where cpu offloading is used, where we will move the # parameters onto device for processing and back off after. with device_loading_context(module, target_device): quant_method.process_weights_after_loading(module) # Initialize post-load attention weights for both Attention and MLA. # NOTE: Happens after other modules so we can easily decompress weights. for _, module in model.named_modules(): if isinstance(module, (Attention, MLAAttention)) and hasattr( module, "process_weights_after_loading" ): # TODO(lucas): see if there is a way to unify the signatures # of process_weights_after_loading with device_loading_context(module, target_device): module.process_weights_after_loading(model_config.dtype) # Needed for torchao model reloading via model.reload_weights # @kylesayrs @jerryzh168 this can be removed if callers move to `reload_weights` if model_config.quantization == "torchao": set_torchao_reload_attrs(model, model_config) @contextmanager def device_loading_context(module: torch.nn.Module, target_device: torch.device): if target_device.type == "cpu": # If target is CPU, no need to move anything yield module return original_device_states: dict[str, torch.device] = {} uva_offloaded_parameters: list[str] = [] # Store original device states and move parameters to GPU if they're on CPU for name, p in module.named_parameters(): if p.device.type == "cpu": original_device_states[name] = p.device p.data = p.data.to(target_device) if getattr(p, "_vllm_is_uva_offloaded", False): uva_offloaded_parameters.append(name) # Parameters already on target device are not touched try: yield module finally: use_pin_memory = ( is_pin_memory_available() and not envs.VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY ) # Restore parameters to their original devices, ignoring new parameters for name, p in module.named_parameters(): if name in original_device_states: original_device: torch.device = original_device_states[name] p.data = p.data.to(original_device) # parameter is UVA offloaded, but was replaced with a new device tensor # re-offload it to CPU using UVA if name in uva_offloaded_parameters and not getattr( p, "_vllm_is_uva_offloaded", False ): cpu_data = p.data.to(device="cpu") if use_pin_memory: cpu_data = cpu_data.pin_memory() p.data = get_accelerator_view_from_cpu_tensor(cpu_data) p._vllm_is_uva_offloaded = True _MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]() """Caches the outputs of `_get_model_architecture`.""" def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]: from vllm.model_executor.models.adapters import as_embedding_model, as_seq_cls_model architectures = getattr(model_config.hf_config, "architectures", []) model_cls, arch = model_config.registry.resolve_model_cls( architectures, model_config=model_config, ) if arch == model_config._get_transformers_backend_cls(): assert model_config.model_impl != "vllm" if model_config.model_impl == "auto": logger.warning_once( "%s has no vLLM implementation, falling back to Transformers " "implementation. Some features may not be supported and " "performance may not be optimal.", arch, ) convert_type = model_config.convert_type if convert_type == "none": pass elif convert_type == "embed": logger.debug_once("Converting to embedding model.") model_cls = as_embedding_model(model_cls) elif convert_type == "classify": logger.debug_once("Converting to sequence classification model.") model_cls = as_seq_cls_model(model_cls) else: assert_never(convert_type) return model_cls, arch def get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]: key = hash( ( model_config.model, model_config.convert_type, model_config.runner_type, model_config.trust_remote_code, model_config.model_impl, tuple(getattr(model_config.hf_config, "architectures", [])), ) ) if key in _MODEL_ARCH_BY_HASH: return _MODEL_ARCH_BY_HASH[key] model_arch = _get_model_architecture(model_config) _MODEL_ARCH_BY_HASH[key] = model_arch return model_arch def get_model_cls(model_config: ModelConfig) -> type[nn.Module]: return get_model_architecture(model_config)[0] def get_architecture_class_name(model_config: ModelConfig) -> str: return get_model_architecture(model_config)[1] @dataclass class ParamMapping: """ A class to handle parameter mapping for model weight loading. It creates a bidirectional mapping between packed parameters and their constituent parts. """ packed_mapping: dict[str, list[str]] inverse_packed_mapping: dict[str, tuple[str, int]] = field(default_factory=dict) def __post_init__(self): for packed_name, sub_params in self.packed_mapping.items(): # Skip self-contained cases (e.g., {"W_pack": ["W_pack"]}) if len(sub_params) == 1 and sub_params[0] == packed_name: continue for index, param_name in enumerate(sub_params): self.inverse_packed_mapping[param_name] = ( packed_name, index, ) def get_sub_modules(self, module_name: str) -> tuple[str, list[str]] | None: for key, value in self.packed_mapping.items(): if module_name.endswith(key): return key, value return None def configure_quant_config( quant_config: QuantizationConfig, model_class: type[nn.Module] ): """ Pass packed_modules_mapping by reference to quant_config so that quant_config can properly match fused modules Note that model attributes are passed by reference to quant_config, enabling them to be updated by model_class.__new__ (ex. chatglm, qwen) Once the `SupportsQuant` mixin has been added to all models, this function can be removed """ if not issubclass(model_class, SupportsQuant): hf_to_vllm_mapper = getattr(model_class, "hf_to_vllm_mapper", None) packed_mapping = getattr(model_class, "packed_modules_mapping", None) # pass mappings by reference to quant_config if hf_to_vllm_mapper is not None: quant_config.apply_vllm_mapper(hf_to_vllm_mapper) if packed_mapping is not None: quant_config.packed_modules_mapping = packed_mapping