Sync from v0.13
This commit is contained in:
@@ -1,30 +1,150 @@
|
||||
from typing import Optional
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
||||
from vllm.model_executor.model_loader.loader import (BaseModelLoader,
|
||||
get_model_loader)
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.bitsandbytes_loader import BitsAndBytesModelLoader
|
||||
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
|
||||
from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader
|
||||
from vllm.model_executor.model_loader.gguf_loader import GGUFModelLoader
|
||||
from vllm.model_executor.model_loader.runai_streamer_loader import (
|
||||
RunaiModelStreamerLoader,
|
||||
)
|
||||
from vllm.model_executor.model_loader.sharded_state_loader import ShardedStateLoader
|
||||
from vllm.model_executor.model_loader.tensorizer_loader import TensorizerLoader
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
get_architecture_class_name, get_model_architecture)
|
||||
get_architecture_class_name,
|
||||
get_model_architecture,
|
||||
get_model_cls,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Reminder: Please update docstring in `LoadConfig`
|
||||
# if a new load format is added here
|
||||
LoadFormats = Literal[
|
||||
"auto",
|
||||
"hf",
|
||||
"bitsandbytes",
|
||||
"dummy",
|
||||
"fastsafetensors",
|
||||
"gguf",
|
||||
"mistral",
|
||||
"npcache",
|
||||
"pt",
|
||||
"runai_streamer",
|
||||
"runai_streamer_sharded",
|
||||
"safetensors",
|
||||
"sharded_state",
|
||||
"tensorizer",
|
||||
]
|
||||
_LOAD_FORMAT_TO_MODEL_LOADER: dict[str, type[BaseModelLoader]] = {
|
||||
"auto": DefaultModelLoader,
|
||||
"hf": DefaultModelLoader,
|
||||
"bitsandbytes": BitsAndBytesModelLoader,
|
||||
"dummy": DummyModelLoader,
|
||||
"fastsafetensors": DefaultModelLoader,
|
||||
"gguf": GGUFModelLoader,
|
||||
"mistral": DefaultModelLoader,
|
||||
"npcache": DefaultModelLoader,
|
||||
"pt": DefaultModelLoader,
|
||||
"runai_streamer": RunaiModelStreamerLoader,
|
||||
"runai_streamer_sharded": ShardedStateLoader,
|
||||
"safetensors": DefaultModelLoader,
|
||||
"sharded_state": ShardedStateLoader,
|
||||
"tensorizer": TensorizerLoader,
|
||||
}
|
||||
|
||||
|
||||
def register_model_loader(load_format: str):
|
||||
"""Register a customized vllm model loader.
|
||||
|
||||
When a load format is not supported by vllm, you can register a customized
|
||||
model loader to support it.
|
||||
|
||||
Args:
|
||||
load_format (str): The model loader format name.
|
||||
|
||||
Examples:
|
||||
>>> from vllm.config.load import LoadConfig
|
||||
>>> from vllm.model_executor.model_loader import (
|
||||
... get_model_loader,
|
||||
... register_model_loader,
|
||||
... )
|
||||
>>> from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
>>>
|
||||
>>> @register_model_loader("my_loader")
|
||||
... class MyModelLoader(BaseModelLoader):
|
||||
... def download_model(self):
|
||||
... pass
|
||||
...
|
||||
... def load_weights(self):
|
||||
... pass
|
||||
>>>
|
||||
>>> load_config = LoadConfig(load_format="my_loader")
|
||||
>>> type(get_model_loader(load_config))
|
||||
<class 'MyModelLoader'>
|
||||
""" # noqa: E501
|
||||
|
||||
def _wrapper(model_loader_cls):
|
||||
if load_format in _LOAD_FORMAT_TO_MODEL_LOADER:
|
||||
logger.warning(
|
||||
"Load format `%s` is already registered, and will be "
|
||||
"overwritten by the new loader class `%s`.",
|
||||
load_format,
|
||||
model_loader_cls,
|
||||
)
|
||||
if not issubclass(model_loader_cls, BaseModelLoader):
|
||||
raise ValueError(
|
||||
"The model loader must be a subclass of `BaseModelLoader`."
|
||||
)
|
||||
_LOAD_FORMAT_TO_MODEL_LOADER[load_format] = model_loader_cls
|
||||
logger.info(
|
||||
"Registered model loader `%s` with load format `%s`",
|
||||
model_loader_cls,
|
||||
load_format,
|
||||
)
|
||||
return model_loader_cls
|
||||
|
||||
return _wrapper
|
||||
|
||||
|
||||
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
||||
"""Get a model loader based on the load format."""
|
||||
load_format = load_config.load_format
|
||||
if load_format not in _LOAD_FORMAT_TO_MODEL_LOADER:
|
||||
raise ValueError(f"Load format `{load_format}` is not supported")
|
||||
return _LOAD_FORMAT_TO_MODEL_LOADER[load_format](load_config)
|
||||
|
||||
|
||||
def get_model(
|
||||
*, model_config: ModelConfig, load_config: LoadConfig,
|
||||
device_config: DeviceConfig, parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig, lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
|
||||
loader = get_model_loader(load_config)
|
||||
return loader.load_model(model_config=model_config,
|
||||
device_config=device_config,
|
||||
lora_config=lora_config,
|
||||
vision_language_config=vision_language_config,
|
||||
parallel_config=parallel_config,
|
||||
scheduler_config=scheduler_config)
|
||||
*, vllm_config: VllmConfig, model_config: ModelConfig | None = None
|
||||
) -> nn.Module:
|
||||
loader = get_model_loader(vllm_config.load_config)
|
||||
if model_config is None:
|
||||
model_config = vllm_config.model_config
|
||||
return loader.load_model(vllm_config=vllm_config, model_config=model_config)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_model", "get_model_loader", "BaseModelLoader",
|
||||
"get_architecture_class_name", "get_model_architecture"
|
||||
"get_model",
|
||||
"get_model_loader",
|
||||
"get_architecture_class_name",
|
||||
"get_model_architecture",
|
||||
"get_model_cls",
|
||||
"register_model_loader",
|
||||
"BaseModelLoader",
|
||||
"BitsAndBytesModelLoader",
|
||||
"GGUFModelLoader",
|
||||
"DefaultModelLoader",
|
||||
"DummyModelLoader",
|
||||
"RunaiModelStreamerLoader",
|
||||
"ShardedStateLoader",
|
||||
"TensorizerLoader",
|
||||
]
|
||||
|
||||
57
vllm/model_executor/model_loader/base_loader.py
Normal file
57
vllm/model_executor/model_loader/base_loader.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
initialize_model,
|
||||
process_weights_after_loading,
|
||||
)
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BaseModelLoader(ABC):
|
||||
"""Base class for model loaders."""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
self.load_config = load_config
|
||||
|
||||
@abstractmethod
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
"""Download a model so that it can be immediately loaded."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
"""Load weights into a model. This standalone API allows
|
||||
inplace weights loading for an already-initialized model"""
|
||||
raise NotImplementedError
|
||||
|
||||
def load_model(
|
||||
self, vllm_config: VllmConfig, model_config: ModelConfig
|
||||
) -> nn.Module:
|
||||
"""Load a model with the given configurations."""
|
||||
device_config = vllm_config.device_config
|
||||
load_config = vllm_config.load_config
|
||||
load_device = (
|
||||
device_config.device if load_config.device is None else load_config.device
|
||||
)
|
||||
target_device = torch.device(load_device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
model = initialize_model(
|
||||
vllm_config=vllm_config, model_config=model_config
|
||||
)
|
||||
|
||||
logger.debug("Loading weights on %s ...", load_device)
|
||||
# Quantization does not happen in `load_weights` but after it
|
||||
self.load_weights(model, model_config)
|
||||
process_weights_after_loading(model, model_config, target_device)
|
||||
return model.eval()
|
||||
822
vllm/model_executor/model_loader/bitsandbytes_loader.py
Normal file
822
vllm/model_executor/model_loader/bitsandbytes_loader.py
Normal file
@@ -0,0 +1,822 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: SIM117
|
||||
import fnmatch
|
||||
import glob
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
from packaging import version
|
||||
from torch import nn
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.linear import (
|
||||
LinearBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.utils import ParamMapping
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_safetensors_index_file_from_hf,
|
||||
download_weights_from_hf,
|
||||
filter_duplicate_safetensors_files,
|
||||
filter_files_not_needed_for_inference,
|
||||
pt_weights_iterator,
|
||||
safetensors_weights_iterator,
|
||||
)
|
||||
from vllm.model_executor.models import is_pooling_model
|
||||
from vllm.model_executor.utils import (
|
||||
get_moe_expert_mapping,
|
||||
get_packed_modules_mapping,
|
||||
set_weight_attrs,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def is_moe_model(model: torch.nn.Module) -> bool:
|
||||
"""Checks if the model contains FusedMoE layers."""
|
||||
return bool(any(isinstance(module, FusedMoE) for module in model.modules()))
|
||||
|
||||
|
||||
class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
"""Model loader to load model weights with BitAndBytes quantization."""
|
||||
|
||||
possible_config_file_names = ["adapter_config.json"]
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
|
||||
# Save the module names without sharding.
|
||||
self.unsharded_weights_modules: list[str] = []
|
||||
# Save the module names that are sharded by column.
|
||||
self.column_sharded_weights_modules: list[str] = []
|
||||
# Modules whose weights might have fused on disk
|
||||
# we need their output_sizes to make shard in flight correctly with TP
|
||||
self.maybe_fused_weights_modules: dict[str, list[int]] = {}
|
||||
# Store all module names (from transformers) that support
|
||||
# BNB quantization.
|
||||
self.target_modules: list[str] = []
|
||||
self.tp_disabled_modules: list[str] = []
|
||||
# Store the mapping of expert parameters for MoE models.
|
||||
self.expert_params_mapping: list[tuple[str, str, int, str]] = []
|
||||
# mapping weight names from transformers to vllm.
|
||||
self.weight_mapper: Callable = lambda name: name
|
||||
self.pre_quant: bool = False
|
||||
self.load_8bit: bool = False
|
||||
self.is_pool_model: bool = False
|
||||
|
||||
def _get_weight_files(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
allowed_patterns: list[str],
|
||||
revision: str | None = None,
|
||||
) -> tuple[str, list[str], str]:
|
||||
"""Retrieve weight files. Download the files if necessary.
|
||||
|
||||
Return the weight files and the file pattern."""
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
|
||||
if is_local:
|
||||
for pattern in allowed_patterns:
|
||||
weight_files = glob.glob(os.path.join(model_name_or_path, pattern))
|
||||
if weight_files:
|
||||
return model_name_or_path, weight_files, pattern
|
||||
else:
|
||||
hf_api = HfApi()
|
||||
repo_files = hf_api.list_repo_files(repo_id=model_name_or_path)
|
||||
for pattern in allowed_patterns:
|
||||
matching_files = fnmatch.filter(repo_files, pattern)
|
||||
if matching_files:
|
||||
hf_folder = download_weights_from_hf(
|
||||
model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
[pattern],
|
||||
revision,
|
||||
ignore_patterns=self.load_config.ignore_patterns,
|
||||
)
|
||||
return (
|
||||
hf_folder,
|
||||
glob.glob(os.path.join(hf_folder, pattern)),
|
||||
pattern,
|
||||
)
|
||||
|
||||
raise RuntimeError(f"No model weights found in: `{model_name_or_path}`")
|
||||
|
||||
def _prepare_weights(
|
||||
self, model_name_or_path: str, revision: str | None
|
||||
) -> tuple[list[str], bool]:
|
||||
"""Prepare weight files for the model."""
|
||||
|
||||
allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]
|
||||
|
||||
hf_folder, hf_weights_files, matched_pattern = self._get_weight_files(
|
||||
model_name_or_path, allowed_patterns, revision
|
||||
)
|
||||
|
||||
use_safetensors = matched_pattern == "*.safetensors"
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
index_file = SAFE_WEIGHTS_INDEX_NAME
|
||||
if use_safetensors:
|
||||
# For models like Mistral-7B-Instruct-v0.3
|
||||
# there are both sharded safetensors files and a consolidated
|
||||
# safetensors file. Using both breaks.
|
||||
# Here, we download the `model.safetensors.index.json` and filter
|
||||
# any files not found in the index.
|
||||
if not is_local:
|
||||
download_safetensors_index_file_from_hf(
|
||||
model_name_or_path,
|
||||
index_file,
|
||||
self.load_config.download_dir,
|
||||
revision,
|
||||
)
|
||||
hf_weights_files = filter_duplicate_safetensors_files(
|
||||
hf_weights_files, hf_folder, index_file
|
||||
)
|
||||
else:
|
||||
hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files)
|
||||
|
||||
if len(hf_weights_files) == 0:
|
||||
raise RuntimeError(
|
||||
f"Cannot find any model weights with `{model_name_or_path}`"
|
||||
)
|
||||
|
||||
return hf_weights_files, use_safetensors
|
||||
|
||||
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
|
||||
def _maybe_pool_model(module_name: str):
|
||||
# For pool model, we need to add the prefix `model.`
|
||||
# for the weight name if possible.
|
||||
if (
|
||||
self.is_pool_model
|
||||
and self.target_modules[0].startswith("model.")
|
||||
and not module_name.startswith("model.")
|
||||
):
|
||||
return "model." + module_name
|
||||
|
||||
return module_name
|
||||
|
||||
if use_safetensors:
|
||||
iterator = safetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
)
|
||||
else:
|
||||
iterator = pt_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
self.load_config.pt_load_map_location,
|
||||
)
|
||||
for org_name, param in iterator:
|
||||
# mapping weight names from transformers to vllm while preserving
|
||||
# original names.
|
||||
mapped_name = self.weight_mapper(org_name)
|
||||
mapped_name = _maybe_pool_model(mapped_name)
|
||||
|
||||
yield org_name, mapped_name, param
|
||||
|
||||
def _get_quantized_weights_iterator(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
revision: str | None,
|
||||
) -> tuple[Generator[tuple[str, torch.Tensor], None, None], dict[str, Any]]:
|
||||
"""Get an iterator to the model weights with bitsandbytes quantization,
|
||||
as well as the quantization state dictionary."""
|
||||
|
||||
# only load the bitsandbytes module when needed
|
||||
try:
|
||||
import bitsandbytes
|
||||
|
||||
if version.parse(bitsandbytes.__version__) < version.parse("0.46.1"):
|
||||
raise ImportError(
|
||||
"bitsandbytes version is wrong. Please "
|
||||
"install bitsandbytes>=0.46.1."
|
||||
)
|
||||
except ImportError as err:
|
||||
raise ImportError(
|
||||
"Please install bitsandbytes>=0.46.1 via "
|
||||
"`pip install bitsandbytes>=0.46.1` to use "
|
||||
"bitsandbytes quantizer."
|
||||
) from err
|
||||
|
||||
hf_weights_files, use_safetensors = self._prepare_weights(
|
||||
model_name_or_path, revision
|
||||
)
|
||||
|
||||
quant_state_dict: dict[str, Any] = {}
|
||||
|
||||
if self.pre_quant:
|
||||
if self.load_8bit:
|
||||
return self._quantized_8bit_generator(
|
||||
hf_weights_files, use_safetensors, quant_state_dict
|
||||
), quant_state_dict
|
||||
else:
|
||||
return self._quantized_4bit_generator(
|
||||
hf_weights_files, use_safetensors, quant_state_dict
|
||||
), quant_state_dict
|
||||
|
||||
return self._unquantized_generator(
|
||||
hf_weights_files, use_safetensors, quant_state_dict
|
||||
), quant_state_dict
|
||||
|
||||
def _is_8bit_weight_name(self, weight_name: str):
|
||||
quantized_suffix = {".scb", ".weight_format"}
|
||||
return any(weight_name.lower().endswith(suffix) for suffix in quantized_suffix)
|
||||
|
||||
def _is_4bit_weight_name(self, weight_name: str):
|
||||
quantized_suffix = {
|
||||
"absmax",
|
||||
"quant_map",
|
||||
"nested_absmax",
|
||||
"nested_quant_map",
|
||||
"bitsandbytes",
|
||||
}
|
||||
suffix = weight_name.split(".")[-1]
|
||||
return any(q_suffix in suffix for q_suffix in quantized_suffix)
|
||||
|
||||
def _quantized_8bit_generator(
|
||||
self, hf_weights_files, use_safetensors, quant_state_dict
|
||||
) -> Generator:
|
||||
for (
|
||||
org_weight_name,
|
||||
mapped_weight_name,
|
||||
weight_tensor,
|
||||
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
|
||||
if not mapped_weight_name.lower().endswith(".scb"):
|
||||
continue
|
||||
|
||||
weight_key = mapped_weight_name.lower().replace(".scb", ".weight")
|
||||
quant_state_dict[weight_key] = weight_tensor
|
||||
|
||||
for (
|
||||
org_weight_name,
|
||||
mapped_weight_name,
|
||||
weight_tensor,
|
||||
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
|
||||
if self._is_8bit_weight_name(mapped_weight_name):
|
||||
continue
|
||||
|
||||
if mapped_weight_name in quant_state_dict:
|
||||
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
|
||||
yield org_weight_name, weight_tensor
|
||||
else:
|
||||
yield org_weight_name, weight_tensor
|
||||
|
||||
def _quantized_4bit_generator(
|
||||
self, hf_weights_files, use_safetensors, quant_state_dict
|
||||
) -> Generator:
|
||||
from bitsandbytes.functional import QuantState
|
||||
|
||||
# First iterate over all quant state weights
|
||||
weight_iterator = self._hf_weight_iter(hf_weights_files, use_safetensors)
|
||||
temp_state_dict = {}
|
||||
for (
|
||||
org_weight_name,
|
||||
mapped_weight_name,
|
||||
weight_tensor,
|
||||
) in weight_iterator:
|
||||
if not self._is_4bit_weight_name(mapped_weight_name):
|
||||
continue
|
||||
# bitsandbytes library requires
|
||||
# weight.quant_state.bitsandbytes__* in CPU
|
||||
if "quant_state.bitsandbytes" in mapped_weight_name:
|
||||
temp_state_dict[mapped_weight_name] = weight_tensor.cpu().data
|
||||
else:
|
||||
temp_state_dict[mapped_weight_name] = weight_tensor
|
||||
|
||||
# Closure to parse quant_state for each prequant weight
|
||||
def _parse_quant_state(param_name: str, temp_state_dict: dict) -> QuantState:
|
||||
quant_state = {}
|
||||
for k in temp_state_dict:
|
||||
if param_name + "." in k:
|
||||
quant_state[k] = temp_state_dict[k]
|
||||
|
||||
return QuantState.from_dict(
|
||||
quant_state, device=current_platform.device_type
|
||||
)
|
||||
|
||||
# Second iterate over all prequant and normal weights
|
||||
# pre quantized weights would have a quant_state
|
||||
for (
|
||||
org_weight_name,
|
||||
mapped_weight_name,
|
||||
weight_tensor,
|
||||
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
|
||||
if self._is_4bit_weight_name(mapped_weight_name):
|
||||
continue
|
||||
|
||||
if (
|
||||
f"{mapped_weight_name}.quant_state.bitsandbytes__nf4" in temp_state_dict
|
||||
) or (
|
||||
f"{mapped_weight_name}.quant_state.bitsandbytes__fp4" in temp_state_dict
|
||||
):
|
||||
quant_state = _parse_quant_state(mapped_weight_name, temp_state_dict)
|
||||
quant_state_dict[mapped_weight_name] = quant_state
|
||||
yield org_weight_name, weight_tensor
|
||||
else:
|
||||
yield org_weight_name, weight_tensor
|
||||
|
||||
def _unquantized_generator(
|
||||
self, hf_weights_files, use_safetensors, quant_state_dict
|
||||
) -> Generator:
|
||||
from bitsandbytes.functional import quantize_4bit
|
||||
|
||||
global_tp_size = get_tensor_model_parallel_world_size()
|
||||
global_tp_rank = get_tensor_model_parallel_rank()
|
||||
check_match = (
|
||||
lambda weight_name, module_name: weight_name.removesuffix(".weight")
|
||||
== module_name
|
||||
)
|
||||
for (
|
||||
org_weight_name,
|
||||
mapped_weight_name,
|
||||
weight_tensor,
|
||||
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
|
||||
# override tp_size and tp_rank if the module has disabled TP
|
||||
if any(
|
||||
tp_disabled_module in mapped_weight_name
|
||||
for tp_disabled_module in self.tp_disabled_modules
|
||||
):
|
||||
tp_size = 1
|
||||
tp_rank = 0
|
||||
else:
|
||||
tp_size = global_tp_size
|
||||
tp_rank = global_tp_rank
|
||||
|
||||
if any(
|
||||
target_module in mapped_weight_name
|
||||
for target_module in self.target_modules
|
||||
) and mapped_weight_name.endswith(".weight"):
|
||||
# Without sharding
|
||||
if any(
|
||||
check_match(mapped_weight_name, module)
|
||||
for module in self.unsharded_weights_modules
|
||||
):
|
||||
weight_sub_tensor = weight_tensor
|
||||
# Shard by column
|
||||
elif any(
|
||||
check_match(mapped_weight_name, module)
|
||||
for module in self.column_sharded_weights_modules
|
||||
):
|
||||
total_size = weight_tensor.size(-1)
|
||||
start_index = total_size // tp_size * tp_rank
|
||||
end_index = total_size // tp_size * (tp_rank + 1)
|
||||
weight_sub_tensor = weight_tensor[..., start_index:end_index]
|
||||
# Weights have fused on disk. In this case, we assume that the
|
||||
# weight and module use same name.
|
||||
elif any(
|
||||
check_match(mapped_weight_name, module)
|
||||
for module in self.maybe_fused_weights_modules
|
||||
):
|
||||
# special case for fused weights
|
||||
# get the size of each shard weight tensor
|
||||
total_shard_sizes = next(
|
||||
(
|
||||
sizes
|
||||
for module, sizes in self.maybe_fused_weights_modules.items() # noqa: E501
|
||||
if check_match(mapped_weight_name, module)
|
||||
)
|
||||
)
|
||||
total_size = weight_tensor.size(0)
|
||||
assert total_size == sum(total_shard_sizes)
|
||||
# get the start/end index of each shard weight tensor
|
||||
total_start_index = list(
|
||||
itertools.accumulate([0] + total_shard_sizes)
|
||||
)[:-1]
|
||||
shard_weights_index = [
|
||||
(
|
||||
idx + size // tp_size * tp_rank,
|
||||
idx + size // tp_size * (tp_rank + 1),
|
||||
)
|
||||
for idx, size in zip(total_start_index, total_shard_sizes)
|
||||
]
|
||||
# slice and reorder the weight tensor
|
||||
weight_tensor = [
|
||||
weight_tensor[start_index:end_index, ...]
|
||||
for start_index, end_index in shard_weights_index
|
||||
]
|
||||
weight_sub_tensor = torch.cat(weight_tensor, dim=0)
|
||||
# Shard by row
|
||||
else:
|
||||
total_size = weight_tensor.size(0)
|
||||
start_index = total_size // tp_size * tp_rank
|
||||
end_index = total_size // tp_size * (tp_rank + 1)
|
||||
weight_sub_tensor = weight_tensor[start_index:end_index, ...]
|
||||
|
||||
# bitsandbytes requires data in GPU
|
||||
if weight_sub_tensor.is_cuda:
|
||||
loaded_weight = weight_sub_tensor
|
||||
else:
|
||||
loaded_weight = weight_sub_tensor.to(
|
||||
device=current_platform.device_type
|
||||
)
|
||||
|
||||
# remove the following after the issue is fixed:
|
||||
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
|
||||
if loaded_weight.is_contiguous() is False:
|
||||
loaded_weight = loaded_weight.contiguous()
|
||||
|
||||
with set_default_torch_dtype(torch.float32):
|
||||
processed_weight, quant_state = quantize_4bit(
|
||||
loaded_weight,
|
||||
compress_statistics=True,
|
||||
quant_type="nf4",
|
||||
)
|
||||
|
||||
quant_state_dict[mapped_weight_name] = quant_state
|
||||
else:
|
||||
processed_weight = weight_tensor
|
||||
yield org_weight_name, processed_weight
|
||||
|
||||
def _get_bnb_target_modules(self, model: nn.Module) -> None:
|
||||
"""
|
||||
Identify and collect all modules that support BitsAndBytes
|
||||
quantization.
|
||||
"""
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LinearBase) and hasattr(
|
||||
module.quant_method, "quant_config"
|
||||
):
|
||||
if modules_info := self.modules_mapping.get_sub_modules(name):
|
||||
# Map vllm's names to transformers's names.
|
||||
rep_name, sub_modules = modules_info
|
||||
for sub_name in sub_modules:
|
||||
new_name = name.replace(rep_name, sub_name)
|
||||
self.target_modules.append(new_name)
|
||||
if module.disable_tp:
|
||||
self.tp_disabled_modules.append(new_name)
|
||||
# Add original module name even if the module has stacked map,
|
||||
# in case model has a mixture of disk-merged and disk-split
|
||||
# weights with same last name.
|
||||
self.target_modules.append(name)
|
||||
if module.disable_tp:
|
||||
self.tp_disabled_modules.append(name)
|
||||
elif isinstance(module, FusedMoE) and hasattr(
|
||||
module.quant_method, "quant_config"
|
||||
):
|
||||
# TODO: support FusedMoE with prequant and 8bit.
|
||||
if self.pre_quant and self.load_8bit:
|
||||
raise ValueError(
|
||||
"Prequant BitsAndBytes 8bit models with FusedMoE "
|
||||
"is not supported yet."
|
||||
)
|
||||
# Get the corresponding weight name using module name and
|
||||
# expert_params_mapping.
|
||||
|
||||
for exp in self.expert_params_mapping:
|
||||
weight_name = exp[1]
|
||||
rep_name = name.replace("experts", "") + weight_name.removesuffix(
|
||||
"."
|
||||
)
|
||||
self.target_modules.append(rep_name)
|
||||
|
||||
assert self.target_modules, (
|
||||
"vLLM currently does not support BNB quantization for"
|
||||
)
|
||||
f" {type(model).__name__}"
|
||||
|
||||
def _classify_module_sharding(self, model: nn.Module):
|
||||
"""
|
||||
Categorize modules based on their weight sharding requirements
|
||||
for tensor parallelism.
|
||||
"""
|
||||
for name, module in model.named_modules():
|
||||
# Some modules like `ReplicatedLinear` should not have their weights
|
||||
# sharded. The reason for implementing it this way is to avoid new
|
||||
# static variable in the model implementation.
|
||||
if isinstance(module, (ReplicatedLinear,)):
|
||||
self.unsharded_weights_modules.append(name)
|
||||
# `QKVParallelLinear` and `MergedColumnParallelLinear` might have
|
||||
# fused weights on disk. We need to use the output sizes of these
|
||||
# modules to shard the weights correctly.
|
||||
elif isinstance(module, (QKVParallelLinear, MergedColumnParallelLinear)):
|
||||
self.maybe_fused_weights_modules[name] = module.output_sizes
|
||||
# In TP, these weights are partitioned along the column
|
||||
# dimension (dim=-1)
|
||||
elif isinstance(module, (RowParallelLinear,)):
|
||||
self.column_sharded_weights_modules.append(name)
|
||||
elif isinstance(module, FusedMoE):
|
||||
expert_mapping = self.expert_params_mapping
|
||||
for exp in expert_mapping:
|
||||
if exp[-1] == "w2":
|
||||
weight_name = exp[1]
|
||||
rep_name = name.replace(
|
||||
"experts", ""
|
||||
) + weight_name.removesuffix(".")
|
||||
self.column_sharded_weights_modules.append(rep_name)
|
||||
|
||||
def _verify_model_compatibility(
|
||||
self, model: nn.Module, model_config: ModelConfig
|
||||
) -> None:
|
||||
"""
|
||||
Verify that the model is compatible with BitsAndBytes quantization.
|
||||
"""
|
||||
if not hasattr(model, "load_weights"):
|
||||
raise AttributeError(
|
||||
"The required method 'load_weights' is not defined in class"
|
||||
f" {type(model).__name__}."
|
||||
)
|
||||
|
||||
if not hasattr(model, "packed_modules_mapping"):
|
||||
raise AttributeError(
|
||||
f"Model {type(model).__name__} does not support BitsAndBytes "
|
||||
"quantization yet. No 'packed_modules_mapping' found."
|
||||
)
|
||||
|
||||
quant_config = getattr(model_config.hf_config, "quantization_config", None)
|
||||
if quant_config and (quant_method := quant_config.get("quant_method")):
|
||||
if quant_method == "bitsandbytes":
|
||||
self.pre_quant = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"BitsAndBytes loader does not support {quant_method} quantization"
|
||||
)
|
||||
|
||||
# The quant_states in pre_quantized models cannot work with a split
|
||||
# weight tensor. So TP does not work with pre_quantized bnb models.
|
||||
if self.pre_quant and get_tensor_model_parallel_world_size() > 1:
|
||||
raise ValueError(
|
||||
"Prequant BitsAndBytes models with tensor parallelism is not "
|
||||
"supported. Please try with pipeline parallelism."
|
||||
)
|
||||
if quant_config and self.pre_quant:
|
||||
self.load_8bit = quant_config.get("load_in_8bit", False)
|
||||
|
||||
def _initialize_loader_state(
|
||||
self, model: nn.Module, model_config: ModelConfig
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the loader's internal state based on the model and
|
||||
configuration.
|
||||
"""
|
||||
self.is_pool_model = is_pooling_model(model)
|
||||
self.modules_mapping = ParamMapping(get_packed_modules_mapping(model))
|
||||
|
||||
if is_moe_model(model):
|
||||
self.expert_params_mapping = get_moe_expert_mapping(model)
|
||||
if not self.expert_params_mapping:
|
||||
raise AttributeError(
|
||||
f"MoE Model {type(model).__name__} does not support "
|
||||
"BitsAndBytes quantization yet. Ensure this model has "
|
||||
"'get_expert_mapping' method."
|
||||
)
|
||||
# For some models like Molmo, we need to use hf_to_vllm_mapper
|
||||
# to ensure correct loading of weights.
|
||||
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
|
||||
self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name)
|
||||
|
||||
self._get_bnb_target_modules(model)
|
||||
self._classify_module_sharding(model)
|
||||
|
||||
def _dequantize_dq(self, quant_states: Any):
|
||||
"""
|
||||
When BNB employs Double Quantization, we perform the dequantization of
|
||||
these constants during weight loading rather than at inference time,
|
||||
thereby avoiding this computational overhead during inference. This
|
||||
comes at the cost of increased memory usage.
|
||||
"""
|
||||
from bitsandbytes.functional import QuantState, dequantize_blockwise
|
||||
|
||||
def _dequantize_single_state(quant_state):
|
||||
"""Helper function to dequantize a single QuantState object."""
|
||||
if not (isinstance(quant_state, QuantState) and quant_state.nested):
|
||||
return
|
||||
|
||||
# Copied from: https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.45.3/bitsandbytes/functional.py#L1352-#L1356
|
||||
absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
|
||||
absmax += quant_state.offset
|
||||
|
||||
# Ensure float32 dtype
|
||||
if absmax.dtype != torch.float32:
|
||||
absmax = absmax.float()
|
||||
|
||||
quant_state.absmax = absmax
|
||||
quant_state.nested = False
|
||||
quant_state.offset = None
|
||||
quant_state.state2 = None
|
||||
|
||||
if isinstance(quant_states, dict):
|
||||
for quant_state in quant_states.values():
|
||||
_dequantize_single_state(quant_state)
|
||||
else:
|
||||
_dequantize_single_state(quant_states)
|
||||
return quant_states
|
||||
|
||||
def _fuse_moe_quant_states(self, model: nn.Module, quant_states_dict: dict) -> dict:
|
||||
"""
|
||||
|
||||
This function consolidates individual expert quantization states into
|
||||
fused representations for w13 and w2.
|
||||
"""
|
||||
from bitsandbytes.functional import QuantState
|
||||
|
||||
if not self.expert_params_mapping:
|
||||
return dict()
|
||||
|
||||
expert_mapping = self.expert_params_mapping
|
||||
expert_qs_dict = {}
|
||||
for name, module in model.named_modules():
|
||||
if not isinstance(module, FusedMoE):
|
||||
continue
|
||||
w1_states_lst = []
|
||||
w2_states_lst = []
|
||||
w3_states_lst = []
|
||||
for exp in expert_mapping:
|
||||
shard_id = exp[-1]
|
||||
if shard_id not in ("w1", "w2", "w3"):
|
||||
raise ValueError(
|
||||
f"shard_id must be ['w1','w2','w3'] but got {shard_id}."
|
||||
)
|
||||
layer_prefix = name.split("experts")[0]
|
||||
weight_qual_name = layer_prefix + exp[1] + "weight"
|
||||
quant_state = self._dequantize_dq(quant_states_dict[weight_qual_name])
|
||||
if shard_id == "w1":
|
||||
w1_states_lst.append(quant_state)
|
||||
elif shard_id == "w2":
|
||||
w2_states_lst.append(quant_state)
|
||||
else:
|
||||
w3_states_lst.append(quant_state)
|
||||
del quant_states_dict[weight_qual_name]
|
||||
assert len(w1_states_lst) == len(w2_states_lst) == len(w3_states_lst)
|
||||
w13_absmax_lst = []
|
||||
w2_absmax_lst = []
|
||||
w13_total_dim0 = 0
|
||||
w2_total_dim0 = 0
|
||||
for w1_qs, w2_qs, w3_qs in zip(w1_states_lst, w2_states_lst, w3_states_lst):
|
||||
assert w1_qs.shape == w3_qs.shape
|
||||
assert w1_qs.blocksize == w2_qs.blocksize == w3_qs.blocksize
|
||||
assert w1_qs.dtype == w2_qs.dtype == w3_qs.dtype
|
||||
# w1 and w3 are interleaved in storage
|
||||
w13_absmax_lst.append(w1_qs.absmax)
|
||||
w13_absmax_lst.append(w3_qs.absmax)
|
||||
w2_absmax_lst.append(w2_qs.absmax)
|
||||
w13_total_dim0 += w1_qs.shape[0] + w3_qs.shape[0]
|
||||
w2_total_dim0 += w2_qs.shape[0]
|
||||
|
||||
w13_absmax = torch.cat(w13_absmax_lst)
|
||||
w2_absmax = torch.cat(w2_absmax_lst)
|
||||
# Create fused quantization state for w13.
|
||||
w13_qs = QuantState(
|
||||
absmax=w13_absmax,
|
||||
shape=(w13_total_dim0, w1_states_lst[0].shape[1]),
|
||||
code=w1_states_lst[0].code,
|
||||
blocksize=w1_states_lst[0].blocksize,
|
||||
quant_type="nf4",
|
||||
dtype=w1_states_lst[0].dtype,
|
||||
)
|
||||
# Create fused quantization state for w2.
|
||||
w2_qs = QuantState(
|
||||
absmax=w2_absmax,
|
||||
shape=(w2_total_dim0, w2_states_lst[0].shape[1]),
|
||||
code=w2_states_lst[0].code,
|
||||
blocksize=w2_states_lst[0].blocksize,
|
||||
quant_type="nf4",
|
||||
dtype=w2_states_lst[0].dtype,
|
||||
)
|
||||
# The weight suffixes .w13_weight and .w2_weight are consistent
|
||||
# with the param in BitsAndBytesMoEMethod.
|
||||
w13_weight_name = name + ".w13_weight"
|
||||
w2_weight_name = name + ".w2_weight"
|
||||
expert_qs_dict[w13_weight_name] = w13_qs
|
||||
expert_qs_dict[w2_weight_name] = w2_qs
|
||||
return expert_qs_dict
|
||||
|
||||
def _stack_quantization_states(
|
||||
self, model: nn.Module, quant_state_dict: dict
|
||||
) -> dict[str, dict[int, Any]]:
|
||||
stacked_quant_state_dict: dict[str, dict[int, Any]] = {}
|
||||
# TODO: Change this lazy import to normal import
|
||||
# after the checks are updated to run on a new version
|
||||
from vllm.model_executor.models.utils import is_pp_missing_parameter
|
||||
|
||||
param_dict = dict(model.named_parameters())
|
||||
for quant_param_name in quant_state_dict:
|
||||
if is_pp_missing_parameter(quant_param_name, model):
|
||||
continue
|
||||
|
||||
non_stacked_param_name = quant_param_name
|
||||
|
||||
shard_index = 0
|
||||
for shard_name, (
|
||||
weight_name,
|
||||
index,
|
||||
) in self.modules_mapping.inverse_packed_mapping.items():
|
||||
# Some models, such as MiniCPM V2.5/2.6, contain both
|
||||
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
|
||||
# from being incorrectly identified as being present in
|
||||
# 'vpm.encoder.layers.0.self_attn.qkv_proj.weight
|
||||
shard_pos = quant_param_name.find(shard_name)
|
||||
can_correct_rename = (shard_pos > 0) and (
|
||||
quant_param_name[shard_pos - 1] == "."
|
||||
)
|
||||
# If the quant_param_name is packed, it won't occur in the
|
||||
# param_dict before renaming.
|
||||
new_quant_param_name = quant_param_name.replace(shard_name, weight_name)
|
||||
need_rename = (quant_param_name not in param_dict) and (
|
||||
new_quant_param_name in param_dict
|
||||
)
|
||||
if can_correct_rename and need_rename:
|
||||
shard_index = index
|
||||
quant_param_name = new_quant_param_name
|
||||
break
|
||||
|
||||
# Models like Clip/Siglip may skip some layers in initialization,
|
||||
# causing unused quant_param_name in state_dict.
|
||||
if quant_param_name not in param_dict:
|
||||
continue
|
||||
|
||||
if quant_param_name not in stacked_quant_state_dict:
|
||||
stacked_quant_state_dict[quant_param_name] = {}
|
||||
|
||||
stacked_quant_state_dict[quant_param_name][shard_index] = quant_state_dict[
|
||||
non_stacked_param_name
|
||||
]
|
||||
return stacked_quant_state_dict
|
||||
|
||||
def _bind_quant_states_to_params(
|
||||
self, model: nn.Module, stacked_quant_state_dict: dict
|
||||
) -> None:
|
||||
# save quant_states and offsets as the attributes of the parameters
|
||||
param_dict = dict(model.named_parameters())
|
||||
for param_name, param in param_dict.items():
|
||||
if param_name in stacked_quant_state_dict:
|
||||
quant_states = stacked_quant_state_dict[param_name]
|
||||
# Dequantize double quantized values during weight loading.
|
||||
self._dequantize_dq(quant_states)
|
||||
set_weight_attrs(param, {"bnb_quant_state": quant_states})
|
||||
if not isinstance(quant_states, dict):
|
||||
continue
|
||||
|
||||
pack_ratio = getattr(param, "pack_factor", -1)
|
||||
if pack_ratio == -1:
|
||||
raise ValueError(f"pack_factor not set for parameter {param_name}.")
|
||||
|
||||
num_elements = [0] * len(quant_states)
|
||||
for seq, quant_state in quant_states.items():
|
||||
num_elements[seq] = math.prod(quant_state.shape) // pack_ratio
|
||||
|
||||
offsets = np.concatenate(([0], np.cumsum(num_elements)))
|
||||
# Make torch infer_schema happy
|
||||
offsets = torch.tensor(offsets).cpu()
|
||||
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
|
||||
|
||||
if self.load_8bit:
|
||||
set_weight_attrs(
|
||||
param, {"matmul_state": [None] * len(quant_states)}
|
||||
)
|
||||
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
self._verify_model_compatibility(model, model_config)
|
||||
self._initialize_loader_state(model, model_config)
|
||||
|
||||
logger.info(
|
||||
"Loading weights with BitsAndBytes quantization. May take a while ..."
|
||||
)
|
||||
qweight_iterator, quant_state_dict = self._get_quantized_weights_iterator(
|
||||
model_config.model,
|
||||
model_config.revision,
|
||||
)
|
||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||
loaded_weights = model.load_weights(qweight_iterator)
|
||||
# Some models may have weights loading tracker unimplemented.
|
||||
if loaded_weights is not None:
|
||||
weights_not_loaded = weights_to_load - loaded_weights
|
||||
if weights_not_loaded:
|
||||
raise ValueError(
|
||||
"Following weights were not initialized from "
|
||||
f"checkpoint: {weights_not_loaded}"
|
||||
)
|
||||
expert_quant_state_dict = self._fuse_moe_quant_states(model, quant_state_dict)
|
||||
|
||||
stacked_quant_state_dict = self._stack_quantization_states(
|
||||
model, quant_state_dict
|
||||
)
|
||||
|
||||
stacked_quant_state_dict = {
|
||||
**expert_quant_state_dict,
|
||||
**stacked_quant_state_dict,
|
||||
}
|
||||
self._bind_quant_states_to_params(model, stacked_quant_state_dict)
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(model_config.model, model_config.revision)
|
||||
321
vllm/model_executor/model_loader/default_loader.py
Normal file
321
vllm/model_executor/model_loader/default_loader.py
Normal file
@@ -0,0 +1,321 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import dataclasses
|
||||
import glob
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Generator, Iterable
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.torchao import torchao_version_at_least
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_safetensors_index_file_from_hf,
|
||||
download_weights_from_hf,
|
||||
fastsafetensors_weights_iterator,
|
||||
filter_duplicate_safetensors_files,
|
||||
filter_files_not_needed_for_inference,
|
||||
get_quant_config,
|
||||
maybe_download_from_modelscope,
|
||||
multi_thread_pt_weights_iterator,
|
||||
multi_thread_safetensors_weights_iterator,
|
||||
np_cache_weights_iterator,
|
||||
pt_weights_iterator,
|
||||
safetensors_weights_iterator,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DefaultModelLoader(BaseModelLoader):
|
||||
"""Model loader that can load different file types from disk."""
|
||||
|
||||
# default number of thread when enable multithread weight loading
|
||||
DEFAULT_NUM_THREADS = 8
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Source:
|
||||
"""A source for weights."""
|
||||
|
||||
model_or_path: str
|
||||
"""The model ID or path."""
|
||||
|
||||
revision: str | None
|
||||
"""The optional model revision."""
|
||||
|
||||
prefix: str = ""
|
||||
"""A prefix to prepend to all weights."""
|
||||
|
||||
fall_back_to_pt: bool = True
|
||||
"""Whether .pt weights can be used."""
|
||||
|
||||
allow_patterns_overrides: list[str] | None = None
|
||||
"""If defined, weights will load exclusively using these patterns."""
|
||||
|
||||
counter_before_loading_weights: float = 0.0
|
||||
counter_after_loading_weights: float = 0.0
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
|
||||
extra_config = load_config.model_loader_extra_config
|
||||
allowed_keys = {"enable_multithread_load", "num_threads"}
|
||||
unexpected_keys = set(extra_config.keys()) - allowed_keys
|
||||
|
||||
if unexpected_keys:
|
||||
raise ValueError(
|
||||
f"Unexpected extra config keys for load format "
|
||||
f"{load_config.load_format}: "
|
||||
f"{unexpected_keys}"
|
||||
)
|
||||
|
||||
def _prepare_weights(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
revision: str | None,
|
||||
fall_back_to_pt: bool,
|
||||
allow_patterns_overrides: list[str] | None,
|
||||
) -> tuple[str, list[str], bool]:
|
||||
"""Prepare weights for the model.
|
||||
|
||||
If the model is not local, it will be downloaded."""
|
||||
model_name_or_path = (
|
||||
maybe_download_from_modelscope(model_name_or_path, revision)
|
||||
or model_name_or_path
|
||||
)
|
||||
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
load_format = self.load_config.load_format
|
||||
use_safetensors = False
|
||||
index_file = SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
# First check for 'auto' format that mistral files format are present.
|
||||
# This is to load mistral models with official format by default.
|
||||
if load_format == "auto":
|
||||
load_format = (
|
||||
"mistral"
|
||||
if len(
|
||||
list_filtered_repo_files(
|
||||
model_name_or_path=model_name_or_path,
|
||||
allow_patterns=["consolidated*.safetensors"],
|
||||
revision=revision,
|
||||
)
|
||||
)
|
||||
> 0
|
||||
else "hf"
|
||||
)
|
||||
|
||||
# Some quantized models use .pt files for storing the weights.
|
||||
if load_format == "hf":
|
||||
allow_patterns = ["*.safetensors", "*.bin"]
|
||||
elif load_format == "safetensors" or load_format == "fastsafetensors":
|
||||
use_safetensors = True
|
||||
allow_patterns = ["*.safetensors"]
|
||||
elif load_format == "mistral":
|
||||
use_safetensors = True
|
||||
allow_patterns = ["consolidated*.safetensors"]
|
||||
index_file = "consolidated.safetensors.index.json"
|
||||
elif load_format == "pt":
|
||||
allow_patterns = ["*.pt"]
|
||||
elif load_format == "npcache":
|
||||
allow_patterns = ["*.bin"]
|
||||
else:
|
||||
raise ValueError(f"Unknown load_format: {load_format}")
|
||||
|
||||
if fall_back_to_pt:
|
||||
allow_patterns += ["*.pt"]
|
||||
|
||||
if allow_patterns_overrides is not None:
|
||||
allow_patterns = allow_patterns_overrides
|
||||
|
||||
if not is_local:
|
||||
hf_folder = download_weights_from_hf(
|
||||
model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
allow_patterns,
|
||||
revision,
|
||||
ignore_patterns=self.load_config.ignore_patterns,
|
||||
)
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
|
||||
hf_weights_files: list[str] = []
|
||||
for pattern in allow_patterns:
|
||||
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
||||
if len(hf_weights_files) > 0:
|
||||
if pattern == "*.safetensors":
|
||||
use_safetensors = True
|
||||
break
|
||||
|
||||
if use_safetensors:
|
||||
# For models like Mistral-7B-Instruct-v0.3
|
||||
# there are both sharded safetensors files and a consolidated
|
||||
# safetensors file. Using both breaks.
|
||||
# Here, we download the `model.safetensors.index.json` and filter
|
||||
# any files not found in the index.
|
||||
if not is_local:
|
||||
download_safetensors_index_file_from_hf(
|
||||
model_name_or_path,
|
||||
index_file,
|
||||
self.load_config.download_dir,
|
||||
revision,
|
||||
)
|
||||
hf_weights_files = filter_duplicate_safetensors_files(
|
||||
hf_weights_files, hf_folder, index_file
|
||||
)
|
||||
else:
|
||||
hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files)
|
||||
|
||||
if len(hf_weights_files) == 0:
|
||||
raise RuntimeError(
|
||||
f"Cannot find any model weights with `{model_name_or_path}`"
|
||||
)
|
||||
|
||||
return hf_folder, hf_weights_files, use_safetensors
|
||||
|
||||
def _get_weights_iterator(
|
||||
self, source: "Source"
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Get an iterator for the model weights based on the load format."""
|
||||
extra_config = self.load_config.model_loader_extra_config
|
||||
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
|
||||
source.model_or_path,
|
||||
source.revision,
|
||||
source.fall_back_to_pt,
|
||||
source.allow_patterns_overrides,
|
||||
)
|
||||
if self.load_config.load_format == "npcache":
|
||||
# Currently np_cache only support *.bin checkpoints
|
||||
assert use_safetensors is False
|
||||
weights_iterator = np_cache_weights_iterator(
|
||||
source.model_or_path,
|
||||
self.load_config.download_dir,
|
||||
hf_folder,
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
)
|
||||
elif use_safetensors:
|
||||
if self.load_config.load_format == "fastsafetensors":
|
||||
weights_iterator = fastsafetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
)
|
||||
else:
|
||||
if extra_config.get("enable_multithread_load"):
|
||||
weights_iterator = multi_thread_safetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
max_workers=extra_config.get(
|
||||
"num_threads", self.DEFAULT_NUM_THREADS
|
||||
),
|
||||
)
|
||||
else:
|
||||
weights_iterator = safetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
self.load_config.safetensors_load_strategy,
|
||||
)
|
||||
else:
|
||||
if extra_config.get("enable_multithread_load"):
|
||||
weights_iterator = multi_thread_pt_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
self.load_config.pt_load_map_location,
|
||||
max_workers=extra_config.get(
|
||||
"num_threads", self.DEFAULT_NUM_THREADS
|
||||
),
|
||||
)
|
||||
else:
|
||||
weights_iterator = pt_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
self.load_config.pt_load_map_location,
|
||||
)
|
||||
|
||||
if current_platform.is_tpu():
|
||||
from vllm.platforms.tpu import USE_TPU_INFERENCE
|
||||
|
||||
if not USE_TPU_INFERENCE:
|
||||
# In PyTorch XLA, we should call `torch_xla.sync`
|
||||
# frequently so that not too many ops are accumulated
|
||||
# in the XLA program.
|
||||
import torch_xla
|
||||
|
||||
def _xla_weights_iterator(iterator: Generator):
|
||||
for weights in iterator:
|
||||
yield weights
|
||||
torch_xla.sync(wait=False)
|
||||
|
||||
weights_iterator = _xla_weights_iterator(weights_iterator)
|
||||
|
||||
if self.counter_before_loading_weights == 0.0:
|
||||
self.counter_before_loading_weights = time.perf_counter()
|
||||
# Apply the prefix.
|
||||
return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator)
|
||||
|
||||
def get_all_weights(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
model: nn.Module,
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
primary_weights = DefaultModelLoader.Source(
|
||||
model_config.model,
|
||||
model_config.revision,
|
||||
prefix="",
|
||||
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True),
|
||||
allow_patterns_overrides=getattr(model, "allow_patterns_overrides", None),
|
||||
)
|
||||
yield from self._get_weights_iterator(primary_weights)
|
||||
|
||||
secondary_weights = cast(
|
||||
Iterable[DefaultModelLoader.Source],
|
||||
getattr(model, "secondary_weights", ()),
|
||||
)
|
||||
for source in secondary_weights:
|
||||
yield from self._get_weights_iterator(source)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(
|
||||
model_config.model,
|
||||
model_config.revision,
|
||||
fall_back_to_pt=True,
|
||||
allow_patterns_overrides=None,
|
||||
)
|
||||
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
if model_config.quantization == "torchao":
|
||||
quant_config = get_quant_config(model_config, self.load_config)
|
||||
if (
|
||||
hasattr(quant_config, "is_checkpoint_torchao_serialized")
|
||||
and quant_config.is_checkpoint_torchao_serialized
|
||||
and torchao_version_at_least("0.15.0")
|
||||
):
|
||||
self.load_config.safetensors_load_strategy = "torchao"
|
||||
|
||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||
loaded_weights = model.load_weights(self.get_all_weights(model_config, model))
|
||||
|
||||
self.counter_after_loading_weights = time.perf_counter()
|
||||
logger.info_once(
|
||||
"Loading weights took %.2f seconds",
|
||||
self.counter_after_loading_weights - self.counter_before_loading_weights,
|
||||
scope="local",
|
||||
)
|
||||
# We only enable strict check for non-quantized models
|
||||
# that have loaded weights tracking currently.
|
||||
if model_config.quantization is None and loaded_weights is not None:
|
||||
weights_not_loaded = weights_to_load - loaded_weights
|
||||
if weights_not_loaded:
|
||||
raise ValueError(
|
||||
"Following weights were not initialized from "
|
||||
f"checkpoint: {weights_not_loaded}"
|
||||
)
|
||||
28
vllm/model_executor/model_loader/dummy_loader.py
Normal file
28
vllm/model_executor/model_loader/dummy_loader.py
Normal file
@@ -0,0 +1,28 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import initialize_dummy_weights
|
||||
|
||||
|
||||
class DummyModelLoader(BaseModelLoader):
|
||||
"""Model loader that will set model weights to random values."""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
if load_config.model_loader_extra_config:
|
||||
raise ValueError(
|
||||
f"Model loader extra config is not supported for "
|
||||
f"load format {load_config.load_format}"
|
||||
)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
pass # Nothing to download
|
||||
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||
# random values to the weights.
|
||||
initialize_dummy_weights(model)
|
||||
371
vllm/model_executor/model_loader/gguf_loader.py
Normal file
371
vllm/model_executor/model_loader/gguf_loader.py
Normal file
@@ -0,0 +1,371 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
import gguf
|
||||
import regex as re
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import AutoModelForCausalLM, AutoModelForImageTextToText
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
initialize_model,
|
||||
process_weights_after_loading,
|
||||
)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_gguf,
|
||||
get_gguf_extra_tensor_names,
|
||||
get_gguf_weight_type_map,
|
||||
gguf_quant_weights_iterator,
|
||||
)
|
||||
from vllm.transformers_utils.gguf_utils import detect_gguf_multimodal
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class GGUFModelLoader(BaseModelLoader):
|
||||
"""
|
||||
Model loader that can load GGUF files. This is useful for loading models
|
||||
that are quantized with GGUF and saved in the GGUF format. This loader
|
||||
supports loading both full models and sharded models.
|
||||
"""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
if load_config.model_loader_extra_config:
|
||||
raise ValueError(
|
||||
f"Model loader extra config is not supported for "
|
||||
f"load format {load_config.load_format}"
|
||||
)
|
||||
|
||||
def _prepare_weights(self, model_config: ModelConfig):
|
||||
model_name_or_path = model_config.model
|
||||
if os.path.isfile(model_name_or_path):
|
||||
return model_name_or_path
|
||||
# for raw HTTPS link
|
||||
if model_name_or_path.startswith(
|
||||
("http://", "https://")
|
||||
) and model_name_or_path.endswith(".gguf"):
|
||||
return hf_hub_download(url=model_name_or_path)
|
||||
# repo id/filename.gguf
|
||||
if "/" in model_name_or_path and model_name_or_path.endswith(".gguf"):
|
||||
repo_id, filename = model_name_or_path.rsplit("/", 1)
|
||||
return hf_hub_download(repo_id=repo_id, filename=filename)
|
||||
# repo_id:quant_type
|
||||
elif "/" in model_name_or_path and ":" in model_name_or_path:
|
||||
repo_id, quant_type = model_name_or_path.rsplit(":", 1)
|
||||
return download_gguf(
|
||||
repo_id,
|
||||
quant_type,
|
||||
cache_dir=self.load_config.download_dir,
|
||||
revision=model_config.revision,
|
||||
ignore_patterns=self.load_config.ignore_patterns,
|
||||
)
|
||||
|
||||
raise ValueError(
|
||||
f"Unrecognised GGUF reference: {model_name_or_path} "
|
||||
"(expected local file, raw URL, <repo_id>/<filename>.gguf, "
|
||||
"or <repo_id>:<quant_type>)"
|
||||
)
|
||||
|
||||
def _get_gguf_weights_map(self, model_config: ModelConfig):
|
||||
"""
|
||||
GGUF uses this naming convention for their tensors from HF checkpoint:
|
||||
`blk.N.BB.weight` and `blk.N.BB.bias`
|
||||
where N signifies the block number of a layer, and BB signifies the
|
||||
attention/mlp layer components.
|
||||
See "Standardized tensor names" in
|
||||
https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
|
||||
"""
|
||||
config = model_config.hf_config
|
||||
# Get text config to handle both nested (multimodal) and flat
|
||||
# (text-only) config structures. For multimodal models like
|
||||
# Gemma3Config, this returns config.text_config. For text-only
|
||||
# models, this returns config itself.
|
||||
text_config = config.get_text_config()
|
||||
model_type = config.model_type
|
||||
is_multimodal = (
|
||||
hasattr(config, "vision_config") and config.vision_config is not None
|
||||
)
|
||||
gguf_to_hf_name_map = {}
|
||||
sideload_params: list[re.Pattern] = []
|
||||
# hack: ggufs have a different name than transformers
|
||||
if model_type == "cohere":
|
||||
model_type = "command-r"
|
||||
if model_type == "gemma3_text":
|
||||
# Gemma3 models use "gemma3_text" in HuggingFace but
|
||||
# "gemma3" in GGUF architecture naming
|
||||
model_type = "gemma3"
|
||||
if model_type in ("deepseek_v3", "deepseek_v2"):
|
||||
model_type = "deepseek2"
|
||||
# GGUF layer map assumes that we will have a merged expert weights
|
||||
# so we need to map them manually
|
||||
for idx in range(config.num_hidden_layers):
|
||||
gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = (
|
||||
f"model.layers.{idx}.mlp.gate.e_score_correction_bias"
|
||||
)
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = (
|
||||
f"model.layers.{idx}.mlp.experts.0.down_proj.weight"
|
||||
)
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = (
|
||||
f"model.layers.{idx}.mlp.experts.0.gate_proj.weight"
|
||||
)
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = (
|
||||
f"model.layers.{idx}.mlp.experts.0.up_proj.weight"
|
||||
)
|
||||
sideload_params.append(
|
||||
re.compile(
|
||||
f"model\\.layers\\.{idx}"
|
||||
r"\.mlp\.experts\.[0-9]+\.(gate|up|down)_proj\.weight"
|
||||
)
|
||||
)
|
||||
if model_type in ("qwen2_moe", "qwen3_moe"):
|
||||
model_type = model_type.replace("_", "")
|
||||
# GGUF layer map assumes that we will have a merged expert weights
|
||||
# so we need to map them manually
|
||||
for idx in range(config.num_hidden_layers):
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = (
|
||||
f"model.layers.{idx}.mlp.experts.0.down_proj.weight"
|
||||
)
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = (
|
||||
f"model.layers.{idx}.mlp.experts.0.gate_proj.weight"
|
||||
)
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = (
|
||||
f"model.layers.{idx}.mlp.experts.0.up_proj.weight"
|
||||
)
|
||||
sideload_params.append(
|
||||
re.compile(
|
||||
f"model\\.layers\\.{idx}"
|
||||
r"\.mlp\.experts\.[0-9]+\.(gate|up|down)_proj\.weight"
|
||||
)
|
||||
)
|
||||
|
||||
arch = None
|
||||
for key, value in gguf.MODEL_ARCH_NAMES.items():
|
||||
if value == model_type:
|
||||
arch = key
|
||||
break
|
||||
if arch is None:
|
||||
raise RuntimeError(f"Unknown gguf model_type: {model_type}")
|
||||
text_num_layers = text_config.num_hidden_layers
|
||||
text_name_map = gguf.get_tensor_name_map(arch, text_num_layers)
|
||||
|
||||
if is_multimodal:
|
||||
mm_proj_arch = gguf.MODEL_ARCH.MMPROJ
|
||||
vision_num_layers = config.vision_config.num_hidden_layers
|
||||
vision_name_map = gguf.get_tensor_name_map(mm_proj_arch, vision_num_layers)
|
||||
else:
|
||||
vision_name_map = None
|
||||
|
||||
# Create dummy model to extract parameter names
|
||||
# For multimodal: use AutoModelForImageTextToText to get
|
||||
# language + vision + projector params
|
||||
# For text-only: use AutoModelForCausalLM to get language model params
|
||||
auto_cls = (
|
||||
AutoModelForImageTextToText if is_multimodal else AutoModelForCausalLM
|
||||
)
|
||||
with torch.device("meta"):
|
||||
dummy_model = auto_cls.from_config(
|
||||
config, trust_remote_code=model_config.trust_remote_code
|
||||
)
|
||||
|
||||
state_dict = dummy_model.state_dict()
|
||||
if hf_checkpoint_map := getattr(
|
||||
dummy_model, "_checkpoint_conversion_mapping", None
|
||||
):
|
||||
|
||||
def revert_hf_rename(name: str) -> str:
|
||||
for original_name, hf_name in hf_checkpoint_map.items():
|
||||
if hf_name in name:
|
||||
name = name.replace(hf_name, original_name).lstrip("^")
|
||||
return name
|
||||
|
||||
state_dict = {
|
||||
revert_hf_rename(name): tensor for name, tensor in state_dict.items()
|
||||
}
|
||||
|
||||
def find_hf_name_in_tensor_map(hf_name: str) -> str | None:
|
||||
"""
|
||||
Map HuggingFace parameter name to GGUF tensor name.
|
||||
|
||||
This function handles the mismatch between HF parameter naming
|
||||
conventions and gguf-py's expected format:
|
||||
1. Strips 'model.' prefix (common in multimodal models)
|
||||
2. Converts '_weight' suffix to '.weight' (Gemma3 compatibility)
|
||||
3. Searches vision_name_map for multimodal parameters
|
||||
4. Falls back to text_name_map for language model parameters
|
||||
|
||||
Args:
|
||||
hf_name: Full HuggingFace parameter name (e.g.,
|
||||
'model.multi_modal_projector.mm_soft_emb_norm.weight')
|
||||
|
||||
Returns:
|
||||
GGUF tensor name with suffix (e.g., 'mm.soft_emb_norm.weight')
|
||||
or None if no mapping found
|
||||
"""
|
||||
# Strip 'language_model.' prefix for multimodal models - gguf-py
|
||||
# tensor mappings expect parameter names without this prefix.
|
||||
# Note: 'model.' prefix should be KEPT for text-only models as
|
||||
# gguf-py expects it.
|
||||
if hf_name.startswith("language_model."):
|
||||
hf_name = hf_name[15:] # Remove 'language_model.'
|
||||
|
||||
# Parse parameter name and suffix
|
||||
if hf_name.endswith((".weight", ".bias")):
|
||||
base_name, suffix = hf_name.rsplit(".", 1)
|
||||
else:
|
||||
base_name, suffix = hf_name, ""
|
||||
# Handle '_weight' suffix (Gemma3 naming: parameter ends with
|
||||
# '_weight' instead of '.weight')
|
||||
if base_name.endswith("_weight"):
|
||||
base_name = base_name[:-7] # Remove '_weight'
|
||||
suffix = "weight"
|
||||
|
||||
gguf_name = None
|
||||
# Priority 1: Search vision/projector parameters for multimodal models
|
||||
if vision_name_map is not None:
|
||||
gguf_name = vision_name_map.get_name(base_name)
|
||||
|
||||
# Priority 2: Search text backbone parameters
|
||||
if gguf_name is None:
|
||||
gguf_name = text_name_map.get_name(base_name)
|
||||
|
||||
if gguf_name is None:
|
||||
return None
|
||||
|
||||
return gguf_name + "." + suffix
|
||||
|
||||
# Build mapping and track unmapped parameters
|
||||
unmapped_params = []
|
||||
for hf_name in state_dict:
|
||||
gguf_name_with_suffix = find_hf_name_in_tensor_map(hf_name)
|
||||
|
||||
# Track mapping success
|
||||
if gguf_name_with_suffix is not None:
|
||||
gguf_to_hf_name_map[gguf_name_with_suffix] = hf_name
|
||||
logger.debug("Mapped GGUF %s → HF %s", gguf_name_with_suffix, hf_name)
|
||||
elif hf_name not in gguf_to_hf_name_map.values():
|
||||
# Parameter not in manual overrides either
|
||||
unmapped_params.append(hf_name)
|
||||
|
||||
# All parameters (except those initialized by other means) must be mapped:
|
||||
# both vision/projector and backbone
|
||||
if unmapped_params:
|
||||
unmapped_params = list(
|
||||
filter(
|
||||
lambda x: not any(re.fullmatch(p, x) for p in sideload_params),
|
||||
unmapped_params,
|
||||
)
|
||||
)
|
||||
if unmapped_params:
|
||||
raise RuntimeError(
|
||||
f"Failed to map GGUF parameters "
|
||||
f"({len(unmapped_params)}): "
|
||||
f"{unmapped_params}"
|
||||
)
|
||||
return gguf_to_hf_name_map
|
||||
|
||||
def _get_gguf_weight_type(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
model_name_or_path: str,
|
||||
gguf_to_hf_name_map: dict[str, str],
|
||||
) -> dict[str, str]:
|
||||
weight_type_map = get_gguf_weight_type_map(
|
||||
model_name_or_path, gguf_to_hf_name_map
|
||||
)
|
||||
is_multimodal = hasattr(model_config.hf_config, "vision_config")
|
||||
if is_multimodal:
|
||||
mmproj_file = detect_gguf_multimodal(model_name_or_path)
|
||||
assert mmproj_file is not None, (
|
||||
"Could not find mm_proj file for multimodal GGUF model"
|
||||
)
|
||||
logger.info("Loading extra mm_proj weights from %s...", mmproj_file)
|
||||
mm_proj_weight_type_map = get_gguf_weight_type_map(
|
||||
mmproj_file, gguf_to_hf_name_map
|
||||
)
|
||||
weight_type_map.update(mm_proj_weight_type_map)
|
||||
return weight_type_map
|
||||
|
||||
def _get_weights_iterator(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
model_name_or_path: str,
|
||||
gguf_to_hf_name_map: dict[str, str],
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""
|
||||
Iterate over GGUF model weights, loading from both main model file and
|
||||
mmproj.gguf for multimodal Gemma3 models.
|
||||
|
||||
For Gemma3 multimodal GGUF models:
|
||||
- Main file (gemma-3-*.gguf): Language model weights (model.*)
|
||||
- mmproj file (mmproj*.gguf): Vision tower + projector weights (v.*, mm.*)
|
||||
|
||||
Yields:
|
||||
Tuples of (parameter_name, tensor) for all model weights
|
||||
"""
|
||||
hf_config = model_config.hf_config
|
||||
is_multimodal = hasattr(hf_config, "vision_config")
|
||||
|
||||
if is_multimodal:
|
||||
# Load mm_proj (mm_encoder + projector) for multimodal weights
|
||||
mmproj_file = detect_gguf_multimodal(model_name_or_path)
|
||||
assert mmproj_file is not None, (
|
||||
"Could not find mm_proj file for multimodal GGUF model"
|
||||
)
|
||||
yield from gguf_quant_weights_iterator(mmproj_file, gguf_to_hf_name_map)
|
||||
|
||||
yield from gguf_quant_weights_iterator(model_name_or_path, gguf_to_hf_name_map)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(model_config)
|
||||
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
local_model_path = self._prepare_weights(model_config)
|
||||
gguf_weights_map = self._get_gguf_weights_map(model_config)
|
||||
model.load_weights(
|
||||
self._get_weights_iterator(model_config, local_model_path, gguf_weights_map)
|
||||
)
|
||||
|
||||
def load_model(
|
||||
self, vllm_config: VllmConfig, model_config: ModelConfig
|
||||
) -> nn.Module:
|
||||
device_config = vllm_config.device_config
|
||||
local_model_path = self._prepare_weights(model_config)
|
||||
gguf_weights_map = self._get_gguf_weights_map(model_config)
|
||||
# we can only know if tie word embeddings after mapping weights
|
||||
if "lm_head.weight" in get_gguf_extra_tensor_names(
|
||||
local_model_path, gguf_weights_map
|
||||
):
|
||||
model_config.hf_config.update({"tie_word_embeddings": True})
|
||||
|
||||
weight_type_map = self._get_gguf_weight_type(
|
||||
model_config, local_model_path, gguf_weights_map
|
||||
)
|
||||
# filter out unquantized modules to skip
|
||||
unquant_names = [
|
||||
name.removesuffix(".weight")
|
||||
for name, weight_type in weight_type_map.items()
|
||||
if weight_type in ("F32", "F16", "BF16") and name.endswith(".weight")
|
||||
]
|
||||
logger.debug(
|
||||
"GGUF unquantized modules: %s",
|
||||
unquant_names,
|
||||
)
|
||||
vllm_config.quant_config.unquantized_modules.extend(unquant_names)
|
||||
|
||||
target_device = torch.device(device_config.device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
model = initialize_model(vllm_config=vllm_config)
|
||||
self.load_weights(model, model_config)
|
||||
|
||||
process_weights_after_loading(model, model_config, target_device)
|
||||
return model
|
||||
@@ -1,362 +0,0 @@
|
||||
# ruff: noqa: SIM117
|
||||
import copy
|
||||
import glob
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple, Type
|
||||
|
||||
import huggingface_hub
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import (DeviceConfig, LoadConfig, LoadFormat, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.envs import VLLM_USE_MODELSCOPE
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.model_loader.tensorizer import (
|
||||
TensorizerConfig, is_vllm_serialized_tensorizer, load_with_tensorizer,
|
||||
tensorizer_weights_iterator)
|
||||
from vllm.model_executor.model_loader.utils import (get_model_architecture,
|
||||
set_default_torch_dtype)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_weights_from_hf, filter_files_not_needed_for_inference,
|
||||
get_quant_config, initialize_dummy_weights, np_cache_weights_iterator,
|
||||
pt_weights_iterator, safetensors_weights_iterator)
|
||||
from vllm.model_executor.models.llava import LlavaForConditionalGeneration
|
||||
|
||||
_VISION_MODEL_CLASSES = [
|
||||
LlavaForConditionalGeneration,
|
||||
]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def _get_quantization_config(
|
||||
model_config: ModelConfig,
|
||||
load_config: LoadConfig) -> Optional[QuantizationConfig]:
|
||||
"""Get the quantization config."""
|
||||
if model_config.quantization is not None:
|
||||
quant_config = get_quant_config(model_config, load_config)
|
||||
capability = torch.cuda.get_device_capability()
|
||||
capability = capability[0] * 10 + capability[1]
|
||||
if capability < quant_config.get_min_capability():
|
||||
raise ValueError(
|
||||
f"The quantization method {model_config.quantization} is not "
|
||||
"supported for the current GPU. "
|
||||
f"Minimum capability: {quant_config.get_min_capability()}. "
|
||||
f"Current capability: {capability}.")
|
||||
supported_dtypes = quant_config.get_supported_act_dtypes()
|
||||
if model_config.dtype not in supported_dtypes:
|
||||
raise ValueError(
|
||||
f"{model_config.dtype} is not supported for quantization "
|
||||
f"method {model_config.quantization}. Supported dtypes: "
|
||||
f"{supported_dtypes}")
|
||||
return quant_config
|
||||
return None
|
||||
|
||||
|
||||
def _get_model_initialization_kwargs(
|
||||
model_class: Type[nn.Module], lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig]
|
||||
) -> Dict[str, Any]:
|
||||
"""Get extra kwargs for model initialization."""
|
||||
extra_kwargs = {}
|
||||
if hasattr(model_class, "supported_lora_modules"):
|
||||
extra_kwargs["lora_config"] = lora_config
|
||||
elif lora_config:
|
||||
raise ValueError(
|
||||
f"Model {model_class.__name__} does not support LoRA, "
|
||||
"but LoRA is enabled. Support for this model may "
|
||||
"be added in the future. If this is important to you, "
|
||||
"please open an issue on github.")
|
||||
elif model_class in _VISION_MODEL_CLASSES:
|
||||
extra_kwargs["vision_language_config"] = vision_language_config
|
||||
return extra_kwargs
|
||||
|
||||
|
||||
def _initialize_model(
|
||||
model_config: ModelConfig, load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module:
|
||||
"""Initialize a model with the given configurations."""
|
||||
model_class = get_model_architecture(model_config)[0]
|
||||
quant_config = _get_quantization_config(model_config, load_config)
|
||||
|
||||
return model_class(config=model_config.hf_config,
|
||||
quant_config=quant_config,
|
||||
**_get_model_initialization_kwargs(
|
||||
model_class, lora_config, vision_language_config))
|
||||
|
||||
|
||||
class BaseModelLoader(ABC):
|
||||
"""Base class for model loaders."""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
self.load_config = load_config
|
||||
|
||||
@abstractmethod
|
||||
def load_model(self, *, model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig) -> nn.Module:
|
||||
"""Load a model with the given configurations."""
|
||||
...
|
||||
|
||||
|
||||
class DefaultModelLoader(BaseModelLoader):
|
||||
"""Model loader that can load different file types from disk."""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
if load_config.model_loader_extra_config:
|
||||
raise ValueError(f"Model loader extra config is not supported for "
|
||||
f"load format {load_config.load_format}")
|
||||
|
||||
def _maybe_download_from_modelscope(
|
||||
self, model: str, revision: Optional[str]) -> Optional[str]:
|
||||
"""Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
|
||||
|
||||
Returns the path to the downloaded model, or None if the model is not
|
||||
downloaded from ModelScope."""
|
||||
if VLLM_USE_MODELSCOPE:
|
||||
# download model from ModelScope hub,
|
||||
# lazy import so that modelscope is not required for normal use.
|
||||
# pylint: disable=C.
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
|
||||
if not os.path.exists(model):
|
||||
model_path = snapshot_download(
|
||||
model_id=model,
|
||||
cache_dir=self.load_config.download_dir,
|
||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||
revision=revision,
|
||||
)
|
||||
else:
|
||||
model_path = model
|
||||
return model_path
|
||||
return None
|
||||
|
||||
def _prepare_weights(self, model_name_or_path: str,
|
||||
revision: Optional[str],
|
||||
fall_back_to_pt: bool) -> Tuple[str, List[str], bool]:
|
||||
"""Prepare weights for the model.
|
||||
|
||||
If the model is not local, it will be downloaded."""
|
||||
model_name_or_path = self._maybe_download_from_modelscope(
|
||||
model_name_or_path, revision) or model_name_or_path
|
||||
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
load_format = self.load_config.load_format
|
||||
use_safetensors = False
|
||||
# Some quantized models use .pt files for storing the weights.
|
||||
if load_format == LoadFormat.AUTO:
|
||||
allow_patterns = ["*.safetensors", "*.bin"]
|
||||
elif load_format == LoadFormat.SAFETENSORS:
|
||||
use_safetensors = True
|
||||
allow_patterns = ["*.safetensors"]
|
||||
elif load_format == LoadFormat.PT:
|
||||
allow_patterns = ["*.pt"]
|
||||
elif load_format == LoadFormat.NPCACHE:
|
||||
allow_patterns = ["*.bin"]
|
||||
else:
|
||||
raise ValueError(f"Unknown load_format: {load_format}")
|
||||
|
||||
if fall_back_to_pt:
|
||||
allow_patterns += ["*.pt"]
|
||||
|
||||
if not is_local:
|
||||
hf_folder = download_weights_from_hf(model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
allow_patterns, revision)
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
|
||||
hf_weights_files: List[str] = []
|
||||
for pattern in allow_patterns:
|
||||
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
||||
if len(hf_weights_files) > 0:
|
||||
if pattern == "*.safetensors":
|
||||
use_safetensors = True
|
||||
break
|
||||
|
||||
if not use_safetensors:
|
||||
hf_weights_files = filter_files_not_needed_for_inference(
|
||||
hf_weights_files)
|
||||
|
||||
if len(hf_weights_files) == 0:
|
||||
raise RuntimeError(
|
||||
f"Cannot find any model weights with `{model_name_or_path}`")
|
||||
|
||||
return hf_folder, hf_weights_files, use_safetensors
|
||||
|
||||
def _get_weights_iterator(
|
||||
self, model_name_or_path: str, revision: Optional[str],
|
||||
fall_back_to_pt: bool
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Get an iterator for the model weights based on the load format."""
|
||||
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
|
||||
model_name_or_path, revision, fall_back_to_pt)
|
||||
if self.load_config.load_format == LoadFormat.NPCACHE:
|
||||
# Currently np_cache only support *.bin checkpoints
|
||||
assert use_safetensors is False
|
||||
return np_cache_weights_iterator(model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
hf_folder, hf_weights_files)
|
||||
if use_safetensors:
|
||||
return safetensors_weights_iterator(hf_weights_files)
|
||||
return pt_weights_iterator(hf_weights_files)
|
||||
|
||||
def load_model(self, *, model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig) -> nn.Module:
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = _initialize_model(model_config, self.load_config,
|
||||
lora_config, vision_language_config)
|
||||
model.load_weights(
|
||||
self._get_weights_iterator(model_config.model,
|
||||
model_config.revision,
|
||||
fall_back_to_pt=getattr(
|
||||
model,
|
||||
"fall_back_to_pt_during_load",
|
||||
True)), )
|
||||
for _, module in model.named_modules():
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if quant_method is not None:
|
||||
quant_method.process_weights_after_loading(module)
|
||||
# FIXME: Remove this after Mixtral is updated
|
||||
# to use quant_method.
|
||||
if hasattr(module, "process_weights_after_loading"):
|
||||
module.process_weights_after_loading()
|
||||
return model.eval()
|
||||
|
||||
|
||||
class DummyModelLoader(BaseModelLoader):
|
||||
"""Model loader that will set model weights to random values."""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
if load_config.model_loader_extra_config:
|
||||
raise ValueError(f"Model loader extra config is not supported for "
|
||||
f"load format {load_config.load_format}")
|
||||
|
||||
def load_model(self, *, model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig) -> nn.Module:
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = _initialize_model(model_config, self.load_config,
|
||||
lora_config, vision_language_config)
|
||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||
# random values to the weights.
|
||||
initialize_dummy_weights(model)
|
||||
return model.eval()
|
||||
|
||||
|
||||
class TensorizerLoader(BaseModelLoader):
|
||||
"""Model loader using CoreWeave's tensorizer library."""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
|
||||
self.tensorizer_config = load_config.model_loader_extra_config
|
||||
else:
|
||||
self.tensorizer_config = TensorizerConfig(
|
||||
**load_config.model_loader_extra_config)
|
||||
|
||||
def _verify_config(self, model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig):
|
||||
self.tensorizer_config.verify_with_model_config(model_config)
|
||||
self.tensorizer_config.verify_with_parallel_config(parallel_config)
|
||||
|
||||
def _get_weights_iterator(
|
||||
self) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
|
||||
return tensorizer_weights_iterator(tensorizer_args)
|
||||
|
||||
def _load_model_unserialized(
|
||||
self, model_config: ModelConfig, device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig]
|
||||
) -> nn.Module:
|
||||
"""Load an unserialized model with tensorizer.
|
||||
|
||||
Unserialized here means "not serialized with tensorizer". This
|
||||
should still be faster than default HuggingFace loading, but will
|
||||
be slower than loading a tensorizer-serialized model.
|
||||
"""
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = _initialize_model(model_config, self.load_config,
|
||||
lora_config, vision_language_config)
|
||||
|
||||
model.load_weights(self._get_weights_iterator())
|
||||
return model.eval()
|
||||
|
||||
def _load_model_serialized(
|
||||
self, model_config: ModelConfig, device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig]
|
||||
) -> nn.Module:
|
||||
"""Load a serialized model with tensorizer.
|
||||
|
||||
See the examples/tensorize_vllm_model.py example "
|
||||
script for serializing vLLM models."""
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model_class = get_model_architecture(model_config)[0]
|
||||
quant_config = _get_quantization_config(
|
||||
model_config, self.load_config)
|
||||
extra_kwargs = _get_model_initialization_kwargs(
|
||||
model_class, lora_config, vision_language_config)
|
||||
extra_kwargs["quant_config"] = quant_config
|
||||
|
||||
tensorizer_config = copy.copy(self.tensorizer_config)
|
||||
tensorizer_config.model_class = model_class
|
||||
tensorizer_config.hf_config = model_config.hf_config
|
||||
tensorizer_config.dtype = model_config.dtype
|
||||
|
||||
model = load_with_tensorizer(tensorizer_config, **extra_kwargs)
|
||||
return model.eval()
|
||||
|
||||
def load_model(self, *, model_config: ModelConfig,
|
||||
device_config: DeviceConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig) -> nn.Module:
|
||||
self._verify_config(model_config, parallel_config)
|
||||
|
||||
if is_vllm_serialized_tensorizer(self.tensorizer_config):
|
||||
return self._load_model_serialized(model_config, device_config,
|
||||
lora_config,
|
||||
vision_language_config)
|
||||
return self._load_model_unserialized(model_config, device_config,
|
||||
lora_config,
|
||||
vision_language_config)
|
||||
|
||||
|
||||
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
||||
"""Get a model loader based on the load format."""
|
||||
|
||||
if isinstance(load_config.load_format, type):
|
||||
return load_config.load_format(load_config)
|
||||
|
||||
if load_config.load_format == LoadFormat.DUMMY:
|
||||
return DummyModelLoader(load_config)
|
||||
|
||||
if load_config.load_format == LoadFormat.TENSORIZER:
|
||||
return TensorizerLoader(load_config)
|
||||
|
||||
return DefaultModelLoader(load_config)
|
||||
@@ -1,136 +0,0 @@
|
||||
"""Utilities for selecting and loading neuron models."""
|
||||
import importlib
|
||||
import os
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
TORCH_DTYPE_TO_NEURON_AMP = {
|
||||
"auto": "f32",
|
||||
"half": "f16",
|
||||
"float16": "f16",
|
||||
"bfloat16": "bf16",
|
||||
"float": "f32",
|
||||
"float32": "f32",
|
||||
torch.float16: "f16",
|
||||
torch.bfloat16: "bf16",
|
||||
torch.float32: "f32",
|
||||
}
|
||||
|
||||
# Models supported by Neuron.
|
||||
_NEURON_SUPPORTED_MODELS: Dict[str, Tuple[str, str, str]] = {
|
||||
"LlamaForCausalLM": ("transformers_neuronx.llama.model",
|
||||
"LlamaForSampling", "LlamaForCausalLM"),
|
||||
"MistralForCausalLM": ("transformers_neuronx.mistral.model",
|
||||
"MistralForSampling", "MistralForCausalLM")
|
||||
}
|
||||
|
||||
|
||||
class NeuronCasualLM(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size,
|
||||
logits_as_input=True)
|
||||
self.sampler = Sampler()
|
||||
|
||||
# Lazy initialized
|
||||
self.model: nn.Module
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_block_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
logits = self.model(input_ids,
|
||||
cache_ids=positions,
|
||||
start_ids=input_block_ids)
|
||||
return logits
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(None, hidden_states, sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, model_name_or_path: str, **kwargs):
|
||||
arch = _get_model_architecture(self.config)
|
||||
neuronx_module_path, neuronx_model_cls_name, hf_model_cls_name = (
|
||||
_NEURON_SUPPORTED_MODELS[arch])
|
||||
neuronx_module = importlib.import_module(neuronx_module_path)
|
||||
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
|
||||
|
||||
split_model_dir = f"{model_name_or_path}-split"
|
||||
if os.path.isdir(os.path.join(model_name_or_path,
|
||||
"pytorch_model.bin")):
|
||||
split_model_dir = model_name_or_path
|
||||
elif not os.path.exists(f"{model_name_or_path}-split"):
|
||||
hf_model_cls = getattr(transformers, hf_model_cls_name)
|
||||
from transformers_neuronx.module import save_pretrained_split
|
||||
|
||||
hf_model = hf_model_cls.from_pretrained(model_name_or_path,
|
||||
low_cpu_mem_usage=True)
|
||||
save_pretrained_split(hf_model, f"{model_name_or_path}-split")
|
||||
|
||||
self.model = neuronx_model_cls.from_pretrained(split_model_dir,
|
||||
**kwargs)
|
||||
self.model.to_neuron()
|
||||
|
||||
|
||||
def _get_model_architecture(config: PretrainedConfig) -> str:
|
||||
architectures = getattr(config, "architectures", [])
|
||||
for arch in architectures:
|
||||
if arch in _NEURON_SUPPORTED_MODELS:
|
||||
return arch
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} are not supported on Neuron "
|
||||
f"for now. Supported architectures: "
|
||||
f"{list(_NEURON_SUPPORTED_MODELS.keys())}")
|
||||
|
||||
|
||||
def get_neuron_model(model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig) -> nn.Module:
|
||||
from transformers_neuronx.config import (ContinuousBatchingConfig,
|
||||
NeuronConfig)
|
||||
|
||||
# Create a model instance.
|
||||
model = NeuronCasualLM(model_config.hf_config)
|
||||
|
||||
continuous_batching_config = ContinuousBatchingConfig(
|
||||
batch_size_for_shared_caches=scheduler_config.max_num_seqs)
|
||||
neuron_config = NeuronConfig(
|
||||
continuous_batching=continuous_batching_config)
|
||||
|
||||
# Load the weights from the cached or downloaded files.
|
||||
model.load_weights(
|
||||
model_config.model,
|
||||
tp_degree=parallel_config.tensor_parallel_size,
|
||||
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
|
||||
neuron_config=neuron_config,
|
||||
context_length_estimate=[scheduler_config.max_model_len],
|
||||
n_positions=[scheduler_config.max_model_len],
|
||||
batch_size=scheduler_config.max_num_seqs)
|
||||
|
||||
return model.eval()
|
||||
275
vllm/model_executor/model_loader/online_quantization.py
Normal file
275
vllm/model_executor/model_loader/online_quantization.py
Normal file
@@ -0,0 +1,275 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import types
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.utils import process_weights_after_loading
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Notes for Online Quantization
|
||||
# In terms of state of checkpoints, quantization config and their
|
||||
# correspondance to online quantization:
|
||||
# | Use Case | Checkpoints | model_config.quantization |
|
||||
# | no quant | high precision | None |
|
||||
# | offline quant | quantized | fp8, torchao etc. |
|
||||
# | online quant | high precision | torchao etc. |
|
||||
#
|
||||
# The process for loading non-quantized checkpoint
|
||||
# 1. load non-quantized weights (load_weights)
|
||||
# 2. do any additional post processing (process_weights_after_loading)
|
||||
#
|
||||
# The process for loading offline quantized checkpoint
|
||||
# 1. load offline-quantized weights (load_weights)
|
||||
# 2. do any additional post processing (process_weights_after_loading)
|
||||
|
||||
# The process for unquantized model reloading
|
||||
# (repeated run in RL training loop)
|
||||
# first run
|
||||
# UI1. load_weights: load bfloat16 weights
|
||||
# UI2. process_weights_after_loading: any additional post processing
|
||||
# subsequent run
|
||||
# UC1: load_weights: load bfloat16 weights
|
||||
# (shouldn't be any issues since we didn't change any attributes
|
||||
# of the weights)
|
||||
# UC2: process_weights_after_loading: any additional post processing
|
||||
|
||||
# The process for weight reloading with online quantization
|
||||
# (repeated run in RL training loop)
|
||||
# first run
|
||||
# I1. load_weights: load bfloat16 weights
|
||||
# I2. process_weights_after_loading:
|
||||
# record weight metadata and attributes for R1 and R2
|
||||
# quantize weights to fp8
|
||||
# subsequent run
|
||||
# (beginning model weight is in fp8)
|
||||
# load_weights:
|
||||
# R1. restore bfloat16 model weight metadata
|
||||
# R2. restore the model weight attributes
|
||||
# R3. reload bfloat16 weights
|
||||
# R4. quantize weights (by calling process_weights_after_loading),
|
||||
# also set `process_weights_after_loading_already_called` to
|
||||
# True to stop it from running again
|
||||
# R5. (workaround for cudagraph), we restore the weight params to original quantized
|
||||
# weights params, and use original_weight_param.copy_(updated_weight_param) so that
|
||||
# the weight update work well with cudagraph
|
||||
# process_weights_after_loading (if called):
|
||||
# this will be skipped since it's already ran in
|
||||
# load_weights
|
||||
|
||||
|
||||
def maybe_save_metadata_and_attributes_for_weight_reloading(
|
||||
model: nn.Module, model_config: ModelConfig
|
||||
):
|
||||
# following is to support on the fly quantization, currently only supported
|
||||
# for torchao
|
||||
if model_config.quantization != "torchao":
|
||||
return
|
||||
|
||||
from vllm.model_executor.model_loader.weight_utils import get_quant_config
|
||||
|
||||
quant_config = get_quant_config(model_config, None)
|
||||
|
||||
# If checkpoint is already torchao serialized, this means it's
|
||||
# pre-quantized quantization case, we'll skip saving the metadata
|
||||
# Otherwise, this is Step I2 of initialization steps of
|
||||
# online quantization
|
||||
# This step record the weights metadata and weight attributes so we can
|
||||
# restore the bfloat16 model weights during the relad step (R1 and R2)
|
||||
# see Notes in online_quantization.py for more details
|
||||
if not (
|
||||
hasattr(quant_config, "is_checkpoint_torchao_serialized")
|
||||
and not quant_config.is_checkpoint_torchao_serialized
|
||||
):
|
||||
return
|
||||
|
||||
# This is the I2 step of online quantiztion that saves
|
||||
# metadata and attributes of weights so they can be used in R1 and
|
||||
# R2 step, note that we only save these during initialization
|
||||
|
||||
# Includes two things
|
||||
# 1. save floating point metadata (shape, dtype, device) for init
|
||||
# 2. save weight attributes, e.g. `output_dim`, `weight_loader` for init
|
||||
|
||||
if getattr(model, "weight_metadata_and_attr_saved", False):
|
||||
return
|
||||
|
||||
# save the dtype, shape and device for model parameter, used for
|
||||
# restoring the model high precision parameters before
|
||||
# reloading the weights
|
||||
assert not hasattr(model, "original_weights_rebuild_keys")
|
||||
model.original_weights_rebuild_keys = {}
|
||||
for name, p in model.named_parameters():
|
||||
model.original_weights_rebuild_keys[name] = {
|
||||
"shape": p.shape,
|
||||
"dtype": p.dtype,
|
||||
"device": p.device,
|
||||
}
|
||||
|
||||
# record the weight attributes (loader functions etc.)
|
||||
# so these can be recovered later when we reload the weights
|
||||
# structure: {"weight_name": {"weight_attr_key": attr}}
|
||||
assert not hasattr(model, "recorded_weight_attr")
|
||||
model.recorded_weight_attr = {}
|
||||
for name, param in model.named_parameters():
|
||||
model.recorded_weight_attr[name] = {}
|
||||
for key in param.__dict__:
|
||||
if hasattr(param, key):
|
||||
attr = getattr(param, key)
|
||||
if not callable(attr):
|
||||
model.recorded_weight_attr[name][key] = attr
|
||||
elif hasattr(attr, "__self__") and param is attr.__self__:
|
||||
# if attr is a bonded method for an instance, and
|
||||
# attr.__self__ points to the instance (param)
|
||||
# we'll record the underlying function object
|
||||
model.recorded_weight_attr[name][key] = attr.__func__
|
||||
else:
|
||||
model.recorded_weight_attr[name][key] = attr
|
||||
# mark the metadata and attributes saved so we don't run it again
|
||||
model._model_config = model_config
|
||||
model.weight_metadata_and_attr_saved = True
|
||||
|
||||
|
||||
def _bond_method_to_cls(func, obj):
|
||||
if hasattr(func, "__self__") or not callable(func):
|
||||
# If the function is already bound to an instance, return it as is
|
||||
return func
|
||||
else:
|
||||
return types.MethodType(func, obj)
|
||||
|
||||
|
||||
def support_quantized_model_reload_from_hp_weights(original_load_weights):
|
||||
"""Decorator for `load_weights` method for AutoWeightsLoader.load_weights to support
|
||||
reloading high precision (bfloat16/float16/float32) weight for an already quantized
|
||||
model, this involves restoring the weights to a high precision weights and
|
||||
then online quantize the weights
|
||||
"""
|
||||
# online quantization, right now only enabled for
|
||||
# torchao
|
||||
# R1, R2, R3, R4, R5 in the Notes
|
||||
|
||||
def patched_model_load_weights(
|
||||
auto_weight_loader, weights: Iterable[tuple[str, torch.Tensor]], *, mapper=None
|
||||
) -> set[str]:
|
||||
model = auto_weight_loader.module
|
||||
offline_quantization_or_first_run_of_online_quantization = not getattr(
|
||||
model, "weight_metadata_and_attr_saved", False
|
||||
)
|
||||
|
||||
# if we don't have `model.weight_metadata_and_attr_saved` defined and
|
||||
# set to True, it means that this is either offline quantization case
|
||||
# or the first run of online quantization
|
||||
# see Notes in this file for more details
|
||||
if offline_quantization_or_first_run_of_online_quantization:
|
||||
# case 1: offline quantized checkpoint
|
||||
# case 2: Step I1 first run of weight loading with
|
||||
# online quantization
|
||||
return original_load_weights(auto_weight_loader, weights, mapper=mapper)
|
||||
|
||||
model_config = model._model_config
|
||||
|
||||
# TODO: Add fp8 support
|
||||
assert model_config.quantization == "torchao", (
|
||||
"online quantization is only enabled for torchao currently"
|
||||
)
|
||||
# TODO: use create_weights to restore the weights to original state
|
||||
|
||||
# Step R1: First restore the quantized weights to original bfloat16
|
||||
# weights, with original metadata (shape, dtype, device)
|
||||
# and attributes, so that bfloat16 weights can be loaded properly
|
||||
# TODO: maybe set remove_duplicate to True?
|
||||
original_quantized_weight_dict = dict(
|
||||
model.named_parameters(remove_duplicate=False)
|
||||
)
|
||||
named_modules = dict(model.named_modules(remove_duplicate=False))
|
||||
model_device = None
|
||||
|
||||
for name, d in model.original_weights_rebuild_keys.items():
|
||||
_shape = d["shape"]
|
||||
_dtype = d["dtype"]
|
||||
_device = d["device"]
|
||||
if model_device is not None:
|
||||
assert model_device == _device, (
|
||||
"Expecting all weights "
|
||||
"to be in the same device for now, got both: "
|
||||
f"{model_device} and {_device}"
|
||||
)
|
||||
else:
|
||||
model_device = _device
|
||||
|
||||
if name in original_quantized_weight_dict:
|
||||
module_name, weight_name = name.rsplit(".", 1)
|
||||
module = named_modules[module_name]
|
||||
setattr(
|
||||
module,
|
||||
weight_name,
|
||||
torch.nn.Parameter(
|
||||
torch.empty(_shape, dtype=_dtype, device=_device),
|
||||
requires_grad=False,
|
||||
),
|
||||
)
|
||||
|
||||
# Step R2: recover the weight attributes to the state before first loading
|
||||
# recorded_weight_attr is
|
||||
# {"weight_name": {"weight_attr_key": attr}}
|
||||
# e.g.
|
||||
# {
|
||||
# {
|
||||
# "layer.0.weight": {
|
||||
# "weight_loader": weight_loader_function_object,
|
||||
# "input_dim": 0, ...
|
||||
# },
|
||||
# "layer.1.weight": ...,
|
||||
# }
|
||||
# }
|
||||
for full_weight_name, weight_attr_dict in model.recorded_weight_attr.items():
|
||||
for attr_name, attr in weight_attr_dict.items():
|
||||
module_name, weight_name = full_weight_name.rsplit(".", 1)
|
||||
module = named_modules[module_name]
|
||||
weight = getattr(module, weight_name)
|
||||
if not hasattr(weight, attr_name):
|
||||
setattr(weight, attr_name, _bond_method_to_cls(attr, weight))
|
||||
|
||||
# Step R3: reload bfloat16 / high precision weights
|
||||
updated_params = original_load_weights(
|
||||
auto_weight_loader, weights, mapper=mapper
|
||||
)
|
||||
|
||||
# Step R4: online quantize the weights
|
||||
# manually process weights after loading
|
||||
model.process_weights_after_loading_already_called = False
|
||||
if model_device is not None:
|
||||
process_weights_after_loading(model, model_config, model_device)
|
||||
else:
|
||||
logger.warning_once(
|
||||
"model_device is None, skip calling process_weights_after_loading"
|
||||
)
|
||||
|
||||
# Step R5 (workaround for cudagraph): restore the original quantized weights
|
||||
# and do a copy_ of the currents weights to the original weights
|
||||
updated_quantized_weights = dict(model.named_parameters(remove_duplicate=False))
|
||||
for name in model.original_weights_rebuild_keys:
|
||||
if name in original_quantized_weight_dict:
|
||||
original_quantized_weight = original_quantized_weight_dict[name]
|
||||
updated_quantized_weight = updated_quantized_weights[name]
|
||||
|
||||
module_name, weight_name = name.rsplit(".", 1)
|
||||
module = named_modules[module_name]
|
||||
setattr(module, weight_name, original_quantized_weight)
|
||||
with torch.no_grad():
|
||||
original_quantized_weight.copy_(updated_quantized_weight)
|
||||
|
||||
del original_quantized_weight_dict
|
||||
del named_modules
|
||||
del updated_quantized_weight
|
||||
|
||||
model.process_weights_after_loading_already_called = True
|
||||
return updated_params
|
||||
|
||||
return patched_model_load_weights
|
||||
116
vllm/model_executor/model_loader/runai_streamer_loader.py
Normal file
116
vllm/model_executor/model_loader/runai_streamer_loader.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: SIM117
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_safetensors_index_file_from_hf,
|
||||
download_weights_from_hf,
|
||||
runai_safetensors_weights_iterator,
|
||||
)
|
||||
from vllm.transformers_utils.runai_utils import is_runai_obj_uri, list_safetensors
|
||||
|
||||
|
||||
class RunaiModelStreamerLoader(BaseModelLoader):
|
||||
"""
|
||||
Model loader that can load safetensors
|
||||
files from local FS or S3 bucket.
|
||||
"""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
|
||||
self._is_distributed = False
|
||||
if load_config.model_loader_extra_config:
|
||||
extra_config = load_config.model_loader_extra_config
|
||||
|
||||
if "distributed" in extra_config and isinstance(
|
||||
extra_config.get("distributed"), bool
|
||||
):
|
||||
self._is_distributed = extra_config.get("distributed")
|
||||
|
||||
if "concurrency" in extra_config and isinstance(
|
||||
extra_config.get("concurrency"), int
|
||||
):
|
||||
os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(
|
||||
extra_config.get("concurrency")
|
||||
)
|
||||
|
||||
if "memory_limit" in extra_config and isinstance(
|
||||
extra_config.get("memory_limit"), int
|
||||
):
|
||||
os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(
|
||||
extra_config.get("memory_limit")
|
||||
)
|
||||
|
||||
runai_streamer_s3_endpoint = os.getenv("RUNAI_STREAMER_S3_ENDPOINT")
|
||||
aws_endpoint_url = os.getenv("AWS_ENDPOINT_URL")
|
||||
if runai_streamer_s3_endpoint is None and aws_endpoint_url is not None:
|
||||
os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url
|
||||
|
||||
def _prepare_weights(
|
||||
self, model_name_or_path: str, revision: str | None
|
||||
) -> list[str]:
|
||||
"""Prepare weights for the model.
|
||||
|
||||
If the model is not local, it will be downloaded."""
|
||||
|
||||
is_object_storage_path = is_runai_obj_uri(model_name_or_path)
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
safetensors_pattern = "*.safetensors"
|
||||
index_file = SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
hf_folder = (
|
||||
model_name_or_path
|
||||
if (is_local or is_object_storage_path)
|
||||
else download_weights_from_hf(
|
||||
model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
[safetensors_pattern],
|
||||
revision,
|
||||
ignore_patterns=self.load_config.ignore_patterns,
|
||||
)
|
||||
)
|
||||
hf_weights_files = list_safetensors(path=hf_folder)
|
||||
|
||||
if not is_local and not is_object_storage_path:
|
||||
download_safetensors_index_file_from_hf(
|
||||
model_name_or_path, index_file, self.load_config.download_dir, revision
|
||||
)
|
||||
|
||||
if not hf_weights_files:
|
||||
raise RuntimeError(
|
||||
f"Cannot find any safetensors model weights with `{model_name_or_path}`"
|
||||
)
|
||||
|
||||
return hf_weights_files
|
||||
|
||||
def _get_weights_iterator(
|
||||
self, model_or_path: str, revision: str
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Get an iterator for the model weights based on the load format."""
|
||||
hf_weights_files = self._prepare_weights(model_or_path, revision)
|
||||
return runai_safetensors_weights_iterator(
|
||||
hf_weights_files, self.load_config.use_tqdm_on_load, self._is_distributed
|
||||
)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
"""Download model if necessary"""
|
||||
self._prepare_weights(model_config.model, model_config.revision)
|
||||
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
"""Load weights into a model."""
|
||||
model_weights = model_config.model
|
||||
if hasattr(model_config, "model_weights"):
|
||||
model_weights = model_config.model_weights
|
||||
model.load_weights(
|
||||
self._get_weights_iterator(model_weights, model_config.revision)
|
||||
)
|
||||
214
vllm/model_executor/model_loader/sharded_state_loader.py
Normal file
214
vllm/model_executor/model_loader/sharded_state_loader.py
Normal file
@@ -0,0 +1,214 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import collections
|
||||
import glob
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_weights_from_hf,
|
||||
runai_safetensors_weights_iterator,
|
||||
)
|
||||
from vllm.transformers_utils.s3_utils import glob as s3_glob
|
||||
from vllm.transformers_utils.utils import is_s3
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ShardedStateLoader(BaseModelLoader):
|
||||
"""
|
||||
Model loader that directly loads each worker's model state dict, which
|
||||
enables a fast load path for large tensor-parallel models where each worker
|
||||
only needs to read its own shard rather than the entire checkpoint. See
|
||||
`examples/offline_inference/save_sharded_state.py` for creating a sharded
|
||||
checkpoint.
|
||||
"""
|
||||
|
||||
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
|
||||
extra_config = (
|
||||
{}
|
||||
if load_config.model_loader_extra_config is None
|
||||
else load_config.model_loader_extra_config.copy()
|
||||
)
|
||||
self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
|
||||
if extra_config:
|
||||
raise ValueError(
|
||||
f"Unexpected extra config keys for load format "
|
||||
f"{load_config.load_format}: "
|
||||
f"{load_config.model_loader_extra_config.keys()}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _filter_subtensors(
|
||||
tensors: dict[str, torch.Tensor],
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Filter out all tensors that share the same memory or a subset of the
|
||||
memory of another tensor.
|
||||
"""
|
||||
same_storage_groups: dict[Any, list[tuple[str, torch.Tensor]]] = (
|
||||
collections.defaultdict(list)
|
||||
)
|
||||
for key, tensor in tensors.items():
|
||||
if tensor.numel():
|
||||
ptr = tensor.untyped_storage().data_ptr()
|
||||
same_storage_groups[tensor.device, ptr].append((key, tensor))
|
||||
|
||||
def get_end_ptr(tensor: torch.Tensor) -> int:
|
||||
return tensor.view(-1)[-1].data_ptr() + tensor.element_size()
|
||||
|
||||
result: dict[str, torch.Tensor] = {}
|
||||
for group in same_storage_groups.values():
|
||||
for k, t in group:
|
||||
a, b = t.data_ptr(), get_end_ptr(t)
|
||||
for k2, t2 in group:
|
||||
if not t2.is_contiguous():
|
||||
continue
|
||||
a2, b2 = t2.data_ptr(), get_end_ptr(t2)
|
||||
if a < a2 or b2 < b:
|
||||
continue
|
||||
if a2 < a or b < b2 or not t.is_contiguous():
|
||||
break # t2 covers strictly more memory than t.
|
||||
if k2 < k:
|
||||
# Same tensors, keep the one with the smaller key.
|
||||
break
|
||||
else:
|
||||
result[k] = t
|
||||
return result
|
||||
|
||||
def _prepare_weights(self, model_name_or_path: str, revision: str | None):
|
||||
if is_s3(model_name_or_path) or os.path.isdir(model_name_or_path):
|
||||
return model_name_or_path
|
||||
else:
|
||||
allow_patterns = ["*.safetensors"]
|
||||
return download_weights_from_hf(
|
||||
model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
allow_patterns,
|
||||
revision,
|
||||
ignore_patterns=self.load_config.ignore_patterns,
|
||||
)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(model_config.model, model_config.revision)
|
||||
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
model_weights = model_config.model
|
||||
if hasattr(model_config, "model_weights"):
|
||||
model_weights = model_config.model_weights
|
||||
local_model_path = model_weights
|
||||
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
pattern = os.path.join(
|
||||
local_model_path,
|
||||
self.pattern.format(rank=rank, part="*"),
|
||||
)
|
||||
|
||||
filepaths = []
|
||||
if is_s3(local_model_path):
|
||||
file_pattern = f"*{self.pattern.format(rank=rank, part='*')}"
|
||||
filepaths = s3_glob(path=local_model_path, allow_pattern=[file_pattern])
|
||||
else:
|
||||
filepaths = glob.glob(pattern)
|
||||
if not filepaths:
|
||||
# TODO: support un-sharded checkpoints too
|
||||
raise ValueError(
|
||||
f"Could not find checkpoint files '{pattern}', only "
|
||||
f"pre-sharded checkpoints are currently supported!"
|
||||
)
|
||||
state_dict = self._filter_subtensors(model.state_dict())
|
||||
counter_before_loading_weights = time.perf_counter()
|
||||
for key, tensor in self.iterate_over_files(filepaths):
|
||||
# If loading with LoRA enabled, additional padding may
|
||||
# be added to certain parameters. We only load into a
|
||||
# narrowed view of the parameter data.
|
||||
param_data = state_dict[key].data
|
||||
param_shape = state_dict[key].shape
|
||||
for dim, size in enumerate(tensor.shape):
|
||||
if size < param_shape[dim]:
|
||||
param_data = param_data.narrow(dim, 0, size)
|
||||
if tensor.shape != param_shape:
|
||||
logger.warning(
|
||||
"loading tensor of shape %s into parameter '%s' of shape %s",
|
||||
tensor.shape,
|
||||
key,
|
||||
param_shape,
|
||||
)
|
||||
param_data.copy_(tensor)
|
||||
state_dict.pop(key)
|
||||
counter_after_loading_weights = time.perf_counter()
|
||||
logger.info_once(
|
||||
"Loading weights took %.2f seconds",
|
||||
counter_after_loading_weights - counter_before_loading_weights,
|
||||
scope="local",
|
||||
)
|
||||
if state_dict:
|
||||
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
|
||||
|
||||
def iterate_over_files(
|
||||
self, paths
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
if self.load_config.load_format == "runai_streamer_sharded":
|
||||
yield from runai_safetensors_weights_iterator(paths, True)
|
||||
else:
|
||||
from safetensors.torch import safe_open
|
||||
|
||||
for path in paths:
|
||||
with safe_open(path, framework="pt") as f:
|
||||
for key in f.keys(): # noqa: SIM118
|
||||
tensor = f.get_tensor(key)
|
||||
yield key, tensor
|
||||
|
||||
@staticmethod
|
||||
def save_model(
|
||||
model: torch.nn.Module,
|
||||
path: str,
|
||||
pattern: str | None = None,
|
||||
max_size: int | None = None,
|
||||
) -> None:
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
if pattern is None:
|
||||
pattern = ShardedStateLoader.DEFAULT_PATTERN
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
part_idx = 0
|
||||
total_size = 0
|
||||
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
|
||||
state_dict_part: dict[str, torch.Tensor] = {}
|
||||
for key, tensor in state_dict.items():
|
||||
param_size = tensor.nelement() * tensor.element_size()
|
||||
if max_size is not None and total_size + param_size > max_size:
|
||||
filename = pattern.format(rank=rank, part=part_idx)
|
||||
save_file(
|
||||
state_dict_part,
|
||||
os.path.join(path, filename),
|
||||
)
|
||||
part_idx += 1
|
||||
total_size = 0
|
||||
state_dict_part = {}
|
||||
state_dict_part[key] = tensor
|
||||
total_size += param_size
|
||||
if len(state_dict_part) > 0:
|
||||
filename = pattern.format(rank=rank, part=part_idx)
|
||||
save_file(
|
||||
state_dict_part,
|
||||
os.path.join(path, filename),
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
151
vllm/model_executor/model_loader/tensorizer_loader.py
Normal file
151
vllm/model_executor/model_loader/tensorizer_loader.py
Normal file
@@ -0,0 +1,151 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: SIM117
|
||||
import copy
|
||||
from collections.abc import Generator
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import ModelConfig, ParallelConfig, VllmConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.tensorizer import (
|
||||
TensorizerConfig,
|
||||
deserialize_tensorizer_model,
|
||||
init_tensorizer_model,
|
||||
is_vllm_tensorized,
|
||||
serialize_vllm_model,
|
||||
tensorizer_weights_iterator,
|
||||
)
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
get_model_architecture,
|
||||
initialize_model,
|
||||
)
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
BLACKLISTED_TENSORIZER_ARGS = {
|
||||
"device", # vLLM decides this
|
||||
"dtype", # vLLM decides this
|
||||
"mode", # Not meant to be configurable by the user
|
||||
}
|
||||
|
||||
|
||||
def validate_config(config: dict):
|
||||
for k, v in config.items():
|
||||
if v is not None and k in BLACKLISTED_TENSORIZER_ARGS:
|
||||
raise ValueError(f"{k} is not an allowed Tensorizer argument.")
|
||||
|
||||
|
||||
class TensorizerLoader(BaseModelLoader):
|
||||
"""Model loader using CoreWeave's tensorizer library."""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
|
||||
self.tensorizer_config = load_config.model_loader_extra_config
|
||||
else:
|
||||
validate_config(load_config.model_loader_extra_config)
|
||||
self.tensorizer_config = TensorizerConfig(
|
||||
**load_config.model_loader_extra_config["tensorizer_config"]
|
||||
)
|
||||
|
||||
def _verify_config(
|
||||
self, model_config: ModelConfig, parallel_config: ParallelConfig
|
||||
):
|
||||
self.tensorizer_config.verify_with_model_config(model_config)
|
||||
self.tensorizer_config.verify_with_parallel_config(parallel_config)
|
||||
|
||||
def _get_weights_iterator(
|
||||
self,
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
|
||||
return tensorizer_weights_iterator(tensorizer_args)
|
||||
|
||||
def _load_model_serialized_cpu(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
) -> nn.Module:
|
||||
"""Load a serialized model with tensorizer to the CPU.
|
||||
|
||||
This is only necessary when the model isn't vLLM-tensorized (see
|
||||
examples/others/tensorize_vllm_model.py) This should still
|
||||
be faster than default HuggingFace loading, but will be slower than
|
||||
loading a vLLM-tensorized model.
|
||||
"""
|
||||
device_config = vllm_config.device_config
|
||||
model_config = vllm_config.model_config
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = initialize_model(vllm_config=vllm_config)
|
||||
|
||||
model.load_weights(self._get_weights_iterator())
|
||||
return model.eval()
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self.tensorizer_config.verify_with_model_config(model_config)
|
||||
|
||||
with self.tensorizer_config.open_stream():
|
||||
pass
|
||||
|
||||
def _patch_tensorizer_config(self, model_config: ModelConfig) -> TensorizerConfig:
|
||||
model_class = get_model_architecture(model_config)[0]
|
||||
tensorizer_config = copy.copy(self.tensorizer_config)
|
||||
tensorizer_config.model_class = model_class
|
||||
tensorizer_config.hf_config = model_config.hf_config
|
||||
tensorizer_config.dtype = model_config.dtype
|
||||
return tensorizer_config
|
||||
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
"""Load serialized model weights with tensorizer.
|
||||
|
||||
Expects a vLLM-tensorized model. See the
|
||||
examples/others/tensorize_vllm_model.py example script
|
||||
for serializing vLLM models."""
|
||||
if is_vllm_tensorized(self.tensorizer_config):
|
||||
tensorizer_config = self._patch_tensorizer_config(model_config)
|
||||
deserialize_tensorizer_model(model, tensorizer_config)
|
||||
else:
|
||||
model.load_weights(self._get_weights_iterator())
|
||||
|
||||
def load_model(
|
||||
self, vllm_config: VllmConfig, model_config: ModelConfig
|
||||
) -> nn.Module:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self._verify_config(model_config, parallel_config)
|
||||
|
||||
if parallel_config.tensor_parallel_size > 1:
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
self.tensorizer_config.tensorizer_uri = (
|
||||
self.tensorizer_config.tensorizer_uri % get_tensor_model_parallel_rank()
|
||||
)
|
||||
|
||||
if is_vllm_tensorized(self.tensorizer_config):
|
||||
tensorizer_config = self._patch_tensorizer_config(model_config)
|
||||
device_config = vllm_config.device_config
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = init_tensorizer_model(
|
||||
tensorizer_config=tensorizer_config, vllm_config=vllm_config
|
||||
)
|
||||
self.load_weights(model, model_config)
|
||||
return model
|
||||
return self._load_model_serialized_cpu(vllm_config=vllm_config)
|
||||
|
||||
@staticmethod
|
||||
def save_model(
|
||||
model: torch.nn.Module,
|
||||
tensorizer_config: TensorizerConfig | dict,
|
||||
model_config: ModelConfig,
|
||||
) -> None:
|
||||
if isinstance(tensorizer_config, dict):
|
||||
tensorizer_config = TensorizerConfig(**tensorizer_config)
|
||||
serialize_vllm_model(
|
||||
model=model,
|
||||
tensorizer_config=tensorizer_config,
|
||||
model_config=model_config,
|
||||
)
|
||||
118
vllm/model_executor/model_loader/tpu.py
Normal file
118
vllm/model_executor/model_loader/tpu.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.distributed.spmd as xs
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.distributed.tpu_distributed_utils import get_fqn, shard_model
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
initialize_model,
|
||||
process_weights_after_loading,
|
||||
)
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TPUModelLoader(DefaultModelLoader):
|
||||
"""
|
||||
A TPU model loader for model loading under SPMD mode.
|
||||
"""
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
model_config: ModelConfig,
|
||||
mesh: xs.Mesh | None = None,
|
||||
) -> nn.Module:
|
||||
# Initialize model and load weights on CPU. Then, during SPMD partition,
|
||||
# weights are sharded and transferred to TPUs.
|
||||
self.counter_before_loading_weights = time.perf_counter()
|
||||
model_config = vllm_config.model_config
|
||||
assert model_config.quantization is None, "Quantization not supported"
|
||||
target_device = torch.device("cpu")
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
model = initialize_model(vllm_config=vllm_config)
|
||||
|
||||
load_format = vllm_config.load_config.load_format
|
||||
if load_format != "dummy":
|
||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||
all_weights = self.get_all_weights(model_config, model)
|
||||
loaded_weights = model.load_weights(all_weights)
|
||||
self.counter_after_loading_weights = time.perf_counter()
|
||||
logger.info(
|
||||
"Loading weights took %.2f seconds",
|
||||
self.counter_after_loading_weights
|
||||
- self.counter_before_loading_weights,
|
||||
)
|
||||
# We only enable strict check for non-quantized models
|
||||
# that have loaded weights tracking currently.
|
||||
if model_config.quantization is None and loaded_weights is not None:
|
||||
weights_not_loaded = weights_to_load - loaded_weights
|
||||
if weights_not_loaded:
|
||||
raise ValueError(
|
||||
"Following weights were not initialized from "
|
||||
f"checkpoint: {weights_not_loaded}"
|
||||
)
|
||||
else:
|
||||
logger.info("Use dummy weight during weight loading.")
|
||||
|
||||
process_weights_after_loading(model, model_config, target_device)
|
||||
|
||||
counter_before_partition = time.perf_counter()
|
||||
model = model.eval()
|
||||
model = model.to("xla")
|
||||
shard_model(model, mesh)
|
||||
counter_after_partition = time.perf_counter()
|
||||
logger.info(
|
||||
"Partition model took %.2f seconds",
|
||||
counter_after_partition - counter_before_partition,
|
||||
)
|
||||
|
||||
# Ensure the model is properly loaded.
|
||||
self._check_model_is_loaded(mesh, model)
|
||||
|
||||
# Need to torch compile after model sharding are done. Because the
|
||||
# compiler hints ('xs.mark_sharding') are torch ops.
|
||||
if not model_config.is_multimodal_model:
|
||||
model.model = torch.compile(model.model, backend="openxla")
|
||||
else:
|
||||
model.language_model.model = torch.compile(
|
||||
model.language_model.model, backend="openxla"
|
||||
)
|
||||
return model
|
||||
|
||||
def _check_model_is_loaded(self, mesh: xs.Mesh | None, model: nn.Module) -> None:
|
||||
"""
|
||||
Ensure the model is properly loaded.
|
||||
1. All model parameters and buffers are on XLA device.
|
||||
2. Non-SPMD friendly layers are replaced as expected.
|
||||
"""
|
||||
device = xm.xla_device()
|
||||
device_type = str(device.type)
|
||||
|
||||
# Check parameters
|
||||
for name, param in model.named_parameters():
|
||||
assert param.device.type == device_type, (
|
||||
f"Parameter {name} is on {param.device.type} instead of {device_type}"
|
||||
)
|
||||
|
||||
# Check buffers
|
||||
for name, buffer in model.named_buffers():
|
||||
assert buffer.device.type == device_type, (
|
||||
f"Buffer {name} is on {buffer.device.type} instead of {device_type}"
|
||||
)
|
||||
|
||||
for module in model.modules():
|
||||
if (mesh is not None) and (get_fqn(module) == "QKVParallelLinear"):
|
||||
raise AssertionError(
|
||||
"QKVParallelLinear should be replaced by \
|
||||
XlaQKVParallelLinear under SPMD mode."
|
||||
)
|
||||
@@ -1,41 +1,292 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utilities for selecting and loading models."""
|
||||
import contextlib
|
||||
from typing import Tuple, Type
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from typing_extensions import assert_never
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.attention.layer import Attention, MLAAttention
|
||||
from vllm.config import ModelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
)
|
||||
from vllm.model_executor.models.interfaces import SupportsQuant, supports_multimodal
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_default_torch_dtype(dtype: torch.dtype):
|
||||
"""Sets the default torch dtype to the given dtype."""
|
||||
old_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(dtype)
|
||||
yield
|
||||
torch.set_default_dtype(old_dtype)
|
||||
def initialize_model(
|
||||
vllm_config: VllmConfig,
|
||||
*,
|
||||
prefix: str = "",
|
||||
model_class: type[nn.Module] | None = None,
|
||||
model_config: ModelConfig | None = None,
|
||||
) -> nn.Module:
|
||||
"""Initialize a model with the given configurations."""
|
||||
if model_config is None:
|
||||
model_config = vllm_config.model_config
|
||||
if model_class is None:
|
||||
model_class, _ = get_model_architecture(model_config)
|
||||
|
||||
if vllm_config.quant_config is not None:
|
||||
configure_quant_config(vllm_config.quant_config, model_class)
|
||||
|
||||
signatures = inspect.signature(model_class.__init__)
|
||||
all_params = [param.name for param in signatures.parameters.values()]
|
||||
if "vllm_config" in all_params and "prefix" in all_params:
|
||||
# new-style model class
|
||||
with set_current_vllm_config(vllm_config, check_compile=True, prefix=prefix):
|
||||
return model_class(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
msg = (
|
||||
"vLLM model class should accept `vllm_config` and `prefix` as "
|
||||
"input arguments. Possibly you have an old-style model class"
|
||||
" registered from out of tree and it is used for new vLLM version. "
|
||||
"Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
|
||||
"for the design and update the model class accordingly."
|
||||
)
|
||||
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
||||
|
||||
logger.warning(
|
||||
"Trying to guess the arguments for old-style model class %s",
|
||||
model_class,
|
||||
)
|
||||
# try to be compatible with old-style model class
|
||||
kwargs = {}
|
||||
if "prefix" in all_params:
|
||||
kwargs["prefix"] = prefix
|
||||
if "config" in all_params:
|
||||
kwargs["config"] = model_config.hf_config
|
||||
if "cache_config" in all_params:
|
||||
kwargs["cache_config"] = vllm_config.cache_config
|
||||
if "quant_config" in all_params:
|
||||
kwargs["quant_config"] = vllm_config.quant_config
|
||||
if "lora_config" in all_params:
|
||||
kwargs["lora_config"] = vllm_config.lora_config
|
||||
if "scheduler_config" in all_params:
|
||||
kwargs["scheduler_config"] = vllm_config.scheduler_config
|
||||
with set_current_vllm_config(vllm_config, check_compile=True, prefix=prefix):
|
||||
return model_class(**kwargs)
|
||||
|
||||
|
||||
def get_model_architecture(
|
||||
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
|
||||
def process_weights_after_loading(
|
||||
model: nn.Module, model_config: ModelConfig, target_device: torch.device
|
||||
) -> None:
|
||||
if getattr(model, "process_weights_after_loading_already_called", False):
|
||||
# In case `process_weights_after_loading` is called multiple times
|
||||
# we'll skip it at later times
|
||||
logger.debug_once(
|
||||
"process_weights_after_loading already called for model %s", model
|
||||
)
|
||||
return
|
||||
|
||||
# to avoid circular dependency
|
||||
from vllm.model_executor.model_loader.online_quantization import (
|
||||
maybe_save_metadata_and_attributes_for_weight_reloading,
|
||||
)
|
||||
|
||||
maybe_save_metadata_and_attributes_for_weight_reloading(model, model_config)
|
||||
|
||||
for _, module in model.named_modules():
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if isinstance(quant_method, QuantizeMethodBase):
|
||||
# When quant methods need to process weights after loading
|
||||
# (for repacking, quantizing, etc), they expect parameters
|
||||
# to be on the global target device. This scope is for the
|
||||
# case where cpu offloading is used, where we will move the
|
||||
# parameters onto device for processing and back off after.
|
||||
with device_loading_context(module, target_device):
|
||||
quant_method.process_weights_after_loading(module)
|
||||
|
||||
# Initialize post-load attention weights for both Attention and MLA.
|
||||
# NOTE: Happens after other modules so we can easily decompress weights.
|
||||
for _, module in model.named_modules():
|
||||
if isinstance(module, (Attention, MLAAttention)) and hasattr(
|
||||
module, "process_weights_after_loading"
|
||||
):
|
||||
# TODO(lucas): see if there is a way to unify the signatures
|
||||
# of process_weights_after_loading
|
||||
module.process_weights_after_loading(model_config.dtype)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def device_loading_context(module: torch.nn.Module, target_device: torch.device):
|
||||
if target_device.type == "cpu":
|
||||
# If target is CPU, no need to move anything
|
||||
yield module
|
||||
return
|
||||
|
||||
original_device_states: dict[str, torch.device] = {}
|
||||
|
||||
# Store original device states and move parameters to GPU if they're on CPU
|
||||
for name, p in module.named_parameters():
|
||||
if p.device.type == "cpu":
|
||||
original_device_states[name] = p.device
|
||||
p.data = p.data.to(target_device)
|
||||
# Parameters already on target device are not touched
|
||||
|
||||
try:
|
||||
yield module
|
||||
|
||||
finally:
|
||||
# Restore parameters to their original devices, ignoring new parameters
|
||||
pin_memory = is_pin_memory_available()
|
||||
for name, p in module.named_parameters():
|
||||
if name in original_device_states:
|
||||
original_device: torch.device = original_device_states[name]
|
||||
if original_device.type == "cpu":
|
||||
# `torch.empty_like` does not support `pin_memory` argument
|
||||
cpu_data = torch.empty_strided(
|
||||
size=p.data.size(),
|
||||
stride=p.data.stride(),
|
||||
dtype=p.data.dtype,
|
||||
layout=p.data.layout,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
cpu_data.copy_(p.data)
|
||||
p.data = cpu_data
|
||||
else:
|
||||
p.data = p.data.to(original_device)
|
||||
# New parameters or parameters already on target device are untouched
|
||||
|
||||
|
||||
_MODEL_ARCH_BY_HASH = dict[int, tuple[type[nn.Module], str]]()
|
||||
"""Caches the outputs of `_get_model_architecture`."""
|
||||
|
||||
|
||||
def _get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]:
|
||||
from vllm.model_executor.models.adapters import (
|
||||
as_embedding_model,
|
||||
as_seq_cls_model,
|
||||
try_create_mm_pooling_model_cls,
|
||||
)
|
||||
|
||||
architectures = getattr(model_config.hf_config, "architectures", [])
|
||||
# Special handling for quantized Mixtral.
|
||||
# FIXME(woosuk): This is a temporary hack.
|
||||
if (model_config.quantization is not None
|
||||
and model_config.quantization != "fp8"
|
||||
and "MixtralForCausalLM" in architectures):
|
||||
architectures = ["QuantMixtralForCausalLM"]
|
||||
|
||||
for arch in architectures:
|
||||
model_cls = ModelRegistry.load_model_cls(arch)
|
||||
if model_cls is not None:
|
||||
return (model_cls, arch)
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} are not supported for now. "
|
||||
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
|
||||
model_cls, arch = model_config.registry.resolve_model_cls(
|
||||
architectures,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
if arch == model_config._get_transformers_backend_cls():
|
||||
assert model_config.model_impl != "vllm"
|
||||
if model_config.model_impl == "auto":
|
||||
logger.warning_once(
|
||||
"%s has no vLLM implementation, falling back to Transformers "
|
||||
"implementation. Some features may not be supported and "
|
||||
"performance may not be optimal.",
|
||||
arch,
|
||||
)
|
||||
|
||||
convert_type = model_config.convert_type
|
||||
if convert_type != "none" and supports_multimodal(model_cls):
|
||||
logger.debug_once("Detected conversion of Multi Modal model.")
|
||||
converted = try_create_mm_pooling_model_cls(model_cls)
|
||||
if converted is not None:
|
||||
logger.debug_once("Creating wrapper class to forward pooler.")
|
||||
return converted, arch
|
||||
else:
|
||||
logger.debug_once("Attempting direct conversion.")
|
||||
|
||||
if convert_type == "none":
|
||||
pass
|
||||
elif convert_type == "embed":
|
||||
logger.debug_once("Converting to embedding model.")
|
||||
model_cls = as_embedding_model(model_cls)
|
||||
elif convert_type == "classify":
|
||||
logger.debug_once("Converting to sequence classification model.")
|
||||
model_cls = as_seq_cls_model(model_cls)
|
||||
else:
|
||||
assert_never(convert_type)
|
||||
|
||||
return model_cls, arch
|
||||
|
||||
|
||||
def get_model_architecture(model_config: ModelConfig) -> tuple[type[nn.Module], str]:
|
||||
key = hash(
|
||||
(
|
||||
model_config.model,
|
||||
model_config.convert_type,
|
||||
model_config.runner_type,
|
||||
model_config.trust_remote_code,
|
||||
model_config.model_impl,
|
||||
tuple(getattr(model_config.hf_config, "architectures", [])),
|
||||
)
|
||||
)
|
||||
if key in _MODEL_ARCH_BY_HASH:
|
||||
return _MODEL_ARCH_BY_HASH[key]
|
||||
|
||||
model_arch = _get_model_architecture(model_config)
|
||||
_MODEL_ARCH_BY_HASH[key] = model_arch
|
||||
return model_arch
|
||||
|
||||
|
||||
def get_model_cls(model_config: ModelConfig) -> type[nn.Module]:
|
||||
return get_model_architecture(model_config)[0]
|
||||
|
||||
|
||||
def get_architecture_class_name(model_config: ModelConfig) -> str:
|
||||
return get_model_architecture(model_config)[1]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParamMapping:
|
||||
"""
|
||||
A class to handle parameter mapping for model weight loading.
|
||||
It creates a bidirectional mapping between packed parameters and their
|
||||
constituent parts.
|
||||
"""
|
||||
|
||||
packed_mapping: dict[str, list[str]]
|
||||
inverse_packed_mapping: dict[str, tuple[str, int]] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
for packed_name, sub_params in self.packed_mapping.items():
|
||||
# Skip self-contained cases (e.g., {"W_pack": ["W_pack"]})
|
||||
if len(sub_params) == 1 and sub_params[0] == packed_name:
|
||||
continue
|
||||
for index, param_name in enumerate(sub_params):
|
||||
self.inverse_packed_mapping[param_name] = (
|
||||
packed_name,
|
||||
index,
|
||||
)
|
||||
|
||||
def get_sub_modules(self, module_name: str) -> tuple[str, list[str]] | None:
|
||||
for key, value in self.packed_mapping.items():
|
||||
if module_name.endswith(key):
|
||||
return key, value
|
||||
return None
|
||||
|
||||
|
||||
def configure_quant_config(
|
||||
quant_config: QuantizationConfig, model_class: type[nn.Module]
|
||||
):
|
||||
"""
|
||||
Pass packed_modules_mapping by reference to quant_config so that
|
||||
quant_config can properly match fused modules
|
||||
|
||||
Note that model attributes are passed by reference to quant_config,
|
||||
enabling them to be updated by model_class.__new__ (ex. chatglm, qwen)
|
||||
|
||||
Once the `SupportsQuant` mixin has been added to all models, this
|
||||
function can be removed
|
||||
"""
|
||||
if not issubclass(model_class, SupportsQuant):
|
||||
hf_to_vllm_mapper = getattr(model_class, "hf_to_vllm_mapper", None)
|
||||
packed_mapping = getattr(model_class, "packed_modules_mapping", None)
|
||||
|
||||
# pass mappings by reference to quant_config
|
||||
if hf_to_vllm_mapper is not None:
|
||||
quant_config.apply_vllm_mapper(hf_to_vllm_mapper)
|
||||
if packed_mapping is not None:
|
||||
quant_config.packed_modules_mapping = packed_mapping
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user