Add minimal vLLM 0.16.1 build repo for BI-V150

This commit is contained in:
2026-04-18 10:56:22 +08:00
commit d69657327e
1895 changed files with 615301 additions and 0 deletions

View File

@@ -0,0 +1,156 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Literal
from torch import nn
from vllm.config import ModelConfig, VllmConfig
from vllm.config.load import LoadConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.bitsandbytes_loader import BitsAndBytesModelLoader
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader
from vllm.model_executor.model_loader.gguf_loader import GGUFModelLoader
from vllm.model_executor.model_loader.runai_streamer_loader import (
RunaiModelStreamerLoader,
)
from vllm.model_executor.model_loader.sharded_state_loader import ShardedStateLoader
from vllm.model_executor.model_loader.tensorizer_loader import TensorizerLoader
from vllm.model_executor.model_loader.utils import (
get_architecture_class_name,
get_model_architecture,
get_model_cls,
)
logger = init_logger(__name__)
# Reminder: Please update docstring in `LoadConfig`
# if a new load format is added here
LoadFormats = Literal[
"auto",
"hf",
"bitsandbytes",
"dummy",
"fastsafetensors",
"gguf",
"mistral",
"npcache",
"pt",
"runai_streamer",
"runai_streamer_sharded",
"safetensors",
"sharded_state",
"tensorizer",
]
_LOAD_FORMAT_TO_MODEL_LOADER: dict[str, type[BaseModelLoader]] = {
"auto": DefaultModelLoader,
"hf": DefaultModelLoader,
"bitsandbytes": BitsAndBytesModelLoader,
"dummy": DummyModelLoader,
"fastsafetensors": DefaultModelLoader,
"gguf": GGUFModelLoader,
"mistral": DefaultModelLoader,
"npcache": DefaultModelLoader,
"pt": DefaultModelLoader,
"runai_streamer": RunaiModelStreamerLoader,
"runai_streamer_sharded": ShardedStateLoader,
"safetensors": DefaultModelLoader,
"sharded_state": ShardedStateLoader,
"tensorizer": TensorizerLoader,
}
def register_model_loader(load_format: str):
"""Register a customized vllm model loader.
When a load format is not supported by vllm, you can register a customized
model loader to support it.
Args:
load_format (str): The model loader format name.
Examples:
>>> from vllm.config.load import LoadConfig
>>> from vllm.model_executor.model_loader import (
... get_model_loader,
... register_model_loader,
... )
>>> from vllm.model_executor.model_loader.base_loader import BaseModelLoader
>>>
>>> @register_model_loader("my_loader")
... class MyModelLoader(BaseModelLoader):
... def download_model(self):
... pass
...
... def load_weights(self):
... pass
>>>
>>> load_config = LoadConfig(load_format="my_loader")
>>> type(get_model_loader(load_config))
<class 'MyModelLoader'>
""" # noqa: E501
def _wrapper(model_loader_cls):
if load_format in _LOAD_FORMAT_TO_MODEL_LOADER:
logger.warning(
"Load format `%s` is already registered, and will be "
"overwritten by the new loader class `%s`.",
load_format,
model_loader_cls,
)
if not issubclass(model_loader_cls, BaseModelLoader):
raise ValueError(
"The model loader must be a subclass of `BaseModelLoader`."
)
_LOAD_FORMAT_TO_MODEL_LOADER[load_format] = model_loader_cls
logger.info(
"Registered model loader `%s` with load format `%s`",
model_loader_cls,
load_format,
)
return model_loader_cls
return _wrapper
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
"""Get a model loader based on the load format."""
load_format = load_config.load_format
if load_format not in _LOAD_FORMAT_TO_MODEL_LOADER:
raise ValueError(f"Load format `{load_format}` is not supported")
return _LOAD_FORMAT_TO_MODEL_LOADER[load_format](load_config)
def get_model(
*,
vllm_config: VllmConfig,
model_config: ModelConfig | None = None,
prefix: str = "",
load_config: LoadConfig | None = None,
) -> nn.Module:
loader = get_model_loader(load_config or vllm_config.load_config)
if model_config is None:
model_config = vllm_config.model_config
return loader.load_model(
vllm_config=vllm_config, model_config=model_config, prefix=prefix
)
__all__ = [
"get_model",
"get_model_loader",
"get_architecture_class_name",
"get_model_architecture",
"get_model_cls",
"register_model_loader",
"BaseModelLoader",
"BitsAndBytesModelLoader",
"GGUFModelLoader",
"DefaultModelLoader",
"DummyModelLoader",
"RunaiModelStreamerLoader",
"ShardedStateLoader",
"TensorizerLoader",
]

View File

