This commit is contained in:
root
2026-03-05 18:06:10 +08:00
commit 809cecae09
2569 changed files with 478204 additions and 0 deletions

View File

@@ -0,0 +1,152 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Literal
from torch import nn
from vllm.config import ModelConfig, VllmConfig
from vllm.config.load import LoadConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.bitsandbytes_loader import BitsAndBytesModelLoader
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader
from vllm.model_executor.model_loader.gguf_loader import GGUFModelLoader
from vllm.model_executor.model_loader.runai_streamer_loader import (
RunaiModelStreamerLoader,
)
from vllm.model_executor.model_loader.sharded_state_loader import ShardedStateLoader
from vllm.model_executor.model_loader.tensorizer_loader import TensorizerLoader
from vllm.model_executor.model_loader.utils import (
get_architecture_class_name,
get_model_architecture,
get_model_cls,
)
from vllm.model_executor.model_loader.weight_utils import (
padding_weight_loader
)
logger = init_logger(__name__)
# Reminder: Please update docstring in `LoadConfig`
# if a new load format is added here
LoadFormats = Literal[
"auto",
"bitsandbytes",
"dummy",
"fastsafetensors",
"gguf",
"mistral",
"npcache",
"pt",
"runai_streamer",
"runai_streamer_sharded",
"safetensors",
"sharded_state",
"tensorizer",
]
_LOAD_FORMAT_TO_MODEL_LOADER: dict[str, type[BaseModelLoader]] = {
"auto": DefaultModelLoader,
"bitsandbytes": BitsAndBytesModelLoader,
"dummy": DummyModelLoader,
"fastsafetensors": DefaultModelLoader,
"gguf": GGUFModelLoader,
"mistral": DefaultModelLoader,
"npcache": DefaultModelLoader,
"pt": DefaultModelLoader,
"runai_streamer": RunaiModelStreamerLoader,
"runai_streamer_sharded": ShardedStateLoader,
"safetensors": DefaultModelLoader,
"sharded_state": ShardedStateLoader,
"tensorizer": TensorizerLoader,
}
def register_model_loader(load_format: str):
"""Register a customized vllm model loader.
When a load format is not supported by vllm, you can register a customized
model loader to support it.
Args:
load_format (str): The model loader format name.
Examples:
>>> from vllm.config.load import LoadConfig
>>> from vllm.model_executor.model_loader import (
... get_model_loader,
... register_model_loader,
... )
>>> from vllm.model_executor.model_loader.base_loader import BaseModelLoader
>>>
>>> @register_model_loader("my_loader")
... class MyModelLoader(BaseModelLoader):
... def download_model(self):
... pass
...
... def load_weights(self):
... pass
>>>
>>> load_config = LoadConfig(load_format="my_loader")
>>> type(get_model_loader(load_config))
<class 'MyModelLoader'>
""" # noqa: E501
def _wrapper(model_loader_cls):
if load_format in _LOAD_FORMAT_TO_MODEL_LOADER:
logger.warning(
"Load format `%s` is already registered, and will be "
"overwritten by the new loader class `%s`.",
load_format,
model_loader_cls,
)
if not issubclass(model_loader_cls, BaseModelLoader):
raise ValueError(
"The model loader must be a subclass of `BaseModelLoader`."
)
_LOAD_FORMAT_TO_MODEL_LOADER[load_format] = model_loader_cls
logger.info(
"Registered model loader `%s` with load format `%s`",
model_loader_cls,
load_format,
)
return model_loader_cls
return _wrapper
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
"""Get a model loader based on the load format."""
load_format = load_config.load_format
if load_format not in _LOAD_FORMAT_TO_MODEL_LOADER:
raise ValueError(f"Load format `{load_format}` is not supported")
return _LOAD_FORMAT_TO_MODEL_LOADER[load_format](load_config)
def get_model(
*, vllm_config: VllmConfig, model_config: ModelConfig | None = None
) -> nn.Module:
loader = get_model_loader(vllm_config.load_config)
if model_config is None:
model_config = vllm_config.model_config
return loader.load_model(vllm_config=vllm_config, model_config=model_config)
__all__ = [
"get_model",
"get_model_loader",
"get_architecture_class_name",
"get_model_architecture",
"get_model_cls",
"register_model_loader",
"BaseModelLoader",
"BitsAndBytesModelLoader",
"GGUFModelLoader",
"DefaultModelLoader",
"DummyModelLoader",
"RunaiModelStreamerLoader",
"ShardedStateLoader",
"TensorizerLoader",
]

View File

