first commit
This commit is contained in:
138
vllm/model_executor/model_loader/__init__.py
Normal file
138
vllm/model_executor/model_loader/__init__.py
Normal file
@@ -0,0 +1,138 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Literal, Optional
|
||||
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.bitsandbytes_loader import (
|
||||
BitsAndBytesModelLoader)
|
||||
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
|
||||
from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader
|
||||
from vllm.model_executor.model_loader.gguf_loader import GGUFModelLoader
|
||||
from vllm.model_executor.model_loader.runai_streamer_loader import (
|
||||
RunaiModelStreamerLoader)
|
||||
from vllm.model_executor.model_loader.sharded_state_loader import (
|
||||
ShardedStateLoader)
|
||||
from vllm.model_executor.model_loader.tensorizer_loader import TensorizerLoader
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
get_architecture_class_name, get_model_architecture, get_model_cls)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Reminder: Please update docstring in `LoadConfig`
|
||||
# if a new load format is added here
|
||||
LoadFormats = Literal[
|
||||
"auto",
|
||||
"bitsandbytes",
|
||||
"dummy",
|
||||
"fastsafetensors",
|
||||
"gguf",
|
||||
"mistral",
|
||||
"npcache",
|
||||
"pt",
|
||||
"runai_streamer",
|
||||
"runai_streamer_sharded",
|
||||
"safetensors",
|
||||
"sharded_state",
|
||||
"tensorizer",
|
||||
]
|
||||
_LOAD_FORMAT_TO_MODEL_LOADER: dict[str, type[BaseModelLoader]] = {
|
||||
"auto": DefaultModelLoader,
|
||||
"bitsandbytes": BitsAndBytesModelLoader,
|
||||
"dummy": DummyModelLoader,
|
||||
"fastsafetensors": DefaultModelLoader,
|
||||
"gguf": GGUFModelLoader,
|
||||
"mistral": DefaultModelLoader,
|
||||
"npcache": DefaultModelLoader,
|
||||
"pt": DefaultModelLoader,
|
||||
"runai_streamer": RunaiModelStreamerLoader,
|
||||
"runai_streamer_sharded": ShardedStateLoader,
|
||||
"safetensors": DefaultModelLoader,
|
||||
"sharded_state": ShardedStateLoader,
|
||||
"tensorizer": TensorizerLoader,
|
||||
}
|
||||
|
||||
|
||||
def register_model_loader(load_format: str):
|
||||
"""Register a customized vllm model loader.
|
||||
|
||||
When a load format is not supported by vllm, you can register a customized
|
||||
model loader to support it.
|
||||
|
||||
Args:
|
||||
load_format (str): The model loader format name.
|
||||
|
||||
Examples:
|
||||
>>> from vllm.config.load import LoadConfig
|
||||
>>> from vllm.model_executor.model_loader import get_model_loader, register_model_loader
|
||||
>>> from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
>>>
|
||||
>>> @register_model_loader("my_loader")
|
||||
... class MyModelLoader(BaseModelLoader):
|
||||
... def download_model(self):
|
||||
... pass
|
||||
...
|
||||
... def load_weights(self):
|
||||
... pass
|
||||
>>>
|
||||
>>> load_config = LoadConfig(load_format="my_loader")
|
||||
>>> type(get_model_loader(load_config))
|
||||
<class 'MyModelLoader'>
|
||||
""" # noqa: E501
|
||||
|
||||
def _wrapper(model_loader_cls):
|
||||
if load_format in _LOAD_FORMAT_TO_MODEL_LOADER:
|
||||
logger.warning(
|
||||
"Load format `%s` is already registered, and will be "
|
||||
"overwritten by the new loader class `%s`.", load_format,
|
||||
model_loader_cls)
|
||||
if not issubclass(model_loader_cls, BaseModelLoader):
|
||||
raise ValueError("The model loader must be a subclass of "
|
||||
"`BaseModelLoader`.")
|
||||
_LOAD_FORMAT_TO_MODEL_LOADER[load_format] = model_loader_cls
|
||||
logger.info("Registered model loader `%s` with load format `%s`",
|
||||
model_loader_cls, load_format)
|
||||
return model_loader_cls
|
||||
|
||||
return _wrapper
|
||||
|
||||
|
||||
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
||||
"""Get a model loader based on the load format."""
|
||||
load_format = load_config.load_format
|
||||
if load_format not in _LOAD_FORMAT_TO_MODEL_LOADER:
|
||||
raise ValueError(f"Load format `{load_format}` is not supported")
|
||||
return _LOAD_FORMAT_TO_MODEL_LOADER[load_format](load_config)
|
||||
|
||||
|
||||
def get_model(*,
|
||||
vllm_config: VllmConfig,
|
||||
model_config: Optional[ModelConfig] = None) -> nn.Module:
|
||||
loader = get_model_loader(vllm_config.load_config)
|
||||
if model_config is None:
|
||||
model_config = vllm_config.model_config
|
||||
return loader.load_model(vllm_config=vllm_config,
|
||||
model_config=model_config)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_model",
|
||||
"get_model_loader",
|
||||
"get_architecture_class_name",
|
||||
"get_model_architecture",
|
||||
"get_model_cls",
|
||||
"register_model_loader",
|
||||
"BaseModelLoader",
|
||||
"BitsAndBytesModelLoader",
|
||||
"GGUFModelLoader",
|
||||
"DefaultModelLoader",
|
||||
"DummyModelLoader",
|
||||
"RunaiModelStreamerLoader",
|
||||
"ShardedStateLoader",
|
||||
"TensorizerLoader",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
vllm/model_executor/model_loader/__pycache__/tpu.cpython-310.pyc
Normal file
BIN
vllm/model_executor/model_loader/__pycache__/tpu.cpython-310.pyc
Normal file
Binary file not shown.
Binary file not shown.
Binary file not shown.
52
vllm/model_executor/model_loader/base_loader.py
Normal file
52
vllm/model_executor/model_loader/base_loader.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
initialize_model, process_weights_after_loading, set_default_torch_dtype)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BaseModelLoader(ABC):
|
||||
"""Base class for model loaders."""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
self.load_config = load_config
|
||||
|
||||
@abstractmethod
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
"""Download a model so that it can be immediately loaded."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load_weights(self, model: nn.Module,
|
||||
model_config: ModelConfig) -> None:
|
||||
"""Load weights into a model. This standalone API allows
|
||||
inplace weights loading for an already-initialized model"""
|
||||
raise NotImplementedError
|
||||
|
||||
def load_model(self, vllm_config: VllmConfig,
|
||||
model_config: ModelConfig) -> nn.Module:
|
||||
"""Load a model with the given configurations."""
|
||||
device_config = vllm_config.device_config
|
||||
load_config = vllm_config.load_config
|
||||
load_device = device_config.device if load_config.device is None else \
|
||||
load_config.device
|
||||
target_device = torch.device(load_device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
model = initialize_model(vllm_config=vllm_config,
|
||||
model_config=model_config)
|
||||
|
||||
logger.debug("Loading weights on %s ...", load_device)
|
||||
# Quantization does not happen in `load_weights` but after it
|
||||
self.load_weights(model, model_config)
|
||||
process_weights_after_loading(model, model_config, target_device)
|
||||
return model.eval()
|
||||
788
vllm/model_executor/model_loader/bitsandbytes_loader.py
Normal file
788
vllm/model_executor/model_loader/bitsandbytes_loader.py
Normal file
@@ -0,0 +1,788 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: SIM117
|
||||
import fnmatch
|
||||
import glob
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
# yapf: enable
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.utils import (ParamMapping,
|
||||
set_default_torch_dtype)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_safetensors_index_file_from_hf, download_weights_from_hf,
|
||||
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
|
||||
pt_weights_iterator, safetensors_weights_iterator)
|
||||
from vllm.model_executor.models import is_pooling_model
|
||||
from vllm.model_executor.utils import (get_moe_expert_mapping,
|
||||
get_packed_modules_mapping,
|
||||
set_weight_attrs)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
# yapf conflicts with isort for this block
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def is_moe_model(model: torch.nn.Module) -> bool:
|
||||
"""Checks if the model contains FusedMoE layers."""
|
||||
return bool(any(
|
||||
isinstance(module, FusedMoE) for module in model.modules()))
|
||||
|
||||
|
||||
class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
"""Model loader to load model weights with BitAndBytes quantization."""
|
||||
|
||||
possible_config_file_names = ["adapter_config.json"]
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
|
||||
# Save the module names without sharding.
|
||||
self.unsharded_weights_modules: list[str] = []
|
||||
# Save the module names that are sharded by column.
|
||||
self.column_sharded_weights_modules: list[str] = []
|
||||
# Modules whose weights might have fused on disk
|
||||
# we need their output_sizes to make shard in flight correctly with TP
|
||||
self.maybe_fused_weights_modules: dict[str, list[int]] = {}
|
||||
# Store all module names (from transformers) that support
|
||||
# BNB quantization.
|
||||
self.target_modules: list[str] = []
|
||||
self.tp_disabled_modules: list[str] = []
|
||||
# Store the mapping of expert parameters for MoE models.
|
||||
self.expert_params_mapping: list[tuple[str, str, int, str]] = []
|
||||
# mapping weight names from transformers to vllm.
|
||||
self.weight_mapper: Callable = lambda name: name
|
||||
self.pre_quant: bool = False
|
||||
self.load_8bit: bool = False
|
||||
self.is_pool_model: bool = False
|
||||
|
||||
def _get_weight_files(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
allowed_patterns: list[str],
|
||||
revision: Optional[str] = None,
|
||||
) -> tuple[str, list[str], str]:
|
||||
"""Retrieve weight files. Download the files if necessary.
|
||||
|
||||
Return the weight files and the file pattern."""
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
|
||||
if is_local:
|
||||
for pattern in allowed_patterns:
|
||||
weight_files = glob.glob(
|
||||
os.path.join(model_name_or_path, pattern))
|
||||
if weight_files:
|
||||
return model_name_or_path, weight_files, pattern
|
||||
else:
|
||||
hf_api = HfApi()
|
||||
repo_files = hf_api.list_repo_files(repo_id=model_name_or_path)
|
||||
for pattern in allowed_patterns:
|
||||
matching_files = fnmatch.filter(repo_files, pattern)
|
||||
if matching_files:
|
||||
hf_folder = download_weights_from_hf(
|
||||
model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
[pattern],
|
||||
revision,
|
||||
ignore_patterns=self.load_config.ignore_patterns,
|
||||
)
|
||||
return hf_folder, glob.glob(
|
||||
os.path.join(hf_folder, pattern)), pattern
|
||||
|
||||
raise RuntimeError(
|
||||
f"No model weights found in: `{model_name_or_path}`")
|
||||
|
||||
def _prepare_weights(self, model_name_or_path: str,
|
||||
revision: Optional[str]) -> tuple[list[str], bool]:
|
||||
"""Prepare weight files for the model."""
|
||||
|
||||
allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]
|
||||
|
||||
hf_folder, hf_weights_files, matched_pattern = self._get_weight_files(
|
||||
model_name_or_path, allowed_patterns, revision)
|
||||
|
||||
use_safetensors = matched_pattern == "*.safetensors"
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
index_file = SAFE_WEIGHTS_INDEX_NAME
|
||||
if use_safetensors:
|
||||
# For models like Mistral-7B-Instruct-v0.3
|
||||
# there are both sharded safetensors files and a consolidated
|
||||
# safetensors file. Using both breaks.
|
||||
# Here, we download the `model.safetensors.index.json` and filter
|
||||
# any files not found in the index.
|
||||
if not is_local:
|
||||
download_safetensors_index_file_from_hf(
|
||||
model_name_or_path,
|
||||
index_file,
|
||||
self.load_config.download_dir,
|
||||
revision,
|
||||
)
|
||||
hf_weights_files = filter_duplicate_safetensors_files(
|
||||
hf_weights_files, hf_folder, index_file)
|
||||
else:
|
||||
hf_weights_files = filter_files_not_needed_for_inference(
|
||||
hf_weights_files)
|
||||
|
||||
if len(hf_weights_files) == 0:
|
||||
raise RuntimeError(
|
||||
f"Cannot find any model weights with `{model_name_or_path}`")
|
||||
|
||||
return hf_weights_files, use_safetensors
|
||||
|
||||
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
|
||||
|
||||
def _maybe_pool_model(module_name: str):
|
||||
# For pool model, we need to add the prefix `model.`
|
||||
# for the weight name if possible.
|
||||
if self.is_pool_model and self.target_modules[0]. \
|
||||
startswith("model.") and not module_name.startswith(
|
||||
"model."):
|
||||
return "model." + module_name
|
||||
|
||||
return module_name
|
||||
|
||||
if use_safetensors:
|
||||
iterator = safetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
)
|
||||
else:
|
||||
iterator = pt_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
self.load_config.pt_load_map_location,
|
||||
)
|
||||
for org_name, param in iterator:
|
||||
# mapping weight names from transformers to vllm while preserving
|
||||
# original names.
|
||||
mapped_name = self.weight_mapper(org_name)
|
||||
mapped_name = _maybe_pool_model(mapped_name)
|
||||
|
||||
yield org_name, mapped_name, param
|
||||
|
||||
def _get_quantized_weights_iterator(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
revision: Optional[str],
|
||||
) -> tuple[Generator[tuple[str, torch.Tensor], None, None], dict[str,
|
||||
Any]]:
|
||||
"""Get an iterator to the model weights with bitsandbytes quantization,
|
||||
as well as the quantization state dictionary."""
|
||||
|
||||
# only load the bitsandbytes module when needed
|
||||
try:
|
||||
import bitsandbytes
|
||||
|
||||
if version.parse(
|
||||
bitsandbytes.__version__) < version.parse("0.46.1"):
|
||||
raise ImportError("bitsandbytes version is wrong. Please "
|
||||
"install bitsandbytes>=0.46.1.")
|
||||
except ImportError as err:
|
||||
raise ImportError("Please install bitsandbytes>=0.46.1 via "
|
||||
"`pip install bitsandbytes>=0.46.1` to use "
|
||||
"bitsandbytes quantizer.") from err
|
||||
|
||||
hf_weights_files, use_safetensors = self._prepare_weights(
|
||||
model_name_or_path, revision)
|
||||
|
||||
quant_state_dict: dict[str, Any] = {}
|
||||
|
||||
if self.pre_quant:
|
||||
if self.load_8bit:
|
||||
return self._quantized_8bit_generator(
|
||||
hf_weights_files, use_safetensors,
|
||||
quant_state_dict), quant_state_dict
|
||||
else:
|
||||
return self._quantized_4bit_generator(
|
||||
hf_weights_files, use_safetensors,
|
||||
quant_state_dict), quant_state_dict
|
||||
|
||||
return self._unquantized_generator(hf_weights_files, use_safetensors,
|
||||
quant_state_dict), quant_state_dict
|
||||
|
||||
def _is_8bit_weight_name(self, weight_name: str):
|
||||
quantized_suffix = {".scb", ".weight_format"}
|
||||
return any(weight_name.lower().endswith(suffix)
|
||||
for suffix in quantized_suffix)
|
||||
|
||||
def _is_4bit_weight_name(self, weight_name: str):
|
||||
quantized_suffix = {
|
||||
"absmax",
|
||||
"quant_map",
|
||||
"nested_absmax",
|
||||
"nested_quant_map",
|
||||
"bitsandbytes",
|
||||
}
|
||||
suffix = weight_name.split(".")[-1]
|
||||
return any(q_suffix in suffix for q_suffix in quantized_suffix)
|
||||
|
||||
def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
|
||||
quant_state_dict) -> Generator:
|
||||
for (
|
||||
org_weight_name,
|
||||
mapped_weight_name,
|
||||
weight_tensor,
|
||||
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
|
||||
if not mapped_weight_name.lower().endswith(".scb"):
|
||||
continue
|
||||
|
||||
weight_key = mapped_weight_name.lower().replace(".scb", ".weight")
|
||||
quant_state_dict[weight_key] = weight_tensor
|
||||
|
||||
for (
|
||||
org_weight_name,
|
||||
mapped_weight_name,
|
||||
weight_tensor,
|
||||
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
|
||||
if self._is_8bit_weight_name(mapped_weight_name):
|
||||
continue
|
||||
|
||||
if mapped_weight_name in quant_state_dict:
|
||||
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
|
||||
yield org_weight_name, weight_tensor
|
||||
else:
|
||||
yield org_weight_name, weight_tensor
|
||||
|
||||
def _quantized_4bit_generator(self, hf_weights_files, use_safetensors,
|
||||
quant_state_dict) -> Generator:
|
||||
from bitsandbytes.functional import QuantState
|
||||
|
||||
# First iterate over all quant state weights
|
||||
weight_iterator = self._hf_weight_iter(hf_weights_files,
|
||||
use_safetensors)
|
||||
temp_state_dict = {}
|
||||
for (
|
||||
org_weight_name,
|
||||
mapped_weight_name,
|
||||
weight_tensor,
|
||||
) in weight_iterator:
|
||||
if not self._is_4bit_weight_name(mapped_weight_name):
|
||||
continue
|
||||
# bitsandbytes library requires
|
||||
# weight.quant_state.bitsandbytes__* in CPU
|
||||
if "quant_state.bitsandbytes" in mapped_weight_name:
|
||||
temp_state_dict[mapped_weight_name] = weight_tensor.cpu().data
|
||||
else:
|
||||
temp_state_dict[mapped_weight_name] = weight_tensor
|
||||
|
||||
# Closure to parse quant_state for each prequant weight
|
||||
def _parse_quant_state(param_name: str,
|
||||
temp_state_dict: dict) -> QuantState:
|
||||
quant_state = {}
|
||||
for k in temp_state_dict:
|
||||
if param_name + "." in k:
|
||||
quant_state[k] = temp_state_dict[k]
|
||||
|
||||
return QuantState.from_dict(quant_state,
|
||||
device=current_platform.device_type)
|
||||
|
||||
# Second iterate over all prequant and normal weights
|
||||
# pre quantized weights would have a quant_state
|
||||
for (
|
||||
org_weight_name,
|
||||
mapped_weight_name,
|
||||
weight_tensor,
|
||||
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
|
||||
if self._is_4bit_weight_name(mapped_weight_name):
|
||||
continue
|
||||
|
||||
if (f"{mapped_weight_name}.quant_state.bitsandbytes__nf4"
|
||||
in temp_state_dict) or (
|
||||
f"{mapped_weight_name}.quant_state.bitsandbytes__fp4"
|
||||
in temp_state_dict):
|
||||
quant_state = _parse_quant_state(mapped_weight_name,
|
||||
temp_state_dict)
|
||||
quant_state_dict[mapped_weight_name] = quant_state
|
||||
yield org_weight_name, weight_tensor
|
||||
else:
|
||||
yield org_weight_name, weight_tensor
|
||||
|
||||
def _unquantized_generator(self, hf_weights_files, use_safetensors,
|
||||
quant_state_dict) -> Generator:
|
||||
from bitsandbytes.functional import quantize_4bit
|
||||
|
||||
global_tp_size = get_tensor_model_parallel_world_size()
|
||||
global_tp_rank = get_tensor_model_parallel_rank()
|
||||
check_match = (lambda weight_name, module_name: weight_name.
|
||||
removesuffix(".weight") == module_name)
|
||||
for (
|
||||
org_weight_name,
|
||||
mapped_weight_name,
|
||||
weight_tensor,
|
||||
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
|
||||
|
||||
# override tp_size and tp_rank if the module has disabled TP
|
||||
if any(tp_disabled_module in mapped_weight_name
|
||||
for tp_disabled_module in self.tp_disabled_modules):
|
||||
tp_size = 1
|
||||
tp_rank = 0
|
||||
else:
|
||||
tp_size = global_tp_size
|
||||
tp_rank = global_tp_rank
|
||||
|
||||
if any(target_module in mapped_weight_name
|
||||
for target_module in self.target_modules
|
||||
) and mapped_weight_name.endswith(".weight"):
|
||||
# Without sharding
|
||||
if any(
|
||||
check_match(mapped_weight_name, module)
|
||||
for module in self.unsharded_weights_modules):
|
||||
weight_sub_tensor = weight_tensor
|
||||
# Shard by column
|
||||
elif any(
|
||||
check_match(mapped_weight_name, module)
|
||||
for module in self.column_sharded_weights_modules):
|
||||
total_size = weight_tensor.size(-1)
|
||||
start_index = total_size // tp_size * tp_rank
|
||||
end_index = total_size // tp_size * (tp_rank + 1)
|
||||
weight_sub_tensor = weight_tensor[...,
|
||||
start_index:end_index]
|
||||
# Weights have fused on disk. In this case, we assume that the
|
||||
# weight and module use same name.
|
||||
elif any(
|
||||
check_match(mapped_weight_name, module)
|
||||
for module in self.maybe_fused_weights_modules):
|
||||
# special case for fused weights
|
||||
# get the size of each shard weight tensor
|
||||
total_shard_sizes = next(
|
||||
(sizes for module, sizes in
|
||||
self.maybe_fused_weights_modules.items()
|
||||
if check_match(mapped_weight_name, module)))
|
||||
total_size = weight_tensor.size(0)
|
||||
assert total_size == sum(total_shard_sizes)
|
||||
# get the start/end index of each shard weight tensor
|
||||
total_start_index = list(
|
||||
itertools.accumulate([0] + total_shard_sizes))[:-1]
|
||||
shard_weights_index = [(
|
||||
idx + size // tp_size * tp_rank,
|
||||
idx + size // tp_size * (tp_rank + 1),
|
||||
) for idx, size in zip(total_start_index,
|
||||
total_shard_sizes)]
|
||||
# slice and reorder the weight tensor
|
||||
weight_tensor = [
|
||||
weight_tensor[start_index:end_index, ...]
|
||||
for start_index, end_index in shard_weights_index
|
||||
]
|
||||
weight_sub_tensor = torch.cat(weight_tensor, dim=0)
|
||||
# Shard by row
|
||||
else:
|
||||
total_size = weight_tensor.size(0)
|
||||
start_index = total_size // tp_size * tp_rank
|
||||
end_index = total_size // tp_size * (tp_rank + 1)
|
||||
weight_sub_tensor = weight_tensor[start_index:end_index,
|
||||
...]
|
||||
|
||||
# bitsandbytes requires data in GPU
|
||||
if weight_sub_tensor.is_cuda:
|
||||
loaded_weight = weight_sub_tensor
|
||||
else:
|
||||
loaded_weight = weight_sub_tensor.to(
|
||||
device=current_platform.device_type)
|
||||
|
||||
# remove the following after the issue is fixed:
|
||||
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
|
||||
if loaded_weight.is_contiguous() is False:
|
||||
loaded_weight = loaded_weight.contiguous()
|
||||
|
||||
with set_default_torch_dtype(torch.float32):
|
||||
processed_weight, quant_state = quantize_4bit(
|
||||
loaded_weight,
|
||||
compress_statistics=True,
|
||||
quant_type="nf4",
|
||||
)
|
||||
|
||||
quant_state_dict[mapped_weight_name] = quant_state
|
||||
else:
|
||||
processed_weight = weight_tensor
|
||||
yield org_weight_name, processed_weight
|
||||
|
||||
def _get_bnb_target_modules(self, model: nn.Module) -> None:
|
||||
"""
|
||||
Identify and collect all modules that support BitsAndBytes
|
||||
quantization.
|
||||
"""
|
||||
for name, module in model.named_modules():
|
||||
if (isinstance(module, LinearBase)
|
||||
and hasattr(module.quant_method, "quant_config")):
|
||||
if modules_info := self.modules_mapping.get_sub_modules(name):
|
||||
# Map vllm's names to transformers's names.
|
||||
rep_name, sub_modules = modules_info
|
||||
for sub_name in sub_modules:
|
||||
new_name = name.replace(rep_name, sub_name)
|
||||
self.target_modules.append(new_name)
|
||||
if module.disable_tp:
|
||||
self.tp_disabled_modules.append(new_name)
|
||||
# Add original module name even if the module has stacked map,
|
||||
# in case model has a mixture of disk-merged and disk-split
|
||||
# weights with same last name.
|
||||
self.target_modules.append(name)
|
||||
if module.disable_tp:
|
||||
self.tp_disabled_modules.append(name)
|
||||
elif isinstance(module, FusedMoE) and hasattr(
|
||||
module.quant_method, "quant_config"):
|
||||
# TODO: support FusedMoE with prequant and 8bit.
|
||||
if self.pre_quant and self.load_8bit:
|
||||
raise ValueError(
|
||||
"Prequant BitsAndBytes 8bit models with FusedMoE "
|
||||
"is not supported yet.")
|
||||
# Get the corresponding weight name using module name and
|
||||
# expert_params_mapping.
|
||||
|
||||
for exp in self.expert_params_mapping:
|
||||
weight_name = exp[1]
|
||||
rep_name = name.replace("experts",
|
||||
"") + weight_name.removesuffix(".")
|
||||
self.target_modules.append(rep_name)
|
||||
|
||||
assert (self.target_modules
|
||||
), "vLLM currently does not support BNB quantization for"
|
||||
f" {type(model).__name__}"
|
||||
|
||||
def _classify_module_sharding(self, model: nn.Module):
|
||||
"""
|
||||
Categorize modules based on their weight sharding requirements
|
||||
for tensor parallelism.
|
||||
"""
|
||||
for name, module in model.named_modules():
|
||||
# Some modules like `ReplicatedLinear` should not have their weights
|
||||
# sharded. The reason for implementing it this way is to avoid new
|
||||
# static variable in the model implementation.
|
||||
if isinstance(module, (ReplicatedLinear, )):
|
||||
self.unsharded_weights_modules.append(name)
|
||||
# `QKVParallelLinear` and `MergedColumnParallelLinear` might have
|
||||
# fused weights on disk. We need to use the output sizes of these
|
||||
# modules to shard the weights correctly.
|
||||
elif isinstance(module,
|
||||
(QKVParallelLinear, MergedColumnParallelLinear)):
|
||||
self.maybe_fused_weights_modules[name] = module.output_sizes
|
||||
# In TP, these weights are partitioned along the column
|
||||
# dimension (dim=-1)
|
||||
elif isinstance(module, (RowParallelLinear, )):
|
||||
self.column_sharded_weights_modules.append(name)
|
||||
elif isinstance(module, FusedMoE):
|
||||
expert_mapping = self.expert_params_mapping
|
||||
for exp in expert_mapping:
|
||||
if exp[-1] == "w2":
|
||||
weight_name = exp[1]
|
||||
rep_name = name.replace(
|
||||
"experts", "") + weight_name.removesuffix(".")
|
||||
self.column_sharded_weights_modules.append(rep_name)
|
||||
|
||||
def _verify_model_compatibility(self, model: nn.Module,
|
||||
model_config: ModelConfig) -> None:
|
||||
"""
|
||||
Verify that the model is compatible with BitsAndBytes quantization.
|
||||
"""
|
||||
if not hasattr(model, "load_weights"):
|
||||
raise AttributeError(
|
||||
"The required method 'load_weights' is not defined in class"
|
||||
f" {type(model).__name__}.")
|
||||
|
||||
if not hasattr(model, "packed_modules_mapping"):
|
||||
raise AttributeError(
|
||||
f"Model {type(model).__name__} does not support BitsAndBytes "
|
||||
"quantization yet. No 'packed_modules_mapping' found.")
|
||||
|
||||
quant_config = getattr(model_config.hf_config, "quantization_config",
|
||||
None)
|
||||
if quant_config is not None:
|
||||
quant_method = quant_config.get("quant_method")
|
||||
if quant_method == "bitsandbytes":
|
||||
self.pre_quant = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"BitsAndBytes loader does not support {quant_method} "
|
||||
"quantization")
|
||||
|
||||
# The quant_states in pre_quantized models cannot work with a split
|
||||
# weight tensor. So TP does not work with pre_quantized bnb models.
|
||||
if self.pre_quant and get_tensor_model_parallel_world_size() > 1:
|
||||
raise ValueError(
|
||||
"Prequant BitsAndBytes models with tensor parallelism is not "
|
||||
"supported. Please try with pipeline parallelism.")
|
||||
if self.pre_quant:
|
||||
self.load_8bit = quant_config.get("load_in_8bit", False)
|
||||
|
||||
def _initialize_loader_state(self, model: nn.Module,
|
||||
model_config: ModelConfig) -> None:
|
||||
"""
|
||||
Initialize the loader's internal state based on the model and
|
||||
configuration.
|
||||
"""
|
||||
self.is_pool_model = is_pooling_model(model)
|
||||
self.modules_mapping = ParamMapping(get_packed_modules_mapping(model))
|
||||
|
||||
if is_moe_model(model):
|
||||
self.expert_params_mapping = get_moe_expert_mapping(model)
|
||||
if not self.expert_params_mapping:
|
||||
raise AttributeError(
|
||||
f"MoE Model {type(model).__name__} does not support "
|
||||
"BitsAndBytes quantization yet. Ensure this model has "
|
||||
"'get_expert_mapping' method.")
|
||||
# For some models like Molmo, we need to use hf_to_vllm_mapper
|
||||
# to ensure correct loading of weights.
|
||||
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
|
||||
self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name)
|
||||
|
||||
self._get_bnb_target_modules(model)
|
||||
self._classify_module_sharding(model)
|
||||
|
||||
def _dequantize_dq(self, quant_states: Any):
|
||||
"""
|
||||
When BNB employs Double Quantization, we perform the dequantization of
|
||||
these constants during weight loading rather than at inference time,
|
||||
thereby avoiding this computational overhead during inference. This
|
||||
comes at the cost of increased memory usage.
|
||||
"""
|
||||
from bitsandbytes.functional import QuantState, dequantize_blockwise
|
||||
|
||||
def _dequantize_single_state(quant_state):
|
||||
"""Helper function to dequantize a single QuantState object."""
|
||||
if not (isinstance(quant_state, QuantState)
|
||||
and quant_state.nested):
|
||||
return
|
||||
|
||||
# Copied from: https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.45.3/bitsandbytes/functional.py#L1352-#L1356
|
||||
absmax = dequantize_blockwise(quant_state.absmax,
|
||||
quant_state.state2)
|
||||
absmax += quant_state.offset
|
||||
|
||||
# Ensure float32 dtype
|
||||
if absmax.dtype != torch.float32:
|
||||
absmax = absmax.float()
|
||||
|
||||
quant_state.absmax = absmax
|
||||
quant_state.nested = False
|
||||
quant_state.offset = None
|
||||
quant_state.state2 = None
|
||||
|
||||
if isinstance(quant_states, dict):
|
||||
for quant_state in quant_states.values():
|
||||
_dequantize_single_state(quant_state)
|
||||
else:
|
||||
_dequantize_single_state(quant_states)
|
||||
return quant_states
|
||||
|
||||
def _fuse_moe_quant_states(self, model: nn.Module,
|
||||
quant_states_dict: dict) -> dict:
|
||||
"""
|
||||
|
||||
This function consolidates individual expert quantization states into
|
||||
fused representations for w13 and w2.
|
||||
"""
|
||||
from bitsandbytes.functional import QuantState
|
||||
|
||||
if not self.expert_params_mapping:
|
||||
return dict()
|
||||
|
||||
expert_mapping = self.expert_params_mapping
|
||||
expert_qs_dict = {}
|
||||
for name, module in model.named_modules():
|
||||
if not isinstance(module, FusedMoE):
|
||||
continue
|
||||
w1_states_lst = []
|
||||
w2_states_lst = []
|
||||
w3_states_lst = []
|
||||
for exp in expert_mapping:
|
||||
shard_id = exp[-1]
|
||||
if shard_id not in ("w1", "w2", "w3"):
|
||||
raise ValueError(f"shard_id must be ['w1','w2','w3'] but "
|
||||
f"got {shard_id}.")
|
||||
layer_prefix = name.split("experts")[0]
|
||||
weight_qual_name = layer_prefix + exp[1] + "weight"
|
||||
quant_state = self._dequantize_dq(
|
||||
quant_states_dict[weight_qual_name])
|
||||
if shard_id == "w1":
|
||||
w1_states_lst.append(quant_state)
|
||||
elif shard_id == "w2":
|
||||
w2_states_lst.append(quant_state)
|
||||
else:
|
||||
w3_states_lst.append(quant_state)
|
||||
del quant_states_dict[weight_qual_name]
|
||||
assert (len(w1_states_lst) == len(w2_states_lst) ==
|
||||
len(w3_states_lst))
|
||||
w13_absmax_lst = []
|
||||
w2_absmax_lst = []
|
||||
w13_total_dim0 = 0
|
||||
w2_total_dim0 = 0
|
||||
for w1_qs, w2_qs, w3_qs in zip(w1_states_lst, w2_states_lst,
|
||||
w3_states_lst):
|
||||
assert w1_qs.shape == w3_qs.shape
|
||||
assert w1_qs.blocksize == w2_qs.blocksize == w3_qs.blocksize
|
||||
assert w1_qs.dtype == w2_qs.dtype == w3_qs.dtype
|
||||
# w1 and w3 are interleaved in storage
|
||||
w13_absmax_lst.append(w1_qs.absmax)
|
||||
w13_absmax_lst.append(w3_qs.absmax)
|
||||
w2_absmax_lst.append(w2_qs.absmax)
|
||||
w13_total_dim0 += w1_qs.shape[0] + w3_qs.shape[0]
|
||||
w2_total_dim0 += w2_qs.shape[0]
|
||||
|
||||
w13_absmax = torch.cat(w13_absmax_lst)
|
||||
w2_absmax = torch.cat(w2_absmax_lst)
|
||||
# Create fused quantization state for w13.
|
||||
w13_qs = QuantState(
|
||||
absmax=w13_absmax,
|
||||
shape=(w13_total_dim0, w1_states_lst[0].shape[1]),
|
||||
code=w1_states_lst[0].code,
|
||||
blocksize=w1_states_lst[0].blocksize,
|
||||
quant_type="nf4",
|
||||
dtype=w1_states_lst[0].dtype,
|
||||
)
|
||||
# Create fused quantization state for w2.
|
||||
w2_qs = QuantState(
|
||||
absmax=w2_absmax,
|
||||
shape=(w2_total_dim0, w2_states_lst[0].shape[1]),
|
||||
code=w2_states_lst[0].code,
|
||||
blocksize=w2_states_lst[0].blocksize,
|
||||
quant_type="nf4",
|
||||
dtype=w2_states_lst[0].dtype,
|
||||
)
|
||||
# The weight suffixes .w13_weight and .w2_weight are consistent
|
||||
# with the param in BitsAndBytesMoEMethod.
|
||||
w13_weight_name = name + ".w13_weight"
|
||||
w2_weight_name = name + ".w2_weight"
|
||||
expert_qs_dict[w13_weight_name] = w13_qs
|
||||
expert_qs_dict[w2_weight_name] = w2_qs
|
||||
return expert_qs_dict
|
||||
|
||||
def _stack_quantization_states(
|
||||
self, model: nn.Module,
|
||||
quant_state_dict: dict) -> dict[str, dict[int, Any]]:
|
||||
stacked_quant_state_dict: dict[str, dict[int, Any]] = {}
|
||||
# TODO: Change this lazy import to normal import
|
||||
# after the checks are updated to run on a new version
|
||||
from vllm.model_executor.models.utils import is_pp_missing_parameter
|
||||
param_dict = dict(model.named_parameters())
|
||||
for quant_param_name in quant_state_dict:
|
||||
if is_pp_missing_parameter(quant_param_name, model):
|
||||
continue
|
||||
|
||||
non_stacked_param_name = quant_param_name
|
||||
|
||||
shard_index = 0
|
||||
for shard_name, (
|
||||
weight_name,
|
||||
index,
|
||||
) in self.modules_mapping.inverse_packed_mapping.items():
|
||||
# Some models, such as MiniCPM V2.5/2.6, contain both
|
||||
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
|
||||
# from being incorrectly identified as being present in
|
||||
# 'vpm.encoder.layers.0.self_attn.qkv_proj.weight
|
||||
shard_pos = quant_param_name.find(shard_name)
|
||||
can_correct_rename = (shard_pos
|
||||
> 0) and (quant_param_name[shard_pos - 1]
|
||||
== ".")
|
||||
# If the quant_param_name is packed, it won't occur in the
|
||||
# param_dict before renaming.
|
||||
new_quant_param_name = quant_param_name.replace(
|
||||
shard_name, weight_name)
|
||||
need_rename = (quant_param_name not in param_dict) \
|
||||
and (new_quant_param_name in param_dict)
|
||||
if can_correct_rename and need_rename:
|
||||
shard_index = index
|
||||
quant_param_name = new_quant_param_name
|
||||
break
|
||||
|
||||
# Models like Clip/Siglip may skip some layers in initialization,
|
||||
# causing unused quant_param_name in state_dict.
|
||||
if quant_param_name not in param_dict:
|
||||
continue
|
||||
|
||||
if quant_param_name not in stacked_quant_state_dict:
|
||||
stacked_quant_state_dict[quant_param_name] = {}
|
||||
|
||||
stacked_quant_state_dict[quant_param_name][shard_index] = (
|
||||
quant_state_dict[non_stacked_param_name])
|
||||
return stacked_quant_state_dict
|
||||
|
||||
def _bind_quant_states_to_params(self, model: nn.Module,
|
||||
stacked_quant_state_dict: dict) -> None:
|
||||
# save quant_states and offsets as the attributes of the parameters
|
||||
param_dict = dict(model.named_parameters())
|
||||
for param_name, param in param_dict.items():
|
||||
if param_name in stacked_quant_state_dict:
|
||||
quant_states = stacked_quant_state_dict[param_name]
|
||||
# Dequantize double quantized values during weight loading.
|
||||
self._dequantize_dq(quant_states)
|
||||
set_weight_attrs(param, {"bnb_quant_state": quant_states})
|
||||
if not isinstance(quant_states, dict):
|
||||
continue
|
||||
|
||||
pack_ratio = getattr(param, "pack_factor", -1)
|
||||
if pack_ratio == -1:
|
||||
raise ValueError(
|
||||
f"pack_factor not set for parameter {param_name}.")
|
||||
|
||||
num_elements = [0] * len(quant_states)
|
||||
for seq, quant_state in quant_states.items():
|
||||
num_elements[seq] = (math.prod(quant_state.shape) //
|
||||
pack_ratio)
|
||||
|
||||
offsets = np.concatenate(([0], np.cumsum(num_elements)))
|
||||
# Make torch infer_schema happy
|
||||
offsets = torch.tensor(offsets).cpu()
|
||||
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
|
||||
|
||||
if self.load_8bit:
|
||||
set_weight_attrs(
|
||||
param, {"matmul_state": [None] * len(quant_states)})
|
||||
|
||||
def load_weights(self, model: nn.Module,
|
||||
model_config: ModelConfig) -> None:
|
||||
|
||||
self._verify_model_compatibility(model, model_config)
|
||||
self._initialize_loader_state(model, model_config)
|
||||
|
||||
logger.info("Loading weights with BitsAndBytes quantization. "
|
||||
"May take a while ...")
|
||||
qweight_iterator, quant_state_dict = (
|
||||
self._get_quantized_weights_iterator(
|
||||
model_config.model,
|
||||
model_config.revision,
|
||||
))
|
||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||
loaded_weights = model.load_weights(qweight_iterator)
|
||||
# Some models may have weights loading tracker unimplemented.
|
||||
if loaded_weights is not None:
|
||||
weights_not_loaded = weights_to_load - loaded_weights
|
||||
if weights_not_loaded:
|
||||
raise ValueError("Following weights were not initialized from "
|
||||
f"checkpoint: {weights_not_loaded}")
|
||||
expert_quant_state_dict = self._fuse_moe_quant_states(
|
||||
model, quant_state_dict)
|
||||
|
||||
stacked_quant_state_dict = self._stack_quantization_states(
|
||||
model, quant_state_dict)
|
||||
|
||||
stacked_quant_state_dict = {
|
||||
**expert_quant_state_dict,
|
||||
**stacked_quant_state_dict
|
||||
}
|
||||
self._bind_quant_states_to_params(model, stacked_quant_state_dict)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(model_config.model, model_config.revision)
|
||||
277
vllm/model_executor/model_loader/default_loader.py
Normal file
277
vllm/model_executor/model_loader/default_loader.py
Normal file
@@ -0,0 +1,277 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import dataclasses
|
||||
import glob
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Generator, Iterable
|
||||
from typing import Optional, cast
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_safetensors_index_file_from_hf, download_weights_from_hf,
|
||||
fastsafetensors_weights_iterator, filter_duplicate_safetensors_files,
|
||||
filter_files_not_needed_for_inference, maybe_download_from_modelscope,
|
||||
multi_thread_pt_weights_iterator,
|
||||
multi_thread_safetensors_weights_iterator, np_cache_weights_iterator,
|
||||
pt_weights_iterator, safetensors_weights_iterator)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DefaultModelLoader(BaseModelLoader):
|
||||
"""Model loader that can load different file types from disk."""
|
||||
|
||||
# default number of thread when enable multithread weight loading
|
||||
DEFAULT_NUM_THREADS = 8
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Source:
|
||||
"""A source for weights."""
|
||||
|
||||
model_or_path: str
|
||||
"""The model ID or path."""
|
||||
|
||||
revision: Optional[str]
|
||||
"""The optional model revision."""
|
||||
|
||||
prefix: str = ""
|
||||
"""A prefix to prepend to all weights."""
|
||||
|
||||
fall_back_to_pt: bool = True
|
||||
"""Whether .pt weights can be used."""
|
||||
|
||||
allow_patterns_overrides: Optional[list[str]] = None
|
||||
"""If defined, weights will load exclusively using these patterns."""
|
||||
|
||||
counter_before_loading_weights: float = 0.0
|
||||
counter_after_loading_weights: float = 0.0
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
|
||||
extra_config = load_config.model_loader_extra_config
|
||||
allowed_keys = {"enable_multithread_load", "num_threads"}
|
||||
unexpected_keys = set(extra_config.keys()) - allowed_keys
|
||||
|
||||
if unexpected_keys:
|
||||
raise ValueError(f"Unexpected extra config keys for load format "
|
||||
f"{load_config.load_format}: "
|
||||
f"{unexpected_keys}")
|
||||
|
||||
def _prepare_weights(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
revision: Optional[str],
|
||||
fall_back_to_pt: bool,
|
||||
allow_patterns_overrides: Optional[list[str]],
|
||||
) -> tuple[str, list[str], bool]:
|
||||
"""Prepare weights for the model.
|
||||
|
||||
If the model is not local, it will be downloaded."""
|
||||
model_name_or_path = (maybe_download_from_modelscope(
|
||||
model_name_or_path, revision) or model_name_or_path)
|
||||
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
load_format = self.load_config.load_format
|
||||
use_safetensors = False
|
||||
index_file = SAFE_WEIGHTS_INDEX_NAME
|
||||
# Some quantized models use .pt files for storing the weights.
|
||||
if load_format == "auto":
|
||||
allow_patterns = ["*.safetensors", "*.bin"]
|
||||
elif (load_format == "safetensors"
|
||||
or load_format == "fastsafetensors"):
|
||||
use_safetensors = True
|
||||
allow_patterns = ["*.safetensors"]
|
||||
elif load_format == "mistral":
|
||||
use_safetensors = True
|
||||
allow_patterns = ["consolidated*.safetensors"]
|
||||
index_file = "consolidated.safetensors.index.json"
|
||||
elif load_format == "pt":
|
||||
allow_patterns = ["*.pt"]
|
||||
elif load_format == "npcache":
|
||||
allow_patterns = ["*.bin"]
|
||||
else:
|
||||
raise ValueError(f"Unknown load_format: {load_format}")
|
||||
|
||||
if fall_back_to_pt:
|
||||
allow_patterns += ["*.pt"]
|
||||
|
||||
if allow_patterns_overrides is not None:
|
||||
allow_patterns = allow_patterns_overrides
|
||||
|
||||
if not is_local:
|
||||
hf_folder = download_weights_from_hf(
|
||||
model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
allow_patterns,
|
||||
revision,
|
||||
ignore_patterns=self.load_config.ignore_patterns,
|
||||
)
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
|
||||
hf_weights_files: list[str] = []
|
||||
for pattern in allow_patterns:
|
||||
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
||||
if len(hf_weights_files) > 0:
|
||||
if pattern == "*.safetensors":
|
||||
use_safetensors = True
|
||||
break
|
||||
|
||||
if use_safetensors:
|
||||
# For models like Mistral-7B-Instruct-v0.3
|
||||
# there are both sharded safetensors files and a consolidated
|
||||
# safetensors file. Using both breaks.
|
||||
# Here, we download the `model.safetensors.index.json` and filter
|
||||
# any files not found in the index.
|
||||
if not is_local:
|
||||
download_safetensors_index_file_from_hf(
|
||||
model_name_or_path,
|
||||
index_file,
|
||||
self.load_config.download_dir,
|
||||
revision,
|
||||
)
|
||||
hf_weights_files = filter_duplicate_safetensors_files(
|
||||
hf_weights_files, hf_folder, index_file)
|
||||
else:
|
||||
hf_weights_files = filter_files_not_needed_for_inference(
|
||||
hf_weights_files)
|
||||
|
||||
if len(hf_weights_files) == 0:
|
||||
raise RuntimeError(
|
||||
f"Cannot find any model weights with `{model_name_or_path}`")
|
||||
|
||||
return hf_folder, hf_weights_files, use_safetensors
|
||||
|
||||
def _get_weights_iterator(
|
||||
self, source: "Source"
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Get an iterator for the model weights based on the load format."""
|
||||
extra_config = self.load_config.model_loader_extra_config
|
||||
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
|
||||
source.model_or_path, source.revision, source.fall_back_to_pt,
|
||||
source.allow_patterns_overrides)
|
||||
if self.load_config.load_format == "npcache":
|
||||
# Currently np_cache only support *.bin checkpoints
|
||||
assert use_safetensors is False
|
||||
weights_iterator = np_cache_weights_iterator(
|
||||
source.model_or_path,
|
||||
self.load_config.download_dir,
|
||||
hf_folder,
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
)
|
||||
elif use_safetensors:
|
||||
if self.load_config.load_format == "fastsafetensors":
|
||||
weights_iterator = fastsafetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
)
|
||||
else:
|
||||
if extra_config.get("enable_multithread_load"):
|
||||
weights_iterator = (
|
||||
multi_thread_safetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
max_workers=extra_config.get(
|
||||
"num_threads", self.DEFAULT_NUM_THREADS),
|
||||
))
|
||||
else:
|
||||
weights_iterator = safetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
self.load_config.safetensors_load_strategy,
|
||||
)
|
||||
else:
|
||||
if extra_config.get("enable_multithread_load"):
|
||||
weights_iterator = multi_thread_pt_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
self.load_config.pt_load_map_location,
|
||||
max_workers=extra_config.get("num_threads",
|
||||
self.DEFAULT_NUM_THREADS),
|
||||
)
|
||||
else:
|
||||
weights_iterator = pt_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
self.load_config.pt_load_map_location,
|
||||
)
|
||||
|
||||
if current_platform.is_tpu():
|
||||
from vllm.platforms.tpu import USE_TPU_COMMONS
|
||||
|
||||
if not USE_TPU_COMMONS:
|
||||
# In PyTorch XLA, we should call `torch_xla.sync`
|
||||
# frequently so that not too many ops are accumulated
|
||||
# in the XLA program.
|
||||
import torch_xla
|
||||
|
||||
def _xla_weights_iterator(iterator: Generator):
|
||||
for weights in iterator:
|
||||
yield weights
|
||||
torch_xla.sync(wait=False)
|
||||
|
||||
weights_iterator = _xla_weights_iterator(weights_iterator)
|
||||
|
||||
if self.counter_before_loading_weights == 0.0:
|
||||
self.counter_before_loading_weights = time.perf_counter()
|
||||
# Apply the prefix.
|
||||
return ((source.prefix + name, tensor)
|
||||
for (name, tensor) in weights_iterator)
|
||||
|
||||
def get_all_weights(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
model: nn.Module,
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
primary_weights = DefaultModelLoader.Source(
|
||||
model_config.model,
|
||||
model_config.revision,
|
||||
prefix="",
|
||||
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load",
|
||||
True),
|
||||
allow_patterns_overrides=getattr(model, "allow_patterns_overrides",
|
||||
None),
|
||||
)
|
||||
yield from self._get_weights_iterator(primary_weights)
|
||||
|
||||
secondary_weights = cast(
|
||||
Iterable[DefaultModelLoader.Source],
|
||||
getattr(model, "secondary_weights", ()),
|
||||
)
|
||||
for source in secondary_weights:
|
||||
yield from self._get_weights_iterator(source)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(model_config.model,
|
||||
model_config.revision,
|
||||
fall_back_to_pt=True,
|
||||
allow_patterns_overrides=None)
|
||||
|
||||
def load_weights(self, model: nn.Module,
|
||||
model_config: ModelConfig) -> None:
|
||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||
loaded_weights = model.load_weights(
|
||||
self.get_all_weights(model_config, model))
|
||||
self.counter_after_loading_weights = time.perf_counter()
|
||||
logger.info(
|
||||
"Loading weights took %.2f seconds",
|
||||
self.counter_after_loading_weights -
|
||||
self.counter_before_loading_weights)
|
||||
# We only enable strict check for non-quantized models
|
||||
# that have loaded weights tracking currently.
|
||||
if model_config.quantization is None and loaded_weights is not None:
|
||||
weights_not_loaded = weights_to_load - loaded_weights
|
||||
if weights_not_loaded:
|
||||
raise ValueError("Following weights were not initialized from "
|
||||
f"checkpoint: {weights_not_loaded}")
|
||||
28
vllm/model_executor/model_loader/dummy_loader.py
Normal file
28
vllm/model_executor/model_loader/dummy_loader.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
initialize_dummy_weights)
|
||||
|
||||
|
||||
class DummyModelLoader(BaseModelLoader):
|
||||
"""Model loader that will set model weights to random values."""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
if load_config.model_loader_extra_config:
|
||||
raise ValueError(f"Model loader extra config is not supported for "
|
||||
f"load format {load_config.load_format}")
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
pass # Nothing to download
|
||||
|
||||
def load_weights(self, model: nn.Module,
|
||||
model_config: ModelConfig) -> None:
|
||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||
# random values to the weights.
|
||||
initialize_dummy_weights(model)
|
||||
155
vllm/model_executor/model_loader/gguf_loader.py
Normal file
155
vllm/model_executor/model_loader/gguf_loader.py
Normal file
@@ -0,0 +1,155 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
import gguf
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
initialize_model, process_weights_after_loading, set_default_torch_dtype)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
get_gguf_extra_tensor_names, get_gguf_weight_type_map,
|
||||
gguf_quant_weights_iterator)
|
||||
|
||||
|
||||
class GGUFModelLoader(BaseModelLoader):
|
||||
"""
|
||||
Model loader that can load GGUF files. This is useful for loading models
|
||||
that are quantized with GGUF and saved in the GGUF format. This loader
|
||||
supports loading both full models and sharded models.
|
||||
"""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
if load_config.model_loader_extra_config:
|
||||
raise ValueError(f"Model loader extra config is not supported for "
|
||||
f"load format {load_config.load_format}")
|
||||
|
||||
def _prepare_weights(self, model_name_or_path: str):
|
||||
if os.path.isfile(model_name_or_path):
|
||||
return model_name_or_path
|
||||
# for raw HTTPS link
|
||||
if model_name_or_path.startswith(
|
||||
("http://", "https://")) and model_name_or_path.endswith(".gguf"):
|
||||
return hf_hub_download(url=model_name_or_path)
|
||||
# repo id/filename.gguf
|
||||
if "/" in model_name_or_path and model_name_or_path.endswith(".gguf"):
|
||||
repo_id, filename = model_name_or_path.rsplit("/", 1)
|
||||
return hf_hub_download(repo_id=repo_id, filename=filename)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unrecognised GGUF reference: {model_name_or_path} "
|
||||
"(expected local file, raw URL, or <repo_id>/<filename>.gguf)")
|
||||
|
||||
def _get_gguf_weights_map(self, model_config: ModelConfig):
|
||||
"""
|
||||
GGUF uses this naming convention for their tensors from HF checkpoint:
|
||||
`blk.N.BB.weight` and `blk.N.BB.bias`
|
||||
where N signifies the block number of a layer, and BB signifies the
|
||||
attention/mlp layer components.
|
||||
See "Standardized tensor names" in
|
||||
https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
|
||||
"""
|
||||
config = model_config.hf_config
|
||||
model_type = config.model_type
|
||||
gguf_to_hf_name_map = {}
|
||||
# hack: ggufs have a different name than transformers
|
||||
if model_type == "cohere":
|
||||
model_type = "command-r"
|
||||
if model_type in ("deepseek_v3", "deepseek_v2"):
|
||||
model_type = "deepseek2"
|
||||
# GGUF layer map assumes that we will have a merged expert weights
|
||||
# so we need to map them manually
|
||||
for idx in range(config.num_hidden_layers):
|
||||
gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = \
|
||||
f"model.layers.{idx}.mlp.gate.e_score_correction_bias"
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = \
|
||||
f"model.layers.{idx}.mlp.experts.0.down_proj.weight"
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = \
|
||||
f"model.layers.{idx}.mlp.experts.0.gate_proj.weight"
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = \
|
||||
f"model.layers.{idx}.mlp.experts.0.up_proj.weight"
|
||||
if model_type in ("qwen2_moe", "qwen3_moe"):
|
||||
model_type = model_type.replace("_", "")
|
||||
# GGUF layer map assumes that we will have a merged expert weights
|
||||
# so we need to map them manually
|
||||
for idx in range(config.num_hidden_layers):
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = \
|
||||
f"model.layers.{idx}.mlp.experts.0.down_proj.weight"
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = \
|
||||
f"model.layers.{idx}.mlp.experts.0.gate_proj.weight"
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = \
|
||||
f"model.layers.{idx}.mlp.experts.0.up_proj.weight"
|
||||
|
||||
arch = None
|
||||
for key, value in gguf.MODEL_ARCH_NAMES.items():
|
||||
if value == model_type:
|
||||
arch = key
|
||||
break
|
||||
if arch is None:
|
||||
raise RuntimeError(f"Unknown gguf model_type: {model_type}")
|
||||
num_layers = config.num_hidden_layers
|
||||
name_map = gguf.get_tensor_name_map(arch, num_layers)
|
||||
with torch.device("meta"):
|
||||
dummy_model = AutoModelForCausalLM.from_config(
|
||||
config, trust_remote_code=model_config.trust_remote_code)
|
||||
state_dict = dummy_model.state_dict()
|
||||
|
||||
for hf_name in state_dict:
|
||||
name, suffix = hf_name.rsplit(".", 1)
|
||||
gguf_name = name_map.get_name(name)
|
||||
gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name
|
||||
return gguf_to_hf_name_map
|
||||
|
||||
def _get_weights_iterator(
|
||||
self, model_name_or_path: str, gguf_to_hf_name_map: dict[str, str]
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
return gguf_quant_weights_iterator(model_name_or_path,
|
||||
gguf_to_hf_name_map)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(model_config.model)
|
||||
|
||||
def load_weights(self, model: nn.Module,
|
||||
model_config: ModelConfig) -> None:
|
||||
local_model_path = self._prepare_weights(model_config.model)
|
||||
gguf_weights_map = self._get_gguf_weights_map(model_config)
|
||||
model.load_weights(
|
||||
self._get_weights_iterator(local_model_path, gguf_weights_map))
|
||||
|
||||
def load_model(self, vllm_config: VllmConfig,
|
||||
model_config: ModelConfig) -> nn.Module:
|
||||
device_config = vllm_config.device_config
|
||||
local_model_path = self._prepare_weights(model_config.model)
|
||||
gguf_weights_map = self._get_gguf_weights_map(model_config)
|
||||
# we can only know if tie word embeddings after mapping weights
|
||||
if "lm_head.weight" in get_gguf_extra_tensor_names(
|
||||
local_model_path, gguf_weights_map):
|
||||
model_config.hf_config.update({"tie_word_embeddings": True})
|
||||
|
||||
weight_type_map = get_gguf_weight_type_map(model_config.model,
|
||||
gguf_weights_map)
|
||||
|
||||
# filter out unquantized modules to skip
|
||||
unquant_names = [
|
||||
name.removesuffix(".weight")
|
||||
for name, weight_type in weight_type_map.items()
|
||||
if weight_type == "F32" and name.endswith(".weight")
|
||||
]
|
||||
vllm_config.quant_config.unquantized_modules.extend(unquant_names)
|
||||
|
||||
target_device = torch.device(device_config.device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
model = initialize_model(vllm_config=vllm_config)
|
||||
self.load_weights(model, model_config)
|
||||
|
||||
process_weights_after_loading(model, model_config, target_device)
|
||||
return model
|
||||
104
vllm/model_executor/model_loader/runai_streamer_loader.py
Normal file
104
vllm/model_executor/model_loader/runai_streamer_loader.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: SIM117
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_safetensors_index_file_from_hf, download_weights_from_hf,
|
||||
runai_safetensors_weights_iterator)
|
||||
from vllm.transformers_utils.runai_utils import (is_runai_obj_uri,
|
||||
list_safetensors)
|
||||
|
||||
|
||||
class RunaiModelStreamerLoader(BaseModelLoader):
|
||||
"""
|
||||
Model loader that can load safetensors
|
||||
files from local FS or S3 bucket.
|
||||
"""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
if load_config.model_loader_extra_config:
|
||||
extra_config = load_config.model_loader_extra_config
|
||||
|
||||
if ("concurrency" in extra_config
|
||||
and isinstance(extra_config.get("concurrency"), int)):
|
||||
os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(
|
||||
extra_config.get("concurrency"))
|
||||
|
||||
if ("memory_limit" in extra_config
|
||||
and isinstance(extra_config.get("memory_limit"), int)):
|
||||
os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(
|
||||
extra_config.get("memory_limit"))
|
||||
|
||||
runai_streamer_s3_endpoint = os.getenv(
|
||||
'RUNAI_STREAMER_S3_ENDPOINT')
|
||||
aws_endpoint_url = os.getenv('AWS_ENDPOINT_URL')
|
||||
if (runai_streamer_s3_endpoint is None
|
||||
and aws_endpoint_url is not None):
|
||||
os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url
|
||||
|
||||
def _prepare_weights(self, model_name_or_path: str,
|
||||
revision: Optional[str]) -> list[str]:
|
||||
"""Prepare weights for the model.
|
||||
|
||||
If the model is not local, it will be downloaded."""
|
||||
|
||||
is_object_storage_path = is_runai_obj_uri(model_name_or_path)
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
safetensors_pattern = "*.safetensors"
|
||||
index_file = SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
hf_folder = (model_name_or_path if (is_local or is_object_storage_path)
|
||||
else download_weights_from_hf(
|
||||
model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
[safetensors_pattern],
|
||||
revision,
|
||||
ignore_patterns=self.load_config.ignore_patterns,
|
||||
))
|
||||
hf_weights_files = list_safetensors(path=hf_folder)
|
||||
|
||||
if not is_local and not is_object_storage_path:
|
||||
download_safetensors_index_file_from_hf(
|
||||
model_name_or_path, index_file, self.load_config.download_dir,
|
||||
revision)
|
||||
|
||||
if not hf_weights_files:
|
||||
raise RuntimeError(
|
||||
f"Cannot find any safetensors model weights with "
|
||||
f"`{model_name_or_path}`")
|
||||
|
||||
return hf_weights_files
|
||||
|
||||
def _get_weights_iterator(
|
||||
self, model_or_path: str,
|
||||
revision: str) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Get an iterator for the model weights based on the load format."""
|
||||
hf_weights_files = self._prepare_weights(model_or_path, revision)
|
||||
return runai_safetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
"""Download model if necessary"""
|
||||
self._prepare_weights(model_config.model, model_config.revision)
|
||||
|
||||
def load_weights(self, model: nn.Module,
|
||||
model_config: ModelConfig) -> None:
|
||||
"""Load weights into a model."""
|
||||
model_weights = model_config.model
|
||||
if hasattr(model_config, "model_weights"):
|
||||
model_weights = model_config.model_weights
|
||||
model.load_weights(
|
||||
self._get_weights_iterator(model_weights, model_config.revision))
|
||||
199
vllm/model_executor/model_loader/sharded_state_loader.py
Normal file
199
vllm/model_executor/model_loader/sharded_state_loader.py
Normal file
@@ -0,0 +1,199 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import collections
|
||||
import glob
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_weights_from_hf, runai_safetensors_weights_iterator)
|
||||
from vllm.transformers_utils.s3_utils import glob as s3_glob
|
||||
from vllm.transformers_utils.utils import is_s3
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ShardedStateLoader(BaseModelLoader):
|
||||
"""
|
||||
Model loader that directly loads each worker's model state dict, which
|
||||
enables a fast load path for large tensor-parallel models where each worker
|
||||
only needs to read its own shard rather than the entire checkpoint. See
|
||||
`examples/offline_inference/save_sharded_state.py` for creating a sharded
|
||||
checkpoint.
|
||||
"""
|
||||
|
||||
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
|
||||
extra_config = ({} if load_config.model_loader_extra_config is None
|
||||
else load_config.model_loader_extra_config.copy())
|
||||
self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
|
||||
if extra_config:
|
||||
raise ValueError(f"Unexpected extra config keys for load format "
|
||||
f"{load_config.load_format}: "
|
||||
f"{load_config.model_loader_extra_config.keys()}")
|
||||
|
||||
@staticmethod
|
||||
def _filter_subtensors(
|
||||
tensors: dict[str, torch.Tensor], ) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Filter out all tensors that share the same memory or a subset of the
|
||||
memory of another tensor.
|
||||
"""
|
||||
same_storage_groups: dict[Any, list[tuple[str, torch.Tensor]]] = (
|
||||
collections.defaultdict(list))
|
||||
for key, tensor in tensors.items():
|
||||
if tensor.numel():
|
||||
ptr = tensor.untyped_storage().data_ptr()
|
||||
same_storage_groups[tensor.device, ptr].append((key, tensor))
|
||||
|
||||
def get_end_ptr(tensor: torch.Tensor) -> int:
|
||||
return tensor.view(-1)[-1].data_ptr() + tensor.element_size()
|
||||
|
||||
result: dict[str, torch.Tensor] = {}
|
||||
for group in same_storage_groups.values():
|
||||
for k, t in group:
|
||||
a, b = t.data_ptr(), get_end_ptr(t)
|
||||
for k2, t2 in group:
|
||||
if not t2.is_contiguous():
|
||||
continue
|
||||
a2, b2 = t2.data_ptr(), get_end_ptr(t2)
|
||||
if a < a2 or b2 < b:
|
||||
continue
|
||||
if a2 < a or b < b2 or not t.is_contiguous():
|
||||
break # t2 covers strictly more memory than t.
|
||||
if k2 < k:
|
||||
# Same tensors, keep the one with the smaller key.
|
||||
break
|
||||
else:
|
||||
result[k] = t
|
||||
return result
|
||||
|
||||
def _prepare_weights(self, model_name_or_path: str,
|
||||
revision: Optional[str]):
|
||||
if is_s3(model_name_or_path) or os.path.isdir(model_name_or_path):
|
||||
return model_name_or_path
|
||||
else:
|
||||
allow_patterns = ["*.safetensors"]
|
||||
return download_weights_from_hf(
|
||||
model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
allow_patterns,
|
||||
revision,
|
||||
ignore_patterns=self.load_config.ignore_patterns,
|
||||
)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(model_config.model, model_config.revision)
|
||||
|
||||
def load_weights(self, model: nn.Module,
|
||||
model_config: ModelConfig) -> None:
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
model_weights = model_config.model
|
||||
if hasattr(model_config, "model_weights"):
|
||||
model_weights = model_config.model_weights
|
||||
local_model_path = model_weights
|
||||
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
pattern = os.path.join(
|
||||
local_model_path,
|
||||
self.pattern.format(rank=rank, part="*"),
|
||||
)
|
||||
|
||||
filepaths = []
|
||||
if is_s3(local_model_path):
|
||||
file_pattern = f"*{self.pattern.format(rank=rank, part=' * ')}"
|
||||
filepaths = s3_glob(path=local_model_path,
|
||||
allow_pattern=[file_pattern])
|
||||
else:
|
||||
filepaths = glob.glob(pattern)
|
||||
if not filepaths:
|
||||
# TODO: support un-sharded checkpoints too
|
||||
raise ValueError(
|
||||
f"Could not find checkpoint files '{pattern}', only "
|
||||
f"pre-sharded checkpoints are currently supported!")
|
||||
state_dict = self._filter_subtensors(model.state_dict())
|
||||
for key, tensor in self.iterate_over_files(filepaths):
|
||||
# If loading with LoRA enabled, additional padding may
|
||||
# be added to certain parameters. We only load into a
|
||||
# narrowed view of the parameter data.
|
||||
param_data = state_dict[key].data
|
||||
param_shape = state_dict[key].shape
|
||||
for dim, size in enumerate(tensor.shape):
|
||||
if size < param_shape[dim]:
|
||||
param_data = param_data.narrow(dim, 0, size)
|
||||
if tensor.shape != param_shape:
|
||||
logger.warning(
|
||||
"loading tensor of shape %s into "
|
||||
"parameter '%s' of shape %s",
|
||||
tensor.shape,
|
||||
key,
|
||||
param_shape,
|
||||
)
|
||||
param_data.copy_(tensor)
|
||||
state_dict.pop(key)
|
||||
if state_dict:
|
||||
raise ValueError(
|
||||
f"Missing keys {tuple(state_dict)} in loaded state!")
|
||||
|
||||
def iterate_over_files(
|
||||
self, paths) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
if self.load_config.load_format == "runai_streamer_sharded":
|
||||
yield from runai_safetensors_weights_iterator(paths, True)
|
||||
else:
|
||||
from safetensors.torch import safe_open
|
||||
for path in paths:
|
||||
with safe_open(path, framework="pt") as f:
|
||||
for key in f.keys(): # noqa: SIM118
|
||||
tensor = f.get_tensor(key)
|
||||
yield key, tensor
|
||||
|
||||
@staticmethod
|
||||
def save_model(
|
||||
model: torch.nn.Module,
|
||||
path: str,
|
||||
pattern: Optional[str] = None,
|
||||
max_size: Optional[int] = None,
|
||||
) -> None:
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
if pattern is None:
|
||||
pattern = ShardedStateLoader.DEFAULT_PATTERN
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
part_idx = 0
|
||||
total_size = 0
|
||||
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
|
||||
state_dict_part: dict[str, torch.Tensor] = {}
|
||||
for key, tensor in state_dict.items():
|
||||
param_size = tensor.nelement() * tensor.element_size()
|
||||
if max_size is not None and total_size + param_size > max_size:
|
||||
filename = pattern.format(rank=rank, part=part_idx)
|
||||
save_file(
|
||||
state_dict_part,
|
||||
os.path.join(path, filename),
|
||||
)
|
||||
part_idx += 1
|
||||
total_size = 0
|
||||
state_dict_part = {}
|
||||
state_dict_part[key] = tensor
|
||||
total_size += param_size
|
||||
if len(state_dict_part) > 0:
|
||||
filename = pattern.format(rank=rank, part=part_idx)
|
||||
save_file(
|
||||
state_dict_part,
|
||||
os.path.join(path, filename),
|
||||
)
|
||||
738
vllm/model_executor/model_loader/tensorizer.py
Normal file
738
vllm/model_executor/model_loader/tensorizer.py
Normal file
@@ -0,0 +1,738 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import contextvars
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Generator, MutableMapping
|
||||
from dataclasses import asdict, dataclass, field, fields
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Optional, Union
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from torch import nn
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (ModelConfig, ParallelConfig, VllmConfig,
|
||||
set_current_vllm_config)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import FlexibleArgumentParser, PlaceholderModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
|
||||
try:
|
||||
from tensorizer import (DecryptionParams, EncryptionParams,
|
||||
TensorDeserializer, TensorSerializer)
|
||||
from tensorizer.stream_io import open_stream
|
||||
from tensorizer.utils import (convert_bytes, get_mem_usage,
|
||||
no_init_or_tensor)
|
||||
|
||||
except ImportError:
|
||||
tensorizer = PlaceholderModule("tensorizer")
|
||||
DecryptionParams = tensorizer.placeholder_attr("DecryptionParams")
|
||||
EncryptionParams = tensorizer.placeholder_attr("EncryptionParams")
|
||||
TensorDeserializer = tensorizer.placeholder_attr("TensorDeserializer")
|
||||
TensorSerializer = tensorizer.placeholder_attr("TensorSerializer")
|
||||
open_stream = tensorizer.placeholder_attr("stream_io.open_stream")
|
||||
convert_bytes = tensorizer.placeholder_attr("utils.convert_bytes")
|
||||
get_mem_usage = tensorizer.placeholder_attr("utils.get_mem_usage")
|
||||
no_init_or_tensor = tensorizer.placeholder_attr("utils.no_init_or_tensor")
|
||||
|
||||
__all__ = [
|
||||
'EncryptionParams', 'DecryptionParams', 'TensorDeserializer',
|
||||
'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage',
|
||||
'no_init_or_tensor', 'TensorizerConfig'
|
||||
]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def is_valid_deserialization_uri(uri: Optional[str]) -> bool:
|
||||
if uri:
|
||||
scheme = uri.lower().split("://")[0]
|
||||
return scheme in {"s3", "http", "https"} or os.path.exists(uri)
|
||||
return False
|
||||
|
||||
|
||||
def tensorizer_kwargs_arg(value):
|
||||
loaded = json.loads(value)
|
||||
if not isinstance(loaded, dict):
|
||||
raise argparse.ArgumentTypeError(
|
||||
f"Not deserializable to dict: {value}. serialization_kwargs and "
|
||||
f"deserialization_kwargs must be "
|
||||
f"deserializable from a JSON string to a dictionary. ")
|
||||
return loaded
|
||||
|
||||
|
||||
class MetaTensorMode(TorchDispatchMode):
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
|
||||
if func._schema.name == "aten::empty" and "device" not in kwargs:
|
||||
kwargs["device"] = "meta"
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def meta_tensor_mode(loading_code=None, ):
|
||||
|
||||
if loading_code is None:
|
||||
return _NoInitOrTensorImpl.context_manager()
|
||||
elif callable(loading_code):
|
||||
with _NoInitOrTensorImpl.context_manager():
|
||||
return loading_code()
|
||||
else:
|
||||
raise TypeError(
|
||||
"expected a callable to evaluate,"
|
||||
" or None if being used as a context manager;"
|
||||
f' got an object of type "{type(loading_code).__name__}" instead.')
|
||||
|
||||
|
||||
class _NoInitOrTensorImpl:
|
||||
_MODULES = (torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm)
|
||||
_MODULE_ORIGINALS = tuple((m, m.reset_parameters) for m in _MODULES)
|
||||
|
||||
is_active = contextvars.ContextVar("_NoInitOrTensorImpl.is_active",
|
||||
default=False)
|
||||
_count_active: int = 0
|
||||
_count_active_lock = threading.Lock()
|
||||
|
||||
@classmethod
|
||||
@contextlib.contextmanager
|
||||
def context_manager(cls):
|
||||
if cls.is_active.get():
|
||||
yield
|
||||
return
|
||||
|
||||
with cls._count_active_lock:
|
||||
cls._count_active += 1
|
||||
if cls._count_active == 1:
|
||||
for mod in cls._MODULES:
|
||||
mod.reset_parameters = cls._disable(mod.reset_parameters)
|
||||
|
||||
reset_token = cls.is_active.set(True)
|
||||
|
||||
try:
|
||||
with MetaTensorMode():
|
||||
yield
|
||||
finally:
|
||||
cls.is_active.reset(reset_token)
|
||||
with cls._count_active_lock:
|
||||
cls._count_active -= 1
|
||||
if cls._count_active == 0:
|
||||
for mod, original in cls._MODULE_ORIGINALS:
|
||||
mod.reset_parameters = original
|
||||
|
||||
@staticmethod
|
||||
def _disable(func):
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
if not _NoInitOrTensorImpl.is_active.get():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorizerConfig(MutableMapping):
|
||||
tensorizer_uri: Optional[str] = None
|
||||
tensorizer_dir: Optional[str] = None
|
||||
vllm_tensorized: Optional[bool] = None
|
||||
verify_hash: Optional[bool] = None
|
||||
num_readers: Optional[int] = None
|
||||
encryption_keyfile: Optional[str] = None
|
||||
s3_access_key_id: Optional[str] = None
|
||||
s3_secret_access_key: Optional[str] = None
|
||||
s3_endpoint: Optional[str] = None
|
||||
lora_dir: Optional[str] = None
|
||||
stream_kwargs: Optional[dict[str, Any]] = None
|
||||
serialization_kwargs: Optional[dict[str, Any]] = None
|
||||
deserialization_kwargs: Optional[dict[str, Any]] = None
|
||||
_extra_serialization_attrs: Optional[dict[str, Any]] = field(init=False,
|
||||
default=None)
|
||||
model_class: Optional[type[torch.nn.Module]] = field(init=False,
|
||||
default=None)
|
||||
hf_config: Optional[PretrainedConfig] = field(init=False, default=None)
|
||||
dtype: Optional[Union[str, torch.dtype]] = field(init=False, default=None)
|
||||
_is_sharded: bool = field(init=False, default=False)
|
||||
_fields: ClassVar[tuple[str, ...]]
|
||||
_keys: ClassVar[frozenset[str]]
|
||||
"""Configuration class for Tensorizer settings.
|
||||
|
||||
These settings configure the behavior of model serialization and
|
||||
deserialization using Tensorizer.
|
||||
|
||||
Attributes:
|
||||
tensorizer_uri: Path to serialized model tensors. Can be a local file
|
||||
path or a S3 URI. This is a required field unless lora_dir is
|
||||
provided and the config is meant to be used for the
|
||||
`tensorize_lora_adapter` function. Unless a `tensorizer_dir` or
|
||||
`lora_dir` is passed to this object's initializer, this is
|
||||
a required argument.
|
||||
tensorizer_dir: Path to a directory containing serialized model tensors,
|
||||
and all other potential model artifacts to load the model, such as
|
||||
configs and tokenizer files. Can be passed instead of
|
||||
`tensorizer_uri` where the `model.tensors` file will be assumed
|
||||
to be in this directory.
|
||||
vllm_tensorized: If True, indicates that the serialized model is a
|
||||
vLLM model. This is used to determine the behavior of the
|
||||
TensorDeserializer when loading tensors from a serialized model.
|
||||
It is far faster to deserialize a vLLM model as it utilizes
|
||||
tensorizer's optimized GPU loading. Note that this is now
|
||||
deprecated, as serialized vLLM models are now automatically
|
||||
inferred as vLLM models.
|
||||
verify_hash: If True, the hashes of each tensor will be verified
|
||||
against the hashes stored in the metadata. A `HashMismatchError`
|
||||
will be raised if any of the hashes do not match.
|
||||
num_readers: Controls how many threads are allowed to read concurrently
|
||||
from the source file. Default is `None`, which will dynamically set
|
||||
the number of readers based on the number of available
|
||||
resources and model size. This greatly increases performance.
|
||||
encryption_keyfile: File path to a binary file containing a
|
||||
binary key to use for decryption. `None` (the default) means
|
||||
no decryption. See the example script in
|
||||
examples/others/tensorize_vllm_model.py.
|
||||
s3_access_key_id: The access key for the S3 bucket. Can also be set via
|
||||
the S3_ACCESS_KEY_ID environment variable.
|
||||
s3_secret_access_key: The secret access key for the S3 bucket. Can also
|
||||
be set via the S3_SECRET_ACCESS_KEY environment variable.
|
||||
s3_endpoint: The endpoint for the S3 bucket. Can also be set via the
|
||||
S3_ENDPOINT_URL environment variable.
|
||||
lora_dir: Path to a directory containing LoRA adapter artifacts for
|
||||
serialization or deserialization. When serializing LoRA adapters
|
||||
this is the only necessary parameter to pass to this object's
|
||||
initializer.
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
# check if the configuration is for a sharded vLLM model
|
||||
self._is_sharded = isinstance(self.tensorizer_uri, str) \
|
||||
and re.search(r'%0\dd', self.tensorizer_uri) is not None
|
||||
|
||||
if self.tensorizer_dir and self.lora_dir:
|
||||
raise ValueError(
|
||||
"Only one of tensorizer_dir or lora_dir may be specified. "
|
||||
"Use lora_dir exclusively when serializing LoRA adapters, "
|
||||
"and tensorizer_dir or tensorizer_uri otherwise.")
|
||||
if self.tensorizer_dir and self.tensorizer_uri:
|
||||
logger.warning_once(
|
||||
"Provided both tensorizer_dir and tensorizer_uri. "
|
||||
"Inferring tensorizer_dir from tensorizer_uri as the "
|
||||
"latter takes precedence.")
|
||||
self.tensorizer_dir = os.path.dirname(self.tensorizer_uri)
|
||||
if not self.tensorizer_uri:
|
||||
if self.lora_dir:
|
||||
self.tensorizer_uri = f"{self.lora_dir}/adapter_model.tensors"
|
||||
elif self.tensorizer_dir:
|
||||
self.tensorizer_uri = f"{self.tensorizer_dir}/model.tensors"
|
||||
else:
|
||||
raise ValueError("Unable to resolve tensorizer_uri. "
|
||||
"A valid tensorizer_uri or tensorizer_dir "
|
||||
"must be provided for deserialization, and a "
|
||||
"valid tensorizer_uri, tensorizer_uri, or "
|
||||
"lora_dir for serialization.")
|
||||
else:
|
||||
self.tensorizer_dir = os.path.dirname(self.tensorizer_uri)
|
||||
|
||||
if not self.serialization_kwargs:
|
||||
self.serialization_kwargs = {}
|
||||
if not self.deserialization_kwargs:
|
||||
self.deserialization_kwargs = {}
|
||||
|
||||
def to_serializable(self) -> dict[str, Any]:
|
||||
# Due to TensorizerConfig needing to be msgpack-serializable, it needs
|
||||
# support for morphing back and forth between itself and its dict
|
||||
# representation
|
||||
|
||||
# TensorizerConfig's representation as a dictionary is meant to be
|
||||
# linked to TensorizerConfig in such a way that the following is
|
||||
# technically initializable:
|
||||
# TensorizerConfig(**my_tensorizer_cfg.to_serializable())
|
||||
|
||||
# This means the dict must not retain non-initializable parameters
|
||||
# and post-init attribute states
|
||||
|
||||
# Also don't want to retain private and unset parameters, so only retain
|
||||
# not None values and public attributes
|
||||
|
||||
raw_tc_dict = asdict(self)
|
||||
blacklisted = []
|
||||
|
||||
if "tensorizer_uri" in raw_tc_dict and "tensorizer_dir" in raw_tc_dict:
|
||||
blacklisted.append("tensorizer_dir")
|
||||
|
||||
if "tensorizer_dir" in raw_tc_dict and "lora_dir" in raw_tc_dict:
|
||||
blacklisted.append("tensorizer_dir")
|
||||
|
||||
tc_dict = {}
|
||||
for k, v in raw_tc_dict.items():
|
||||
if (k not in blacklisted and k not in tc_dict
|
||||
and not k.startswith("_") and v is not None):
|
||||
tc_dict[k] = v
|
||||
|
||||
return tc_dict
|
||||
|
||||
def _construct_tensorizer_args(self) -> "TensorizerArgs":
|
||||
return TensorizerArgs(self) # type: ignore
|
||||
|
||||
def verify_with_parallel_config(
|
||||
self,
|
||||
parallel_config: "ParallelConfig",
|
||||
) -> None:
|
||||
if parallel_config.tensor_parallel_size > 1 \
|
||||
and not self._is_sharded:
|
||||
raise ValueError(
|
||||
"For a sharded model, tensorizer_uri should include a"
|
||||
" string format template like '%04d' to be formatted"
|
||||
" with the rank of the shard")
|
||||
|
||||
def verify_with_model_config(self, model_config: "ModelConfig") -> None:
|
||||
if (model_config.quantization is not None
|
||||
and self.tensorizer_uri is not None):
|
||||
logger.warning(
|
||||
"Loading a model using Tensorizer with quantization on vLLM"
|
||||
" is unstable and may lead to errors.")
|
||||
|
||||
def open_stream(self, tensorizer_args: Optional["TensorizerArgs"] = None):
|
||||
if tensorizer_args is None:
|
||||
tensorizer_args = self._construct_tensorizer_args()
|
||||
|
||||
return open_stream(self.tensorizer_uri,
|
||||
**tensorizer_args.stream_kwargs)
|
||||
|
||||
def keys(self):
|
||||
return self._keys
|
||||
|
||||
def __len__(self):
|
||||
return len(fields(self))
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self._fields)
|
||||
|
||||
def __getitem__(self, item: str) -> Any:
|
||||
if item not in self.keys():
|
||||
raise KeyError(item)
|
||||
return getattr(self, item)
|
||||
|
||||
def __setitem__(self, key: str, value: Any) -> None:
|
||||
if key not in self.keys():
|
||||
# Disallow modifying invalid keys
|
||||
raise KeyError(key)
|
||||
setattr(self, key, value)
|
||||
|
||||
def __delitem__(self, key, /):
|
||||
if key not in self.keys():
|
||||
raise KeyError(key)
|
||||
delattr(self, key)
|
||||
|
||||
|
||||
TensorizerConfig._fields = tuple(f.name for f in fields(TensorizerConfig))
|
||||
TensorizerConfig._keys = frozenset(TensorizerConfig._fields)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorizerArgs:
|
||||
tensorizer_uri: Optional[str] = None
|
||||
tensorizer_dir: Optional[str] = None
|
||||
encryption_keyfile: Optional[str] = None
|
||||
|
||||
def __init__(self, tensorizer_config: TensorizerConfig):
|
||||
for k, v in tensorizer_config.items():
|
||||
setattr(self, k, v)
|
||||
self.file_obj = tensorizer_config.tensorizer_uri
|
||||
self.s3_access_key_id = (tensorizer_config.s3_access_key_id
|
||||
or envs.S3_ACCESS_KEY_ID)
|
||||
self.s3_secret_access_key = (tensorizer_config.s3_secret_access_key
|
||||
or envs.S3_SECRET_ACCESS_KEY)
|
||||
self.s3_endpoint = tensorizer_config.s3_endpoint or envs.S3_ENDPOINT_URL
|
||||
|
||||
self.stream_kwargs = {
|
||||
"s3_access_key_id": tensorizer_config.s3_access_key_id,
|
||||
"s3_secret_access_key": tensorizer_config.s3_secret_access_key,
|
||||
"s3_endpoint": tensorizer_config.s3_endpoint,
|
||||
**(tensorizer_config.stream_kwargs or {})
|
||||
}
|
||||
|
||||
self.deserialization_kwargs = {
|
||||
"verify_hash": tensorizer_config.verify_hash,
|
||||
"encryption": tensorizer_config.encryption_keyfile,
|
||||
"num_readers": tensorizer_config.num_readers,
|
||||
**(tensorizer_config.deserialization_kwargs or {})
|
||||
}
|
||||
|
||||
if self.encryption_keyfile:
|
||||
with open_stream(
|
||||
tensorizer_config.encryption_keyfile,
|
||||
**self.stream_kwargs,
|
||||
) as stream:
|
||||
key = stream.read()
|
||||
decryption_params = DecryptionParams.from_key(key)
|
||||
self.deserialization_kwargs['encryption'] = decryption_params
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
"""Tensorizer CLI arguments"""
|
||||
|
||||
# Tensorizer options arg group
|
||||
group = parser.add_argument_group(
|
||||
'tensorizer options',
|
||||
description=('Options for configuring the behavior of the'
|
||||
' tensorizer deserializer when '
|
||||
'load_format=tensorizer is specified when '
|
||||
'initializing an LLMEngine, either via the CLI '
|
||||
'when running the vLLM OpenAI inference server '
|
||||
'with a JSON string passed to '
|
||||
'--model-loader-extra-config or as arguments given '
|
||||
'to TensorizerConfig when passed to '
|
||||
'model_loader_extra_config in the constructor '
|
||||
'for LLMEngine.'))
|
||||
|
||||
group.add_argument(
|
||||
"--tensorizer-uri",
|
||||
type=str,
|
||||
help="Path to serialized model tensors. Can be a local file path,"
|
||||
" or an HTTP(S) or S3 URI.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--verify-hash",
|
||||
action="store_true",
|
||||
help="If enabled, the hashes of each tensor will be verified"
|
||||
" against the hashes stored in the file metadata. An exception"
|
||||
" will be raised if any of the hashes do not match.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--encryption-keyfile",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The file path to a binary file containing a binary key to "
|
||||
"use for decryption. Can be a file path or S3 network URI.")
|
||||
group.add_argument(
|
||||
"--num-readers",
|
||||
default=None,
|
||||
type=int,
|
||||
help="Controls how many threads are allowed to read concurrently "
|
||||
"from the source file. Default is `None`, which will dynamically "
|
||||
"set the number of readers based on the available resources "
|
||||
"and model size. This greatly increases performance.")
|
||||
group.add_argument(
|
||||
"--s3-access-key-id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The access key for the S3 bucket. Can also be set via the "
|
||||
"S3_ACCESS_KEY_ID environment variable.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--s3-secret-access-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The secret access key for the S3 bucket. Can also be set via "
|
||||
"the S3_SECRET_ACCESS_KEY environment variable.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--s3-endpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The endpoint for the S3 bucket. Can also be set via the "
|
||||
"S3_ENDPOINT_URL environment variable.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace) -> "TensorizerArgs":
|
||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||
tensorizer_args = cls(**{
|
||||
attr: getattr(args, attr)
|
||||
for attr in attrs if hasattr(args, attr)
|
||||
})
|
||||
return tensorizer_args
|
||||
|
||||
|
||||
def _check_tensors_on_meta_device(model: nn.Module) -> None:
|
||||
for tensor in model.state_dict().values():
|
||||
if tensor.device.type == 'meta':
|
||||
raise ValueError(
|
||||
"The serialized model contains tensors on the meta device,"
|
||||
" indicating that some tensors were not loaded properly."
|
||||
" Please check that the parameters of the model being"
|
||||
" specified match that of the serialized model, such as"
|
||||
" its quantization.")
|
||||
|
||||
|
||||
def _resize_lora_embeddings(model: nn.Module):
|
||||
"""Modify LoRA embedding layers to use bigger tensors
|
||||
to allow for adapter added tokens."""
|
||||
for child in model.modules():
|
||||
if (isinstance(child, VocabParallelEmbedding) and child.weight.shape[0]
|
||||
< child.num_embeddings_per_partition):
|
||||
new_weight = torch.empty(child.num_embeddings_per_partition,
|
||||
child.embedding_dim,
|
||||
dtype=child.weight.dtype,
|
||||
device=child.weight.device)
|
||||
new_weight[:child.weight.shape[0]].copy_(child.weight.data)
|
||||
new_weight[child.weight.shape[0]:].fill_(0)
|
||||
child.weight.data = new_weight
|
||||
|
||||
|
||||
def init_tensorizer_model(tensorizer_config: TensorizerConfig,
|
||||
vllm_config: VllmConfig) -> nn.Module:
|
||||
assert tensorizer_config.hf_config is not None
|
||||
model_args = tensorizer_config.hf_config
|
||||
model_args.torch_dtype = tensorizer_config.dtype
|
||||
assert tensorizer_config.model_class is not None
|
||||
# TODO: Do we need to consider old-style model class?
|
||||
with meta_tensor_mode(), set_current_vllm_config(vllm_config,
|
||||
check_compile=True):
|
||||
return tensorizer_config.model_class(vllm_config=vllm_config)
|
||||
|
||||
|
||||
def deserialize_tensorizer_model(model: nn.Module,
|
||||
tensorizer_config: TensorizerConfig) -> None:
|
||||
tensorizer_args = tensorizer_config._construct_tensorizer_args()
|
||||
if not is_valid_deserialization_uri(tensorizer_config.tensorizer_uri):
|
||||
raise ValueError(
|
||||
f"{tensorizer_config.tensorizer_uri} is not a valid "
|
||||
f"tensorizer URI. Please check that the URI is correct. "
|
||||
f"It must either point to a local existing file, or have a "
|
||||
f"S3, HTTP or HTTPS scheme.")
|
||||
before_mem = get_mem_usage()
|
||||
start = time.perf_counter()
|
||||
with open_stream(
|
||||
tensorizer_config.tensorizer_uri,
|
||||
mode="rb",
|
||||
**tensorizer_args.stream_kwargs) as stream, TensorDeserializer(
|
||||
stream,
|
||||
dtype=tensorizer_config.dtype,
|
||||
device=f'xpu:{torch.xpu.current_device()}'
|
||||
if current_platform.is_xpu() else
|
||||
f'cuda:{torch.cuda.current_device()}',
|
||||
**tensorizer_args.deserialization_kwargs) as deserializer:
|
||||
deserializer.load_into_module(model)
|
||||
end = time.perf_counter()
|
||||
|
||||
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
|
||||
duration = end - start
|
||||
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
|
||||
after_mem = get_mem_usage()
|
||||
deserializer.close()
|
||||
logger.info("Deserialized %s in %0.2fs, %s/s", total_bytes_str,
|
||||
end - start, per_second)
|
||||
logger.info("Memory usage before: %s", before_mem)
|
||||
logger.info("Memory usage after: %s", after_mem)
|
||||
|
||||
_check_tensors_on_meta_device(model)
|
||||
_resize_lora_embeddings(model)
|
||||
del model.vllm_tensorized_marker
|
||||
|
||||
|
||||
def tensorizer_weights_iterator(
|
||||
tensorizer_args: "TensorizerArgs"
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
logger.warning("Deserializing HuggingFace models is not optimized for "
|
||||
"loading on vLLM, as tensorizer is forced to load to CPU. "
|
||||
"Consider deserializing a vLLM model instead for faster "
|
||||
"load times. See the "
|
||||
"examples/others/tensorize_vllm_model.py example script "
|
||||
"for serializing vLLM models.")
|
||||
|
||||
deserializer_args = tensorizer_args.deserialization_kwargs
|
||||
stream_kwargs = tensorizer_args.stream_kwargs
|
||||
stream = open_stream(tensorizer_args.tensorizer_uri, **stream_kwargs)
|
||||
with TensorDeserializer(stream, **deserializer_args,
|
||||
device="cpu") as state:
|
||||
yield from state.items()
|
||||
del state
|
||||
|
||||
|
||||
def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool:
|
||||
"""
|
||||
Infer if the model is a vLLM model by checking the weights for
|
||||
a vLLM tensorized marker.
|
||||
|
||||
Args:
|
||||
tensorizer_config: The TensorizerConfig object containing the
|
||||
tensorizer_uri to the serialized model.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is a vLLM model, False otherwise.
|
||||
"""
|
||||
tensorizer_args = tensorizer_config._construct_tensorizer_args()
|
||||
deserializer = TensorDeserializer(open_stream(
|
||||
tensorizer_args.tensorizer_uri, **tensorizer_args.stream_kwargs),
|
||||
**tensorizer_args.deserialization_kwargs,
|
||||
lazy_load=True)
|
||||
if tensorizer_config.vllm_tensorized:
|
||||
logger.warning(
|
||||
"Please note that newly serialized vLLM models are automatically "
|
||||
"inferred as vLLM models, so setting vllm_tensorized=True is "
|
||||
"only necessary for models serialized prior to this change.")
|
||||
return True
|
||||
return ".vllm_tensorized_marker" in deserializer
|
||||
|
||||
|
||||
def serialize_extra_artifacts(
|
||||
tensorizer_args: TensorizerArgs,
|
||||
served_model_name: Union[str, list[str], None]) -> None:
|
||||
if not isinstance(served_model_name, str):
|
||||
raise ValueError(
|
||||
f"served_model_name must be a str for serialize_extra_artifacts, "
|
||||
f"not {type(served_model_name)}.")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
snapshot_download(served_model_name,
|
||||
local_dir=tmpdir,
|
||||
ignore_patterns=[
|
||||
"*.pt", "*.safetensors", "*.bin", "*.cache",
|
||||
"*.gitattributes", "*.md"
|
||||
])
|
||||
for artifact in os.scandir(tmpdir):
|
||||
if not artifact.is_file():
|
||||
continue
|
||||
with open(artifact.path, "rb") as f, open_stream(
|
||||
f"{tensorizer_args.tensorizer_dir}/{artifact.name}",
|
||||
mode="wb+",
|
||||
**tensorizer_args.stream_kwargs) as stream:
|
||||
logger.info("Writing artifact %s", artifact.name)
|
||||
stream.write(f.read())
|
||||
|
||||
|
||||
def serialize_vllm_model(
|
||||
model: nn.Module,
|
||||
tensorizer_config: TensorizerConfig,
|
||||
model_config: "ModelConfig",
|
||||
) -> nn.Module:
|
||||
model.register_parameter(
|
||||
"vllm_tensorized_marker",
|
||||
nn.Parameter(torch.tensor((1, ), device="meta"), requires_grad=False))
|
||||
|
||||
tensorizer_args = tensorizer_config._construct_tensorizer_args()
|
||||
|
||||
encryption_params = None
|
||||
if (keyfile := tensorizer_config.encryption_keyfile) is not None:
|
||||
with open(keyfile, "rb") as f:
|
||||
key = f.read()
|
||||
encryption_params = EncryptionParams(key=key)
|
||||
|
||||
output_file = tensorizer_args.tensorizer_uri
|
||||
if tensorizer_config._is_sharded:
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
output_file = output_file % get_tensor_model_parallel_rank()
|
||||
|
||||
with open_stream(output_file, mode="wb+",
|
||||
**tensorizer_args.stream_kwargs) as stream:
|
||||
serializer = TensorSerializer(stream,
|
||||
encryption=encryption_params,
|
||||
**tensorizer_config.serialization_kwargs)
|
||||
serializer.write_module(model)
|
||||
serializer.close()
|
||||
|
||||
serialize_extra_artifacts(tensorizer_args, model_config.served_model_name)
|
||||
|
||||
logger.info("Successfully serialized model to %s", str(output_file))
|
||||
return model
|
||||
|
||||
|
||||
def tensorize_vllm_model(engine_args: "EngineArgs",
|
||||
tensorizer_config: TensorizerConfig,
|
||||
generate_keyfile: bool = True):
|
||||
"""Utility to load a model and then serialize it with Tensorizer
|
||||
|
||||
Intended to be used separately from running a vLLM server since it
|
||||
creates its own Engine instance.
|
||||
"""
|
||||
engine_config = engine_args.create_engine_config()
|
||||
tensorizer_config.verify_with_model_config(engine_config.model_config)
|
||||
tensorizer_config.verify_with_parallel_config(
|
||||
engine_config.parallel_config)
|
||||
|
||||
# generate the encryption key before creating the engine to support sharding
|
||||
if generate_keyfile and (keyfile :=
|
||||
tensorizer_config.encryption_keyfile) is not None:
|
||||
encryption_params = EncryptionParams.random()
|
||||
with open_stream(
|
||||
keyfile,
|
||||
mode="wb+",
|
||||
s3_access_key_id=tensorizer_config.s3_access_key_id,
|
||||
s3_secret_access_key=tensorizer_config.s3_secret_access_key,
|
||||
s3_endpoint=tensorizer_config.s3_endpoint,
|
||||
) as stream:
|
||||
stream.write(encryption_params.key)
|
||||
|
||||
assert envs.VLLM_USE_V1
|
||||
|
||||
from vllm.v1.engine.llm_engine import LLMEngine
|
||||
|
||||
engine = LLMEngine.from_vllm_config(engine_config)
|
||||
engine.collective_rpc(
|
||||
"save_tensorized_model",
|
||||
kwargs={"tensorizer_config": tensorizer_config.to_serializable()},
|
||||
)
|
||||
|
||||
|
||||
def tensorize_lora_adapter(lora_path: str,
|
||||
tensorizer_config: TensorizerConfig):
|
||||
"""
|
||||
Uses tensorizer to serialize a LoRA adapter. Assumes that the files
|
||||
needed to load a LoRA adapter are a safetensors-format file called
|
||||
adapter_model.safetensors and a json config file called adapter_config.json.
|
||||
|
||||
Serializes the files in the tensorizer_config.tensorizer_dir
|
||||
"""
|
||||
import safetensors
|
||||
|
||||
from vllm.lora.utils import get_adapter_absolute_path
|
||||
|
||||
lora_dir = get_adapter_absolute_path(lora_path)
|
||||
|
||||
tensor_path = config_path = ""
|
||||
|
||||
for file in os.listdir(lora_dir):
|
||||
if file.startswith("adapter_model"):
|
||||
tensor_path = lora_dir + "/" + file
|
||||
if file.startswith("adapter_config"):
|
||||
config_path = lora_dir + "/" + file
|
||||
if tensor_path and config_path:
|
||||
break
|
||||
|
||||
if tensor_path.endswith(".safetensors"):
|
||||
tensors = safetensors.torch.load_file(tensor_path)
|
||||
elif tensor_path.endswith(".bin"):
|
||||
tensors = torch.load(tensor_path)
|
||||
else:
|
||||
raise ValueError("Unsupported file: %s", tensor_path)
|
||||
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
tensorizer_args = tensorizer_config._construct_tensorizer_args()
|
||||
|
||||
with open_stream(f"{tensorizer_config.tensorizer_dir}/adapter_config.json",
|
||||
mode="wb+",
|
||||
**tensorizer_args.stream_kwargs) as f:
|
||||
|
||||
f.write(json.dumps(config).encode("utf-8"))
|
||||
|
||||
lora_uri = (f"{tensorizer_config.tensorizer_dir}"
|
||||
f"/adapter_model.tensors")
|
||||
with open_stream(lora_uri, mode="wb+",
|
||||
**tensorizer_args.stream_kwargs) as f:
|
||||
serializer = TensorSerializer(f)
|
||||
serializer.write_state_dict(tensors)
|
||||
serializer.close()
|
||||
|
||||
logger.info("Successfully serialized LoRA files to %s",
|
||||
str(tensorizer_config.tensorizer_dir))
|
||||
143
vllm/model_executor/model_loader/tensorizer_loader.py
Normal file
143
vllm/model_executor/model_loader/tensorizer_loader.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: SIM117
|
||||
import copy
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import ModelConfig, ParallelConfig, VllmConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.tensorizer import (
|
||||
TensorizerConfig, deserialize_tensorizer_model, init_tensorizer_model,
|
||||
is_vllm_tensorized, serialize_vllm_model, tensorizer_weights_iterator)
|
||||
from vllm.model_executor.model_loader.utils import (get_model_architecture,
|
||||
initialize_model,
|
||||
set_default_torch_dtype)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
BLACKLISTED_TENSORIZER_ARGS = {
|
||||
"device", # vLLM decides this
|
||||
"dtype", # vLLM decides this
|
||||
"mode", # Not meant to be configurable by the user
|
||||
}
|
||||
|
||||
|
||||
def validate_config(config: dict):
|
||||
for k, v in config.items():
|
||||
if v is not None and k in BLACKLISTED_TENSORIZER_ARGS:
|
||||
raise ValueError(f"{k} is not an allowed Tensorizer argument.")
|
||||
|
||||
|
||||
class TensorizerLoader(BaseModelLoader):
|
||||
"""Model loader using CoreWeave's tensorizer library."""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
|
||||
self.tensorizer_config = load_config.model_loader_extra_config
|
||||
else:
|
||||
validate_config(load_config.model_loader_extra_config)
|
||||
self.tensorizer_config = TensorizerConfig(
|
||||
**load_config.model_loader_extra_config["tensorizer_config"])
|
||||
|
||||
def _verify_config(self, model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig):
|
||||
self.tensorizer_config.verify_with_model_config(model_config)
|
||||
self.tensorizer_config.verify_with_parallel_config(parallel_config)
|
||||
|
||||
def _get_weights_iterator(
|
||||
self, ) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
|
||||
return tensorizer_weights_iterator(tensorizer_args)
|
||||
|
||||
def _load_model_serialized_cpu(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
) -> nn.Module:
|
||||
"""Load a serialized model with tensorizer to the CPU.
|
||||
|
||||
This is only necessary when the model isn't vLLM-tensorized (see
|
||||
examples/others/tensorize_vllm_model.py) This should still
|
||||
be faster than default HuggingFace loading, but will be slower than
|
||||
loading a vLLM-tensorized model.
|
||||
"""
|
||||
device_config = vllm_config.device_config
|
||||
model_config = vllm_config.model_config
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = initialize_model(vllm_config=vllm_config)
|
||||
|
||||
model.load_weights(self._get_weights_iterator())
|
||||
return model.eval()
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self.tensorizer_config.verify_with_model_config(model_config)
|
||||
|
||||
with self.tensorizer_config.open_stream():
|
||||
pass
|
||||
|
||||
def _patch_tensorizer_config(
|
||||
self, model_config: ModelConfig) -> TensorizerConfig:
|
||||
model_class = get_model_architecture(model_config)[0]
|
||||
tensorizer_config = copy.copy(self.tensorizer_config)
|
||||
tensorizer_config.model_class = model_class
|
||||
tensorizer_config.hf_config = model_config.hf_config
|
||||
tensorizer_config.dtype = model_config.dtype
|
||||
return tensorizer_config
|
||||
|
||||
def load_weights(self, model: nn.Module,
|
||||
model_config: ModelConfig) -> None:
|
||||
"""Load serialized model weights with tensorizer.
|
||||
|
||||
Expects a vLLM-tensorized model. See the
|
||||
examples/others/tensorize_vllm_model.py example script
|
||||
for serializing vLLM models."""
|
||||
if is_vllm_tensorized(self.tensorizer_config):
|
||||
tensorizer_config = self._patch_tensorizer_config(model_config)
|
||||
deserialize_tensorizer_model(model, tensorizer_config)
|
||||
else:
|
||||
model.load_weights(self._get_weights_iterator())
|
||||
|
||||
def load_model(self, vllm_config: VllmConfig,
|
||||
model_config: ModelConfig) -> nn.Module:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self._verify_config(model_config, parallel_config)
|
||||
|
||||
if parallel_config.tensor_parallel_size > 1:
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
self.tensorizer_config.tensorizer_uri = (
|
||||
self.tensorizer_config.tensorizer_uri %
|
||||
get_tensor_model_parallel_rank())
|
||||
|
||||
if is_vllm_tensorized(self.tensorizer_config):
|
||||
tensorizer_config = self._patch_tensorizer_config(model_config)
|
||||
device_config = vllm_config.device_config
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = init_tensorizer_model(
|
||||
tensorizer_config=tensorizer_config,
|
||||
vllm_config=vllm_config)
|
||||
self.load_weights(model, model_config)
|
||||
return model
|
||||
return self._load_model_serialized_cpu(vllm_config=vllm_config)
|
||||
|
||||
@staticmethod
|
||||
def save_model(
|
||||
model: torch.nn.Module,
|
||||
tensorizer_config: Union[TensorizerConfig, dict],
|
||||
model_config: ModelConfig,
|
||||
) -> None:
|
||||
if isinstance(tensorizer_config, dict):
|
||||
tensorizer_config = TensorizerConfig(**tensorizer_config)
|
||||
serialize_vllm_model(
|
||||
model=model,
|
||||
tensorizer_config=tensorizer_config,
|
||||
model_config=model_config,
|
||||
)
|
||||
114
vllm/model_executor/model_loader/tpu.py
Normal file
114
vllm/model_executor/model_loader/tpu.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.distributed.spmd as xs
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.distributed.tpu_distributed_utils import get_fqn, shard_model
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
initialize_model, process_weights_after_loading, set_default_torch_dtype)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TPUModelLoader(DefaultModelLoader):
|
||||
"""
|
||||
A TPU model loader for model loading under SPMD mode.
|
||||
"""
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
model_config: ModelConfig,
|
||||
mesh: Optional[xs.Mesh] = None,
|
||||
) -> nn.Module:
|
||||
# Initialize model and load weights on CPU. Then, during SPMD partition,
|
||||
# weights are sharded and transferred to TPUs.
|
||||
self.counter_before_loading_weights = time.perf_counter()
|
||||
model_config = vllm_config.model_config
|
||||
assert model_config.quantization is None, "Quantization not supported"
|
||||
target_device = torch.device('cpu')
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
model = initialize_model(vllm_config=vllm_config)
|
||||
|
||||
load_format = vllm_config.load_config.load_format
|
||||
if load_format != "dummy":
|
||||
weights_to_load = {
|
||||
name
|
||||
for name, _ in model.named_parameters()
|
||||
}
|
||||
all_weights = self.get_all_weights(model_config, model)
|
||||
loaded_weights = model.load_weights(all_weights)
|
||||
self.counter_after_loading_weights = time.perf_counter()
|
||||
logger.info(
|
||||
"Loading weights took %.2f seconds",
|
||||
self.counter_after_loading_weights -
|
||||
self.counter_before_loading_weights)
|
||||
# We only enable strict check for non-quantized models
|
||||
# that have loaded weights tracking currently.
|
||||
if model_config.quantization is None and \
|
||||
loaded_weights is not None:
|
||||
weights_not_loaded = weights_to_load - loaded_weights
|
||||
if weights_not_loaded:
|
||||
raise ValueError(
|
||||
"Following weights were not initialized from "
|
||||
f"checkpoint: {weights_not_loaded}")
|
||||
else:
|
||||
logger.info("Use dummy weight during weight loading.")
|
||||
|
||||
process_weights_after_loading(model, model_config, target_device)
|
||||
|
||||
counter_before_partition = time.perf_counter()
|
||||
model = model.eval()
|
||||
model = model.to('xla')
|
||||
shard_model(model, mesh)
|
||||
counter_after_partition = time.perf_counter()
|
||||
logger.info("Partition model took %.2f seconds",
|
||||
counter_after_partition - counter_before_partition)
|
||||
|
||||
# Ensure the model is properly loaded.
|
||||
self._check_model_is_loaded(mesh, model)
|
||||
|
||||
# Need to torch compile after model sharding are done. Because the
|
||||
# compiler hints ('xs.mark_sharding') are torch ops.
|
||||
if not model_config.is_multimodal_model:
|
||||
model.model = torch.compile(model.model, backend="openxla")
|
||||
else:
|
||||
model.language_model.model = \
|
||||
torch.compile(model.language_model.model, backend="openxla")
|
||||
return model
|
||||
|
||||
def _check_model_is_loaded(self, mesh: Optional[xs.Mesh],
|
||||
model: nn.Module) -> None:
|
||||
"""
|
||||
Ensure the model is properly loaded.
|
||||
1. All model parameters and buffers are on XLA device.
|
||||
2. Non-SPMD friendly layers are replaced as expected.
|
||||
"""
|
||||
device = xm.xla_device()
|
||||
device_type = str(device.type)
|
||||
|
||||
# Check parameters
|
||||
for name, param in model.named_parameters():
|
||||
assert param.device.type == device_type, (
|
||||
f"Parameter {name} is on {param.device.type} "
|
||||
f"instead of {device_type}")
|
||||
|
||||
# Check buffers
|
||||
for name, buffer in model.named_buffers():
|
||||
assert buffer.device.type == device_type, (
|
||||
f"Buffer {name} is on {buffer.device.type} "
|
||||
f"instead of {device_type}")
|
||||
|
||||
for module in model.modules():
|
||||
if (mesh is not None) and (get_fqn(module) == 'QKVParallelLinear'):
|
||||
raise AssertionError("QKVParallelLinear should be replaced by \
|
||||
XlaQKVParallelLinear under SPMD mode.")
|
||||
292
vllm/model_executor/model_loader/utils.py
Normal file
292
vllm/model_executor/model_loader/utils.py
Normal file
@@ -0,0 +1,292 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utilities for selecting and loading models."""
|
||||
import contextlib
|
||||
import inspect
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import QKVCrossParallelLinear
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.models.adapters import (
|
||||
as_embedding_model, as_reward_model, as_seq_cls_model,
|
||||
try_create_mm_pooling_model_cls)
|
||||
from vllm.model_executor.models.interfaces import (SupportsQuant,
|
||||
supports_multimodal)
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_default_torch_dtype(dtype: torch.dtype):
|
||||
"""Sets the default torch dtype to the given dtype."""
|
||||
old_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(dtype)
|
||||
yield
|
||||
torch.set_default_dtype(old_dtype)
|
||||
|
||||
|
||||
def initialize_model(
|
||||
vllm_config: VllmConfig,
|
||||
*,
|
||||
prefix: str = "",
|
||||
model_class: Optional[type[nn.Module]] = None,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
) -> nn.Module:
|
||||
"""Initialize a model with the given configurations."""
|
||||
if model_config is None:
|
||||
model_config = vllm_config.model_config
|
||||
if model_class is None:
|
||||
model_class, _ = get_model_architecture(model_config)
|
||||
|
||||
if vllm_config.quant_config is not None:
|
||||
configure_quant_config(vllm_config.quant_config, model_class)
|
||||
|
||||
signatures = inspect.signature(model_class.__init__)
|
||||
all_params = [param.name for param in signatures.parameters.values()]
|
||||
if "vllm_config" in all_params and "prefix" in all_params:
|
||||
# new-style model class
|
||||
with set_current_vllm_config(vllm_config,
|
||||
check_compile=True,
|
||||
prefix=prefix):
|
||||
return model_class(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
|
||||
"input arguments. Possibly you have an old-style model class"
|
||||
" registered from out of tree and it is used for new vLLM version. "
|
||||
"Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
|
||||
"for the design and update the model class accordingly.")
|
||||
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
||||
|
||||
logger.warning(
|
||||
"Trying to guess the arguments for old-style model class %s",
|
||||
model_class,
|
||||
)
|
||||
# try to be compatible with old-style model class
|
||||
kwargs = {}
|
||||
if "prefix" in all_params:
|
||||
kwargs["prefix"] = prefix
|
||||
if "config" in all_params:
|
||||
kwargs["config"] = model_config.hf_config
|
||||
if "cache_config" in all_params:
|
||||
kwargs["cache_config"] = vllm_config.cache_config
|
||||
if "quant_config" in all_params:
|
||||
kwargs["quant_config"] = vllm_config.quant_config
|
||||
if "lora_config" in all_params:
|
||||
kwargs["lora_config"] = vllm_config.lora_config
|
||||
if "scheduler_config" in all_params:
|
||||
kwargs["scheduler_config"] = vllm_config.scheduler_config
|
||||
with set_current_vllm_config(vllm_config,
|
||||
check_compile=True,
|
||||
prefix=prefix):
|
||||
return model_class(**kwargs)
|
||||
|
||||
|
||||
def process_weights_after_loading(model: nn.Module, model_config: ModelConfig,
|
||||
target_device: torch.device) -> None:
|
||||
for _, module in model.named_modules():
|
||||
if isinstance(module, QKVCrossParallelLinear):
|
||||
# NOTE(Isotr0py): special case for cross QKV layer because
|
||||
# q and kv proj aren't registered as submodules intentionally
|
||||
module.process_weights_after_loading()
|
||||
continue
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if isinstance(quant_method, QuantizeMethodBase):
|
||||
# When quant methods need to process weights after loading
|
||||
# (for repacking, quantizing, etc), they expect parameters
|
||||
# to be on the global target device. This scope is for the
|
||||
# case where cpu offloading is used, where we will move the
|
||||
# parameters onto device for processing and back off after.
|
||||
with device_loading_context(module, target_device):
|
||||
quant_method.process_weights_after_loading(module)
|
||||
|
||||
# Currently only used by MLA.
|
||||
# NOTE: This intentionally happens after other modules so we can easily
|
||||
# decompress the weights for MLA.
|
||||
for _, module in model.named_modules():
|
||||
if isinstance(module, Attention) and \
|
||||
hasattr(module, "process_weights_after_loading"):
|
||||
# TODO(lucas): see if there is a way to unify the signatures
|
||||
# of process_weights_after_loading
|
||||
module.process_weights_after_loading(model_config.dtype)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def device_loading_context(module: torch.nn.Module,
|
||||
target_device: torch.device):
|
||||
if target_device.type == "cpu":
|
||||
# If target is CPU, no need to move anything
|
||||
yield module
|
||||
return
|
||||
|
||||
original_device_states: dict[str, torch.device] = {}
|
||||
|
||||
# Store original device states and move parameters to GPU if they're on CPU
|
||||
for name, p in module.named_parameters():
|
||||
if p.device.type == "cpu":
|
||||
original_device_states[name] = p.device
|
||||
p.data = p.data.to(target_device)
|
||||
# Parameters already on target device are not touched
|
||||
|
||||
try:
|
||||
yield module
|
||||
|
||||
finally:
|
||||
# Restore parameters to their original devices, ignoring new parameters
|
||||
pin_memory = is_pin_memory_available()
|
||||
for name, p in module.named_parameters():
|
||||
if name in original_device_states:
|
||||
original_device: torch.device = original_device_states[name]
|
||||
if original_device.type == "cpu":
|
||||
# `torch.empty_like` does not support `pin_memory` argument
|
||||
cpu_data = torch.empty_strided(
|
||||
size=p.data.size(),
|
||||
stride=p.data.stride(),
|
||||
dtype=p.data.dtype,
|
||||
layout=p.data.layout,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
cpu_data.copy_(p.data)
|
||||
p.data = cpu_data
|
||||
else:
|
||||
p.data = p.data.to(original_device)
|
||||
# New parameters or parameters already on target device are untouched
|
||||
|
||||
|
||||
_MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]()
|
||||
"""Caches the outputs of `_get_model_architecture`."""
|
||||
|
||||
|
||||
def _get_model_architecture(
|
||||
model_config: ModelConfig) -> tuple[type[nn.Module], str]:
|
||||
architectures = getattr(model_config.hf_config, "architectures", [])
|
||||
|
||||
model_cls, arch = model_config.registry.resolve_model_cls(
|
||||
architectures,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
if arch == model_config._get_transformers_backend_cls():
|
||||
assert model_config.model_impl != "vllm"
|
||||
if model_config.model_impl == "auto":
|
||||
logger.warning_once(
|
||||
"%s has no vLLM implementation, falling back to Transformers "
|
||||
"implementation. Some features may not be supported and "
|
||||
"performance may not be optimal.", arch)
|
||||
|
||||
convert_type = model_config.convert_type
|
||||
if convert_type != "none" and supports_multimodal(model_cls):
|
||||
logger.debug_once("Detected conversion of Multi Modal model.")
|
||||
converted = try_create_mm_pooling_model_cls(model_cls)
|
||||
if converted is not None:
|
||||
logger.debug_once("Creating wrapper class to forward pooler.")
|
||||
return converted, arch
|
||||
else:
|
||||
logger.debug_once("Attempting direct conversion.")
|
||||
|
||||
if convert_type == "none":
|
||||
pass
|
||||
elif convert_type == "embed":
|
||||
logger.debug_once("Converting to embedding model.")
|
||||
model_cls = as_embedding_model(model_cls)
|
||||
elif convert_type == "classify":
|
||||
logger.debug_once("Converting to sequence classification model.")
|
||||
model_cls = as_seq_cls_model(model_cls)
|
||||
elif convert_type == "reward":
|
||||
logger.debug_once("Converting to reward model.")
|
||||
model_cls = as_reward_model(model_cls)
|
||||
else:
|
||||
assert_never(convert_type)
|
||||
|
||||
return model_cls, arch
|
||||
|
||||
|
||||
def get_model_architecture(
|
||||
model_config: ModelConfig) -> tuple[type[nn.Module], str]:
|
||||
key = hash((
|
||||
model_config.model,
|
||||
model_config.convert_type,
|
||||
model_config.runner_type,
|
||||
model_config.trust_remote_code,
|
||||
model_config.model_impl,
|
||||
tuple(getattr(model_config.hf_config, "architectures", [])),
|
||||
))
|
||||
if key in _MODEL_ARCH_BY_HASH:
|
||||
return _MODEL_ARCH_BY_HASH[key]
|
||||
|
||||
model_arch = _get_model_architecture(model_config)
|
||||
_MODEL_ARCH_BY_HASH[key] = model_arch
|
||||
return model_arch
|
||||
|
||||
|
||||
def get_model_cls(model_config: ModelConfig) -> type[nn.Module]:
|
||||
return get_model_architecture(model_config)[0]
|
||||
|
||||
|
||||
def get_architecture_class_name(model_config: ModelConfig) -> str:
|
||||
return get_model_architecture(model_config)[1]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParamMapping:
|
||||
"""
|
||||
A class to handle parameter mapping for model weight loading.
|
||||
It creates a bidirectional mapping between packed parameters and their
|
||||
constituent parts.
|
||||
"""
|
||||
packed_mapping: dict[str, list[str]]
|
||||
inverse_packed_mapping: dict[str, tuple[str,
|
||||
int]] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
for packed_name, sub_params in self.packed_mapping.items():
|
||||
# Skip self-contained cases (e.g., {"W_pack": ["W_pack"]})
|
||||
if len(sub_params) == 1 and sub_params[0] == packed_name:
|
||||
continue
|
||||
for index, param_name in enumerate(sub_params):
|
||||
self.inverse_packed_mapping[param_name] = (
|
||||
packed_name,
|
||||
index,
|
||||
)
|
||||
|
||||
def get_sub_modules(self,
|
||||
module_name: str) -> Optional[tuple[str, list[str]]]:
|
||||
for key, value in self.packed_mapping.items():
|
||||
if module_name.endswith(key):
|
||||
return key, value
|
||||
return None
|
||||
|
||||
|
||||
def configure_quant_config(quant_config: QuantizationConfig,
|
||||
model_class: type[nn.Module]):
|
||||
"""
|
||||
Pass packed_modules_mapping by reference to quant_config so that
|
||||
quant_config can properly match fused modules
|
||||
|
||||
Note that model attributes are passed by reference to quant_config,
|
||||
enabling them to be updated by model_class.__new__ (ex. chatglm, qwen)
|
||||
|
||||
Once the `SupportsQuant` mixin has been added to all models, this
|
||||
function can be removed
|
||||
"""
|
||||
if not issubclass(model_class, SupportsQuant):
|
||||
hf_to_vllm_mapper = getattr(model_class, "hf_to_vllm_mapper", None)
|
||||
packed_mapping = getattr(model_class, "packed_modules_mapping", None)
|
||||
|
||||
# pass mappings by reference to quant_config
|
||||
if hf_to_vllm_mapper is not None:
|
||||
quant_config.apply_vllm_mapper(hf_to_vllm_mapper)
|
||||
if packed_mapping is not None:
|
||||
quant_config.packed_modules_mapping = packed_mapping
|
||||
990
vllm/model_executor/model_loader/weight_utils.py
Normal file
990
vllm/model_executor/model_loader/weight_utils.py
Normal file
@@ -0,0 +1,990 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utilities for downloading and initializing model weights."""
|
||||
import concurrent.futures
|
||||
import fnmatch
|
||||
import glob
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from pathlib import Path
|
||||
from typing import IO, Any, Callable, Optional, Union
|
||||
|
||||
import filelock
|
||||
import huggingface_hub.constants
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
|
||||
from safetensors.torch import load, load_file, safe_open, save_file
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import (QuantizationConfig,
|
||||
get_quantization_config)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import PlaceholderModule
|
||||
|
||||
try:
|
||||
from runai_model_streamer import SafetensorsStreamer
|
||||
except ImportError:
|
||||
runai_model_streamer = PlaceholderModule(
|
||||
"runai_model_streamer") # type: ignore[assignment]
|
||||
SafetensorsStreamer = runai_model_streamer.placeholder_attr(
|
||||
"SafetensorsStreamer")
|
||||
|
||||
try:
|
||||
import gguf
|
||||
except ImportError:
|
||||
gguf = PlaceholderModule("gguf")
|
||||
|
||||
try:
|
||||
from fastsafetensors import SafeTensorsFileLoader, SingleGroup
|
||||
except ImportError:
|
||||
fastsafetensors = PlaceholderModule("fastsafetensors")
|
||||
SafeTensorsFileLoader = fastsafetensors.placeholder_attr(
|
||||
"SafeTensorsFileLoader")
|
||||
SingleGroup = fastsafetensors.placeholder_attr("SingleGroup")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# use system-level temp directory for file locks, so that multiple users
|
||||
# can share the same lock without error.
|
||||
# lock files in the temp directory will be automatically deleted when the
|
||||
# system reboots, so users will not complain about annoying lock files
|
||||
temp_dir = tempfile.gettempdir()
|
||||
|
||||
|
||||
def enable_hf_transfer():
|
||||
"""automatically activates hf_transfer
|
||||
"""
|
||||
if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
|
||||
try:
|
||||
# enable hf hub transfer if available
|
||||
import hf_transfer # type: ignore # noqa
|
||||
huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
enable_hf_transfer()
|
||||
|
||||
|
||||
class DisabledTqdm(tqdm):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs, disable=True)
|
||||
|
||||
|
||||
def get_lock(model_name_or_path: Union[str, Path],
|
||||
cache_dir: Optional[str] = None):
|
||||
lock_dir = cache_dir or temp_dir
|
||||
model_name_or_path = str(model_name_or_path)
|
||||
os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
|
||||
model_name = model_name_or_path.replace("/", "-")
|
||||
hash_name = hashlib.sha256(model_name.encode()).hexdigest()
|
||||
# add hash to avoid conflict with old users' lock files
|
||||
lock_file_name = hash_name + model_name + ".lock"
|
||||
# mode 0o666 is required for the filelock to be shared across users
|
||||
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name),
|
||||
mode=0o666)
|
||||
return lock
|
||||
|
||||
|
||||
@contextmanager
|
||||
def atomic_writer(filepath: Union[str, Path],
|
||||
mode: str = 'w',
|
||||
encoding: Optional[str] = None) -> Generator[IO]:
|
||||
"""
|
||||
Context manager that provides an atomic file writing routine.
|
||||
|
||||
The context manager writes to a temporary file and, if successful,
|
||||
atomically replaces the original file.
|
||||
|
||||
Args:
|
||||
filepath (str or Path): The path to the file to write.
|
||||
mode (str): The file mode for the temporary file (e.g., 'w', 'wb').
|
||||
encoding (str): The encoding for text mode.
|
||||
|
||||
Yields:
|
||||
file object: A handle to the temporary file.
|
||||
"""
|
||||
# Create a temporary file in the same directory as the target file
|
||||
# to ensure it's on the same filesystem for an atomic replace.
|
||||
temp_dir = os.path.dirname(filepath)
|
||||
temp_fd, temp_path = tempfile.mkstemp(dir=temp_dir)
|
||||
|
||||
try:
|
||||
# Open the temporary file for writing
|
||||
with os.fdopen(temp_fd, mode=mode, encoding=encoding) as temp_file:
|
||||
yield temp_file
|
||||
|
||||
# If the 'with' block completes successfully,
|
||||
# perform the atomic replace.
|
||||
os.replace(temp_path, filepath)
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Error during atomic write. Original file '%s' not modified",
|
||||
filepath)
|
||||
raise
|
||||
finally:
|
||||
# Clean up the temporary file if it still exists.
|
||||
if os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
|
||||
def maybe_download_from_modelscope(
|
||||
model: str,
|
||||
revision: Optional[str] = None,
|
||||
download_dir: Optional[str] = None,
|
||||
ignore_patterns: Optional[Union[str, list[str]]] = None,
|
||||
allow_patterns: Optional[Union[list[str],
|
||||
str]] = None) -> Optional[str]:
|
||||
"""Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
|
||||
|
||||
Returns the path to the downloaded model, or None if the model is not
|
||||
downloaded from ModelScope."""
|
||||
if envs.VLLM_USE_MODELSCOPE:
|
||||
# download model from ModelScope hub,
|
||||
# lazy import so that modelscope is not required for normal use.
|
||||
# pylint: disable=C.
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
|
||||
# Use file lock to prevent multiple processes from
|
||||
# downloading the same model weights at the same time.
|
||||
with get_lock(model, download_dir):
|
||||
if not os.path.exists(model):
|
||||
model_path = snapshot_download(
|
||||
model_id=model,
|
||||
cache_dir=download_dir,
|
||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||
revision=revision,
|
||||
ignore_file_pattern=ignore_patterns,
|
||||
allow_patterns=allow_patterns,
|
||||
)
|
||||
else:
|
||||
model_path = model
|
||||
return model_path
|
||||
return None
|
||||
|
||||
|
||||
def _shared_pointers(tensors):
|
||||
ptrs = defaultdict(list)
|
||||
for k, v in tensors.items():
|
||||
ptrs[v.data_ptr()].append(k)
|
||||
failing = []
|
||||
for _, names in ptrs.items():
|
||||
if len(names) > 1:
|
||||
failing.append(names)
|
||||
return failing
|
||||
|
||||
|
||||
def convert_bin_to_safetensor_file(
|
||||
pt_filename: str,
|
||||
sf_filename: str,
|
||||
) -> None:
|
||||
loaded = torch.load(pt_filename, map_location="cpu", weights_only=True)
|
||||
if "state_dict" in loaded:
|
||||
loaded = loaded["state_dict"]
|
||||
shared = _shared_pointers(loaded)
|
||||
for shared_weights in shared:
|
||||
for name in shared_weights[1:]:
|
||||
loaded.pop(name)
|
||||
|
||||
# For tensors to be contiguous
|
||||
loaded = {k: v.contiguous() for k, v in loaded.items()}
|
||||
|
||||
dirname = os.path.dirname(sf_filename)
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
save_file(loaded, sf_filename, metadata={"format": "pt"})
|
||||
|
||||
# check file size
|
||||
sf_size = os.stat(sf_filename).st_size
|
||||
pt_size = os.stat(pt_filename).st_size
|
||||
if (sf_size - pt_size) / pt_size > 0.01:
|
||||
raise RuntimeError(f"""The file size different is more than 1%:
|
||||
- {sf_filename}: {sf_size}
|
||||
- {pt_filename}: {pt_size}
|
||||
""")
|
||||
|
||||
# check if the tensors are the same
|
||||
reloaded = load_file(sf_filename)
|
||||
for k in loaded:
|
||||
pt_tensor = loaded[k]
|
||||
sf_tensor = reloaded[k]
|
||||
if not torch.equal(pt_tensor, sf_tensor):
|
||||
raise RuntimeError(f"The output tensors do not match for key {k}")
|
||||
|
||||
|
||||
# TODO(woosuk): Move this to other place.
|
||||
def get_quant_config(model_config: ModelConfig,
|
||||
load_config: LoadConfig) -> QuantizationConfig:
|
||||
|
||||
quant_cls = get_quantization_config(model_config.quantization)
|
||||
|
||||
# GGUF doesn't have config file
|
||||
if model_config.quantization in ("gguf", "inc"):
|
||||
return quant_cls()
|
||||
|
||||
# Read the quantization config from the HF model config, if available.
|
||||
hf_quant_config = getattr(model_config.hf_config, "quantization_config",
|
||||
None)
|
||||
# some vision model may keep quantization_config in their text_config
|
||||
hf_text_config = getattr(model_config.hf_config, "text_config", None)
|
||||
if hf_quant_config is None and hf_text_config is not None:
|
||||
hf_quant_config = getattr(hf_text_config, "quantization_config", None)
|
||||
if hf_quant_config is None:
|
||||
# compressed-tensors uses a compressions_config
|
||||
hf_quant_config = getattr(model_config.hf_config, "compression_config",
|
||||
None)
|
||||
if hf_quant_config is not None:
|
||||
return quant_cls.from_config(hf_quant_config)
|
||||
# Inflight BNB quantization
|
||||
if model_config.quantization == "bitsandbytes":
|
||||
return quant_cls.from_config({})
|
||||
model_name_or_path = maybe_download_from_modelscope(
|
||||
model_config.model,
|
||||
revision=model_config.revision,
|
||||
download_dir=load_config.download_dir,
|
||||
allow_patterns=["*.json"],
|
||||
) or model_config.model
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
if not is_local:
|
||||
# Download the config files.
|
||||
with get_lock(model_config.model, load_config.download_dir):
|
||||
hf_folder = snapshot_download(
|
||||
model_config.model,
|
||||
revision=model_config.revision,
|
||||
allow_patterns="*.json",
|
||||
cache_dir=load_config.download_dir,
|
||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||
tqdm_class=DisabledTqdm,
|
||||
)
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
|
||||
possible_config_filenames = quant_cls.get_config_filenames()
|
||||
|
||||
# If the quantization config is not found, use the default config.
|
||||
if not possible_config_filenames:
|
||||
return quant_cls()
|
||||
|
||||
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
|
||||
|
||||
quant_config_files = [
|
||||
f for f in config_files if any(
|
||||
f.endswith(x) for x in possible_config_filenames)
|
||||
]
|
||||
if len(quant_config_files) == 0:
|
||||
raise ValueError(
|
||||
f"Cannot find the config file for {model_config.quantization}")
|
||||
if len(quant_config_files) > 1:
|
||||
raise ValueError(
|
||||
f"Found multiple config files for {model_config.quantization}: "
|
||||
f"{quant_config_files}")
|
||||
|
||||
quant_config_file = quant_config_files[0]
|
||||
with open(quant_config_file) as f:
|
||||
config = json.load(f)
|
||||
|
||||
if model_config.quantization == "bitsandbytes":
|
||||
config["adapter_name_or_path"] = model_config.model
|
||||
elif model_config.quantization == "modelopt":
|
||||
if config["producer"]["name"] == "modelopt":
|
||||
return quant_cls.from_config(config)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported quantization config"
|
||||
f" found for {model_config.quantization} in {f}.")
|
||||
|
||||
return quant_cls.from_config(config)
|
||||
|
||||
|
||||
def get_sparse_attention_config(
|
||||
model_config: ModelConfig,
|
||||
load_config: LoadConfig,
|
||||
sparse_attention_config_filename: str = "sparse_attention_config.json",
|
||||
) -> dict[str, Any]:
|
||||
model_name_or_path = model_config.model
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
if not is_local:
|
||||
# Download the config files.
|
||||
with get_lock(model_name_or_path, load_config.download_dir):
|
||||
hf_folder = snapshot_download(
|
||||
model_name_or_path,
|
||||
revision=model_config.revision,
|
||||
allow_patterns="*.json",
|
||||
cache_dir=load_config.download_dir,
|
||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||
tqdm_class=DisabledTqdm,
|
||||
)
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
|
||||
config_file = os.path.join(hf_folder, sparse_attention_config_filename)
|
||||
if not os.path.exists(config_file):
|
||||
return {}
|
||||
|
||||
# Load the sparse attention config.
|
||||
with open(config_file) as f:
|
||||
config = json.load(f)
|
||||
logger.info("Loaded sparse attention config from %s", config_file)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def download_weights_from_hf(
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str],
|
||||
allow_patterns: list[str],
|
||||
revision: Optional[str] = None,
|
||||
ignore_patterns: Optional[Union[str, list[str]]] = None,
|
||||
) -> str:
|
||||
"""Download model weights from Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
model_name_or_path (str): The model name or path.
|
||||
cache_dir (Optional[str]): The cache directory to store the model
|
||||
weights. If None, will use HF defaults.
|
||||
allow_patterns (list[str]): The allowed patterns for the
|
||||
weight files. Files matched by any of the patterns will be
|
||||
downloaded.
|
||||
revision (Optional[str]): The revision of the model.
|
||||
ignore_patterns (Optional[Union[str, list[str]]]): The patterns to
|
||||
filter out the weight files. Files matched by any of the patterns
|
||||
will be ignored.
|
||||
|
||||
Returns:
|
||||
str: The path to the downloaded model weights.
|
||||
"""
|
||||
assert len(allow_patterns) > 0
|
||||
local_only = huggingface_hub.constants.HF_HUB_OFFLINE
|
||||
if not local_only:
|
||||
# Attempt to reduce allow_patterns to a single pattern
|
||||
# so we only have to call snapshot_download once.
|
||||
try:
|
||||
fs = HfFileSystem()
|
||||
file_list = fs.ls(model_name_or_path,
|
||||
detail=False,
|
||||
revision=revision)
|
||||
|
||||
# Use the first pattern found in the HF repo's files.
|
||||
for pattern in allow_patterns:
|
||||
matching = fnmatch.filter(file_list, pattern)
|
||||
if len(matching) > 0:
|
||||
allow_patterns = [pattern]
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to get file list for '%s'. Trying each pattern in "
|
||||
"allow_patterns individually until weights have been "
|
||||
"downloaded. Error: %s", model_name_or_path, e)
|
||||
|
||||
logger.info("Using model weights format %s", allow_patterns)
|
||||
# Use file lock to prevent multiple processes from
|
||||
# downloading the same model weights at the same time.
|
||||
with get_lock(model_name_or_path, cache_dir):
|
||||
start_time = time.perf_counter()
|
||||
for allow_pattern in allow_patterns:
|
||||
hf_folder = snapshot_download(
|
||||
model_name_or_path,
|
||||
allow_patterns=allow_pattern,
|
||||
ignore_patterns=ignore_patterns,
|
||||
cache_dir=cache_dir,
|
||||
tqdm_class=DisabledTqdm,
|
||||
revision=revision,
|
||||
local_files_only=local_only,
|
||||
)
|
||||
# If we have downloaded weights for this allow_pattern,
|
||||
# we don't need to check the rest.
|
||||
if any(Path(hf_folder).glob(allow_pattern)):
|
||||
break
|
||||
time_taken = time.perf_counter() - start_time
|
||||
if time_taken > 0.5:
|
||||
logger.info("Time spent downloading weights for %s: %.6f seconds",
|
||||
model_name_or_path, time_taken)
|
||||
return hf_folder
|
||||
|
||||
|
||||
def download_safetensors_index_file_from_hf(
|
||||
model_name_or_path: str,
|
||||
index_file: str,
|
||||
cache_dir: Optional[str],
|
||||
revision: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Download hf safetensors index file from Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
model_name_or_path (str): The model name or path.
|
||||
index_file (str): The safetensors index file name
|
||||
cache_dir (Optional[str]): The cache directory to store the model
|
||||
weights. If None, will use HF defaults.
|
||||
revision (Optional[str]): The revision of the model.
|
||||
"""
|
||||
# Use file lock to prevent multiple processes from
|
||||
# downloading the same model weights at the same time.
|
||||
with get_lock(model_name_or_path, cache_dir):
|
||||
try:
|
||||
# Download the safetensors index file.
|
||||
hf_hub_download(
|
||||
repo_id=model_name_or_path,
|
||||
filename=index_file,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||
)
|
||||
# If file not found on remote or locally, we should not fail since
|
||||
# only some models will have index_file.
|
||||
except huggingface_hub.utils.LocalEntryNotFoundError:
|
||||
logger.info("No %s found in local cache.", index_file)
|
||||
except huggingface_hub.utils.EntryNotFoundError:
|
||||
logger.info("No %s found in remote.", index_file)
|
||||
|
||||
|
||||
# For models like Mistral-7B-v0.3, there are both sharded
|
||||
# safetensors files and a consolidated safetensors file.
|
||||
# Passing both of these to the weight loader functionality breaks.
|
||||
# So, we use the index_file to
|
||||
# look up which safetensors files should be used.
|
||||
def filter_duplicate_safetensors_files(hf_weights_files: list[str],
|
||||
hf_folder: str,
|
||||
index_file: str) -> list[str]:
|
||||
# model.safetensors.index.json is a mapping from keys in the
|
||||
# torch state_dict to safetensors file holding that weight.
|
||||
index_file_name = os.path.join(hf_folder, index_file)
|
||||
if not os.path.isfile(index_file_name):
|
||||
return hf_weights_files
|
||||
|
||||
# Iterate through the weight_map (weight_name: safetensors files)
|
||||
# to identify weights that we should use.
|
||||
with open(index_file_name) as f:
|
||||
weight_map = json.load(f)["weight_map"]
|
||||
weight_files_in_index = set()
|
||||
for weight_name in weight_map:
|
||||
weight_files_in_index.add(
|
||||
os.path.join(hf_folder, weight_map[weight_name]))
|
||||
# Filter out any fields that are not found in the index file.
|
||||
hf_weights_files = [
|
||||
f for f in hf_weights_files if f in weight_files_in_index
|
||||
]
|
||||
return hf_weights_files
|
||||
|
||||
|
||||
def filter_files_not_needed_for_inference(
|
||||
hf_weights_files: list[str]) -> list[str]:
|
||||
"""
|
||||
Exclude files that are not needed for inference.
|
||||
|
||||
See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
|
||||
"""
|
||||
blacklist = [
|
||||
"training_args.bin",
|
||||
"optimizer.bin",
|
||||
"optimizer.pt",
|
||||
"scheduler.pt",
|
||||
"scaler.pt",
|
||||
]
|
||||
hf_weights_files = [
|
||||
f for f in hf_weights_files
|
||||
if not any(f.endswith(x) for x in blacklist)
|
||||
]
|
||||
return hf_weights_files
|
||||
|
||||
|
||||
# explicitly use pure text format, with a newline at the end
|
||||
# this makes it impossible to see the animation in the progress bar
|
||||
# but will avoid messing up with ray or multiprocessing, which wraps
|
||||
# each line of output with some prefix.
|
||||
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
|
||||
|
||||
|
||||
def enable_tqdm(use_tqdm_on_load: bool):
|
||||
return use_tqdm_on_load and (not torch.distributed.is_initialized()
|
||||
or torch.distributed.get_rank() == 0)
|
||||
|
||||
|
||||
def np_cache_weights_iterator(
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str],
|
||||
hf_folder: str,
|
||||
hf_weights_files: list[str],
|
||||
use_tqdm_on_load: bool,
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model np files.
|
||||
|
||||
Will dump the model weights to numpy files if they are not already dumped.
|
||||
"""
|
||||
# Convert the model weights from torch tensors to numpy arrays for
|
||||
# faster loading.
|
||||
np_folder = os.path.join(hf_folder, "np")
|
||||
os.makedirs(np_folder, exist_ok=True)
|
||||
weight_names_file = os.path.join(np_folder, "weight_names.json")
|
||||
# Use file lock to prevent multiple processes from
|
||||
# dumping the same model weights to numpy at the same time.
|
||||
with get_lock(model_name_or_path, cache_dir):
|
||||
if not os.path.exists(weight_names_file):
|
||||
weight_names: list[str] = []
|
||||
for bin_file in tqdm(
|
||||
hf_weights_files,
|
||||
desc="Loading np_cache checkpoint shards",
|
||||
disable=not enable_tqdm(use_tqdm_on_load),
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
state = torch.load(bin_file,
|
||||
map_location="cpu",
|
||||
weights_only=True)
|
||||
for name, param in state.items():
|
||||
param_path = os.path.join(np_folder, name)
|
||||
with open(param_path, "wb") as f:
|
||||
np.save(f, param.cpu().detach().numpy())
|
||||
weight_names.append(name)
|
||||
with open(weight_names_file, "w") as f:
|
||||
json.dump(weight_names, f)
|
||||
|
||||
with open(weight_names_file) as f:
|
||||
weight_names = json.load(f)
|
||||
|
||||
for name in weight_names:
|
||||
param_path = os.path.join(np_folder, name)
|
||||
with open(param_path, "rb") as f:
|
||||
param = np.load(f)
|
||||
yield name, torch.from_numpy(param)
|
||||
|
||||
|
||||
def safetensors_weights_iterator(
|
||||
hf_weights_files: list[str],
|
||||
use_tqdm_on_load: bool,
|
||||
safetensors_load_strategy: str = "lazy",
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model safetensor files."""
|
||||
loading_desc = "Loading safetensors checkpoint shards"
|
||||
if safetensors_load_strategy == "eager":
|
||||
loading_desc += " (eager)"
|
||||
|
||||
for st_file in tqdm(
|
||||
hf_weights_files,
|
||||
desc=loading_desc,
|
||||
disable=not enable_tqdm(use_tqdm_on_load),
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
if safetensors_load_strategy == "eager":
|
||||
with open(st_file, "rb") as f:
|
||||
state_dict = load(f.read())
|
||||
yield from state_dict.items()
|
||||
else:
|
||||
with safe_open(st_file, framework="pt") as f:
|
||||
for name in f.keys(): # noqa: SIM118
|
||||
param = f.get_tensor(name)
|
||||
yield name, param
|
||||
|
||||
|
||||
def multi_thread_safetensors_weights_iterator(
|
||||
hf_weights_files: list[str],
|
||||
use_tqdm_on_load: bool,
|
||||
max_workers: int = 4,
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Multi-Thread iterate over the weights in the model safetensor files."""
|
||||
|
||||
def _load_file(st_file: str):
|
||||
result = load_file(st_file, device="cpu")
|
||||
return result
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=max_workers) as executor:
|
||||
futures = [
|
||||
executor.submit(_load_file, st_file)
|
||||
for st_file in hf_weights_files
|
||||
]
|
||||
futures_iter = tqdm(
|
||||
concurrent.futures.as_completed(futures),
|
||||
total=len(hf_weights_files),
|
||||
desc="Multi-thread loading shards",
|
||||
disable=not enable_tqdm(use_tqdm_on_load),
|
||||
bar_format=_BAR_FORMAT,
|
||||
)
|
||||
|
||||
for future in futures_iter:
|
||||
state_dict = future.result()
|
||||
yield from state_dict.items()
|
||||
|
||||
|
||||
def runai_safetensors_weights_iterator(
|
||||
hf_weights_files: list[str],
|
||||
use_tqdm_on_load: bool,
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model safetensor files."""
|
||||
with SafetensorsStreamer() as streamer:
|
||||
streamer.stream_files(hf_weights_files)
|
||||
total_tensors = sum(
|
||||
len(tensors_meta)
|
||||
for tensors_meta in streamer.files_to_tensors_metadata.values())
|
||||
|
||||
tensor_iter = tqdm(
|
||||
streamer.get_tensors(),
|
||||
total=total_tensors,
|
||||
desc="Loading safetensors using Runai Model Streamer",
|
||||
bar_format=_BAR_FORMAT,
|
||||
disable=not enable_tqdm(use_tqdm_on_load),
|
||||
)
|
||||
|
||||
yield from tensor_iter
|
||||
|
||||
|
||||
def fastsafetensors_weights_iterator(
|
||||
hf_weights_files: list[str],
|
||||
use_tqdm_on_load: bool,
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model safetensor files
|
||||
using fastsafetensor library."""
|
||||
if torch.distributed.is_initialized():
|
||||
pg = torch.distributed.group.WORLD
|
||||
else:
|
||||
pg = SingleGroup()
|
||||
|
||||
device = torch.device(f'cuda:{pg.rank()}')
|
||||
weight_files_sub_lists = [
|
||||
hf_weights_files[i:i + pg.size()]
|
||||
for i in range(0, len(hf_weights_files), pg.size())
|
||||
]
|
||||
|
||||
for f_list in tqdm(
|
||||
weight_files_sub_lists,
|
||||
desc="Loading safetensors using Fastsafetensor loader",
|
||||
disable=not enable_tqdm(use_tqdm_on_load),
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
loader = SafeTensorsFileLoader(pg, device)
|
||||
rank_file_map = {i: [f] for i, f in enumerate(f_list)}
|
||||
loader.add_filenames(rank_file_map)
|
||||
try:
|
||||
fb = loader.copy_files_to_device()
|
||||
try:
|
||||
keys = list(fb.key_to_rank_lidx.keys())
|
||||
for k in keys:
|
||||
t = fb.get_tensor(k)
|
||||
yield k, t
|
||||
finally:
|
||||
fb.close()
|
||||
finally:
|
||||
loader.close()
|
||||
|
||||
|
||||
def pt_weights_iterator(
|
||||
hf_weights_files: list[str],
|
||||
use_tqdm_on_load: bool,
|
||||
pt_load_map_location: Union[str, dict[str, str]] = "cpu",
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model bin/pt files."""
|
||||
for bin_file in tqdm(
|
||||
hf_weights_files,
|
||||
desc="Loading pt checkpoint shards",
|
||||
disable=not enable_tqdm(use_tqdm_on_load),
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
state = torch.load(bin_file,
|
||||
map_location=pt_load_map_location,
|
||||
weights_only=True)
|
||||
yield from state.items()
|
||||
del state
|
||||
|
||||
|
||||
def multi_thread_pt_weights_iterator(
|
||||
hf_weights_files: list[str],
|
||||
use_tqdm_on_load: bool,
|
||||
pt_load_map_location: Union[str, dict[str, str]] = "cpu",
|
||||
max_workers: int = 4,
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Multi-Thread iterate over the weights in the model bin/pt files."""
|
||||
|
||||
def _load_file(bin_file: str):
|
||||
return torch.load(bin_file,
|
||||
map_location=pt_load_map_location,
|
||||
weights_only=True)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=max_workers) as executor:
|
||||
futures = [
|
||||
executor.submit(_load_file, bin_file)
|
||||
for bin_file in hf_weights_files
|
||||
]
|
||||
futures_iter = tqdm(
|
||||
concurrent.futures.as_completed(futures),
|
||||
total=len(hf_weights_files),
|
||||
desc="Multi-thread loading pt checkpoint shards",
|
||||
disable=not enable_tqdm(use_tqdm_on_load),
|
||||
bar_format=_BAR_FORMAT,
|
||||
)
|
||||
|
||||
for future in futures_iter:
|
||||
state = future.result()
|
||||
yield from state.items()
|
||||
del state
|
||||
|
||||
|
||||
def get_gguf_extra_tensor_names(
|
||||
gguf_file: str, gguf_to_hf_name_map: dict[str, str]) -> list[str]:
|
||||
reader = gguf.GGUFReader(gguf_file)
|
||||
expected_gguf_keys = set(gguf_to_hf_name_map.keys())
|
||||
exact_gguf_keys = set([tensor.name for tensor in reader.tensors])
|
||||
extra_keys = expected_gguf_keys - exact_gguf_keys
|
||||
return [gguf_to_hf_name_map[key] for key in extra_keys]
|
||||
|
||||
|
||||
def get_gguf_weight_type_map(
|
||||
gguf_file: str, gguf_to_hf_name_map: dict[str, str]) -> dict[str, str]:
|
||||
"""
|
||||
Return GGUF mapped weight's name and its quant type
|
||||
"""
|
||||
reader = gguf.GGUFReader(gguf_file)
|
||||
return {
|
||||
gguf_to_hf_name_map[tensor.name]: tensor.tensor_type.name
|
||||
for tensor in reader.tensors if tensor.name in gguf_to_hf_name_map
|
||||
}
|
||||
|
||||
|
||||
def gguf_quant_weights_iterator(
|
||||
gguf_file: str, gguf_to_hf_name_map: dict[str, str]
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""
|
||||
Iterate over the quant weights in the model gguf files and convert
|
||||
them to torch tensors
|
||||
"""
|
||||
|
||||
reader = gguf.GGUFReader(gguf_file)
|
||||
|
||||
for tensor in reader.tensors:
|
||||
if tensor.name in gguf_to_hf_name_map:
|
||||
weight_type = tensor.tensor_type
|
||||
name = gguf_to_hf_name_map[tensor.name]
|
||||
|
||||
if weight_type.name != "F32":
|
||||
weight_type_name = name.replace("weight", "qweight_type")
|
||||
weight_type = torch.tensor(weight_type)
|
||||
yield weight_type_name, weight_type
|
||||
|
||||
for tensor in reader.tensors:
|
||||
if tensor.name in gguf_to_hf_name_map:
|
||||
weight = tensor.data
|
||||
weight_type = tensor.tensor_type
|
||||
name = gguf_to_hf_name_map[tensor.name]
|
||||
if weight_type.name != "F32":
|
||||
name = name.replace("weight", "qweight")
|
||||
param = torch.tensor(weight)
|
||||
yield name, param
|
||||
|
||||
|
||||
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
|
||||
"""convert PySafeSlice object from safetensors to torch.Tensor
|
||||
|
||||
PySafeSlice object supports indexing, which is done before loading the
|
||||
actual tensor and can reduce the amount of memory being read into the
|
||||
memory. However, it does not support more advanced functionalities
|
||||
like `.view()` or `.t()`. Therefore, if we need to modify the loaded
|
||||
tensor with these more complicated operators, we need to convert to
|
||||
tensor first.
|
||||
"""
|
||||
if not isinstance(x, torch.Tensor):
|
||||
x = x[:]
|
||||
return x
|
||||
|
||||
|
||||
def default_weight_loader(param: torch.Tensor,
|
||||
loaded_weight: torch.Tensor) -> None:
|
||||
"""Default weight loader."""
|
||||
try:
|
||||
if param.numel() == 1 and loaded_weight.numel() == 1:
|
||||
# Sometimes scalar values aren't considered tensors with shapes
|
||||
# so if both param and loaded_weight are a scalar,
|
||||
# "broadcast" instead of copy
|
||||
param.data.fill_(loaded_weight.item())
|
||||
else:
|
||||
assert param.size() == loaded_weight.size(), (
|
||||
f"Attempted to load weight ({loaded_weight.size()}) "
|
||||
f"into parameter ({param.size()})")
|
||||
|
||||
param.data.copy_(loaded_weight)
|
||||
except Exception:
|
||||
# NOTE: This exception is added for the purpose of setting breakpoint to
|
||||
# debug weight loading issues.
|
||||
raise
|
||||
|
||||
|
||||
def row_parallel_weight_loader(param: torch.Tensor,
|
||||
loaded_weight: torch.Tensor) -> None:
|
||||
"""Load weights that are row-parallelized."""
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_dim = 0 if param.dim() != 1 else None
|
||||
|
||||
if shard_dim is not None:
|
||||
shard_size = param.data.shape[shard_dim]
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(shard_dim, start_idx, shard_size)
|
||||
|
||||
return default_weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
LoaderFunction = Callable[[torch.Tensor, torch.Tensor], None]
|
||||
|
||||
|
||||
def sharded_weight_loader(shard_axis: int) -> LoaderFunction:
|
||||
"""Create a weight loader that shards the weights along the given axis"""
|
||||
|
||||
def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
shard_size = param.data.shape[shard_axis]
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(shard_axis, start_idx, shard_size)
|
||||
|
||||
return default_weight_loader(param, loaded_weight)
|
||||
|
||||
return loader
|
||||
|
||||
|
||||
def composed_weight_loader(
|
||||
loader: LoaderFunction, fn: Callable[[torch.Tensor],
|
||||
torch.Tensor]) -> LoaderFunction:
|
||||
"""Create a weight loader that post-processes the weights after loading"""
|
||||
|
||||
def composed_loader(param: torch.Tensor,
|
||||
loaded_weight: torch.Tensor) -> None:
|
||||
loader(param, loaded_weight)
|
||||
param.data.copy_(fn(param))
|
||||
return
|
||||
|
||||
return composed_loader
|
||||
|
||||
|
||||
def initialize_dummy_weights(
|
||||
model: torch.nn.Module,
|
||||
low: float = -1e-3,
|
||||
high: float = 1e-3,
|
||||
seed: int = 1234,
|
||||
) -> None:
|
||||
"""Initialize model weights with random values.
|
||||
|
||||
The model weights must be randomly initialized for accurate performance
|
||||
measurements. Additionally, the model weights should not cause NaNs in the
|
||||
forward pass. We empirically found that initializing the weights with
|
||||
values between -1e-3 and 1e-3 works well for most models.
|
||||
|
||||
We use per-parameter random seed, so that dummy weights are consistent,
|
||||
even if the model is partitioned across multiple devices. When the seed
|
||||
is fixed, the random values generated by this function only depends on
|
||||
the parameter's number of elements and its data type.
|
||||
"""
|
||||
for param in model.state_dict().values():
|
||||
if torch.is_floating_point(param):
|
||||
if current_platform.is_tpu():
|
||||
generator = torch.Generator(device="cpu")
|
||||
generator.manual_seed(seed)
|
||||
# Note: The param.uniform_ function cannot be used in this
|
||||
# context because it demands more TPU HBM than directly copying
|
||||
# from a CPU tensor.
|
||||
# Note: We avoid using torch.rank_like as it doesn't currently
|
||||
# support the generator argument.
|
||||
param.copy_((high - low) *
|
||||
torch.rand(param.shape,
|
||||
generator=generator,
|
||||
dtype=param.dtype,
|
||||
layout=param.layout,
|
||||
requires_grad=param.requires_grad,
|
||||
device="cpu") + low)
|
||||
torch._sync(param)
|
||||
continue
|
||||
|
||||
generator = torch.Generator(device=param.data.device)
|
||||
generator.manual_seed(seed)
|
||||
if torch.finfo(param.data.dtype).bits < 16:
|
||||
# uniform_ doesn't support < 16-bit datatypes (FP8)
|
||||
dtype = param.data.dtype
|
||||
tmp_param = param.data.to(torch.float16)
|
||||
tmp_param = tmp_param.uniform_(low, high,
|
||||
generator=generator).to(dtype)
|
||||
param.data.copy_(tmp_param)
|
||||
else:
|
||||
param.uniform_(low, high, generator=generator)
|
||||
|
||||
|
||||
def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
|
||||
"""Remap the name of FP8 k/v_scale parameters.
|
||||
|
||||
This function handles the remapping of FP8 k/v_scale parameter names.
|
||||
It detects if the given name ends with a suffix and attempts to remap
|
||||
it to the expected name format in the model. If the remapped name is not
|
||||
found in the params_dict, a warning is printed and None is returned.
|
||||
|
||||
Args:
|
||||
name (str): The original loaded checkpoint parameter name.
|
||||
params_dict (dict): Dictionary containing the model's named parameters.
|
||||
|
||||
Returns:
|
||||
str: The remapped parameter name if successful, or the original name
|
||||
if no remapping is needed.
|
||||
None: If the remapped name is not found in params_dict.
|
||||
"""
|
||||
if name.endswith(".kv_scale"):
|
||||
logger.warning_once(
|
||||
"DEPRECATED. Found kv_scale in the checkpoint. "
|
||||
"This format is deprecated in favor of separate k_scale and "
|
||||
"v_scale tensors and will be removed in a future release. "
|
||||
"Functionally, we will remap kv_scale to k_scale and duplicate "
|
||||
"k_scale to v_scale")
|
||||
# NOTE: we remap the deprecated kv_scale to k_scale
|
||||
remapped_name = name.replace(".kv_scale", ".attn.k_scale")
|
||||
if remapped_name not in params_dict:
|
||||
logger.warning_once(
|
||||
"Found kv_scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv_scale is not loaded.", # noqa: E501
|
||||
name,
|
||||
remapped_name,
|
||||
)
|
||||
return None
|
||||
return remapped_name
|
||||
|
||||
# Define scale name mapping patterns in order of precedence
|
||||
scale_mapping_patterns = [
|
||||
# ModelOpt format: .self_attn.{k,v}_proj.{k,v}_scale ->
|
||||
# .self_attn.attn.{k,v}_scale
|
||||
(r"\.self_attn\.([kv])_proj\.([kv])_scale$",
|
||||
r".self_attn.attn.\2_scale"),
|
||||
# QKV proj format: .self_attn.qkv_proj.{k,v}_scale ->
|
||||
# .self_attn.attn.{k,v}_scale
|
||||
(r"\.self_attn\.qkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale"),
|
||||
# Qwen3 MoE format: .self_attn.qkqkv_proj.{k,v}_scale ->
|
||||
# .self_attn.attn.{k,v}_scale
|
||||
(r"\.self_attn\.qkqkv_proj\.([kv])_scale$", r".self_attn.attn.\1_scale"
|
||||
),
|
||||
# Default format: .{k,v}_scale -> .attn.{k,v}_scale
|
||||
(r"\.([kv])_scale$", r".attn.\1_scale"),
|
||||
]
|
||||
|
||||
# Check if name ends with k_scale or v_scale
|
||||
if name.endswith((".k_scale", ".v_scale")):
|
||||
import regex as re
|
||||
|
||||
for pattern, replacement in scale_mapping_patterns:
|
||||
if re.search(pattern, name):
|
||||
remapped_name = re.sub(pattern, replacement, name)
|
||||
if remapped_name not in params_dict:
|
||||
scale_type = name.split(".")[-1]
|
||||
logger.warning_once(
|
||||
"Found %s in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). %s is not loaded.", # noqa: E501
|
||||
scale_type,
|
||||
name,
|
||||
remapped_name,
|
||||
scale_type,
|
||||
)
|
||||
return None
|
||||
return remapped_name
|
||||
|
||||
# If there were no matches, return the untouched param name
|
||||
return name
|
||||
Reference in New Issue
Block a user