@@ -0,0 +1,86 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
import torch
import torch.nn as nn
import vllm.envs as envs
from vllm.config import ModelConfig, VllmConfig
from vllm.config.load import LoadConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.utils import (
initialize_model,
process_weights_after_loading,
)
from vllm.platforms import current_platform
from vllm.tracing import instrument
from vllm.utils.mem_utils import format_gib
from vllm.utils.torch_utils import set_default_torch_dtype
logger = init_logger(__name__)
class BaseModelLoader(ABC):
"""Base class for model loaders."""
def __init__(self, load_config: LoadConfig):
self.load_config = load_config
@abstractmethod
def download_model(self, model_config: ModelConfig) -> None:
"""Download a model so that it can be immediately loaded."""
raise NotImplementedError
@abstractmethod
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
"""Load weights into a model. This standalone API allows
inplace weights loading for an already-initialized model"""
raise NotImplementedError
@instrument(span_name="Load model")
def load_model(
self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = ""
) -> nn.Module:
"""Load a model with the given configurations."""
device_config = vllm_config.device_config
load_config = vllm_config.load_config
load_device = (
device_config.device if load_config.device is None else load_config.device
)
target_device = torch.device(load_device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = initialize_model(
vllm_config=vllm_config, model_config=model_config, prefix=prefix
)
log_model_inspection(model)
logger.debug("Loading weights on %s ...", load_device)
# Quantization does not happen in `load_weights` but after it
self.load_weights(model, model_config)
# Log peak GPU memory after loading weights. This is needed
# to have test coverage on peak memory for online quantization.
if current_platform.is_cuda():
peak_memory = torch.cuda.max_memory_allocated()
logger.debug_once(
"Peak GPU memory after loading weights: %s GiB",
format_gib(peak_memory),
scope="local",
)
process_weights_after_loading(model, model_config, target_device)
return model.eval()
def log_model_inspection(model: nn.Module) -> None:
"""Log model structure if VLLM_LOG_MODEL_INSPECTION=1."""
if not envs.VLLM_LOG_MODEL_INSPECTION:
return
from vllm.model_inspection import format_model_inspection
logger.info("vLLM model structure:\n%s", format_model_inspection(model))

View File

@@ -0,0 +1,817 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import fnmatch
import glob
import itertools
import math
import os
from collections.abc import Callable, Generator
from typing import Any
import numpy as np
import torch
from huggingface_hub import HfApi
from packaging import version
from torch import nn
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import ModelConfig
from vllm.config.load import LoadConfig
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.logger import init_logger
from vllm.lora.utils import is_moe_model
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (
LinearBase,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.utils import ParamMapping
from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf,
download_weights_from_hf,
filter_duplicate_safetensors_files,
filter_files_not_needed_for_inference,
pt_weights_iterator,
safetensors_weights_iterator,
)
from vllm.model_executor.models import is_pooling_model
from vllm.model_executor.utils import (
get_moe_expert_mapping,
get_packed_modules_mapping,
set_weight_attrs,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_default_torch_dtype
logger = init_logger(__name__)
class BitsAndBytesModelLoader(BaseModelLoader):
"""Model loader to load model weights with BitsAndBytes quantization."""
possible_config_file_names = ["adapter_config.json"]
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
# Save the module names without sharding.
self.unsharded_weights_modules: list[str] = []
# Save the module names that are sharded by column.
self.column_sharded_weights_modules: list[str] = []
# Modules whose weights might have fused on disk
# we need their output_sizes to make shard in flight correctly with TP
self.maybe_fused_weights_modules: dict[str, list[int]] = {}
# Store all module names (from transformers) that support
# BNB quantization.
self.target_modules: list[str] = []
self.tp_disabled_modules: list[str] = []
# Store the mapping of expert parameters for MoE models.
self.expert_params_mapping: list[tuple[str, str, int, str]] = []
# mapping weight names from transformers to vllm.
self.weight_mapper: Callable = lambda name: name
self.pre_quant: bool = False
self.load_8bit: bool = False
self.is_pool_model: bool = False
def _get_weight_files(
self,
model_name_or_path: str,
allowed_patterns: list[str],
revision: str | None = None,
) -> tuple[str, list[str], str]:
"""Retrieve weight files. Download the files if necessary.
Return the weight files and the file pattern."""
is_local = os.path.isdir(model_name_or_path)
if is_local:
for pattern in allowed_patterns:
weight_files = glob.glob(os.path.join(model_name_or_path, pattern))
if weight_files:
return model_name_or_path, weight_files, pattern
else:
hf_api = HfApi()
repo_files = hf_api.list_repo_files(repo_id=model_name_or_path)
for pattern in allowed_patterns:
matching_files = fnmatch.filter(repo_files, pattern)
if matching_files:
hf_folder = download_weights_from_hf(
model_name_or_path,
self.load_config.download_dir,
[pattern],
revision,
ignore_patterns=self.load_config.ignore_patterns,
)
return (
hf_folder,
glob.glob(os.path.join(hf_folder, pattern)),
pattern,
)
raise RuntimeError(f"No model weights found in: `{model_name_or_path}`")
def _prepare_weights(
self, model_name_or_path: str, revision: str | None
) -> tuple[list[str], bool]:
"""Prepare weight files for the model."""
allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]
hf_folder, hf_weights_files, matched_pattern = self._get_weight_files(
model_name_or_path, allowed_patterns, revision
)
use_safetensors = matched_pattern == "*.safetensors"
is_local = os.path.isdir(model_name_or_path)
index_file = SAFE_WEIGHTS_INDEX_NAME
if use_safetensors:
# For models like Mistral-7B-Instruct-v0.3
# there are both sharded safetensors files and a consolidated
# safetensors file. Using both breaks.
# Here, we download the `model.safetensors.index.json` and filter
# any files not found in the index.
if not is_local:
download_safetensors_index_file_from_hf(
model_name_or_path,
index_file,
self.load_config.download_dir,
revision,
)
hf_weights_files = filter_duplicate_safetensors_files(
hf_weights_files, hf_folder, index_file
)
else:
hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files)
if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`"
)
return hf_weights_files, use_safetensors
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
def _maybe_pool_model(module_name: str):
# For pool model, we need to add the prefix `model.`
# for the weight name if possible.
if (
self.is_pool_model
and self.target_modules[0].startswith("model.")
and not module_name.startswith("model.")
):
return "model." + module_name
return module_name
if use_safetensors:
iterator = safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
else:
iterator = pt_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
self.load_config.pt_load_map_location,
)
for org_name, param in iterator:
# mapping weight names from transformers to vllm while preserving
# original names.
mapped_name = self.weight_mapper(org_name)
mapped_name = _maybe_pool_model(mapped_name)
yield org_name, mapped_name, param
def _get_quantized_weights_iterator(
self,
model_name_or_path: str,
revision: str | None,
) -> tuple[Generator[tuple[str, torch.Tensor], None, None], dict[str, Any]]:
"""Get an iterator to the model weights with bitsandbytes quantization,
as well as the quantization state dictionary."""
# only load the bitsandbytes module when needed
try:
import bitsandbytes
if version.parse(bitsandbytes.__version__) < version.parse("0.46.1"):
raise ImportError(
"bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.46.1."
)
except ImportError as err:
raise ImportError(
"Please install bitsandbytes>=0.46.1 via "
"`pip install bitsandbytes>=0.46.1` to use "
"bitsandbytes quantizer."
) from err
hf_weights_files, use_safetensors = self._prepare_weights(
model_name_or_path, revision
)
quant_state_dict: dict[str, Any] = {}
if self.pre_quant:
if self.load_8bit:
return self._quantized_8bit_generator(
hf_weights_files, use_safetensors, quant_state_dict
), quant_state_dict
else:
return self._quantized_4bit_generator(
hf_weights_files, use_safetensors, quant_state_dict
), quant_state_dict
return self._unquantized_generator(
hf_weights_files, use_safetensors, quant_state_dict
), quant_state_dict
def _is_8bit_weight_name(self, weight_name: str):
quantized_suffix = {".scb", ".weight_format"}
return any(weight_name.lower().endswith(suffix) for suffix in quantized_suffix)
def _is_4bit_weight_name(self, weight_name: str):
quantized_suffix = {
"absmax",
"quant_map",
"nested_absmax",
"nested_quant_map",
"bitsandbytes",
}
suffix = weight_name.split(".")[-1]
return any(q_suffix in suffix for q_suffix in quantized_suffix)
def _quantized_8bit_generator(
self, hf_weights_files, use_safetensors, quant_state_dict
) -> Generator:
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
if not mapped_weight_name.lower().endswith(".scb"):
continue
weight_key = mapped_weight_name.lower().replace(".scb", ".weight")
quant_state_dict[weight_key] = weight_tensor
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
if self._is_8bit_weight_name(mapped_weight_name):
continue
if mapped_weight_name in quant_state_dict:
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
yield org_weight_name, weight_tensor
else:
yield org_weight_name, weight_tensor
def _quantized_4bit_generator(
self, hf_weights_files, use_safetensors, quant_state_dict
) -> Generator:
from bitsandbytes.functional import QuantState
# First iterate over all quant state weights
weight_iterator = self._hf_weight_iter(hf_weights_files, use_safetensors)
temp_state_dict = {}
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in weight_iterator:
if not self._is_4bit_weight_name(mapped_weight_name):
continue
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__* in CPU
if "quant_state.bitsandbytes" in mapped_weight_name:
temp_state_dict[mapped_weight_name] = weight_tensor.cpu().data
else:
temp_state_dict[mapped_weight_name] = weight_tensor
# Closure to parse quant_state for each prequant weight
def _parse_quant_state(param_name: str, temp_state_dict: dict) -> QuantState:
quant_state = {}
for k in temp_state_dict:
if param_name + "." in k:
quant_state[k] = temp_state_dict[k]
return QuantState.from_dict(
quant_state, device=current_platform.device_type
)
# Second iterate over all prequant and normal weights
# pre quantized weights would have a quant_state
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
if self._is_4bit_weight_name(mapped_weight_name):
continue
if (
f"{mapped_weight_name}.quant_state.bitsandbytes__nf4" in temp_state_dict
) or (
f"{mapped_weight_name}.quant_state.bitsandbytes__fp4" in temp_state_dict
):
quant_state = _parse_quant_state(mapped_weight_name, temp_state_dict)
quant_state_dict[mapped_weight_name] = quant_state
yield org_weight_name, weight_tensor
else:
yield org_weight_name, weight_tensor
def _unquantized_generator(
self, hf_weights_files, use_safetensors, quant_state_dict
) -> Generator:
from bitsandbytes.functional import quantize_4bit
global_tp_size = get_tensor_model_parallel_world_size()
global_tp_rank = get_tensor_model_parallel_rank()
check_match = (
lambda weight_name, module_name: weight_name.removesuffix(".weight")
== module_name
)
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
# override tp_size and tp_rank if the module has disabled TP
if any(
tp_disabled_module in mapped_weight_name
for tp_disabled_module in self.tp_disabled_modules
):
tp_size = 1
tp_rank = 0
else:
tp_size = global_tp_size
tp_rank = global_tp_rank
if any(
target_module in mapped_weight_name
for target_module in self.target_modules
) and mapped_weight_name.endswith(".weight"):
# Without sharding
if any(
check_match(mapped_weight_name, module)
for module in self.unsharded_weights_modules
):
weight_sub_tensor = weight_tensor
# Shard by column
elif any(
check_match(mapped_weight_name, module)
for module in self.column_sharded_weights_modules
):
total_size = weight_tensor.size(-1)
start_index = total_size // tp_size * tp_rank
end_index = total_size // tp_size * (tp_rank + 1)
weight_sub_tensor = weight_tensor[..., start_index:end_index]
# Weights have fused on disk. In this case, we assume that the
# weight and module use same name.
elif any(
check_match(mapped_weight_name, module)
for module in self.maybe_fused_weights_modules
):
# special case for fused weights
# get the size of each shard weight tensor
total_shard_sizes = next(
(
sizes
for module, sizes in self.maybe_fused_weights_modules.items() # noqa: E501
if check_match(mapped_weight_name, module)
)
)
total_size = weight_tensor.size(0)
assert total_size == sum(total_shard_sizes)
# get the start/end index of each shard weight tensor
total_start_index = list(
itertools.accumulate([0] + total_shard_sizes)
)[:-1]
shard_weights_index = [
(
idx + size // tp_size * tp_rank,
idx + size // tp_size * (tp_rank + 1),
)
for idx, size in zip(total_start_index, total_shard_sizes)
]
# slice and reorder the weight tensor
weight_tensor = [
weight_tensor[start_index:end_index, ...]
for start_index, end_index in shard_weights_index
]
weight_sub_tensor = torch.cat(weight_tensor, dim=0)
# Shard by row
else:
total_size = weight_tensor.size(0)
start_index = total_size // tp_size * tp_rank
end_index = total_size // tp_size * (tp_rank + 1)
weight_sub_tensor = weight_tensor[start_index:end_index, ...]
# bitsandbytes requires data in GPU
if weight_sub_tensor.is_cuda:
loaded_weight = weight_sub_tensor
else:
loaded_weight = weight_sub_tensor.to(
device=current_platform.device_type
)
# remove the following after the issue is fixed:
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
if loaded_weight.is_contiguous() is False:
loaded_weight = loaded_weight.contiguous()
with set_default_torch_dtype(torch.float32):
processed_weight, quant_state = quantize_4bit(
loaded_weight,
compress_statistics=True,
quant_type="nf4",
)
quant_state_dict[mapped_weight_name] = quant_state
else:
processed_weight = weight_tensor
yield org_weight_name, processed_weight
def _get_bnb_target_modules(self, model: nn.Module) -> None:
"""
Identify and collect all modules that support BitsAndBytes
quantization.
"""
for name, module in model.named_modules():
if isinstance(module, LinearBase) and hasattr(
module.quant_method, "quant_config"
):
if modules_info := self.modules_mapping.get_sub_modules(name):
# Map vllm's names to transformers's names.
rep_name, sub_modules = modules_info
for sub_name in sub_modules:
new_name = name.replace(rep_name, sub_name)
self.target_modules.append(new_name)
if module.disable_tp:
self.tp_disabled_modules.append(new_name)
# Add original module name even if the module has stacked map,
# in case model has a mixture of disk-merged and disk-split
# weights with same last name.
self.target_modules.append(name)
if module.disable_tp:
self.tp_disabled_modules.append(name)
elif isinstance(module, FusedMoE) and hasattr(
module.quant_method, "quant_config"
):
# TODO: support FusedMoE with prequant and 8bit.
if self.pre_quant and self.load_8bit:
raise ValueError(
"Prequant BitsAndBytes 8bit models with FusedMoE "
"is not supported yet."
)
# Get the corresponding weight name using module name and
# expert_params_mapping.
for exp in self.expert_params_mapping:
weight_name = exp[1]
rep_name = name.replace("experts", "") + weight_name.removesuffix(
"."
)
self.target_modules.append(rep_name)
assert self.target_modules, (
"vLLM currently does not support BNB quantization for"
)
f" {type(model).__name__}"
def _classify_module_sharding(self, model: nn.Module):
"""
Categorize modules based on their weight sharding requirements
for tensor parallelism.
"""
for name, module in model.named_modules():
# Some modules like `ReplicatedLinear` should not have their weights
# sharded. The reason for implementing it this way is to avoid new
# static variable in the model implementation.
if isinstance(module, (ReplicatedLinear,)):
self.unsharded_weights_modules.append(name)
# `QKVParallelLinear` and `MergedColumnParallelLinear` might have
# fused weights on disk. We need to use the output sizes of these
# modules to shard the weights correctly.
elif isinstance(module, (QKVParallelLinear, MergedColumnParallelLinear)):
self.maybe_fused_weights_modules[name] = module.output_sizes
# In TP, these weights are partitioned along the column
# dimension (dim=-1)
elif isinstance(module, (RowParallelLinear,)):
self.column_sharded_weights_modules.append(name)
elif isinstance(module, FusedMoE):
expert_mapping = self.expert_params_mapping
for exp in expert_mapping:
if exp[-1] == "w2":
weight_name = exp[1]
rep_name = name.replace(
"experts", ""
) + weight_name.removesuffix(".")
self.column_sharded_weights_modules.append(rep_name)
def _verify_model_compatibility(
self, model: nn.Module, model_config: ModelConfig
) -> None:
"""
Verify that the model is compatible with BitsAndBytes quantization.
"""
if not hasattr(model, "load_weights"):
raise AttributeError(
"The required method 'load_weights' is not defined in class"
f" {type(model).__name__}."
)
if not hasattr(model, "packed_modules_mapping"):
raise AttributeError(
f"Model {type(model).__name__} does not support BitsAndBytes "
"quantization yet. No 'packed_modules_mapping' found."
)
quant_config = getattr(model_config.hf_config, "quantization_config", None)
if quant_config and (quant_method := quant_config.get("quant_method")):
if quant_method == "bitsandbytes":
self.pre_quant = True
else:
raise ValueError(
f"BitsAndBytes loader does not support {quant_method} quantization"
)
# The quant_states in pre_quantized models cannot work with a split
# weight tensor. So TP does not work with pre_quantized bnb models.
if self.pre_quant and get_tensor_model_parallel_world_size() > 1:
raise ValueError(
"Prequant BitsAndBytes models with tensor parallelism is not "
"supported. Please try with pipeline parallelism."
)
if quant_config and self.pre_quant:
self.load_8bit = quant_config.get("load_in_8bit", False)
def _initialize_loader_state(
self, model: nn.Module, model_config: ModelConfig
) -> None:
"""
Initialize the loader's internal state based on the model and
configuration.
"""
self.is_pool_model = is_pooling_model(model)
self.modules_mapping = ParamMapping(get_packed_modules_mapping(model))
if is_moe_model(model):
self.expert_params_mapping = get_moe_expert_mapping(model)
if not self.expert_params_mapping:
raise AttributeError(
f"MoE Model {type(model).__name__} does not support "
"BitsAndBytes quantization yet. Ensure this model has "
"'get_expert_mapping' method."
)
# For some models like Molmo, we need to use hf_to_vllm_mapper
# to ensure correct loading of weights.
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name)
self._get_bnb_target_modules(model)
self._classify_module_sharding(model)
def _dequantize_dq(self, quant_states: Any):
"""
When BNB employs Double Quantization, we perform the dequantization of
these constants during weight loading rather than at inference time,
thereby avoiding this computational overhead during inference. This
comes at the cost of increased memory usage.
"""
from bitsandbytes.functional import QuantState, dequantize_blockwise
def _dequantize_single_state(quant_state):
"""Helper function to dequantize a single QuantState object."""
if not (isinstance(quant_state, QuantState) and quant_state.nested):
return
# Copied from: https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.45.3/bitsandbytes/functional.py#L1352-#L1356
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
absmax += quant_state.offset
# Ensure float32 dtype
if absmax.dtype != torch.float32:
absmax = absmax.float()
quant_state.absmax = absmax
quant_state.nested = False
quant_state.offset = None
quant_state.state2 = None
if isinstance(quant_states, dict):
for quant_state in quant_states.values():
_dequantize_single_state(quant_state)
else:
_dequantize_single_state(quant_states)
return quant_states
def _fuse_moe_quant_states(self, model: nn.Module, quant_states_dict: dict) -> dict:
"""
This function consolidates individual expert quantization states into
fused representations for w13 and w2.
"""
from bitsandbytes.functional import QuantState
if not self.expert_params_mapping:
return dict()
expert_mapping = self.expert_params_mapping
expert_qs_dict = {}
for name, module in model.named_modules():
if not isinstance(module, FusedMoE):
continue
w1_states_lst = []
w2_states_lst = []
w3_states_lst = []
for exp in expert_mapping:
shard_id = exp[-1]
if shard_id not in ("w1", "w2", "w3"):
raise ValueError(
f"shard_id must be ['w1','w2','w3'] but got {shard_id}."
)
layer_prefix = name.split("experts")[0]
weight_qual_name = layer_prefix + exp[1] + "weight"
quant_state = self._dequantize_dq(quant_states_dict[weight_qual_name])
if shard_id == "w1":
w1_states_lst.append(quant_state)
elif shard_id == "w2":
w2_states_lst.append(quant_state)
else:
w3_states_lst.append(quant_state)
del quant_states_dict[weight_qual_name]
assert len(w1_states_lst) == len(w2_states_lst) == len(w3_states_lst)
w13_absmax_lst = []
w2_absmax_lst = []
w13_total_dim0 = 0
w2_total_dim0 = 0
for w1_qs, w2_qs, w3_qs in zip(w1_states_lst, w2_states_lst, w3_states_lst):
assert w1_qs.shape == w3_qs.shape
assert w1_qs.blocksize == w2_qs.blocksize == w3_qs.blocksize
assert w1_qs.dtype == w2_qs.dtype == w3_qs.dtype
# w1 and w3 are interleaved in storage
w13_absmax_lst.append(w1_qs.absmax)
w13_absmax_lst.append(w3_qs.absmax)
w2_absmax_lst.append(w2_qs.absmax)
w13_total_dim0 += w1_qs.shape[0] + w3_qs.shape[0]
w2_total_dim0 += w2_qs.shape[0]
w13_absmax = torch.cat(w13_absmax_lst)
w2_absmax = torch.cat(w2_absmax_lst)
# Create fused quantization state for w13.
w13_qs = QuantState(
absmax=w13_absmax,
shape=(w13_total_dim0, w1_states_lst[0].shape[1]),
code=w1_states_lst[0].code,
blocksize=w1_states_lst[0].blocksize,
quant_type="nf4",
dtype=w1_states_lst[0].dtype,
)
# Create fused quantization state for w2.
w2_qs = QuantState(
absmax=w2_absmax,
shape=(w2_total_dim0, w2_states_lst[0].shape[1]),
code=w2_states_lst[0].code,
blocksize=w2_states_lst[0].blocksize,
quant_type="nf4",
dtype=w2_states_lst[0].dtype,
)
# The weight suffixes .w13_weight and .w2_weight are consistent
# with the param in BitsAndBytesMoEMethod.
w13_weight_name = name + ".w13_weight"
w2_weight_name = name + ".w2_weight"
expert_qs_dict[w13_weight_name] = w13_qs
expert_qs_dict[w2_weight_name] = w2_qs
return expert_qs_dict
def _stack_quantization_states(
self, model: nn.Module, quant_state_dict: dict
) -> dict[str, dict[int, Any]]:
stacked_quant_state_dict: dict[str, dict[int, Any]] = {}
# TODO: Change this lazy import to normal import
# after the checks are updated to run on a new version
from vllm.model_executor.models.utils import is_pp_missing_parameter
param_dict = dict(model.named_parameters())
for quant_param_name in quant_state_dict:
if is_pp_missing_parameter(quant_param_name, model):
continue
non_stacked_param_name = quant_param_name
shard_index = 0
for shard_name, (
weight_name,
index,
) in self.modules_mapping.inverse_packed_mapping.items():
# Some models, such as MiniCPM V2.5/2.6, contain both
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
# from being incorrectly identified as being present in
# 'vpm.encoder.layers.0.self_attn.qkv_proj.weight
shard_pos = quant_param_name.find(shard_name)
can_correct_rename = (shard_pos > 0) and (
quant_param_name[shard_pos - 1] == "."
)
# If the quant_param_name is packed, it won't occur in the
# param_dict before renaming.
new_quant_param_name = quant_param_name.replace(shard_name, weight_name)
need_rename = (quant_param_name not in param_dict) and (
new_quant_param_name in param_dict
)
if can_correct_rename and need_rename:
shard_index = index
quant_param_name = new_quant_param_name
break
# Models like Clip/Siglip may skip some layers in initialization,
# causing unused quant_param_name in state_dict.
if quant_param_name not in param_dict:
continue
if quant_param_name not in stacked_quant_state_dict:
stacked_quant_state_dict[quant_param_name] = {}
stacked_quant_state_dict[quant_param_name][shard_index] = quant_state_dict[
non_stacked_param_name
]
return stacked_quant_state_dict
def _bind_quant_states_to_params(
self, model: nn.Module, stacked_quant_state_dict: dict
) -> None:
# save quant_states and offsets as the attributes of the parameters
param_dict = dict(model.named_parameters())
for param_name, param in param_dict.items():
if param_name in stacked_quant_state_dict:
quant_states = stacked_quant_state_dict[param_name]
# Dequantize double quantized values during weight loading.
self._dequantize_dq(quant_states)
set_weight_attrs(param, {"bnb_quant_state": quant_states})
if not isinstance(quant_states, dict):
continue
pack_ratio = getattr(param, "pack_factor", -1)
if pack_ratio == -1:
raise ValueError(f"pack_factor not set for parameter {param_name}.")
num_elements = [0] * len(quant_states)
for seq, quant_state in quant_states.items():
num_elements[seq] = math.prod(quant_state.shape) // pack_ratio
offsets = np.concatenate(([0], np.cumsum(num_elements)))
# Make torch infer_schema happy
offsets = torch.tensor(offsets).cpu()
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
if self.load_8bit:
set_weight_attrs(
param, {"matmul_state": [None] * len(quant_states)}
)
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
self._verify_model_compatibility(model, model_config)
self._initialize_loader_state(model, model_config)
logger.info(
"Loading weights with BitsAndBytes quantization. May take a while ..."
)
qweight_iterator, quant_state_dict = self._get_quantized_weights_iterator(
model_config.model,
model_config.revision,
)
weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(qweight_iterator)
# Some models may have weights loading tracker unimplemented.
if loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded:
raise ValueError(
"Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}"
)
expert_quant_state_dict = self._fuse_moe_quant_states(model, quant_state_dict)
stacked_quant_state_dict = self._stack_quantization_states(
model, quant_state_dict
)
stacked_quant_state_dict = {
**expert_quant_state_dict,
**stacked_quant_state_dict,
}
self._bind_quant_states_to_params(model, stacked_quant_state_dict)
torch.cuda.empty_cache()
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)

View File

@@ -0,0 +1,307 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import glob
import os
import time
from collections.abc import Generator, Iterable
from typing import cast
import torch
from torch import nn
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import ModelConfig
from vllm.config.load import LoadConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.torchao import torchao_version_at_least
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf,
download_weights_from_hf,
fastsafetensors_weights_iterator,
filter_duplicate_safetensors_files,
filter_files_not_needed_for_inference,
get_quant_config,
maybe_download_from_modelscope,
multi_thread_pt_weights_iterator,
multi_thread_safetensors_weights_iterator,
np_cache_weights_iterator,
pt_weights_iterator,
safetensors_weights_iterator,
)
from vllm.tracing import instrument
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
logger = init_logger(__name__)
class DefaultModelLoader(BaseModelLoader):
"""Model loader that can load different file types from disk."""
# default number of thread when enable multithread weight loading
DEFAULT_NUM_THREADS = 8
@dataclasses.dataclass
class Source:
"""A source for weights."""
model_or_path: str
"""The model ID or path."""
revision: str | None
"""The optional model revision."""
prefix: str = ""
"""A prefix to prepend to all weights."""
fall_back_to_pt: bool = True
"""Whether .pt weights can be used."""
allow_patterns_overrides: list[str] | None = None
"""If defined, weights will load exclusively using these patterns."""
counter_before_loading_weights: float = 0.0
counter_after_loading_weights: float = 0.0
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
extra_config = load_config.model_loader_extra_config
allowed_keys = {"enable_multithread_load", "num_threads"}
unexpected_keys = set(extra_config.keys()) - allowed_keys
if unexpected_keys:
raise ValueError(
f"Unexpected extra config keys for load format "
f"{load_config.load_format}: "
f"{unexpected_keys}"
)
def _prepare_weights(
self,
model_name_or_path: str,
revision: str | None,
fall_back_to_pt: bool,
allow_patterns_overrides: list[str] | None,
) -> tuple[str, list[str], bool]:
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
model_name_or_path = (
maybe_download_from_modelscope(model_name_or_path, revision)
or model_name_or_path
)
is_local = os.path.isdir(model_name_or_path)
load_format = self.load_config.load_format
use_safetensors = False
index_file = SAFE_WEIGHTS_INDEX_NAME
# First check for 'auto' format that mistral files format are present.
# This is to load mistral models with official format by default.
if load_format == "auto":
load_format = (
"mistral"
if len(
list_filtered_repo_files(
model_name_or_path=model_name_or_path,
allow_patterns=["consolidated*.safetensors"],
revision=revision,
)
)
> 0
else "hf"
)
# Some quantized models use .pt files for storing the weights.
if load_format == "hf":
allow_patterns = ["*.safetensors", "*.bin"]
elif load_format == "safetensors" or load_format == "fastsafetensors":
use_safetensors = True
allow_patterns = ["*.safetensors"]
elif load_format == "mistral":
use_safetensors = True
allow_patterns = ["consolidated*.safetensors"]
index_file = "consolidated.safetensors.index.json"
elif load_format == "pt":
allow_patterns = ["*.pt"]
elif load_format == "npcache":
allow_patterns = ["*.bin"]
else:
raise ValueError(f"Unknown load_format: {load_format}")
if fall_back_to_pt:
allow_patterns += ["*.pt"]
if allow_patterns_overrides is not None:
allow_patterns = allow_patterns_overrides
if not is_local:
hf_folder = download_weights_from_hf(
model_name_or_path,
self.load_config.download_dir,
allow_patterns,
revision,
ignore_patterns=self.load_config.ignore_patterns,
)
else:
hf_folder = model_name_or_path
hf_weights_files: list[str] = []
for pattern in allow_patterns:
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
if len(hf_weights_files) > 0:
if pattern == "*.safetensors":
use_safetensors = True
break
if use_safetensors:
# For models like Mistral-7B-Instruct-v0.3
# there are both sharded safetensors files and a consolidated
# safetensors file. Using both breaks.
# Here, we download the `model.safetensors.index.json` and filter
# any files not found in the index.
if not is_local:
download_safetensors_index_file_from_hf(
model_name_or_path,
index_file,
self.load_config.download_dir,
revision,
)
hf_weights_files = filter_duplicate_safetensors_files(
hf_weights_files, hf_folder, index_file
)
else:
hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files)
if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`"
)
return hf_folder, hf_weights_files, use_safetensors
def _get_weights_iterator(
self, source: "Source"
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
extra_config = self.load_config.model_loader_extra_config
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
source.model_or_path,
source.revision,
source.fall_back_to_pt,
source.allow_patterns_overrides,
)
if self.load_config.load_format == "npcache":
# Currently np_cache only support *.bin checkpoints
assert use_safetensors is False
weights_iterator = np_cache_weights_iterator(
source.model_or_path,
self.load_config.download_dir,
hf_folder,
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
elif use_safetensors:
if self.load_config.load_format == "fastsafetensors":
weights_iterator = fastsafetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
else:
if extra_config.get("enable_multithread_load"):
weights_iterator = multi_thread_safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
max_workers=extra_config.get(
"num_threads", self.DEFAULT_NUM_THREADS
),
)
else:
weights_iterator = safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
self.load_config.safetensors_load_strategy,
)
else:
if extra_config.get("enable_multithread_load"):
weights_iterator = multi_thread_pt_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
self.load_config.pt_load_map_location,
max_workers=extra_config.get(
"num_threads", self.DEFAULT_NUM_THREADS
),
)
else:
weights_iterator = pt_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
self.load_config.pt_load_map_location,
)
if self.counter_before_loading_weights == 0.0:
self.counter_before_loading_weights = time.perf_counter()
# Apply the prefix.
return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator)
def get_all_weights(
self,
model_config: ModelConfig,
model: nn.Module,
) -> Generator[tuple[str, torch.Tensor], None, None]:
primary_weights = DefaultModelLoader.Source(
model_config.model,
model_config.revision,
prefix="",
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True),
allow_patterns_overrides=getattr(model, "allow_patterns_overrides", None),
)
yield from self._get_weights_iterator(primary_weights)
secondary_weights = cast(
Iterable[DefaultModelLoader.Source],
getattr(model, "secondary_weights", ()),
)
for source in secondary_weights:
yield from self._get_weights_iterator(source)
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(
model_config.model,
model_config.revision,
fall_back_to_pt=True,
allow_patterns_overrides=None,
)
@instrument(span_name="Load weights")
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
if model_config.quantization == "torchao":
quant_config = get_quant_config(model_config, self.load_config)
if (
hasattr(quant_config, "is_checkpoint_torchao_serialized")
and quant_config.is_checkpoint_torchao_serialized
and torchao_version_at_least("0.15.0")
):
self.load_config.safetensors_load_strategy = "torchao"
weights_to_load = {name for name, _ in model.named_parameters()}
all_weights = self.get_all_weights(model_config, model)
loaded_weights = model.load_weights(all_weights)
self.counter_after_loading_weights = time.perf_counter()
logger.info_once(
"Loading weights took %.2f seconds",
self.counter_after_loading_weights - self.counter_before_loading_weights,
scope="local",
)
# We only enable strict check for non-quantized models
# that have loaded weights tracking currently.
if model_config.quantization is None and loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded:
raise ValueError(
"Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}"
)

