Add minimal vLLM 0.16.1 build repo for BI-V150
This commit is contained in:
156
vllm/model_executor/model_loader/__init__.py
Normal file
156
vllm/model_executor/model_loader/__init__.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.bitsandbytes_loader import BitsAndBytesModelLoader
|
||||
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
|
||||
from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader
|
||||
from vllm.model_executor.model_loader.gguf_loader import GGUFModelLoader
|
||||
from vllm.model_executor.model_loader.runai_streamer_loader import (
|
||||
RunaiModelStreamerLoader,
|
||||
)
|
||||
from vllm.model_executor.model_loader.sharded_state_loader import ShardedStateLoader
|
||||
from vllm.model_executor.model_loader.tensorizer_loader import TensorizerLoader
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
get_architecture_class_name,
|
||||
get_model_architecture,
|
||||
get_model_cls,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Reminder: Please update docstring in `LoadConfig`
|
||||
# if a new load format is added here
|
||||
LoadFormats = Literal[
|
||||
"auto",
|
||||
"hf",
|
||||
"bitsandbytes",
|
||||
"dummy",
|
||||
"fastsafetensors",
|
||||
"gguf",
|
||||
"mistral",
|
||||
"npcache",
|
||||
"pt",
|
||||
"runai_streamer",
|
||||
"runai_streamer_sharded",
|
||||
"safetensors",
|
||||
"sharded_state",
|
||||
"tensorizer",
|
||||
]
|
||||
_LOAD_FORMAT_TO_MODEL_LOADER: dict[str, type[BaseModelLoader]] = {
|
||||
"auto": DefaultModelLoader,
|
||||
"hf": DefaultModelLoader,
|
||||
"bitsandbytes": BitsAndBytesModelLoader,
|
||||
"dummy": DummyModelLoader,
|
||||
"fastsafetensors": DefaultModelLoader,
|
||||
"gguf": GGUFModelLoader,
|
||||
"mistral": DefaultModelLoader,
|
||||
"npcache": DefaultModelLoader,
|
||||
"pt": DefaultModelLoader,
|
||||
"runai_streamer": RunaiModelStreamerLoader,
|
||||
"runai_streamer_sharded": ShardedStateLoader,
|
||||
"safetensors": DefaultModelLoader,
|
||||
"sharded_state": ShardedStateLoader,
|
||||
"tensorizer": TensorizerLoader,
|
||||
}
|
||||
|
||||
|
||||
def register_model_loader(load_format: str):
|
||||
"""Register a customized vllm model loader.
|
||||
|
||||
When a load format is not supported by vllm, you can register a customized
|
||||
model loader to support it.
|
||||
|
||||
Args:
|
||||
load_format (str): The model loader format name.
|
||||
|
||||
Examples:
|
||||
>>> from vllm.config.load import LoadConfig
|
||||
>>> from vllm.model_executor.model_loader import (
|
||||
... get_model_loader,
|
||||
... register_model_loader,
|
||||
... )
|
||||
>>> from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
>>>
|
||||
>>> @register_model_loader("my_loader")
|
||||
... class MyModelLoader(BaseModelLoader):
|
||||
... def download_model(self):
|
||||
... pass
|
||||
...
|
||||
... def load_weights(self):
|
||||
... pass
|
||||
>>>
|
||||
>>> load_config = LoadConfig(load_format="my_loader")
|
||||
>>> type(get_model_loader(load_config))
|
||||
<class 'MyModelLoader'>
|
||||
""" # noqa: E501
|
||||
|
||||
def _wrapper(model_loader_cls):
|
||||
if load_format in _LOAD_FORMAT_TO_MODEL_LOADER:
|
||||
logger.warning(
|
||||
"Load format `%s` is already registered, and will be "
|
||||
"overwritten by the new loader class `%s`.",
|
||||
load_format,
|
||||
model_loader_cls,
|
||||
)
|
||||
if not issubclass(model_loader_cls, BaseModelLoader):
|
||||
raise ValueError(
|
||||
"The model loader must be a subclass of `BaseModelLoader`."
|
||||
)
|
||||
_LOAD_FORMAT_TO_MODEL_LOADER[load_format] = model_loader_cls
|
||||
logger.info(
|
||||
"Registered model loader `%s` with load format `%s`",
|
||||
model_loader_cls,
|
||||
load_format,
|
||||
)
|
||||
return model_loader_cls
|
||||
|
||||
return _wrapper
|
||||
|
||||
|
||||
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
||||
"""Get a model loader based on the load format."""
|
||||
load_format = load_config.load_format
|
||||
if load_format not in _LOAD_FORMAT_TO_MODEL_LOADER:
|
||||
raise ValueError(f"Load format `{load_format}` is not supported")
|
||||
return _LOAD_FORMAT_TO_MODEL_LOADER[load_format](load_config)
|
||||
|
||||
|
||||
def get_model(
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
model_config: ModelConfig | None = None,
|
||||
prefix: str = "",
|
||||
load_config: LoadConfig | None = None,
|
||||
) -> nn.Module:
|
||||
loader = get_model_loader(load_config or vllm_config.load_config)
|
||||
if model_config is None:
|
||||
model_config = vllm_config.model_config
|
||||
return loader.load_model(
|
||||
vllm_config=vllm_config, model_config=model_config, prefix=prefix
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_model",
|
||||
"get_model_loader",
|
||||
"get_architecture_class_name",
|
||||
"get_model_architecture",
|
||||
"get_model_cls",
|
||||
"register_model_loader",
|
||||
"BaseModelLoader",
|
||||
"BitsAndBytesModelLoader",
|
||||
"GGUFModelLoader",
|
||||
"DefaultModelLoader",
|
||||
"DummyModelLoader",
|
||||
"RunaiModelStreamerLoader",
|
||||
"ShardedStateLoader",
|
||||
"TensorizerLoader",
|
||||
]
|
||||
86
vllm/model_executor/model_loader/base_loader.py
Normal file
86
vllm/model_executor/model_loader/base_loader.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
initialize_model,
|
||||
process_weights_after_loading,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.tracing import instrument
|
||||
from vllm.utils.mem_utils import format_gib
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BaseModelLoader(ABC):
|
||||
"""Base class for model loaders."""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
self.load_config = load_config
|
||||
|
||||
@abstractmethod
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
"""Download a model so that it can be immediately loaded."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
"""Load weights into a model. This standalone API allows
|
||||
inplace weights loading for an already-initialized model"""
|
||||
raise NotImplementedError
|
||||
|
||||
@instrument(span_name="Load model")
|
||||
def load_model(
|
||||
self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = ""
|
||||
) -> nn.Module:
|
||||
"""Load a model with the given configurations."""
|
||||
device_config = vllm_config.device_config
|
||||
load_config = vllm_config.load_config
|
||||
load_device = (
|
||||
device_config.device if load_config.device is None else load_config.device
|
||||
)
|
||||
target_device = torch.device(load_device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
model = initialize_model(
|
||||
vllm_config=vllm_config, model_config=model_config, prefix=prefix
|
||||
)
|
||||
|
||||
log_model_inspection(model)
|
||||
|
||||
logger.debug("Loading weights on %s ...", load_device)
|
||||
# Quantization does not happen in `load_weights` but after it
|
||||
self.load_weights(model, model_config)
|
||||
|
||||
# Log peak GPU memory after loading weights. This is needed
|
||||
# to have test coverage on peak memory for online quantization.
|
||||
if current_platform.is_cuda():
|
||||
peak_memory = torch.cuda.max_memory_allocated()
|
||||
logger.debug_once(
|
||||
"Peak GPU memory after loading weights: %s GiB",
|
||||
format_gib(peak_memory),
|
||||
scope="local",
|
||||
)
|
||||
|
||||
process_weights_after_loading(model, model_config, target_device)
|
||||
|
||||
return model.eval()
|
||||
|
||||
|
||||
def log_model_inspection(model: nn.Module) -> None:
|
||||
"""Log model structure if VLLM_LOG_MODEL_INSPECTION=1."""
|
||||
if not envs.VLLM_LOG_MODEL_INSPECTION:
|
||||
return
|
||||
|
||||
from vllm.model_inspection import format_model_inspection
|
||||
|
||||
logger.info("vLLM model structure:\n%s", format_model_inspection(model))
|
||||
817
vllm/model_executor/model_loader/bitsandbytes_loader.py
Normal file
817
vllm/model_executor/model_loader/bitsandbytes_loader.py
Normal file
@@ -0,0 +1,817 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import fnmatch
|
||||
import glob
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.utils import is_moe_model
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.utils import ParamMapping
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_safetensors_index_file_from_hf,
|
||||
download_weights_from_hf,
|
||||
filter_duplicate_safetensors_files,
|
||||
filter_files_not_needed_for_inference,
|
||||
pt_weights_iterator,
|
||||
safetensors_weights_iterator,
|
||||
)
|
||||
from vllm.model_executor.models import is_pooling_model
|
||||
from vllm.model_executor.utils import (
|
||||
get_moe_expert_mapping,
|
||||
get_packed_modules_mapping,
|
||||
set_weight_attrs,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
"""Model loader to load model weights with BitsAndBytes quantization."""
|
||||
|
||||
possible_config_file_names = ["adapter_config.json"]
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
|
||||
# Save the module names without sharding.
|
||||
self.unsharded_weights_modules: list[str] = []
|
||||
# Save the module names that are sharded by column.
|
||||
self.column_sharded_weights_modules: list[str] = []
|
||||
# Modules whose weights might have fused on disk
|
||||
# we need their output_sizes to make shard in flight correctly with TP
|
||||
self.maybe_fused_weights_modules: dict[str, list[int]] = {}
|
||||
# Store all module names (from transformers) that support
|
||||
# BNB quantization.
|
||||
self.target_modules: list[str] = []
|
||||
self.tp_disabled_modules: list[str] = []
|
||||
# Store the mapping of expert parameters for MoE models.
|
||||
self.expert_params_mapping: list[tuple[str, str, int, str]] = []
|
||||
# mapping weight names from transformers to vllm.
|
||||
self.weight_mapper: Callable = lambda name: name
|
||||
self.pre_quant: bool = False
|
||||
self.load_8bit: bool = False
|
||||
self.is_pool_model: bool = False
|
||||
|
||||
def _get_weight_files(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
allowed_patterns: list[str],
|
||||
revision: str | None = None,
|
||||
) -> tuple[str, list[str], str]:
|
||||
"""Retrieve weight files. Download the files if necessary.
|
||||
|
||||
Return the weight files and the file pattern."""
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
|
||||
if is_local:
|
||||
for pattern in allowed_patterns:
|
||||
weight_files = glob.glob(os.path.join(model_name_or_path, pattern))
|
||||
if weight_files:
|
||||
return model_name_or_path, weight_files, pattern
|
||||
else:
|
||||
hf_api = HfApi()
|
||||
repo_files = hf_api.list_repo_files(repo_id=model_name_or_path)
|
||||
for pattern in allowed_patterns:
|
||||
matching_files = fnmatch.filter(repo_files, pattern)
|
||||
if matching_files:
|
||||
hf_folder = download_weights_from_hf(
|
||||
model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
[pattern],
|
||||
revision,
|
||||
ignore_patterns=self.load_config.ignore_patterns,
|
||||
)
|
||||
return (
|
||||
hf_folder,
|
||||
glob.glob(os.path.join(hf_folder, pattern)),
|
||||
pattern,
|
||||
)
|
||||
|
||||
raise RuntimeError(f"No model weights found in: `{model_name_or_path}`")
|
||||
|
||||
def _prepare_weights(
|
||||
self, model_name_or_path: str, revision: str | None
|
||||
) -> tuple[list[str], bool]:
|
||||
"""Prepare weight files for the model."""
|
||||
|
||||
allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]
|
||||
|
||||
hf_folder, hf_weights_files, matched_pattern = self._get_weight_files(
|
||||
model_name_or_path, allowed_patterns, revision
|
||||
)
|
||||
|
||||
use_safetensors = matched_pattern == "*.safetensors"
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
index_file = SAFE_WEIGHTS_INDEX_NAME
|
||||
if use_safetensors:
|
||||
# For models like Mistral-7B-Instruct-v0.3
|
||||
# there are both sharded safetensors files and a consolidated
|
||||
# safetensors file. Using both breaks.
|
||||
# Here, we download the `model.safetensors.index.json` and filter
|
||||
# any files not found in the index.
|
||||
if not is_local:
|
||||
download_safetensors_index_file_from_hf(
|
||||
model_name_or_path,
|
||||
index_file,
|
||||
self.load_config.download_dir,
|
||||
revision,
|
||||
)
|
||||
hf_weights_files = filter_duplicate_safetensors_files(
|
||||
hf_weights_files, hf_folder, index_file
|
||||
)
|
||||
else:
|
||||
hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files)
|
||||
|
||||
if len(hf_weights_files) == 0:
|
||||
raise RuntimeError(
|
||||
f"Cannot find any model weights with `{model_name_or_path}`"
|
||||
)
|
||||
|
||||
return hf_weights_files, use_safetensors
|
||||
|
||||
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
|
||||
def _maybe_pool_model(module_name: str):
|
||||
# For pool model, we need to add the prefix `model.`
|
||||
# for the weight name if possible.
|
||||
if (
|
||||
self.is_pool_model
|
||||
and self.target_modules[0].startswith("model.")
|
||||
and not module_name.startswith("model.")
|
||||
):
|
||||
return "model." + module_name
|
||||
|
||||
return module_name
|
||||
|
||||
if use_safetensors:
|
||||
iterator = safetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
)
|
||||
else:
|
||||
iterator = pt_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
self.load_config.pt_load_map_location,
|
||||
)
|
||||
for org_name, param in iterator:
|
||||
# mapping weight names from transformers to vllm while preserving
|
||||
# original names.
|
||||
mapped_name = self.weight_mapper(org_name)
|
||||
mapped_name = _maybe_pool_model(mapped_name)
|
||||
|
||||
yield org_name, mapped_name, param
|
||||
|
||||
def _get_quantized_weights_iterator(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
revision: str | None,
|
||||
) -> tuple[Generator[tuple[str, torch.Tensor], None, None], dict[str, Any]]:
|
||||
"""Get an iterator to the model weights with bitsandbytes quantization,
|
||||
as well as the quantization state dictionary."""
|
||||
|
||||
# only load the bitsandbytes module when needed
|
||||
try:
|
||||
import bitsandbytes
|
||||
|
||||
if version.parse(bitsandbytes.__version__) < version.parse("0.46.1"):
|
||||
raise ImportError(
|
||||
"bitsandbytes version is wrong. Please "
|
||||
"install bitsandbytes>=0.46.1."
|
||||
)
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Please install bitsandbytes>=0.46.1 via "
|
||||
"`pip install bitsandbytes>=0.46.1` to use "
|
||||
"bitsandbytes quantizer."
|
||||
) from err
|
||||
|
||||
hf_weights_files, use_safetensors = self._prepare_weights(
|
||||
model_name_or_path, revision
|
||||
)
|
||||
|
||||
quant_state_dict: dict[str, Any] = {}
|
||||
|
||||
if self.pre_quant:
|
||||
if self.load_8bit:
|
||||
return self._quantized_8bit_generator(
|
||||
hf_weights_files, use_safetensors, quant_state_dict
|
||||
), quant_state_dict
|
||||
else:
|
||||
return self._quantized_4bit_generator(
|
||||
hf_weights_files, use_safetensors, quant_state_dict
|
||||
), quant_state_dict
|
||||
|
||||
return self._unquantized_generator(
|
||||
hf_weights_files, use_safetensors, quant_state_dict
|
||||
), quant_state_dict
|
||||
|
||||
def _is_8bit_weight_name(self, weight_name: str):
|
||||
quantized_suffix = {".scb", ".weight_format"}
|
||||
return any(weight_name.lower().endswith(suffix) for suffix in quantized_suffix)
|
||||
|
||||
def _is_4bit_weight_name(self, weight_name: str):
|
||||
quantized_suffix = {
|
||||
"absmax",
|
||||
"quant_map",
|
||||
"nested_absmax",
|
||||
"nested_quant_map",
|
||||
"bitsandbytes",
|
||||
}
|
||||
suffix = weight_name.split(".")[-1]
|
||||
return any(q_suffix in suffix for q_suffix in quantized_suffix)
|
||||
|
||||
def _quantized_8bit_generator(
|
||||
self, hf_weights_files, use_safetensors, quant_state_dict
|
||||
) -> Generator:
|
||||
for (
|
||||
org_weight_name,
|
||||
mapped_weight_name,
|
||||
weight_tensor,
|
||||
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
|
||||
if not mapped_weight_name.lower().endswith(".scb"):
|
||||
continue
|
||||
|
||||
weight_key = mapped_weight_name.lower().replace(".scb", ".weight")
|
||||
quant_state_dict[weight_key] = weight_tensor
|
||||
|
||||
for (
|
||||
org_weight_name,
|
||||
mapped_weight_name,
|
||||
weight_tensor,
|
||||
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
|
||||
if self._is_8bit_weight_name(mapped_weight_name):
|
||||
continue
|
||||
|
||||
if mapped_weight_name in quant_state_dict:
|
||||
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
|
||||
yield org_weight_name, weight_tensor
|
||||
else:
|
||||
yield org_weight_name, weight_tensor
|
||||
|
||||
def _quantized_4bit_generator(
|
||||
self, hf_weights_files, use_safetensors, quant_state_dict
|
||||
) -> Generator:
|
||||
from bitsandbytes.functional import QuantState
|
||||
|
||||
# First iterate over all quant state weights
|
||||
weight_iterator = self._hf_weight_iter(hf_weights_files, use_safetensors)
|
||||
temp_state_dict = {}
|
||||
for (
|
||||
org_weight_name,
|
||||
mapped_weight_name,
|
||||
weight_tensor,
|
||||
) in weight_iterator:
|
||||
if not self._is_4bit_weight_name(mapped_weight_name):
|
||||
continue
|
||||
# bitsandbytes library requires
|
||||
# weight.quant_state.bitsandbytes__* in CPU
|
||||
if "quant_state.bitsandbytes" in mapped_weight_name:
|
||||
temp_state_dict[mapped_weight_name] = weight_tensor.cpu().data
|
||||
else:
|
||||
temp_state_dict[mapped_weight_name] = weight_tensor
|
||||
|
||||
# Closure to parse quant_state for each prequant weight
|
||||
def _parse_quant_state(param_name: str, temp_state_dict: dict) -> QuantState:
|
||||
quant_state = {}
|
||||
for k in temp_state_dict:
|
||||
if param_name + "." in k:
|
||||
quant_state[k] = temp_state_dict[k]
|
||||
|
||||
return QuantState.from_dict(
|
||||
quant_state, device=current_platform.device_type
|
||||
)
|
||||
|
||||
# Second iterate over all prequant and normal weights
|
||||
# pre quantized weights would have a quant_state
|
||||
for (
|
||||
org_weight_name,
|
||||
mapped_weight_name,
|
||||
weight_tensor,
|
||||
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
|
||||
if self._is_4bit_weight_name(mapped_weight_name):
|
||||
continue
|
||||
|
||||
if (
|
||||
f"{mapped_weight_name}.quant_state.bitsandbytes__nf4" in temp_state_dict
|
||||
) or (
|
||||
f"{mapped_weight_name}.quant_state.bitsandbytes__fp4" in temp_state_dict
|
||||
):
|
||||
quant_state = _parse_quant_state(mapped_weight_name, temp_state_dict)
|
||||
quant_state_dict[mapped_weight_name] = quant_state
|
||||
yield org_weight_name, weight_tensor
|
||||
else:
|
||||
yield org_weight_name, weight_tensor
|
||||
|
||||
def _unquantized_generator(
|
||||
self, hf_weights_files, use_safetensors, quant_state_dict
|
||||
) -> Generator:
|
||||
from bitsandbytes.functional import quantize_4bit
|
||||
|
||||
global_tp_size = get_tensor_model_parallel_world_size()
|
||||
global_tp_rank = get_tensor_model_parallel_rank()
|
||||
check_match = (
|
||||
lambda weight_name, module_name: weight_name.removesuffix(".weight")
|
||||
== module_name
|
||||
)
|
||||
for (
|
||||
org_weight_name,
|
||||
mapped_weight_name,
|
||||
weight_tensor,
|
||||
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
|
||||
# override tp_size and tp_rank if the module has disabled TP
|
||||
if any(
|
||||
tp_disabled_module in mapped_weight_name
|
||||
for tp_disabled_module in self.tp_disabled_modules
|
||||
):
|
||||
tp_size = 1
|
||||
tp_rank = 0
|
||||
else:
|
||||
tp_size = global_tp_size
|
||||
tp_rank = global_tp_rank
|
||||
|
||||
if any(
|
||||
target_module in mapped_weight_name
|
||||
for target_module in self.target_modules
|
||||
) and mapped_weight_name.endswith(".weight"):
|
||||
# Without sharding
|
||||
if any(
|
||||
check_match(mapped_weight_name, module)
|
||||
for module in self.unsharded_weights_modules
|
||||
):
|
||||
weight_sub_tensor = weight_tensor
|
||||
# Shard by column
|
||||
elif any(
|
||||
check_match(mapped_weight_name, module)
|
||||
for module in self.column_sharded_weights_modules
|
||||
):
|
||||
total_size = weight_tensor.size(-1)
|
||||
start_index = total_size // tp_size * tp_rank
|
||||
end_index = total_size // tp_size * (tp_rank + 1)
|
||||
weight_sub_tensor = weight_tensor[..., start_index:end_index]
|
||||
# Weights have fused on disk. In this case, we assume that the
|
||||
# weight and module use same name.
|
||||
elif any(
|
||||
check_match(mapped_weight_name, module)
|
||||
for module in self.maybe_fused_weights_modules
|
||||
):
|
||||
# special case for fused weights
|
||||
# get the size of each shard weight tensor
|
||||
total_shard_sizes = next(
|
||||
(
|
||||
sizes
|
||||
for module, sizes in self.maybe_fused_weights_modules.items() # noqa: E501
|
||||
if check_match(mapped_weight_name, module)
|
||||
)
|
||||
)
|
||||
total_size = weight_tensor.size(0)
|
||||
assert total_size == sum(total_shard_sizes)
|
||||
# get the start/end index of each shard weight tensor
|
||||
total_start_index = list(
|
||||
itertools.accumulate([0] + total_shard_sizes)
|
||||
)[:-1]
|
||||
shard_weights_index = [
|
||||
(
|
||||
idx + size // tp_size * tp_rank,
|
||||
idx + size // tp_size * (tp_rank + 1),
|
||||
)
|
||||
for idx, size in zip(total_start_index, total_shard_sizes)
|
||||
]
|
||||
# slice and reorder the weight tensor
|
||||
weight_tensor = [
|
||||
weight_tensor[start_index:end_index, ...]
|
||||
for start_index, end_index in shard_weights_index
|
||||
]
|
||||
weight_sub_tensor = torch.cat(weight_tensor, dim=0)
|
||||
# Shard by row
|
||||
else:
|
||||
total_size = weight_tensor.size(0)
|
||||
start_index = total_size // tp_size * tp_rank
|
||||
end_index = total_size // tp_size * (tp_rank + 1)
|
||||
weight_sub_tensor = weight_tensor[start_index:end_index, ...]
|
||||
|
||||
# bitsandbytes requires data in GPU
|
||||
if weight_sub_tensor.is_cuda:
|
||||
loaded_weight = weight_sub_tensor
|
||||
else:
|
||||
loaded_weight = weight_sub_tensor.to(
|
||||
device=current_platform.device_type
|
||||
)
|
||||
|
||||
# remove the following after the issue is fixed:
|
||||
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
|
||||
if loaded_weight.is_contiguous() is False:
|
||||
loaded_weight = loaded_weight.contiguous()
|
||||
|
||||
with set_default_torch_dtype(torch.float32):
|
||||
processed_weight, quant_state = quantize_4bit(
|
||||
loaded_weight,
|
||||
compress_statistics=True,
|
||||
quant_type="nf4",
|
||||
)
|
||||
|
||||
quant_state_dict[mapped_weight_name] = quant_state
|
||||
else:
|
||||
processed_weight = weight_tensor
|
||||
yield org_weight_name, processed_weight
|
||||
|
||||
def _get_bnb_target_modules(self, model: nn.Module) -> None:
|
||||
"""
|
||||
Identify and collect all modules that support BitsAndBytes
|
||||
quantization.
|
||||
"""
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LinearBase) and hasattr(
|
||||
module.quant_method, "quant_config"
|
||||
):
|
||||
if modules_info := self.modules_mapping.get_sub_modules(name):
|
||||
# Map vllm's names to transformers's names.
|
||||
rep_name, sub_modules = modules_info
|
||||
for sub_name in sub_modules:
|
||||
new_name = name.replace(rep_name, sub_name)
|
||||
self.target_modules.append(new_name)
|
||||
if module.disable_tp:
|
||||
self.tp_disabled_modules.append(new_name)
|
||||
# Add original module name even if the module has stacked map,
|
||||
# in case model has a mixture of disk-merged and disk-split
|
||||
# weights with same last name.
|
||||
self.target_modules.append(name)
|
||||
if module.disable_tp:
|
||||
self.tp_disabled_modules.append(name)
|
||||
elif isinstance(module, FusedMoE) and hasattr(
|
||||
module.quant_method, "quant_config"
|
||||
):
|
||||
# TODO: support FusedMoE with prequant and 8bit.
|
||||
if self.pre_quant and self.load_8bit:
|
||||
raise ValueError(
|
||||
"Prequant BitsAndBytes 8bit models with FusedMoE "
|
||||
"is not supported yet."
|
||||
)
|
||||
# Get the corresponding weight name using module name and
|
||||
# expert_params_mapping.
|
||||
|
||||
for exp in self.expert_params_mapping:
|
||||
weight_name = exp[1]
|
||||
rep_name = name.replace("experts", "") + weight_name.removesuffix(
|
||||
"."
|
||||
)
|
||||
self.target_modules.append(rep_name)
|
||||
|
||||
assert self.target_modules, (
|
||||
"vLLM currently does not support BNB quantization for"
|
||||
)
|
||||
f" {type(model).__name__}"
|
||||
|
||||
def _classify_module_sharding(self, model: nn.Module):
|
||||
"""
|
||||
Categorize modules based on their weight sharding requirements
|
||||
for tensor parallelism.
|
||||
"""
|
||||
for name, module in model.named_modules():
|
||||
# Some modules like `ReplicatedLinear` should not have their weights
|
||||
# sharded. The reason for implementing it this way is to avoid new
|
||||
# static variable in the model implementation.
|
||||
if isinstance(module, (ReplicatedLinear,)):
|
||||
self.unsharded_weights_modules.append(name)
|
||||
# `QKVParallelLinear` and `MergedColumnParallelLinear` might have
|
||||
# fused weights on disk. We need to use the output sizes of these
|
||||
# modules to shard the weights correctly.
|
||||
elif isinstance(module, (QKVParallelLinear, MergedColumnParallelLinear)):
|
||||
self.maybe_fused_weights_modules[name] = module.output_sizes
|
||||
# In TP, these weights are partitioned along the column
|
||||
# dimension (dim=-1)
|
||||
elif isinstance(module, (RowParallelLinear,)):
|
||||
self.column_sharded_weights_modules.append(name)
|
||||
elif isinstance(module, FusedMoE):
|
||||
expert_mapping = self.expert_params_mapping
|
||||
for exp in expert_mapping:
|
||||
if exp[-1] == "w2":
|
||||
weight_name = exp[1]
|
||||
rep_name = name.replace(
|
||||
"experts", ""
|
||||
) + weight_name.removesuffix(".")
|
||||
self.column_sharded_weights_modules.append(rep_name)
|
||||
|
||||
def _verify_model_compatibility(
|
||||
self, model: nn.Module, model_config: ModelConfig
|
||||
) -> None:
|
||||
"""
|
||||
Verify that the model is compatible with BitsAndBytes quantization.
|
||||
"""
|
||||
if not hasattr(model, "load_weights"):
|
||||
raise AttributeError(
|
||||
"The required method 'load_weights' is not defined in class"
|
||||
f" {type(model).__name__}."
|
||||
)
|
||||
|
||||
if not hasattr(model, "packed_modules_mapping"):
|
||||
raise AttributeError(
|
||||
f"Model {type(model).__name__} does not support BitsAndBytes "
|
||||
"quantization yet. No 'packed_modules_mapping' found."
|
||||
)
|
||||
|
||||
quant_config = getattr(model_config.hf_config, "quantization_config", None)
|
||||
if quant_config and (quant_method := quant_config.get("quant_method")):
|
||||
if quant_method == "bitsandbytes":
|
||||
self.pre_quant = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"BitsAndBytes loader does not support {quant_method} quantization"
|
||||
)
|
||||
|
||||
# The quant_states in pre_quantized models cannot work with a split
|
||||
# weight tensor. So TP does not work with pre_quantized bnb models.
|
||||
if self.pre_quant and get_tensor_model_parallel_world_size() > 1:
|
||||
raise ValueError(
|
||||
"Prequant BitsAndBytes models with tensor parallelism is not "
|
||||
"supported. Please try with pipeline parallelism."
|
||||
)
|
||||
if quant_config and self.pre_quant:
|
||||
self.load_8bit = quant_config.get("load_in_8bit", False)
|
||||
|
||||
def _initialize_loader_state(
|
||||
self, model: nn.Module, model_config: ModelConfig
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the loader's internal state based on the model and
|
||||
configuration.
|
||||
"""
|
||||
self.is_pool_model = is_pooling_model(model)
|
||||
self.modules_mapping = ParamMapping(get_packed_modules_mapping(model))
|
||||
|
||||
if is_moe_model(model):
|
||||
self.expert_params_mapping = get_moe_expert_mapping(model)
|
||||
if not self.expert_params_mapping:
|
||||
raise AttributeError(
|
||||
f"MoE Model {type(model).__name__} does not support "
|
||||
"BitsAndBytes quantization yet. Ensure this model has "
|
||||
"'get_expert_mapping' method."
|
||||
)
|
||||
# For some models like Molmo, we need to use hf_to_vllm_mapper
|
||||
# to ensure correct loading of weights.
|
||||
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
|
||||
self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name)
|
||||
|
||||
self._get_bnb_target_modules(model)
|
||||
self._classify_module_sharding(model)
|
||||
|
||||
def _dequantize_dq(self, quant_states: Any):
|
||||
"""
|
||||
When BNB employs Double Quantization, we perform the dequantization of
|
||||
these constants during weight loading rather than at inference time,
|
||||
thereby avoiding this computational overhead during inference. This
|
||||
comes at the cost of increased memory usage.
|
||||
"""
|
||||
from bitsandbytes.functional import QuantState, dequantize_blockwise
|
||||
|
||||
def _dequantize_single_state(quant_state):
|
||||
"""Helper function to dequantize a single QuantState object."""
|
||||
if not (isinstance(quant_state, QuantState) and quant_state.nested):
|
||||
return
|
||||
|
||||
# Copied from: https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.45.3/bitsandbytes/functional.py#L1352-#L1356
|
||||
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
|
||||
absmax += quant_state.offset
|
||||
|
||||
# Ensure float32 dtype
|
||||
if absmax.dtype != torch.float32:
|
||||
absmax = absmax.float()
|
||||
|
||||
quant_state.absmax = absmax
|
||||
quant_state.nested = False
|
||||
quant_state.offset = None
|
||||
quant_state.state2 = None
|
||||
|
||||
if isinstance(quant_states, dict):
|
||||
for quant_state in quant_states.values():
|
||||
_dequantize_single_state(quant_state)
|
||||
else:
|
||||
_dequantize_single_state(quant_states)
|
||||
return quant_states
|
||||
|
||||
def _fuse_moe_quant_states(self, model: nn.Module, quant_states_dict: dict) -> dict:
|
||||
"""
|
||||
|
||||
This function consolidates individual expert quantization states into
|
||||
fused representations for w13 and w2.
|
||||
"""
|
||||
from bitsandbytes.functional import QuantState
|
||||
|
||||
if not self.expert_params_mapping:
|
||||
return dict()
|
||||
|
||||
expert_mapping = self.expert_params_mapping
|
||||
expert_qs_dict = {}
|
||||
for name, module in model.named_modules():
|
||||
if not isinstance(module, FusedMoE):
|
||||
continue
|
||||
w1_states_lst = []
|
||||
w2_states_lst = []
|
||||
w3_states_lst = []
|
||||
for exp in expert_mapping:
|
||||
shard_id = exp[-1]
|
||||
if shard_id not in ("w1", "w2", "w3"):
|
||||
raise ValueError(
|
||||
f"shard_id must be ['w1','w2','w3'] but got {shard_id}."
|
||||
)
|
||||
layer_prefix = name.split("experts")[0]
|
||||
weight_qual_name = layer_prefix + exp[1] + "weight"
|
||||
quant_state = self._dequantize_dq(quant_states_dict[weight_qual_name])
|
||||
if shard_id == "w1":
|
||||
w1_states_lst.append(quant_state)
|
||||
elif shard_id == "w2":
|
||||
w2_states_lst.append(quant_state)
|
||||
else:
|
||||
w3_states_lst.append(quant_state)
|
||||
del quant_states_dict[weight_qual_name]
|
||||
assert len(w1_states_lst) == len(w2_states_lst) == len(w3_states_lst)
|
||||
w13_absmax_lst = []
|
||||
w2_absmax_lst = []
|
||||
w13_total_dim0 = 0
|
||||
w2_total_dim0 = 0
|
||||
for w1_qs, w2_qs, w3_qs in zip(w1_states_lst, w2_states_lst, w3_states_lst):
|
||||
assert w1_qs.shape == w3_qs.shape
|
||||
assert w1_qs.blocksize == w2_qs.blocksize == w3_qs.blocksize
|
||||
assert w1_qs.dtype == w2_qs.dtype == w3_qs.dtype
|
||||
# w1 and w3 are interleaved in storage
|
||||
w13_absmax_lst.append(w1_qs.absmax)
|
||||
w13_absmax_lst.append(w3_qs.absmax)
|
||||
w2_absmax_lst.append(w2_qs.absmax)
|
||||
w13_total_dim0 += w1_qs.shape[0] + w3_qs.shape[0]
|
||||
w2_total_dim0 += w2_qs.shape[0]
|
||||
|
||||
w13_absmax = torch.cat(w13_absmax_lst)
|
||||
w2_absmax = torch.cat(w2_absmax_lst)
|
||||
# Create fused quantization state for w13.
|
||||
w13_qs = QuantState(
|
||||
absmax=w13_absmax,
|
||||
shape=(w13_total_dim0, w1_states_lst[0].shape[1]),
|
||||
code=w1_states_lst[0].code,
|
||||
blocksize=w1_states_lst[0].blocksize,
|
||||
quant_type="nf4",
|
||||
dtype=w1_states_lst[0].dtype,
|
||||
)
|
||||
# Create fused quantization state for w2.
|
||||
w2_qs = QuantState(
|
||||
absmax=w2_absmax,
|
||||
shape=(w2_total_dim0, w2_states_lst[0].shape[1]),
|
||||
code=w2_states_lst[0].code,
|
||||
blocksize=w2_states_lst[0].blocksize,
|
||||
quant_type="nf4",
|
||||
dtype=w2_states_lst[0].dtype,
|
||||
)
|
||||
# The weight suffixes .w13_weight and .w2_weight are consistent
|
||||
# with the param in BitsAndBytesMoEMethod.
|
||||
w13_weight_name = name + ".w13_weight"
|
||||
w2_weight_name = name + ".w2_weight"
|
||||
expert_qs_dict[w13_weight_name] = w13_qs
|
||||
expert_qs_dict[w2_weight_name] = w2_qs
|
||||
return expert_qs_dict
|
||||
|
||||
def _stack_quantization_states(
|
||||
self, model: nn.Module, quant_state_dict: dict
|
||||
) -> dict[str, dict[int, Any]]:
|
||||
stacked_quant_state_dict: dict[str, dict[int, Any]] = {}
|
||||
# TODO: Change this lazy import to normal import
|
||||
# after the checks are updated to run on a new version
|
||||
from vllm.model_executor.models.utils import is_pp_missing_parameter
|
||||
|
||||
param_dict = dict(model.named_parameters())
|
||||
for quant_param_name in quant_state_dict:
|
||||
if is_pp_missing_parameter(quant_param_name, model):
|
||||
continue
|
||||
|
||||
non_stacked_param_name = quant_param_name
|
||||
|
||||
shard_index = 0
|
||||
for shard_name, (
|
||||
weight_name,
|
||||
index,
|
||||
) in self.modules_mapping.inverse_packed_mapping.items():
|
||||
# Some models, such as MiniCPM V2.5/2.6, contain both
|
||||
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
|
||||
# from being incorrectly identified as being present in
|
||||
# 'vpm.encoder.layers.0.self_attn.qkv_proj.weight
|
||||
shard_pos = quant_param_name.find(shard_name)
|
||||
can_correct_rename = (shard_pos > 0) and (
|
||||
quant_param_name[shard_pos - 1] == "."
|
||||
)
|
||||
# If the quant_param_name is packed, it won't occur in the
|
||||
# param_dict before renaming.
|
||||
new_quant_param_name = quant_param_name.replace(shard_name, weight_name)
|
||||
need_rename = (quant_param_name not in param_dict) and (
|
||||
new_quant_param_name in param_dict
|
||||
)
|
||||
if can_correct_rename and need_rename:
|
||||
shard_index = index
|
||||
quant_param_name = new_quant_param_name
|
||||
break
|
||||
|
||||
# Models like Clip/Siglip may skip some layers in initialization,
|
||||
# causing unused quant_param_name in state_dict.
|
||||
if quant_param_name not in param_dict:
|
||||
continue
|
||||
|
||||
if quant_param_name not in stacked_quant_state_dict:
|
||||
stacked_quant_state_dict[quant_param_name] = {}
|
||||
|
||||
stacked_quant_state_dict[quant_param_name][shard_index] = quant_state_dict[
|
||||
non_stacked_param_name
|
||||
]
|
||||
return stacked_quant_state_dict
|
||||
|
||||
def _bind_quant_states_to_params(
|
||||
self, model: nn.Module, stacked_quant_state_dict: dict
|
||||
) -> None:
|
||||
# save quant_states and offsets as the attributes of the parameters
|
||||
param_dict = dict(model.named_parameters())
|
||||
for param_name, param in param_dict.items():
|
||||
if param_name in stacked_quant_state_dict:
|
||||
quant_states = stacked_quant_state_dict[param_name]
|
||||
# Dequantize double quantized values during weight loading.
|
||||
self._dequantize_dq(quant_states)
|
||||
set_weight_attrs(param, {"bnb_quant_state": quant_states})
|
||||
if not isinstance(quant_states, dict):
|
||||
continue
|
||||
|
||||
pack_ratio = getattr(param, "pack_factor", -1)
|
||||
if pack_ratio == -1:
|
||||
raise ValueError(f"pack_factor not set for parameter {param_name}.")
|
||||
|
||||
num_elements = [0] * len(quant_states)
|
||||
for seq, quant_state in quant_states.items():
|
||||
num_elements[seq] = math.prod(quant_state.shape) // pack_ratio
|
||||
|
||||
offsets = np.concatenate(([0], np.cumsum(num_elements)))
|
||||
# Make torch infer_schema happy
|
||||
offsets = torch.tensor(offsets).cpu()
|
||||
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
|
||||
|
||||
if self.load_8bit:
|
||||
set_weight_attrs(
|
||||
param, {"matmul_state": [None] * len(quant_states)}
|
||||
)
|
||||
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
self._verify_model_compatibility(model, model_config)
|
||||
self._initialize_loader_state(model, model_config)
|
||||
|
||||
logger.info(
|
||||
"Loading weights with BitsAndBytes quantization. May take a while ..."
|
||||
)
|
||||
qweight_iterator, quant_state_dict = self._get_quantized_weights_iterator(
|
||||
model_config.model,
|
||||
model_config.revision,
|
||||
)
|
||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||
loaded_weights = model.load_weights(qweight_iterator)
|
||||
# Some models may have weights loading tracker unimplemented.
|
||||
if loaded_weights is not None:
|
||||
weights_not_loaded = weights_to_load - loaded_weights
|
||||
if weights_not_loaded:
|
||||
raise ValueError(
|
||||
"Following weights were not initialized from "
|
||||
f"checkpoint: {weights_not_loaded}"
|
||||
)
|
||||
expert_quant_state_dict = self._fuse_moe_quant_states(model, quant_state_dict)
|
||||
|
||||
stacked_quant_state_dict = self._stack_quantization_states(
|
||||
model, quant_state_dict
|
||||
)
|
||||
|
||||
stacked_quant_state_dict = {
|
||||
**expert_quant_state_dict,
|
||||
**stacked_quant_state_dict,
|
||||
}
|
||||
self._bind_quant_states_to_params(model, stacked_quant_state_dict)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(model_config.model, model_config.revision)
|
||||
307
vllm/model_executor/model_loader/default_loader.py
Normal file
307
vllm/model_executor/model_loader/default_loader.py
Normal file
@@ -0,0 +1,307 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import dataclasses
|
||||
import glob
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Generator, Iterable
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.torchao import torchao_version_at_least
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_safetensors_index_file_from_hf,
|
||||
download_weights_from_hf,
|
||||
fastsafetensors_weights_iterator,
|
||||
filter_duplicate_safetensors_files,
|
||||
filter_files_not_needed_for_inference,
|
||||
get_quant_config,
|
||||
maybe_download_from_modelscope,
|
||||
multi_thread_pt_weights_iterator,
|
||||
multi_thread_safetensors_weights_iterator,
|
||||
np_cache_weights_iterator,
|
||||
pt_weights_iterator,
|
||||
safetensors_weights_iterator,
|
||||
)
|
||||
from vllm.tracing import instrument
|
||||
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DefaultModelLoader(BaseModelLoader):
|
||||
"""Model loader that can load different file types from disk."""
|
||||
|
||||
# default number of thread when enable multithread weight loading
|
||||
DEFAULT_NUM_THREADS = 8
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Source:
|
||||
"""A source for weights."""
|
||||
|
||||
model_or_path: str
|
||||
"""The model ID or path."""
|
||||
|
||||
revision: str | None
|
||||
"""The optional model revision."""
|
||||
|
||||
prefix: str = ""
|
||||
"""A prefix to prepend to all weights."""
|
||||
|
||||
fall_back_to_pt: bool = True
|
||||
"""Whether .pt weights can be used."""
|
||||
|
||||
allow_patterns_overrides: list[str] | None = None
|
||||
"""If defined, weights will load exclusively using these patterns."""
|
||||
|
||||
counter_before_loading_weights: float = 0.0
|
||||
counter_after_loading_weights: float = 0.0
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
|
||||
extra_config = load_config.model_loader_extra_config
|
||||
allowed_keys = {"enable_multithread_load", "num_threads"}
|
||||
unexpected_keys = set(extra_config.keys()) - allowed_keys
|
||||
|
||||
if unexpected_keys:
|
||||
raise ValueError(
|
||||
f"Unexpected extra config keys for load format "
|
||||
f"{load_config.load_format}: "
|
||||
f"{unexpected_keys}"
|
||||
)
|
||||
|
||||
def _prepare_weights(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
revision: str | None,
|
||||
fall_back_to_pt: bool,
|
||||
allow_patterns_overrides: list[str] | None,
|
||||
) -> tuple[str, list[str], bool]:
|
||||
"""Prepare weights for the model.
|
||||
|
||||
If the model is not local, it will be downloaded."""
|
||||
model_name_or_path = (
|
||||
maybe_download_from_modelscope(model_name_or_path, revision)
|
||||
or model_name_or_path
|
||||
)
|
||||
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
load_format = self.load_config.load_format
|
||||
use_safetensors = False
|
||||
index_file = SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
# First check for 'auto' format that mistral files format are present.
|
||||
# This is to load mistral models with official format by default.
|
||||
if load_format == "auto":
|
||||
load_format = (
|
||||
"mistral"
|
||||
if len(
|
||||
list_filtered_repo_files(
|
||||
model_name_or_path=model_name_or_path,
|
||||
allow_patterns=["consolidated*.safetensors"],
|
||||
revision=revision,
|
||||
)
|
||||
)
|
||||
> 0
|
||||
else "hf"
|
||||
)
|
||||
|
||||
# Some quantized models use .pt files for storing the weights.
|
||||
if load_format == "hf":
|
||||
allow_patterns = ["*.safetensors", "*.bin"]
|
||||
elif load_format == "safetensors" or load_format == "fastsafetensors":
|
||||
use_safetensors = True
|
||||
allow_patterns = ["*.safetensors"]
|
||||
elif load_format == "mistral":
|
||||
use_safetensors = True
|
||||
allow_patterns = ["consolidated*.safetensors"]
|
||||
index_file = "consolidated.safetensors.index.json"
|
||||
elif load_format == "pt":
|
||||
allow_patterns = ["*.pt"]
|
||||
elif load_format == "npcache":
|
||||
allow_patterns = ["*.bin"]
|
||||
else:
|
||||
raise ValueError(f"Unknown load_format: {load_format}")
|
||||
|
||||
if fall_back_to_pt:
|
||||
allow_patterns += ["*.pt"]
|
||||
|
||||
if allow_patterns_overrides is not None:
|
||||
allow_patterns = allow_patterns_overrides
|
||||
|
||||
if not is_local:
|
||||
hf_folder = download_weights_from_hf(
|
||||
model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
allow_patterns,
|
||||
revision,
|
||||
ignore_patterns=self.load_config.ignore_patterns,
|
||||
)
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
|
||||
hf_weights_files: list[str] = []
|
||||
for pattern in allow_patterns:
|
||||
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
||||
if len(hf_weights_files) > 0:
|
||||
if pattern == "*.safetensors":
|
||||
use_safetensors = True
|
||||
break
|
||||
|
||||
if use_safetensors:
|
||||
# For models like Mistral-7B-Instruct-v0.3
|
||||
# there are both sharded safetensors files and a consolidated
|
||||
# safetensors file. Using both breaks.
|
||||
# Here, we download the `model.safetensors.index.json` and filter
|
||||
# any files not found in the index.
|
||||
if not is_local:
|
||||
download_safetensors_index_file_from_hf(
|
||||
model_name_or_path,
|
||||
index_file,
|
||||
self.load_config.download_dir,
|
||||
revision,
|
||||
)
|
||||
hf_weights_files = filter_duplicate_safetensors_files(
|
||||
hf_weights_files, hf_folder, index_file
|
||||
)
|
||||
else:
|
||||
hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files)
|
||||
|
||||
if len(hf_weights_files) == 0:
|
||||
raise RuntimeError(
|
||||
f"Cannot find any model weights with `{model_name_or_path}`"
|
||||
)
|
||||
|
||||
return hf_folder, hf_weights_files, use_safetensors
|
||||
|
||||
def _get_weights_iterator(
|
||||
self, source: "Source"
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Get an iterator for the model weights based on the load format."""
|
||||
extra_config = self.load_config.model_loader_extra_config
|
||||
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
|
||||
source.model_or_path,
|
||||
source.revision,
|
||||
source.fall_back_to_pt,
|
||||
source.allow_patterns_overrides,
|
||||
)
|
||||
if self.load_config.load_format == "npcache":
|
||||
# Currently np_cache only support *.bin checkpoints
|
||||
assert use_safetensors is False
|
||||
weights_iterator = np_cache_weights_iterator(
|
||||
source.model_or_path,
|
||||
self.load_config.download_dir,
|
||||
hf_folder,
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
)
|
||||
elif use_safetensors:
|
||||
if self.load_config.load_format == "fastsafetensors":
|
||||
weights_iterator = fastsafetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
)
|
||||
else:
|
||||
if extra_config.get("enable_multithread_load"):
|
||||
weights_iterator = multi_thread_safetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
max_workers=extra_config.get(
|
||||
"num_threads", self.DEFAULT_NUM_THREADS
|
||||
),
|
||||
)
|
||||
else:
|
||||
weights_iterator = safetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
self.load_config.safetensors_load_strategy,
|
||||
)
|
||||
else:
|
||||
if extra_config.get("enable_multithread_load"):
|
||||
weights_iterator = multi_thread_pt_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
self.load_config.pt_load_map_location,
|
||||
max_workers=extra_config.get(
|
||||
"num_threads", self.DEFAULT_NUM_THREADS
|
||||
),
|
||||
)
|
||||
else:
|
||||
weights_iterator = pt_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
self.load_config.pt_load_map_location,
|
||||
)
|
||||
|
||||
if self.counter_before_loading_weights == 0.0:
|
||||
self.counter_before_loading_weights = time.perf_counter()
|
||||
# Apply the prefix.
|
||||
return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator)
|
||||
|
||||
def get_all_weights(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
model: nn.Module,
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
primary_weights = DefaultModelLoader.Source(
|
||||
model_config.model,
|
||||
model_config.revision,
|
||||
prefix="",
|
||||
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True),
|
||||
allow_patterns_overrides=getattr(model, "allow_patterns_overrides", None),
|
||||
)
|
||||
yield from self._get_weights_iterator(primary_weights)
|
||||
|
||||
secondary_weights = cast(
|
||||
Iterable[DefaultModelLoader.Source],
|
||||
getattr(model, "secondary_weights", ()),
|
||||
)
|
||||
for source in secondary_weights:
|
||||
yield from self._get_weights_iterator(source)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(
|
||||
model_config.model,
|
||||
model_config.revision,
|
||||
fall_back_to_pt=True,
|
||||
allow_patterns_overrides=None,
|
||||
)
|
||||
|
||||
@instrument(span_name="Load weights")
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
if model_config.quantization == "torchao":
|
||||
quant_config = get_quant_config(model_config, self.load_config)
|
||||
if (
|
||||
hasattr(quant_config, "is_checkpoint_torchao_serialized")
|
||||
and quant_config.is_checkpoint_torchao_serialized
|
||||
and torchao_version_at_least("0.15.0")
|
||||
):
|
||||
self.load_config.safetensors_load_strategy = "torchao"
|
||||
|
||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||
all_weights = self.get_all_weights(model_config, model)
|
||||
loaded_weights = model.load_weights(all_weights)
|
||||
|
||||
self.counter_after_loading_weights = time.perf_counter()
|
||||
logger.info_once(
|
||||
"Loading weights took %.2f seconds",
|
||||
self.counter_after_loading_weights - self.counter_before_loading_weights,
|
||||
scope="local",
|
||||
)
|
||||
# We only enable strict check for non-quantized models
|
||||
# that have loaded weights tracking currently.
|
||||
if model_config.quantization is None and loaded_weights is not None:
|
||||
weights_not_loaded = weights_to_load - loaded_weights
|
||||
if weights_not_loaded:
|
||||
raise ValueError(
|
||||
"Following weights were not initialized from "
|
||||
f"checkpoint: {weights_not_loaded}"
|
||||
)
|
||||
28
vllm/model_executor/model_loader/dummy_loader.py
Normal file
28
vllm/model_executor/model_loader/dummy_loader.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import initialize_dummy_weights
|
||||
|
||||
|
||||
class DummyModelLoader(BaseModelLoader):
|
||||
"""Model loader that will set model weights to random values."""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
if load_config.model_loader_extra_config:
|
||||
raise ValueError(
|
||||
f"Model loader extra config is not supported for "
|
||||
f"load format {load_config.load_format}"
|
||||
)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
pass # Nothing to download
|
||||
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||
# random values to the weights.
|
||||
initialize_dummy_weights(model, model_config)
|
||||
366
vllm/model_executor/model_loader/gguf_loader.py
Normal file
366
vllm/model_executor/model_loader/gguf_loader.py
Normal file
@@ -0,0 +1,366 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
import gguf
|
||||
import regex as re
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
initialize_model,
|
||||
process_weights_after_loading,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_gguf,
|
||||
get_gguf_extra_tensor_names,
|
||||
get_gguf_weight_type_map,
|
||||
gguf_quant_weights_iterator,
|
||||
)
|
||||
from vllm.transformers_utils.gguf_utils import detect_gguf_multimodal
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GGUFModelLoader(BaseModelLoader):
|
||||
"""
|
||||
Model loader that can load GGUF files. This is useful for loading models
|
||||
that are quantized with GGUF and saved in the GGUF format. This loader
|
||||
supports loading both full models and sharded models.
|
||||
"""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
if load_config.model_loader_extra_config:
|
||||
raise ValueError(
|
||||
f"Model loader extra config is not supported for "
|
||||
f"load format {load_config.load_format}"
|
||||
)
|
||||
|
||||
def _prepare_weights(self, model_config: ModelConfig):
|
||||
model_name_or_path = model_config.model
|
||||
if os.path.isfile(model_name_or_path):
|
||||
return model_name_or_path
|
||||
# repo id/filename.gguf
|
||||
if "/" in model_name_or_path and model_name_or_path.endswith(".gguf"):
|
||||
repo_id, filename = model_name_or_path.rsplit("/", 1)
|
||||
return hf_hub_download(repo_id=repo_id, filename=filename)
|
||||
# repo_id:quant_type
|
||||
elif "/" in model_name_or_path and ":" in model_name_or_path:
|
||||
repo_id, quant_type = model_name_or_path.rsplit(":", 1)
|
||||
return download_gguf(
|
||||
repo_id,
|
||||
quant_type,
|
||||
cache_dir=self.load_config.download_dir,
|
||||
revision=model_config.revision,
|
||||
ignore_patterns=self.load_config.ignore_patterns,
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Unrecognised GGUF reference: {model_name_or_path} "
|
||||
"(expected local file, <repo_id>/<filename>.gguf, "
|
||||
"or <repo_id>:<quant_type>)"
|
||||
)
|
||||
|
||||
def _get_gguf_weights_map(self, model_config: ModelConfig):
|
||||
"""
|
||||
GGUF uses this naming convention for their tensors from HF checkpoint:
|
||||
`blk.N.BB.weight` and `blk.N.BB.bias`
|
||||
where N signifies the block number of a layer, and BB signifies the
|
||||
attention/mlp layer components.
|
||||
See "Standardized tensor names" in
|
||||
https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
|
||||
"""
|
||||
config = model_config.hf_config
|
||||
# Get text config to handle both nested (multimodal) and flat
|
||||
# (text-only) config structures. For multimodal models like
|
||||
# Gemma3Config, this returns config.text_config. For text-only
|
||||
# models, this returns config itself.
|
||||
text_config = config.get_text_config()
|
||||
model_type = config.model_type
|
||||
is_multimodal = (
|
||||
hasattr(config, "vision_config") and config.vision_config is not None
|
||||
)
|
||||
gguf_to_hf_name_map = {}
|
||||
sideload_params: list[re.Pattern] = []
|
||||
# hack: ggufs have a different name than transformers
|
||||
if model_type == "cohere":
|
||||
model_type = "command-r"
|
||||
if model_type == "gemma3_text":
|
||||
# Gemma3 models use "gemma3_text" in HuggingFace but
|
||||
# "gemma3" in GGUF architecture naming
|
||||
model_type = "gemma3"
|
||||
if model_type in ("deepseek_v3", "deepseek_v2"):
|
||||
model_type = "deepseek2"
|
||||
# GGUF layer map assumes that we will have a merged expert weights
|
||||
# so we need to map them manually
|
||||
for idx in range(config.num_hidden_layers):
|
||||
gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = (
|
||||
f"model.layers.{idx}.mlp.gate.e_score_correction_bias"
|
||||
)
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = (
|
||||
f"model.layers.{idx}.mlp.experts.0.down_proj.weight"
|
||||
)
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = (
|
||||
f"model.layers.{idx}.mlp.experts.0.gate_proj.weight"
|
||||
)
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = (
|
||||
f"model.layers.{idx}.mlp.experts.0.up_proj.weight"
|
||||
)
|
||||
sideload_params.append(
|
||||
re.compile(
|
||||
f"model\\.layers\\.{idx}"
|
||||
r"\.mlp\.experts\.[0-9]+\.(gate|up|down)_proj\.weight"
|
||||
)
|
||||
)
|
||||
if model_type in ("qwen2_moe", "qwen3_moe"):
|
||||
model_type = model_type.replace("_", "")
|
||||
# GGUF layer map assumes that we will have a merged expert weights
|
||||
# so we need to map them manually
|
||||
for idx in range(config.num_hidden_layers):
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = (
|
||||
f"model.layers.{idx}.mlp.experts.0.down_proj.weight"
|
||||
)
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = (
|
||||
f"model.layers.{idx}.mlp.experts.0.gate_proj.weight"
|
||||
)
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = (
|
||||
f"model.layers.{idx}.mlp.experts.0.up_proj.weight"
|
||||
)
|
||||
sideload_params.append(
|
||||
re.compile(
|
||||
f"model\\.layers\\.{idx}"
|
||||
r"\.mlp\.experts\.[0-9]+\.(gate|up|down)_proj\.weight"
|
||||
)
|
||||
)
|
||||
|
||||
arch = None
|
||||
for key, value in gguf.MODEL_ARCH_NAMES.items():
|
||||
if value == model_type:
|
||||
arch = key
|
||||
break
|
||||
if arch is None:
|
||||
raise RuntimeError(f"Unknown gguf model_type: {model_type}")
|
||||
text_num_layers = text_config.num_hidden_layers
|
||||
text_name_map = gguf.get_tensor_name_map(arch, text_num_layers)
|
||||
|
||||
if is_multimodal:
|
||||
mm_proj_arch = gguf.MODEL_ARCH.MMPROJ
|
||||
vision_num_layers = config.vision_config.num_hidden_layers
|
||||
vision_name_map = gguf.get_tensor_name_map(mm_proj_arch, vision_num_layers)
|
||||
else:
|
||||
vision_name_map = None
|
||||
|
||||
# Create dummy model to extract parameter names
|
||||
# For multimodal: use AutoModelForImageTextToText to get
|
||||
# language + vision + projector params
|
||||
# For text-only: use AutoModelForCausalLM to get language model params
|
||||
auto_cls = (
|
||||
AutoModelForImageTextToText if is_multimodal else AutoModelForCausalLM
|
||||
)
|
||||
with torch.device("meta"):
|
||||
dummy_model = auto_cls.from_config(
|
||||
config, trust_remote_code=model_config.trust_remote_code
|
||||
)
|
||||
|
||||
state_dict = dummy_model.state_dict()
|
||||
if hf_checkpoint_map := getattr(
|
||||
dummy_model, "_checkpoint_conversion_mapping", None
|
||||
):
|
||||
|
||||
def revert_hf_rename(name: str) -> str:
|
||||
for original_name, hf_name in hf_checkpoint_map.items():
|
||||
if hf_name in name:
|
||||
name = name.replace(hf_name, original_name).lstrip("^")
|
||||
return name
|
||||
|
||||
state_dict = {
|
||||
revert_hf_rename(name): tensor for name, tensor in state_dict.items()
|
||||
}
|
||||
|
||||
def find_hf_name_in_tensor_map(hf_name: str) -> str | None:
|
||||
"""
|
||||
Map HuggingFace parameter name to GGUF tensor name.
|
||||
|
||||
This function handles the mismatch between HF parameter naming
|
||||
conventions and gguf-py's expected format:
|
||||
1. Strips 'model.' prefix (common in multimodal models)
|
||||
2. Converts '_weight' suffix to '.weight' (Gemma3 compatibility)
|
||||
3. Searches vision_name_map for multimodal parameters
|
||||
4. Falls back to text_name_map for language model parameters
|
||||
|
||||
Args:
|
||||
hf_name: Full HuggingFace parameter name (e.g.,
|
||||
'model.multi_modal_projector.mm_soft_emb_norm.weight')
|
||||
|
||||
Returns:
|
||||
GGUF tensor name with suffix (e.g., 'mm.soft_emb_norm.weight')
|
||||
or None if no mapping found
|
||||
"""
|
||||
# Strip 'language_model.' prefix for multimodal models - gguf-py
|
||||
# tensor mappings expect parameter names without this prefix.
|
||||
# Note: 'model.' prefix should be KEPT for text-only models as
|
||||
# gguf-py expects it.
|
||||
if hf_name.startswith("language_model."):
|
||||
hf_name = hf_name[15:] # Remove 'language_model.'
|
||||
|
||||
# Parse parameter name and suffix
|
||||
if hf_name.endswith((".weight", ".bias")):
|
||||
base_name, suffix = hf_name.rsplit(".", 1)
|
||||
else:
|
||||
base_name, suffix = hf_name, ""
|
||||
# Handle '_weight' suffix (Gemma3 naming: parameter ends with
|
||||
# '_weight' instead of '.weight')
|
||||
if base_name.endswith("_weight"):
|
||||
base_name = base_name[:-7] # Remove '_weight'
|
||||
suffix = "weight"
|
||||
|
||||
gguf_name = None
|
||||
# Priority 1: Search vision/projector parameters for multimodal models
|
||||
if vision_name_map is not None:
|
||||
gguf_name = vision_name_map.get_name(base_name)
|
||||
|
||||
# Priority 2: Search text backbone parameters
|
||||
if gguf_name is None:
|
||||
gguf_name = text_name_map.get_name(base_name)
|
||||
|
||||
if gguf_name is None:
|
||||
return None
|
||||
|
||||
return gguf_name + "." + suffix
|
||||
|
||||
# Build mapping and track unmapped parameters
|
||||
unmapped_params = []
|
||||
for hf_name in state_dict:
|
||||
gguf_name_with_suffix = find_hf_name_in_tensor_map(hf_name)
|
||||
|
||||
# Track mapping success
|
||||
if gguf_name_with_suffix is not None:
|
||||
gguf_to_hf_name_map[gguf_name_with_suffix] = hf_name
|
||||
logger.debug("Mapped GGUF %s → HF %s", gguf_name_with_suffix, hf_name)
|
||||
elif hf_name not in gguf_to_hf_name_map.values():
|
||||
# Parameter not in manual overrides either
|
||||
unmapped_params.append(hf_name)
|
||||
|
||||
# All parameters (except those initialized by other means) must be mapped:
|
||||
# both vision/projector and backbone
|
||||
if unmapped_params:
|
||||
unmapped_params = list(
|
||||
filter(
|
||||
lambda x: not any(re.fullmatch(p, x) for p in sideload_params),
|
||||
unmapped_params,
|
||||
)
|
||||
)
|
||||
if unmapped_params:
|
||||
raise RuntimeError(
|
||||
f"Failed to map GGUF parameters "
|
||||
f"({len(unmapped_params)}): "
|
||||
f"{unmapped_params}"
|
||||
)
|
||||
return gguf_to_hf_name_map
|
||||
|
||||
def _get_gguf_weight_type(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
model_name_or_path: str,
|
||||
gguf_to_hf_name_map: dict[str, str],
|
||||
) -> dict[str, str]:
|
||||
weight_type_map = get_gguf_weight_type_map(
|
||||
model_name_or_path, gguf_to_hf_name_map
|
||||
)
|
||||
is_multimodal = hasattr(model_config.hf_config, "vision_config")
|
||||
if is_multimodal:
|
||||
mmproj_file = detect_gguf_multimodal(model_name_or_path)
|
||||
assert mmproj_file is not None, (
|
||||
"Could not find mm_proj file for multimodal GGUF model"
|
||||
)
|
||||
logger.info("Loading extra mm_proj weights from %s...", mmproj_file)
|
||||
mm_proj_weight_type_map = get_gguf_weight_type_map(
|
||||
mmproj_file, gguf_to_hf_name_map
|
||||
)
|
||||
weight_type_map.update(mm_proj_weight_type_map)
|
||||
return weight_type_map
|
||||
|
||||
def _get_weights_iterator(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
model_name_or_path: str,
|
||||
gguf_to_hf_name_map: dict[str, str],
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""
|
||||
Iterate over GGUF model weights, loading from both main model file and
|
||||
mmproj.gguf for multimodal Gemma3 models.
|
||||
|
||||
For Gemma3 multimodal GGUF models:
|
||||
- Main file (gemma-3-*.gguf): Language model weights (model.*)
|
||||
- mmproj file (mmproj*.gguf): Vision tower + projector weights (v.*, mm.*)
|
||||
|
||||
Yields:
|
||||
Tuples of (parameter_name, tensor) for all model weights
|
||||
"""
|
||||
hf_config = model_config.hf_config
|
||||
is_multimodal = hasattr(hf_config, "vision_config")
|
||||
|
||||
if is_multimodal:
|
||||
# Load mm_proj (mm_encoder + projector) for multimodal weights
|
||||
mmproj_file = detect_gguf_multimodal(model_name_or_path)
|
||||
assert mmproj_file is not None, (
|
||||
"Could not find mm_proj file for multimodal GGUF model"
|
||||
)
|
||||
yield from gguf_quant_weights_iterator(mmproj_file, gguf_to_hf_name_map)
|
||||
|
||||
yield from gguf_quant_weights_iterator(model_name_or_path, gguf_to_hf_name_map)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(model_config)
|
||||
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
local_model_path = self._prepare_weights(model_config)
|
||||
gguf_weights_map = self._get_gguf_weights_map(model_config)
|
||||
model.load_weights(
|
||||
self._get_weights_iterator(model_config, local_model_path, gguf_weights_map)
|
||||
)
|
||||
|
||||
def load_model(
|
||||
self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = ""
|
||||
) -> nn.Module:
|
||||
device_config = vllm_config.device_config
|
||||
local_model_path = self._prepare_weights(model_config)
|
||||
gguf_weights_map = self._get_gguf_weights_map(model_config)
|
||||
# we can only know if tie word embeddings after mapping weights
|
||||
if "lm_head.weight" in get_gguf_extra_tensor_names(
|
||||
local_model_path, gguf_weights_map
|
||||
):
|
||||
model_config.hf_config.update({"tie_word_embeddings": True})
|
||||
|
||||
weight_type_map = self._get_gguf_weight_type(
|
||||
model_config, local_model_path, gguf_weights_map
|
||||
)
|
||||
# filter out unquantized modules to skip
|
||||
unquant_names = [
|
||||
name.removesuffix(".weight")
|
||||
for name, weight_type in weight_type_map.items()
|
||||
if weight_type in ("F32", "F16", "BF16") and name.endswith(".weight")
|
||||
]
|
||||
logger.debug(
|
||||
"GGUF unquantized modules: %s",
|
||||
unquant_names,
|
||||
)
|
||||
vllm_config.quant_config.unquantized_modules.extend(unquant_names)
|
||||
|
||||
target_device = torch.device(device_config.device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
model = initialize_model(vllm_config=vllm_config, prefix=prefix)
|
||||
self.load_weights(model, model_config)
|
||||
|
||||
process_weights_after_loading(model, model_config, target_device)
|
||||
return model
|
||||
37
vllm/model_executor/model_loader/reload/__init__.py
Normal file
37
vllm/model_executor/model_loader/reload/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
Layerwise weight reloading utilities for vLLM.
|
||||
|
||||
This module provides functionality to reload model weights layer-by-layer,
|
||||
which is useful for weight updates without full model reconstruction.
|
||||
|
||||
Limitations:
|
||||
1. Composition with CPU offloading has not been implemented
|
||||
2. Reloading Attention/MLA weights (q_scale, k_scale, v_scale) has not been implemented
|
||||
3. Tied parameters will only reflect processing from one of the parent layers (for
|
||||
example, only processing from embed_tokens will have an effect)
|
||||
4. This design assumes that the number of weights loaded from disk is the same as the
|
||||
number of weights created at model init time. This is not true for quant methods
|
||||
which (1) pad weights or (2) load qkv weights into the same parameter. Both of these
|
||||
cases are non-issues for today's quant methods, but future quantizations may cause
|
||||
reloading to fail
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"record_metadata_for_reloading",
|
||||
"initialize_layerwise_reload",
|
||||
"finalize_layerwise_reload",
|
||||
"set_torchao_reload_attrs",
|
||||
"support_quantized_model_reload_from_hp_weights",
|
||||
]
|
||||
|
||||
from .layerwise import (
|
||||
finalize_layerwise_reload,
|
||||
initialize_layerwise_reload,
|
||||
record_metadata_for_reloading,
|
||||
)
|
||||
from .torchao_decorator import (
|
||||
set_torchao_reload_attrs,
|
||||
support_quantized_model_reload_from_hp_weights,
|
||||
)
|
||||
275
vllm/model_executor/model_loader/reload/layerwise.py
Normal file
275
vllm/model_executor/model_loader/reload/layerwise.py
Normal file
@@ -0,0 +1,275 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import inspect
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from weakref import WeakKeyDictionary
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention, MLAAttention
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
|
||||
from .meta import (
|
||||
capture_layer_to_meta,
|
||||
get_numel_loaded,
|
||||
materialize_layer,
|
||||
restore_layer_on_meta,
|
||||
)
|
||||
from .types import LayerReloadingInfo
|
||||
from .utils import get_layer_params_buffers, get_layer_size, get_layer_tensors
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
__all__ = [
|
||||
"get_layerwise_info",
|
||||
"record_metadata_for_reloading",
|
||||
"initialize_layerwise_reload",
|
||||
"finalize_layerwise_reload",
|
||||
]
|
||||
|
||||
|
||||
# Global dict storing information used for layerwise restoring, loading, and processing.
|
||||
# For more information regarding what info is stored when, see `LayerReloadingInfo`
|
||||
#
|
||||
# Use a weak ref dictionary so that modules can be freed when the model is freed.
|
||||
# Values are sanitized from references to the layer key in order to avoid circular refs
|
||||
LAYERWISE_INFO: WeakKeyDictionary[torch.nn.Module, LayerReloadingInfo] = (
|
||||
WeakKeyDictionary()
|
||||
)
|
||||
|
||||
|
||||
def get_layerwise_info(layer: torch.nn.Module) -> LayerReloadingInfo:
|
||||
"""
|
||||
Get information related to restoring and layerwise processing. If no previous
|
||||
information existed, a new entry is constructed
|
||||
"""
|
||||
if layer not in LAYERWISE_INFO:
|
||||
LAYERWISE_INFO[layer] = LayerReloadingInfo()
|
||||
|
||||
return LAYERWISE_INFO[layer]
|
||||
|
||||
|
||||
def record_metadata_for_reloading(model: torch.nn.Module):
|
||||
"""
|
||||
Record layer metadata needed for later reloading.
|
||||
|
||||
Stores parameter and buffer metadata as meta tensors for restoration.
|
||||
Must be called before `initialize_layerwise_reload`.
|
||||
"""
|
||||
for layer in model.modules():
|
||||
info = get_layerwise_info(layer)
|
||||
info.restore_metadata = capture_layer_to_meta(layer)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def initialize_layerwise_reload(model: torch.nn.Module):
|
||||
"""
|
||||
Set up layerwise weight loading with deferred processing.
|
||||
|
||||
Must be called after `record_metadata_for_reloading`. This function:
|
||||
1. Saves current kernel tensors for later copying
|
||||
2. Restores layer parameters/buffers from metadata (on meta device)
|
||||
3. Wraps weight loaders to defer processing until all weights are loaded
|
||||
|
||||
When all weights for a layer are loaded, the wrapped loaders will:
|
||||
1. Materialize the layer onto the target device
|
||||
2. Load all cached weights
|
||||
3. Run quantization processing if applicable
|
||||
4. Copy processed values back to original tensor storage
|
||||
"""
|
||||
# disable torchao reloading to avoid infinite recursion
|
||||
model._original_do_torchao_reload = getattr(model, "_do_torchao_reload", False)
|
||||
model._do_torchao_reload = False
|
||||
|
||||
for layer in model.modules():
|
||||
info = get_layerwise_info(layer)
|
||||
|
||||
# Skip if the layer has already been initialized
|
||||
if info.can_process():
|
||||
continue
|
||||
|
||||
# Save current tensors for later copying
|
||||
info.kernel_tensors = get_layer_params_buffers(layer)
|
||||
|
||||
# Restore layer parameters/buffers onto meta device
|
||||
restore_layer_on_meta(layer, info)
|
||||
|
||||
# Track loading progress to determine when to process/copy
|
||||
info.load_numel = 0
|
||||
info.load_numel_total = get_layer_size(layer)
|
||||
|
||||
# Wrap each parameter's weight loader
|
||||
# Note that nested wrapping will occur for shared tensors
|
||||
for name, tensor in get_layer_tensors(layer).items():
|
||||
if _get_weight_loader(tensor).__name__ != "online_process_loader":
|
||||
tensor.weight_loader = make_online_process_loader(layer, name)
|
||||
|
||||
|
||||
def make_online_process_loader(layer: torch.nn.Module, param_name: str) -> Callable:
|
||||
"""Create a wrapped weight loader that defers processing."""
|
||||
info = get_layerwise_info(layer)
|
||||
param = getattr(layer, param_name)
|
||||
original_loader = _get_original_loader(param)
|
||||
loader_signature = inspect.signature(original_loader)
|
||||
|
||||
@wraps(original_loader, assigned=("__doc__", "__annotations__"))
|
||||
def online_process_loader(*args, **kwargs):
|
||||
if not info.can_process():
|
||||
# Unfortunately, some qconfigs are set up to load the same weight
|
||||
# multiple times. For example, CT_WNA16 loads `weight_shape` for
|
||||
# each of the qkv partitions. This results in layers loading extra
|
||||
# weights (beyond load_numel_total) after it's already processed.
|
||||
#
|
||||
# Best solution is to ensure that `load_numel_total` reflects the
|
||||
# actual number of weights loaded, either by modifying qconfigs to
|
||||
# create as many weights as loaded (see padding issue as well)
|
||||
# or maybe capturing how many weights are loaded on first pass
|
||||
#
|
||||
# For now, `load_numel_total` is still safe to use as long as
|
||||
# there's no way to reach `load_numel_total` without loading all
|
||||
# necessary weights. `weight_shape` is very small, so this is safe.
|
||||
# see Limitations(4)
|
||||
logger.debug("%s: Excessive loading", layer.__class__.__name__)
|
||||
return
|
||||
|
||||
# Bind and normalize arguments
|
||||
bound_args = loader_signature.bind(*args, **kwargs)
|
||||
bound_args.apply_defaults()
|
||||
|
||||
# Cache loaded weights, track loading progress
|
||||
info.loaded_weights.append((param_name, bound_args))
|
||||
num_loaded, ret = get_numel_loaded(original_loader, bound_args)
|
||||
info.load_numel += num_loaded
|
||||
|
||||
logger.debug(
|
||||
"%s: %d / %d",
|
||||
layer.__class__.__name__,
|
||||
info.load_numel,
|
||||
info.load_numel_total,
|
||||
)
|
||||
|
||||
# Process and copy when all weights are loaded
|
||||
if info.load_numel >= info.load_numel_total and not isinstance( # type: ignore[operator]
|
||||
layer, (Attention, MLAAttention)
|
||||
):
|
||||
_layerwise_process(layer, info)
|
||||
|
||||
return ret
|
||||
|
||||
return online_process_loader
|
||||
|
||||
|
||||
def finalize_layerwise_reload(model: torch.nn.Module, model_config: ModelConfig):
|
||||
"""
|
||||
Remove the outermost layer of weight loading wrappers.
|
||||
|
||||
This function should be applied after `initialize_layerwise_reload` is applied
|
||||
unwrap the layerwise weight loaders.
|
||||
|
||||
Also processes Attention/MLA layers, which must be processed after all other layers
|
||||
"""
|
||||
model._do_torchao_reload = model._original_do_torchao_reload
|
||||
|
||||
for layer in model.modules():
|
||||
info = get_layerwise_info(layer)
|
||||
|
||||
# Attention/MLA layers are processed after all other layers
|
||||
if isinstance(layer, (Attention, MLAAttention)):
|
||||
if info.load_numel > 0:
|
||||
raise NotImplementedError(
|
||||
"Layerwise reloading of Q/K/V scale weights is not implemented yet"
|
||||
)
|
||||
|
||||
else:
|
||||
_place_kernel_tensors(layer, info)
|
||||
layer.process_weights_after_loading(model_config.dtype)
|
||||
|
||||
# No weights were loaded, place kernel tensors back
|
||||
elif info.can_process() and info.load_numel <= 0:
|
||||
_place_kernel_tensors(layer, info)
|
||||
|
||||
# Process non-attention layers which did not load all elements. This can happen
|
||||
# if the created weight has extra padding elements which are not loaded
|
||||
# Having too many of these delayed layers can lead to execess memory usage
|
||||
# see Limitations(4)
|
||||
elif info.load_numel > 0 and info.load_numel < info.load_numel_total: # type: ignore[operator]
|
||||
logger.debug("%s: Delayed processing", layer.__class__.__name__)
|
||||
_layerwise_process(layer, info)
|
||||
|
||||
info.reset()
|
||||
|
||||
|
||||
def _layerwise_process(layer: torch.nn.Module, info: LayerReloadingInfo):
|
||||
"""
|
||||
Finalize layer loading after all weights have been cached.
|
||||
|
||||
This function:
|
||||
1. Materializes the layer onto the target device
|
||||
2. Loads all cached weights
|
||||
3. Runs quantization processing if applicable
|
||||
4. Copies processed values back to original tensor storage
|
||||
"""
|
||||
# Materialize layer tensors onto device
|
||||
materialize_layer(layer)
|
||||
|
||||
# Reset FP8 online quantization flag so process_weights_after_loading
|
||||
# will run again during reload
|
||||
if hasattr(layer, "_already_called_process_weights_after_loading"):
|
||||
delattr(layer, "_already_called_process_weights_after_loading")
|
||||
|
||||
# Unwrap layerwise loading wrappers
|
||||
for param in get_layer_tensors(layer).values():
|
||||
param.weight_loader = _get_original_loader(param)
|
||||
|
||||
# Load all cached weights into materialized layer (using original loaders)
|
||||
for name, args in info.loaded_weights:
|
||||
param = getattr(layer, name)
|
||||
args.arguments["param"] = param
|
||||
param.weight_loader(*args.args, **args.kwargs)
|
||||
|
||||
# Process weights (quantization, repacking, etc.)
|
||||
# Attention/MLA are processed in `finalize_layerwise_reload`
|
||||
quant_method = getattr(layer, "quant_method", None)
|
||||
if isinstance(quant_method, QuantizeMethodBase):
|
||||
quant_method.process_weights_after_loading(layer)
|
||||
|
||||
# Copy processed values into original tensor storage (preserves cudagraph refs)
|
||||
# this code is a no-op if not reloading (because kernel tensors is empty)
|
||||
parameters, buffers = info.kernel_tensors
|
||||
for name, param in parameters.items():
|
||||
param.data.copy_(getattr(layer, name))
|
||||
for name, buffer in buffers.items():
|
||||
buffer.data.copy_(getattr(layer, name))
|
||||
|
||||
_place_kernel_tensors(layer, info)
|
||||
|
||||
info.reset()
|
||||
logger.debug("%s: Processed", layer.__class__.__name__)
|
||||
|
||||
|
||||
def _get_original_loader(tensor: torch.Tensor) -> Callable:
|
||||
"""Return the weight loader with any layerwise wrappers removed"""
|
||||
loader = _get_weight_loader(tensor)
|
||||
while loader.__name__ == "online_process_loader":
|
||||
loader = loader.__wrapped__ # type: ignore[union-attr]
|
||||
|
||||
return loader
|
||||
|
||||
|
||||
def _get_weight_loader(tensor: torch.Tensor):
|
||||
return getattr(tensor, "weight_loader", default_weight_loader)
|
||||
|
||||
|
||||
def _place_kernel_tensors(layer: torch.nn.Module, info: LayerReloadingInfo):
|
||||
for name in get_layer_tensors(layer):
|
||||
delattr(layer, name)
|
||||
|
||||
parameters, buffers = info.kernel_tensors
|
||||
for name, param in parameters.items():
|
||||
layer.register_parameter(name, param)
|
||||
for name, buffer in buffers.items():
|
||||
layer.register_buffer(name, buffer)
|
||||
146
vllm/model_executor/model_loader/reload/meta.py
Normal file
146
vllm/model_executor/model_loader/reload/meta.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import inspect
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
from .sanitize import restore_layer_refs, sanitize_layer_refs
|
||||
from .types import LayerReloadingInfo, LayerTensors
|
||||
from .utils import get_layer_params_buffers, get_layer_tensors
|
||||
|
||||
__all__ = [
|
||||
"to_meta_tensor",
|
||||
"materialize_meta_tensor",
|
||||
"capture_layer_to_meta",
|
||||
"restore_layer_on_meta",
|
||||
"materialize_layer",
|
||||
"get_numel_loaded",
|
||||
]
|
||||
|
||||
SKIP_MODULES: set[str] = {"HadamardTransform"}
|
||||
|
||||
SKIP_TENSORS: set[str] = {
|
||||
"_expert_map",
|
||||
"expert_mask",
|
||||
"expert_global_to_physical",
|
||||
"expert_physical_to_global",
|
||||
"expert_local_to_global",
|
||||
}
|
||||
|
||||
|
||||
def to_meta_tensor(tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert a tensor to a meta tensor while preserving class and attributes."""
|
||||
meta_tensor = tensor.data.to("meta")
|
||||
meta_tensor.__class__ = tensor.__class__
|
||||
meta_tensor.__dict__ = tensor.__dict__.copy()
|
||||
return meta_tensor
|
||||
|
||||
|
||||
def materialize_meta_tensor(meta_tensor: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Materialize a meta tensor into an actual tensor on the current device.
|
||||
Should be called within the torch device context for the given rank.
|
||||
"""
|
||||
tensor = torch.empty_strided(
|
||||
size=tuple(meta_tensor.size()),
|
||||
stride=tuple(meta_tensor.stride()),
|
||||
dtype=meta_tensor.dtype,
|
||||
requires_grad=False,
|
||||
)
|
||||
tensor.__class__ = meta_tensor.__class__
|
||||
tensor.__dict__ = meta_tensor.__dict__.copy()
|
||||
return tensor
|
||||
|
||||
|
||||
def capture_layer_to_meta(layer: torch.nn.Module) -> LayerTensors:
|
||||
if layer.__class__.__name__ in SKIP_MODULES:
|
||||
return ({}, {})
|
||||
|
||||
params, buffers = get_layer_params_buffers(layer)
|
||||
return (
|
||||
{
|
||||
name: sanitize_layer_refs(to_meta_tensor(param), layer)
|
||||
for name, param in params.items()
|
||||
if name not in SKIP_TENSORS
|
||||
},
|
||||
{
|
||||
name: sanitize_layer_refs(to_meta_tensor(buffer), layer)
|
||||
for name, buffer in buffers.items()
|
||||
if name not in SKIP_TENSORS
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def restore_layer_on_meta(layer: torch.nn.Module, info: LayerReloadingInfo):
|
||||
"""Restore a layer to model format with tensors on the meta device"""
|
||||
if layer.__class__.__name__ in SKIP_MODULES:
|
||||
return
|
||||
|
||||
for name in get_layer_tensors(layer):
|
||||
if name not in SKIP_TENSORS:
|
||||
delattr(layer, name)
|
||||
|
||||
restore_params, restore_buffers = info.restore_metadata
|
||||
for name, param in restore_params.items():
|
||||
if name not in SKIP_TENSORS:
|
||||
param = restore_layer_refs(param, layer)
|
||||
layer.register_parameter(name, param)
|
||||
|
||||
for name, buffer in restore_buffers.items():
|
||||
if name not in SKIP_TENSORS:
|
||||
buffer = restore_layer_refs(buffer, layer)
|
||||
layer.register_buffer(name, buffer)
|
||||
|
||||
|
||||
def materialize_layer(layer: torch.nn.Module) -> None:
|
||||
"""Materialize all meta tensors in a layer to actual tensors."""
|
||||
if layer.__class__.__name__ in SKIP_MODULES:
|
||||
return
|
||||
|
||||
for name, tensor in get_layer_tensors(layer).items():
|
||||
if name not in SKIP_TENSORS:
|
||||
setattr(layer, name, materialize_meta_tensor(tensor))
|
||||
|
||||
|
||||
class MetaCopyCounter(TorchDispatchMode):
|
||||
"""
|
||||
Tracks total number of elements modified with `copy_`.
|
||||
|
||||
Useful for keeping track of weight loading where underlying weights can be
|
||||
arbitrarily transformed (such as with `narrow`) before calling copy.
|
||||
|
||||
Note: Assumes that copy kwargs are not used.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.copied_numel = 0
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
if func is torch.ops.aten.copy_.default and args[0].device.type == "meta":
|
||||
assert args[0].numel() == args[1].numel()
|
||||
self.copied_numel += args[0].numel()
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def get_numel_loaded(
|
||||
weight_loader: Callable, args: inspect.BoundArguments
|
||||
) -> tuple[int, object]:
|
||||
"""
|
||||
Determine how many elements would be loaded by a weight loader call.
|
||||
|
||||
:param weight loader: used to load weights
|
||||
:param args: bound arguments to weight loader
|
||||
:return: number of elements loaded by the weight loader, the return value of the
|
||||
weight loader
|
||||
"""
|
||||
assert args.arguments["param"].device.type == "meta"
|
||||
with MetaCopyCounter() as counter:
|
||||
return_value = weight_loader(*args.args, **args.kwargs)
|
||||
return counter.copied_numel, return_value
|
||||
50
vllm/model_executor/model_loader/reload/sanitize.py
Normal file
50
vllm/model_executor/model_loader/reload/sanitize.py
Normal file
@@ -0,0 +1,50 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from types import MethodType
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ["sanitize_layer_refs", "restore_layer_refs"]
|
||||
|
||||
|
||||
layer_ref_sentinel = object()
|
||||
|
||||
|
||||
def sanitize_layer_refs(tensor: torch.Tensor, layer: torch.nn.Module) -> torch.Tensor:
|
||||
"""
|
||||
Removes references to layer held by tensor attributes. Specifically, removes the
|
||||
`__self__` attribute of weight loader methods attached to the tensor.
|
||||
|
||||
Used by `capture_layer_to_meta` to avoid circular references to layers in
|
||||
`LAYERWISE_INFO`, leading to modules never being cleaned up. Without sanitation,
|
||||
tensors will reference layers, and the WeakKeyDictionary will never evict entries,
|
||||
even when the model is deleted.
|
||||
|
||||
:param tensor: tensor to be sanitized
|
||||
:param layer: layer whose references should be removed
|
||||
:return: sanitized tensor
|
||||
"""
|
||||
for key, value in tensor.__dict__.items():
|
||||
if isinstance(value, MethodType) and value.__self__ is layer:
|
||||
tensor.__dict__[key] = value.__func__.__get__(layer_ref_sentinel)
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
def restore_layer_refs(tensor: torch.Tensor, layer: torch.nn.Module) -> torch.Tensor:
|
||||
"""
|
||||
Restores references to layer held by tensor attributes.
|
||||
|
||||
Used by `restore_layer_on_meta` to add back layer references, allowing for proper
|
||||
weight loading.
|
||||
|
||||
:param tensor: tensor to be sanitized
|
||||
:param layer: layer whose references should be removed
|
||||
:return: sanitized tensor
|
||||
|
||||
"""
|
||||
for key, value in tensor.__dict__.items():
|
||||
if isinstance(value, MethodType) and value.__self__ is layer_ref_sentinel:
|
||||
tensor.__dict__[key] = value.__func__.__get__(layer)
|
||||
|
||||
return tensor
|
||||
58
vllm/model_executor/model_loader/reload/torchao_decorator.py
Normal file
58
vllm/model_executor/model_loader/reload/torchao_decorator.py
Normal file
@@ -0,0 +1,58 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from collections.abc import Iterable
|
||||
from functools import wraps
|
||||
from types import FunctionType
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
|
||||
from .layerwise import (
|
||||
finalize_layerwise_reload,
|
||||
initialize_layerwise_reload,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.model_executor.models.utils import AutoWeightsLoader
|
||||
|
||||
__all__ = ["set_torchao_reload_attrs", "support_quantized_model_reload_from_hp_weights"]
|
||||
|
||||
|
||||
def set_torchao_reload_attrs(model: torch.nn.Module, model_config: ModelConfig):
|
||||
model._do_torchao_reload = True
|
||||
model._model_config = model_config
|
||||
|
||||
|
||||
def support_quantized_model_reload_from_hp_weights(original_load_weights: FunctionType):
|
||||
"""
|
||||
Decorator for `load_weights` method for AutoWeightsLoader.load_weights to support
|
||||
reloading high precision (bfloat16/float16/float32) weight for an already quantized
|
||||
model, this involves restoring the weights to a high precision weights and
|
||||
then online quantize the weights.
|
||||
|
||||
Only applies to torchao quantized models. Assumes that all model weights are
|
||||
loaded within a single weights iterator (cannot perform batched updates)
|
||||
"""
|
||||
|
||||
@wraps(original_load_weights)
|
||||
def patched_model_load_weights(
|
||||
self: "AutoWeightsLoader",
|
||||
weights: Iterable[tuple[str, torch.Tensor]],
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
model = self.module
|
||||
|
||||
if not getattr(model, "_do_torchao_reload", False):
|
||||
return original_load_weights(self, weights, *args, **kwargs)
|
||||
|
||||
initialize_layerwise_reload(model)
|
||||
loaded_weights = original_load_weights(self, weights, *args, **kwargs)
|
||||
finalize_layerwise_reload(model, model._model_config)
|
||||
|
||||
return loaded_weights
|
||||
|
||||
return patched_model_load_weights
|
||||
33
vllm/model_executor/model_loader/reload/types.py
Normal file
33
vllm/model_executor/model_loader/reload/types.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass, field
|
||||
from inspect import BoundArguments
|
||||
|
||||
import torch
|
||||
|
||||
__all__ = ["LayerTensors", "LayerReloadingInfo"]
|
||||
|
||||
# encodes both parameters and buffers separately
|
||||
LayerTensors = tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class LayerReloadingInfo:
|
||||
# model format (meta), populated by `record_metadata_for_reloading`
|
||||
restore_metadata: LayerTensors = field(default_factory=lambda: ({}, {}))
|
||||
|
||||
# kernel format (device)
|
||||
kernel_tensors: LayerTensors = field(default_factory=lambda: ({}, {}))
|
||||
|
||||
# track how many restored elements are ready for loading
|
||||
load_numel: int = 0
|
||||
load_numel_total: int | None = None
|
||||
|
||||
# stores arguments and tensors ready for loading
|
||||
loaded_weights: list[tuple[str, BoundArguments]] = field(default_factory=list)
|
||||
|
||||
def reset(self):
|
||||
self.__init__(restore_metadata=self.restore_metadata) # type: ignore[misc]
|
||||
|
||||
def can_process(self) -> bool:
|
||||
return self.load_numel_total is not None
|
||||
31
vllm/model_executor/model_loader/reload/utils.py
Normal file
31
vllm/model_executor/model_loader/reload/utils.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from .types import LayerTensors
|
||||
|
||||
__all__ = [
|
||||
"get_layer_tensors",
|
||||
"get_layer_params_buffers",
|
||||
"get_layer_size",
|
||||
]
|
||||
|
||||
|
||||
def get_layer_tensors(layer: torch.nn.Module) -> dict[str, torch.Tensor]:
|
||||
"""Get all parameters and buffers from a module as a dict."""
|
||||
params, buffers = get_layer_params_buffers(layer)
|
||||
return params | buffers
|
||||
|
||||
|
||||
def get_layer_params_buffers(layer: torch.nn.Module) -> LayerTensors:
|
||||
"""Get all parameters and buffers of a module as a tuple of dicts."""
|
||||
return (
|
||||
{name: param for name, param in layer._parameters.items() if param is not None},
|
||||
{name: buffer for name, buffer in layer._buffers.items() if buffer is not None},
|
||||
)
|
||||
|
||||
|
||||
def get_layer_size(layer: torch.nn.Module) -> int:
|
||||
"""Calculate total number of elements across all tensors in a layer."""
|
||||
return sum(tensor.numel() for tensor in get_layer_tensors(layer).values())
|
||||
115
vllm/model_executor/model_loader/runai_streamer_loader.py
Normal file
115
vllm/model_executor/model_loader/runai_streamer_loader.py
Normal file
@@ -0,0 +1,115 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_safetensors_index_file_from_hf,
|
||||
download_weights_from_hf,
|
||||
runai_safetensors_weights_iterator,
|
||||
)
|
||||
from vllm.transformers_utils.runai_utils import is_runai_obj_uri, list_safetensors
|
||||
|
||||
|
||||
class RunaiModelStreamerLoader(BaseModelLoader):
|
||||
"""
|
||||
Model loader that can load safetensors
|
||||
files from local FS or S3 bucket.
|
||||
"""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
|
||||
self._is_distributed = False
|
||||
if load_config.model_loader_extra_config:
|
||||
extra_config = load_config.model_loader_extra_config
|
||||
|
||||
if "distributed" in extra_config and isinstance(
|
||||
extra_config.get("distributed"), bool
|
||||
):
|
||||
self._is_distributed = extra_config.get("distributed")
|
||||
|
||||
if "concurrency" in extra_config and isinstance(
|
||||
extra_config.get("concurrency"), int
|
||||
):
|
||||
os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(
|
||||
extra_config.get("concurrency")
|
||||
)
|
||||
|
||||
if "memory_limit" in extra_config and isinstance(
|
||||
extra_config.get("memory_limit"), int
|
||||
):
|
||||
os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(
|
||||
extra_config.get("memory_limit")
|
||||
)
|
||||
|
||||
runai_streamer_s3_endpoint = os.getenv("RUNAI_STREAMER_S3_ENDPOINT")
|
||||
aws_endpoint_url = os.getenv("AWS_ENDPOINT_URL")
|
||||
if runai_streamer_s3_endpoint is None and aws_endpoint_url is not None:
|
||||
os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url
|
||||
|
||||
def _prepare_weights(
|
||||
self, model_name_or_path: str, revision: str | None
|
||||
) -> list[str]:
|
||||
"""Prepare weights for the model.
|
||||
|
||||
If the model is not local, it will be downloaded."""
|
||||
|
||||
is_object_storage_path = is_runai_obj_uri(model_name_or_path)
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
safetensors_pattern = "*.safetensors"
|
||||
index_file = SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
hf_folder = (
|
||||
model_name_or_path
|
||||
if (is_local or is_object_storage_path)
|
||||
else download_weights_from_hf(
|
||||
model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
[safetensors_pattern],
|
||||
revision,
|
||||
ignore_patterns=self.load_config.ignore_patterns,
|
||||
)
|
||||
)
|
||||
hf_weights_files = list_safetensors(path=hf_folder)
|
||||
|
||||
if not is_local and not is_object_storage_path:
|
||||
download_safetensors_index_file_from_hf(
|
||||
model_name_or_path, index_file, self.load_config.download_dir, revision
|
||||
)
|
||||
|
||||
if not hf_weights_files:
|
||||
raise RuntimeError(
|
||||
f"Cannot find any safetensors model weights with `{model_name_or_path}`"
|
||||
)
|
||||
|
||||
return hf_weights_files
|
||||
|
||||
def _get_weights_iterator(
|
||||
self, model_or_path: str, revision: str
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Get an iterator for the model weights based on the load format."""
|
||||
hf_weights_files = self._prepare_weights(model_or_path, revision)
|
||||
return runai_safetensors_weights_iterator(
|
||||
hf_weights_files, self.load_config.use_tqdm_on_load, self._is_distributed
|
||||
)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
"""Download model if necessary"""
|
||||
self._prepare_weights(model_config.model, model_config.revision)
|
||||
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
"""Load weights into a model."""
|
||||
model_weights = model_config.model
|
||||
if model_weights_override := model_config.model_weights:
|
||||
model_weights = model_weights_override
|
||||
model.load_weights(
|
||||
self._get_weights_iterator(model_weights, model_config.revision)
|
||||
)
|
||||
214
vllm/model_executor/model_loader/sharded_state_loader.py
Normal file
214
vllm/model_executor/model_loader/sharded_state_loader.py
Normal file
@@ -0,0 +1,214 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import collections
|
||||
import glob
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_weights_from_hf,
|
||||
runai_safetensors_weights_iterator,
|
||||
)
|
||||
from vllm.transformers_utils.s3_utils import glob as s3_glob
|
||||
from vllm.transformers_utils.utils import is_s3
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ShardedStateLoader(BaseModelLoader):
|
||||
"""
|
||||
Model loader that directly loads each worker's model state dict, which
|
||||
enables a fast load path for large tensor-parallel models where each worker
|
||||
only needs to read its own shard rather than the entire checkpoint. See
|
||||
`examples/offline_inference/save_sharded_state.py` for creating a sharded
|
||||
checkpoint.
|
||||
"""
|
||||
|
||||
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
|
||||
extra_config = (
|
||||
{}
|
||||
if load_config.model_loader_extra_config is None
|
||||
else load_config.model_loader_extra_config.copy()
|
||||
)
|
||||
self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
|
||||
if extra_config:
|
||||
raise ValueError(
|
||||
f"Unexpected extra config keys for load format "
|
||||
f"{load_config.load_format}: "
|
||||
f"{load_config.model_loader_extra_config.keys()}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _filter_subtensors(
|
||||
tensors: dict[str, torch.Tensor],
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Filter out all tensors that share the same memory or a subset of the
|
||||
memory of another tensor.
|
||||
"""
|
||||
same_storage_groups: dict[Any, list[tuple[str, torch.Tensor]]] = (
|
||||
collections.defaultdict(list)
|
||||
)
|
||||
for key, tensor in tensors.items():
|
||||
if tensor.numel():
|
||||
ptr = tensor.untyped_storage().data_ptr()
|
||||
same_storage_groups[tensor.device, ptr].append((key, tensor))
|
||||
|
||||
def get_end_ptr(tensor: torch.Tensor) -> int:
|
||||
return tensor.view(-1)[-1].data_ptr() + tensor.element_size()
|
||||
|
||||
result: dict[str, torch.Tensor] = {}
|
||||
for group in same_storage_groups.values():
|
||||
for k, t in group:
|
||||
a, b = t.data_ptr(), get_end_ptr(t)
|
||||
for k2, t2 in group:
|
||||
if not t2.is_contiguous():
|
||||
continue
|
||||
a2, b2 = t2.data_ptr(), get_end_ptr(t2)
|
||||
if a < a2 or b2 < b:
|
||||
continue
|
||||
if a2 < a or b < b2 or not t.is_contiguous():
|
||||
break # t2 covers strictly more memory than t.
|
||||
if k2 < k:
|
||||
# Same tensors, keep the one with the smaller key.
|
||||
break
|
||||
else:
|
||||
result[k] = t
|
||||
return result
|
||||
|
||||
def _prepare_weights(self, model_name_or_path: str, revision: str | None):
|
||||
if is_s3(model_name_or_path) or os.path.isdir(model_name_or_path):
|
||||
return model_name_or_path
|
||||
else:
|
||||
allow_patterns = ["*.safetensors"]
|
||||
return download_weights_from_hf(
|
||||
model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
allow_patterns,
|
||||
revision,
|
||||
ignore_patterns=self.load_config.ignore_patterns,
|
||||
)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(model_config.model, model_config.revision)
|
||||
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
model_weights = model_config.model
|
||||
if model_weights_override := model_config.model_weights:
|
||||
model_weights = model_weights_override
|
||||
local_model_path = model_weights
|
||||
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
pattern = os.path.join(
|
||||
local_model_path,
|
||||
self.pattern.format(rank=rank, part="*"),
|
||||
)
|
||||
|
||||
filepaths = []
|
||||
if is_s3(local_model_path):
|
||||
file_pattern = f"*{self.pattern.format(rank=rank, part='*')}"
|
||||
filepaths = s3_glob(path=local_model_path, allow_pattern=[file_pattern])
|
||||
else:
|
||||
filepaths = glob.glob(pattern)
|
||||
if not filepaths:
|
||||
# TODO: support un-sharded checkpoints too
|
||||
raise ValueError(
|
||||
f"Could not find checkpoint files '{pattern}', only "
|
||||
f"pre-sharded checkpoints are currently supported!"
|
||||
)
|
||||
state_dict = self._filter_subtensors(model.state_dict())
|
||||
counter_before_loading_weights = time.perf_counter()
|
||||
for key, tensor in self.iterate_over_files(filepaths):
|
||||
# If loading with LoRA enabled, additional padding may
|
||||
# be added to certain parameters. We only load into a
|
||||
# narrowed view of the parameter data.
|
||||
param_data = state_dict[key].data
|
||||
param_shape = state_dict[key].shape
|
||||
for dim, size in enumerate(tensor.shape):
|
||||
if size < param_shape[dim]:
|
||||
param_data = param_data.narrow(dim, 0, size)
|
||||
if tensor.shape != param_shape:
|
||||
logger.warning(
|
||||
"loading tensor of shape %s into parameter '%s' of shape %s",
|
||||
tensor.shape,
|
||||
key,
|
||||
param_shape,
|
||||
)
|
||||
param_data.copy_(tensor)
|
||||
state_dict.pop(key)
|
||||
counter_after_loading_weights = time.perf_counter()
|
||||
logger.info_once(
|
||||
"Loading weights took %.2f seconds",
|
||||
counter_after_loading_weights - counter_before_loading_weights,
|
||||
scope="local",
|
||||
)
|
||||
if state_dict:
|
||||
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
|
||||
|
||||
def iterate_over_files(
|
||||
self, paths
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
if self.load_config.load_format == "runai_streamer_sharded":
|
||||
yield from runai_safetensors_weights_iterator(paths, True)
|
||||
else:
|
||||
from safetensors.torch import safe_open
|
||||
|
||||
for path in paths:
|
||||
with safe_open(path, framework="pt") as f:
|
||||
for key in f.keys(): # noqa: SIM118
|
||||
tensor = f.get_tensor(key)
|
||||
yield key, tensor
|
||||
|
||||
@staticmethod
|
||||
def save_model(
|
||||
model: torch.nn.Module,
|
||||
path: str,
|
||||
pattern: str | None = None,
|
||||
max_size: int | None = None,
|
||||
) -> None:
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
if pattern is None:
|
||||
pattern = ShardedStateLoader.DEFAULT_PATTERN
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
part_idx = 0
|
||||
total_size = 0
|
||||
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
|
||||
state_dict_part: dict[str, torch.Tensor] = {}
|
||||
for key, tensor in state_dict.items():
|
||||
param_size = tensor.nelement() * tensor.element_size()
|
||||
if max_size is not None and total_size + param_size > max_size:
|
||||
filename = pattern.format(rank=rank, part=part_idx)
|
||||
save_file(
|
||||
state_dict_part,
|
||||
os.path.join(path, filename),
|
||||
)
|
||||
part_idx += 1
|
||||
total_size = 0
|
||||
state_dict_part = {}
|
||||
state_dict_part[key] = tensor
|
||||
total_size += param_size
|
||||
if len(state_dict_part) > 0:
|
||||
filename = pattern.format(rank=rank, part=part_idx)
|
||||
save_file(
|
||||
state_dict_part,
|
||||
os.path.join(path, filename),
|
||||
)
|
||||
793
vllm/model_executor/model_loader/tensorizer.py
Normal file
793
vllm/model_executor/model_loader/tensorizer.py
Normal file
@@ -0,0 +1,793 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import contextvars
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Generator, MutableMapping
|
||||
from dataclasses import asdict, dataclass, field, fields
|
||||
from typing import TYPE_CHECKING, Any, ClassVar
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from torch import nn
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig, ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
|
||||
try:
|
||||
from tensorizer import (
|
||||
DecryptionParams,
|
||||
EncryptionParams,
|
||||
TensorDeserializer,
|
||||
TensorSerializer,
|
||||
)
|
||||
from tensorizer.stream_io import open_stream
|
||||
from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor
|
||||
|
||||
except ImportError:
|
||||
tensorizer = PlaceholderModule("tensorizer")
|
||||
DecryptionParams = tensorizer.placeholder_attr("DecryptionParams")
|
||||
EncryptionParams = tensorizer.placeholder_attr("EncryptionParams")
|
||||
TensorDeserializer = tensorizer.placeholder_attr("TensorDeserializer")
|
||||
TensorSerializer = tensorizer.placeholder_attr("TensorSerializer")
|
||||
open_stream = tensorizer.placeholder_attr("stream_io.open_stream")
|
||||
convert_bytes = tensorizer.placeholder_attr("utils.convert_bytes")
|
||||
get_mem_usage = tensorizer.placeholder_attr("utils.get_mem_usage")
|
||||
no_init_or_tensor = tensorizer.placeholder_attr("utils.no_init_or_tensor")
|
||||
|
||||
__all__ = [
|
||||
"EncryptionParams",
|
||||
"DecryptionParams",
|
||||
"TensorDeserializer",
|
||||
"TensorSerializer",
|
||||
"open_stream",
|
||||
"convert_bytes",
|
||||
"get_mem_usage",
|
||||
"no_init_or_tensor",
|
||||
"TensorizerConfig",
|
||||
]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def is_valid_deserialization_uri(uri: str | None) -> bool:
|
||||
if uri:
|
||||
scheme = uri.lower().split("://")[0]
|
||||
return scheme in {"s3", "http", "https"} or os.path.exists(uri)
|
||||
return False
|
||||
|
||||
|
||||
def tensorizer_kwargs_arg(value):
|
||||
loaded = json.loads(value)
|
||||
if not isinstance(loaded, dict):
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"Not deserializable to dict: {value}. serialization_kwargs and "
|
||||
f"deserialization_kwargs must be "
|
||||
f"deserializable from a JSON string to a dictionary. "
|
||||
)
|
||||
return loaded
|
||||
|
||||
|
||||
class MetaTensorMode(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
|
||||
if func._schema.name == "aten::empty" and "device" not in kwargs:
|
||||
kwargs["device"] = "meta"
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def meta_tensor_mode(
|
||||
loading_code=None,
|
||||
):
|
||||
if loading_code is None:
|
||||
return _NoInitOrTensorImpl.context_manager()
|
||||
elif callable(loading_code):
|
||||
with _NoInitOrTensorImpl.context_manager():
|
||||
return loading_code()
|
||||
else:
|
||||
raise TypeError(
|
||||
"expected a callable to evaluate,"
|
||||
" or None if being used as a context manager;"
|
||||
f' got an object of type "{type(loading_code).__name__}" instead.'
|
||||
)
|
||||
|
||||
|
||||
class _NoInitOrTensorImpl:
|
||||
_MODULES = (torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm)
|
||||
_MODULE_ORIGINALS = tuple((m, m.reset_parameters) for m in _MODULES)
|
||||
|
||||
is_active = contextvars.ContextVar("_NoInitOrTensorImpl.is_active", default=False)
|
||||
_count_active: int = 0
|
||||
_count_active_lock = threading.Lock()
|
||||
|
||||
@classmethod
|
||||
@contextlib.contextmanager
|
||||
def context_manager(cls):
|
||||
if cls.is_active.get():
|
||||
yield
|
||||
return
|
||||
|
||||
with cls._count_active_lock:
|
||||
cls._count_active += 1
|
||||
if cls._count_active == 1:
|
||||
for mod in cls._MODULES:
|
||||
mod.reset_parameters = cls._disable(mod.reset_parameters)
|
||||
|
||||
reset_token = cls.is_active.set(True)
|
||||
|
||||
try:
|
||||
with MetaTensorMode():
|
||||
yield
|
||||
finally:
|
||||
cls.is_active.reset(reset_token)
|
||||
with cls._count_active_lock:
|
||||
cls._count_active -= 1
|
||||
if cls._count_active == 0:
|
||||
for mod, original in cls._MODULE_ORIGINALS:
|
||||
mod.reset_parameters = original
|
||||
|
||||
@staticmethod
|
||||
def _disable(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
if not _NoInitOrTensorImpl.is_active.get():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorizerConfig(MutableMapping):
|
||||
tensorizer_uri: str | None = None
|
||||
tensorizer_dir: str | None = None
|
||||
vllm_tensorized: bool | None = None
|
||||
verify_hash: bool | None = None
|
||||
num_readers: int | None = None
|
||||
encryption_keyfile: str | None = None
|
||||
s3_access_key_id: str | None = None
|
||||
s3_secret_access_key: str | None = None
|
||||
s3_endpoint: str | None = None
|
||||
lora_dir: str | None = None
|
||||
stream_kwargs: dict[str, Any] | None = None
|
||||
serialization_kwargs: dict[str, Any] | None = None
|
||||
deserialization_kwargs: dict[str, Any] | None = None
|
||||
_extra_serialization_attrs: dict[str, Any] | None = field(init=False, default=None)
|
||||
model_class: type[torch.nn.Module] | None = field(init=False, default=None)
|
||||
hf_config: PretrainedConfig | None = field(init=False, default=None)
|
||||
dtype: str | torch.dtype | None = field(init=False, default=None)
|
||||
_is_sharded: bool = field(init=False, default=False)
|
||||
_fields: ClassVar[tuple[str, ...]]
|
||||
_keys: ClassVar[frozenset[str]]
|
||||
"""Configuration class for Tensorizer settings.
|
||||
|
||||
These settings configure the behavior of model serialization and
|
||||
deserialization using Tensorizer.
|
||||
|
||||
Attributes:
|
||||
tensorizer_uri: Path to serialized model tensors. Can be a local file
|
||||
path or a S3 URI. This is a required field unless lora_dir is
|
||||
provided and the config is meant to be used for the
|
||||
`tensorize_lora_adapter` function. Unless a `tensorizer_dir` or
|
||||
`lora_dir` is passed to this object's initializer, this is
|
||||
a required argument.
|
||||
tensorizer_dir: Path to a directory containing serialized model tensors,
|
||||
and all other potential model artifacts to load the model, such as
|
||||
configs and tokenizer files. Can be passed instead of
|
||||
`tensorizer_uri` where the `model.tensors` file will be assumed
|
||||
to be in this directory.
|
||||
vllm_tensorized: If True, indicates that the serialized model is a
|
||||
vLLM model. This is used to determine the behavior of the
|
||||
TensorDeserializer when loading tensors from a serialized model.
|
||||
It is far faster to deserialize a vLLM model as it utilizes
|
||||
tensorizer's optimized GPU loading. Note that this is now
|
||||
deprecated, as serialized vLLM models are now automatically
|
||||
inferred as vLLM models.
|
||||
verify_hash: If True, the hashes of each tensor will be verified
|
||||
against the hashes stored in the metadata. A `HashMismatchError`
|
||||
will be raised if any of the hashes do not match.
|
||||
num_readers: Controls how many threads are allowed to read concurrently
|
||||
from the source file. Default is `None`, which will dynamically set
|
||||
the number of readers based on the number of available
|
||||
resources and model size. This greatly increases performance.
|
||||
encryption_keyfile: File path to a binary file containing a
|
||||
binary key to use for decryption. `None` (the default) means
|
||||
no decryption. See the example script in
|
||||
examples/others/tensorize_vllm_model.py.
|
||||
s3_access_key_id: The access key for the S3 bucket. Can also be set via
|
||||
the S3_ACCESS_KEY_ID environment variable.
|
||||
s3_secret_access_key: The secret access key for the S3 bucket. Can also
|
||||
be set via the S3_SECRET_ACCESS_KEY environment variable.
|
||||
s3_endpoint: The endpoint for the S3 bucket. Can also be set via the
|
||||
S3_ENDPOINT_URL environment variable.
|
||||
lora_dir: Path to a directory containing LoRA adapter artifacts for
|
||||
serialization or deserialization. When serializing LoRA adapters
|
||||
this is the only necessary parameter to pass to this object's
|
||||
initializer.
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
# check if the configuration is for a sharded vLLM model
|
||||
self._is_sharded = (
|
||||
isinstance(self.tensorizer_uri, str)
|
||||
and re.search(r"%0\dd", self.tensorizer_uri) is not None
|
||||
)
|
||||
|
||||
if self.tensorizer_dir and self.lora_dir:
|
||||
raise ValueError(
|
||||
"Only one of tensorizer_dir or lora_dir may be specified. "
|
||||
"Use lora_dir exclusively when serializing LoRA adapters, "
|
||||
"and tensorizer_dir or tensorizer_uri otherwise."
|
||||
)
|
||||
if self.tensorizer_dir and self.tensorizer_uri:
|
||||
logger.warning_once(
|
||||
"Provided both tensorizer_dir and tensorizer_uri. "
|
||||
"Inferring tensorizer_dir from tensorizer_uri as the "
|
||||
"latter takes precedence."
|
||||
)
|
||||
self.tensorizer_dir = os.path.dirname(self.tensorizer_uri)
|
||||
if not self.tensorizer_uri:
|
||||
if self.lora_dir:
|
||||
self.tensorizer_uri = f"{self.lora_dir}/adapter_model.tensors"
|
||||
elif self.tensorizer_dir:
|
||||
self.tensorizer_uri = f"{self.tensorizer_dir}/model.tensors"
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unable to resolve tensorizer_uri. "
|
||||
"A valid tensorizer_uri or tensorizer_dir "
|
||||
"must be provided for deserialization, and a "
|
||||
"valid tensorizer_uri, tensorizer_uri, or "
|
||||
"lora_dir for serialization."
|
||||
)
|
||||
else:
|
||||
self.tensorizer_dir = os.path.dirname(self.tensorizer_uri)
|
||||
|
||||
if not self.serialization_kwargs:
|
||||
self.serialization_kwargs = {}
|
||||
if not self.deserialization_kwargs:
|
||||
self.deserialization_kwargs = {}
|
||||
|
||||
def to_serializable(self) -> dict[str, Any]:
|
||||
# Due to TensorizerConfig needing to be msgpack-serializable, it needs
|
||||
# support for morphing back and forth between itself and its dict
|
||||
# representation
|
||||
|
||||
# TensorizerConfig's representation as a dictionary is meant to be
|
||||
# linked to TensorizerConfig in such a way that the following is
|
||||
# technically initializable:
|
||||
# TensorizerConfig(**my_tensorizer_cfg.to_serializable())
|
||||
|
||||
# This means the dict must not retain non-initializable parameters
|
||||
# and post-init attribute states
|
||||
|
||||
# Also don't want to retain private and unset parameters, so only retain
|
||||
# not None values and public attributes
|
||||
|
||||
raw_tc_dict = asdict(self)
|
||||
blacklisted = []
|
||||
|
||||
if "tensorizer_uri" in raw_tc_dict and "tensorizer_dir" in raw_tc_dict:
|
||||
blacklisted.append("tensorizer_dir")
|
||||
|
||||
if "tensorizer_dir" in raw_tc_dict and "lora_dir" in raw_tc_dict:
|
||||
blacklisted.append("tensorizer_dir")
|
||||
|
||||
tc_dict = {}
|
||||
for k, v in raw_tc_dict.items():
|
||||
if (
|
||||
k not in blacklisted
|
||||
and k not in tc_dict
|
||||
and not k.startswith("_")
|
||||
and v is not None
|
||||
):
|
||||
tc_dict[k] = v
|
||||
|
||||
return tc_dict
|
||||
|
||||
def _construct_tensorizer_args(self) -> "TensorizerArgs":
|
||||
return TensorizerArgs(self) # type: ignore
|
||||
|
||||
def verify_with_parallel_config(
|
||||
self,
|
||||
parallel_config: "ParallelConfig",
|
||||
) -> None:
|
||||
if parallel_config.tensor_parallel_size > 1 and not self._is_sharded:
|
||||
raise ValueError(
|
||||
"For a sharded model, tensorizer_uri should include a"
|
||||
" string format template like '%04d' to be formatted"
|
||||
" with the rank of the shard"
|
||||
)
|
||||
|
||||
def verify_with_model_config(self, model_config: "ModelConfig") -> None:
|
||||
if model_config.quantization is not None and self.tensorizer_uri is not None:
|
||||
logger.warning(
|
||||
"Loading a model using Tensorizer with quantization on vLLM"
|
||||
" is unstable and may lead to errors."
|
||||
)
|
||||
|
||||
def open_stream(self, tensorizer_args: "TensorizerArgs | None" = None):
|
||||
if tensorizer_args is None:
|
||||
tensorizer_args = self._construct_tensorizer_args()
|
||||
|
||||
return open_stream(self.tensorizer_uri, **tensorizer_args.stream_kwargs)
|
||||
|
||||
def keys(self):
|
||||
return self._keys
|
||||
|
||||
def __len__(self):
|
||||
return len(fields(self))
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._fields)
|
||||
|
||||
def __getitem__(self, item: str) -> Any:
|
||||
if item not in self.keys():
|
||||
raise KeyError(item)
|
||||
return getattr(self, item)
|
||||
|
||||
def __setitem__(self, key: str, value: Any) -> None:
|
||||
if key not in self.keys():
|
||||
# Disallow modifying invalid keys
|
||||
raise KeyError(key)
|
||||
setattr(self, key, value)
|
||||
|
||||
def __delitem__(self, key, /):
|
||||
if key not in self.keys():
|
||||
raise KeyError(key)
|
||||
delattr(self, key)
|
||||
|
||||
|
||||
TensorizerConfig._fields = tuple(f.name for f in fields(TensorizerConfig))
|
||||
TensorizerConfig._keys = frozenset(TensorizerConfig._fields)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorizerArgs:
|
||||
tensorizer_uri: str | None = None
|
||||
tensorizer_dir: str | None = None
|
||||
encryption_keyfile: str | None = None
|
||||
|
||||
def __init__(self, tensorizer_config: TensorizerConfig):
|
||||
for k, v in tensorizer_config.items():
|
||||
setattr(self, k, v)
|
||||
self.file_obj = tensorizer_config.tensorizer_uri
|
||||
self.s3_access_key_id = (
|
||||
tensorizer_config.s3_access_key_id or envs.S3_ACCESS_KEY_ID
|
||||
)
|
||||
self.s3_secret_access_key = (
|
||||
tensorizer_config.s3_secret_access_key or envs.S3_SECRET_ACCESS_KEY
|
||||
)
|
||||
self.s3_endpoint = tensorizer_config.s3_endpoint or envs.S3_ENDPOINT_URL
|
||||
|
||||
self.stream_kwargs = {
|
||||
"s3_access_key_id": tensorizer_config.s3_access_key_id,
|
||||
"s3_secret_access_key": tensorizer_config.s3_secret_access_key,
|
||||
"s3_endpoint": tensorizer_config.s3_endpoint,
|
||||
**(tensorizer_config.stream_kwargs or {}),
|
||||
}
|
||||
|
||||
self.deserialization_kwargs = {
|
||||
"verify_hash": tensorizer_config.verify_hash,
|
||||
"encryption": tensorizer_config.encryption_keyfile,
|
||||
"num_readers": tensorizer_config.num_readers,
|
||||
**(tensorizer_config.deserialization_kwargs or {}),
|
||||
}
|
||||
|
||||
if self.encryption_keyfile:
|
||||
with open_stream(
|
||||
tensorizer_config.encryption_keyfile,
|
||||
**self.stream_kwargs,
|
||||
) as stream:
|
||||
key = stream.read()
|
||||
decryption_params = DecryptionParams.from_key(key)
|
||||
self.deserialization_kwargs["encryption"] = decryption_params
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
"""Tensorizer CLI arguments"""
|
||||
|
||||
# Tensorizer options arg group
|
||||
group = parser.add_argument_group(
|
||||
"tensorizer options",
|
||||
description=(
|
||||
"Options for configuring the behavior of the"
|
||||
" tensorizer deserializer when "
|
||||
"load_format=tensorizer is specified when "
|
||||
"initializing an LLMEngine, either via the CLI "
|
||||
"when running the vLLM OpenAI inference server "
|
||||
"with a JSON string passed to "
|
||||
"--model-loader-extra-config or as arguments given "
|
||||
"to TensorizerConfig when passed to "
|
||||
"model_loader_extra_config in the constructor "
|
||||
"for LLMEngine."
|
||||
),
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--tensorizer-uri",
|
||||
type=str,
|
||||
help="Path to serialized model tensors. Can be a local file path,"
|
||||
" or an HTTP(S) or S3 URI.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--verify-hash",
|
||||
action="store_true",
|
||||
help="If enabled, the hashes of each tensor will be verified"
|
||||
" against the hashes stored in the file metadata. An exception"
|
||||
" will be raised if any of the hashes do not match.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--encryption-keyfile",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The file path to a binary file containing a binary key to "
|
||||
"use for decryption. Can be a file path or S3 network URI.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-readers",
|
||||
default=None,
|
||||
type=int,
|
||||
help="Controls how many threads are allowed to read concurrently "
|
||||
"from the source file. Default is `None`, which will dynamically "
|
||||
"set the number of readers based on the available resources "
|
||||
"and model size. This greatly increases performance.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--s3-access-key-id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The access key for the S3 bucket. Can also be set via the "
|
||||
"S3_ACCESS_KEY_ID environment variable.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--s3-secret-access-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The secret access key for the S3 bucket. Can also be set via "
|
||||
"the S3_SECRET_ACCESS_KEY environment variable.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--s3-endpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The endpoint for the S3 bucket. Can also be set via the "
|
||||
"S3_ENDPOINT_URL environment variable.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace) -> "TensorizerArgs":
|
||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||
tensorizer_args = cls(
|
||||
**{attr: getattr(args, attr) for attr in attrs if hasattr(args, attr)}
|
||||
)
|
||||
return tensorizer_args
|
||||
|
||||
|
||||
def _check_tensors_on_meta_device(model: nn.Module) -> None:
|
||||
for tensor in model.state_dict().values():
|
||||
if tensor.device.type == "meta":
|
||||
raise ValueError(
|
||||
"The serialized model contains tensors on the meta device,"
|
||||
" indicating that some tensors were not loaded properly."
|
||||
" Please check that the parameters of the model being"
|
||||
" specified match that of the serialized model, such as"
|
||||
" its quantization."
|
||||
)
|
||||
|
||||
|
||||
def _resize_lora_embeddings(model: nn.Module):
|
||||
"""Modify LoRA embedding layers to use bigger tensors
|
||||
to allow for adapter added tokens."""
|
||||
for child in model.modules():
|
||||
if (
|
||||
isinstance(child, VocabParallelEmbedding)
|
||||
and child.weight.shape[0] < child.num_embeddings_per_partition
|
||||
):
|
||||
new_weight = torch.empty(
|
||||
child.num_embeddings_per_partition,
|
||||
child.embedding_dim,
|
||||
dtype=child.weight.dtype,
|
||||
device=child.weight.device,
|
||||
)
|
||||
new_weight[: child.weight.shape[0]].copy_(child.weight.data)
|
||||
new_weight[child.weight.shape[0] :].fill_(0)
|
||||
child.weight.data = new_weight
|
||||
|
||||
|
||||
def init_tensorizer_model(
|
||||
tensorizer_config: TensorizerConfig, vllm_config: VllmConfig
|
||||
) -> nn.Module:
|
||||
assert tensorizer_config.hf_config is not None
|
||||
model_args = tensorizer_config.hf_config
|
||||
model_args.dtype = tensorizer_config.dtype
|
||||
assert tensorizer_config.model_class is not None
|
||||
# TODO: Do we need to consider old-style model class?
|
||||
with meta_tensor_mode(), set_current_vllm_config(vllm_config, check_compile=True):
|
||||
return tensorizer_config.model_class(vllm_config=vllm_config)
|
||||
|
||||
|
||||
def deserialize_tensorizer_model(
|
||||
model: nn.Module, tensorizer_config: TensorizerConfig
|
||||
) -> None:
|
||||
tensorizer_args = tensorizer_config._construct_tensorizer_args()
|
||||
if not is_valid_deserialization_uri(tensorizer_config.tensorizer_uri):
|
||||
raise ValueError(
|
||||
f"{tensorizer_config.tensorizer_uri} is not a valid "
|
||||
f"tensorizer URI. Please check that the URI is correct. "
|
||||
f"It must either point to a local existing file, or have a "
|
||||
f"S3, HTTP or HTTPS scheme."
|
||||
)
|
||||
before_mem = get_mem_usage()
|
||||
start = time.perf_counter()
|
||||
with (
|
||||
open_stream(
|
||||
tensorizer_config.tensorizer_uri, mode="rb", **tensorizer_args.stream_kwargs
|
||||
) as stream,
|
||||
TensorDeserializer(
|
||||
stream,
|
||||
dtype=tensorizer_config.dtype,
|
||||
device=f"xpu:{torch.xpu.current_device()}"
|
||||
if current_platform.is_xpu()
|
||||
else f"cuda:{torch.cuda.current_device()}",
|
||||
**tensorizer_args.deserialization_kwargs,
|
||||
) as deserializer,
|
||||
):
|
||||
deserializer.load_into_module(model)
|
||||
end = time.perf_counter()
|
||||
|
||||
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
|
||||
duration = end - start
|
||||
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
|
||||
after_mem = get_mem_usage()
|
||||
deserializer.close()
|
||||
logger.info(
|
||||
"Deserialized %s in %0.2fs, %s/s", total_bytes_str, end - start, per_second
|
||||
)
|
||||
logger.info("Memory usage before: %s", before_mem)
|
||||
logger.info("Memory usage after: %s", after_mem)
|
||||
|
||||
_check_tensors_on_meta_device(model)
|
||||
_resize_lora_embeddings(model)
|
||||
del model.vllm_tensorized_marker
|
||||
|
||||
|
||||
def tensorizer_weights_iterator(
|
||||
tensorizer_args: "TensorizerArgs",
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
logger.warning(
|
||||
"Deserializing HuggingFace models is not optimized for "
|
||||
"loading on vLLM, as tensorizer is forced to load to CPU. "
|
||||
"Consider deserializing a vLLM model instead for faster "
|
||||
"load times. See the "
|
||||
"examples/others/tensorize_vllm_model.py example script "
|
||||
"for serializing vLLM models."
|
||||
)
|
||||
|
||||
deserializer_args = tensorizer_args.deserialization_kwargs
|
||||
stream_kwargs = tensorizer_args.stream_kwargs
|
||||
stream = open_stream(tensorizer_args.tensorizer_uri, **stream_kwargs)
|
||||
with TensorDeserializer(stream, **deserializer_args, device="cpu") as state:
|
||||
yield from state.items()
|
||||
del state
|
||||
|
||||
|
||||
def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool:
|
||||
"""
|
||||
Infer if the model is a vLLM model by checking the weights for
|
||||
a vLLM tensorized marker.
|
||||
|
||||
Args:
|
||||
tensorizer_config: The TensorizerConfig object containing the
|
||||
tensorizer_uri to the serialized model.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is a vLLM model, False otherwise.
|
||||
"""
|
||||
tensorizer_args = tensorizer_config._construct_tensorizer_args()
|
||||
deserializer = TensorDeserializer(
|
||||
open_stream(tensorizer_args.tensorizer_uri, **tensorizer_args.stream_kwargs),
|
||||
**tensorizer_args.deserialization_kwargs,
|
||||
lazy_load=True,
|
||||
)
|
||||
if tensorizer_config.vllm_tensorized:
|
||||
logger.warning(
|
||||
"Please note that newly serialized vLLM models are automatically "
|
||||
"inferred as vLLM models, so setting vllm_tensorized=True is "
|
||||
"only necessary for models serialized prior to this change."
|
||||
)
|
||||
return True
|
||||
return ".vllm_tensorized_marker" in deserializer
|
||||
|
||||
|
||||
def serialize_extra_artifacts(
|
||||
tensorizer_args: TensorizerArgs, served_model_name: str | list[str] | None
|
||||
) -> None:
|
||||
if not isinstance(served_model_name, str):
|
||||
raise ValueError(
|
||||
f"served_model_name must be a str for serialize_extra_artifacts, "
|
||||
f"not {type(served_model_name)}."
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
snapshot_download(
|
||||
served_model_name,
|
||||
local_dir=tmpdir,
|
||||
ignore_patterns=[
|
||||
"*.pt",
|
||||
"*.safetensors",
|
||||
"*.bin",
|
||||
"*.cache",
|
||||
"*.gitattributes",
|
||||
"*.md",
|
||||
],
|
||||
)
|
||||
for artifact in os.scandir(tmpdir):
|
||||
if not artifact.is_file():
|
||||
continue
|
||||
with (
|
||||
open(artifact.path, "rb") as f,
|
||||
open_stream(
|
||||
f"{tensorizer_args.tensorizer_dir}/{artifact.name}",
|
||||
mode="wb+",
|
||||
**tensorizer_args.stream_kwargs,
|
||||
) as stream,
|
||||
):
|
||||
logger.info("Writing artifact %s", artifact.name)
|
||||
stream.write(f.read())
|
||||
|
||||
|
||||
def serialize_vllm_model(
|
||||
model: nn.Module,
|
||||
tensorizer_config: TensorizerConfig,
|
||||
model_config: "ModelConfig",
|
||||
) -> nn.Module:
|
||||
model.register_parameter(
|
||||
"vllm_tensorized_marker",
|
||||
nn.Parameter(torch.tensor((1,), device="meta"), requires_grad=False),
|
||||
)
|
||||
|
||||
tensorizer_args = tensorizer_config._construct_tensorizer_args()
|
||||
|
||||
encryption_params = None
|
||||
if (keyfile := tensorizer_config.encryption_keyfile) is not None:
|
||||
with open(keyfile, "rb") as f:
|
||||
key = f.read()
|
||||
encryption_params = EncryptionParams(key=key)
|
||||
|
||||
output_file = tensorizer_args.tensorizer_uri
|
||||
if tensorizer_config._is_sharded:
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
output_file = output_file % get_tensor_model_parallel_rank()
|
||||
|
||||
with open_stream(
|
||||
output_file, mode="wb+", **tensorizer_args.stream_kwargs
|
||||
) as stream:
|
||||
serializer = TensorSerializer(
|
||||
stream,
|
||||
encryption=encryption_params,
|
||||
**tensorizer_config.serialization_kwargs,
|
||||
)
|
||||
serializer.write_module(model)
|
||||
serializer.close()
|
||||
|
||||
serialize_extra_artifacts(tensorizer_args, model_config.served_model_name)
|
||||
|
||||
logger.info("Successfully serialized model to %s", str(output_file))
|
||||
return model
|
||||
|
||||
|
||||
def tensorize_vllm_model(
|
||||
engine_args: "EngineArgs",
|
||||
tensorizer_config: TensorizerConfig,
|
||||
generate_keyfile: bool = True,
|
||||
):
|
||||
"""Utility to load a model and then serialize it with Tensorizer
|
||||
|
||||
Intended to be used separately from running a vLLM server since it
|
||||
creates its own Engine instance.
|
||||
"""
|
||||
engine_config = engine_args.create_engine_config()
|
||||
tensorizer_config.verify_with_model_config(engine_config.model_config)
|
||||
tensorizer_config.verify_with_parallel_config(engine_config.parallel_config)
|
||||
|
||||
# generate the encryption key before creating the engine to support sharding
|
||||
if (
|
||||
generate_keyfile
|
||||
and (keyfile := tensorizer_config.encryption_keyfile) is not None
|
||||
):
|
||||
encryption_params = EncryptionParams.random()
|
||||
with open_stream(
|
||||
keyfile,
|
||||
mode="wb+",
|
||||
s3_access_key_id=tensorizer_config.s3_access_key_id,
|
||||
s3_secret_access_key=tensorizer_config.s3_secret_access_key,
|
||||
s3_endpoint=tensorizer_config.s3_endpoint,
|
||||
) as stream:
|
||||
stream.write(encryption_params.key)
|
||||
|
||||
from vllm.v1.engine.llm_engine import LLMEngine
|
||||
|
||||
engine = LLMEngine.from_vllm_config(engine_config)
|
||||
engine.collective_rpc(
|
||||
"save_tensorized_model",
|
||||
kwargs={"tensorizer_config": tensorizer_config.to_serializable()},
|
||||
)
|
||||
|
||||
|
||||
def tensorize_lora_adapter(lora_path: str, tensorizer_config: TensorizerConfig):
|
||||
"""
|
||||
Uses tensorizer to serialize a LoRA adapter. Assumes that the files
|
||||
needed to load a LoRA adapter are a safetensors-format file called
|
||||
adapter_model.safetensors and a json config file called adapter_config.json.
|
||||
|
||||
Serializes the files in the tensorizer_config.tensorizer_dir
|
||||
"""
|
||||
import safetensors
|
||||
|
||||
from vllm.lora.utils import get_adapter_absolute_path
|
||||
|
||||
lora_dir = get_adapter_absolute_path(lora_path)
|
||||
|
||||
tensor_path = config_path = ""
|
||||
|
||||
for file in os.listdir(lora_dir):
|
||||
if file.startswith("adapter_model"):
|
||||
tensor_path = lora_dir + "/" + file
|
||||
if file.startswith("adapter_config"):
|
||||
config_path = lora_dir + "/" + file
|
||||
if tensor_path and config_path:
|
||||
break
|
||||
|
||||
if tensor_path.endswith(".safetensors"):
|
||||
tensors = safetensors.torch.load_file(tensor_path)
|
||||
elif tensor_path.endswith(".bin"):
|
||||
tensors = torch.load(tensor_path, weights_only=True)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported adapter model file: {tensor_path}. "
|
||||
f"Must be a .safetensors or .bin file."
|
||||
)
|
||||
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
tensorizer_args = tensorizer_config._construct_tensorizer_args()
|
||||
|
||||
with open_stream(
|
||||
f"{tensorizer_config.tensorizer_dir}/adapter_config.json",
|
||||
mode="wb+",
|
||||
**tensorizer_args.stream_kwargs,
|
||||
) as f:
|
||||
f.write(json.dumps(config).encode("utf-8"))
|
||||
|
||||
lora_uri = f"{tensorizer_config.tensorizer_dir}/adapter_model.tensors"
|
||||
with open_stream(lora_uri, mode="wb+", **tensorizer_args.stream_kwargs) as f:
|
||||
serializer = TensorSerializer(f)
|
||||
serializer.write_state_dict(tensors)
|
||||
serializer.close()
|
||||
|
||||
logger.info(
|
||||
"Successfully serialized LoRA files to %s",
|
||||
str(tensorizer_config.tensorizer_dir),
|
||||
)
|
||||
152
vllm/model_executor/model_loader/tensorizer_loader.py
Normal file
152
vllm/model_executor/model_loader/tensorizer_loader.py
Normal file
@@ -0,0 +1,152 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: SIM117
|
||||
import copy
|
||||
from collections.abc import Generator
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import ModelConfig, ParallelConfig, VllmConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.tensorizer import (
|
||||
TensorizerConfig,
|
||||
deserialize_tensorizer_model,
|
||||
init_tensorizer_model,
|
||||
is_vllm_tensorized,
|
||||
serialize_vllm_model,
|
||||
tensorizer_weights_iterator,
|
||||
)
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
get_model_architecture,
|
||||
initialize_model,
|
||||
)
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
BLACKLISTED_TENSORIZER_ARGS = {
|
||||
"device", # vLLM decides this
|
||||
"dtype", # vLLM decides this
|
||||
"mode", # Not meant to be configurable by the user
|
||||
}
|
||||
|
||||
|
||||
def validate_config(config: dict):
|
||||
for k, v in config.items():
|
||||
if v is not None and k in BLACKLISTED_TENSORIZER_ARGS:
|
||||
raise ValueError(f"{k} is not an allowed Tensorizer argument.")
|
||||
|
||||
|
||||
class TensorizerLoader(BaseModelLoader):
|
||||
"""Model loader using CoreWeave's tensorizer library."""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
|
||||
self.tensorizer_config = load_config.model_loader_extra_config
|
||||
else:
|
||||
validate_config(load_config.model_loader_extra_config)
|
||||
self.tensorizer_config = TensorizerConfig(
|
||||
**load_config.model_loader_extra_config["tensorizer_config"]
|
||||
)
|
||||
|
||||
def _verify_config(
|
||||
self, model_config: ModelConfig, parallel_config: ParallelConfig
|
||||
):
|
||||
self.tensorizer_config.verify_with_model_config(model_config)
|
||||
self.tensorizer_config.verify_with_parallel_config(parallel_config)
|
||||
|
||||
def _get_weights_iterator(
|
||||
self,
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
|
||||
return tensorizer_weights_iterator(tensorizer_args)
|
||||
|
||||
def _load_model_serialized_cpu(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
) -> nn.Module:
|
||||
"""Load a serialized model with tensorizer to the CPU.
|
||||
|
||||
This is only necessary when the model isn't vLLM-tensorized (see
|
||||
examples/others/tensorize_vllm_model.py) This should still
|
||||
be faster than default HuggingFace loading, but will be slower than
|
||||
loading a vLLM-tensorized model.
|
||||
"""
|
||||
device_config = vllm_config.device_config
|
||||
model_config = vllm_config.model_config
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = initialize_model(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
model.load_weights(self._get_weights_iterator())
|
||||
return model.eval()
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self.tensorizer_config.verify_with_model_config(model_config)
|
||||
|
||||
with self.tensorizer_config.open_stream():
|
||||
pass
|
||||
|
||||
def _patch_tensorizer_config(self, model_config: ModelConfig) -> TensorizerConfig:
|
||||
model_class = get_model_architecture(model_config)[0]
|
||||
tensorizer_config = copy.copy(self.tensorizer_config)
|
||||
tensorizer_config.model_class = model_class
|
||||
tensorizer_config.hf_config = model_config.hf_config
|
||||
tensorizer_config.dtype = model_config.dtype
|
||||
return tensorizer_config
|
||||
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
"""Load serialized model weights with tensorizer.
|
||||
|
||||
Expects a vLLM-tensorized model. See the
|
||||
examples/others/tensorize_vllm_model.py example script
|
||||
for serializing vLLM models."""
|
||||
if is_vllm_tensorized(self.tensorizer_config):
|
||||
tensorizer_config = self._patch_tensorizer_config(model_config)
|
||||
deserialize_tensorizer_model(model, tensorizer_config)
|
||||
else:
|
||||
model.load_weights(self._get_weights_iterator())
|
||||
|
||||
def load_model(
|
||||
self, vllm_config: VllmConfig, model_config: ModelConfig, prefix: str = ""
|
||||
) -> nn.Module:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self._verify_config(model_config, parallel_config)
|
||||
|
||||
if parallel_config.tensor_parallel_size > 1:
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
self.tensorizer_config.tensorizer_uri = (
|
||||
self.tensorizer_config.tensorizer_uri % get_tensor_model_parallel_rank()
|
||||
)
|
||||
|
||||
if is_vllm_tensorized(self.tensorizer_config):
|
||||
tensorizer_config = self._patch_tensorizer_config(model_config)
|
||||
device_config = vllm_config.device_config
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = init_tensorizer_model(
|
||||
tensorizer_config=tensorizer_config, vllm_config=vllm_config
|
||||
)
|
||||
self.load_weights(model, model_config)
|
||||
return model
|
||||
return self._load_model_serialized_cpu(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
@staticmethod
|
||||
def save_model(
|
||||
model: torch.nn.Module,
|
||||
tensorizer_config: TensorizerConfig | dict,
|
||||
model_config: ModelConfig,
|
||||
) -> None:
|
||||
if isinstance(tensorizer_config, dict):
|
||||
tensorizer_config = TensorizerConfig(**tensorizer_config)
|
||||
serialize_vllm_model(
|
||||
model=model,
|
||||
tensorizer_config=tensorizer_config,
|
||||
model_config=model_config,
|
||||
)
|
||||
286
vllm/model_executor/model_loader/utils.py
Normal file
286
vllm/model_executor/model_loader/utils.py
Normal file
@@ -0,0 +1,286 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utilities for selecting and loading models."""
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from typing_extensions import assert_never
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.attention import Attention, MLAAttention
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.model_loader.reload import (
|
||||
record_metadata_for_reloading,
|
||||
set_torchao_reload_attrs,
|
||||
)
|
||||
from vllm.model_executor.models.interfaces import SupportsQuant
|
||||
from vllm.tracing import instrument
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.utils.torch_utils import get_accelerator_view_from_cpu_tensor
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@instrument(span_name="Initialize model")
|
||||
def initialize_model(
|
||||
vllm_config: VllmConfig,
|
||||
*,
|
||||
prefix: str = "",
|
||||
model_class: type[nn.Module] | None = None,
|
||||
model_config: ModelConfig | None = None,
|
||||
) -> nn.Module:
|
||||
"""Initialize a model with the given configurations."""
|
||||
if model_config is None:
|
||||
model_config = vllm_config.model_config
|
||||
if model_class is None:
|
||||
model_class, _ = get_model_architecture(model_config)
|
||||
|
||||
if vllm_config.quant_config is not None:
|
||||
configure_quant_config(vllm_config.quant_config, model_class)
|
||||
|
||||
signatures = inspect.signature(model_class.__init__)
|
||||
all_params = [param.name for param in signatures.parameters.values()]
|
||||
if "vllm_config" in all_params and "prefix" in all_params:
|
||||
# new-style model class
|
||||
with set_current_vllm_config(vllm_config, check_compile=True, prefix=prefix):
|
||||
model = model_class(vllm_config=vllm_config, prefix=prefix)
|
||||
record_metadata_for_reloading(model)
|
||||
return model
|
||||
|
||||
msg = (
|
||||
"vLLM model class should accept `vllm_config` and `prefix` as "
|
||||
"input arguments. Possibly you have an old-style model class"
|
||||
" registered from out of tree and it is used for new vLLM version. "
|
||||
"Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
|
||||
"for the design and update the model class accordingly."
|
||||
)
|
||||
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
||||
|
||||
logger.warning(
|
||||
"Trying to guess the arguments for old-style model class %s",
|
||||
model_class,
|
||||
)
|
||||
# try to be compatible with old-style model class
|
||||
kwargs = {}
|
||||
if "prefix" in all_params:
|
||||
kwargs["prefix"] = prefix
|
||||
if "config" in all_params:
|
||||
kwargs["config"] = model_config.hf_config
|
||||
if "cache_config" in all_params:
|
||||
kwargs["cache_config"] = vllm_config.cache_config
|
||||
if "quant_config" in all_params:
|
||||
kwargs["quant_config"] = vllm_config.quant_config
|
||||
if "lora_config" in all_params:
|
||||
kwargs["lora_config"] = vllm_config.lora_config
|
||||
if "scheduler_config" in all_params:
|
||||
kwargs["scheduler_config"] = vllm_config.scheduler_config
|
||||
with set_current_vllm_config(vllm_config, check_compile=True, prefix=prefix):
|
||||
model = model_class(**kwargs)
|
||||
record_metadata_for_reloading(model)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def process_weights_after_loading(
|
||||
model: nn.Module, model_config: ModelConfig, target_device: torch.device
|
||||
) -> None:
|
||||
for _, module in model.named_modules():
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if isinstance(quant_method, QuantizeMethodBase):
|
||||
# When quant methods need to process weights after loading
|
||||
# (for repacking, quantizing, etc), they expect parameters
|
||||
# to be on the global target device. This scope is for the
|
||||
# case where cpu offloading is used, where we will move the
|
||||
# parameters onto device for processing and back off after.
|
||||
with device_loading_context(module, target_device):
|
||||
quant_method.process_weights_after_loading(module)
|
||||
|
||||
# Initialize post-load attention weights for both Attention and MLA.
|
||||
# NOTE: Happens after other modules so we can easily decompress weights.
|
||||
for _, module in model.named_modules():
|
||||
if isinstance(module, (Attention, MLAAttention)) and hasattr(
|
||||
module, "process_weights_after_loading"
|
||||
):
|
||||
# TODO(lucas): see if there is a way to unify the signatures
|
||||
# of process_weights_after_loading
|
||||
with device_loading_context(module, target_device):
|
||||
module.process_weights_after_loading(model_config.dtype)
|
||||
|
||||
# Needed for torchao model reloading via model.reload_weights
|
||||
# @kylesayrs @jerryzh168 this can be removed if callers move to `reload_weights`
|
||||
if model_config.quantization == "torchao":
|
||||
set_torchao_reload_attrs(model, model_config)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def device_loading_context(module: torch.nn.Module, target_device: torch.device):
|
||||
if target_device.type == "cpu":
|
||||
# If target is CPU, no need to move anything
|
||||
yield module
|
||||
return
|
||||
|
||||
original_device_states: dict[str, torch.device] = {}
|
||||
uva_offloaded_parameters: list[str] = []
|
||||
|
||||
# Store original device states and move parameters to GPU if they're on CPU
|
||||
for name, p in module.named_parameters():
|
||||
if p.device.type == "cpu":
|
||||
original_device_states[name] = p.device
|
||||
p.data = p.data.to(target_device)
|
||||
if getattr(p, "_vllm_is_uva_offloaded", False):
|
||||
uva_offloaded_parameters.append(name)
|
||||
# Parameters already on target device are not touched
|
||||
|
||||
try:
|
||||
yield module
|
||||
|
||||
finally:
|
||||
use_pin_memory = (
|
||||
is_pin_memory_available()
|
||||
and not envs.VLLM_WEIGHT_OFFLOADING_DISABLE_PIN_MEMORY
|
||||
)
|
||||
# Restore parameters to their original devices, ignoring new parameters
|
||||
for name, p in module.named_parameters():
|
||||
if name in original_device_states:
|
||||
original_device: torch.device = original_device_states[name]
|
||||
p.data = p.data.to(original_device)
|
||||
|
||||
# parameter is UVA offloaded, but was replaced with a new device tensor
|
||||
# re-offload it to CPU using UVA
|
||||
if name in uva_offloaded_parameters and not getattr(
|
||||
p, "_vllm_is_uva_offloaded", False
|
||||
):
|
||||
cpu_data = p.data.to(device="cpu")
|
||||
if use_pin_memory:
|
||||
cpu_data = cpu_data.pin_memory()
|
||||
p.data = get_accelerator_view_from_cpu_tensor(cpu_data)
|
||||
p._vllm_is_uva_offloaded = True
|
||||
|
||||
|
||||
_MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]()
|
||||
"""Caches the outputs of `_get_model_architecture`."""
|
||||
|
||||
|
||||
def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]:
|
||||
from vllm.model_executor.models.adapters import as_embedding_model, as_seq_cls_model
|
||||
|
||||
architectures = getattr(model_config.hf_config, "architectures", [])
|
||||
|
||||
model_cls, arch = model_config.registry.resolve_model_cls(
|
||||
architectures,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
if arch == model_config._get_transformers_backend_cls():
|
||||
assert model_config.model_impl != "vllm"
|
||||
if model_config.model_impl == "auto":
|
||||
logger.warning_once(
|
||||
"%s has no vLLM implementation, falling back to Transformers "
|
||||
"implementation. Some features may not be supported and "
|
||||
"performance may not be optimal.",
|
||||
arch,
|
||||
)
|
||||
|
||||
convert_type = model_config.convert_type
|
||||
if convert_type == "none":
|
||||
pass
|
||||
elif convert_type == "embed":
|
||||
logger.debug_once("Converting to embedding model.")
|
||||
model_cls = as_embedding_model(model_cls)
|
||||
elif convert_type == "classify":
|
||||
logger.debug_once("Converting to sequence classification model.")
|
||||
model_cls = as_seq_cls_model(model_cls)
|
||||
else:
|
||||
assert_never(convert_type)
|
||||
|
||||
return model_cls, arch
|
||||
|
||||
|
||||
def get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]:
|
||||
key = hash(
|
||||
(
|
||||
model_config.model,
|
||||
model_config.convert_type,
|
||||
model_config.runner_type,
|
||||
model_config.trust_remote_code,
|
||||
model_config.model_impl,
|
||||
tuple(getattr(model_config.hf_config, "architectures", [])),
|
||||
)
|
||||
)
|
||||
if key in _MODEL_ARCH_BY_HASH:
|
||||
return _MODEL_ARCH_BY_HASH[key]
|
||||
|
||||
model_arch = _get_model_architecture(model_config)
|
||||
_MODEL_ARCH_BY_HASH[key] = model_arch
|
||||
return model_arch
|
||||
|
||||
|
||||
def get_model_cls(model_config: ModelConfig) -> type[nn.Module]:
|
||||
return get_model_architecture(model_config)[0]
|
||||
|
||||
|
||||
def get_architecture_class_name(model_config: ModelConfig) -> str:
|
||||
return get_model_architecture(model_config)[1]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParamMapping:
|
||||
"""
|
||||
A class to handle parameter mapping for model weight loading.
|
||||
It creates a bidirectional mapping between packed parameters and their
|
||||
constituent parts.
|
||||
"""
|
||||
|
||||
packed_mapping: dict[str, list[str]]
|
||||
inverse_packed_mapping: dict[str, tuple[str, int]] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
for packed_name, sub_params in self.packed_mapping.items():
|
||||
# Skip self-contained cases (e.g., {"W_pack": ["W_pack"]})
|
||||
if len(sub_params) == 1 and sub_params[0] == packed_name:
|
||||
continue
|
||||
for index, param_name in enumerate(sub_params):
|
||||
self.inverse_packed_mapping[param_name] = (
|
||||
packed_name,
|
||||
index,
|
||||
)
|
||||
|
||||
def get_sub_modules(self, module_name: str) -> tuple[str, list[str]] | None:
|
||||
for key, value in self.packed_mapping.items():
|
||||
if module_name.endswith(key):
|
||||
return key, value
|
||||
return None
|
||||
|
||||
|
||||
def configure_quant_config(
|
||||
quant_config: QuantizationConfig, model_class: type[nn.Module]
|
||||
):
|
||||
"""
|
||||
Pass packed_modules_mapping by reference to quant_config so that
|
||||
quant_config can properly match fused modules
|
||||
|
||||
Note that model attributes are passed by reference to quant_config,
|
||||
enabling them to be updated by model_class.__new__ (ex. chatglm, qwen)
|
||||
|
||||
Once the `SupportsQuant` mixin has been added to all models, this
|
||||
function can be removed
|
||||
"""
|
||||
if not issubclass(model_class, SupportsQuant):
|
||||
hf_to_vllm_mapper = getattr(model_class, "hf_to_vllm_mapper", None)
|
||||
packed_mapping = getattr(model_class, "packed_modules_mapping", None)
|
||||
|
||||
# pass mappings by reference to quant_config
|
||||
if hf_to_vllm_mapper is not None:
|
||||
quant_config.apply_vllm_mapper(hf_to_vllm_mapper)
|
||||
if packed_mapping is not None:
|
||||
quant_config.packed_modules_mapping = packed_mapping
|
||||
1428
vllm/model_executor/model_loader/weight_utils.py
Normal file
1428
vllm/model_executor/model_loader/weight_utils.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user