v1.0
This commit is contained in:
152
model_executor/model_loader/__init__.py
Normal file
152
model_executor/model_loader/__init__.py
Normal file
@@ -0,0 +1,152 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.bitsandbytes_loader import BitsAndBytesModelLoader
|
||||
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
|
||||
from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader
|
||||
from vllm.model_executor.model_loader.gguf_loader import GGUFModelLoader
|
||||
from vllm.model_executor.model_loader.runai_streamer_loader import (
|
||||
RunaiModelStreamerLoader,
|
||||
)
|
||||
from vllm.model_executor.model_loader.sharded_state_loader import ShardedStateLoader
|
||||
from vllm.model_executor.model_loader.tensorizer_loader import TensorizerLoader
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
get_architecture_class_name,
|
||||
get_model_architecture,
|
||||
get_model_cls,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
padding_weight_loader
|
||||
)
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Reminder: Please update docstring in `LoadConfig`
|
||||
# if a new load format is added here
|
||||
LoadFormats = Literal[
|
||||
"auto",
|
||||
"bitsandbytes",
|
||||
"dummy",
|
||||
"fastsafetensors",
|
||||
"gguf",
|
||||
"mistral",
|
||||
"npcache",
|
||||
"pt",
|
||||
"runai_streamer",
|
||||
"runai_streamer_sharded",
|
||||
"safetensors",
|
||||
"sharded_state",
|
||||
"tensorizer",
|
||||
]
|
||||
_LOAD_FORMAT_TO_MODEL_LOADER: dict[str, type[BaseModelLoader]] = {
|
||||
"auto": DefaultModelLoader,
|
||||
"bitsandbytes": BitsAndBytesModelLoader,
|
||||
"dummy": DummyModelLoader,
|
||||
"fastsafetensors": DefaultModelLoader,
|
||||
"gguf": GGUFModelLoader,
|
||||
"mistral": DefaultModelLoader,
|
||||
"npcache": DefaultModelLoader,
|
||||
"pt": DefaultModelLoader,
|
||||
"runai_streamer": RunaiModelStreamerLoader,
|
||||
"runai_streamer_sharded": ShardedStateLoader,
|
||||
"safetensors": DefaultModelLoader,
|
||||
"sharded_state": ShardedStateLoader,
|
||||
"tensorizer": TensorizerLoader,
|
||||
}
|
||||
|
||||
|
||||
def register_model_loader(load_format: str):
|
||||
"""Register a customized vllm model loader.
|
||||
|
||||
When a load format is not supported by vllm, you can register a customized
|
||||
model loader to support it.
|
||||
|
||||
Args:
|
||||
load_format (str): The model loader format name.
|
||||
|
||||
Examples:
|
||||
>>> from vllm.config.load import LoadConfig
|
||||
>>> from vllm.model_executor.model_loader import (
|
||||
... get_model_loader,
|
||||
... register_model_loader,
|
||||
... )
|
||||
>>> from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
>>>
|
||||
>>> @register_model_loader("my_loader")
|
||||
... class MyModelLoader(BaseModelLoader):
|
||||
... def download_model(self):
|
||||
... pass
|
||||
...
|
||||
... def load_weights(self):
|
||||
... pass
|
||||
>>>
|
||||
>>> load_config = LoadConfig(load_format="my_loader")
|
||||
>>> type(get_model_loader(load_config))
|
||||
<class 'MyModelLoader'>
|
||||
""" # noqa: E501
|
||||
|
||||
def _wrapper(model_loader_cls):
|
||||
if load_format in _LOAD_FORMAT_TO_MODEL_LOADER:
|
||||
logger.warning(
|
||||
"Load format `%s` is already registered, and will be "
|
||||
"overwritten by the new loader class `%s`.",
|
||||
load_format,
|
||||
model_loader_cls,
|
||||
)
|
||||
if not issubclass(model_loader_cls, BaseModelLoader):
|
||||
raise ValueError(
|
||||
"The model loader must be a subclass of `BaseModelLoader`."
|
||||
)
|
||||
_LOAD_FORMAT_TO_MODEL_LOADER[load_format] = model_loader_cls
|
||||
logger.info(
|
||||
"Registered model loader `%s` with load format `%s`",
|
||||
model_loader_cls,
|
||||
load_format,
|
||||
)
|
||||
return model_loader_cls
|
||||
|
||||
return _wrapper
|
||||
|
||||
|
||||
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
||||
"""Get a model loader based on the load format."""
|
||||
load_format = load_config.load_format
|
||||
if load_format not in _LOAD_FORMAT_TO_MODEL_LOADER:
|
||||
raise ValueError(f"Load format `{load_format}` is not supported")
|
||||
return _LOAD_FORMAT_TO_MODEL_LOADER[load_format](load_config)
|
||||
|
||||
|
||||
def get_model(
|
||||
*, vllm_config: VllmConfig, model_config: ModelConfig | None = None
|
||||
) -> nn.Module:
|
||||
loader = get_model_loader(vllm_config.load_config)
|
||||
if model_config is None:
|
||||
model_config = vllm_config.model_config
|
||||
return loader.load_model(vllm_config=vllm_config, model_config=model_config)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_model",
|
||||
"get_model_loader",
|
||||
"get_architecture_class_name",
|
||||
"get_model_architecture",
|
||||
"get_model_cls",
|
||||
"register_model_loader",
|
||||
"BaseModelLoader",
|
||||
"BitsAndBytesModelLoader",
|
||||
"GGUFModelLoader",
|
||||
"DefaultModelLoader",
|
||||
"DummyModelLoader",
|
||||
"RunaiModelStreamerLoader",
|
||||
"ShardedStateLoader",
|
||||
"TensorizerLoader",
|
||||
]
|
||||
BIN
model_executor/model_loader/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
model_executor/model_loader/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
model_executor/model_loader/__pycache__/tpu.cpython-312.pyc
Normal file
BIN
model_executor/model_loader/__pycache__/tpu.cpython-312.pyc
Normal file
Binary file not shown.
BIN
model_executor/model_loader/__pycache__/utils.cpython-312.pyc
Normal file
BIN
model_executor/model_loader/__pycache__/utils.cpython-312.pyc
Normal file
Binary file not shown.
Binary file not shown.
57
model_executor/model_loader/base_loader.py
Normal file
57
model_executor/model_loader/base_loader.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
initialize_model,
|
||||
process_weights_after_loading,
|
||||
)
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BaseModelLoader(ABC):
|
||||
"""Base class for model loaders."""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
self.load_config = load_config
|
||||
|
||||
@abstractmethod
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
"""Download a model so that it can be immediately loaded."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
"""Load weights into a model. This standalone API allows
|
||||
inplace weights loading for an already-initialized model"""
|
||||
raise NotImplementedError
|
||||
|
||||
def load_model(
|
||||
self, vllm_config: VllmConfig, model_config: ModelConfig
|
||||
) -> nn.Module:
|
||||
"""Load a model with the given configurations."""
|
||||
device_config = vllm_config.device_config
|
||||
load_config = vllm_config.load_config
|
||||
load_device = (
|
||||
device_config.device if load_config.device is None else load_config.device
|
||||
)
|
||||
target_device = torch.device(load_device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
model = initialize_model(
|
||||
vllm_config=vllm_config, model_config=model_config
|
||||
)
|
||||
|
||||
logger.debug("Loading weights on %s ...", load_device)
|
||||
# Quantization does not happen in `load_weights` but after it
|
||||
self.load_weights(model, model_config)
|
||||
process_weights_after_loading(model, model_config, target_device)
|
||||
return model.eval()
|
||||
822
model_executor/model_loader/bitsandbytes_loader.py
Normal file
822
model_executor/model_loader/bitsandbytes_loader.py
Normal file
@@ -0,0 +1,822 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: SIM117
|
||||
import fnmatch
|
||||
import glob
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.utils import ParamMapping
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_safetensors_index_file_from_hf,
|
||||
download_weights_from_hf,
|
||||
filter_duplicate_safetensors_files,
|
||||
filter_files_not_needed_for_inference,
|
||||
pt_weights_iterator,
|
||||
safetensors_weights_iterator,
|
||||
)
|
||||
from vllm.model_executor.models import is_pooling_model
|
||||
from vllm.model_executor.utils import (
|
||||
get_moe_expert_mapping,
|
||||
get_packed_modules_mapping,
|
||||
set_weight_attrs,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def is_moe_model(model: torch.nn.Module) -> bool:
|
||||
"""Checks if the model contains FusedMoE layers."""
|
||||
return bool(any(isinstance(module, FusedMoE) for module in model.modules()))
|
||||
|
||||
|
||||
class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
"""Model loader to load model weights with BitAndBytes quantization."""
|
||||
|
||||
possible_config_file_names = ["adapter_config.json"]
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
|
||||
# Save the module names without sharding.
|
||||
self.unsharded_weights_modules: list[str] = []
|
||||
# Save the module names that are sharded by column.
|
||||
self.column_sharded_weights_modules: list[str] = []
|
||||
# Modules whose weights might have fused on disk
|
||||
# we need their output_sizes to make shard in flight correctly with TP
|
||||
self.maybe_fused_weights_modules: dict[str, list[int]] = {}
|
||||
# Store all module names (from transformers) that support
|
||||
# BNB quantization.
|
||||
self.target_modules: list[str] = []
|
||||
self.tp_disabled_modules: list[str] = []
|
||||
# Store the mapping of expert parameters for MoE models.
|
||||
self.expert_params_mapping: list[tuple[str, str, int, str]] = []
|
||||
# mapping weight names from transformers to vllm.
|
||||
self.weight_mapper: Callable = lambda name: name
|
||||
self.pre_quant: bool = False
|
||||
self.load_8bit: bool = False
|
||||
self.is_pool_model: bool = False
|
||||
|
||||
def _get_weight_files(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
allowed_patterns: list[str],
|
||||
revision: str | None = None,
|
||||
) -> tuple[str, list[str], str]:
|
||||
"""Retrieve weight files. Download the files if necessary.
|
||||
|
||||
Return the weight files and the file pattern."""
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
|
||||
if is_local:
|
||||
for pattern in allowed_patterns:
|
||||
weight_files = glob.glob(os.path.join(model_name_or_path, pattern))
|
||||
if weight_files:
|
||||
return model_name_or_path, weight_files, pattern
|
||||
else:
|
||||
hf_api = HfApi()
|
||||
repo_files = hf_api.list_repo_files(repo_id=model_name_or_path)
|
||||
for pattern in allowed_patterns:
|
||||
matching_files = fnmatch.filter(repo_files, pattern)
|
||||
if matching_files:
|
||||
hf_folder = download_weights_from_hf(
|
||||
model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
[pattern],
|
||||
revision,
|
||||
ignore_patterns=self.load_config.ignore_patterns,
|
||||
)
|
||||
return (
|
||||
hf_folder,
|
||||
glob.glob(os.path.join(hf_folder, pattern)),
|
||||
pattern,
|
||||
)
|
||||
|
||||
raise RuntimeError(f"No model weights found in: `{model_name_or_path}`")
|
||||
|
||||
def _prepare_weights(
|
||||
self, model_name_or_path: str, revision: str | None
|
||||
) -> tuple[list[str], bool]:
|
||||
"""Prepare weight files for the model."""
|
||||
|
||||
allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]
|
||||
|
||||
hf_folder, hf_weights_files, matched_pattern = self._get_weight_files(
|
||||
model_name_or_path, allowed_patterns, revision
|
||||
)
|
||||
|
||||
use_safetensors = matched_pattern == "*.safetensors"
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
index_file = SAFE_WEIGHTS_INDEX_NAME
|
||||
if use_safetensors:
|
||||
# For models like Mistral-7B-Instruct-v0.3
|
||||
# there are both sharded safetensors files and a consolidated
|
||||
# safetensors file. Using both breaks.
|
||||
# Here, we download the `model.safetensors.index.json` and filter
|
||||
# any files not found in the index.
|
||||
if not is_local:
|
||||
download_safetensors_index_file_from_hf(
|
||||
model_name_or_path,
|
||||
index_file,
|
||||
self.load_config.download_dir,
|
||||
revision,
|
||||
)
|
||||
hf_weights_files = filter_duplicate_safetensors_files(
|
||||
hf_weights_files, hf_folder, index_file
|
||||
)
|
||||
else:
|
||||
hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files)
|
||||
|
||||
if len(hf_weights_files) == 0:
|
||||
raise RuntimeError(
|
||||
f"Cannot find any model weights with `{model_name_or_path}`"
|
||||
)
|
||||
|
||||
return hf_weights_files, use_safetensors
|
||||
|
||||
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
|
||||
def _maybe_pool_model(module_name: str):
|
||||
# For pool model, we need to add the prefix `model.`
|
||||
# for the weight name if possible.
|
||||
if (
|
||||
self.is_pool_model
|
||||
and self.target_modules[0].startswith("model.")
|
||||
and not module_name.startswith("model.")
|
||||
):
|
||||
return "model." + module_name
|
||||
|
||||
return module_name
|
||||
|
||||
if use_safetensors:
|
||||
iterator = safetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
)
|
||||
else:
|
||||
iterator = pt_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
self.load_config.pt_load_map_location,
|
||||
)
|
||||
for org_name, param in iterator:
|
||||
# mapping weight names from transformers to vllm while preserving
|
||||
# original names.
|
||||
mapped_name = self.weight_mapper(org_name)
|
||||
mapped_name = _maybe_pool_model(mapped_name)
|
||||
|
||||
yield org_name, mapped_name, param
|
||||
|
||||
def _get_quantized_weights_iterator(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
revision: str | None,
|
||||
) -> tuple[Generator[tuple[str, torch.Tensor], None, None], dict[str, Any]]:
|
||||
"""Get an iterator to the model weights with bitsandbytes quantization,
|
||||
as well as the quantization state dictionary."""
|
||||
|
||||
# only load the bitsandbytes module when needed
|
||||
try:
|
||||
import bitsandbytes
|
||||
|
||||
if version.parse(bitsandbytes.__version__) < version.parse("0.46.1"):
|
||||
raise ImportError(
|
||||
"bitsandbytes version is wrong. Please "
|
||||
"install bitsandbytes>=0.46.1."
|
||||
)
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Please install bitsandbytes>=0.46.1 via "
|
||||
"`pip install bitsandbytes>=0.46.1` to use "
|
||||
"bitsandbytes quantizer."
|
||||
) from err
|
||||
|
||||
hf_weights_files, use_safetensors = self._prepare_weights(
|
||||
model_name_or_path, revision
|
||||
)
|
||||
|
||||
quant_state_dict: dict[str, Any] = {}
|
||||
|
||||
if self.pre_quant:
|
||||
if self.load_8bit:
|
||||
return self._quantized_8bit_generator(
|
||||
hf_weights_files, use_safetensors, quant_state_dict
|
||||
), quant_state_dict
|
||||
else:
|
||||
return self._quantized_4bit_generator(
|
||||
hf_weights_files, use_safetensors, quant_state_dict
|
||||
), quant_state_dict
|
||||
|
||||
return self._unquantized_generator(
|
||||
hf_weights_files, use_safetensors, quant_state_dict
|
||||
), quant_state_dict
|
||||
|
||||
def _is_8bit_weight_name(self, weight_name: str):
|
||||
quantized_suffix = {".scb", ".weight_format"}
|
||||
return any(weight_name.lower().endswith(suffix) for suffix in quantized_suffix)
|
||||
|
||||
def _is_4bit_weight_name(self, weight_name: str):
|
||||
quantized_suffix = {
|
||||
"absmax",
|
||||
"quant_map",
|
||||
"nested_absmax",
|
||||
"nested_quant_map",
|
||||
"bitsandbytes",
|
||||
}
|
||||
suffix = weight_name.split(".")[-1]
|
||||
return any(q_suffix in suffix for q_suffix in quantized_suffix)
|
||||
|
||||
def _quantized_8bit_generator(
|
||||
self, hf_weights_files, use_safetensors, quant_state_dict
|
||||
) -> Generator:
|
||||
for (
|
||||
org_weight_name,
|
||||
mapped_weight_name,
|
||||
weight_tensor,
|
||||
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
|
||||
if not mapped_weight_name.lower().endswith(".scb"):
|
||||
continue
|
||||
|
||||
weight_key = mapped_weight_name.lower().replace(".scb", ".weight")
|
||||
quant_state_dict[weight_key] = weight_tensor
|
||||
|
||||
for (
|
||||
org_weight_name,
|
||||
mapped_weight_name,
|
||||
weight_tensor,
|
||||
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
|
||||
if self._is_8bit_weight_name(mapped_weight_name):
|
||||
continue
|
||||
|
||||
if mapped_weight_name in quant_state_dict:
|
||||
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
|
||||
yield org_weight_name, weight_tensor
|
||||
else:
|
||||
yield org_weight_name, weight_tensor
|
||||
|
||||
def _quantized_4bit_generator(
|
||||
self, hf_weights_files, use_safetensors, quant_state_dict
|
||||
) -> Generator:
|
||||
from bitsandbytes.functional import QuantState
|
||||
|
||||
# First iterate over all quant state weights
|
||||
weight_iterator = self._hf_weight_iter(hf_weights_files, use_safetensors)
|
||||
temp_state_dict = {}
|
||||
for (
|
||||
org_weight_name,
|
||||
mapped_weight_name,
|
||||
weight_tensor,
|
||||
) in weight_iterator:
|
||||
if not self._is_4bit_weight_name(mapped_weight_name):
|
||||
continue
|
||||
# bitsandbytes library requires
|
||||
# weight.quant_state.bitsandbytes__* in CPU
|
||||
if "quant_state.bitsandbytes" in mapped_weight_name:
|
||||
temp_state_dict[mapped_weight_name] = weight_tensor.cpu().data
|
||||
else:
|
||||
temp_state_dict[mapped_weight_name] = weight_tensor
|
||||
|
||||
# Closure to parse quant_state for each prequant weight
|
||||
def _parse_quant_state(param_name: str, temp_state_dict: dict) -> QuantState:
|
||||
quant_state = {}
|
||||
for k in temp_state_dict:
|
||||
if param_name + "." in k:
|
||||
quant_state[k] = temp_state_dict[k]
|
||||
|
||||
return QuantState.from_dict(
|
||||
quant_state, device=current_platform.device_type
|
||||
)
|
||||
|
||||
# Second iterate over all prequant and normal weights
|
||||
# pre quantized weights would have a quant_state
|
||||
for (
|
||||
org_weight_name,
|
||||
mapped_weight_name,
|
||||
weight_tensor,
|
||||
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
|
||||
if self._is_4bit_weight_name(mapped_weight_name):
|
||||
continue
|
||||
|
||||
if (
|
||||
f"{mapped_weight_name}.quant_state.bitsandbytes__nf4" in temp_state_dict
|
||||
) or (
|
||||
f"{mapped_weight_name}.quant_state.bitsandbytes__fp4" in temp_state_dict
|
||||
):
|
||||
quant_state = _parse_quant_state(mapped_weight_name, temp_state_dict)
|
||||
quant_state_dict[mapped_weight_name] = quant_state
|
||||
yield org_weight_name, weight_tensor
|
||||
else:
|
||||
yield org_weight_name, weight_tensor
|
||||
|
||||
def _unquantized_generator(
|
||||
self, hf_weights_files, use_safetensors, quant_state_dict
|
||||
) -> Generator:
|
||||
from bitsandbytes.functional import quantize_4bit
|
||||
|
||||
global_tp_size = get_tensor_model_parallel_world_size()
|
||||
global_tp_rank = get_tensor_model_parallel_rank()
|
||||
check_match = (
|
||||
lambda weight_name, module_name: weight_name.removesuffix(".weight")
|
||||
== module_name
|
||||
)
|
||||
for (
|
||||
org_weight_name,
|
||||
mapped_weight_name,
|
||||
weight_tensor,
|
||||
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
|
||||
# override tp_size and tp_rank if the module has disabled TP
|
||||
if any(
|
||||
tp_disabled_module in mapped_weight_name
|
||||
for tp_disabled_module in self.tp_disabled_modules
|
||||
):
|
||||
tp_size = 1
|
||||
tp_rank = 0
|
||||
else:
|
||||
tp_size = global_tp_size
|
||||
tp_rank = global_tp_rank
|
||||
|
||||
if any(
|
||||
target_module in mapped_weight_name
|
||||
for target_module in self.target_modules
|
||||
) and mapped_weight_name.endswith(".weight"):
|
||||
# Without sharding
|
||||
if any(
|
||||
check_match(mapped_weight_name, module)
|
||||
for module in self.unsharded_weights_modules
|
||||
):
|
||||
weight_sub_tensor = weight_tensor
|
||||
# Shard by column
|
||||
elif any(
|
||||
check_match(mapped_weight_name, module)
|
||||
for module in self.column_sharded_weights_modules
|
||||
):
|
||||
total_size = weight_tensor.size(-1)
|
||||
start_index = total_size // tp_size * tp_rank
|
||||
end_index = total_size // tp_size * (tp_rank + 1)
|
||||
weight_sub_tensor = weight_tensor[..., start_index:end_index]
|
||||
# Weights have fused on disk. In this case, we assume that the
|
||||
# weight and module use same name.
|
||||
elif any(
|
||||
check_match(mapped_weight_name, module)
|
||||
for module in self.maybe_fused_weights_modules
|
||||
):
|
||||
# special case for fused weights
|
||||
# get the size of each shard weight tensor
|
||||
total_shard_sizes = next(
|
||||
(
|
||||
sizes
|
||||
for module, sizes in self.maybe_fused_weights_modules.items() # noqa: E501
|
||||
if check_match(mapped_weight_name, module)
|
||||
)
|
||||
)
|
||||
total_size = weight_tensor.size(0)
|
||||
assert total_size == sum(total_shard_sizes)
|
||||
# get the start/end index of each shard weight tensor
|
||||
total_start_index = list(
|
||||
itertools.accumulate([0] + total_shard_sizes)
|
||||
)[:-1]
|
||||
shard_weights_index = [
|
||||
(
|
||||
idx + size // tp_size * tp_rank,
|
||||
idx + size // tp_size * (tp_rank + 1),
|
||||
)
|
||||
for idx, size in zip(total_start_index, total_shard_sizes)
|
||||
]
|
||||
# slice and reorder the weight tensor
|
||||
weight_tensor = [
|
||||
weight_tensor[start_index:end_index, ...]
|
||||
for start_index, end_index in shard_weights_index
|
||||
]
|
||||
weight_sub_tensor = torch.cat(weight_tensor, dim=0)
|
||||
# Shard by row
|
||||
else:
|
||||
total_size = weight_tensor.size(0)
|
||||
start_index = total_size // tp_size * tp_rank
|
||||
end_index = total_size // tp_size * (tp_rank + 1)
|
||||
weight_sub_tensor = weight_tensor[start_index:end_index, ...]
|
||||
|
||||
# bitsandbytes requires data in GPU
|
||||
if weight_sub_tensor.is_cuda:
|
||||
loaded_weight = weight_sub_tensor
|
||||
else:
|
||||
loaded_weight = weight_sub_tensor.to(
|
||||
device=current_platform.device_type
|
||||
)
|
||||
|
||||
# remove the following after the issue is fixed:
|
||||
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
|
||||
if loaded_weight.is_contiguous() is False:
|
||||
loaded_weight = loaded_weight.contiguous()
|
||||
|
||||
with set_default_torch_dtype(torch.float32):
|
||||
processed_weight, quant_state = quantize_4bit(
|
||||
loaded_weight,
|
||||
compress_statistics=True,
|
||||
quant_type="nf4",
|
||||
)
|
||||
|
||||
quant_state_dict[mapped_weight_name] = quant_state
|
||||
else:
|
||||
processed_weight = weight_tensor
|
||||
yield org_weight_name, processed_weight
|
||||
|
||||
def _get_bnb_target_modules(self, model: nn.Module) -> None:
|
||||
"""
|
||||
Identify and collect all modules that support BitsAndBytes
|
||||
quantization.
|
||||
"""
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LinearBase) and hasattr(
|
||||
module.quant_method, "quant_config"
|
||||
):
|
||||
if modules_info := self.modules_mapping.get_sub_modules(name):
|
||||
# Map vllm's names to transformers's names.
|
||||
rep_name, sub_modules = modules_info
|
||||
for sub_name in sub_modules:
|
||||
new_name = name.replace(rep_name, sub_name)
|
||||
self.target_modules.append(new_name)
|
||||
if module.disable_tp:
|
||||
self.tp_disabled_modules.append(new_name)
|
||||
# Add original module name even if the module has stacked map,
|
||||
# in case model has a mixture of disk-merged and disk-split
|
||||
# weights with same last name.
|
||||
self.target_modules.append(name)
|
||||
if module.disable_tp:
|
||||
self.tp_disabled_modules.append(name)
|
||||
elif isinstance(module, FusedMoE) and hasattr(
|
||||
module.quant_method, "quant_config"
|
||||
):
|
||||
# TODO: support FusedMoE with prequant and 8bit.
|
||||
if self.pre_quant and self.load_8bit:
|
||||
raise ValueError(
|
||||
"Prequant BitsAndBytes 8bit models with FusedMoE "
|
||||
"is not supported yet."
|
||||
)
|
||||
# Get the corresponding weight name using module name and
|
||||
# expert_params_mapping.
|
||||
|
||||
for exp in self.expert_params_mapping:
|
||||
weight_name = exp[1]
|
||||
rep_name = name.replace("experts", "") + weight_name.removesuffix(
|
||||
"."
|
||||
)
|
||||
self.target_modules.append(rep_name)
|
||||
|
||||
assert self.target_modules, (
|
||||
"vLLM currently does not support BNB quantization for"
|
||||
)
|
||||
f" {type(model).__name__}"
|
||||
|
||||
def _classify_module_sharding(self, model: nn.Module):
|
||||
"""
|
||||
Categorize modules based on their weight sharding requirements
|
||||
for tensor parallelism.
|
||||
"""
|
||||
for name, module in model.named_modules():
|
||||
# Some modules like `ReplicatedLinear` should not have their weights
|
||||
# sharded. The reason for implementing it this way is to avoid new
|
||||
# static variable in the model implementation.
|
||||
if isinstance(module, (ReplicatedLinear,)):
|
||||
self.unsharded_weights_modules.append(name)
|
||||
# `QKVParallelLinear` and `MergedColumnParallelLinear` might have
|
||||
# fused weights on disk. We need to use the output sizes of these
|
||||
# modules to shard the weights correctly.
|
||||
elif isinstance(module, (QKVParallelLinear, MergedColumnParallelLinear)):
|
||||
self.maybe_fused_weights_modules[name] = module.output_sizes
|
||||
# In TP, these weights are partitioned along the column
|
||||
# dimension (dim=-1)
|
||||
elif isinstance(module, (RowParallelLinear,)):
|
||||
self.column_sharded_weights_modules.append(name)
|
||||
elif isinstance(module, FusedMoE):
|
||||
expert_mapping = self.expert_params_mapping
|
||||
for exp in expert_mapping:
|
||||
if exp[-1] == "w2":
|
||||
weight_name = exp[1]
|
||||
rep_name = name.replace(
|
||||
"experts", ""
|
||||
) + weight_name.removesuffix(".")
|
||||
self.column_sharded_weights_modules.append(rep_name)
|
||||
|
||||
def _verify_model_compatibility(
|
||||
self, model: nn.Module, model_config: ModelConfig
|
||||
) -> None:
|
||||
"""
|
||||
Verify that the model is compatible with BitsAndBytes quantization.
|
||||
"""
|
||||
if not hasattr(model, "load_weights"):
|
||||
raise AttributeError(
|
||||
"The required method 'load_weights' is not defined in class"
|
||||
f" {type(model).__name__}."
|
||||
)
|
||||
|
||||
if not hasattr(model, "packed_modules_mapping"):
|
||||
raise AttributeError(
|
||||
f"Model {type(model).__name__} does not support BitsAndBytes "
|
||||
"quantization yet. No 'packed_modules_mapping' found."
|
||||
)
|
||||
|
||||
quant_config = getattr(model_config.hf_config, "quantization_config", None)
|
||||
if quant_config and (quant_method := quant_config.get("quant_method")):
|
||||
if quant_method == "bitsandbytes":
|
||||
self.pre_quant = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"BitsAndBytes loader does not support {quant_method} quantization"
|
||||
)
|
||||
|
||||
# The quant_states in pre_quantized models cannot work with a split
|
||||
# weight tensor. So TP does not work with pre_quantized bnb models.
|
||||
if self.pre_quant and get_tensor_model_parallel_world_size() > 1:
|
||||
raise ValueError(
|
||||
"Prequant BitsAndBytes models with tensor parallelism is not "
|
||||
"supported. Please try with pipeline parallelism."
|
||||
)
|
||||
if quant_config and self.pre_quant:
|
||||
self.load_8bit = quant_config.get("load_in_8bit", False)
|
||||
|
||||
def _initialize_loader_state(
|
||||
self, model: nn.Module, model_config: ModelConfig
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the loader's internal state based on the model and
|
||||
configuration.
|
||||
"""
|
||||
self.is_pool_model = is_pooling_model(model)
|
||||
self.modules_mapping = ParamMapping(get_packed_modules_mapping(model))
|
||||
|
||||
if is_moe_model(model):
|
||||
self.expert_params_mapping = get_moe_expert_mapping(model)
|
||||
if not self.expert_params_mapping:
|
||||
raise AttributeError(
|
||||
f"MoE Model {type(model).__name__} does not support "
|
||||
"BitsAndBytes quantization yet. Ensure this model has "
|
||||
"'get_expert_mapping' method."
|
||||
)
|
||||
# For some models like Molmo, we need to use hf_to_vllm_mapper
|
||||
# to ensure correct loading of weights.
|
||||
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
|
||||
self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name)
|
||||
|
||||
self._get_bnb_target_modules(model)
|
||||
self._classify_module_sharding(model)
|
||||
|
||||
def _dequantize_dq(self, quant_states: Any):
|
||||
"""
|
||||
When BNB employs Double Quantization, we perform the dequantization of
|
||||
these constants during weight loading rather than at inference time,
|
||||
thereby avoiding this computational overhead during inference. This
|
||||
comes at the cost of increased memory usage.
|
||||
"""
|
||||
from bitsandbytes.functional import QuantState, dequantize_blockwise
|
||||
|
||||
def _dequantize_single_state(quant_state):
|
||||
"""Helper function to dequantize a single QuantState object."""
|
||||
if not (isinstance(quant_state, QuantState) and quant_state.nested):
|
||||
return
|
||||
|
||||
# Copied from: https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.45.3/bitsandbytes/functional.py#L1352-#L1356
|
||||
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
|
||||
absmax += quant_state.offset
|
||||
|
||||
# Ensure float32 dtype
|
||||
if absmax.dtype != torch.float32:
|
||||
absmax = absmax.float()
|
||||
|
||||
quant_state.absmax = absmax
|
||||
quant_state.nested = False
|
||||
quant_state.offset = None
|
||||
quant_state.state2 = None
|
||||
|
||||
if isinstance(quant_states, dict):
|
||||
for quant_state in quant_states.values():
|
||||
_dequantize_single_state(quant_state)
|
||||
else:
|
||||
_dequantize_single_state(quant_states)
|
||||
return quant_states
|
||||
|
||||
def _fuse_moe_quant_states(self, model: nn.Module, quant_states_dict: dict) -> dict:
|
||||
"""
|
||||
|
||||
This function consolidates individual expert quantization states into
|
||||
fused representations for w13 and w2.
|
||||
"""
|
||||
from bitsandbytes.functional import QuantState
|
||||
|
||||
if not self.expert_params_mapping:
|
||||
return dict()
|
||||
|
||||
expert_mapping = self.expert_params_mapping
|
||||
expert_qs_dict = {}
|
||||
for name, module in model.named_modules():
|
||||
if not isinstance(module, FusedMoE):
|
||||
continue
|
||||
w1_states_lst = []
|
||||
w2_states_lst = []
|
||||
w3_states_lst = []
|
||||
for exp in expert_mapping:
|
||||
shard_id = exp[-1]
|
||||
if shard_id not in ("w1", "w2", "w3"):
|
||||
raise ValueError(
|
||||
f"shard_id must be ['w1','w2','w3'] but got {shard_id}."
|
||||
)
|
||||
layer_prefix = name.split("experts")[0]
|
||||
weight_qual_name = layer_prefix + exp[1] + "weight"
|
||||
quant_state = self._dequantize_dq(quant_states_dict[weight_qual_name])
|
||||
if shard_id == "w1":
|
||||
w1_states_lst.append(quant_state)
|
||||
elif shard_id == "w2":
|
||||
w2_states_lst.append(quant_state)
|
||||
else:
|
||||
w3_states_lst.append(quant_state)
|
||||
del quant_states_dict[weight_qual_name]
|
||||
assert len(w1_states_lst) == len(w2_states_lst) == len(w3_states_lst)
|
||||
w13_absmax_lst = []
|
||||
w2_absmax_lst = []
|
||||
w13_total_dim0 = 0
|
||||
w2_total_dim0 = 0
|
||||
for w1_qs, w2_qs, w3_qs in zip(w1_states_lst, w2_states_lst, w3_states_lst):
|
||||
assert w1_qs.shape == w3_qs.shape
|
||||
assert w1_qs.blocksize == w2_qs.blocksize == w3_qs.blocksize
|
||||
assert w1_qs.dtype == w2_qs.dtype == w3_qs.dtype
|
||||
# w1 and w3 are interleaved in storage
|
||||
w13_absmax_lst.append(w1_qs.absmax)
|
||||
w13_absmax_lst.append(w3_qs.absmax)
|
||||
w2_absmax_lst.append(w2_qs.absmax)
|
||||
w13_total_dim0 += w1_qs.shape[0] + w3_qs.shape[0]
|
||||
w2_total_dim0 += w2_qs.shape[0]
|
||||
|
||||
w13_absmax = torch.cat(w13_absmax_lst)
|
||||
w2_absmax = torch.cat(w2_absmax_lst)
|
||||
# Create fused quantization state for w13.
|
||||
w13_qs = QuantState(
|
||||
absmax=w13_absmax,
|
||||
shape=(w13_total_dim0, w1_states_lst[0].shape[1]),
|
||||
code=w1_states_lst[0].code,
|
||||
blocksize=w1_states_lst[0].blocksize,
|
||||
quant_type="nf4",
|
||||
dtype=w1_states_lst[0].dtype,
|
||||
)
|
||||
# Create fused quantization state for w2.
|
||||
w2_qs = QuantState(
|
||||
absmax=w2_absmax,
|
||||
shape=(w2_total_dim0, w2_states_lst[0].shape[1]),
|
||||
code=w2_states_lst[0].code,
|
||||
blocksize=w2_states_lst[0].blocksize,
|
||||
quant_type="nf4",
|
||||
dtype=w2_states_lst[0].dtype,
|
||||
)
|
||||
# The weight suffixes .w13_weight and .w2_weight are consistent
|
||||
# with the param in BitsAndBytesMoEMethod.
|
||||
w13_weight_name = name + ".w13_weight"
|
||||
w2_weight_name = name + ".w2_weight"
|
||||
expert_qs_dict[w13_weight_name] = w13_qs
|
||||
expert_qs_dict[w2_weight_name] = w2_qs
|
||||
return expert_qs_dict
|
||||
|
||||
def _stack_quantization_states(
|
||||
self, model: nn.Module, quant_state_dict: dict
|
||||
) -> dict[str, dict[int, Any]]:
|
||||
stacked_quant_state_dict: dict[str, dict[int, Any]] = {}
|
||||
# TODO: Change this lazy import to normal import
|
||||
# after the checks are updated to run on a new version
|
||||
from vllm.model_executor.models.utils import is_pp_missing_parameter
|
||||
|
||||
param_dict = dict(model.named_parameters())
|
||||
for quant_param_name in quant_state_dict:
|
||||
if is_pp_missing_parameter(quant_param_name, model):
|
||||
continue
|
||||
|
||||
non_stacked_param_name = quant_param_name
|
||||
|
||||
shard_index = 0
|
||||
for shard_name, (
|
||||
weight_name,
|
||||
index,
|
||||
) in self.modules_mapping.inverse_packed_mapping.items():
|
||||
# Some models, such as MiniCPM V2.5/2.6, contain both
|
||||
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
|
||||
# from being incorrectly identified as being present in
|
||||
# 'vpm.encoder.layers.0.self_attn.qkv_proj.weight
|
||||
shard_pos = quant_param_name.find(shard_name)
|
||||
can_correct_rename = (shard_pos > 0) and (
|
||||
quant_param_name[shard_pos - 1] == "."
|
||||
)
|
||||
# If the quant_param_name is packed, it won't occur in the
|
||||
# param_dict before renaming.
|
||||
new_quant_param_name = quant_param_name.replace(shard_name, weight_name)
|
||||
need_rename = (quant_param_name not in param_dict) and (
|
||||
new_quant_param_name in param_dict
|
||||
)
|
||||
if can_correct_rename and need_rename:
|
||||
shard_index = index
|
||||
quant_param_name = new_quant_param_name
|
||||
break
|
||||
|
||||
# Models like Clip/Siglip may skip some layers in initialization,
|
||||
# causing unused quant_param_name in state_dict.
|
||||
if quant_param_name not in param_dict:
|
||||
continue
|
||||
|
||||
if quant_param_name not in stacked_quant_state_dict:
|
||||
stacked_quant_state_dict[quant_param_name] = {}
|
||||
|
||||
stacked_quant_state_dict[quant_param_name][shard_index] = quant_state_dict[
|
||||
non_stacked_param_name
|
||||
]
|
||||
return stacked_quant_state_dict
|
||||
|
||||
def _bind_quant_states_to_params(
|
||||
self, model: nn.Module, stacked_quant_state_dict: dict
|
||||
) -> None:
|
||||
# save quant_states and offsets as the attributes of the parameters
|
||||
param_dict = dict(model.named_parameters())
|
||||
for param_name, param in param_dict.items():
|
||||
if param_name in stacked_quant_state_dict:
|
||||
quant_states = stacked_quant_state_dict[param_name]
|
||||
# Dequantize double quantized values during weight loading.
|
||||
self._dequantize_dq(quant_states)
|
||||
set_weight_attrs(param, {"bnb_quant_state": quant_states})
|
||||
if not isinstance(quant_states, dict):
|
||||
continue
|
||||
|
||||
pack_ratio = getattr(param, "pack_factor", -1)
|
||||
if pack_ratio == -1:
|
||||
raise ValueError(f"pack_factor not set for parameter {param_name}.")
|
||||
|
||||
num_elements = [0] * len(quant_states)
|
||||
for seq, quant_state in quant_states.items():
|
||||
num_elements[seq] = math.prod(quant_state.shape) // pack_ratio
|
||||
|
||||
offsets = np.concatenate(([0], np.cumsum(num_elements)))
|
||||
# Make torch infer_schema happy
|
||||
offsets = torch.tensor(offsets).cpu()
|
||||
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
|
||||
|
||||
if self.load_8bit:
|
||||
set_weight_attrs(
|
||||
param, {"matmul_state": [None] * len(quant_states)}
|
||||
)
|
||||
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
self._verify_model_compatibility(model, model_config)
|
||||
self._initialize_loader_state(model, model_config)
|
||||
|
||||
logger.info(
|
||||
"Loading weights with BitsAndBytes quantization. May take a while ..."
|
||||
)
|
||||
qweight_iterator, quant_state_dict = self._get_quantized_weights_iterator(
|
||||
model_config.model,
|
||||
model_config.revision,
|
||||
)
|
||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||
loaded_weights = model.load_weights(qweight_iterator)
|
||||
# Some models may have weights loading tracker unimplemented.
|
||||
if loaded_weights is not None:
|
||||
weights_not_loaded = weights_to_load - loaded_weights
|
||||
if weights_not_loaded:
|
||||
raise ValueError(
|
||||
"Following weights were not initialized from "
|
||||
f"checkpoint: {weights_not_loaded}"
|
||||
)
|
||||
expert_quant_state_dict = self._fuse_moe_quant_states(model, quant_state_dict)
|
||||
|
||||
stacked_quant_state_dict = self._stack_quantization_states(
|
||||
model, quant_state_dict
|
||||
)
|
||||
|
||||
stacked_quant_state_dict = {
|
||||
**expert_quant_state_dict,
|
||||
**stacked_quant_state_dict,
|
||||
}
|
||||
self._bind_quant_states_to_params(model, stacked_quant_state_dict)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(model_config.model, model_config.revision)
|
||||
329
model_executor/model_loader/default_loader.py
Normal file
329
model_executor/model_loader/default_loader.py
Normal file
@@ -0,0 +1,329 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import dataclasses
|
||||
import glob
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Generator, Iterable
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.torchao import torchao_version_at_least
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_safetensors_index_file_from_hf,
|
||||
download_weights_from_hf,
|
||||
fastsafetensors_weights_iterator,
|
||||
filter_duplicate_safetensors_files,
|
||||
filter_files_not_needed_for_inference,
|
||||
maybe_download_from_modelscope,
|
||||
multi_thread_pt_weights_iterator,
|
||||
multi_thread_safetensors_weights_iterator,
|
||||
np_cache_weights_iterator,
|
||||
pt_weights_iterator,
|
||||
safetensors_weights_iterator,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DefaultModelLoader(BaseModelLoader):
|
||||
"""Model loader that can load different file types from disk."""
|
||||
|
||||
# default number of thread when enable multithread weight loading
|
||||
DEFAULT_NUM_THREADS = 8
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Source:
|
||||
"""A source for weights."""
|
||||
|
||||
model_or_path: str
|
||||
"""The model ID or path."""
|
||||
|
||||
revision: str | None
|
||||
"""The optional model revision."""
|
||||
|
||||
prefix: str = ""
|
||||
"""A prefix to prepend to all weights."""
|
||||
|
||||
fall_back_to_pt: bool = True
|
||||
"""Whether .pt weights can be used."""
|
||||
|
||||
allow_patterns_overrides: list[str] | None = None
|
||||
"""If defined, weights will load exclusively using these patterns."""
|
||||
|
||||
counter_before_loading_weights: float = 0.0
|
||||
counter_after_loading_weights: float = 0.0
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
|
||||
extra_config = load_config.model_loader_extra_config
|
||||
allowed_keys = {"enable_multithread_load", "num_threads"}
|
||||
unexpected_keys = set(extra_config.keys()) - allowed_keys
|
||||
|
||||
if unexpected_keys:
|
||||
raise ValueError(
|
||||
f"Unexpected extra config keys for load format "
|
||||
f"{load_config.load_format}: "
|
||||
f"{unexpected_keys}"
|
||||
)
|
||||
|
||||
def _prepare_weights(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
revision: str | None,
|
||||
fall_back_to_pt: bool,
|
||||
allow_patterns_overrides: list[str] | None,
|
||||
) -> tuple[str, list[str], bool]:
|
||||
"""Prepare weights for the model.
|
||||
|
||||
If the model is not local, it will be downloaded."""
|
||||
model_name_or_path = (
|
||||
maybe_download_from_modelscope(model_name_or_path, revision)
|
||||
or model_name_or_path
|
||||
)
|
||||
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
load_format = self.load_config.load_format
|
||||
use_safetensors = False
|
||||
index_file = SAFE_WEIGHTS_INDEX_NAME
|
||||
# Some quantized models use .pt files for storing the weights.
|
||||
if load_format == "auto":
|
||||
allow_patterns = ["*.safetensors", "*.bin"]
|
||||
elif load_format == "safetensors" or load_format == "fastsafetensors":
|
||||
use_safetensors = True
|
||||
allow_patterns = ["*.safetensors"]
|
||||
elif load_format == "mistral":
|
||||
use_safetensors = True
|
||||
allow_patterns = ["consolidated*.safetensors"]
|
||||
index_file = "consolidated.safetensors.index.json"
|
||||
elif load_format == "pt":
|
||||
allow_patterns = ["*.pt"]
|
||||
elif load_format == "npcache":
|
||||
allow_patterns = ["*.bin"]
|
||||
else:
|
||||
raise ValueError(f"Unknown load_format: {load_format}")
|
||||
|
||||
if fall_back_to_pt:
|
||||
allow_patterns += ["*.pt"]
|
||||
|
||||
if allow_patterns_overrides is not None:
|
||||
allow_patterns = allow_patterns_overrides
|
||||
|
||||
if not is_local:
|
||||
hf_folder = download_weights_from_hf(
|
||||
model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
allow_patterns,
|
||||
revision,
|
||||
ignore_patterns=self.load_config.ignore_patterns,
|
||||
)
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
|
||||
hf_weights_files: list[str] = []
|
||||
for pattern in allow_patterns:
|
||||
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
||||
if len(hf_weights_files) > 0:
|
||||
if pattern == "*.safetensors":
|
||||
use_safetensors = True
|
||||
break
|
||||
|
||||
if use_safetensors:
|
||||
# For models like Mistral-7B-Instruct-v0.3
|
||||
# there are both sharded safetensors files and a consolidated
|
||||
# safetensors file. Using both breaks.
|
||||
# Here, we download the `model.safetensors.index.json` and filter
|
||||
# any files not found in the index.
|
||||
if not is_local:
|
||||
download_safetensors_index_file_from_hf(
|
||||
model_name_or_path,
|
||||
index_file,
|
||||
self.load_config.download_dir,
|
||||
revision,
|
||||
)
|
||||
hf_weights_files = filter_duplicate_safetensors_files(
|
||||
hf_weights_files, hf_folder, index_file
|
||||
)
|
||||
else:
|
||||
hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files)
|
||||
|
||||
if len(hf_weights_files) == 0:
|
||||
raise RuntimeError(
|
||||
f"Cannot find any model weights with `{model_name_or_path}`"
|
||||
)
|
||||
|
||||
return hf_folder, hf_weights_files, use_safetensors
|
||||
|
||||
def _get_weights_iterator(
|
||||
self, source: "Source"
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Get an iterator for the model weights based on the load format."""
|
||||
extra_config = self.load_config.model_loader_extra_config
|
||||
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
|
||||
source.model_or_path,
|
||||
source.revision,
|
||||
source.fall_back_to_pt,
|
||||
source.allow_patterns_overrides,
|
||||
)
|
||||
if self.load_config.load_format == "npcache":
|
||||
# Currently np_cache only support *.bin checkpoints
|
||||
assert use_safetensors is False
|
||||
weights_iterator = np_cache_weights_iterator(
|
||||
source.model_or_path,
|
||||
self.load_config.download_dir,
|
||||
hf_folder,
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
)
|
||||
elif use_safetensors:
|
||||
if self.load_config.load_format == "fastsafetensors":
|
||||
weights_iterator = fastsafetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
)
|
||||
else:
|
||||
if extra_config.get("enable_multithread_load"):
|
||||
weights_iterator = multi_thread_safetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
max_workers=extra_config.get(
|
||||
"num_threads", self.DEFAULT_NUM_THREADS
|
||||
),
|
||||
)
|
||||
else:
|
||||
weights_iterator = safetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
self.load_config.safetensors_load_strategy,
|
||||
)
|
||||
else:
|
||||
if extra_config.get("enable_multithread_load"):
|
||||
weights_iterator = multi_thread_pt_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
self.load_config.pt_load_map_location,
|
||||
max_workers=extra_config.get(
|
||||
"num_threads", self.DEFAULT_NUM_THREADS
|
||||
),
|
||||
)
|
||||
else:
|
||||
weights_iterator = pt_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
self.load_config.pt_load_map_location,
|
||||
)
|
||||
|
||||
if current_platform.is_tpu():
|
||||
from vllm.platforms.tpu import USE_TPU_INFERENCE
|
||||
|
||||
if not USE_TPU_INFERENCE:
|
||||
# In PyTorch XLA, we should call `torch_xla.sync`
|
||||
# frequently so that not too many ops are accumulated
|
||||
# in the XLA program.
|
||||
import torch_xla
|
||||
|
||||
def _xla_weights_iterator(iterator: Generator):
|
||||
for weights in iterator:
|
||||
yield weights
|
||||
torch_xla.sync(wait=False)
|
||||
|
||||
weights_iterator = _xla_weights_iterator(weights_iterator)
|
||||
|
||||
if self.counter_before_loading_weights == 0.0:
|
||||
self.counter_before_loading_weights = time.perf_counter()
|
||||
# Apply the prefix.
|
||||
return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator)
|
||||
|
||||
def get_all_weights(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
model: nn.Module,
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
primary_weights = DefaultModelLoader.Source(
|
||||
model_config.model,
|
||||
model_config.revision,
|
||||
prefix="",
|
||||
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True),
|
||||
allow_patterns_overrides=getattr(model, "allow_patterns_overrides", None),
|
||||
)
|
||||
yield from self._get_weights_iterator(primary_weights)
|
||||
|
||||
secondary_weights = cast(
|
||||
Iterable[DefaultModelLoader.Source],
|
||||
getattr(model, "secondary_weights", ()),
|
||||
)
|
||||
for source in secondary_weights:
|
||||
yield from self._get_weights_iterator(source)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(
|
||||
model_config.model,
|
||||
model_config.revision,
|
||||
fall_back_to_pt=True,
|
||||
allow_patterns_overrides=None,
|
||||
)
|
||||
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
if model_config.quantization == "torchao" and torchao_version_at_least(
|
||||
"0.14.0"
|
||||
):
|
||||
self.load_config.safetensors_load_strategy = "torchao"
|
||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||
|
||||
# if we don't have `model.weight_metadata_and_attr_saved` defined and
|
||||
# set to True, it means that this is either offline quantization case
|
||||
# or the first run of online quantization
|
||||
# see online_quantization.py for detailed notes
|
||||
offline_quantization_or_first_run_of_online_quantization = not getattr(
|
||||
model, "weight_metadata_and_attr_saved", False
|
||||
)
|
||||
|
||||
if model_config.quantization is None:
|
||||
# model is not quantized
|
||||
loaded_weights = model.load_weights(
|
||||
self.get_all_weights(model_config, model)
|
||||
)
|
||||
elif offline_quantization_or_first_run_of_online_quantization:
|
||||
# case 1: offline quantized checkpoint
|
||||
# case 2: Step I1 first run of weight loading with
|
||||
# online quantization
|
||||
# see online_quantization.py for detailed notes
|
||||
loaded_weights = model.load_weights(
|
||||
self.get_all_weights(model_config, model)
|
||||
)
|
||||
else:
|
||||
# to avoid circular dependency
|
||||
from vllm.model_executor.model_loader.online_quantization import (
|
||||
load_weights_and_online_quantize,
|
||||
)
|
||||
|
||||
# subsequent runs of weight loading with online
|
||||
# quantization
|
||||
loaded_weights = load_weights_and_online_quantize(self, model, model_config)
|
||||
|
||||
self.counter_after_loading_weights = time.perf_counter()
|
||||
logger.info_once(
|
||||
"Loading weights took %.2f seconds",
|
||||
self.counter_after_loading_weights - self.counter_before_loading_weights,
|
||||
scope="local",
|
||||
)
|
||||
# We only enable strict check for non-quantized models
|
||||
# that have loaded weights tracking currently.
|
||||
opt_flag = envs.VLLM_MOE_OPT_LEVEL != 0 or envs.VLLM_LINEAR_OPT_LEVEL != 0
|
||||
if model_config.quantization is None and loaded_weights is not None and not opt_flag:
|
||||
weights_not_loaded = weights_to_load - loaded_weights
|
||||
if weights_not_loaded:
|
||||
raise ValueError(
|
||||
"Following weights were not initialized from "
|
||||
f"checkpoint: {weights_not_loaded}"
|
||||
)
|
||||
28
model_executor/model_loader/dummy_loader.py
Normal file
28
model_executor/model_loader/dummy_loader.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import initialize_dummy_weights
|
||||
|
||||
|
||||
class DummyModelLoader(BaseModelLoader):
|
||||
"""Model loader that will set model weights to random values."""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
if load_config.model_loader_extra_config:
|
||||
raise ValueError(
|
||||
f"Model loader extra config is not supported for "
|
||||
f"load format {load_config.load_format}"
|
||||
)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
pass # Nothing to download
|
||||
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||
# random values to the weights.
|
||||
initialize_dummy_weights(model)
|
||||
176
model_executor/model_loader/gguf_loader.py
Normal file
176
model_executor/model_loader/gguf_loader.py
Normal file
@@ -0,0 +1,176 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
import gguf
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
initialize_model,
|
||||
process_weights_after_loading,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
get_gguf_extra_tensor_names,
|
||||
get_gguf_weight_type_map,
|
||||
gguf_quant_weights_iterator,
|
||||
)
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
|
||||
|
||||
class GGUFModelLoader(BaseModelLoader):
|
||||
"""
|
||||
Model loader that can load GGUF files. This is useful for loading models
|
||||
that are quantized with GGUF and saved in the GGUF format. This loader
|
||||
supports loading both full models and sharded models.
|
||||
"""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
if load_config.model_loader_extra_config:
|
||||
raise ValueError(
|
||||
f"Model loader extra config is not supported for "
|
||||
f"load format {load_config.load_format}"
|
||||
)
|
||||
|
||||
def _prepare_weights(self, model_name_or_path: str):
|
||||
if os.path.isfile(model_name_or_path):
|
||||
return model_name_or_path
|
||||
# for raw HTTPS link
|
||||
if model_name_or_path.startswith(
|
||||
("http://", "https://")
|
||||
) and model_name_or_path.endswith(".gguf"):
|
||||
return hf_hub_download(url=model_name_or_path)
|
||||
# repo id/filename.gguf
|
||||
if "/" in model_name_or_path and model_name_or_path.endswith(".gguf"):
|
||||
repo_id, filename = model_name_or_path.rsplit("/", 1)
|
||||
return hf_hub_download(repo_id=repo_id, filename=filename)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unrecognised GGUF reference: {model_name_or_path} "
|
||||
"(expected local file, raw URL, or <repo_id>/<filename>.gguf)"
|
||||
)
|
||||
|
||||
def _get_gguf_weights_map(self, model_config: ModelConfig):
|
||||
"""
|
||||
GGUF uses this naming convention for their tensors from HF checkpoint:
|
||||
`blk.N.BB.weight` and `blk.N.BB.bias`
|
||||
where N signifies the block number of a layer, and BB signifies the
|
||||
attention/mlp layer components.
|
||||
See "Standardized tensor names" in
|
||||
https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
|
||||
"""
|
||||
config = model_config.hf_config
|
||||
model_type = config.model_type
|
||||
gguf_to_hf_name_map = {}
|
||||
# hack: ggufs have a different name than transformers
|
||||
if model_type == "cohere":
|
||||
model_type = "command-r"
|
||||
if model_type == "gemma3_text":
|
||||
# Gemma3 models use "gemma3_text" in HuggingFace but
|
||||
# "gemma3" in GGUF architecture naming
|
||||
model_type = "gemma3"
|
||||
if model_type in ("deepseek_v3", "deepseek_v2"):
|
||||
model_type = "deepseek2"
|
||||
# GGUF layer map assumes that we will have a merged expert weights
|
||||
# so we need to map them manually
|
||||
for idx in range(config.num_hidden_layers):
|
||||
gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = (
|
||||
f"model.layers.{idx}.mlp.gate.e_score_correction_bias"
|
||||
)
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = (
|
||||
f"model.layers.{idx}.mlp.experts.0.down_proj.weight"
|
||||
)
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = (
|
||||
f"model.layers.{idx}.mlp.experts.0.gate_proj.weight"
|
||||
)
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = (
|
||||
f"model.layers.{idx}.mlp.experts.0.up_proj.weight"
|
||||
)
|
||||
if model_type in ("qwen2_moe", "qwen3_moe"):
|
||||
model_type = model_type.replace("_", "")
|
||||
# GGUF layer map assumes that we will have a merged expert weights
|
||||
# so we need to map them manually
|
||||
for idx in range(config.num_hidden_layers):
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = (
|
||||
f"model.layers.{idx}.mlp.experts.0.down_proj.weight"
|
||||
)
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = (
|
||||
f"model.layers.{idx}.mlp.experts.0.gate_proj.weight"
|
||||
)
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = (
|
||||
f"model.layers.{idx}.mlp.experts.0.up_proj.weight"
|
||||
)
|
||||
|
||||
arch = None
|
||||
for key, value in gguf.MODEL_ARCH_NAMES.items():
|
||||
if value == model_type:
|
||||
arch = key
|
||||
break
|
||||
if arch is None:
|
||||
raise RuntimeError(f"Unknown gguf model_type: {model_type}")
|
||||
num_layers = config.num_hidden_layers
|
||||
name_map = gguf.get_tensor_name_map(arch, num_layers)
|
||||
with torch.device("meta"):
|
||||
dummy_model = AutoModelForCausalLM.from_config(
|
||||
config, trust_remote_code=model_config.trust_remote_code
|
||||
)
|
||||
state_dict = dummy_model.state_dict()
|
||||
|
||||
for hf_name in state_dict:
|
||||
name, suffix = hf_name.rsplit(".", 1)
|
||||
gguf_name = name_map.get_name(name)
|
||||
gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name
|
||||
return gguf_to_hf_name_map
|
||||
|
||||
def _get_weights_iterator(
|
||||
self, model_name_or_path: str, gguf_to_hf_name_map: dict[str, str]
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
return gguf_quant_weights_iterator(model_name_or_path, gguf_to_hf_name_map)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(model_config.model)
|
||||
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
local_model_path = self._prepare_weights(model_config.model)
|
||||
gguf_weights_map = self._get_gguf_weights_map(model_config)
|
||||
model.load_weights(
|
||||
self._get_weights_iterator(local_model_path, gguf_weights_map)
|
||||
)
|
||||
|
||||
def load_model(
|
||||
self, vllm_config: VllmConfig, model_config: ModelConfig
|
||||
) -> nn.Module:
|
||||
device_config = vllm_config.device_config
|
||||
local_model_path = self._prepare_weights(model_config.model)
|
||||
gguf_weights_map = self._get_gguf_weights_map(model_config)
|
||||
# we can only know if tie word embeddings after mapping weights
|
||||
if "lm_head.weight" in get_gguf_extra_tensor_names(
|
||||
local_model_path, gguf_weights_map
|
||||
):
|
||||
model_config.hf_config.update({"tie_word_embeddings": True})
|
||||
|
||||
weight_type_map = get_gguf_weight_type_map(model_config.model, gguf_weights_map)
|
||||
|
||||
# filter out unquantized modules to skip
|
||||
unquant_names = [
|
||||
name.removesuffix(".weight")
|
||||
for name, weight_type in weight_type_map.items()
|
||||
if weight_type == "F32" and name.endswith(".weight")
|
||||
]
|
||||
vllm_config.quant_config.unquantized_modules.extend(unquant_names)
|
||||
|
||||
target_device = torch.device(device_config.device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
model = initialize_model(vllm_config=vllm_config)
|
||||
self.load_weights(model, model_config)
|
||||
|
||||
process_weights_after_loading(model, model_config, target_device)
|
||||
return model
|
||||
224
model_executor/model_loader/online_quantization.py
Normal file
224
model_executor/model_loader/online_quantization.py
Normal file
@@ -0,0 +1,224 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import types
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
|
||||
from vllm.model_executor.model_loader.utils import process_weights_after_loading
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Notes for Online Quantization
|
||||
# In terms of state of checkpoints, quantization config and their
|
||||
# correspondance to online quantization:
|
||||
# | Use Case | Checkpoints | model_config.quantization |
|
||||
# | no quant | high precision | None |
|
||||
# | offline quant | quantized | fp8, torchao etc. |
|
||||
# | online quant | high precision | torchao etc. |
|
||||
#
|
||||
# The process for loading non-quantized checkpoint
|
||||
# 1. load non-quantized weights (load_weights)
|
||||
# 2. do any additional post processing (process_weights_after_loading)
|
||||
#
|
||||
# The process for loading offline quantized checkpoint
|
||||
# 1. load offline-quantized weights (load_weights)
|
||||
# 2. do any additional post processing (process_weights_after_loading)
|
||||
|
||||
# The process for unquantized model reloading
|
||||
# (repeated run in RL training loop)
|
||||
# first run
|
||||
# UI1. load_weights: load bfloat16 weights
|
||||
# UI2. process_weights_after_loading: any additional post processing
|
||||
# subsequent run
|
||||
# UC1: load_weights: load bfloat16 weights
|
||||
# (shouldn't be any issues since we didn't change any attributes
|
||||
# of the weights)
|
||||
# UC2: process_weights_after_loading: any additional post processing
|
||||
|
||||
# The process for weight reloading with online quantization
|
||||
# (repeated run in RL training loop)
|
||||
# first run
|
||||
# I1. load_weights: load bfloat16 weights
|
||||
# I2. process_weights_after_loading:
|
||||
# record weight metadata and attributes for R1 and R2
|
||||
# quantize weights to fp8
|
||||
# subsequent run
|
||||
# (beginning model weight is in fp8)
|
||||
# load_weights:
|
||||
# R1. restore bfloat16 model weight metadata
|
||||
# R2. restore the model weight attributes
|
||||
# R3. reload bfloat16 weights
|
||||
# R4. quantize weights (by calling process_weights_after_loading),
|
||||
# also set `process_weights_after_loading_already_called` to
|
||||
# True to stop it from running again
|
||||
# process_weights_after_loading (if called):
|
||||
# this will be skipped since it's already ran in
|
||||
# load_weights
|
||||
|
||||
|
||||
def maybe_save_metadata_and_attributes_for_weight_reloading(
|
||||
model: nn.Module, model_config: ModelConfig
|
||||
):
|
||||
# following is to support on the fly quantization, currently only supported
|
||||
# for torchao
|
||||
if model_config.quantization != "torchao":
|
||||
return
|
||||
|
||||
if getattr(model, "process_weights_after_loading_already_called", False):
|
||||
# In case `process_weights_after_loading` is called multiple times
|
||||
# we'll skip it at later times
|
||||
logger.warning(
|
||||
"process_weights_after_loading already called for model %s", model
|
||||
)
|
||||
return
|
||||
|
||||
from vllm.model_executor.model_loader.weight_utils import get_quant_config
|
||||
|
||||
quant_config = get_quant_config(model_config, None)
|
||||
|
||||
# If checkpoint is already torchao serialized, this means it's
|
||||
# pre-quantized quantization case, we'll skip saving the metadata
|
||||
# Otherwise, this is Step I2 of initialization steps of
|
||||
# online quantization
|
||||
# This step record the weights metadata and weight attributes so we can
|
||||
# restore the bfloat16 model weights during the relad step (R1 and R2)
|
||||
# see Notes in online_quantization.py for more details
|
||||
if not (
|
||||
hasattr(quant_config, "is_checkpoint_torchao_serialized")
|
||||
and not quant_config.is_checkpoint_torchao_serialized
|
||||
):
|
||||
return
|
||||
|
||||
# This is the I2 step of online quantiztion that saves
|
||||
# metadata and attributes of weights so they can be used in R1 and
|
||||
# R2 step, note that we only save these during initialization
|
||||
|
||||
# Includes two things
|
||||
# 1. save floating point metadata (shape, dtype, device) for init
|
||||
# 2. save weight attributes, e.g. `output_dim`, `weight_loader` for init
|
||||
|
||||
if getattr(model, "weight_metadata_and_attr_saved", False):
|
||||
return
|
||||
|
||||
# save the dtype, shape and device for model parameter, used for
|
||||
# restoring the model high precision parameters before
|
||||
# reloading the weights
|
||||
assert not hasattr(model, "original_weights_rebuild_keys")
|
||||
model.original_weights_rebuild_keys = {}
|
||||
for name, p in model.named_parameters():
|
||||
model.original_weights_rebuild_keys[name] = {
|
||||
"shape": p.shape,
|
||||
"dtype": p.dtype,
|
||||
"device": p.device,
|
||||
}
|
||||
|
||||
# record the weight attributes (loader functions etc.)
|
||||
# so these can be recovered later when we reload the weights
|
||||
# structure: {"weight_name": {"weight_attr_key": attr}}
|
||||
assert not hasattr(model, "recorded_weight_attr")
|
||||
model.recorded_weight_attr = {}
|
||||
for name, param in model.named_parameters():
|
||||
model.recorded_weight_attr[name] = {}
|
||||
for key in param.__dict__:
|
||||
if hasattr(param, key):
|
||||
attr = getattr(param, key)
|
||||
if not callable(attr):
|
||||
model.recorded_weight_attr[name][key] = attr
|
||||
elif hasattr(attr, "__self__") and param is attr.__self__:
|
||||
# if attr is a bonded method for an instance, and
|
||||
# attr.__self__ points to the instance (param)
|
||||
# we'll record the underlying function object
|
||||
model.recorded_weight_attr[name][key] = attr.__func__
|
||||
else:
|
||||
model.recorded_weight_attr[name][key] = attr
|
||||
# mark the metadata and attributes saved so we don't run it again
|
||||
model.weight_metadata_and_attr_saved = True
|
||||
|
||||
|
||||
def _bond_method_to_cls(func, obj):
|
||||
if hasattr(func, "__self__") or not callable(func):
|
||||
# If the function is already bound to an instance, return it as is
|
||||
return func
|
||||
else:
|
||||
return types.MethodType(func, obj)
|
||||
|
||||
|
||||
def load_weights_and_online_quantize(
|
||||
model_loader: DefaultModelLoader, model: nn.Module, model_config: ModelConfig
|
||||
) -> set[str]:
|
||||
# online quantization, right now only enabled for
|
||||
# torchao
|
||||
# R1, R2, R3, R4 in the Notes
|
||||
|
||||
# TODO: Add fp8 support
|
||||
assert model_config.quantization == "torchao", (
|
||||
"online quantization is only enabled for torchao currently"
|
||||
)
|
||||
# TODO: use create_weights to restore the weights to original state
|
||||
|
||||
# Step R1: First restore the quantized weights to original bfloat16
|
||||
# weights, with original metadata (shape, dtype, device)
|
||||
# and attributes, so that bfloat16 weights can be loaded properly
|
||||
existing_param_names = dict(model.named_parameters(remove_duplicate=False)).keys()
|
||||
named_modules = dict(model.named_modules(remove_duplicate=False))
|
||||
model_device = None
|
||||
|
||||
# Step R2: recover the parameter to the state before first loading
|
||||
for name, d in model.original_weights_rebuild_keys.items():
|
||||
_shape = d["shape"]
|
||||
_dtype = d["dtype"]
|
||||
_device = d["device"]
|
||||
if model_device is not None:
|
||||
assert model_device == _device, (
|
||||
"Expecting all weights "
|
||||
"to be in the same device for now, got both: "
|
||||
f"{model_device} and {_device}"
|
||||
)
|
||||
else:
|
||||
model_device = _device
|
||||
|
||||
if name in existing_param_names:
|
||||
module_name, weight_name = name.rsplit(".", 1)
|
||||
module = named_modules[module_name]
|
||||
setattr(
|
||||
module,
|
||||
weight_name,
|
||||
torch.nn.Parameter(torch.empty(_shape, dtype=_dtype, device=_device)),
|
||||
)
|
||||
|
||||
# recorded_weight_attr is
|
||||
# {"weight_name": {"weight_attr_key": attr}}
|
||||
# e.g.
|
||||
# {
|
||||
# {
|
||||
# "layer.0.weight": {
|
||||
# "weight_loader": weight_loader_function_object,
|
||||
# "input_dim": 0, ...
|
||||
# },
|
||||
# "layer.1.weight": ...,
|
||||
# }
|
||||
# }
|
||||
for full_weight_name, weight_attr_dict in model.recorded_weight_attr.items():
|
||||
for attr_name, attr in weight_attr_dict.items():
|
||||
module_name, weight_name = full_weight_name.rsplit(".", 1)
|
||||
module = named_modules[module_name]
|
||||
weight = getattr(module, weight_name)
|
||||
if not hasattr(weight, attr_name):
|
||||
setattr(weight, attr_name, _bond_method_to_cls(attr, weight))
|
||||
|
||||
# Step I1: reload bfloat16 / high precision weights
|
||||
loaded_weights = model.load_weights(
|
||||
model_loader.get_all_weights(model_config, model)
|
||||
)
|
||||
|
||||
# Step I2: online quantize the weights
|
||||
# manually process weights after loading
|
||||
model.process_weights_after_loading_already_called = False
|
||||
process_weights_after_loading(model, model_config, model_device)
|
||||
model.process_weights_after_loading_already_called = True
|
||||
return loaded_weights
|
||||
116
model_executor/model_loader/runai_streamer_loader.py
Normal file
116
model_executor/model_loader/runai_streamer_loader.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: SIM117
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_safetensors_index_file_from_hf,
|
||||
download_weights_from_hf,
|
||||
runai_safetensors_weights_iterator,
|
||||
)
|
||||
from vllm.transformers_utils.runai_utils import is_runai_obj_uri, list_safetensors
|
||||
|
||||
|
||||
class RunaiModelStreamerLoader(BaseModelLoader):
|
||||
"""
|
||||
Model loader that can load safetensors
|
||||
files from local FS or S3 bucket.
|
||||
"""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
|
||||
self._is_distributed = False
|
||||
if load_config.model_loader_extra_config:
|
||||
extra_config = load_config.model_loader_extra_config
|
||||
|
||||
if "distributed" in extra_config and isinstance(
|
||||
extra_config.get("distributed"), bool
|
||||
):
|
||||
self._is_distributed = extra_config.get("distributed")
|
||||
|
||||
if "concurrency" in extra_config and isinstance(
|
||||
extra_config.get("concurrency"), int
|
||||
):
|
||||
os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(
|
||||
extra_config.get("concurrency")
|
||||
)
|
||||
|
||||
if "memory_limit" in extra_config and isinstance(
|
||||
extra_config.get("memory_limit"), int
|
||||
):
|
||||
os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(
|
||||
extra_config.get("memory_limit")
|
||||
)
|
||||
|
||||
runai_streamer_s3_endpoint = os.getenv("RUNAI_STREAMER_S3_ENDPOINT")
|
||||
aws_endpoint_url = os.getenv("AWS_ENDPOINT_URL")
|
||||
if runai_streamer_s3_endpoint is None and aws_endpoint_url is not None:
|
||||
os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url
|
||||
|
||||
def _prepare_weights(
|
||||
self, model_name_or_path: str, revision: str | None
|
||||
) -> list[str]:
|
||||
"""Prepare weights for the model.
|
||||
|
||||
If the model is not local, it will be downloaded."""
|
||||
|
||||
is_object_storage_path = is_runai_obj_uri(model_name_or_path)
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
safetensors_pattern = "*.safetensors"
|
||||
index_file = SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
hf_folder = (
|
||||
model_name_or_path
|
||||
if (is_local or is_object_storage_path)
|
||||
else download_weights_from_hf(
|
||||
model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
[safetensors_pattern],
|
||||
revision,
|
||||
ignore_patterns=self.load_config.ignore_patterns,
|
||||
)
|
||||
)
|
||||
hf_weights_files = list_safetensors(path=hf_folder)
|
||||
|
||||
if not is_local and not is_object_storage_path:
|
||||
download_safetensors_index_file_from_hf(
|
||||
model_name_or_path, index_file, self.load_config.download_dir, revision
|
||||
)
|
||||
|
||||
if not hf_weights_files:
|
||||
raise RuntimeError(
|
||||
f"Cannot find any safetensors model weights with `{model_name_or_path}`"
|
||||
)
|
||||
|
||||
return hf_weights_files
|
||||
|
||||
def _get_weights_iterator(
|
||||
self, model_or_path: str, revision: str
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Get an iterator for the model weights based on the load format."""
|
||||
hf_weights_files = self._prepare_weights(model_or_path, revision)
|
||||
return runai_safetensors_weights_iterator(
|
||||
hf_weights_files, self.load_config.use_tqdm_on_load, self._is_distributed
|
||||
)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
"""Download model if necessary"""
|
||||
self._prepare_weights(model_config.model, model_config.revision)
|
||||
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
"""Load weights into a model."""
|
||||
model_weights = model_config.model
|
||||
if hasattr(model_config, "model_weights"):
|
||||
model_weights = model_config.model_weights
|
||||
model.load_weights(
|
||||
self._get_weights_iterator(model_weights, model_config.revision)
|
||||
)
|
||||
206
model_executor/model_loader/sharded_state_loader.py
Normal file
206
model_executor/model_loader/sharded_state_loader.py
Normal file
@@ -0,0 +1,206 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import collections
|
||||
import glob
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_weights_from_hf,
|
||||
runai_safetensors_weights_iterator,
|
||||
)
|
||||
from vllm.transformers_utils.s3_utils import glob as s3_glob
|
||||
from vllm.transformers_utils.utils import is_s3
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ShardedStateLoader(BaseModelLoader):
|
||||
"""
|
||||
Model loader that directly loads each worker's model state dict, which
|
||||
enables a fast load path for large tensor-parallel models where each worker
|
||||
only needs to read its own shard rather than the entire checkpoint. See
|
||||
`examples/offline_inference/save_sharded_state.py` for creating a sharded
|
||||
checkpoint.
|
||||
"""
|
||||
|
||||
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
|
||||
extra_config = (
|
||||
{}
|
||||
if load_config.model_loader_extra_config is None
|
||||
else load_config.model_loader_extra_config.copy()
|
||||
)
|
||||
self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
|
||||
if extra_config:
|
||||
raise ValueError(
|
||||
f"Unexpected extra config keys for load format "
|
||||
f"{load_config.load_format}: "
|
||||
f"{load_config.model_loader_extra_config.keys()}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _filter_subtensors(
|
||||
tensors: dict[str, torch.Tensor],
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Filter out all tensors that share the same memory or a subset of the
|
||||
memory of another tensor.
|
||||
"""
|
||||
same_storage_groups: dict[Any, list[tuple[str, torch.Tensor]]] = (
|
||||
collections.defaultdict(list)
|
||||
)
|
||||
for key, tensor in tensors.items():
|
||||
if tensor.numel():
|
||||
ptr = tensor.untyped_storage().data_ptr()
|
||||
same_storage_groups[tensor.device, ptr].append((key, tensor))
|
||||
|
||||
def get_end_ptr(tensor: torch.Tensor) -> int:
|
||||
return tensor.view(-1)[-1].data_ptr() + tensor.element_size()
|
||||
|
||||
result: dict[str, torch.Tensor] = {}
|
||||
for group in same_storage_groups.values():
|
||||
for k, t in group:
|
||||
a, b = t.data_ptr(), get_end_ptr(t)
|
||||
for k2, t2 in group:
|
||||
if not t2.is_contiguous():
|
||||
continue
|
||||
a2, b2 = t2.data_ptr(), get_end_ptr(t2)
|
||||
if a < a2 or b2 < b:
|
||||
continue
|
||||
if a2 < a or b < b2 or not t.is_contiguous():
|
||||
break # t2 covers strictly more memory than t.
|
||||
if k2 < k:
|
||||
# Same tensors, keep the one with the smaller key.
|
||||
break
|
||||
else:
|
||||
result[k] = t
|
||||
return result
|
||||
|
||||
def _prepare_weights(self, model_name_or_path: str, revision: str | None):
|
||||
if is_s3(model_name_or_path) or os.path.isdir(model_name_or_path):
|
||||
return model_name_or_path
|
||||
else:
|
||||
allow_patterns = ["*.safetensors"]
|
||||
return download_weights_from_hf(
|
||||
model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
allow_patterns,
|
||||
revision,
|
||||
ignore_patterns=self.load_config.ignore_patterns,
|
||||
)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(model_config.model, model_config.revision)
|
||||
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
model_weights = model_config.model
|
||||
if hasattr(model_config, "model_weights"):
|
||||
model_weights = model_config.model_weights
|
||||
local_model_path = model_weights
|
||||
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
pattern = os.path.join(
|
||||
local_model_path,
|
||||
self.pattern.format(rank=rank, part="*"),
|
||||
)
|
||||
|
||||
filepaths = []
|
||||
if is_s3(local_model_path):
|
||||
file_pattern = f"*{self.pattern.format(rank=rank, part='*')}"
|
||||
filepaths = s3_glob(path=local_model_path, allow_pattern=[file_pattern])
|
||||
else:
|
||||
filepaths = glob.glob(pattern)
|
||||
if not filepaths:
|
||||
# TODO: support un-sharded checkpoints too
|
||||
raise ValueError(
|
||||
f"Could not find checkpoint files '{pattern}', only "
|
||||
f"pre-sharded checkpoints are currently supported!"
|
||||
)
|
||||
state_dict = self._filter_subtensors(model.state_dict())
|
||||
for key, tensor in self.iterate_over_files(filepaths):
|
||||
# If loading with LoRA enabled, additional padding may
|
||||
# be added to certain parameters. We only load into a
|
||||
# narrowed view of the parameter data.
|
||||
param_data = state_dict[key].data
|
||||
param_shape = state_dict[key].shape
|
||||
for dim, size in enumerate(tensor.shape):
|
||||
if size < param_shape[dim]:
|
||||
param_data = param_data.narrow(dim, 0, size)
|
||||
if tensor.shape != param_shape:
|
||||
logger.warning(
|
||||
"loading tensor of shape %s into parameter '%s' of shape %s",
|
||||
tensor.shape,
|
||||
key,
|
||||
param_shape,
|
||||
)
|
||||
param_data.copy_(tensor)
|
||||
state_dict.pop(key)
|
||||
if state_dict:
|
||||
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
|
||||
|
||||
def iterate_over_files(
|
||||
self, paths
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
if self.load_config.load_format == "runai_streamer_sharded":
|
||||
yield from runai_safetensors_weights_iterator(paths, True)
|
||||
else:
|
||||
from safetensors.torch import safe_open
|
||||
|
||||
for path in paths:
|
||||
with safe_open(path, framework="pt") as f:
|
||||
for key in f.keys(): # noqa: SIM118
|
||||
tensor = f.get_tensor(key)
|
||||
yield key, tensor
|
||||
|
||||
@staticmethod
|
||||
def save_model(
|
||||
model: torch.nn.Module,
|
||||
path: str,
|
||||
pattern: str | None = None,
|
||||
max_size: int | None = None,
|
||||
) -> None:
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
if pattern is None:
|
||||
pattern = ShardedStateLoader.DEFAULT_PATTERN
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
part_idx = 0
|
||||
total_size = 0
|
||||
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
|
||||
state_dict_part: dict[str, torch.Tensor] = {}
|
||||
for key, tensor in state_dict.items():
|
||||
param_size = tensor.nelement() * tensor.element_size()
|
||||
if max_size is not None and total_size + param_size > max_size:
|
||||
filename = pattern.format(rank=rank, part=part_idx)
|
||||
save_file(
|
||||
state_dict_part,
|
||||
os.path.join(path, filename),
|
||||
)
|
||||
part_idx += 1
|
||||
total_size = 0
|
||||
state_dict_part = {}
|
||||
state_dict_part[key] = tensor
|
||||
total_size += param_size
|
||||
if len(state_dict_part) > 0:
|
||||
filename = pattern.format(rank=rank, part=part_idx)
|
||||
save_file(
|
||||
state_dict_part,
|
||||
os.path.join(path, filename),
|
||||
)
|
||||
790
model_executor/model_loader/tensorizer.py
Normal file
790
model_executor/model_loader/tensorizer.py
Normal file
@@ -0,0 +1,790 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import contextvars
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Generator, MutableMapping
|
||||
from dataclasses import asdict, dataclass, field, fields
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Optional
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from torch import nn
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import ModelConfig, ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.argparse_utils import FlexibleArgumentParser
|
||||
from vllm.utils.import_utils import PlaceholderModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
|
||||
try:
|
||||
from tensorizer import (
|
||||
DecryptionParams,
|
||||
EncryptionParams,
|
||||
TensorDeserializer,
|
||||
TensorSerializer,
|
||||
)
|
||||
from tensorizer.stream_io import open_stream
|
||||
from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor
|
||||
|
||||
except ImportError:
|
||||
tensorizer = PlaceholderModule("tensorizer")
|
||||
DecryptionParams = tensorizer.placeholder_attr("DecryptionParams")
|
||||
EncryptionParams = tensorizer.placeholder_attr("EncryptionParams")
|
||||
TensorDeserializer = tensorizer.placeholder_attr("TensorDeserializer")
|
||||
TensorSerializer = tensorizer.placeholder_attr("TensorSerializer")
|
||||
open_stream = tensorizer.placeholder_attr("stream_io.open_stream")
|
||||
convert_bytes = tensorizer.placeholder_attr("utils.convert_bytes")
|
||||
get_mem_usage = tensorizer.placeholder_attr("utils.get_mem_usage")
|
||||
no_init_or_tensor = tensorizer.placeholder_attr("utils.no_init_or_tensor")
|
||||
|
||||
__all__ = [
|
||||
"EncryptionParams",
|
||||
"DecryptionParams",
|
||||
"TensorDeserializer",
|
||||
"TensorSerializer",
|
||||
"open_stream",
|
||||
"convert_bytes",
|
||||
"get_mem_usage",
|
||||
"no_init_or_tensor",
|
||||
"TensorizerConfig",
|
||||
]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def is_valid_deserialization_uri(uri: str | None) -> bool:
|
||||
if uri:
|
||||
scheme = uri.lower().split("://")[0]
|
||||
return scheme in {"s3", "http", "https"} or os.path.exists(uri)
|
||||
return False
|
||||
|
||||
|
||||
def tensorizer_kwargs_arg(value):
|
||||
loaded = json.loads(value)
|
||||
if not isinstance(loaded, dict):
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"Not deserializable to dict: {value}. serialization_kwargs and "
|
||||
f"deserialization_kwargs must be "
|
||||
f"deserializable from a JSON string to a dictionary. "
|
||||
)
|
||||
return loaded
|
||||
|
||||
|
||||
class MetaTensorMode(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
|
||||
if func._schema.name == "aten::empty" and "device" not in kwargs:
|
||||
kwargs["device"] = "meta"
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def meta_tensor_mode(
|
||||
loading_code=None,
|
||||
):
|
||||
if loading_code is None:
|
||||
return _NoInitOrTensorImpl.context_manager()
|
||||
elif callable(loading_code):
|
||||
with _NoInitOrTensorImpl.context_manager():
|
||||
return loading_code()
|
||||
else:
|
||||
raise TypeError(
|
||||
"expected a callable to evaluate,"
|
||||
" or None if being used as a context manager;"
|
||||
f' got an object of type "{type(loading_code).__name__}" instead.'
|
||||
)
|
||||
|
||||
|
||||
class _NoInitOrTensorImpl:
|
||||
_MODULES = (torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm)
|
||||
_MODULE_ORIGINALS = tuple((m, m.reset_parameters) for m in _MODULES)
|
||||
|
||||
is_active = contextvars.ContextVar("_NoInitOrTensorImpl.is_active", default=False)
|
||||
_count_active: int = 0
|
||||
_count_active_lock = threading.Lock()
|
||||
|
||||
@classmethod
|
||||
@contextlib.contextmanager
|
||||
def context_manager(cls):
|
||||
if cls.is_active.get():
|
||||
yield
|
||||
return
|
||||
|
||||
with cls._count_active_lock:
|
||||
cls._count_active += 1
|
||||
if cls._count_active == 1:
|
||||
for mod in cls._MODULES:
|
||||
mod.reset_parameters = cls._disable(mod.reset_parameters)
|
||||
|
||||
reset_token = cls.is_active.set(True)
|
||||
|
||||
try:
|
||||
with MetaTensorMode():
|
||||
yield
|
||||
finally:
|
||||
cls.is_active.reset(reset_token)
|
||||
with cls._count_active_lock:
|
||||
cls._count_active -= 1
|
||||
if cls._count_active == 0:
|
||||
for mod, original in cls._MODULE_ORIGINALS:
|
||||
mod.reset_parameters = original
|
||||
|
||||
@staticmethod
|
||||
def _disable(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
if not _NoInitOrTensorImpl.is_active.get():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorizerConfig(MutableMapping):
|
||||
tensorizer_uri: str | None = None
|
||||
tensorizer_dir: str | None = None
|
||||
vllm_tensorized: bool | None = None
|
||||
verify_hash: bool | None = None
|
||||
num_readers: int | None = None
|
||||
encryption_keyfile: str | None = None
|
||||
s3_access_key_id: str | None = None
|
||||
s3_secret_access_key: str | None = None
|
||||
s3_endpoint: str | None = None
|
||||
lora_dir: str | None = None
|
||||
stream_kwargs: dict[str, Any] | None = None
|
||||
serialization_kwargs: dict[str, Any] | None = None
|
||||
deserialization_kwargs: dict[str, Any] | None = None
|
||||
_extra_serialization_attrs: dict[str, Any] | None = field(init=False, default=None)
|
||||
model_class: type[torch.nn.Module] | None = field(init=False, default=None)
|
||||
hf_config: PretrainedConfig | None = field(init=False, default=None)
|
||||
dtype: str | torch.dtype | None = field(init=False, default=None)
|
||||
_is_sharded: bool = field(init=False, default=False)
|
||||
_fields: ClassVar[tuple[str, ...]]
|
||||
_keys: ClassVar[frozenset[str]]
|
||||
"""Configuration class for Tensorizer settings.
|
||||
|
||||
These settings configure the behavior of model serialization and
|
||||
deserialization using Tensorizer.
|
||||
|
||||
Attributes:
|
||||
tensorizer_uri: Path to serialized model tensors. Can be a local file
|
||||
path or a S3 URI. This is a required field unless lora_dir is
|
||||
provided and the config is meant to be used for the
|
||||
`tensorize_lora_adapter` function. Unless a `tensorizer_dir` or
|
||||
`lora_dir` is passed to this object's initializer, this is
|
||||
a required argument.
|
||||
tensorizer_dir: Path to a directory containing serialized model tensors,
|
||||
and all other potential model artifacts to load the model, such as
|
||||
configs and tokenizer files. Can be passed instead of
|
||||
`tensorizer_uri` where the `model.tensors` file will be assumed
|
||||
to be in this directory.
|
||||
vllm_tensorized: If True, indicates that the serialized model is a
|
||||
vLLM model. This is used to determine the behavior of the
|
||||
TensorDeserializer when loading tensors from a serialized model.
|
||||
It is far faster to deserialize a vLLM model as it utilizes
|
||||
tensorizer's optimized GPU loading. Note that this is now
|
||||
deprecated, as serialized vLLM models are now automatically
|
||||
inferred as vLLM models.
|
||||
verify_hash: If True, the hashes of each tensor will be verified
|
||||
against the hashes stored in the metadata. A `HashMismatchError`
|
||||
will be raised if any of the hashes do not match.
|
||||
num_readers: Controls how many threads are allowed to read concurrently
|
||||
from the source file. Default is `None`, which will dynamically set
|
||||
the number of readers based on the number of available
|
||||
resources and model size. This greatly increases performance.
|
||||
encryption_keyfile: File path to a binary file containing a
|
||||
binary key to use for decryption. `None` (the default) means
|
||||
no decryption. See the example script in
|
||||
examples/others/tensorize_vllm_model.py.
|
||||
s3_access_key_id: The access key for the S3 bucket. Can also be set via
|
||||
the S3_ACCESS_KEY_ID environment variable.
|
||||
s3_secret_access_key: The secret access key for the S3 bucket. Can also
|
||||
be set via the S3_SECRET_ACCESS_KEY environment variable.
|
||||
s3_endpoint: The endpoint for the S3 bucket. Can also be set via the
|
||||
S3_ENDPOINT_URL environment variable.
|
||||
lora_dir: Path to a directory containing LoRA adapter artifacts for
|
||||
serialization or deserialization. When serializing LoRA adapters
|
||||
this is the only necessary parameter to pass to this object's
|
||||
initializer.
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
# check if the configuration is for a sharded vLLM model
|
||||
self._is_sharded = (
|
||||
isinstance(self.tensorizer_uri, str)
|
||||
and re.search(r"%0\dd", self.tensorizer_uri) is not None
|
||||
)
|
||||
|
||||
if self.tensorizer_dir and self.lora_dir:
|
||||
raise ValueError(
|
||||
"Only one of tensorizer_dir or lora_dir may be specified. "
|
||||
"Use lora_dir exclusively when serializing LoRA adapters, "
|
||||
"and tensorizer_dir or tensorizer_uri otherwise."
|
||||
)
|
||||
if self.tensorizer_dir and self.tensorizer_uri:
|
||||
logger.warning_once(
|
||||
"Provided both tensorizer_dir and tensorizer_uri. "
|
||||
"Inferring tensorizer_dir from tensorizer_uri as the "
|
||||
"latter takes precedence."
|
||||
)
|
||||
self.tensorizer_dir = os.path.dirname(self.tensorizer_uri)
|
||||
if not self.tensorizer_uri:
|
||||
if self.lora_dir:
|
||||
self.tensorizer_uri = f"{self.lora_dir}/adapter_model.tensors"
|
||||
elif self.tensorizer_dir:
|
||||
self.tensorizer_uri = f"{self.tensorizer_dir}/model.tensors"
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unable to resolve tensorizer_uri. "
|
||||
"A valid tensorizer_uri or tensorizer_dir "
|
||||
"must be provided for deserialization, and a "
|
||||
"valid tensorizer_uri, tensorizer_uri, or "
|
||||
"lora_dir for serialization."
|
||||
)
|
||||
else:
|
||||
self.tensorizer_dir = os.path.dirname(self.tensorizer_uri)
|
||||
|
||||
if not self.serialization_kwargs:
|
||||
self.serialization_kwargs = {}
|
||||
if not self.deserialization_kwargs:
|
||||
self.deserialization_kwargs = {}
|
||||
|
||||
def to_serializable(self) -> dict[str, Any]:
|
||||
# Due to TensorizerConfig needing to be msgpack-serializable, it needs
|
||||
# support for morphing back and forth between itself and its dict
|
||||
# representation
|
||||
|
||||
# TensorizerConfig's representation as a dictionary is meant to be
|
||||
# linked to TensorizerConfig in such a way that the following is
|
||||
# technically initializable:
|
||||
# TensorizerConfig(**my_tensorizer_cfg.to_serializable())
|
||||
|
||||
# This means the dict must not retain non-initializable parameters
|
||||
# and post-init attribute states
|
||||
|
||||
# Also don't want to retain private and unset parameters, so only retain
|
||||
# not None values and public attributes
|
||||
|
||||
raw_tc_dict = asdict(self)
|
||||
blacklisted = []
|
||||
|
||||
if "tensorizer_uri" in raw_tc_dict and "tensorizer_dir" in raw_tc_dict:
|
||||
blacklisted.append("tensorizer_dir")
|
||||
|
||||
if "tensorizer_dir" in raw_tc_dict and "lora_dir" in raw_tc_dict:
|
||||
blacklisted.append("tensorizer_dir")
|
||||
|
||||
tc_dict = {}
|
||||
for k, v in raw_tc_dict.items():
|
||||
if (
|
||||
k not in blacklisted
|
||||
and k not in tc_dict
|
||||
and not k.startswith("_")
|
||||
and v is not None
|
||||
):
|
||||
tc_dict[k] = v
|
||||
|
||||
return tc_dict
|
||||
|
||||
def _construct_tensorizer_args(self) -> "TensorizerArgs":
|
||||
return TensorizerArgs(self) # type: ignore
|
||||
|
||||
def verify_with_parallel_config(
|
||||
self,
|
||||
parallel_config: "ParallelConfig",
|
||||
) -> None:
|
||||
if parallel_config.tensor_parallel_size > 1 and not self._is_sharded:
|
||||
raise ValueError(
|
||||
"For a sharded model, tensorizer_uri should include a"
|
||||
" string format template like '%04d' to be formatted"
|
||||
" with the rank of the shard"
|
||||
)
|
||||
|
||||
def verify_with_model_config(self, model_config: "ModelConfig") -> None:
|
||||
if model_config.quantization is not None and self.tensorizer_uri is not None:
|
||||
logger.warning(
|
||||
"Loading a model using Tensorizer with quantization on vLLM"
|
||||
" is unstable and may lead to errors."
|
||||
)
|
||||
|
||||
def open_stream(self, tensorizer_args: Optional["TensorizerArgs"] = None):
|
||||
if tensorizer_args is None:
|
||||
tensorizer_args = self._construct_tensorizer_args()
|
||||
|
||||
return open_stream(self.tensorizer_uri, **tensorizer_args.stream_kwargs)
|
||||
|
||||
def keys(self):
|
||||
return self._keys
|
||||
|
||||
def __len__(self):
|
||||
return len(fields(self))
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._fields)
|
||||
|
||||
def __getitem__(self, item: str) -> Any:
|
||||
if item not in self.keys():
|
||||
raise KeyError(item)
|
||||
return getattr(self, item)
|
||||
|
||||
def __setitem__(self, key: str, value: Any) -> None:
|
||||
if key not in self.keys():
|
||||
# Disallow modifying invalid keys
|
||||
raise KeyError(key)
|
||||
setattr(self, key, value)
|
||||
|
||||
def __delitem__(self, key, /):
|
||||
if key not in self.keys():
|
||||
raise KeyError(key)
|
||||
delattr(self, key)
|
||||
|
||||
|
||||
TensorizerConfig._fields = tuple(f.name for f in fields(TensorizerConfig))
|
||||
TensorizerConfig._keys = frozenset(TensorizerConfig._fields)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorizerArgs:
|
||||
tensorizer_uri: str | None = None
|
||||
tensorizer_dir: str | None = None
|
||||
encryption_keyfile: str | None = None
|
||||
|
||||
def __init__(self, tensorizer_config: TensorizerConfig):
|
||||
for k, v in tensorizer_config.items():
|
||||
setattr(self, k, v)
|
||||
self.file_obj = tensorizer_config.tensorizer_uri
|
||||
self.s3_access_key_id = (
|
||||
tensorizer_config.s3_access_key_id or envs.S3_ACCESS_KEY_ID
|
||||
)
|
||||
self.s3_secret_access_key = (
|
||||
tensorizer_config.s3_secret_access_key or envs.S3_SECRET_ACCESS_KEY
|
||||
)
|
||||
self.s3_endpoint = tensorizer_config.s3_endpoint or envs.S3_ENDPOINT_URL
|
||||
|
||||
self.stream_kwargs = {
|
||||
"s3_access_key_id": tensorizer_config.s3_access_key_id,
|
||||
"s3_secret_access_key": tensorizer_config.s3_secret_access_key,
|
||||
"s3_endpoint": tensorizer_config.s3_endpoint,
|
||||
**(tensorizer_config.stream_kwargs or {}),
|
||||
}
|
||||
|
||||
self.deserialization_kwargs = {
|
||||
"verify_hash": tensorizer_config.verify_hash,
|
||||
"encryption": tensorizer_config.encryption_keyfile,
|
||||
"num_readers": tensorizer_config.num_readers,
|
||||
**(tensorizer_config.deserialization_kwargs or {}),
|
||||
}
|
||||
|
||||
if self.encryption_keyfile:
|
||||
with open_stream(
|
||||
tensorizer_config.encryption_keyfile,
|
||||
**self.stream_kwargs,
|
||||
) as stream:
|
||||
key = stream.read()
|
||||
decryption_params = DecryptionParams.from_key(key)
|
||||
self.deserialization_kwargs["encryption"] = decryption_params
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
"""Tensorizer CLI arguments"""
|
||||
|
||||
# Tensorizer options arg group
|
||||
group = parser.add_argument_group(
|
||||
"tensorizer options",
|
||||
description=(
|
||||
"Options for configuring the behavior of the"
|
||||
" tensorizer deserializer when "
|
||||
"load_format=tensorizer is specified when "
|
||||
"initializing an LLMEngine, either via the CLI "
|
||||
"when running the vLLM OpenAI inference server "
|
||||
"with a JSON string passed to "
|
||||
"--model-loader-extra-config or as arguments given "
|
||||
"to TensorizerConfig when passed to "
|
||||
"model_loader_extra_config in the constructor "
|
||||
"for LLMEngine."
|
||||
),
|
||||
)
|
||||
|
||||
group.add_argument(
|
||||
"--tensorizer-uri",
|
||||
type=str,
|
||||
help="Path to serialized model tensors. Can be a local file path,"
|
||||
" or an HTTP(S) or S3 URI.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--verify-hash",
|
||||
action="store_true",
|
||||
help="If enabled, the hashes of each tensor will be verified"
|
||||
" against the hashes stored in the file metadata. An exception"
|
||||
" will be raised if any of the hashes do not match.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--encryption-keyfile",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The file path to a binary file containing a binary key to "
|
||||
"use for decryption. Can be a file path or S3 network URI.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--num-readers",
|
||||
default=None,
|
||||
type=int,
|
||||
help="Controls how many threads are allowed to read concurrently "
|
||||
"from the source file. Default is `None`, which will dynamically "
|
||||
"set the number of readers based on the available resources "
|
||||
"and model size. This greatly increases performance.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--s3-access-key-id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The access key for the S3 bucket. Can also be set via the "
|
||||
"S3_ACCESS_KEY_ID environment variable.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--s3-secret-access-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The secret access key for the S3 bucket. Can also be set via "
|
||||
"the S3_SECRET_ACCESS_KEY environment variable.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--s3-endpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The endpoint for the S3 bucket. Can also be set via the "
|
||||
"S3_ENDPOINT_URL environment variable.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace) -> "TensorizerArgs":
|
||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||
tensorizer_args = cls(
|
||||
**{attr: getattr(args, attr) for attr in attrs if hasattr(args, attr)}
|
||||
)
|
||||
return tensorizer_args
|
||||
|
||||
|
||||
def _check_tensors_on_meta_device(model: nn.Module) -> None:
|
||||
for tensor in model.state_dict().values():
|
||||
if tensor.device.type == "meta":
|
||||
raise ValueError(
|
||||
"The serialized model contains tensors on the meta device,"
|
||||
" indicating that some tensors were not loaded properly."
|
||||
" Please check that the parameters of the model being"
|
||||
" specified match that of the serialized model, such as"
|
||||
" its quantization."
|
||||
)
|
||||
|
||||
|
||||
def _resize_lora_embeddings(model: nn.Module):
|
||||
"""Modify LoRA embedding layers to use bigger tensors
|
||||
to allow for adapter added tokens."""
|
||||
for child in model.modules():
|
||||
if (
|
||||
isinstance(child, VocabParallelEmbedding)
|
||||
and child.weight.shape[0] < child.num_embeddings_per_partition
|
||||
):
|
||||
new_weight = torch.empty(
|
||||
child.num_embeddings_per_partition,
|
||||
child.embedding_dim,
|
||||
dtype=child.weight.dtype,
|
||||
device=child.weight.device,
|
||||
)
|
||||
new_weight[: child.weight.shape[0]].copy_(child.weight.data)
|
||||
new_weight[child.weight.shape[0] :].fill_(0)
|
||||
child.weight.data = new_weight
|
||||
|
||||
|
||||
def init_tensorizer_model(
|
||||
tensorizer_config: TensorizerConfig, vllm_config: VllmConfig
|
||||
) -> nn.Module:
|
||||
assert tensorizer_config.hf_config is not None
|
||||
model_args = tensorizer_config.hf_config
|
||||
model_args.dtype = tensorizer_config.dtype
|
||||
assert tensorizer_config.model_class is not None
|
||||
# TODO: Do we need to consider old-style model class?
|
||||
with meta_tensor_mode(), set_current_vllm_config(vllm_config, check_compile=True):
|
||||
return tensorizer_config.model_class(vllm_config=vllm_config)
|
||||
|
||||
|
||||
def deserialize_tensorizer_model(
|
||||
model: nn.Module, tensorizer_config: TensorizerConfig
|
||||
) -> None:
|
||||
tensorizer_args = tensorizer_config._construct_tensorizer_args()
|
||||
if not is_valid_deserialization_uri(tensorizer_config.tensorizer_uri):
|
||||
raise ValueError(
|
||||
f"{tensorizer_config.tensorizer_uri} is not a valid "
|
||||
f"tensorizer URI. Please check that the URI is correct. "
|
||||
f"It must either point to a local existing file, or have a "
|
||||
f"S3, HTTP or HTTPS scheme."
|
||||
)
|
||||
before_mem = get_mem_usage()
|
||||
start = time.perf_counter()
|
||||
with (
|
||||
open_stream(
|
||||
tensorizer_config.tensorizer_uri, mode="rb", **tensorizer_args.stream_kwargs
|
||||
) as stream,
|
||||
TensorDeserializer(
|
||||
stream,
|
||||
dtype=tensorizer_config.dtype,
|
||||
device=f"xpu:{torch.xpu.current_device()}"
|
||||
if current_platform.is_xpu()
|
||||
else f"cuda:{torch.cuda.current_device()}",
|
||||
**tensorizer_args.deserialization_kwargs,
|
||||
) as deserializer,
|
||||
):
|
||||
deserializer.load_into_module(model)
|
||||
end = time.perf_counter()
|
||||
|
||||
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
|
||||
duration = end - start
|
||||
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
|
||||
after_mem = get_mem_usage()
|
||||
deserializer.close()
|
||||
logger.info(
|
||||
"Deserialized %s in %0.2fs, %s/s", total_bytes_str, end - start, per_second
|
||||
)
|
||||
logger.info("Memory usage before: %s", before_mem)
|
||||
logger.info("Memory usage after: %s", after_mem)
|
||||
|
||||
_check_tensors_on_meta_device(model)
|
||||
_resize_lora_embeddings(model)
|
||||
del model.vllm_tensorized_marker
|
||||
|
||||
|
||||
def tensorizer_weights_iterator(
|
||||
tensorizer_args: "TensorizerArgs",
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
logger.warning(
|
||||
"Deserializing HuggingFace models is not optimized for "
|
||||
"loading on vLLM, as tensorizer is forced to load to CPU. "
|
||||
"Consider deserializing a vLLM model instead for faster "
|
||||
"load times. See the "
|
||||
"examples/others/tensorize_vllm_model.py example script "
|
||||
"for serializing vLLM models."
|
||||
)
|
||||
|
||||
deserializer_args = tensorizer_args.deserialization_kwargs
|
||||
stream_kwargs = tensorizer_args.stream_kwargs
|
||||
stream = open_stream(tensorizer_args.tensorizer_uri, **stream_kwargs)
|
||||
with TensorDeserializer(stream, **deserializer_args, device="cpu") as state:
|
||||
yield from state.items()
|
||||
del state
|
||||
|
||||
|
||||
def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool:
|
||||
"""
|
||||
Infer if the model is a vLLM model by checking the weights for
|
||||
a vLLM tensorized marker.
|
||||
|
||||
Args:
|
||||
tensorizer_config: The TensorizerConfig object containing the
|
||||
tensorizer_uri to the serialized model.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is a vLLM model, False otherwise.
|
||||
"""
|
||||
tensorizer_args = tensorizer_config._construct_tensorizer_args()
|
||||
deserializer = TensorDeserializer(
|
||||
open_stream(tensorizer_args.tensorizer_uri, **tensorizer_args.stream_kwargs),
|
||||
**tensorizer_args.deserialization_kwargs,
|
||||
lazy_load=True,
|
||||
)
|
||||
if tensorizer_config.vllm_tensorized:
|
||||
logger.warning(
|
||||
"Please note that newly serialized vLLM models are automatically "
|
||||
"inferred as vLLM models, so setting vllm_tensorized=True is "
|
||||
"only necessary for models serialized prior to this change."
|
||||
)
|
||||
return True
|
||||
return ".vllm_tensorized_marker" in deserializer
|
||||
|
||||
|
||||
def serialize_extra_artifacts(
|
||||
tensorizer_args: TensorizerArgs, served_model_name: str | list[str] | None
|
||||
) -> None:
|
||||
if not isinstance(served_model_name, str):
|
||||
raise ValueError(
|
||||
f"served_model_name must be a str for serialize_extra_artifacts, "
|
||||
f"not {type(served_model_name)}."
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
snapshot_download(
|
||||
served_model_name,
|
||||
local_dir=tmpdir,
|
||||
ignore_patterns=[
|
||||
"*.pt",
|
||||
"*.safetensors",
|
||||
"*.bin",
|
||||
"*.cache",
|
||||
"*.gitattributes",
|
||||
"*.md",
|
||||
],
|
||||
)
|
||||
for artifact in os.scandir(tmpdir):
|
||||
if not artifact.is_file():
|
||||
continue
|
||||
with (
|
||||
open(artifact.path, "rb") as f,
|
||||
open_stream(
|
||||
f"{tensorizer_args.tensorizer_dir}/{artifact.name}",
|
||||
mode="wb+",
|
||||
**tensorizer_args.stream_kwargs,
|
||||
) as stream,
|
||||
):
|
||||
logger.info("Writing artifact %s", artifact.name)
|
||||
stream.write(f.read())
|
||||
|
||||
|
||||
def serialize_vllm_model(
|
||||
model: nn.Module,
|
||||
tensorizer_config: TensorizerConfig,
|
||||
model_config: "ModelConfig",
|
||||
) -> nn.Module:
|
||||
model.register_parameter(
|
||||
"vllm_tensorized_marker",
|
||||
nn.Parameter(torch.tensor((1,), device="meta"), requires_grad=False),
|
||||
)
|
||||
|
||||
tensorizer_args = tensorizer_config._construct_tensorizer_args()
|
||||
|
||||
encryption_params = None
|
||||
if (keyfile := tensorizer_config.encryption_keyfile) is not None:
|
||||
with open(keyfile, "rb") as f:
|
||||
key = f.read()
|
||||
encryption_params = EncryptionParams(key=key)
|
||||
|
||||
output_file = tensorizer_args.tensorizer_uri
|
||||
if tensorizer_config._is_sharded:
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
output_file = output_file % get_tensor_model_parallel_rank()
|
||||
|
||||
with open_stream(
|
||||
output_file, mode="wb+", **tensorizer_args.stream_kwargs
|
||||
) as stream:
|
||||
serializer = TensorSerializer(
|
||||
stream,
|
||||
encryption=encryption_params,
|
||||
**tensorizer_config.serialization_kwargs,
|
||||
)
|
||||
serializer.write_module(model)
|
||||
serializer.close()
|
||||
|
||||
serialize_extra_artifacts(tensorizer_args, model_config.served_model_name)
|
||||
|
||||
logger.info("Successfully serialized model to %s", str(output_file))
|
||||
return model
|
||||
|
||||
|
||||
def tensorize_vllm_model(
|
||||
engine_args: "EngineArgs",
|
||||
tensorizer_config: TensorizerConfig,
|
||||
generate_keyfile: bool = True,
|
||||
):
|
||||
"""Utility to load a model and then serialize it with Tensorizer
|
||||
|
||||
Intended to be used separately from running a vLLM server since it
|
||||
creates its own Engine instance.
|
||||
"""
|
||||
engine_config = engine_args.create_engine_config()
|
||||
tensorizer_config.verify_with_model_config(engine_config.model_config)
|
||||
tensorizer_config.verify_with_parallel_config(engine_config.parallel_config)
|
||||
|
||||
# generate the encryption key before creating the engine to support sharding
|
||||
if (
|
||||
generate_keyfile
|
||||
and (keyfile := tensorizer_config.encryption_keyfile) is not None
|
||||
):
|
||||
encryption_params = EncryptionParams.random()
|
||||
with open_stream(
|
||||
keyfile,
|
||||
mode="wb+",
|
||||
s3_access_key_id=tensorizer_config.s3_access_key_id,
|
||||
s3_secret_access_key=tensorizer_config.s3_secret_access_key,
|
||||
s3_endpoint=tensorizer_config.s3_endpoint,
|
||||
) as stream:
|
||||
stream.write(encryption_params.key)
|
||||
|
||||
from vllm.v1.engine.llm_engine import LLMEngine
|
||||
|
||||
engine = LLMEngine.from_vllm_config(engine_config)
|
||||
engine.collective_rpc(
|
||||
"save_tensorized_model",
|
||||
kwargs={"tensorizer_config": tensorizer_config.to_serializable()},
|
||||
)
|
||||
|
||||
|
||||
def tensorize_lora_adapter(lora_path: str, tensorizer_config: TensorizerConfig):
|
||||
"""
|
||||
Uses tensorizer to serialize a LoRA adapter. Assumes that the files
|
||||
needed to load a LoRA adapter are a safetensors-format file called
|
||||
adapter_model.safetensors and a json config file called adapter_config.json.
|
||||
|
||||
Serializes the files in the tensorizer_config.tensorizer_dir
|
||||
"""
|
||||
import safetensors
|
||||
|
||||
from vllm.lora.utils import get_adapter_absolute_path
|
||||
|
||||
lora_dir = get_adapter_absolute_path(lora_path)
|
||||
|
||||
tensor_path = config_path = ""
|
||||
|
||||
for file in os.listdir(lora_dir):
|
||||
if file.startswith("adapter_model"):
|
||||
tensor_path = lora_dir + "/" + file
|
||||
if file.startswith("adapter_config"):
|
||||
config_path = lora_dir + "/" + file
|
||||
if tensor_path and config_path:
|
||||
break
|
||||
|
||||
if tensor_path.endswith(".safetensors"):
|
||||
tensors = safetensors.torch.load_file(tensor_path)
|
||||
elif tensor_path.endswith(".bin"):
|
||||
tensors = torch.load(tensor_path)
|
||||
else:
|
||||
raise ValueError("Unsupported file: %s", tensor_path)
|
||||
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
tensorizer_args = tensorizer_config._construct_tensorizer_args()
|
||||
|
||||
with open_stream(
|
||||
f"{tensorizer_config.tensorizer_dir}/adapter_config.json",
|
||||
mode="wb+",
|
||||
**tensorizer_args.stream_kwargs,
|
||||
) as f:
|
||||
f.write(json.dumps(config).encode("utf-8"))
|
||||
|
||||
lora_uri = f"{tensorizer_config.tensorizer_dir}/adapter_model.tensors"
|
||||
with open_stream(lora_uri, mode="wb+", **tensorizer_args.stream_kwargs) as f:
|
||||
serializer = TensorSerializer(f)
|
||||
serializer.write_state_dict(tensors)
|
||||
serializer.close()
|
||||
|
||||
logger.info(
|
||||
"Successfully serialized LoRA files to %s",
|
||||
str(tensorizer_config.tensorizer_dir),
|
||||
)
|
||||
151
model_executor/model_loader/tensorizer_loader.py
Normal file
151
model_executor/model_loader/tensorizer_loader.py
Normal file
@@ -0,0 +1,151 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: SIM117
|
||||
import copy
|
||||
from collections.abc import Generator
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import ModelConfig, ParallelConfig, VllmConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.tensorizer import (
|
||||
TensorizerConfig,
|
||||
deserialize_tensorizer_model,
|
||||
init_tensorizer_model,
|
||||
is_vllm_tensorized,
|
||||
serialize_vllm_model,
|
||||
tensorizer_weights_iterator,
|
||||
)
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
get_model_architecture,
|
||||
initialize_model,
|
||||
)
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
BLACKLISTED_TENSORIZER_ARGS = {
|
||||
"device", # vLLM decides this
|
||||
"dtype", # vLLM decides this
|
||||
"mode", # Not meant to be configurable by the user
|
||||
}
|
||||
|
||||
|
||||
def validate_config(config: dict):
|
||||
for k, v in config.items():
|
||||
if v is not None and k in BLACKLISTED_TENSORIZER_ARGS:
|
||||
raise ValueError(f"{k} is not an allowed Tensorizer argument.")
|
||||
|
||||
|
||||
class TensorizerLoader(BaseModelLoader):
|
||||
"""Model loader using CoreWeave's tensorizer library."""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
|
||||
self.tensorizer_config = load_config.model_loader_extra_config
|
||||
else:
|
||||
validate_config(load_config.model_loader_extra_config)
|
||||
self.tensorizer_config = TensorizerConfig(
|
||||
**load_config.model_loader_extra_config["tensorizer_config"]
|
||||
)
|
||||
|
||||
def _verify_config(
|
||||
self, model_config: ModelConfig, parallel_config: ParallelConfig
|
||||
):
|
||||
self.tensorizer_config.verify_with_model_config(model_config)
|
||||
self.tensorizer_config.verify_with_parallel_config(parallel_config)
|
||||
|
||||
def _get_weights_iterator(
|
||||
self,
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
|
||||
return tensorizer_weights_iterator(tensorizer_args)
|
||||
|
||||
def _load_model_serialized_cpu(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
) -> nn.Module:
|
||||
"""Load a serialized model with tensorizer to the CPU.
|
||||
|
||||
This is only necessary when the model isn't vLLM-tensorized (see
|
||||
examples/others/tensorize_vllm_model.py) This should still
|
||||
be faster than default HuggingFace loading, but will be slower than
|
||||
loading a vLLM-tensorized model.
|
||||
"""
|
||||
device_config = vllm_config.device_config
|
||||
model_config = vllm_config.model_config
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = initialize_model(vllm_config=vllm_config)
|
||||
|
||||
model.load_weights(self._get_weights_iterator())
|
||||
return model.eval()
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self.tensorizer_config.verify_with_model_config(model_config)
|
||||
|
||||
with self.tensorizer_config.open_stream():
|
||||
pass
|
||||
|
||||
def _patch_tensorizer_config(self, model_config: ModelConfig) -> TensorizerConfig:
|
||||
model_class = get_model_architecture(model_config)[0]
|
||||
tensorizer_config = copy.copy(self.tensorizer_config)
|
||||
tensorizer_config.model_class = model_class
|
||||
tensorizer_config.hf_config = model_config.hf_config
|
||||
tensorizer_config.dtype = model_config.dtype
|
||||
return tensorizer_config
|
||||
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
"""Load serialized model weights with tensorizer.
|
||||
|
||||
Expects a vLLM-tensorized model. See the
|
||||
examples/others/tensorize_vllm_model.py example script
|
||||
for serializing vLLM models."""
|
||||
if is_vllm_tensorized(self.tensorizer_config):
|
||||
tensorizer_config = self._patch_tensorizer_config(model_config)
|
||||
deserialize_tensorizer_model(model, tensorizer_config)
|
||||
else:
|
||||
model.load_weights(self._get_weights_iterator())
|
||||
|
||||
def load_model(
|
||||
self, vllm_config: VllmConfig, model_config: ModelConfig
|
||||
) -> nn.Module:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self._verify_config(model_config, parallel_config)
|
||||
|
||||
if parallel_config.tensor_parallel_size > 1:
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
self.tensorizer_config.tensorizer_uri = (
|
||||
self.tensorizer_config.tensorizer_uri % get_tensor_model_parallel_rank()
|
||||
)
|
||||
|
||||
if is_vllm_tensorized(self.tensorizer_config):
|
||||
tensorizer_config = self._patch_tensorizer_config(model_config)
|
||||
device_config = vllm_config.device_config
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = init_tensorizer_model(
|
||||
tensorizer_config=tensorizer_config, vllm_config=vllm_config
|
||||
)
|
||||
self.load_weights(model, model_config)
|
||||
return model
|
||||
return self._load_model_serialized_cpu(vllm_config=vllm_config)
|
||||
|
||||
@staticmethod
|
||||
def save_model(
|
||||
model: torch.nn.Module,
|
||||
tensorizer_config: TensorizerConfig | dict,
|
||||
model_config: ModelConfig,
|
||||
) -> None:
|
||||
if isinstance(tensorizer_config, dict):
|
||||
tensorizer_config = TensorizerConfig(**tensorizer_config)
|
||||
serialize_vllm_model(
|
||||
model=model,
|
||||
tensorizer_config=tensorizer_config,
|
||||
model_config=model_config,
|
||||
)
|
||||
118
model_executor/model_loader/tpu.py
Normal file
118
model_executor/model_loader/tpu.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.distributed.spmd as xs
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.distributed.tpu_distributed_utils import get_fqn, shard_model
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
initialize_model,
|
||||
process_weights_after_loading,
|
||||
)
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TPUModelLoader(DefaultModelLoader):
|
||||
"""
|
||||
A TPU model loader for model loading under SPMD mode.
|
||||
"""
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
model_config: ModelConfig,
|
||||
mesh: xs.Mesh | None = None,
|
||||
) -> nn.Module:
|
||||
# Initialize model and load weights on CPU. Then, during SPMD partition,
|
||||
# weights are sharded and transferred to TPUs.
|
||||
self.counter_before_loading_weights = time.perf_counter()
|
||||
model_config = vllm_config.model_config
|
||||
assert model_config.quantization is None, "Quantization not supported"
|
||||
target_device = torch.device("cpu")
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
model = initialize_model(vllm_config=vllm_config)
|
||||
|
||||
load_format = vllm_config.load_config.load_format
|
||||
if load_format != "dummy":
|
||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||
all_weights = self.get_all_weights(model_config, model)
|
||||
loaded_weights = model.load_weights(all_weights)
|
||||
self.counter_after_loading_weights = time.perf_counter()
|
||||
logger.info(
|
||||
"Loading weights took %.2f seconds",
|
||||
self.counter_after_loading_weights
|
||||
- self.counter_before_loading_weights,
|
||||
)
|
||||
# We only enable strict check for non-quantized models
|
||||
# that have loaded weights tracking currently.
|
||||
if model_config.quantization is None and loaded_weights is not None:
|
||||
weights_not_loaded = weights_to_load - loaded_weights
|
||||
if weights_not_loaded:
|
||||
raise ValueError(
|
||||
"Following weights were not initialized from "
|
||||
f"checkpoint: {weights_not_loaded}"
|
||||
)
|
||||
else:
|
||||
logger.info("Use dummy weight during weight loading.")
|
||||
|
||||
process_weights_after_loading(model, model_config, target_device)
|
||||
|
||||
counter_before_partition = time.perf_counter()
|
||||
model = model.eval()
|
||||
model = model.to("xla")
|
||||
shard_model(model, mesh)
|
||||
counter_after_partition = time.perf_counter()
|
||||
logger.info(
|
||||
"Partition model took %.2f seconds",
|
||||
counter_after_partition - counter_before_partition,
|
||||
)
|
||||
|
||||
# Ensure the model is properly loaded.
|
||||
self._check_model_is_loaded(mesh, model)
|
||||
|
||||
# Need to torch compile after model sharding are done. Because the
|
||||
# compiler hints ('xs.mark_sharding') are torch ops.
|
||||
if not model_config.is_multimodal_model:
|
||||
model.model = torch.compile(model.model, backend="openxla")
|
||||
else:
|
||||
model.language_model.model = torch.compile(
|
||||
model.language_model.model, backend="openxla"
|
||||
)
|
||||
return model
|
||||
|
||||
def _check_model_is_loaded(self, mesh: xs.Mesh | None, model: nn.Module) -> None:
|
||||
"""
|
||||
Ensure the model is properly loaded.
|
||||
1. All model parameters and buffers are on XLA device.
|
||||
2. Non-SPMD friendly layers are replaced as expected.
|
||||
"""
|
||||
device = xm.xla_device()
|
||||
device_type = str(device.type)
|
||||
|
||||
# Check parameters
|
||||
for name, param in model.named_parameters():
|
||||
assert param.device.type == device_type, (
|
||||
f"Parameter {name} is on {param.device.type} instead of {device_type}"
|
||||
)
|
||||
|
||||
# Check buffers
|
||||
for name, buffer in model.named_buffers():
|
||||
assert buffer.device.type == device_type, (
|
||||
f"Buffer {name} is on {buffer.device.type} instead of {device_type}"
|
||||
)
|
||||
|
||||
for module in model.modules():
|
||||
if (mesh is not None) and (get_fqn(module) == "QKVParallelLinear"):
|
||||
raise AssertionError(
|
||||
"QKVParallelLinear should be replaced by \
|
||||
XlaQKVParallelLinear under SPMD mode."
|
||||
)
|
||||
288
model_executor/model_loader/utils.py
Normal file
288
model_executor/model_loader/utils.py
Normal file
@@ -0,0 +1,288 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utilities for selecting and loading models."""
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.attention.layer import MLAAttention
|
||||
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.models.adapters import (
|
||||
as_embedding_model,
|
||||
as_reward_model,
|
||||
as_seq_cls_model,
|
||||
try_create_mm_pooling_model_cls,
|
||||
)
|
||||
from vllm.model_executor.models.interfaces import SupportsQuant, supports_multimodal
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def initialize_model(
|
||||
vllm_config: VllmConfig,
|
||||
*,
|
||||
prefix: str = "",
|
||||
model_class: type[nn.Module] | None = None,
|
||||
model_config: ModelConfig | None = None,
|
||||
) -> nn.Module:
|
||||
"""Initialize a model with the given configurations."""
|
||||
if model_config is None:
|
||||
model_config = vllm_config.model_config
|
||||
if model_class is None:
|
||||
model_class, _ = get_model_architecture(model_config)
|
||||
|
||||
if vllm_config.quant_config is not None:
|
||||
configure_quant_config(vllm_config.quant_config, model_class)
|
||||
|
||||
signatures = inspect.signature(model_class.__init__)
|
||||
all_params = [param.name for param in signatures.parameters.values()]
|
||||
if "vllm_config" in all_params and "prefix" in all_params:
|
||||
# new-style model class
|
||||
with set_current_vllm_config(vllm_config, check_compile=True, prefix=prefix):
|
||||
return model_class(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
msg = (
|
||||
"vLLM model class should accept `vllm_config` and `prefix` as "
|
||||
"input arguments. Possibly you have an old-style model class"
|
||||
" registered from out of tree and it is used for new vLLM version. "
|
||||
"Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
|
||||
"for the design and update the model class accordingly."
|
||||
)
|
||||
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
||||
|
||||
logger.warning(
|
||||
"Trying to guess the arguments for old-style model class %s",
|
||||
model_class,
|
||||
)
|
||||
# try to be compatible with old-style model class
|
||||
kwargs = {}
|
||||
if "prefix" in all_params:
|
||||
kwargs["prefix"] = prefix
|
||||
if "config" in all_params:
|
||||
kwargs["config"] = model_config.hf_config
|
||||
if "cache_config" in all_params:
|
||||
kwargs["cache_config"] = vllm_config.cache_config
|
||||
if "quant_config" in all_params:
|
||||
kwargs["quant_config"] = vllm_config.quant_config
|
||||
if "lora_config" in all_params:
|
||||
kwargs["lora_config"] = vllm_config.lora_config
|
||||
if "scheduler_config" in all_params:
|
||||
kwargs["scheduler_config"] = vllm_config.scheduler_config
|
||||
with set_current_vllm_config(vllm_config, check_compile=True, prefix=prefix):
|
||||
return model_class(**kwargs)
|
||||
|
||||
|
||||
def process_weights_after_loading(
|
||||
model: nn.Module, model_config: ModelConfig, target_device: torch.device
|
||||
) -> None:
|
||||
# to avoid circular dependency
|
||||
from vllm.model_executor.model_loader.online_quantization import (
|
||||
maybe_save_metadata_and_attributes_for_weight_reloading,
|
||||
)
|
||||
|
||||
maybe_save_metadata_and_attributes_for_weight_reloading(model, model_config)
|
||||
|
||||
for _, module in model.named_modules():
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if isinstance(quant_method, QuantizeMethodBase):
|
||||
# When quant methods need to process weights after loading
|
||||
# (for repacking, quantizing, etc), they expect parameters
|
||||
# to be on the global target device. This scope is for the
|
||||
# case where cpu offloading is used, where we will move the
|
||||
# parameters onto device for processing and back off after.
|
||||
with device_loading_context(module, target_device):
|
||||
quant_method.process_weights_after_loading(module)
|
||||
|
||||
# Initialize post-load attention weights for both Attention and MLA.
|
||||
# NOTE: Happens after other modules so we can easily decompress weights.
|
||||
for _, module in model.named_modules():
|
||||
if isinstance(module, (Attention, MLAAttention)) and hasattr(
|
||||
module, "process_weights_after_loading"
|
||||
):
|
||||
# TODO(lucas): see if there is a way to unify the signatures
|
||||
# of process_weights_after_loading
|
||||
module.process_weights_after_loading(model_config.dtype)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def device_loading_context(module: torch.nn.Module, target_device: torch.device):
|
||||
if target_device.type == "cpu":
|
||||
# If target is CPU, no need to move anything
|
||||
yield module
|
||||
return
|
||||
|
||||
original_device_states: dict[str, torch.device] = {}
|
||||
|
||||
# Store original device states and move parameters to GPU if they're on CPU
|
||||
for name, p in module.named_parameters():
|
||||
if p.device.type == "cpu":
|
||||
original_device_states[name] = p.device
|
||||
p.data = p.data.to(target_device)
|
||||
# Parameters already on target device are not touched
|
||||
|
||||
try:
|
||||
yield module
|
||||
|
||||
finally:
|
||||
# Restore parameters to their original devices, ignoring new parameters
|
||||
pin_memory = is_pin_memory_available()
|
||||
for name, p in module.named_parameters():
|
||||
if name in original_device_states:
|
||||
original_device: torch.device = original_device_states[name]
|
||||
if original_device.type == "cpu":
|
||||
# `torch.empty_like` does not support `pin_memory` argument
|
||||
cpu_data = torch.empty_strided(
|
||||
size=p.data.size(),
|
||||
stride=p.data.stride(),
|
||||
dtype=p.data.dtype,
|
||||
layout=p.data.layout,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
cpu_data.copy_(p.data)
|
||||
p.data = cpu_data
|
||||
else:
|
||||
p.data = p.data.to(original_device)
|
||||
# New parameters or parameters already on target device are untouched
|
||||
|
||||
|
||||
_MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]()
|
||||
"""Caches the outputs of `_get_model_architecture`."""
|
||||
|
||||
|
||||
def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]:
|
||||
architectures = getattr(model_config.hf_config, "architectures", [])
|
||||
|
||||
model_cls, arch = model_config.registry.resolve_model_cls(
|
||||
architectures,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
if arch == model_config._get_transformers_backend_cls():
|
||||
assert model_config.model_impl != "vllm"
|
||||
if model_config.model_impl == "auto":
|
||||
logger.warning_once(
|
||||
"%s has no vLLM implementation, falling back to Transformers "
|
||||
"implementation. Some features may not be supported and "
|
||||
"performance may not be optimal.",
|
||||
arch,
|
||||
)
|
||||
|
||||
convert_type = model_config.convert_type
|
||||
if convert_type != "none" and supports_multimodal(model_cls):
|
||||
logger.debug_once("Detected conversion of Multi Modal model.")
|
||||
converted = try_create_mm_pooling_model_cls(model_cls)
|
||||
if converted is not None:
|
||||
logger.debug_once("Creating wrapper class to forward pooler.")
|
||||
return converted, arch
|
||||
else:
|
||||
logger.debug_once("Attempting direct conversion.")
|
||||
|
||||
if convert_type == "none":
|
||||
pass
|
||||
elif convert_type == "embed":
|
||||
logger.debug_once("Converting to embedding model.")
|
||||
model_cls = as_embedding_model(model_cls)
|
||||
elif convert_type == "classify":
|
||||
logger.debug_once("Converting to sequence classification model.")
|
||||
model_cls = as_seq_cls_model(model_cls)
|
||||
elif convert_type == "reward":
|
||||
logger.debug_once("Converting to reward model.")
|
||||
model_cls = as_reward_model(model_cls)
|
||||
else:
|
||||
assert_never(convert_type)
|
||||
|
||||
return model_cls, arch
|
||||
|
||||
|
||||
def get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]:
|
||||
key = hash(
|
||||
(
|
||||
model_config.model,
|
||||
model_config.convert_type,
|
||||
model_config.runner_type,
|
||||
model_config.trust_remote_code,
|
||||
model_config.model_impl,
|
||||
tuple(getattr(model_config.hf_config, "architectures", [])),
|
||||
)
|
||||
)
|
||||
if key in _MODEL_ARCH_BY_HASH:
|
||||
return _MODEL_ARCH_BY_HASH[key]
|
||||
|
||||
model_arch = _get_model_architecture(model_config)
|
||||
_MODEL_ARCH_BY_HASH[key] = model_arch
|
||||
return model_arch
|
||||
|
||||
|
||||
def get_model_cls(model_config: ModelConfig) -> type[nn.Module]:
|
||||
return get_model_architecture(model_config)[0]
|
||||
|
||||
|
||||
def get_architecture_class_name(model_config: ModelConfig) -> str:
|
||||
return get_model_architecture(model_config)[1]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParamMapping:
|
||||
"""
|
||||
A class to handle parameter mapping for model weight loading.
|
||||
It creates a bidirectional mapping between packed parameters and their
|
||||
constituent parts.
|
||||
"""
|
||||
|
||||
packed_mapping: dict[str, list[str]]
|
||||
inverse_packed_mapping: dict[str, tuple[str, int]] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
for packed_name, sub_params in self.packed_mapping.items():
|
||||
# Skip self-contained cases (e.g., {"W_pack": ["W_pack"]})
|
||||
if len(sub_params) == 1 and sub_params[0] == packed_name:
|
||||
continue
|
||||
for index, param_name in enumerate(sub_params):
|
||||
self.inverse_packed_mapping[param_name] = (
|
||||
packed_name,
|
||||
index,
|
||||
)
|
||||
|
||||
def get_sub_modules(self, module_name: str) -> tuple[str, list[str]] | None:
|
||||
for key, value in self.packed_mapping.items():
|
||||
if module_name.endswith(key):
|
||||
return key, value
|
||||
return None
|
||||
|
||||
|
||||
def configure_quant_config(
|
||||
quant_config: QuantizationConfig, model_class: type[nn.Module]
|
||||
):
|
||||
"""
|
||||
Pass packed_modules_mapping by reference to quant_config so that
|
||||
quant_config can properly match fused modules
|
||||
|
||||
Note that model attributes are passed by reference to quant_config,
|
||||
enabling them to be updated by model_class.__new__ (ex. chatglm, qwen)
|
||||
|
||||
Once the `SupportsQuant` mixin has been added to all models, this
|
||||
function can be removed
|
||||
"""
|
||||
if not issubclass(model_class, SupportsQuant):
|
||||
hf_to_vllm_mapper = getattr(model_class, "hf_to_vllm_mapper", None)
|
||||
packed_mapping = getattr(model_class, "packed_modules_mapping", None)
|
||||
|
||||
# pass mappings by reference to quant_config
|
||||
if hf_to_vllm_mapper is not None:
|
||||
quant_config.apply_vllm_mapper(hf_to_vllm_mapper)
|
||||
if packed_mapping is not None:
|
||||
quant_config.packed_modules_mapping = packed_mapping
|
||||
1106
model_executor/model_loader/weight_utils.py
Normal file
1106
model_executor/model_loader/weight_utils.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user