View File

@@ -0,0 +1,28 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch.nn as nn
from vllm.config import ModelConfig
from vllm.config.load import LoadConfig
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.weight_utils import initialize_dummy_weights
class DummyModelLoader(BaseModelLoader):
"""Model loader that will set model weights to random values."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(
f"Model loader extra config is not supported for "
f"load format {load_config.load_format}"
)
def download_model(self, model_config: ModelConfig) -> None:
pass # Nothing to download
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model, model_config)

View File

@@ -0,0 +1,366 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from collections.abc import Generator
import gguf
import regex as re
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
from vllm.config import ModelConfig, VllmConfig
from vllm.config.load import LoadConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.utils import (
initialize_model,
process_weights_after_loading,
)
from vllm.model_executor.model_loader.weight_utils import (
download_gguf,
get_gguf_extra_tensor_names,
get_gguf_weight_type_map,
gguf_quant_weights_iterator,
)
from vllm.transformers_utils.gguf_utils import detect_gguf_multimodal
from vllm.utils.torch_utils import set_default_torch_dtype
logger = init_logger(__name__)
class GGUFModelLoader(BaseModelLoader):
"""
Model loader that can load GGUF files. This is useful for loading models
that are quantized with GGUF and saved in the GGUF format. This loader
supports loading both full models and sharded models.
"""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(
f"Model loader extra config is not supported for "
f"load format {load_config.load_format}"
)
def _prepare_weights(self, model_config: ModelConfig):
model_name_or_path = model_config.model
if os.path.isfile(model_name_or_path):
return model_name_or_path
# repo id/filename.gguf
if "/" in model_name_or_path and model_name_or_path.endswith(".gguf"):
repo_id, filename = model_name_or_path.rsplit("/", 1)
return hf_hub_download(repo_id=repo_id, filename=filename)
# repo_id:quant_type
elif "/" in model_name_or_path and ":" in model_name_or_path:
repo_id, quant_type = model_name_or_path.rsplit(":", 1)
return download_gguf(
repo_id,
quant_type,
cache_dir=self.load_config.download_dir,
revision=model_config.revision,
ignore_patterns=self.load_config.ignore_patterns,
)
raise ValueError(
f"Unrecognised GGUF reference: {model_name_or_path} "
"(expected local file, <repo_id>/<filename>.gguf, "
"or <repo_id>:<quant_type>)"
)
def _get_gguf_weights_map(self, model_config: ModelConfig):
"""
GGUF uses this naming convention for their tensors from HF checkpoint:
`blk.N.BB.weight` and `blk.N.BB.bias`
where N signifies the block number of a layer, and BB signifies the
attention/mlp layer components.
See "Standardized tensor names" in
https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
"""
config = model_config.hf_config
# Get text config to handle both nested (multimodal) and flat
# (text-only) config structures. For multimodal models like
# Gemma3Config, this returns config.text_config. For text-only
# models, this returns config itself.
text_config = config.get_text_config()
model_type = config.model_type
is_multimodal = (
hasattr(config, "vision_config") and config.vision_config is not None
)
gguf_to_hf_name_map = {}
sideload_params: list[re.Pattern] = []
# hack: ggufs have a different name than transformers
if model_type == "cohere":
model_type = "command-r"
if model_type == "gemma3_text":
# Gemma3 models use "gemma3_text" in HuggingFace but
# "gemma3" in GGUF architecture naming
model_type = "gemma3"
if model_type in ("deepseek_v3", "deepseek_v2"):
model_type = "deepseek2"
# GGUF layer map assumes that we will have a merged expert weights
# so we need to map them manually
for idx in range(config.num_hidden_layers):
gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = (
f"model.layers.{idx}.mlp.gate.e_score_correction_bias"
)
gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = (
f"model.layers.{idx}.mlp.experts.0.down_proj.weight"
)
gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = (
f"model.layers.{idx}.mlp.experts.0.gate_proj.weight"
)
gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = (
f"model.layers.{idx}.mlp.experts.0.up_proj.weight"
)
sideload_params.append(
re.compile(
f"model\\.layers\\.{idx}"
r"\.mlp\.experts\.[0-9]+\.(gate|up|down)_proj\.weight"
)
)
if model_type in ("qwen2_moe", "qwen3_moe"):
model_type = model_type.replace("_", "")
# GGUF layer map assumes that we will have a merged expert weights
# so we need to map them manually
for idx in range(config.num_hidden_layers):
gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = (
f"model.layers.{idx}.mlp.experts.0.down_proj.weight"
)
gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = (
f"model.layers.{idx}.mlp.experts.0.gate_proj.weight"
)
gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = (
f"model.layers.{idx}.mlp.experts.0.up_proj.weight"
)
sideload_params.append(
re.compile(
f"model\\.layers\\.{idx}"
r"\.mlp\.experts\.[0-9]+\.(gate|up|down)_proj\.weight"
)
)
arch = None
for key, value in gguf.MODEL_ARCH_NAMES.items():
if value == model_type:
arch = key
break
if arch is None:
raise RuntimeError(f"Unknown gguf model_type: {model_type}")
text_num_layers = text_config.num_hidden_layers
text_name_map = gguf.get_tensor_name_map(arch, text_num_layers)
if is_multimodal:
mm_proj_arch = gguf.MODEL_ARCH.MMPROJ
vision_num_layers = config.vision_config.num_hidden_layers
vision_name_map = gguf.get_tensor_name_map(mm_proj_arch, vision_num_layers)
else:
vision_name_map = None
# Create dummy model to extract parameter names
# For multimodal: use AutoModelForImageTextToText to get
# language + vision + projector params
# For text-only: use AutoModelForCausalLM to get language model params
auto_cls = (
AutoModelForImageTextToText if is_multimodal else AutoModelForCausalLM
)
with torch.device("meta"):
dummy_model = auto_cls.from_config(
config, trust_remote_code=model_config.trust_remote_code
)
state_dict = dummy_model.state_dict()
if hf_checkpoint_map := getattr(
dummy_model, "_checkpoint_conversion_mapping", None
):
def revert_hf_rename(name: str) -> str:
for original_name, hf_name in hf_checkpoint_map.items():
if hf_name in name:
name = name.replace(hf_name, original_name).lstrip("^")
return name
state_dict = {
revert_hf_rename(name): tensor for name, tensor in state_dict.items()
}
def find_hf_name_in_tensor_map(hf_name: str) -> str | None:
"""
Map HuggingFace parameter name to GGUF tensor name.
This function handles the mismatch between HF parameter naming
conventions and gguf-py's expected format:
1. Strips 'model.' prefix (common in multimodal models)
2. Converts '_weight' suffix to '.weight' (Gemma3 compatibility)
3. Searches vision_name_map for multimodal parameters
4. Falls back to text_name_map for language model parameters
Args:
hf_name: Full HuggingFace parameter name (e.g.,
'model.multi_modal_projector.mm_soft_emb_norm.weight')
Returns:
GGUF tensor name with suffix (e.g., 'mm.soft_emb_norm.weight')
or None if no mapping found
"""
# Strip 'language_model.' prefix for multimodal models - gguf-py
# tensor mappings expect parameter names without this prefix.
# Note: 'model.' prefix should be KEPT for text-only models as
# gguf-py expects it.
if hf_name.startswith("language_model."):
hf_name = hf_name[15:] # Remove 'language_model.'
# Parse parameter name and suffix
if hf_name.endswith((".weight", ".bias")):
base_name, suffix = hf_name.rsplit(".", 1)
else:
base_name, suffix = hf_name, ""
# Handle '_weight' suffix (Gemma3 naming: parameter ends with
# '_weight' instead of '.weight')
if base_name.endswith("_weight"):
base_name = base_name[:-7] # Remove '_weight'
suffix = "weight"
gguf_name = None
# Priority 1: Search vision/projector parameters for multimodal models
if vision_name_map is not None:
gguf_name = vision_name_map.get_name(base_name)
# Priority 2: Search text backbone parameters
if gguf_name is None:
gguf_name = text_name_map.get_name(base_name)
if gguf_name is None:
return None
return gguf_name + "." + suffix
# Build mapping and track unmapped parameters
unmapped_params = []
for hf_name in state_dict:
gguf_name_with_suffix = find_hf_name_in_tensor_map(hf_name)
# Track mapping success
if gguf_name_with_suffix is not None:
gguf_to_hf_name_map[gguf_name_with_suffix] = hf_name
logger.debug("Mapped GGUF %s → HF %s", gguf_name_with_suffix, hf_name)
elif hf_name not in gguf_to_hf_name_map.values():
# Parameter not in manual overrides either
unmapped_params.append(hf_name)
# All parameters (except those initialized by other means) must be mapped:
# both vision/projector and backbone
if unmapped_params:
unmapped_params = list(
filter(
lambda x: not any(re.fullmatch(p, x) for p in sideload_params),
unmapped_params,
)
)
if unmapped_params:
raise RuntimeError(
f"Failed to map GGUF parameters "
f"({len(unmapped_params)}): "
f"{unmapped_params}"
)
return gguf_to_hf_name_map
def _get_gguf_weight_type(
self,
model_config: ModelConfig,
model_name_or_path: str,
gguf_to_hf_name_map: dict[str, str],
) -> dict[str, str]:
weight_type_map = get_gguf_weight_type_map(
model_name_or_path, gguf_to_hf_name_map
)
is_multimodal = hasattr(model_config.hf_config, "vision_config")
if is_multimodal:
mmproj_file = detect_gguf_multimodal(model_name_or_path)
assert mmproj_file is not None, (
"Could not find mm_proj file for multimodal GGUF model"
)
logger.info("Loading extra mm_proj weights from %s...", mmproj_file)
mm_proj_weight_type_map = get_gguf_weight_type_map(
mmproj_file, gguf_to_hf_name_map
)
weight_type_map.update(mm_proj_weight_type_map)
return weight_type_map
def _get_weights_iterator(
self,
model_config: ModelConfig,
model_name_or_path: str,
gguf_to_hf_name_map: dict[str, str],
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""
Iterate over GGUF model weights, loading from both main model file and
mmproj.gguf for multimodal Gemma3 models.
For Gemma3 multimodal GGUF models:
- Main file (gemma-3-*.gguf): Language model weights (model.*)
- mmproj file (mmproj*.gguf): Vision tower + projector weights (v.*, mm.*)
Yields:
Tuples of (parameter_name, tensor) for all model weights
"""
hf_config = model_config.hf_config
is_multimodal = hasattr(hf_config, "vision_config")
if is_multimodal:
# Load mm_proj (mm_encoder + projector) for multimodal weights
mmproj_file = detect_gguf_multimodal(model_name_or_path)
assert mmproj_file is not None, (
"Could not find mm_proj file for multimodal GGUF model"
)
yield from gguf_quant_weights_iterator(mmproj_file, gguf_to_hf_name_map)
yield from gguf_quant_weights_iterator(model_name_or_path, gguf_to_hf_name_map)
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config)
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
local_model_path = self._prepare_weights(model_config)
gguf_weights_map = self._get_gguf_weights_map(model_config)
model.load_weights(
self._get_weights_iterator(model_config, local_model_path, gguf_weights_map)
)
def load_model(
self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = ""
) -> nn.Module:
device_config = vllm_config.device_config
local_model_path = self._prepare_weights(model_config)
gguf_weights_map = self._get_gguf_weights_map(model_config)
# we can only know if tie word embeddings after mapping weights
if "lm_head.weight" in get_gguf_extra_tensor_names(
local_model_path, gguf_weights_map
):
model_config.hf_config.update({"tie_word_embeddings": True})
weight_type_map = self._get_gguf_weight_type(
model_config, local_model_path, gguf_weights_map
)
# filter out unquantized modules to skip
unquant_names = [
name.removesuffix(".weight")
for name, weight_type in weight_type_map.items()
if weight_type in ("F32", "F16", "BF16") and name.endswith(".weight")
]
logger.debug(
"GGUF unquantized modules: %s",
unquant_names,
)
vllm_config.quant_config.unquantized_modules.extend(unquant_names)
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = initialize_model(vllm_config=vllm_config, prefix=prefix)
self.load_weights(model, model_config)
process_weights_after_loading(model, model_config, target_device)
return model