@@ -0,0 +1,57 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
import torch
import torch.nn as nn
from vllm.config import ModelConfig, VllmConfig
from vllm.config.load import LoadConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.utils import (
initialize_model,
process_weights_after_loading,
)
from vllm.utils.torch_utils import set_default_torch_dtype
logger = init_logger(__name__)
class BaseModelLoader(ABC):
"""Base class for model loaders."""
def __init__(self, load_config: LoadConfig):
self.load_config = load_config
@abstractmethod
def download_model(self, model_config: ModelConfig) -> None:
"""Download a model so that it can be immediately loaded."""
raise NotImplementedError
@abstractmethod
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
"""Load weights into a model. This standalone API allows
inplace weights loading for an already-initialized model"""
raise NotImplementedError
def load_model(
self, vllm_config: VllmConfig, model_config: ModelConfig
) -> nn.Module:
"""Load a model with the given configurations."""
device_config = vllm_config.device_config
load_config = vllm_config.load_config
load_device = (
device_config.device if load_config.device is None else load_config.device
)
target_device = torch.device(load_device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = initialize_model(
vllm_config=vllm_config, model_config=model_config
)
logger.debug("Loading weights on %s ...", load_device)
# Quantization does not happen in `load_weights` but after it
self.load_weights(model, model_config)
process_weights_after_loading(model, model_config, target_device)
return model.eval()

View File

@@ -0,0 +1,822 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: SIM117
import fnmatch
import glob
import itertools
import math
import os
from collections.abc import Callable, Generator
from typing import Any
import numpy as np
import torch
from huggingface_hub import HfApi
from packaging import version
from torch import nn
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import ModelConfig
from vllm.config.load import LoadConfig
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.linear import (
LinearBase,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.utils import ParamMapping
from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf,
download_weights_from_hf,
filter_duplicate_safetensors_files,
filter_files_not_needed_for_inference,
pt_weights_iterator,
safetensors_weights_iterator,
)
from vllm.model_executor.models import is_pooling_model
from vllm.model_executor.utils import (
get_moe_expert_mapping,
get_packed_modules_mapping,
set_weight_attrs,
)
from vllm.platforms import current_platform
from vllm.utils.torch_utils import set_default_torch_dtype
logger = init_logger(__name__)
def is_moe_model(model: torch.nn.Module) -> bool:
"""Checks if the model contains FusedMoE layers."""
return bool(any(isinstance(module, FusedMoE) for module in model.modules()))
class BitsAndBytesModelLoader(BaseModelLoader):
"""Model loader to load model weights with BitAndBytes quantization."""
possible_config_file_names = ["adapter_config.json"]
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
# Save the module names without sharding.
self.unsharded_weights_modules: list[str] = []
# Save the module names that are sharded by column.
self.column_sharded_weights_modules: list[str] = []
# Modules whose weights might have fused on disk
# we need their output_sizes to make shard in flight correctly with TP
self.maybe_fused_weights_modules: dict[str, list[int]] = {}
# Store all module names (from transformers) that support
# BNB quantization.
self.target_modules: list[str] = []
self.tp_disabled_modules: list[str] = []
# Store the mapping of expert parameters for MoE models.
self.expert_params_mapping: list[tuple[str, str, int, str]] = []
# mapping weight names from transformers to vllm.
self.weight_mapper: Callable = lambda name: name
self.pre_quant: bool = False
self.load_8bit: bool = False
self.is_pool_model: bool = False
def _get_weight_files(
self,
model_name_or_path: str,
allowed_patterns: list[str],
revision: str | None = None,
) -> tuple[str, list[str], str]:
"""Retrieve weight files. Download the files if necessary.
Return the weight files and the file pattern."""
is_local = os.path.isdir(model_name_or_path)
if is_local:
for pattern in allowed_patterns:
weight_files = glob.glob(os.path.join(model_name_or_path, pattern))
if weight_files:
return model_name_or_path, weight_files, pattern
else:
hf_api = HfApi()
repo_files = hf_api.list_repo_files(repo_id=model_name_or_path)
for pattern in allowed_patterns:
matching_files = fnmatch.filter(repo_files, pattern)
if matching_files:
hf_folder = download_weights_from_hf(
model_name_or_path,
self.load_config.download_dir,
[pattern],
revision,
ignore_patterns=self.load_config.ignore_patterns,
)
return (
hf_folder,
glob.glob(os.path.join(hf_folder, pattern)),
pattern,
)
raise RuntimeError(f"No model weights found in: `{model_name_or_path}`")
def _prepare_weights(
self, model_name_or_path: str, revision: str | None
) -> tuple[list[str], bool]:
"""Prepare weight files for the model."""
allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]
hf_folder, hf_weights_files, matched_pattern = self._get_weight_files(
model_name_or_path, allowed_patterns, revision
)
use_safetensors = matched_pattern == "*.safetensors"
is_local = os.path.isdir(model_name_or_path)
index_file = SAFE_WEIGHTS_INDEX_NAME
if use_safetensors:
# For models like Mistral-7B-Instruct-v0.3
# there are both sharded safetensors files and a consolidated
# safetensors file. Using both breaks.
# Here, we download the `model.safetensors.index.json` and filter
# any files not found in the index.
if not is_local:
download_safetensors_index_file_from_hf(
model_name_or_path,
index_file,
self.load_config.download_dir,
revision,
)
hf_weights_files = filter_duplicate_safetensors_files(
hf_weights_files, hf_folder, index_file
)
else:
hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files)
if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`"
)
return hf_weights_files, use_safetensors
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
def _maybe_pool_model(module_name: str):
# For pool model, we need to add the prefix `model.`
# for the weight name if possible.
if (
self.is_pool_model
and self.target_modules[0].startswith("model.")
and not module_name.startswith("model.")
):
return "model." + module_name
return module_name
if use_safetensors:
iterator = safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
else:
iterator = pt_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
self.load_config.pt_load_map_location,
)
for org_name, param in iterator:
# mapping weight names from transformers to vllm while preserving
# original names.
mapped_name = self.weight_mapper(org_name)
mapped_name = _maybe_pool_model(mapped_name)
yield org_name, mapped_name, param
def _get_quantized_weights_iterator(
self,
model_name_or_path: str,
revision: str | None,
) -> tuple[Generator[tuple[str, torch.Tensor], None, None], dict[str, Any]]:
"""Get an iterator to the model weights with bitsandbytes quantization,
as well as the quantization state dictionary."""
# only load the bitsandbytes module when needed
try:
import bitsandbytes
if version.parse(bitsandbytes.__version__) < version.parse("0.46.1"):
raise ImportError(
"bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.46.1."
)
except ImportError as err:
raise ImportError(
"Please install bitsandbytes>=0.46.1 via "
"`pip install bitsandbytes>=0.46.1` to use "
"bitsandbytes quantizer."
) from err
hf_weights_files, use_safetensors = self._prepare_weights(
model_name_or_path, revision
)
quant_state_dict: dict[str, Any] = {}
if self.pre_quant:
if self.load_8bit:
return self._quantized_8bit_generator(
hf_weights_files, use_safetensors, quant_state_dict
), quant_state_dict
else:
return self._quantized_4bit_generator(
hf_weights_files, use_safetensors, quant_state_dict
), quant_state_dict
return self._unquantized_generator(
hf_weights_files, use_safetensors, quant_state_dict
), quant_state_dict
def _is_8bit_weight_name(self, weight_name: str):
quantized_suffix = {".scb", ".weight_format"}
return any(weight_name.lower().endswith(suffix) for suffix in quantized_suffix)
def _is_4bit_weight_name(self, weight_name: str):
quantized_suffix = {
"absmax",
"quant_map",
"nested_absmax",
"nested_quant_map",
"bitsandbytes",
}
suffix = weight_name.split(".")[-1]
return any(q_suffix in suffix for q_suffix in quantized_suffix)
def _quantized_8bit_generator(
self, hf_weights_files, use_safetensors, quant_state_dict
) -> Generator:
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
if not mapped_weight_name.lower().endswith(".scb"):
continue
weight_key = mapped_weight_name.lower().replace(".scb", ".weight")
quant_state_dict[weight_key] = weight_tensor
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
if self._is_8bit_weight_name(mapped_weight_name):
continue
if mapped_weight_name in quant_state_dict:
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
yield org_weight_name, weight_tensor
else:
yield org_weight_name, weight_tensor
def _quantized_4bit_generator(
self, hf_weights_files, use_safetensors, quant_state_dict
) -> Generator:
from bitsandbytes.functional import QuantState
# First iterate over all quant state weights
weight_iterator = self._hf_weight_iter(hf_weights_files, use_safetensors)
temp_state_dict = {}
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in weight_iterator:
if not self._is_4bit_weight_name(mapped_weight_name):
continue
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__* in CPU
if "quant_state.bitsandbytes" in mapped_weight_name:
temp_state_dict[mapped_weight_name] = weight_tensor.cpu().data
else:
temp_state_dict[mapped_weight_name] = weight_tensor
# Closure to parse quant_state for each prequant weight
def _parse_quant_state(param_name: str, temp_state_dict: dict) -> QuantState:
quant_state = {}
for k in temp_state_dict:
if param_name + "." in k:
quant_state[k] = temp_state_dict[k]
return QuantState.from_dict(
quant_state, device=current_platform.device_type
)
# Second iterate over all prequant and normal weights
# pre quantized weights would have a quant_state
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
if self._is_4bit_weight_name(mapped_weight_name):
continue
if (
f"{mapped_weight_name}.quant_state.bitsandbytes__nf4" in temp_state_dict
) or (
f"{mapped_weight_name}.quant_state.bitsandbytes__fp4" in temp_state_dict
):
quant_state = _parse_quant_state(mapped_weight_name, temp_state_dict)
quant_state_dict[mapped_weight_name] = quant_state
yield org_weight_name, weight_tensor
else:
yield org_weight_name, weight_tensor
def _unquantized_generator(
self, hf_weights_files, use_safetensors, quant_state_dict
) -> Generator:
from bitsandbytes.functional import quantize_4bit
global_tp_size = get_tensor_model_parallel_world_size()
global_tp_rank = get_tensor_model_parallel_rank()
check_match = (
lambda weight_name, module_name: weight_name.removesuffix(".weight")
== module_name
)
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
# override tp_size and tp_rank if the module has disabled TP
if any(
tp_disabled_module in mapped_weight_name
for tp_disabled_module in self.tp_disabled_modules
):
tp_size = 1
tp_rank = 0
else:
tp_size = global_tp_size
tp_rank = global_tp_rank
if any(
target_module in mapped_weight_name
for target_module in self.target_modules
) and mapped_weight_name.endswith(".weight"):
# Without sharding
if any(
check_match(mapped_weight_name, module)
for module in self.unsharded_weights_modules
):
weight_sub_tensor = weight_tensor
# Shard by column
elif any(
check_match(mapped_weight_name, module)
for module in self.column_sharded_weights_modules
):
total_size = weight_tensor.size(-1)
start_index = total_size // tp_size * tp_rank
end_index = total_size // tp_size * (tp_rank + 1)
weight_sub_tensor = weight_tensor[..., start_index:end_index]
# Weights have fused on disk. In this case, we assume that the
# weight and module use same name.
elif any(
check_match(mapped_weight_name, module)
for module in self.maybe_fused_weights_modules
):
# special case for fused weights
# get the size of each shard weight tensor
total_shard_sizes = next(
(
sizes
for module, sizes in self.maybe_fused_weights_modules.items() # noqa: E501
if check_match(mapped_weight_name, module)
)
)
total_size = weight_tensor.size(0)
assert total_size == sum(total_shard_sizes)
# get the start/end index of each shard weight tensor
total_start_index = list(
itertools.accumulate([0] + total_shard_sizes)
)[:-1]
shard_weights_index = [
(
idx + size // tp_size * tp_rank,
idx + size // tp_size * (tp_rank + 1),
)
for idx, size in zip(total_start_index, total_shard_sizes)
]
# slice and reorder the weight tensor
weight_tensor = [
weight_tensor[start_index:end_index, ...]
for start_index, end_index in shard_weights_index
]
weight_sub_tensor = torch.cat(weight_tensor, dim=0)
# Shard by row
else:
total_size = weight_tensor.size(0)
start_index = total_size // tp_size * tp_rank
end_index = total_size // tp_size * (tp_rank + 1)
weight_sub_tensor = weight_tensor[start_index:end_index, ...]
# bitsandbytes requires data in GPU
if weight_sub_tensor.is_cuda:
loaded_weight = weight_sub_tensor
else:
loaded_weight = weight_sub_tensor.to(
device=current_platform.device_type
)
# remove the following after the issue is fixed:
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
if loaded_weight.is_contiguous() is False:
loaded_weight = loaded_weight.contiguous()
with set_default_torch_dtype(torch.float32):
processed_weight, quant_state = quantize_4bit(
loaded_weight,
compress_statistics=True,
quant_type="nf4",
)
quant_state_dict[mapped_weight_name] = quant_state
else:
processed_weight = weight_tensor
yield org_weight_name, processed_weight
def _get_bnb_target_modules(self, model: nn.Module) -> None:
"""
Identify and collect all modules that support BitsAndBytes
quantization.
"""
for name, module in model.named_modules():
if isinstance(module, LinearBase) and hasattr(
module.quant_method, "quant_config"
):
if modules_info := self.modules_mapping.get_sub_modules(name):
# Map vllm's names to transformers's names.
rep_name, sub_modules = modules_info
for sub_name in sub_modules:
new_name = name.replace(rep_name, sub_name)
self.target_modules.append(new_name)
if module.disable_tp:
self.tp_disabled_modules.append(new_name)
# Add original module name even if the module has stacked map,
# in case model has a mixture of disk-merged and disk-split
# weights with same last name.
self.target_modules.append(name)
if module.disable_tp:
self.tp_disabled_modules.append(name)
elif isinstance(module, FusedMoE) and hasattr(
module.quant_method, "quant_config"
):
# TODO: support FusedMoE with prequant and 8bit.
if self.pre_quant and self.load_8bit:
raise ValueError(
"Prequant BitsAndBytes 8bit models with FusedMoE "
"is not supported yet."
)
# Get the corresponding weight name using module name and
# expert_params_mapping.
for exp in self.expert_params_mapping:
weight_name = exp[1]
rep_name = name.replace("experts", "") + weight_name.removesuffix(
"."
)
self.target_modules.append(rep_name)
assert self.target_modules, (
"vLLM currently does not support BNB quantization for"
)
f" {type(model).__name__}"
def _classify_module_sharding(self, model: nn.Module):
"""
Categorize modules based on their weight sharding requirements
for tensor parallelism.
"""
for name, module in model.named_modules():
# Some modules like `ReplicatedLinear` should not have their weights
# sharded. The reason for implementing it this way is to avoid new
# static variable in the model implementation.
if isinstance(module, (ReplicatedLinear,)):
self.unsharded_weights_modules.append(name)
# `QKVParallelLinear` and `MergedColumnParallelLinear` might have
# fused weights on disk. We need to use the output sizes of these
# modules to shard the weights correctly.
elif isinstance(module, (QKVParallelLinear, MergedColumnParallelLinear)):
self.maybe_fused_weights_modules[name] = module.output_sizes
# In TP, these weights are partitioned along the column
# dimension (dim=-1)
elif isinstance(module, (RowParallelLinear,)):
self.column_sharded_weights_modules.append(name)
elif isinstance(module, FusedMoE):
expert_mapping = self.expert_params_mapping
for exp in expert_mapping:
if exp[-1] == "w2":
weight_name = exp[1]
rep_name = name.replace(
"experts", ""
) + weight_name.removesuffix(".")
self.column_sharded_weights_modules.append(rep_name)
def _verify_model_compatibility(
self, model: nn.Module, model_config: ModelConfig
) -> None:
"""
Verify that the model is compatible with BitsAndBytes quantization.
"""
if not hasattr(model, "load_weights"):
raise AttributeError(
"The required method 'load_weights' is not defined in class"
f" {type(model).__name__}."
)
if not hasattr(model, "packed_modules_mapping"):
raise AttributeError(
f"Model {type(model).__name__} does not support BitsAndBytes "
"quantization yet. No 'packed_modules_mapping' found."
)
quant_config = getattr(model_config.hf_config, "quantization_config", None)
if quant_config and (quant_method := quant_config.get("quant_method")):
if quant_method == "bitsandbytes":
self.pre_quant = True
else:
raise ValueError(
f"BitsAndBytes loader does not support {quant_method} quantization"
)
# The quant_states in pre_quantized models cannot work with a split
# weight tensor. So TP does not work with pre_quantized bnb models.
if self.pre_quant and get_tensor_model_parallel_world_size() > 1:
raise ValueError(
"Prequant BitsAndBytes models with tensor parallelism is not "
"supported. Please try with pipeline parallelism."
)
if quant_config and self.pre_quant:
self.load_8bit = quant_config.get("load_in_8bit", False)
def _initialize_loader_state(
self, model: nn.Module, model_config: ModelConfig
) -> None:
"""
Initialize the loader's internal state based on the model and
configuration.
"""
self.is_pool_model = is_pooling_model(model)
self.modules_mapping = ParamMapping(get_packed_modules_mapping(model))
if is_moe_model(model):
self.expert_params_mapping = get_moe_expert_mapping(model)
if not self.expert_params_mapping:
raise AttributeError(
f"MoE Model {type(model).__name__} does not support "
"BitsAndBytes quantization yet. Ensure this model has "
"'get_expert_mapping' method."
)
# For some models like Molmo, we need to use hf_to_vllm_mapper
# to ensure correct loading of weights.
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name)
self._get_bnb_target_modules(model)
self._classify_module_sharding(model)
def _dequantize_dq(self, quant_states: Any):
"""
When BNB employs Double Quantization, we perform the dequantization of
these constants during weight loading rather than at inference time,
thereby avoiding this computational overhead during inference. This
comes at the cost of increased memory usage.
"""
from bitsandbytes.functional import QuantState, dequantize_blockwise
def _dequantize_single_state(quant_state):
"""Helper function to dequantize a single QuantState object."""
if not (isinstance(quant_state, QuantState) and quant_state.nested):
return
# Copied from: https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.45.3/bitsandbytes/functional.py#L1352-#L1356
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
absmax += quant_state.offset
# Ensure float32 dtype
if absmax.dtype != torch.float32:
absmax = absmax.float()
quant_state.absmax = absmax
quant_state.nested = False
quant_state.offset = None
quant_state.state2 = None
if isinstance(quant_states, dict):
for quant_state in quant_states.values():
_dequantize_single_state(quant_state)
else:
_dequantize_single_state(quant_states)
return quant_states
def _fuse_moe_quant_states(self, model: nn.Module, quant_states_dict: dict) -> dict:
"""
This function consolidates individual expert quantization states into
fused representations for w13 and w2.
"""
from bitsandbytes.functional import QuantState
if not self.expert_params_mapping:
return dict()
expert_mapping = self.expert_params_mapping
expert_qs_dict = {}
for name, module in model.named_modules():
if not isinstance(module, FusedMoE):
continue
w1_states_lst = []
w2_states_lst = []
w3_states_lst = []
for exp in expert_mapping:
shard_id = exp[-1]
if shard_id not in ("w1", "w2", "w3"):
raise ValueError(
f"shard_id must be ['w1','w2','w3'] but got {shard_id}."
)
layer_prefix = name.split("experts")[0]
weight_qual_name = layer_prefix + exp[1] + "weight"
quant_state = self._dequantize_dq(quant_states_dict[weight_qual_name])
if shard_id == "w1":
w1_states_lst.append(quant_state)
elif shard_id == "w2":
w2_states_lst.append(quant_state)
else:
w3_states_lst.append(quant_state)
del quant_states_dict[weight_qual_name]
assert len(w1_states_lst) == len(w2_states_lst) == len(w3_states_lst)
w13_absmax_lst = []
w2_absmax_lst = []
w13_total_dim0 = 0
w2_total_dim0 = 0
for w1_qs, w2_qs, w3_qs in zip(w1_states_lst, w2_states_lst, w3_states_lst):
assert w1_qs.shape == w3_qs.shape
assert w1_qs.blocksize == w2_qs.blocksize == w3_qs.blocksize
assert w1_qs.dtype == w2_qs.dtype == w3_qs.dtype
# w1 and w3 are interleaved in storage
w13_absmax_lst.append(w1_qs.absmax)
w13_absmax_lst.append(w3_qs.absmax)
w2_absmax_lst.append(w2_qs.absmax)
w13_total_dim0 += w1_qs.shape[0] + w3_qs.shape[0]
w2_total_dim0 += w2_qs.shape[0]
w13_absmax = torch.cat(w13_absmax_lst)
w2_absmax = torch.cat(w2_absmax_lst)
# Create fused quantization state for w13.
w13_qs = QuantState(
absmax=w13_absmax,
shape=(w13_total_dim0, w1_states_lst[0].shape[1]),
code=w1_states_lst[0].code,
blocksize=w1_states_lst[0].blocksize,
quant_type="nf4",
dtype=w1_states_lst[0].dtype,
)
# Create fused quantization state for w2.
w2_qs = QuantState(
absmax=w2_absmax,
shape=(w2_total_dim0, w2_states_lst[0].shape[1]),
code=w2_states_lst[0].code,
blocksize=w2_states_lst[0].blocksize,
quant_type="nf4",
dtype=w2_states_lst[0].dtype,
)
# The weight suffixes .w13_weight and .w2_weight are consistent
# with the param in BitsAndBytesMoEMethod.
w13_weight_name = name + ".w13_weight"
w2_weight_name = name + ".w2_weight"
expert_qs_dict[w13_weight_name] = w13_qs
expert_qs_dict[w2_weight_name] = w2_qs
return expert_qs_dict
def _stack_quantization_states(
self, model: nn.Module, quant_state_dict: dict
) -> dict[str, dict[int, Any]]:
stacked_quant_state_dict: dict[str, dict[int, Any]] = {}
# TODO: Change this lazy import to normal import
# after the checks are updated to run on a new version
from vllm.model_executor.models.utils import is_pp_missing_parameter
param_dict = dict(model.named_parameters())
for quant_param_name in quant_state_dict:
if is_pp_missing_parameter(quant_param_name, model):
continue
non_stacked_param_name = quant_param_name
shard_index = 0
for shard_name, (
weight_name,
index,
) in self.modules_mapping.inverse_packed_mapping.items():
# Some models, such as MiniCPM V2.5/2.6, contain both
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
# from being incorrectly identified as being present in
# 'vpm.encoder.layers.0.self_attn.qkv_proj.weight
shard_pos = quant_param_name.find(shard_name)
can_correct_rename = (shard_pos > 0) and (
quant_param_name[shard_pos - 1] == "."
)
# If the quant_param_name is packed, it won't occur in the
# param_dict before renaming.
new_quant_param_name = quant_param_name.replace(shard_name, weight_name)
need_rename = (quant_param_name not in param_dict) and (
new_quant_param_name in param_dict
)
if can_correct_rename and need_rename:
shard_index = index
quant_param_name = new_quant_param_name
break
# Models like Clip/Siglip may skip some layers in initialization,
# causing unused quant_param_name in state_dict.
if quant_param_name not in param_dict:
continue
if quant_param_name not in stacked_quant_state_dict:
stacked_quant_state_dict[quant_param_name] = {}
stacked_quant_state_dict[quant_param_name][shard_index] = quant_state_dict[
non_stacked_param_name
]
return stacked_quant_state_dict
def _bind_quant_states_to_params(
self, model: nn.Module, stacked_quant_state_dict: dict
) -> None:
# save quant_states and offsets as the attributes of the parameters
param_dict = dict(model.named_parameters())
for param_name, param in param_dict.items():
if param_name in stacked_quant_state_dict:
quant_states = stacked_quant_state_dict[param_name]
# Dequantize double quantized values during weight loading.
self._dequantize_dq(quant_states)
set_weight_attrs(param, {"bnb_quant_state": quant_states})
if not isinstance(quant_states, dict):
continue
pack_ratio = getattr(param, "pack_factor", -1)
if pack_ratio == -1:
raise ValueError(f"pack_factor not set for parameter {param_name}.")
num_elements = [0] * len(quant_states)
for seq, quant_state in quant_states.items():
num_elements[seq] = math.prod(quant_state.shape) // pack_ratio
offsets = np.concatenate(([0], np.cumsum(num_elements)))
# Make torch infer_schema happy
offsets = torch.tensor(offsets).cpu()
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
if self.load_8bit:
set_weight_attrs(
param, {"matmul_state": [None] * len(quant_states)}
)
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
self._verify_model_compatibility(model, model_config)
self._initialize_loader_state(model, model_config)
logger.info(
"Loading weights with BitsAndBytes quantization. May take a while ..."
)
qweight_iterator, quant_state_dict = self._get_quantized_weights_iterator(
model_config.model,
model_config.revision,
)
weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(qweight_iterator)
# Some models may have weights loading tracker unimplemented.
if loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded:
raise ValueError(
"Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}"
)
expert_quant_state_dict = self._fuse_moe_quant_states(model, quant_state_dict)
stacked_quant_state_dict = self._stack_quantization_states(
model, quant_state_dict
)
stacked_quant_state_dict = {
**expert_quant_state_dict,
**stacked_quant_state_dict,
}
self._bind_quant_states_to_params(model, stacked_quant_state_dict)
torch.cuda.empty_cache()
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)

View File

@@ -0,0 +1,329 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import glob
import os
import time
from collections.abc import Generator, Iterable
from typing import cast
import torch
from torch import nn
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import ModelConfig
from vllm.config.load import LoadConfig
from vllm import envs
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.torchao import torchao_version_at_least
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf,
download_weights_from_hf,
fastsafetensors_weights_iterator,
filter_duplicate_safetensors_files,
filter_files_not_needed_for_inference,
maybe_download_from_modelscope,
multi_thread_pt_weights_iterator,
multi_thread_safetensors_weights_iterator,
np_cache_weights_iterator,
pt_weights_iterator,
safetensors_weights_iterator,
)
from vllm.platforms import current_platform
logger = init_logger(__name__)
class DefaultModelLoader(BaseModelLoader):
"""Model loader that can load different file types from disk."""
# default number of thread when enable multithread weight loading
DEFAULT_NUM_THREADS = 8
@dataclasses.dataclass
class Source:
"""A source for weights."""
model_or_path: str
"""The model ID or path."""
revision: str | None
"""The optional model revision."""
prefix: str = ""
"""A prefix to prepend to all weights."""
fall_back_to_pt: bool = True
"""Whether .pt weights can be used."""
allow_patterns_overrides: list[str] | None = None
"""If defined, weights will load exclusively using these patterns."""
counter_before_loading_weights: float = 0.0
counter_after_loading_weights: float = 0.0
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
extra_config = load_config.model_loader_extra_config
allowed_keys = {"enable_multithread_load", "num_threads"}
unexpected_keys = set(extra_config.keys()) - allowed_keys
if unexpected_keys:
raise ValueError(
f"Unexpected extra config keys for load format "
f"{load_config.load_format}: "
f"{unexpected_keys}"
)
def _prepare_weights(
self,
model_name_or_path: str,
revision: str | None,
fall_back_to_pt: bool,
allow_patterns_overrides: list[str] | None,
) -> tuple[str, list[str], bool]:
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
model_name_or_path = (
maybe_download_from_modelscope(model_name_or_path, revision)
or model_name_or_path
)
is_local = os.path.isdir(model_name_or_path)
load_format = self.load_config.load_format
use_safetensors = False
index_file = SAFE_WEIGHTS_INDEX_NAME
# Some quantized models use .pt files for storing the weights.
if load_format == "auto":
allow_patterns = ["*.safetensors", "*.bin"]
elif load_format == "safetensors" or load_format == "fastsafetensors":
use_safetensors = True
allow_patterns = ["*.safetensors"]
elif load_format == "mistral":
use_safetensors = True
allow_patterns = ["consolidated*.safetensors"]
index_file = "consolidated.safetensors.index.json"
elif load_format == "pt":
allow_patterns = ["*.pt"]
elif load_format == "npcache":
allow_patterns = ["*.bin"]
else:
raise ValueError(f"Unknown load_format: {load_format}")
if fall_back_to_pt:
allow_patterns += ["*.pt"]
if allow_patterns_overrides is not None:
allow_patterns = allow_patterns_overrides
if not is_local:
hf_folder = download_weights_from_hf(
model_name_or_path,
self.load_config.download_dir,
allow_patterns,
revision,
ignore_patterns=self.load_config.ignore_patterns,
)
else:
hf_folder = model_name_or_path
hf_weights_files: list[str] = []
for pattern in allow_patterns:
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
if len(hf_weights_files) > 0:
if pattern == "*.safetensors":
use_safetensors = True
break
if use_safetensors:
# For models like Mistral-7B-Instruct-v0.3
# there are both sharded safetensors files and a consolidated
# safetensors file. Using both breaks.
# Here, we download the `model.safetensors.index.json` and filter
# any files not found in the index.
if not is_local:
download_safetensors_index_file_from_hf(
model_name_or_path,
index_file,
self.load_config.download_dir,
revision,
)
hf_weights_files = filter_duplicate_safetensors_files(
hf_weights_files, hf_folder, index_file
)
else:
hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files)
if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`"
)
return hf_folder, hf_weights_files, use_safetensors
def _get_weights_iterator(
self, source: "Source"
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
extra_config = self.load_config.model_loader_extra_config
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
source.model_or_path,
source.revision,
source.fall_back_to_pt,
source.allow_patterns_overrides,
)
if self.load_config.load_format == "npcache":
# Currently np_cache only support *.bin checkpoints
assert use_safetensors is False
weights_iterator = np_cache_weights_iterator(
source.model_or_path,
self.load_config.download_dir,
hf_folder,
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
elif use_safetensors:
if self.load_config.load_format == "fastsafetensors":
weights_iterator = fastsafetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
else:
if extra_config.get("enable_multithread_load"):
weights_iterator = multi_thread_safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
max_workers=extra_config.get(
"num_threads", self.DEFAULT_NUM_THREADS
),
)
else:
weights_iterator = safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
self.load_config.safetensors_load_strategy,
)
else:
if extra_config.get("enable_multithread_load"):
weights_iterator = multi_thread_pt_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
self.load_config.pt_load_map_location,
max_workers=extra_config.get(
"num_threads", self.DEFAULT_NUM_THREADS
),
)
else:
weights_iterator = pt_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
self.load_config.pt_load_map_location,
)
if current_platform.is_tpu():
from vllm.platforms.tpu import USE_TPU_INFERENCE
if not USE_TPU_INFERENCE:
# In PyTorch XLA, we should call `torch_xla.sync`
# frequently so that not too many ops are accumulated
# in the XLA program.
import torch_xla
def _xla_weights_iterator(iterator: Generator):
for weights in iterator:
yield weights
torch_xla.sync(wait=False)
weights_iterator = _xla_weights_iterator(weights_iterator)
if self.counter_before_loading_weights == 0.0:
self.counter_before_loading_weights = time.perf_counter()
# Apply the prefix.
return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator)
def get_all_weights(
self,
model_config: ModelConfig,
model: nn.Module,
) -> Generator[tuple[str, torch.Tensor], None, None]:
primary_weights = DefaultModelLoader.Source(
model_config.model,
model_config.revision,
prefix="",
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True),
allow_patterns_overrides=getattr(model, "allow_patterns_overrides", None),
)
yield from self._get_weights_iterator(primary_weights)
secondary_weights = cast(
Iterable[DefaultModelLoader.Source],
getattr(model, "secondary_weights", ()),
)
for source in secondary_weights:
yield from self._get_weights_iterator(source)
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(
model_config.model,
model_config.revision,
fall_back_to_pt=True,
allow_patterns_overrides=None,
)
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
if model_config.quantization == "torchao" and torchao_version_at_least(
"0.14.0"
):
self.load_config.safetensors_load_strategy = "torchao"
weights_to_load = {name for name, _ in model.named_parameters()}
# if we don't have `model.weight_metadata_and_attr_saved` defined and
# set to True, it means that this is either offline quantization case
# or the first run of online quantization
# see online_quantization.py for detailed notes
offline_quantization_or_first_run_of_online_quantization = not getattr(
model, "weight_metadata_and_attr_saved", False
)
if model_config.quantization is None:
# model is not quantized
loaded_weights = model.load_weights(
self.get_all_weights(model_config, model)
)
elif offline_quantization_or_first_run_of_online_quantization:
# case 1: offline quantized checkpoint
# case 2: Step I1 first run of weight loading with
# online quantization
# see online_quantization.py for detailed notes
loaded_weights = model.load_weights(
self.get_all_weights(model_config, model)
)
else:
# to avoid circular dependency
from vllm.model_executor.model_loader.online_quantization import (
load_weights_and_online_quantize,
)
# subsequent runs of weight loading with online
# quantization
loaded_weights = load_weights_and_online_quantize(self, model, model_config)
self.counter_after_loading_weights = time.perf_counter()
logger.info_once(
"Loading weights took %.2f seconds",
self.counter_after_loading_weights - self.counter_before_loading_weights,
scope="local",
)
# We only enable strict check for non-quantized models
# that have loaded weights tracking currently.
opt_flag = envs.VLLM_MOE_OPT_LEVEL != 0 or envs.VLLM_LINEAR_OPT_LEVEL != 0
if model_config.quantization is None and loaded_weights is not None and not opt_flag:
weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded:
raise ValueError(
"Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}"
)