View File

@@ -0,0 +1,37 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Layerwise weight reloading utilities for vLLM.
This module provides functionality to reload model weights layer-by-layer,
which is useful for weight updates without full model reconstruction.
Limitations:
1. Composition with CPU offloading has not been implemented
2. Reloading Attention/MLA weights (q_scale, k_scale, v_scale) has not been implemented
3. Tied parameters will only reflect processing from one of the parent layers (for
example, only processing from embed_tokens will have an effect)
4. This design assumes that the number of weights loaded from disk is the same as the
number of weights created at model init time. This is not true for quant methods
which (1) pad weights or (2) load qkv weights into the same parameter. Both of these
cases are non-issues for today's quant methods, but future quantizations may cause
reloading to fail
"""
__all__ = [
"record_metadata_for_reloading",
"initialize_layerwise_reload",
"finalize_layerwise_reload",
"set_torchao_reload_attrs",
"support_quantized_model_reload_from_hp_weights",
]
from .layerwise import (
finalize_layerwise_reload,
initialize_layerwise_reload,
record_metadata_for_reloading,
)
from .torchao_decorator import (
set_torchao_reload_attrs,
support_quantized_model_reload_from_hp_weights,
)

View File

@@ -0,0 +1,275 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import inspect
from collections.abc import Callable
from functools import wraps
from weakref import WeakKeyDictionary
import torch
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention, MLAAttention
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from .meta import (
capture_layer_to_meta,
get_numel_loaded,
materialize_layer,
restore_layer_on_meta,
)
from .types import LayerReloadingInfo
from .utils import get_layer_params_buffers, get_layer_size, get_layer_tensors
logger = init_logger(__name__)
__all__ = [
"get_layerwise_info",
"record_metadata_for_reloading",
"initialize_layerwise_reload",
"finalize_layerwise_reload",
]
# Global dict storing information used for layerwise restoring, loading, and processing.
# For more information regarding what info is stored when, see `LayerReloadingInfo`
#
# Use a weak ref dictionary so that modules can be freed when the model is freed.
# Values are sanitized from references to the layer key in order to avoid circular refs
LAYERWISE_INFO: WeakKeyDictionary[torch.nn.Module, LayerReloadingInfo] = (
WeakKeyDictionary()
)
def get_layerwise_info(layer: torch.nn.Module) -> LayerReloadingInfo:
"""
Get information related to restoring and layerwise processing. If no previous
information existed, a new entry is constructed
"""
if layer not in LAYERWISE_INFO:
LAYERWISE_INFO[layer] = LayerReloadingInfo()
return LAYERWISE_INFO[layer]
def record_metadata_for_reloading(model: torch.nn.Module):
"""
Record layer metadata needed for later reloading.
Stores parameter and buffer metadata as meta tensors for restoration.
Must be called before `initialize_layerwise_reload`.
"""
for layer in model.modules():
info = get_layerwise_info(layer)
info.restore_metadata = capture_layer_to_meta(layer)
@torch.no_grad()
def initialize_layerwise_reload(model: torch.nn.Module):
"""
Set up layerwise weight loading with deferred processing.
Must be called after `record_metadata_for_reloading`. This function:
1. Saves current kernel tensors for later copying
2. Restores layer parameters/buffers from metadata (on meta device)
3. Wraps weight loaders to defer processing until all weights are loaded
When all weights for a layer are loaded, the wrapped loaders will:
1. Materialize the layer onto the target device
2. Load all cached weights
3. Run quantization processing if applicable
4. Copy processed values back to original tensor storage
"""
# disable torchao reloading to avoid infinite recursion
model._original_do_torchao_reload = getattr(model, "_do_torchao_reload", False)
model._do_torchao_reload = False
for layer in model.modules():
info = get_layerwise_info(layer)
# Skip if the layer has already been initialized
if info.can_process():
continue
# Save current tensors for later copying
info.kernel_tensors = get_layer_params_buffers(layer)
# Restore layer parameters/buffers onto meta device
restore_layer_on_meta(layer, info)
# Track loading progress to determine when to process/copy
info.load_numel = 0
info.load_numel_total = get_layer_size(layer)
# Wrap each parameter's weight loader
# Note that nested wrapping will occur for shared tensors
for name, tensor in get_layer_tensors(layer).items():
if _get_weight_loader(tensor).__name__ != "online_process_loader":
tensor.weight_loader = make_online_process_loader(layer, name)
def make_online_process_loader(layer: torch.nn.Module, param_name: str) -> Callable:
"""Create a wrapped weight loader that defers processing."""
info = get_layerwise_info(layer)
param = getattr(layer, param_name)
original_loader = _get_original_loader(param)
loader_signature = inspect.signature(original_loader)
@wraps(original_loader, assigned=("__doc__", "__annotations__"))
def online_process_loader(*args, **kwargs):
if not info.can_process():
# Unfortunately, some qconfigs are set up to load the same weight
# multiple times. For example, CT_WNA16 loads `weight_shape` for
# each of the qkv partitions. This results in layers loading extra
# weights (beyond load_numel_total) after it's already processed.
#
# Best solution is to ensure that `load_numel_total` reflects the
# actual number of weights loaded, either by modifying qconfigs to
# create as many weights as loaded (see padding issue as well)
# or maybe capturing how many weights are loaded on first pass
#
# For now, `load_numel_total` is still safe to use as long as
# there's no way to reach `load_numel_total` without loading all
# necessary weights. `weight_shape` is very small, so this is safe.
# see Limitations(4)
logger.debug("%s: Excessive loading", layer.__class__.__name__)
return
# Bind and normalize arguments
bound_args = loader_signature.bind(*args, **kwargs)
bound_args.apply_defaults()
# Cache loaded weights, track loading progress
info.loaded_weights.append((param_name, bound_args))
num_loaded, ret = get_numel_loaded(original_loader, bound_args)
info.load_numel += num_loaded
logger.debug(
"%s: %d / %d",
layer.__class__.__name__,
info.load_numel,
info.load_numel_total,
)
# Process and copy when all weights are loaded
if info.load_numel >= info.load_numel_total and not isinstance( # type: ignore[operator]
layer, (Attention, MLAAttention)
):
_layerwise_process(layer, info)
return ret
return online_process_loader
def finalize_layerwise_reload(model: torch.nn.Module, model_config: ModelConfig):
"""
Remove the outermost layer of weight loading wrappers.
This function should be applied after `initialize_layerwise_reload` is applied
unwrap the layerwise weight loaders.
Also processes Attention/MLA layers, which must be processed after all other layers
"""
model._do_torchao_reload = model._original_do_torchao_reload
for layer in model.modules():
info = get_layerwise_info(layer)
# Attention/MLA layers are processed after all other layers
if isinstance(layer, (Attention, MLAAttention)):
if info.load_numel > 0:
raise NotImplementedError(
"Layerwise reloading of Q/K/V scale weights is not implemented yet"
)
else:
_place_kernel_tensors(layer, info)
layer.process_weights_after_loading(model_config.dtype)
# No weights were loaded, place kernel tensors back
elif info.can_process() and info.load_numel <= 0:
_place_kernel_tensors(layer, info)
# Process non-attention layers which did not load all elements. This can happen
# if the created weight has extra padding elements which are not loaded
# Having too many of these delayed layers can lead to execess memory usage
# see Limitations(4)
elif info.load_numel > 0 and info.load_numel < info.load_numel_total: # type: ignore[operator]
logger.debug("%s: Delayed processing", layer.__class__.__name__)
_layerwise_process(layer, info)
info.reset()
def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
"""
Finalize layer loading after all weights have been cached.
This function:
1. Materializes the layer onto the target device
2. Loads all cached weights
3. Runs quantization processing if applicable
4. Copies processed values back to original tensor storage
"""
# Materialize layer tensors onto device
materialize_layer(layer)
# Reset FP8 online quantization flag so process_weights_after_loading
# will run again during reload
if hasattr(layer, "_already_called_process_weights_after_loading"):
delattr(layer, "_already_called_process_weights_after_loading")
# Unwrap layerwise loading wrappers
for param in get_layer_tensors(layer).values():
param.weight_loader = _get_original_loader(param)
# Load all cached weights into materialized layer (using original loaders)
for name, args in info.loaded_weights:
param = getattr(layer, name)
args.arguments["param"] = param
param.weight_loader(*args.args, **args.kwargs)
# Process weights (quantization, repacking, etc.)
# Attention/MLA are processed in `finalize_layerwise_reload`
quant_method = getattr(layer, "quant_method", None)
if isinstance(quant_method, QuantizeMethodBase):
quant_method.process_weights_after_loading(layer)
# Copy processed values into original tensor storage (preserves cudagraph refs)
# this code is a no-op if not reloading (because kernel tensors is empty)
parameters, buffers = info.kernel_tensors
for name, param in parameters.items():
param.data.copy_(getattr(layer, name))
for name, buffer in buffers.items():
buffer.data.copy_(getattr(layer, name))
_place_kernel_tensors(layer, info)
info.reset()
logger.debug("%s: Processed", layer.__class__.__name__)
def _get_original_loader(tensor: torch.Tensor) -> Callable:
"""Return the weight loader with any layerwise wrappers removed"""
loader = _get_weight_loader(tensor)
while loader.__name__ == "online_process_loader":
loader = loader.__wrapped__ # type: ignore[union-attr]
return loader
def _get_weight_loader(tensor: torch.Tensor):
return getattr(tensor, "weight_loader", default_weight_loader)
def _place_kernel_tensors(layer: torch.nn.Module, info: LayerReloadingInfo):
for name in get_layer_tensors(layer):
delattr(layer, name)
parameters, buffers = info.kernel_tensors
for name, param in parameters.items():
layer.register_parameter(name, param)
for name, buffer in buffers.items():
layer.register_buffer(name, buffer)

View File

@@ -0,0 +1,146 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import inspect
from collections.abc import Callable
import torch
from torch.utils._python_dispatch import TorchDispatchMode
from .sanitize import restore_layer_refs, sanitize_layer_refs
from .types import LayerReloadingInfo, LayerTensors
from .utils import get_layer_params_buffers, get_layer_tensors
__all__ = [
"to_meta_tensor",
"materialize_meta_tensor",
"capture_layer_to_meta",
"restore_layer_on_meta",
"materialize_layer",
"get_numel_loaded",
]
SKIP_MODULES: set[str] = {"HadamardTransform"}
SKIP_TENSORS: set[str] = {
"_expert_map",
"expert_mask",
"expert_global_to_physical",
"expert_physical_to_global",
"expert_local_to_global",
}
def to_meta_tensor(tensor: torch.Tensor) -> torch.Tensor:
"""Convert a tensor to a meta tensor while preserving class and attributes."""
meta_tensor = tensor.data.to("meta")
meta_tensor.__class__ = tensor.__class__
meta_tensor.__dict__ = tensor.__dict__.copy()
return meta_tensor
def materialize_meta_tensor(meta_tensor: torch.Tensor) -> torch.Tensor:
"""
Materialize a meta tensor into an actual tensor on the current device.
Should be called within the torch device context for the given rank.
"""
tensor = torch.empty_strided(
size=tuple(meta_tensor.size()),
stride=tuple(meta_tensor.stride()),
dtype=meta_tensor.dtype,
requires_grad=False,
)
tensor.__class__ = meta_tensor.__class__
tensor.__dict__ = meta_tensor.__dict__.copy()
return tensor
def capture_layer_to_meta(layer: torch.nn.Module) -> LayerTensors:
if layer.__class__.__name__ in SKIP_MODULES:
return ({}, {})
params, buffers = get_layer_params_buffers(layer)
return (
{
name: sanitize_layer_refs(to_meta_tensor(param), layer)
for name, param in params.items()
if name not in SKIP_TENSORS
},
{
name: sanitize_layer_refs(to_meta_tensor(buffer), layer)
for name, buffer in buffers.items()
if name not in SKIP_TENSORS
},
)
def restore_layer_on_meta(layer: torch.nn.Module, info: LayerReloadingInfo):
"""Restore a layer to model format with tensors on the meta device"""
if layer.__class__.__name__ in SKIP_MODULES:
return
for name in get_layer_tensors(layer):
if name not in SKIP_TENSORS:
delattr(layer, name)
restore_params, restore_buffers = info.restore_metadata
for name, param in restore_params.items():
if name not in SKIP_TENSORS:
param = restore_layer_refs(param, layer)
layer.register_parameter(name, param)
for name, buffer in restore_buffers.items():
if name not in SKIP_TENSORS:
buffer = restore_layer_refs(buffer, layer)
layer.register_buffer(name, buffer)
def materialize_layer(layer: torch.nn.Module) -> None:
"""Materialize all meta tensors in a layer to actual tensors."""
if layer.__class__.__name__ in SKIP_MODULES:
return
for name, tensor in get_layer_tensors(layer).items():
if name not in SKIP_TENSORS:
setattr(layer, name, materialize_meta_tensor(tensor))
class MetaCopyCounter(TorchDispatchMode):
"""
Tracks total number of elements modified with `copy_`.
Useful for keeping track of weight loading where underlying weights can be
arbitrarily transformed (such as with `narrow`) before calling copy.
Note: Assumes that copy kwargs are not used.
"""
def __init__(self):
super().__init__()
self.copied_numel = 0
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func is torch.ops.aten.copy_.default and args[0].device.type == "meta":
assert args[0].numel() == args[1].numel()
self.copied_numel += args[0].numel()
return func(*args, **kwargs)
def get_numel_loaded(
weight_loader: Callable, args: inspect.BoundArguments
) -> tuple[int, object]:
"""
Determine how many elements would be loaded by a weight loader call.
:param weight loader: used to load weights
:param args: bound arguments to weight loader
:return: number of elements loaded by the weight loader, the return value of the
weight loader
"""
assert args.arguments["param"].device.type == "meta"
with MetaCopyCounter() as counter:
return_value = weight_loader(*args.args, **args.kwargs)
return counter.copied_numel, return_value

View File

@@ -0,0 +1,50 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from types import MethodType
import torch
__all__ = ["sanitize_layer_refs", "restore_layer_refs"]
layer_ref_sentinel = object()
def sanitize_layer_refs(tensor: torch.Tensor, layer: torch.nn.Module) -> torch.Tensor:
"""
Removes references to layer held by tensor attributes. Specifically, removes the
`__self__` attribute of weight loader methods attached to the tensor.
Used by `capture_layer_to_meta` to avoid circular references to layers in
`LAYERWISE_INFO`, leading to modules never being cleaned up. Without sanitation,
tensors will reference layers, and the WeakKeyDictionary will never evict entries,
even when the model is deleted.
:param tensor: tensor to be sanitized
:param layer: layer whose references should be removed
:return: sanitized tensor
"""
for key, value in tensor.__dict__.items():
if isinstance(value, MethodType) and value.__self__ is layer:
tensor.__dict__[key] = value.__func__.__get__(layer_ref_sentinel)
return tensor
def restore_layer_refs(tensor: torch.Tensor, layer: torch.nn.Module) -> torch.Tensor:
"""
Restores references to layer held by tensor attributes.
Used by `restore_layer_on_meta` to add back layer references, allowing for proper
weight loading.
:param tensor: tensor to be sanitized
:param layer: layer whose references should be removed
:return: sanitized tensor
"""
for key, value in tensor.__dict__.items():
if isinstance(value, MethodType) and value.__self__ is layer_ref_sentinel:
tensor.__dict__[key] = value.__func__.__get__(layer)
return tensor

View File

@@ -0,0 +1,58 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Iterable
from functools import wraps
from types import FunctionType
from typing import TYPE_CHECKING
import torch
from vllm.config import ModelConfig
from .layerwise import (
finalize_layerwise_reload,
initialize_layerwise_reload,
)
if TYPE_CHECKING:
from vllm.model_executor.models.utils import AutoWeightsLoader
__all__ = ["set_torchao_reload_attrs", "support_quantized_model_reload_from_hp_weights"]
def set_torchao_reload_attrs(model: torch.nn.Module, model_config: ModelConfig):
model._do_torchao_reload = True
model._model_config = model_config
def support_quantized_model_reload_from_hp_weights(original_load_weights: FunctionType):
"""
Decorator for `load_weights` method for AutoWeightsLoader.load_weights to support
reloading high precision (bfloat16/float16/float32) weight for an already quantized
model, this involves restoring the weights to a high precision weights and
then online quantize the weights.
Only applies to torchao quantized models. Assumes that all model weights are
loaded within a single weights iterator (cannot perform batched updates)
"""
@wraps(original_load_weights)
def patched_model_load_weights(
self: "AutoWeightsLoader",
weights: Iterable[tuple[str, torch.Tensor]],
*args,
**kwargs,
):
model = self.module
if not getattr(model, "_do_torchao_reload", False):
return original_load_weights(self, weights, *args, **kwargs)
initialize_layerwise_reload(model)
loaded_weights = original_load_weights(self, weights, *args, **kwargs)
finalize_layerwise_reload(model, model._model_config)
return loaded_weights
return patched_model_load_weights

View File

@@ -0,0 +1,33 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from dataclasses import dataclass, field
from inspect import BoundArguments
import torch
__all__ = ["LayerTensors", "LayerReloadingInfo"]
# encodes both parameters and buffers separately
LayerTensors = tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]
@dataclass
class LayerReloadingInfo:
# model format (meta), populated by `record_metadata_for_reloading`
restore_metadata: LayerTensors = field(default_factory=lambda: ({}, {}))
# kernel format (device)
kernel_tensors: LayerTensors = field(default_factory=lambda: ({}, {}))
# track how many restored elements are ready for loading
load_numel: int = 0
load_numel_total: int | None = None
# stores arguments and tensors ready for loading
loaded_weights: list[tuple[str, BoundArguments]] = field(default_factory=list)
def reset(self):
self.__init__(restore_metadata=self.restore_metadata) # type: ignore[misc]
def can_process(self) -> bool:
return self.load_numel_total is not None

View File

@@ -0,0 +1,31 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from .types import LayerTensors
__all__ = [
"get_layer_tensors",
"get_layer_params_buffers",
"get_layer_size",
]
def get_layer_tensors(layer: torch.nn.Module) -> dict[str, torch.Tensor]:
"""Get all parameters and buffers from a module as a dict."""
params, buffers = get_layer_params_buffers(layer)
return params | buffers
def get_layer_params_buffers(layer: torch.nn.Module) -> LayerTensors:
"""Get all parameters and buffers of a module as a tuple of dicts."""
return (
{name: param for name, param in layer._parameters.items() if param is not None},
{name: buffer for name, buffer in layer._buffers.items() if buffer is not None},
)
def get_layer_size(layer: torch.nn.Module) -> int:
"""Calculate total number of elements across all tensors in a layer."""
return sum(tensor.numel() for tensor in get_layer_tensors(layer).values())

View File

@@ -0,0 +1,115 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from collections.abc import Generator
import torch
from torch import nn
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import ModelConfig
from vllm.config.load import LoadConfig
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf,
download_weights_from_hf,
runai_safetensors_weights_iterator,
)
from vllm.transformers_utils.runai_utils import is_runai_obj_uri, list_safetensors
class RunaiModelStreamerLoader(BaseModelLoader):
"""
Model loader that can load safetensors
files from local FS or S3 bucket.
"""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
self._is_distributed = False
if load_config.model_loader_extra_config:
extra_config = load_config.model_loader_extra_config
if "distributed" in extra_config and isinstance(
extra_config.get("distributed"), bool
):
self._is_distributed = extra_config.get("distributed")
if "concurrency" in extra_config and isinstance(
extra_config.get("concurrency"), int
):
os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(
extra_config.get("concurrency")
)
if "memory_limit" in extra_config and isinstance(
extra_config.get("memory_limit"), int
):
os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(
extra_config.get("memory_limit")
)
runai_streamer_s3_endpoint = os.getenv("RUNAI_STREAMER_S3_ENDPOINT")
aws_endpoint_url = os.getenv("AWS_ENDPOINT_URL")
if runai_streamer_s3_endpoint is None and aws_endpoint_url is not None:
os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url
def _prepare_weights(
self, model_name_or_path: str, revision: str | None
) -> list[str]:
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
is_object_storage_path = is_runai_obj_uri(model_name_or_path)
is_local = os.path.isdir(model_name_or_path)
safetensors_pattern = "*.safetensors"
index_file = SAFE_WEIGHTS_INDEX_NAME
hf_folder = (
model_name_or_path
if (is_local or is_object_storage_path)
else download_weights_from_hf(
model_name_or_path,
self.load_config.download_dir,
[safetensors_pattern],
revision,
ignore_patterns=self.load_config.ignore_patterns,
)
)
hf_weights_files = list_safetensors(path=hf_folder)
if not is_local and not is_object_storage_path:
download_safetensors_index_file_from_hf(
model_name_or_path, index_file, self.load_config.download_dir, revision
)
if not hf_weights_files:
raise RuntimeError(
f"Cannot find any safetensors model weights with `{model_name_or_path}`"
)
return hf_weights_files
def _get_weights_iterator(
self, model_or_path: str, revision: str
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
hf_weights_files = self._prepare_weights(model_or_path, revision)
return runai_safetensors_weights_iterator(
hf_weights_files, self.load_config.use_tqdm_on_load, self._is_distributed
)
def download_model(self, model_config: ModelConfig) -> None:
"""Download model if necessary"""
self._prepare_weights(model_config.model, model_config.revision)
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
"""Load weights into a model."""
model_weights = model_config.model
if model_weights_override := model_config.model_weights:
model_weights = model_weights_override
model.load_weights(
self._get_weights_iterator(model_weights, model_config.revision)
)

View File

@@ -0,0 +1,214 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import collections
import glob
import os
import time
from collections.abc import Generator
from typing import Any
import torch
from torch import nn
from vllm.config import ModelConfig
from vllm.config.load import LoadConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.weight_utils import (
download_weights_from_hf,
runai_safetensors_weights_iterator,
)
from vllm.transformers_utils.s3_utils import glob as s3_glob
from vllm.transformers_utils.utils import is_s3
logger = init_logger(__name__)
class ShardedStateLoader(BaseModelLoader):
"""
Model loader that directly loads each worker's model state dict, which
enables a fast load path for large tensor-parallel models where each worker
only needs to read its own shard rather than the entire checkpoint. See
`examples/offline_inference/save_sharded_state.py` for creating a sharded
checkpoint.
"""
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
extra_config = (
{}
if load_config.model_loader_extra_config is None
else load_config.model_loader_extra_config.copy()
)
self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
if extra_config:
raise ValueError(
f"Unexpected extra config keys for load format "
f"{load_config.load_format}: "
f"{load_config.model_loader_extra_config.keys()}"
)
@staticmethod
def _filter_subtensors(
tensors: dict[str, torch.Tensor],
) -> dict[str, torch.Tensor]:
"""
Filter out all tensors that share the same memory or a subset of the
memory of another tensor.
"""
same_storage_groups: dict[Any, list[tuple[str, torch.Tensor]]] = (
collections.defaultdict(list)
)
for key, tensor in tensors.items():
if tensor.numel():
ptr = tensor.untyped_storage().data_ptr()
same_storage_groups[tensor.device, ptr].append((key, tensor))
def get_end_ptr(tensor: torch.Tensor) -> int:
return tensor.view(-1)[-1].data_ptr() + tensor.element_size()
result: dict[str, torch.Tensor] = {}
for group in same_storage_groups.values():
for k, t in group:
a, b = t.data_ptr(), get_end_ptr(t)
for k2, t2 in group:
if not t2.is_contiguous():
continue
a2, b2 = t2.data_ptr(), get_end_ptr(t2)
if a < a2 or b2 < b:
continue
if a2 < a or b < b2 or not t.is_contiguous():
break # t2 covers strictly more memory than t.
if k2 < k:
# Same tensors, keep the one with the smaller key.
break
else:
result[k] = t
return result
def _prepare_weights(self, model_name_or_path: str, revision: str | None):
if is_s3(model_name_or_path) or os.path.isdir(model_name_or_path):
return model_name_or_path
else:
allow_patterns = ["*.safetensors"]
return download_weights_from_hf(
model_name_or_path,
self.load_config.download_dir,
allow_patterns,
revision,
ignore_patterns=self.load_config.ignore_patterns,
)
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
from vllm.distributed import get_tensor_model_parallel_rank
model_weights = model_config.model
if model_weights_override := model_config.model_weights:
model_weights = model_weights_override
local_model_path = model_weights
rank = get_tensor_model_parallel_rank()
pattern = os.path.join(
local_model_path,
self.pattern.format(rank=rank, part="*"),
)
filepaths = []
if is_s3(local_model_path):
file_pattern = f"*{self.pattern.format(rank=rank, part='*')}"
filepaths = s3_glob(path=local_model_path, allow_pattern=[file_pattern])
else:
filepaths = glob.glob(pattern)
if not filepaths:
# TODO: support un-sharded checkpoints too
raise ValueError(
f"Could not find checkpoint files '{pattern}', only "
f"pre-sharded checkpoints are currently supported!"
)
state_dict = self._filter_subtensors(model.state_dict())
counter_before_loading_weights = time.perf_counter()
for key, tensor in self.iterate_over_files(filepaths):
# If loading with LoRA enabled, additional padding may
# be added to certain parameters. We only load into a
# narrowed view of the parameter data.
param_data = state_dict[key].data
param_shape = state_dict[key].shape
for dim, size in enumerate(tensor.shape):
if size < param_shape[dim]:
param_data = param_data.narrow(dim, 0, size)
if tensor.shape != param_shape:
logger.warning(
"loading tensor of shape %s into parameter '%s' of shape %s",
tensor.shape,
key,
param_shape,
)
param_data.copy_(tensor)
state_dict.pop(key)
counter_after_loading_weights = time.perf_counter()
logger.info_once(
"Loading weights took %.2f seconds",
counter_after_loading_weights - counter_before_loading_weights,
scope="local",
)
if state_dict:
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
def iterate_over_files(
self, paths
) -> Generator[tuple[str, torch.Tensor], None, None]:
if self.load_config.load_format == "runai_streamer_sharded":
yield from runai_safetensors_weights_iterator(paths, True)
else:
from safetensors.torch import safe_open
for path in paths:
with safe_open(path, framework="pt") as f:
for key in f.keys(): # noqa: SIM118
tensor = f.get_tensor(key)
yield key, tensor
@staticmethod
def save_model(
model: torch.nn.Module,
path: str,
pattern: str | None = None,
max_size: int | None = None,
) -> None:
from safetensors.torch import save_file
from vllm.distributed import get_tensor_model_parallel_rank
if pattern is None:
pattern = ShardedStateLoader.DEFAULT_PATTERN
rank = get_tensor_model_parallel_rank()
part_idx = 0
total_size = 0
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
state_dict_part: dict[str, torch.Tensor] = {}
for key, tensor in state_dict.items():
param_size = tensor.nelement() * tensor.element_size()
if max_size is not None and total_size + param_size > max_size:
filename = pattern.format(rank=rank, part=part_idx)
save_file(
state_dict_part,
os.path.join(path, filename),
)
part_idx += 1
total_size = 0
state_dict_part = {}
state_dict_part[key] = tensor
total_size += param_size
if len(state_dict_part) > 0:
filename = pattern.format(rank=rank, part=part_idx)
save_file(
state_dict_part,
os.path.join(path, filename),
)

View File

@@ -0,0 +1,793 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import contextlib
import contextvars
import dataclasses
import json
import os
import tempfile
import threading
import time
from collections.abc import Generator, MutableMapping
from dataclasses import asdict, dataclass, field, fields
from typing import TYPE_CHECKING, Any, ClassVar
import regex as re
import torch
from huggingface_hub import snapshot_download
from torch import nn
from torch.utils._python_dispatch import TorchDispatchMode
from transformers import PretrainedConfig
import vllm.envs as envs
from vllm.config import ModelConfig, ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.platforms import current_platform
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.import_utils import PlaceholderModule
if TYPE_CHECKING:
from vllm.engine.arg_utils import EngineArgs
try:
from tensorizer import (
DecryptionParams,
EncryptionParams,
TensorDeserializer,
TensorSerializer,
)
from tensorizer.stream_io import open_stream
from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor
except ImportError:
tensorizer = PlaceholderModule("tensorizer")
DecryptionParams = tensorizer.placeholder_attr("DecryptionParams")
EncryptionParams = tensorizer.placeholder_attr("EncryptionParams")
TensorDeserializer = tensorizer.placeholder_attr("TensorDeserializer")
TensorSerializer = tensorizer.placeholder_attr("TensorSerializer")
open_stream = tensorizer.placeholder_attr("stream_io.open_stream")
convert_bytes = tensorizer.placeholder_attr("utils.convert_bytes")
get_mem_usage = tensorizer.placeholder_attr("utils.get_mem_usage")
no_init_or_tensor = tensorizer.placeholder_attr("utils.no_init_or_tensor")
__all__ = [
"EncryptionParams",
"DecryptionParams",
"TensorDeserializer",
"TensorSerializer",
"open_stream",
"convert_bytes",
"get_mem_usage",
"no_init_or_tensor",
"TensorizerConfig",
]
logger = init_logger(__name__)
def is_valid_deserialization_uri(uri: str | None) -> bool:
if uri:
scheme = uri.lower().split("://")[0]
return scheme in {"s3", "http", "https"} or os.path.exists(uri)
return False
def tensorizer_kwargs_arg(value):
loaded = json.loads(value)
if not isinstance(loaded, dict):
raise argparse.ArgumentTypeError(
f"Not deserializable to dict: {value}. serialization_kwargs and "
f"deserialization_kwargs must be "
f"deserializable from a JSON string to a dictionary. "
)
return loaded
class MetaTensorMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
if func._schema.name == "aten::empty" and "device" not in kwargs:
kwargs["device"] = "meta"
return func(*args, **kwargs)
def meta_tensor_mode(
loading_code=None,
):
if loading_code is None:
return _NoInitOrTensorImpl.context_manager()
elif callable(loading_code):
with _NoInitOrTensorImpl.context_manager():
return loading_code()
else:
raise TypeError(
"expected a callable to evaluate,"
" or None if being used as a context manager;"
f' got an object of type "{type(loading_code).__name__}" instead.'
)
class _NoInitOrTensorImpl:
_MODULES = (torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm)
_MODULE_ORIGINALS = tuple((m, m.reset_parameters) for m in _MODULES)
is_active = contextvars.ContextVar("_NoInitOrTensorImpl.is_active", default=False)
_count_active: int = 0
_count_active_lock = threading.Lock()
@classmethod
@contextlib.contextmanager
def context_manager(cls):
if cls.is_active.get():
yield
return
with cls._count_active_lock:
cls._count_active += 1
if cls._count_active == 1:
for mod in cls._MODULES:
mod.reset_parameters = cls._disable(mod.reset_parameters)
reset_token = cls.is_active.set(True)
try:
with MetaTensorMode():
yield
finally:
cls.is_active.reset(reset_token)
with cls._count_active_lock:
cls._count_active -= 1
if cls._count_active == 0:
for mod, original in cls._MODULE_ORIGINALS:
mod.reset_parameters = original
@staticmethod
def _disable(func):
def wrapper(*args, **kwargs):
if not _NoInitOrTensorImpl.is_active.get():
return func(*args, **kwargs)
return wrapper
@dataclass
class TensorizerConfig(MutableMapping):
tensorizer_uri: str | None = None
tensorizer_dir: str | None = None
vllm_tensorized: bool | None = None
verify_hash: bool | None = None
num_readers: int | None = None
encryption_keyfile: str | None = None
s3_access_key_id: str | None = None
s3_secret_access_key: str | None = None
s3_endpoint: str | None = None
lora_dir: str | None = None
stream_kwargs: dict[str, Any] | None = None
serialization_kwargs: dict[str, Any] | None = None
deserialization_kwargs: dict[str, Any] | None = None
_extra_serialization_attrs: dict[str, Any] | None = field(init=False, default=None)
model_class: type[torch.nn.Module] | None = field(init=False, default=None)
hf_config: PretrainedConfig | None = field(init=False, default=None)
dtype: str | torch.dtype | None = field(init=False, default=None)
_is_sharded: bool = field(init=False, default=False)
_fields: ClassVar[tuple[str, ...]]
_keys: ClassVar[frozenset[str]]
"""Configuration class for Tensorizer settings.
These settings configure the behavior of model serialization and
deserialization using Tensorizer.
Attributes:
tensorizer_uri: Path to serialized model tensors. Can be a local file
path or a S3 URI. This is a required field unless lora_dir is
provided and the config is meant to be used for the
`tensorize_lora_adapter` function. Unless a `tensorizer_dir` or
`lora_dir` is passed to this object's initializer, this is
a required argument.
tensorizer_dir: Path to a directory containing serialized model tensors,
and all other potential model artifacts to load the model, such as
configs and tokenizer files. Can be passed instead of
`tensorizer_uri` where the `model.tensors` file will be assumed
to be in this directory.
vllm_tensorized: If True, indicates that the serialized model is a
vLLM model. This is used to determine the behavior of the
TensorDeserializer when loading tensors from a serialized model.
It is far faster to deserialize a vLLM model as it utilizes
tensorizer's optimized GPU loading. Note that this is now
deprecated, as serialized vLLM models are now automatically
inferred as vLLM models.
verify_hash: If True, the hashes of each tensor will be verified
against the hashes stored in the metadata. A `HashMismatchError`
will be raised if any of the hashes do not match.
num_readers: Controls how many threads are allowed to read concurrently
from the source file. Default is `None`, which will dynamically set
the number of readers based on the number of available
resources and model size. This greatly increases performance.
encryption_keyfile: File path to a binary file containing a
binary key to use for decryption. `None` (the default) means
no decryption. See the example script in
examples/others/tensorize_vllm_model.py.
s3_access_key_id: The access key for the S3 bucket. Can also be set via
the S3_ACCESS_KEY_ID environment variable.
s3_secret_access_key: The secret access key for the S3 bucket. Can also
be set via the S3_SECRET_ACCESS_KEY environment variable.
s3_endpoint: The endpoint for the S3 bucket. Can also be set via the
S3_ENDPOINT_URL environment variable.
lora_dir: Path to a directory containing LoRA adapter artifacts for
serialization or deserialization. When serializing LoRA adapters
this is the only necessary parameter to pass to this object's
initializer.
"""
def __post_init__(self):
# check if the configuration is for a sharded vLLM model
self._is_sharded = (
isinstance(self.tensorizer_uri, str)
and re.search(r"%0\dd", self.tensorizer_uri) is not None
)
if self.tensorizer_dir and self.lora_dir:
raise ValueError(
"Only one of tensorizer_dir or lora_dir may be specified. "
"Use lora_dir exclusively when serializing LoRA adapters, "
"and tensorizer_dir or tensorizer_uri otherwise."
)
if self.tensorizer_dir and self.tensorizer_uri:
logger.warning_once(
"Provided both tensorizer_dir and tensorizer_uri. "
"Inferring tensorizer_dir from tensorizer_uri as the "
"latter takes precedence."
)
self.tensorizer_dir = os.path.dirname(self.tensorizer_uri)
if not self.tensorizer_uri:
if self.lora_dir:
self.tensorizer_uri = f"{self.lora_dir}/adapter_model.tensors"
elif self.tensorizer_dir:
self.tensorizer_uri = f"{self.tensorizer_dir}/model.tensors"
else:
raise ValueError(
"Unable to resolve tensorizer_uri. "
"A valid tensorizer_uri or tensorizer_dir "
"must be provided for deserialization, and a "
"valid tensorizer_uri, tensorizer_uri, or "
"lora_dir for serialization."
)
else:
self.tensorizer_dir = os.path.dirname(self.tensorizer_uri)
if not self.serialization_kwargs:
self.serialization_kwargs = {}
if not self.deserialization_kwargs:
self.deserialization_kwargs = {}
def to_serializable(self) -> dict[str, Any]:
# Due to TensorizerConfig needing to be msgpack-serializable, it needs
# support for morphing back and forth between itself and its dict
# representation
# TensorizerConfig's representation as a dictionary is meant to be
# linked to TensorizerConfig in such a way that the following is
# technically initializable:
# TensorizerConfig(**my_tensorizer_cfg.to_serializable())
# This means the dict must not retain non-initializable parameters
# and post-init attribute states
# Also don't want to retain private and unset parameters, so only retain
# not None values and public attributes
raw_tc_dict = asdict(self)
blacklisted = []
if "tensorizer_uri" in raw_tc_dict and "tensorizer_dir" in raw_tc_dict:
blacklisted.append("tensorizer_dir")
if "tensorizer_dir" in raw_tc_dict and "lora_dir" in raw_tc_dict:
blacklisted.append("tensorizer_dir")
tc_dict = {}
for k, v in raw_tc_dict.items():
if (
k not in blacklisted
and k not in tc_dict
and not k.startswith("_")
and v is not None
):
tc_dict[k] = v
return tc_dict
def _construct_tensorizer_args(self) -> "TensorizerArgs":
return TensorizerArgs(self) # type: ignore
def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",
) -> None:
if parallel_config.tensor_parallel_size > 1 and not self._is_sharded:
raise ValueError(
"For a sharded model, tensorizer_uri should include a"
" string format template like '%04d' to be formatted"
" with the rank of the shard"
)
def verify_with_model_config(self, model_config: "ModelConfig") -> None:
if model_config.quantization is not None and self.tensorizer_uri is not None:
logger.warning(
"Loading a model using Tensorizer with quantization on vLLM"
" is unstable and may lead to errors."
)
def open_stream(self, tensorizer_args: "TensorizerArgs | None" = None):
if tensorizer_args is None:
tensorizer_args = self._construct_tensorizer_args()
return open_stream(self.tensorizer_uri, **tensorizer_args.stream_kwargs)
def keys(self):
return self._keys
def __len__(self):
return len(fields(self))
def __iter__(self):
return iter(self._fields)
def __getitem__(self, item: str) -> Any:
if item not in self.keys():
raise KeyError(item)
return getattr(self, item)
def __setitem__(self, key: str, value: Any) -> None:
if key not in self.keys():
# Disallow modifying invalid keys
raise KeyError(key)
setattr(self, key, value)
def __delitem__(self, key, /):
if key not in self.keys():
raise KeyError(key)
delattr(self, key)
TensorizerConfig._fields = tuple(f.name for f in fields(TensorizerConfig))
TensorizerConfig._keys = frozenset(TensorizerConfig._fields)
@dataclass
class TensorizerArgs:
tensorizer_uri: str | None = None
tensorizer_dir: str | None = None
encryption_keyfile: str | None = None
def __init__(self, tensorizer_config: TensorizerConfig):
for k, v in tensorizer_config.items():
setattr(self, k, v)
self.file_obj = tensorizer_config.tensorizer_uri
self.s3_access_key_id = (
tensorizer_config.s3_access_key_id or envs.S3_ACCESS_KEY_ID
)
self.s3_secret_access_key = (
tensorizer_config.s3_secret_access_key or envs.S3_SECRET_ACCESS_KEY
)
self.s3_endpoint = tensorizer_config.s3_endpoint or envs.S3_ENDPOINT_URL
self.stream_kwargs = {
"s3_access_key_id": tensorizer_config.s3_access_key_id,
"s3_secret_access_key": tensorizer_config.s3_secret_access_key,
"s3_endpoint": tensorizer_config.s3_endpoint,
**(tensorizer_config.stream_kwargs or {}),
}
self.deserialization_kwargs = {
"verify_hash": tensorizer_config.verify_hash,
"encryption": tensorizer_config.encryption_keyfile,
"num_readers": tensorizer_config.num_readers,
**(tensorizer_config.deserialization_kwargs or {}),
}
if self.encryption_keyfile:
with open_stream(
tensorizer_config.encryption_keyfile,
**self.stream_kwargs,
) as stream:
key = stream.read()
decryption_params = DecryptionParams.from_key(key)
self.deserialization_kwargs["encryption"] = decryption_params
@staticmethod
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""Tensorizer CLI arguments"""
# Tensorizer options arg group
group = parser.add_argument_group(
"tensorizer options",
description=(
"Options for configuring the behavior of the"
" tensorizer deserializer when "
"load_format=tensorizer is specified when "
"initializing an LLMEngine, either via the CLI "
"when running the vLLM OpenAI inference server "
"with a JSON string passed to "
"--model-loader-extra-config or as arguments given "
"to TensorizerConfig when passed to "
"model_loader_extra_config in the constructor "
"for LLMEngine."
),
)
group.add_argument(
"--tensorizer-uri",
type=str,
help="Path to serialized model tensors. Can be a local file path,"
" or an HTTP(S) or S3 URI.",
)
group.add_argument(
"--verify-hash",
action="store_true",
help="If enabled, the hashes of each tensor will be verified"
" against the hashes stored in the file metadata. An exception"
" will be raised if any of the hashes do not match.",
)
group.add_argument(
"--encryption-keyfile",
type=str,
default=None,
help="The file path to a binary file containing a binary key to "
"use for decryption. Can be a file path or S3 network URI.",
)
group.add_argument(
"--num-readers",
default=None,
type=int,
help="Controls how many threads are allowed to read concurrently "
"from the source file. Default is `None`, which will dynamically "
"set the number of readers based on the available resources "
"and model size. This greatly increases performance.",
)
group.add_argument(
"--s3-access-key-id",
type=str,
default=None,
help="The access key for the S3 bucket. Can also be set via the "
"S3_ACCESS_KEY_ID environment variable.",
)
group.add_argument(
"--s3-secret-access-key",
type=str,
default=None,
help="The secret access key for the S3 bucket. Can also be set via "
"the S3_SECRET_ACCESS_KEY environment variable.",
)
group.add_argument(
"--s3-endpoint",
type=str,
default=None,
help="The endpoint for the S3 bucket. Can also be set via the "
"S3_ENDPOINT_URL environment variable.",
)
return parser
@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> "TensorizerArgs":
attrs = [attr.name for attr in dataclasses.fields(cls)]
tensorizer_args = cls(
**{attr: getattr(args, attr) for attr in attrs if hasattr(args, attr)}
)
return tensorizer_args
def _check_tensors_on_meta_device(model: nn.Module) -> None:
for tensor in model.state_dict().values():
if tensor.device.type == "meta":
raise ValueError(
"The serialized model contains tensors on the meta device,"
" indicating that some tensors were not loaded properly."
" Please check that the parameters of the model being"
" specified match that of the serialized model, such as"
" its quantization."
)
def _resize_lora_embeddings(model: nn.Module):
"""Modify LoRA embedding layers to use bigger tensors
to allow for adapter added tokens."""
for child in model.modules():
if (
isinstance(child, VocabParallelEmbedding)
and child.weight.shape[0] < child.num_embeddings_per_partition
):
new_weight = torch.empty(
child.num_embeddings_per_partition,
child.embedding_dim,
dtype=child.weight.dtype,
device=child.weight.device,
)
new_weight[: child.weight.shape[0]].copy_(child.weight.data)
new_weight[child.weight.shape[0] :].fill_(0)
child.weight.data = new_weight
def init_tensorizer_model(
tensorizer_config: TensorizerConfig, vllm_config: VllmConfig
) -> nn.Module:
assert tensorizer_config.hf_config is not None
model_args = tensorizer_config.hf_config
model_args.dtype = tensorizer_config.dtype
assert tensorizer_config.model_class is not None
# TODO: Do we need to consider old-style model class?
with meta_tensor_mode(), set_current_vllm_config(vllm_config, check_compile=True):
return tensorizer_config.model_class(vllm_config=vllm_config)
def deserialize_tensorizer_model(
model: nn.Module, tensorizer_config: TensorizerConfig
) -> None:
tensorizer_args = tensorizer_config._construct_tensorizer_args()
if not is_valid_deserialization_uri(tensorizer_config.tensorizer_uri):
raise ValueError(
f"{tensorizer_config.tensorizer_uri} is not a valid "
f"tensorizer URI. Please check that the URI is correct. "
f"It must either point to a local existing file, or have a "
f"S3, HTTP or HTTPS scheme."
)
before_mem = get_mem_usage()
start = time.perf_counter()
with (
open_stream(
tensorizer_config.tensorizer_uri, mode="rb", **tensorizer_args.stream_kwargs
) as stream,
TensorDeserializer(
stream,
dtype=tensorizer_config.dtype,
device=f"xpu:{torch.xpu.current_device()}"
if current_platform.is_xpu()
else f"cuda:{torch.cuda.current_device()}",
**tensorizer_args.deserialization_kwargs,
) as deserializer,
):
deserializer.load_into_module(model)
end = time.perf_counter()
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
duration = end - start
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
after_mem = get_mem_usage()
deserializer.close()
logger.info(
"Deserialized %s in %0.2fs, %s/s", total_bytes_str, end - start, per_second
)
logger.info("Memory usage before: %s", before_mem)
logger.info("Memory usage after: %s", after_mem)
_check_tensors_on_meta_device(model)
_resize_lora_embeddings(model)
del model.vllm_tensorized_marker
def tensorizer_weights_iterator(
tensorizer_args: "TensorizerArgs",
) -> Generator[tuple[str, torch.Tensor], None, None]:
logger.warning(
"Deserializing HuggingFace models is not optimized for "
"loading on vLLM, as tensorizer is forced to load to CPU. "
"Consider deserializing a vLLM model instead for faster "
"load times. See the "
"examples/others/tensorize_vllm_model.py example script "
"for serializing vLLM models."
)
deserializer_args = tensorizer_args.deserialization_kwargs
stream_kwargs = tensorizer_args.stream_kwargs
stream = open_stream(tensorizer_args.tensorizer_uri, **stream_kwargs)
with TensorDeserializer(stream, **deserializer_args, device="cpu") as state:
yield from state.items()
del state
def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool:
"""
Infer if the model is a vLLM model by checking the weights for
a vLLM tensorized marker.
Args:
tensorizer_config: The TensorizerConfig object containing the
tensorizer_uri to the serialized model.
Returns:
bool: True if the model is a vLLM model, False otherwise.
"""
tensorizer_args = tensorizer_config._construct_tensorizer_args()
deserializer = TensorDeserializer(
open_stream(tensorizer_args.tensorizer_uri, **tensorizer_args.stream_kwargs),
**tensorizer_args.deserialization_kwargs,
lazy_load=True,
)
if tensorizer_config.vllm_tensorized:
logger.warning(
"Please note that newly serialized vLLM models are automatically "
"inferred as vLLM models, so setting vllm_tensorized=True is "
"only necessary for models serialized prior to this change."
)
return True
return ".vllm_tensorized_marker" in deserializer
def serialize_extra_artifacts(
tensorizer_args: TensorizerArgs, served_model_name: str | list[str] | None
) -> None:
if not isinstance(served_model_name, str):
raise ValueError(
f"served_model_name must be a str for serialize_extra_artifacts, "
f"not {type(served_model_name)}."
)
with tempfile.TemporaryDirectory() as tmpdir:
snapshot_download(
served_model_name,
local_dir=tmpdir,
ignore_patterns=[
"*.pt",
"*.safetensors",
"*.bin",
"*.cache",
"*.gitattributes",
"*.md",
],
)
for artifact in os.scandir(tmpdir):
if not artifact.is_file():
continue
with (
open(artifact.path, "rb") as f,
open_stream(
f"{tensorizer_args.tensorizer_dir}/{artifact.name}",
mode="wb+",
**tensorizer_args.stream_kwargs,
) as stream,
):
logger.info("Writing artifact %s", artifact.name)
stream.write(f.read())
def serialize_vllm_model(
model: nn.Module,
tensorizer_config: TensorizerConfig,
model_config: "ModelConfig",
) -> nn.Module:
model.register_parameter(
"vllm_tensorized_marker",
nn.Parameter(torch.tensor((1,), device="meta"), requires_grad=False),
)
tensorizer_args = tensorizer_config._construct_tensorizer_args()
encryption_params = None
if (keyfile := tensorizer_config.encryption_keyfile) is not None:
with open(keyfile, "rb") as f:
key = f.read()
encryption_params = EncryptionParams(key=key)
output_file = tensorizer_args.tensorizer_uri
if tensorizer_config._is_sharded:
from vllm.distributed import get_tensor_model_parallel_rank
output_file = output_file % get_tensor_model_parallel_rank()
with open_stream(
output_file, mode="wb+", **tensorizer_args.stream_kwargs
) as stream:
serializer = TensorSerializer(
stream,
encryption=encryption_params,
**tensorizer_config.serialization_kwargs,
)
serializer.write_module(model)
serializer.close()
serialize_extra_artifacts(tensorizer_args, model_config.served_model_name)
logger.info("Successfully serialized model to %s", str(output_file))
return model
def tensorize_vllm_model(
engine_args: "EngineArgs",
tensorizer_config: TensorizerConfig,
generate_keyfile: bool = True,
):
"""Utility to load a model and then serialize it with Tensorizer
Intended to be used separately from running a vLLM server since it
creates its own Engine instance.
"""
engine_config = engine_args.create_engine_config()
tensorizer_config.verify_with_model_config(engine_config.model_config)
tensorizer_config.verify_with_parallel_config(engine_config.parallel_config)
# generate the encryption key before creating the engine to support sharding
if (
generate_keyfile
and (keyfile := tensorizer_config.encryption_keyfile) is not None
):
encryption_params = EncryptionParams.random()
with open_stream(
keyfile,
mode="wb+",
s3_access_key_id=tensorizer_config.s3_access_key_id,
s3_secret_access_key=tensorizer_config.s3_secret_access_key,
s3_endpoint=tensorizer_config.s3_endpoint,
) as stream:
stream.write(encryption_params.key)
from vllm.v1.engine.llm_engine import LLMEngine
engine = LLMEngine.from_vllm_config(engine_config)
engine.collective_rpc(
"save_tensorized_model",
kwargs={"tensorizer_config": tensorizer_config.to_serializable()},
)
def tensorize_lora_adapter(lora_path: str, tensorizer_config: TensorizerConfig):
"""
Uses tensorizer to serialize a LoRA adapter. Assumes that the files
needed to load a LoRA adapter are a safetensors-format file called
adapter_model.safetensors and a json config file called adapter_config.json.
Serializes the files in the tensorizer_config.tensorizer_dir
"""
import safetensors
from vllm.lora.utils import get_adapter_absolute_path
lora_dir = get_adapter_absolute_path(lora_path)
tensor_path = config_path = ""
for file in os.listdir(lora_dir):
if file.startswith("adapter_model"):
tensor_path = lora_dir + "/" + file
if file.startswith("adapter_config"):
config_path = lora_dir + "/" + file
if tensor_path and config_path:
break
if tensor_path.endswith(".safetensors"):
tensors = safetensors.torch.load_file(tensor_path)
elif tensor_path.endswith(".bin"):
tensors = torch.load(tensor_path, weights_only=True)
else:
raise ValueError(
f"Unsupported adapter model file: {tensor_path}. "
f"Must be a .safetensors or .bin file."
)
with open(config_path) as f:
config = json.load(f)
tensorizer_args = tensorizer_config._construct_tensorizer_args()
with open_stream(
f"{tensorizer_config.tensorizer_dir}/adapter_config.json",
mode="wb+",
**tensorizer_args.stream_kwargs,
) as f:
f.write(json.dumps(config).encode("utf-8"))
lora_uri = f"{tensorizer_config.tensorizer_dir}/adapter_model.tensors"
with open_stream(lora_uri, mode="wb+", **tensorizer_args.stream_kwargs) as f:
serializer = TensorSerializer(f)
serializer.write_state_dict(tensors)
serializer.close()
logger.info(
"Successfully serialized LoRA files to %s",
str(tensorizer_config.tensorizer_dir),
)

View File

@@ -0,0 +1,152 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: SIM117
import copy
from collections.abc import Generator
import torch
from torch import nn
from vllm.config import ModelConfig, ParallelConfig, VllmConfig
from vllm.config.load import LoadConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig,
deserialize_tensorizer_model,
init_tensorizer_model,
is_vllm_tensorized,
serialize_vllm_model,
tensorizer_weights_iterator,
)
from vllm.model_executor.model_loader.utils import (
get_model_architecture,
initialize_model,
)
from vllm.utils.torch_utils import set_default_torch_dtype
logger = init_logger(__name__)
BLACKLISTED_TENSORIZER_ARGS = {
"device", # vLLM decides this
"dtype", # vLLM decides this
"mode", # Not meant to be configurable by the user
}
def validate_config(config: dict):
for k, v in config.items():
if v is not None and k in BLACKLISTED_TENSORIZER_ARGS:
raise ValueError(f"{k} is not an allowed Tensorizer argument.")
class TensorizerLoader(BaseModelLoader):
"""Model loader using CoreWeave's tensorizer library."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
self.tensorizer_config = load_config.model_loader_extra_config
else:
validate_config(load_config.model_loader_extra_config)
self.tensorizer_config = TensorizerConfig(
**load_config.model_loader_extra_config["tensorizer_config"]
)
def _verify_config(
self, model_config: ModelConfig, parallel_config: ParallelConfig
):
self.tensorizer_config.verify_with_model_config(model_config)
self.tensorizer_config.verify_with_parallel_config(parallel_config)
def _get_weights_iterator(
self,
) -> Generator[tuple[str, torch.Tensor], None, None]:
tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
return tensorizer_weights_iterator(tensorizer_args)
def _load_model_serialized_cpu(
self,
vllm_config: VllmConfig,
prefix: str = "",
) -> nn.Module:
"""Load a serialized model with tensorizer to the CPU.
This is only necessary when the model isn't vLLM-tensorized (see
examples/others/tensorize_vllm_model.py) This should still
be faster than default HuggingFace loading, but will be slower than
loading a vLLM-tensorized model.
"""
device_config = vllm_config.device_config
model_config = vllm_config.model_config
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = initialize_model(vllm_config=vllm_config, prefix=prefix)
model.load_weights(self._get_weights_iterator())
return model.eval()
def download_model(self, model_config: ModelConfig) -> None:
self.tensorizer_config.verify_with_model_config(model_config)
with self.tensorizer_config.open_stream():
pass
def _patch_tensorizer_config(self, model_config: ModelConfig) -> TensorizerConfig:
model_class = get_model_architecture(model_config)[0]
tensorizer_config = copy.copy(self.tensorizer_config)
tensorizer_config.model_class = model_class
tensorizer_config.hf_config = model_config.hf_config
tensorizer_config.dtype = model_config.dtype
return tensorizer_config
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
"""Load serialized model weights with tensorizer.
Expects a vLLM-tensorized model. See the
examples/others/tensorize_vllm_model.py example script
for serializing vLLM models."""
if is_vllm_tensorized(self.tensorizer_config):
tensorizer_config = self._patch_tensorizer_config(model_config)
deserialize_tensorizer_model(model, tensorizer_config)
else:
model.load_weights(self._get_weights_iterator())
def load_model(
self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = ""
) -> nn.Module:
parallel_config = vllm_config.parallel_config
self._verify_config(model_config, parallel_config)
if parallel_config.tensor_parallel_size > 1:
from vllm.distributed import get_tensor_model_parallel_rank
self.tensorizer_config.tensorizer_uri = (
self.tensorizer_config.tensorizer_uri % get_tensor_model_parallel_rank()
)
if is_vllm_tensorized(self.tensorizer_config):
tensorizer_config = self._patch_tensorizer_config(model_config)
device_config = vllm_config.device_config
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = init_tensorizer_model(
tensorizer_config=tensorizer_config, vllm_config=vllm_config
)
self.load_weights(model, model_config)
return model
return self._load_model_serialized_cpu(vllm_config=vllm_config, prefix=prefix)
@staticmethod
def save_model(
model: torch.nn.Module,
tensorizer_config: TensorizerConfig | dict,
model_config: ModelConfig,
) -> None:
if isinstance(tensorizer_config, dict):
tensorizer_config = TensorizerConfig(**tensorizer_config)
serialize_vllm_model(
model=model,
tensorizer_config=tensorizer_config,
model_config=model_config,
)

View File

@@ -0,0 +1,286 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utilities for selecting and loading models."""
import inspect
import warnings
from contextlib import contextmanager
from dataclasses import dataclass, field
import torch
from torch import nn
from typing_extensions import assert_never
import vllm.envs as envs
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.attention import Attention, MLAAttention
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.model_loader.reload import (
record_metadata_for_reloading,
set_torchao_reload_attrs,
)
from vllm.model_executor.models.interfaces import SupportsQuant
from vllm.tracing import instrument
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.utils.torch_utils import get_accelerator_view_from_cpu_tensor
logger = init_logger(__name__)
@instrument(span_name="Initialize model")
def initialize_model(
vllm_config: VllmConfig,
*,
prefix: str = "",
model_class: type[nn.Module] | None = None,
model_config: ModelConfig | None = None,
) -> nn.Module:
"""Initialize a model with the given configurations."""
if model_config is None:
model_config = vllm_config.model_config
if model_class is None:
model_class, _ = get_model_architecture(model_config)
if vllm_config.quant_config is not None:
configure_quant_config(vllm_config.quant_config, model_class)
signatures = inspect.signature(model_class.__init__)
all_params = [param.name for param in signatures.parameters.values()]
if "vllm_config" in all_params and "prefix" in all_params:
# new-style model class
with set_current_vllm_config(vllm_config, check_compile=True, prefix=prefix):
model = model_class(vllm_config=vllm_config, prefix=prefix)
record_metadata_for_reloading(model)
return model
msg = (
"vLLM model class should accept `vllm_config` and `prefix` as "
"input arguments. Possibly you have an old-style model class"
" registered from out of tree and it is used for new vLLM version. "
"Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
"for the design and update the model class accordingly."
)
warnings.warn(msg, DeprecationWarning, stacklevel=2)
logger.warning(
"Trying to guess the arguments for old-style model class %s",
model_class,
)
# try to be compatible with old-style model class
kwargs = {}
if "prefix" in all_params:
kwargs["prefix"] = prefix
if "config" in all_params:
kwargs["config"] = model_config.hf_config
if "cache_config" in all_params:
kwargs["cache_config"] = vllm_config.cache_config
if "quant_config" in all_params:
kwargs["quant_config"] = vllm_config.quant_config
if "lora_config" in all_params:
kwargs["lora_config"] = vllm_config.lora_config
if "scheduler_config" in all_params:
kwargs["scheduler_config"] = vllm_config.scheduler_config
with set_current_vllm_config(vllm_config, check_compile=True, prefix=prefix):
model = model_class(**kwargs)
record_metadata_for_reloading(model)
return model
def process_weights_after_loading(
model: nn.Module, model_config: ModelConfig, target_device: torch.device
) -> None:
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if isinstance(quant_method, QuantizeMethodBase):
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
# Initialize post-load attention weights for both Attention and MLA.
# NOTE: Happens after other modules so we can easily decompress weights.
for _, module in model.named_modules():
if isinstance(module, (Attention, MLAAttention)) and hasattr(
module, "process_weights_after_loading"
):
# TODO(lucas): see if there is a way to unify the signatures
# of process_weights_after_loading
with device_loading_context(module, target_device):
module.process_weights_after_loading(model_config.dtype)
# Needed for torchao model reloading via model.reload_weights
# @kylesayrs @jerryzh168 this can be removed if callers move to `reload_weights`
if model_config.quantization == "torchao":
set_torchao_reload_attrs(model, model_config)
@contextmanager
def device_loading_context(module: torch.nn.Module, target_device: torch.device):
if target_device.type == "cpu":
# If target is CPU, no need to move anything
yield module
return
original_device_states: dict[str, torch.device] = {}
uva_offloaded_parameters: list[str] = []
# Store original device states and move parameters to GPU if they're on CPU
for name, p in module.named_parameters():
if p.device.type == "cpu":
original_device_states[name] = p.device
p.data = p.data.to(target_device)
if getattr(p, "_vllm_is_uva_offloaded", False):
uva_offloaded_parameters.append(name)
# Parameters already on target device are not touched
try:
yield module
finally:
use_pin_memory = (
is_pin_memory_available()
and not envs.VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY
)
# Restore parameters to their original devices, ignoring new parameters
for name, p in module.named_parameters():
if name in original_device_states:
original_device: torch.device = original_device_states[name]
p.data = p.data.to(original_device)
# parameter is UVA offloaded, but was replaced with a new device tensor
# re-offload it to CPU using UVA
if name in uva_offloaded_parameters and not getattr(
p, "_vllm_is_uva_offloaded", False
):
cpu_data = p.data.to(device="cpu")
if use_pin_memory:
cpu_data = cpu_data.pin_memory()
p.data = get_accelerator_view_from_cpu_tensor(cpu_data)
p._vllm_is_uva_offloaded = True
_MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]()
"""Caches the outputs of `_get_model_architecture`."""
def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]:
from vllm.model_executor.models.adapters import as_embedding_model, as_seq_cls_model
architectures = getattr(model_config.hf_config, "architectures", [])
model_cls, arch = model_config.registry.resolve_model_cls(
architectures,
model_config=model_config,
)
if arch == model_config._get_transformers_backend_cls():
assert model_config.model_impl != "vllm"
if model_config.model_impl == "auto":
logger.warning_once(
"%s has no vLLM implementation, falling back to Transformers "
"implementation. Some features may not be supported and "
"performance may not be optimal.",
arch,
)
convert_type = model_config.convert_type
if convert_type == "none":
pass
elif convert_type == "embed":
logger.debug_once("Converting to embedding model.")
model_cls = as_embedding_model(model_cls)
elif convert_type == "classify":
logger.debug_once("Converting to sequence classification model.")
model_cls = as_seq_cls_model(model_cls)
else:
assert_never(convert_type)
return model_cls, arch
def get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]:
key = hash(
(
model_config.model,
model_config.convert_type,
model_config.runner_type,
model_config.trust_remote_code,
model_config.model_impl,
tuple(getattr(model_config.hf_config, "architectures", [])),
)
)
if key in _MODEL_ARCH_BY_HASH:
return _MODEL_ARCH_BY_HASH[key]
model_arch = _get_model_architecture(model_config)
_MODEL_ARCH_BY_HASH[key] = model_arch
return model_arch
def get_model_cls(model_config: ModelConfig) -> type[nn.Module]:
return get_model_architecture(model_config)[0]
def get_architecture_class_name(model_config: ModelConfig) -> str:
return get_model_architecture(model_config)[1]
@dataclass
class ParamMapping:
"""
A class to handle parameter mapping for model weight loading.
It creates a bidirectional mapping between packed parameters and their
constituent parts.
"""
packed_mapping: dict[str, list[str]]
inverse_packed_mapping: dict[str, tuple[str, int]] = field(default_factory=dict)
def __post_init__(self):
for packed_name, sub_params in self.packed_mapping.items():
# Skip self-contained cases (e.g., {"W_pack": ["W_pack"]})
if len(sub_params) == 1 and sub_params[0] == packed_name:
continue
for index, param_name in enumerate(sub_params):
self.inverse_packed_mapping[param_name] = (
packed_name,
index,
)
def get_sub_modules(self, module_name: str) -> tuple[str, list[str]] | None:
for key, value in self.packed_mapping.items():
if module_name.endswith(key):
return key, value
return None
def configure_quant_config(
quant_config: QuantizationConfig, model_class: type[nn.Module]
):
"""
Pass packed_modules_mapping by reference to quant_config so that
quant_config can properly match fused modules
Note that model attributes are passed by reference to quant_config,
enabling them to be updated by model_class.__new__ (ex. chatglm, qwen)
Once the `SupportsQuant` mixin has been added to all models, this
function can be removed
"""
if not issubclass(model_class, SupportsQuant):
hf_to_vllm_mapper = getattr(model_class, "hf_to_vllm_mapper", None)
packed_mapping = getattr(model_class, "packed_modules_mapping", None)
# pass mappings by reference to quant_config
if hf_to_vllm_mapper is not None:
quant_config.apply_vllm_mapper(hf_to_vllm_mapper)
if packed_mapping is not None:
quant_config.packed_modules_mapping = packed_mapping

File diff suppressed because it is too large Load Diff