View File

@@ -0,0 +1,28 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch.nn as nn
from vllm.config import ModelConfig
from vllm.config.load import LoadConfig
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.weight_utils import initialize_dummy_weights
class DummyModelLoader(BaseModelLoader):
"""Model loader that will set model weights to random values."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(
f"Model loader extra config is not supported for "
f"load format {load_config.load_format}"
)
def download_model(self, model_config: ModelConfig) -> None:
pass # Nothing to download
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)

View File

@@ -0,0 +1,176 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from collections.abc import Generator
import gguf
import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download
from transformers import AutoModelForCausalLM
from vllm.config import ModelConfig, VllmConfig
from vllm.config.load import LoadConfig
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.utils import (
initialize_model,
process_weights_after_loading,
)
from vllm.model_executor.model_loader.weight_utils import (
get_gguf_extra_tensor_names,
get_gguf_weight_type_map,
gguf_quant_weights_iterator,
)
from vllm.utils.torch_utils import set_default_torch_dtype
class GGUFModelLoader(BaseModelLoader):
"""
Model loader that can load GGUF files. This is useful for loading models
that are quantized with GGUF and saved in the GGUF format. This loader
supports loading both full models and sharded models.
"""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(
f"Model loader extra config is not supported for "
f"load format {load_config.load_format}"
)
def _prepare_weights(self, model_name_or_path: str):
if os.path.isfile(model_name_or_path):
return model_name_or_path
# for raw HTTPS link
if model_name_or_path.startswith(
("http://", "https://")
) and model_name_or_path.endswith(".gguf"):
return hf_hub_download(url=model_name_or_path)
# repo id/filename.gguf
if "/" in model_name_or_path and model_name_or_path.endswith(".gguf"):
repo_id, filename = model_name_or_path.rsplit("/", 1)
return hf_hub_download(repo_id=repo_id, filename=filename)
else:
raise ValueError(
f"Unrecognised GGUF reference: {model_name_or_path} "
"(expected local file, raw URL, or <repo_id>/<filename>.gguf)"
)
def _get_gguf_weights_map(self, model_config: ModelConfig):
"""
GGUF uses this naming convention for their tensors from HF checkpoint:
`blk.N.BB.weight` and `blk.N.BB.bias`
where N signifies the block number of a layer, and BB signifies the
attention/mlp layer components.
See "Standardized tensor names" in
https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
"""
config = model_config.hf_config
model_type = config.model_type
gguf_to_hf_name_map = {}
# hack: ggufs have a different name than transformers
if model_type == "cohere":
model_type = "command-r"
if model_type == "gemma3_text":
# Gemma3 models use "gemma3_text" in HuggingFace but
# "gemma3" in GGUF architecture naming
model_type = "gemma3"
if model_type in ("deepseek_v3", "deepseek_v2"):
model_type = "deepseek2"
# GGUF layer map assumes that we will have a merged expert weights
# so we need to map them manually
for idx in range(config.num_hidden_layers):
gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = (
f"model.layers.{idx}.mlp.gate.e_score_correction_bias"
)
gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = (
f"model.layers.{idx}.mlp.experts.0.down_proj.weight"
)
gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = (
f"model.layers.{idx}.mlp.experts.0.gate_proj.weight"
)
gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = (
f"model.layers.{idx}.mlp.experts.0.up_proj.weight"
)
if model_type in ("qwen2_moe", "qwen3_moe"):
model_type = model_type.replace("_", "")
# GGUF layer map assumes that we will have a merged expert weights
# so we need to map them manually
for idx in range(config.num_hidden_layers):
gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = (
f"model.layers.{idx}.mlp.experts.0.down_proj.weight"
)
gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = (
f"model.layers.{idx}.mlp.experts.0.gate_proj.weight"
)
gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = (
f"model.layers.{idx}.mlp.experts.0.up_proj.weight"
)
arch = None
for key, value in gguf.MODEL_ARCH_NAMES.items():
if value == model_type:
arch = key
break
if arch is None:
raise RuntimeError(f"Unknown gguf model_type: {model_type}")
num_layers = config.num_hidden_layers
name_map = gguf.get_tensor_name_map(arch, num_layers)
with torch.device("meta"):
dummy_model = AutoModelForCausalLM.from_config(
config, trust_remote_code=model_config.trust_remote_code
)
state_dict = dummy_model.state_dict()
for hf_name in state_dict:
name, suffix = hf_name.rsplit(".", 1)
gguf_name = name_map.get_name(name)
gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name
return gguf_to_hf_name_map
def _get_weights_iterator(
self, model_name_or_path: str, gguf_to_hf_name_map: dict[str, str]
) -> Generator[tuple[str, torch.Tensor], None, None]:
return gguf_quant_weights_iterator(model_name_or_path, gguf_to_hf_name_map)
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model)
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
local_model_path = self._prepare_weights(model_config.model)
gguf_weights_map = self._get_gguf_weights_map(model_config)
model.load_weights(
self._get_weights_iterator(local_model_path, gguf_weights_map)
)
def load_model(
self, vllm_config: VllmConfig, model_config: ModelConfig
) -> nn.Module:
device_config = vllm_config.device_config
local_model_path = self._prepare_weights(model_config.model)
gguf_weights_map = self._get_gguf_weights_map(model_config)
# we can only know if tie word embeddings after mapping weights
if "lm_head.weight" in get_gguf_extra_tensor_names(
local_model_path, gguf_weights_map
):
model_config.hf_config.update({"tie_word_embeddings": True})
weight_type_map = get_gguf_weight_type_map(model_config.model, gguf_weights_map)
# filter out unquantized modules to skip
unquant_names = [
name.removesuffix(".weight")
for name, weight_type in weight_type_map.items()
if weight_type == "F32" and name.endswith(".weight")
]
vllm_config.quant_config.unquantized_modules.extend(unquant_names)
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = initialize_model(vllm_config=vllm_config)
self.load_weights(model, model_config)
process_weights_after_loading(model, model_config, target_device)
return model

View File

@@ -0,0 +1,224 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import types
import torch
from torch import nn
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
from vllm.model_executor.model_loader.utils import process_weights_after_loading
logger = init_logger(__name__)
# Notes for Online Quantization
# In terms of state of checkpoints, quantization config and their
# correspondance to online quantization:
# | Use Case | Checkpoints | model_config.quantization |
# | no quant | high precision | None |
# | offline quant | quantized | fp8, torchao etc. |
# | online quant | high precision | torchao etc. |
#
# The process for loading non-quantized checkpoint
# 1. load non-quantized weights (load_weights)
# 2. do any additional post processing (process_weights_after_loading)
#
# The process for loading offline quantized checkpoint
# 1. load offline-quantized weights (load_weights)
# 2. do any additional post processing (process_weights_after_loading)
# The process for unquantized model reloading
# (repeated run in RL training loop)
# first run
# UI1. load_weights: load bfloat16 weights
# UI2. process_weights_after_loading: any additional post processing
# subsequent run
# UC1: load_weights: load bfloat16 weights
# (shouldn't be any issues since we didn't change any attributes
# of the weights)
# UC2: process_weights_after_loading: any additional post processing
# The process for weight reloading with online quantization
# (repeated run in RL training loop)
# first run
# I1. load_weights: load bfloat16 weights
# I2. process_weights_after_loading:
# record weight metadata and attributes for R1 and R2
# quantize weights to fp8
# subsequent run
# (beginning model weight is in fp8)
# load_weights:
# R1. restore bfloat16 model weight metadata
# R2. restore the model weight attributes
# R3. reload bfloat16 weights
# R4. quantize weights (by calling process_weights_after_loading),
# also set `process_weights_after_loading_already_called` to
# True to stop it from running again
# process_weights_after_loading (if called):
# this will be skipped since it's already ran in
# load_weights
def maybe_save_metadata_and_attributes_for_weight_reloading(
model: nn.Module, model_config: ModelConfig
):
# following is to support on the fly quantization, currently only supported
# for torchao
if model_config.quantization != "torchao":
return
if getattr(model, "process_weights_after_loading_already_called", False):
# In case `process_weights_after_loading` is called multiple times
# we'll skip it at later times
logger.warning(
"process_weights_after_loading already called for model %s", model
)
return
from vllm.model_executor.model_loader.weight_utils import get_quant_config
quant_config = get_quant_config(model_config, None)
# If checkpoint is already torchao serialized, this means it's
# pre-quantized quantization case, we'll skip saving the metadata
# Otherwise, this is Step I2 of initialization steps of
# online quantization
# This step record the weights metadata and weight attributes so we can
# restore the bfloat16 model weights during the relad step (R1 and R2)
# see Notes in online_quantization.py for more details
if not (
hasattr(quant_config, "is_checkpoint_torchao_serialized")
and not quant_config.is_checkpoint_torchao_serialized
):
return
# This is the I2 step of online quantiztion that saves
# metadata and attributes of weights so they can be used in R1 and
# R2 step, note that we only save these during initialization
# Includes two things
# 1. save floating point metadata (shape, dtype, device) for init
# 2. save weight attributes, e.g. `output_dim`, `weight_loader` for init
if getattr(model, "weight_metadata_and_attr_saved", False):
return
# save the dtype, shape and device for model parameter, used for
# restoring the model high precision parameters before
# reloading the weights
assert not hasattr(model, "original_weights_rebuild_keys")
model.original_weights_rebuild_keys = {}
for name, p in model.named_parameters():
model.original_weights_rebuild_keys[name] = {
"shape": p.shape,
"dtype": p.dtype,
"device": p.device,
}
# record the weight attributes (loader functions etc.)
# so these can be recovered later when we reload the weights
# structure: {"weight_name": {"weight_attr_key": attr}}
assert not hasattr(model, "recorded_weight_attr")
model.recorded_weight_attr = {}
for name, param in model.named_parameters():
model.recorded_weight_attr[name] = {}
for key in param.__dict__:
if hasattr(param, key):
attr = getattr(param, key)
if not callable(attr):
model.recorded_weight_attr[name][key] = attr
elif hasattr(attr, "__self__") and param is attr.__self__:
# if attr is a bonded method for an instance, and
# attr.__self__ points to the instance (param)
# we'll record the underlying function object
model.recorded_weight_attr[name][key] = attr.__func__
else:
model.recorded_weight_attr[name][key] = attr
# mark the metadata and attributes saved so we don't run it again
model.weight_metadata_and_attr_saved = True
def _bond_method_to_cls(func, obj):
if hasattr(func, "__self__") or not callable(func):
# If the function is already bound to an instance, return it as is
return func
else:
return types.MethodType(func, obj)
def load_weights_and_online_quantize(
model_loader: DefaultModelLoader, model: nn.Module, model_config: ModelConfig
) -> set[str]:
# online quantization, right now only enabled for
# torchao
# R1, R2, R3, R4 in the Notes
# TODO: Add fp8 support
assert model_config.quantization == "torchao", (
"online quantization is only enabled for torchao currently"
)
# TODO: use create_weights to restore the weights to original state
# Step R1: First restore the quantized weights to original bfloat16
# weights, with original metadata (shape, dtype, device)
# and attributes, so that bfloat16 weights can be loaded properly
existing_param_names = dict(model.named_parameters(remove_duplicate=False)).keys()
named_modules = dict(model.named_modules(remove_duplicate=False))
model_device = None
# Step R2: recover the parameter to the state before first loading
for name, d in model.original_weights_rebuild_keys.items():
_shape = d["shape"]
_dtype = d["dtype"]
_device = d["device"]
if model_device is not None:
assert model_device == _device, (
"Expecting all weights "
"to be in the same device for now, got both: "
f"{model_device} and {_device}"
)
else:
model_device = _device
if name in existing_param_names:
module_name, weight_name = name.rsplit(".", 1)
module = named_modules[module_name]
setattr(
module,
weight_name,
torch.nn.Parameter(torch.empty(_shape, dtype=_dtype, device=_device)),
)
# recorded_weight_attr is
# {"weight_name": {"weight_attr_key": attr}}
# e.g.
# {
# {
# "layer.0.weight": {
# "weight_loader": weight_loader_function_object,
# "input_dim": 0, ...
# },
# "layer.1.weight": ...,
# }
# }
for full_weight_name, weight_attr_dict in model.recorded_weight_attr.items():
for attr_name, attr in weight_attr_dict.items():
module_name, weight_name = full_weight_name.rsplit(".", 1)
module = named_modules[module_name]
weight = getattr(module, weight_name)
if not hasattr(weight, attr_name):
setattr(weight, attr_name, _bond_method_to_cls(attr, weight))
# Step I1: reload bfloat16 / high precision weights
loaded_weights = model.load_weights(
model_loader.get_all_weights(model_config, model)
)
# Step I2: online quantize the weights
# manually process weights after loading
model.process_weights_after_loading_already_called = False
process_weights_after_loading(model, model_config, model_device)
model.process_weights_after_loading_already_called = True
return loaded_weights

View File

@@ -0,0 +1,116 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: SIM117
import os
from collections.abc import Generator
import torch
from torch import nn
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import ModelConfig
from vllm.config.load import LoadConfig
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf,
download_weights_from_hf,
runai_safetensors_weights_iterator,
)
from vllm.transformers_utils.runai_utils import is_runai_obj_uri, list_safetensors
class RunaiModelStreamerLoader(BaseModelLoader):
"""
Model loader that can load safetensors
files from local FS or S3 bucket.
"""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
self._is_distributed = False
if load_config.model_loader_extra_config:
extra_config = load_config.model_loader_extra_config
if "distributed" in extra_config and isinstance(
extra_config.get("distributed"), bool
):
self._is_distributed = extra_config.get("distributed")
if "concurrency" in extra_config and isinstance(
extra_config.get("concurrency"), int
):
os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(
extra_config.get("concurrency")
)
if "memory_limit" in extra_config and isinstance(
extra_config.get("memory_limit"), int
):
os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(
extra_config.get("memory_limit")
)
runai_streamer_s3_endpoint = os.getenv("RUNAI_STREAMER_S3_ENDPOINT")
aws_endpoint_url = os.getenv("AWS_ENDPOINT_URL")
if runai_streamer_s3_endpoint is None and aws_endpoint_url is not None:
os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url
def _prepare_weights(
self, model_name_or_path: str, revision: str | None
) -> list[str]:
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
is_object_storage_path = is_runai_obj_uri(model_name_or_path)
is_local = os.path.isdir(model_name_or_path)
safetensors_pattern = "*.safetensors"
index_file = SAFE_WEIGHTS_INDEX_NAME
hf_folder = (
model_name_or_path
if (is_local or is_object_storage_path)
else download_weights_from_hf(
model_name_or_path,
self.load_config.download_dir,
[safetensors_pattern],
revision,
ignore_patterns=self.load_config.ignore_patterns,
)
)
hf_weights_files = list_safetensors(path=hf_folder)
if not is_local and not is_object_storage_path:
download_safetensors_index_file_from_hf(
model_name_or_path, index_file, self.load_config.download_dir, revision
)
if not hf_weights_files:
raise RuntimeError(
f"Cannot find any safetensors model weights with `{model_name_or_path}`"
)
return hf_weights_files
def _get_weights_iterator(
self, model_or_path: str, revision: str
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
hf_weights_files = self._prepare_weights(model_or_path, revision)
return runai_safetensors_weights_iterator(
hf_weights_files, self.load_config.use_tqdm_on_load, self._is_distributed
)
def download_model(self, model_config: ModelConfig) -> None:
"""Download model if necessary"""
self._prepare_weights(model_config.model, model_config.revision)
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
"""Load weights into a model."""
model_weights = model_config.model
if hasattr(model_config, "model_weights"):
model_weights = model_config.model_weights
model.load_weights(
self._get_weights_iterator(model_weights, model_config.revision)
)

View File

@@ -0,0 +1,206 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import collections
import glob
import os
from collections.abc import Generator
from typing import Any
import torch
from torch import nn
from vllm.config import ModelConfig
from vllm.config.load import LoadConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.weight_utils import (
download_weights_from_hf,
runai_safetensors_weights_iterator,
)
from vllm.transformers_utils.s3_utils import glob as s3_glob
from vllm.transformers_utils.utils import is_s3
logger = init_logger(__name__)
class ShardedStateLoader(BaseModelLoader):
"""
Model loader that directly loads each worker's model state dict, which
enables a fast load path for large tensor-parallel models where each worker
only needs to read its own shard rather than the entire checkpoint. See
`examples/offline_inference/save_sharded_state.py` for creating a sharded
checkpoint.
"""
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
extra_config = (
{}
if load_config.model_loader_extra_config is None
else load_config.model_loader_extra_config.copy()
)
self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
if extra_config:
raise ValueError(
f"Unexpected extra config keys for load format "
f"{load_config.load_format}: "
f"{load_config.model_loader_extra_config.keys()}"
)
@staticmethod
def _filter_subtensors(
tensors: dict[str, torch.Tensor],
) -> dict[str, torch.Tensor]:
"""
Filter out all tensors that share the same memory or a subset of the
memory of another tensor.
"""
same_storage_groups: dict[Any, list[tuple[str, torch.Tensor]]] = (
collections.defaultdict(list)
)
for key, tensor in tensors.items():
if tensor.numel():
ptr = tensor.untyped_storage().data_ptr()
same_storage_groups[tensor.device, ptr].append((key, tensor))
def get_end_ptr(tensor: torch.Tensor) -> int:
return tensor.view(-1)[-1].data_ptr() + tensor.element_size()
result: dict[str, torch.Tensor] = {}
for group in same_storage_groups.values():
for k, t in group:
a, b = t.data_ptr(), get_end_ptr(t)
for k2, t2 in group:
if not t2.is_contiguous():
continue
a2, b2 = t2.data_ptr(), get_end_ptr(t2)
if a < a2 or b2 < b:
continue
if a2 < a or b < b2 or not t.is_contiguous():
break # t2 covers strictly more memory than t.
if k2 < k:
# Same tensors, keep the one with the smaller key.
break
else:
result[k] = t
return result
def _prepare_weights(self, model_name_or_path: str, revision: str | None):
if is_s3(model_name_or_path) or os.path.isdir(model_name_or_path):
return model_name_or_path
else:
allow_patterns = ["*.safetensors"]
return download_weights_from_hf(
model_name_or_path,
self.load_config.download_dir,
allow_patterns,
revision,
ignore_patterns=self.load_config.ignore_patterns,
)
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
from vllm.distributed import get_tensor_model_parallel_rank
model_weights = model_config.model
if hasattr(model_config, "model_weights"):
model_weights = model_config.model_weights
local_model_path = model_weights
rank = get_tensor_model_parallel_rank()
pattern = os.path.join(
local_model_path,
self.pattern.format(rank=rank, part="*"),
)
filepaths = []
if is_s3(local_model_path):
file_pattern = f"*{self.pattern.format(rank=rank, part='*')}"
filepaths = s3_glob(path=local_model_path, allow_pattern=[file_pattern])
else:
filepaths = glob.glob(pattern)
if not filepaths:
# TODO: support un-sharded checkpoints too
raise ValueError(
f"Could not find checkpoint files '{pattern}', only "
f"pre-sharded checkpoints are currently supported!"
)
state_dict = self._filter_subtensors(model.state_dict())
for key, tensor in self.iterate_over_files(filepaths):
# If loading with LoRA enabled, additional padding may
# be added to certain parameters. We only load into a
# narrowed view of the parameter data.
param_data = state_dict[key].data
param_shape = state_dict[key].shape
for dim, size in enumerate(tensor.shape):
if size < param_shape[dim]:
param_data = param_data.narrow(dim, 0, size)
if tensor.shape != param_shape:
logger.warning(
"loading tensor of shape %s into parameter '%s' of shape %s",
tensor.shape,
key,
param_shape,
)
param_data.copy_(tensor)
state_dict.pop(key)
if state_dict:
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
def iterate_over_files(
self, paths
) -> Generator[tuple[str, torch.Tensor], None, None]:
if self.load_config.load_format == "runai_streamer_sharded":
yield from runai_safetensors_weights_iterator(paths, True)
else:
from safetensors.torch import safe_open
for path in paths:
with safe_open(path, framework="pt") as f:
for key in f.keys(): # noqa: SIM118
tensor = f.get_tensor(key)
yield key, tensor
@staticmethod
def save_model(
model: torch.nn.Module,
path: str,
pattern: str | None = None,
max_size: int | None = None,
) -> None:
from safetensors.torch import save_file
from vllm.distributed import get_tensor_model_parallel_rank
if pattern is None:
pattern = ShardedStateLoader.DEFAULT_PATTERN
rank = get_tensor_model_parallel_rank()
part_idx = 0
total_size = 0
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
state_dict_part: dict[str, torch.Tensor] = {}
for key, tensor in state_dict.items():
param_size = tensor.nelement() * tensor.element_size()
if max_size is not None and total_size + param_size > max_size:
filename = pattern.format(rank=rank, part=part_idx)
save_file(
state_dict_part,
os.path.join(path, filename),
)
part_idx += 1
total_size = 0
state_dict_part = {}
state_dict_part[key] = tensor
total_size += param_size
if len(state_dict_part) > 0:
filename = pattern.format(rank=rank, part=part_idx)
save_file(
state_dict_part,
os.path.join(path, filename),
)

View File

@@ -0,0 +1,790 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import contextlib
import contextvars
import dataclasses
import json
import os
import tempfile
import threading
import time
from collections.abc import Generator, MutableMapping
from dataclasses import asdict, dataclass, field, fields
from typing import TYPE_CHECKING, Any, ClassVar, Optional
import regex as re
import torch
from huggingface_hub import snapshot_download
from torch import nn
from torch.utils._python_dispatch import TorchDispatchMode
from transformers import PretrainedConfig
import vllm.envs as envs
from vllm.config import ModelConfig, ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.platforms import current_platform
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.import_utils import PlaceholderModule
if TYPE_CHECKING:
from vllm.engine.arg_utils import EngineArgs
try:
from tensorizer import (
DecryptionParams,
EncryptionParams,
TensorDeserializer,
TensorSerializer,
)
from tensorizer.stream_io import open_stream
from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor
except ImportError:
tensorizer = PlaceholderModule("tensorizer")
DecryptionParams = tensorizer.placeholder_attr("DecryptionParams")
EncryptionParams = tensorizer.placeholder_attr("EncryptionParams")
TensorDeserializer = tensorizer.placeholder_attr("TensorDeserializer")
TensorSerializer = tensorizer.placeholder_attr("TensorSerializer")
open_stream = tensorizer.placeholder_attr("stream_io.open_stream")
convert_bytes = tensorizer.placeholder_attr("utils.convert_bytes")
get_mem_usage = tensorizer.placeholder_attr("utils.get_mem_usage")
no_init_or_tensor = tensorizer.placeholder_attr("utils.no_init_or_tensor")
__all__ = [
"EncryptionParams",
"DecryptionParams",
"TensorDeserializer",
"TensorSerializer",
"open_stream",
"convert_bytes",
"get_mem_usage",
"no_init_or_tensor",
"TensorizerConfig",
]
logger = init_logger(__name__)
def is_valid_deserialization_uri(uri: str | None) -> bool:
if uri:
scheme = uri.lower().split("://")[0]
return scheme in {"s3", "http", "https"} or os.path.exists(uri)
return False
def tensorizer_kwargs_arg(value):
loaded = json.loads(value)
if not isinstance(loaded, dict):
raise argparse.ArgumentTypeError(
f"Not deserializable to dict: {value}. serialization_kwargs and "
f"deserialization_kwargs must be "
f"deserializable from a JSON string to a dictionary. "
)
return loaded
class MetaTensorMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
if func._schema.name == "aten::empty" and "device" not in kwargs:
kwargs["device"] = "meta"
return func(*args, **kwargs)
def meta_tensor_mode(
loading_code=None,
):
if loading_code is None:
return _NoInitOrTensorImpl.context_manager()
elif callable(loading_code):
with _NoInitOrTensorImpl.context_manager():
return loading_code()
else:
raise TypeError(
"expected a callable to evaluate,"
" or None if being used as a context manager;"
f' got an object of type "{type(loading_code).__name__}" instead.'
)
class _NoInitOrTensorImpl:
_MODULES = (torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm)
_MODULE_ORIGINALS = tuple((m, m.reset_parameters) for m in _MODULES)
is_active = contextvars.ContextVar("_NoInitOrTensorImpl.is_active", default=False)
_count_active: int = 0
_count_active_lock = threading.Lock()
@classmethod
@contextlib.contextmanager
def context_manager(cls):
if cls.is_active.get():
yield
return
with cls._count_active_lock:
cls._count_active += 1
if cls._count_active == 1:
for mod in cls._MODULES:
mod.reset_parameters = cls._disable(mod.reset_parameters)
reset_token = cls.is_active.set(True)
try:
with MetaTensorMode():
yield
finally:
cls.is_active.reset(reset_token)
with cls._count_active_lock:
cls._count_active -= 1
if cls._count_active == 0:
for mod, original in cls._MODULE_ORIGINALS:
mod.reset_parameters = original
@staticmethod
def _disable(func):
def wrapper(*args, **kwargs):
if not _NoInitOrTensorImpl.is_active.get():
return func(*args, **kwargs)
return wrapper
@dataclass
class TensorizerConfig(MutableMapping):
tensorizer_uri: str | None = None
tensorizer_dir: str | None = None
vllm_tensorized: bool | None = None
verify_hash: bool | None = None
num_readers: int | None = None
encryption_keyfile: str | None = None
s3_access_key_id: str | None = None
s3_secret_access_key: str | None = None
s3_endpoint: str | None = None
lora_dir: str | None = None
stream_kwargs: dict[str, Any] | None = None
serialization_kwargs: dict[str, Any] | None = None
deserialization_kwargs: dict[str, Any] | None = None
_extra_serialization_attrs: dict[str, Any] | None = field(init=False, default=None)
model_class: type[torch.nn.Module] | None = field(init=False, default=None)
hf_config: PretrainedConfig | None = field(init=False, default=None)
dtype: str | torch.dtype | None = field(init=False, default=None)
_is_sharded: bool = field(init=False, default=False)
_fields: ClassVar[tuple[str, ...]]
_keys: ClassVar[frozenset[str]]
"""Configuration class for Tensorizer settings.
These settings configure the behavior of model serialization and
deserialization using Tensorizer.
Attributes:
tensorizer_uri: Path to serialized model tensors. Can be a local file
path or a S3 URI. This is a required field unless lora_dir is
provided and the config is meant to be used for the
`tensorize_lora_adapter` function. Unless a `tensorizer_dir` or
`lora_dir` is passed to this object's initializer, this is
a required argument.
tensorizer_dir: Path to a directory containing serialized model tensors,
and all other potential model artifacts to load the model, such as
configs and tokenizer files. Can be passed instead of
`tensorizer_uri` where the `model.tensors` file will be assumed
to be in this directory.
vllm_tensorized: If True, indicates that the serialized model is a
vLLM model. This is used to determine the behavior of the
TensorDeserializer when loading tensors from a serialized model.
It is far faster to deserialize a vLLM model as it utilizes
tensorizer's optimized GPU loading. Note that this is now
deprecated, as serialized vLLM models are now automatically
inferred as vLLM models.
verify_hash: If True, the hashes of each tensor will be verified
against the hashes stored in the metadata. A `HashMismatchError`
will be raised if any of the hashes do not match.
num_readers: Controls how many threads are allowed to read concurrently
from the source file. Default is `None`, which will dynamically set
the number of readers based on the number of available
resources and model size. This greatly increases performance.
encryption_keyfile: File path to a binary file containing a
binary key to use for decryption. `None` (the default) means
no decryption. See the example script in
examples/others/tensorize_vllm_model.py.
s3_access_key_id: The access key for the S3 bucket. Can also be set via
the S3_ACCESS_KEY_ID environment variable.
s3_secret_access_key: The secret access key for the S3 bucket. Can also
be set via the S3_SECRET_ACCESS_KEY environment variable.
s3_endpoint: The endpoint for the S3 bucket. Can also be set via the
S3_ENDPOINT_URL environment variable.
lora_dir: Path to a directory containing LoRA adapter artifacts for
serialization or deserialization. When serializing LoRA adapters
this is the only necessary parameter to pass to this object's
initializer.
"""
def __post_init__(self):
# check if the configuration is for a sharded vLLM model
self._is_sharded = (
isinstance(self.tensorizer_uri, str)
and re.search(r"%0\dd", self.tensorizer_uri) is not None
)
if self.tensorizer_dir and self.lora_dir:
raise ValueError(
"Only one of tensorizer_dir or lora_dir may be specified. "
"Use lora_dir exclusively when serializing LoRA adapters, "
"and tensorizer_dir or tensorizer_uri otherwise."
)
if self.tensorizer_dir and self.tensorizer_uri:
logger.warning_once(
"Provided both tensorizer_dir and tensorizer_uri. "
"Inferring tensorizer_dir from tensorizer_uri as the "
"latter takes precedence."
)
self.tensorizer_dir = os.path.dirname(self.tensorizer_uri)
if not self.tensorizer_uri:
if self.lora_dir:
self.tensorizer_uri = f"{self.lora_dir}/adapter_model.tensors"
elif self.tensorizer_dir:
self.tensorizer_uri = f"{self.tensorizer_dir}/model.tensors"
else:
raise ValueError(
"Unable to resolve tensorizer_uri. "
"A valid tensorizer_uri or tensorizer_dir "
"must be provided for deserialization, and a "
"valid tensorizer_uri, tensorizer_uri, or "
"lora_dir for serialization."
)
else:
self.tensorizer_dir = os.path.dirname(self.tensorizer_uri)
if not self.serialization_kwargs:
self.serialization_kwargs = {}
if not self.deserialization_kwargs:
self.deserialization_kwargs = {}
def to_serializable(self) -> dict[str, Any]:
# Due to TensorizerConfig needing to be msgpack-serializable, it needs
# support for morphing back and forth between itself and its dict
# representation
# TensorizerConfig's representation as a dictionary is meant to be
# linked to TensorizerConfig in such a way that the following is
# technically initializable:
# TensorizerConfig(**my_tensorizer_cfg.to_serializable())
# This means the dict must not retain non-initializable parameters
# and post-init attribute states
# Also don't want to retain private and unset parameters, so only retain
# not None values and public attributes
raw_tc_dict = asdict(self)
blacklisted = []
if "tensorizer_uri" in raw_tc_dict and "tensorizer_dir" in raw_tc_dict:
blacklisted.append("tensorizer_dir")
if "tensorizer_dir" in raw_tc_dict and "lora_dir" in raw_tc_dict:
blacklisted.append("tensorizer_dir")
tc_dict = {}
for k, v in raw_tc_dict.items():
if (
k not in blacklisted
and k not in tc_dict
and not k.startswith("_")
and v is not None
):
tc_dict[k] = v
return tc_dict
def _construct_tensorizer_args(self) -> "TensorizerArgs":
return TensorizerArgs(self) # type: ignore
def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",
) -> None:
if parallel_config.tensor_parallel_size > 1 and not self._is_sharded:
raise ValueError(
"For a sharded model, tensorizer_uri should include a"
" string format template like '%04d' to be formatted"
" with the rank of the shard"
)
def verify_with_model_config(self, model_config: "ModelConfig") -> None:
if model_config.quantization is not None and self.tensorizer_uri is not None:
logger.warning(
"Loading a model using Tensorizer with quantization on vLLM"
" is unstable and may lead to errors."
)
def open_stream(self, tensorizer_args: Optional["TensorizerArgs"] = None):
if tensorizer_args is None:
tensorizer_args = self._construct_tensorizer_args()
return open_stream(self.tensorizer_uri, **tensorizer_args.stream_kwargs)
def keys(self):
return self._keys
def __len__(self):
return len(fields(self))
def __iter__(self):
return iter(self._fields)
def __getitem__(self, item: str) -> Any:
if item not in self.keys():
raise KeyError(item)
return getattr(self, item)
def __setitem__(self, key: str, value: Any) -> None:
if key not in self.keys():
# Disallow modifying invalid keys
raise KeyError(key)
setattr(self, key, value)
def __delitem__(self, key, /):
if key not in self.keys():
raise KeyError(key)
delattr(self, key)
TensorizerConfig._fields = tuple(f.name for f in fields(TensorizerConfig))
TensorizerConfig._keys = frozenset(TensorizerConfig._fields)
@dataclass
class TensorizerArgs:
tensorizer_uri: str | None = None
tensorizer_dir: str | None = None
encryption_keyfile: str | None = None
def __init__(self, tensorizer_config: TensorizerConfig):
for k, v in tensorizer_config.items():
setattr(self, k, v)
self.file_obj = tensorizer_config.tensorizer_uri
self.s3_access_key_id = (
tensorizer_config.s3_access_key_id or envs.S3_ACCESS_KEY_ID
)
self.s3_secret_access_key = (
tensorizer_config.s3_secret_access_key or envs.S3_SECRET_ACCESS_KEY
)
self.s3_endpoint = tensorizer_config.s3_endpoint or envs.S3_ENDPOINT_URL
self.stream_kwargs = {
"s3_access_key_id": tensorizer_config.s3_access_key_id,
"s3_secret_access_key": tensorizer_config.s3_secret_access_key,
"s3_endpoint": tensorizer_config.s3_endpoint,
**(tensorizer_config.stream_kwargs or {}),
}
self.deserialization_kwargs = {
"verify_hash": tensorizer_config.verify_hash,
"encryption": tensorizer_config.encryption_keyfile,
"num_readers": tensorizer_config.num_readers,
**(tensorizer_config.deserialization_kwargs or {}),
}
if self.encryption_keyfile:
with open_stream(
tensorizer_config.encryption_keyfile,
**self.stream_kwargs,
) as stream:
key = stream.read()
decryption_params = DecryptionParams.from_key(key)
self.deserialization_kwargs["encryption"] = decryption_params
@staticmethod
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""Tensorizer CLI arguments"""
# Tensorizer options arg group
group = parser.add_argument_group(
"tensorizer options",
description=(
"Options for configuring the behavior of the"
" tensorizer deserializer when "
"load_format=tensorizer is specified when "
"initializing an LLMEngine, either via the CLI "
"when running the vLLM OpenAI inference server "
"with a JSON string passed to "
"--model-loader-extra-config or as arguments given "
"to TensorizerConfig when passed to "
"model_loader_extra_config in the constructor "
"for LLMEngine."
),
)
group.add_argument(
"--tensorizer-uri",
type=str,
help="Path to serialized model tensors. Can be a local file path,"
" or an HTTP(S) or S3 URI.",
)
group.add_argument(
"--verify-hash",
action="store_true",
help="If enabled, the hashes of each tensor will be verified"
" against the hashes stored in the file metadata. An exception"
" will be raised if any of the hashes do not match.",
)
group.add_argument(
"--encryption-keyfile",
type=str,
default=None,
help="The file path to a binary file containing a binary key to "
"use for decryption. Can be a file path or S3 network URI.",
)
group.add_argument(
"--num-readers",
default=None,
type=int,
help="Controls how many threads are allowed to read concurrently "
"from the source file. Default is `None`, which will dynamically "
"set the number of readers based on the available resources "
"and model size. This greatly increases performance.",
)
group.add_argument(
"--s3-access-key-id",
type=str,
default=None,
help="The access key for the S3 bucket. Can also be set via the "
"S3_ACCESS_KEY_ID environment variable.",
)
group.add_argument(
"--s3-secret-access-key",
type=str,
default=None,
help="The secret access key for the S3 bucket. Can also be set via "
"the S3_SECRET_ACCESS_KEY environment variable.",
)
group.add_argument(
"--s3-endpoint",
type=str,
default=None,
help="The endpoint for the S3 bucket. Can also be set via the "
"S3_ENDPOINT_URL environment variable.",
)
return parser
@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> "TensorizerArgs":
attrs = [attr.name for attr in dataclasses.fields(cls)]
tensorizer_args = cls(
**{attr: getattr(args, attr) for attr in attrs if hasattr(args, attr)}
)
return tensorizer_args
def _check_tensors_on_meta_device(model: nn.Module) -> None:
for tensor in model.state_dict().values():
if tensor.device.type == "meta":
raise ValueError(
"The serialized model contains tensors on the meta device,"
" indicating that some tensors were not loaded properly."
" Please check that the parameters of the model being"
" specified match that of the serialized model, such as"
" its quantization."
)
def _resize_lora_embeddings(model: nn.Module):
"""Modify LoRA embedding layers to use bigger tensors
to allow for adapter added tokens."""
for child in model.modules():
if (
isinstance(child, VocabParallelEmbedding)
and child.weight.shape[0] < child.num_embeddings_per_partition
):
new_weight = torch.empty(
child.num_embeddings_per_partition,
child.embedding_dim,
dtype=child.weight.dtype,
device=child.weight.device,
)
new_weight[: child.weight.shape[0]].copy_(child.weight.data)
new_weight[child.weight.shape[0] :].fill_(0)
child.weight.data = new_weight
def init_tensorizer_model(
tensorizer_config: TensorizerConfig, vllm_config: VllmConfig
) -> nn.Module:
assert tensorizer_config.hf_config is not None
model_args = tensorizer_config.hf_config
model_args.dtype = tensorizer_config.dtype
assert tensorizer_config.model_class is not None
# TODO: Do we need to consider old-style model class?
with meta_tensor_mode(), set_current_vllm_config(vllm_config, check_compile=True):
return tensorizer_config.model_class(vllm_config=vllm_config)
def deserialize_tensorizer_model(
model: nn.Module, tensorizer_config: TensorizerConfig
) -> None:
tensorizer_args = tensorizer_config._construct_tensorizer_args()
if not is_valid_deserialization_uri(tensorizer_config.tensorizer_uri):
raise ValueError(
f"{tensorizer_config.tensorizer_uri} is not a valid "
f"tensorizer URI. Please check that the URI is correct. "
f"It must either point to a local existing file, or have a "
f"S3, HTTP or HTTPS scheme."
)
before_mem = get_mem_usage()
start = time.perf_counter()
with (
open_stream(
tensorizer_config.tensorizer_uri, mode="rb", **tensorizer_args.stream_kwargs
) as stream,
TensorDeserializer(
stream,
dtype=tensorizer_config.dtype,
device=f"xpu:{torch.xpu.current_device()}"
if current_platform.is_xpu()
else f"cuda:{torch.cuda.current_device()}",
**tensorizer_args.deserialization_kwargs,
) as deserializer,
):
deserializer.load_into_module(model)
end = time.perf_counter()
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
duration = end - start
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
after_mem = get_mem_usage()
deserializer.close()
logger.info(
"Deserialized %s in %0.2fs, %s/s", total_bytes_str, end - start, per_second
)
logger.info("Memory usage before: %s", before_mem)
logger.info("Memory usage after: %s", after_mem)
_check_tensors_on_meta_device(model)
_resize_lora_embeddings(model)
del model.vllm_tensorized_marker
def tensorizer_weights_iterator(
tensorizer_args: "TensorizerArgs",
) -> Generator[tuple[str, torch.Tensor], None, None]:
logger.warning(
"Deserializing HuggingFace models is not optimized for "
"loading on vLLM, as tensorizer is forced to load to CPU. "
"Consider deserializing a vLLM model instead for faster "
"load times. See the "
"examples/others/tensorize_vllm_model.py example script "
"for serializing vLLM models."
)
deserializer_args = tensorizer_args.deserialization_kwargs
stream_kwargs = tensorizer_args.stream_kwargs
stream = open_stream(tensorizer_args.tensorizer_uri, **stream_kwargs)
with TensorDeserializer(stream, **deserializer_args, device="cpu") as state:
yield from state.items()
del state
def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool:
"""
Infer if the model is a vLLM model by checking the weights for
a vLLM tensorized marker.
Args:
tensorizer_config: The TensorizerConfig object containing the
tensorizer_uri to the serialized model.
Returns:
bool: True if the model is a vLLM model, False otherwise.
"""
tensorizer_args = tensorizer_config._construct_tensorizer_args()
deserializer = TensorDeserializer(
open_stream(tensorizer_args.tensorizer_uri, **tensorizer_args.stream_kwargs),
**tensorizer_args.deserialization_kwargs,
lazy_load=True,
)
if tensorizer_config.vllm_tensorized:
logger.warning(
"Please note that newly serialized vLLM models are automatically "
"inferred as vLLM models, so setting vllm_tensorized=True is "
"only necessary for models serialized prior to this change."
)
return True
return ".vllm_tensorized_marker" in deserializer
def serialize_extra_artifacts(
tensorizer_args: TensorizerArgs, served_model_name: str | list[str] | None
) -> None:
if not isinstance(served_model_name, str):
raise ValueError(
f"served_model_name must be a str for serialize_extra_artifacts, "
f"not {type(served_model_name)}."
)
with tempfile.TemporaryDirectory() as tmpdir:
snapshot_download(
served_model_name,
local_dir=tmpdir,
ignore_patterns=[
"*.pt",
"*.safetensors",
"*.bin",
"*.cache",
"*.gitattributes",
"*.md",
],
)
for artifact in os.scandir(tmpdir):
if not artifact.is_file():
continue
with (
open(artifact.path, "rb") as f,
open_stream(
f"{tensorizer_args.tensorizer_dir}/{artifact.name}",
mode="wb+",
**tensorizer_args.stream_kwargs,
) as stream,
):
logger.info("Writing artifact %s", artifact.name)
stream.write(f.read())
def serialize_vllm_model(
model: nn.Module,
tensorizer_config: TensorizerConfig,
model_config: "ModelConfig",
) -> nn.Module:
model.register_parameter(
"vllm_tensorized_marker",
nn.Parameter(torch.tensor((1,), device="meta"), requires_grad=False),
)
tensorizer_args = tensorizer_config._construct_tensorizer_args()
encryption_params = None
if (keyfile := tensorizer_config.encryption_keyfile) is not None:
with open(keyfile, "rb") as f:
key = f.read()
encryption_params = EncryptionParams(key=key)
output_file = tensorizer_args.tensorizer_uri
if tensorizer_config._is_sharded:
from vllm.distributed import get_tensor_model_parallel_rank
output_file = output_file % get_tensor_model_parallel_rank()
with open_stream(
output_file, mode="wb+", **tensorizer_args.stream_kwargs
) as stream:
serializer = TensorSerializer(
stream,
encryption=encryption_params,
**tensorizer_config.serialization_kwargs,
)
serializer.write_module(model)
serializer.close()
serialize_extra_artifacts(tensorizer_args, model_config.served_model_name)
logger.info("Successfully serialized model to %s", str(output_file))
return model
def tensorize_vllm_model(
engine_args: "EngineArgs",
tensorizer_config: TensorizerConfig,
generate_keyfile: bool = True,
):
"""Utility to load a model and then serialize it with Tensorizer
Intended to be used separately from running a vLLM server since it
creates its own Engine instance.
"""
engine_config = engine_args.create_engine_config()
tensorizer_config.verify_with_model_config(engine_config.model_config)
tensorizer_config.verify_with_parallel_config(engine_config.parallel_config)
# generate the encryption key before creating the engine to support sharding
if (
generate_keyfile
and (keyfile := tensorizer_config.encryption_keyfile) is not None
):
encryption_params = EncryptionParams.random()
with open_stream(
keyfile,
mode="wb+",
s3_access_key_id=tensorizer_config.s3_access_key_id,
s3_secret_access_key=tensorizer_config.s3_secret_access_key,
s3_endpoint=tensorizer_config.s3_endpoint,
) as stream:
stream.write(encryption_params.key)
from vllm.v1.engine.llm_engine import LLMEngine
engine = LLMEngine.from_vllm_config(engine_config)
engine.collective_rpc(
"save_tensorized_model",
kwargs={"tensorizer_config": tensorizer_config.to_serializable()},
)
def tensorize_lora_adapter(lora_path: str, tensorizer_config: TensorizerConfig):
"""
Uses tensorizer to serialize a LoRA adapter. Assumes that the files
needed to load a LoRA adapter are a safetensors-format file called
adapter_model.safetensors and a json config file called adapter_config.json.
Serializes the files in the tensorizer_config.tensorizer_dir
"""
import safetensors
from vllm.lora.utils import get_adapter_absolute_path
lora_dir = get_adapter_absolute_path(lora_path)
tensor_path = config_path = ""
for file in os.listdir(lora_dir):
if file.startswith("adapter_model"):
tensor_path = lora_dir + "/" + file
if file.startswith("adapter_config"):
config_path = lora_dir + "/" + file
if tensor_path and config_path:
break
if tensor_path.endswith(".safetensors"):
tensors = safetensors.torch.load_file(tensor_path)
elif tensor_path.endswith(".bin"):
tensors = torch.load(tensor_path)
else:
raise ValueError("Unsupported file: %s", tensor_path)
with open(config_path) as f:
config = json.load(f)
tensorizer_args = tensorizer_config._construct_tensorizer_args()
with open_stream(
f"{tensorizer_config.tensorizer_dir}/adapter_config.json",
mode="wb+",
**tensorizer_args.stream_kwargs,
) as f:
f.write(json.dumps(config).encode("utf-8"))
lora_uri = f"{tensorizer_config.tensorizer_dir}/adapter_model.tensors"
with open_stream(lora_uri, mode="wb+", **tensorizer_args.stream_kwargs) as f:
serializer = TensorSerializer(f)
serializer.write_state_dict(tensors)
serializer.close()
logger.info(
"Successfully serialized LoRA files to %s",
str(tensorizer_config.tensorizer_dir),
)

View File

@@ -0,0 +1,151 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: SIM117
import copy
from collections.abc import Generator
import torch
from torch import nn
from vllm.config import ModelConfig, ParallelConfig, VllmConfig
from vllm.config.load import LoadConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig,
deserialize_tensorizer_model,
init_tensorizer_model,
is_vllm_tensorized,
serialize_vllm_model,
tensorizer_weights_iterator,
)
from vllm.model_executor.model_loader.utils import (
get_model_architecture,
initialize_model,
)
from vllm.utils.torch_utils import set_default_torch_dtype
logger = init_logger(__name__)
BLACKLISTED_TENSORIZER_ARGS = {
"device", # vLLM decides this
"dtype", # vLLM decides this
"mode", # Not meant to be configurable by the user
}
def validate_config(config: dict):
for k, v in config.items():
if v is not None and k in BLACKLISTED_TENSORIZER_ARGS:
raise ValueError(f"{k} is not an allowed Tensorizer argument.")
class TensorizerLoader(BaseModelLoader):
"""Model loader using CoreWeave's tensorizer library."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
self.tensorizer_config = load_config.model_loader_extra_config
else:
validate_config(load_config.model_loader_extra_config)
self.tensorizer_config = TensorizerConfig(
**load_config.model_loader_extra_config["tensorizer_config"]
)
def _verify_config(
self, model_config: ModelConfig, parallel_config: ParallelConfig
):
self.tensorizer_config.verify_with_model_config(model_config)
self.tensorizer_config.verify_with_parallel_config(parallel_config)
def _get_weights_iterator(
self,
) -> Generator[tuple[str, torch.Tensor], None, None]:
tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
return tensorizer_weights_iterator(tensorizer_args)
def _load_model_serialized_cpu(
self,
vllm_config: VllmConfig,
) -> nn.Module:
"""Load a serialized model with tensorizer to the CPU.
This is only necessary when the model isn't vLLM-tensorized (see
examples/others/tensorize_vllm_model.py) This should still
be faster than default HuggingFace loading, but will be slower than
loading a vLLM-tensorized model.
"""
device_config = vllm_config.device_config
model_config = vllm_config.model_config
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = initialize_model(vllm_config=vllm_config)
model.load_weights(self._get_weights_iterator())
return model.eval()
def download_model(self, model_config: ModelConfig) -> None:
self.tensorizer_config.verify_with_model_config(model_config)
with self.tensorizer_config.open_stream():
pass
def _patch_tensorizer_config(self, model_config: ModelConfig) -> TensorizerConfig:
model_class = get_model_architecture(model_config)[0]
tensorizer_config = copy.copy(self.tensorizer_config)
tensorizer_config.model_class = model_class
tensorizer_config.hf_config = model_config.hf_config
tensorizer_config.dtype = model_config.dtype
return tensorizer_config
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
"""Load serialized model weights with tensorizer.
Expects a vLLM-tensorized model. See the
examples/others/tensorize_vllm_model.py example script
for serializing vLLM models."""
if is_vllm_tensorized(self.tensorizer_config):
tensorizer_config = self._patch_tensorizer_config(model_config)
deserialize_tensorizer_model(model, tensorizer_config)
else:
model.load_weights(self._get_weights_iterator())
def load_model(
self, vllm_config: VllmConfig, model_config: ModelConfig
) -> nn.Module:
parallel_config = vllm_config.parallel_config
self._verify_config(model_config, parallel_config)
if parallel_config.tensor_parallel_size > 1:
from vllm.distributed import get_tensor_model_parallel_rank
self.tensorizer_config.tensorizer_uri = (
self.tensorizer_config.tensorizer_uri % get_tensor_model_parallel_rank()
)
if is_vllm_tensorized(self.tensorizer_config):
tensorizer_config = self._patch_tensorizer_config(model_config)
device_config = vllm_config.device_config
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = init_tensorizer_model(
tensorizer_config=tensorizer_config, vllm_config=vllm_config
)
self.load_weights(model, model_config)
return model
return self._load_model_serialized_cpu(vllm_config=vllm_config)
@staticmethod
def save_model(
model: torch.nn.Module,
tensorizer_config: TensorizerConfig | dict,
model_config: ModelConfig,
) -> None:
if isinstance(tensorizer_config, dict):
tensorizer_config = TensorizerConfig(**tensorizer_config)
serialize_vllm_model(
model=model,
tensorizer_config=tensorizer_config,
model_config=model_config,
)

View File

@@ -0,0 +1,118 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import time
import torch
import torch.nn as nn
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs
from vllm.config import ModelConfig, VllmConfig
from vllm.distributed.tpu_distributed_utils import get_fqn, shard_model
from vllm.logger import init_logger
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
from vllm.model_executor.model_loader.utils import (
initialize_model,
process_weights_after_loading,
)
from vllm.utils.torch_utils import set_default_torch_dtype
logger = init_logger(__name__)
class TPUModelLoader(DefaultModelLoader):
"""
A TPU model loader for model loading under SPMD mode.
"""
def load_model(
self,
vllm_config: VllmConfig,
model_config: ModelConfig,
mesh: xs.Mesh | None = None,
) -> nn.Module:
# Initialize model and load weights on CPU. Then, during SPMD partition,
# weights are sharded and transferred to TPUs.
self.counter_before_loading_weights = time.perf_counter()
model_config = vllm_config.model_config
assert model_config.quantization is None, "Quantization not supported"
target_device = torch.device("cpu")
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = initialize_model(vllm_config=vllm_config)
load_format = vllm_config.load_config.load_format
if load_format != "dummy":
weights_to_load = {name for name, _ in model.named_parameters()}
all_weights = self.get_all_weights(model_config, model)
loaded_weights = model.load_weights(all_weights)
self.counter_after_loading_weights = time.perf_counter()
logger.info(
"Loading weights took %.2f seconds",
self.counter_after_loading_weights
- self.counter_before_loading_weights,
)
# We only enable strict check for non-quantized models
# that have loaded weights tracking currently.
if model_config.quantization is None and loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded:
raise ValueError(
"Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}"
)
else:
logger.info("Use dummy weight during weight loading.")
process_weights_after_loading(model, model_config, target_device)
counter_before_partition = time.perf_counter()
model = model.eval()
model = model.to("xla")
shard_model(model, mesh)
counter_after_partition = time.perf_counter()
logger.info(
"Partition model took %.2f seconds",
counter_after_partition - counter_before_partition,
)
# Ensure the model is properly loaded.
self._check_model_is_loaded(mesh, model)
# Need to torch compile after model sharding are done. Because the
# compiler hints ('xs.mark_sharding') are torch ops.
if not model_config.is_multimodal_model:
model.model = torch.compile(model.model, backend="openxla")
else:
model.language_model.model = torch.compile(
model.language_model.model, backend="openxla"
)
return model
def _check_model_is_loaded(self, mesh: xs.Mesh | None, model: nn.Module) -> None:
"""
Ensure the model is properly loaded.
1. All model parameters and buffers are on XLA device.
2. Non-SPMD friendly layers are replaced as expected.
"""
device = xm.xla_device()
device_type = str(device.type)
# Check parameters
for name, param in model.named_parameters():
assert param.device.type == device_type, (
f"Parameter {name} is on {param.device.type} instead of {device_type}"
)
# Check buffers
for name, buffer in model.named_buffers():
assert buffer.device.type == device_type, (
f"Buffer {name} is on {buffer.device.type} instead of {device_type}"
)
for module in model.modules():
if (mesh is not None) and (get_fqn(module) == "QKVParallelLinear"):
raise AssertionError(
"QKVParallelLinear should be replaced by \
XlaQKVParallelLinear under SPMD mode."
)

View File

@@ -0,0 +1,288 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utilities for selecting and loading models."""
import inspect
import warnings
from contextlib import contextmanager
from dataclasses import dataclass, field
import torch
from torch import nn
from typing_extensions import assert_never
from vllm.attention import Attention
from vllm.attention.layer import MLAAttention
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from vllm.model_executor.models.adapters import (
as_embedding_model,
as_reward_model,
as_seq_cls_model,
try_create_mm_pooling_model_cls,
)
from vllm.model_executor.models.interfaces import SupportsQuant, supports_multimodal
from vllm.utils.platform_utils import is_pin_memory_available
logger = init_logger(__name__)
def initialize_model(
vllm_config: VllmConfig,
*,
prefix: str = "",
model_class: type[nn.Module] | None = None,
model_config: ModelConfig | None = None,
) -> nn.Module:
"""Initialize a model with the given configurations."""
if model_config is None:
model_config = vllm_config.model_config
if model_class is None:
model_class, _ = get_model_architecture(model_config)
if vllm_config.quant_config is not None:
configure_quant_config(vllm_config.quant_config, model_class)
signatures = inspect.signature(model_class.__init__)
all_params = [param.name for param in signatures.parameters.values()]
if "vllm_config" in all_params and "prefix" in all_params:
# new-style model class
with set_current_vllm_config(vllm_config, check_compile=True, prefix=prefix):
return model_class(vllm_config=vllm_config, prefix=prefix)
msg = (
"vLLM model class should accept `vllm_config` and `prefix` as "
"input arguments. Possibly you have an old-style model class"
" registered from out of tree and it is used for new vLLM version. "
"Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
"for the design and update the model class accordingly."
)
warnings.warn(msg, DeprecationWarning, stacklevel=2)
logger.warning(
"Trying to guess the arguments for old-style model class %s",
model_class,
)
# try to be compatible with old-style model class
kwargs = {}
if "prefix" in all_params:
kwargs["prefix"] = prefix
if "config" in all_params:
kwargs["config"] = model_config.hf_config
if "cache_config" in all_params:
kwargs["cache_config"] = vllm_config.cache_config
if "quant_config" in all_params:
kwargs["quant_config"] = vllm_config.quant_config
if "lora_config" in all_params:
kwargs["lora_config"] = vllm_config.lora_config
if "scheduler_config" in all_params:
kwargs["scheduler_config"] = vllm_config.scheduler_config
with set_current_vllm_config(vllm_config, check_compile=True, prefix=prefix):
return model_class(**kwargs)
def process_weights_after_loading(
model: nn.Module, model_config: ModelConfig, target_device: torch.device
) -> None:
# to avoid circular dependency
from vllm.model_executor.model_loader.online_quantization import (
maybe_save_metadata_and_attributes_for_weight_reloading,
)
maybe_save_metadata_and_attributes_for_weight_reloading(model, model_config)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if isinstance(quant_method, QuantizeMethodBase):
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
# Initialize post-load attention weights for both Attention and MLA.
# NOTE: Happens after other modules so we can easily decompress weights.
for _, module in model.named_modules():
if isinstance(module, (Attention, MLAAttention)) and hasattr(
module, "process_weights_after_loading"
):
# TODO(lucas): see if there is a way to unify the signatures
# of process_weights_after_loading
module.process_weights_after_loading(model_config.dtype)
@contextmanager
def device_loading_context(module: torch.nn.Module, target_device: torch.device):
if target_device.type == "cpu":
# If target is CPU, no need to move anything
yield module
return
original_device_states: dict[str, torch.device] = {}
# Store original device states and move parameters to GPU if they're on CPU
for name, p in module.named_parameters():
if p.device.type == "cpu":
original_device_states[name] = p.device
p.data = p.data.to(target_device)
# Parameters already on target device are not touched
try:
yield module
finally:
# Restore parameters to their original devices, ignoring new parameters
pin_memory = is_pin_memory_available()
for name, p in module.named_parameters():
if name in original_device_states:
original_device: torch.device = original_device_states[name]
if original_device.type == "cpu":
# `torch.empty_like` does not support `pin_memory` argument
cpu_data = torch.empty_strided(
size=p.data.size(),
stride=p.data.stride(),
dtype=p.data.dtype,
layout=p.data.layout,
device="cpu",
pin_memory=pin_memory,
)
cpu_data.copy_(p.data)
p.data = cpu_data
else:
p.data = p.data.to(original_device)
# New parameters or parameters already on target device are untouched
_MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]()
"""Caches the outputs of `_get_model_architecture`."""
def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", [])
model_cls, arch = model_config.registry.resolve_model_cls(
architectures,
model_config=model_config,
)
if arch == model_config._get_transformers_backend_cls():
assert model_config.model_impl != "vllm"
if model_config.model_impl == "auto":
logger.warning_once(
"%s has no vLLM implementation, falling back to Transformers "
"implementation. Some features may not be supported and "
"performance may not be optimal.",
arch,
)
convert_type = model_config.convert_type
if convert_type != "none" and supports_multimodal(model_cls):
logger.debug_once("Detected conversion of Multi Modal model.")
converted = try_create_mm_pooling_model_cls(model_cls)
if converted is not None:
logger.debug_once("Creating wrapper class to forward pooler.")
return converted, arch
else:
logger.debug_once("Attempting direct conversion.")
if convert_type == "none":
pass
elif convert_type == "embed":
logger.debug_once("Converting to embedding model.")
model_cls = as_embedding_model(model_cls)
elif convert_type == "classify":
logger.debug_once("Converting to sequence classification model.")
model_cls = as_seq_cls_model(model_cls)
elif convert_type == "reward":
logger.debug_once("Converting to reward model.")
model_cls = as_reward_model(model_cls)
else:
assert_never(convert_type)
return model_cls, arch
def get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]:
key = hash(
(
model_config.model,
model_config.convert_type,
model_config.runner_type,
model_config.trust_remote_code,
model_config.model_impl,
tuple(getattr(model_config.hf_config, "architectures", [])),
)
)
if key in _MODEL_ARCH_BY_HASH:
return _MODEL_ARCH_BY_HASH[key]
model_arch = _get_model_architecture(model_config)
_MODEL_ARCH_BY_HASH[key] = model_arch
return model_arch
def get_model_cls(model_config: ModelConfig) -> type[nn.Module]:
return get_model_architecture(model_config)[0]
def get_architecture_class_name(model_config: ModelConfig) -> str:
return get_model_architecture(model_config)[1]
@dataclass
class ParamMapping:
"""
A class to handle parameter mapping for model weight loading.
It creates a bidirectional mapping between packed parameters and their
constituent parts.
"""
packed_mapping: dict[str, list[str]]
inverse_packed_mapping: dict[str, tuple[str, int]] = field(default_factory=dict)
def __post_init__(self):
for packed_name, sub_params in self.packed_mapping.items():
# Skip self-contained cases (e.g., {"W_pack": ["W_pack"]})
if len(sub_params) == 1 and sub_params[0] == packed_name:
continue
for index, param_name in enumerate(sub_params):
self.inverse_packed_mapping[param_name] = (
packed_name,
index,
)
def get_sub_modules(self, module_name: str) -> tuple[str, list[str]] | None:
for key, value in self.packed_mapping.items():
if module_name.endswith(key):
return key, value
return None
def configure_quant_config(
quant_config: QuantizationConfig, model_class: type[nn.Module]
):
"""
Pass packed_modules_mapping by reference to quant_config so that
quant_config can properly match fused modules
Note that model attributes are passed by reference to quant_config,
enabling them to be updated by model_class.__new__ (ex. chatglm, qwen)
Once the `SupportsQuant` mixin has been added to all models, this
function can be removed
"""
if not issubclass(model_class, SupportsQuant):
hf_to_vllm_mapper = getattr(model_class, "hf_to_vllm_mapper", None)
packed_mapping = getattr(model_class, "packed_modules_mapping", None)
# pass mappings by reference to quant_config
if hf_to_vllm_mapper is not None:
quant_config.apply_vllm_mapper(hf_to_vllm_mapper)
if packed_mapping is not None:
quant_config.packed_modules_mapping = packed_mapping

File diff suppressed because it is too large Load Diff