# ruff: noqa: SIM117 import collections import copy import dataclasses import fnmatch import glob import inspect import json import math import os from abc import ABC, abstractmethod from contextlib import contextmanager from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast import gguf import huggingface_hub import numpy as np import torch from huggingface_hub import HfApi, hf_hub_download from torch import nn from transformers import AutoModelForCausalLM from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from vllm.config import (LoadConfig, LoadFormat, ModelConfig, ParallelConfig, VllmConfig) from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger from vllm.model_executor.layers.linear import (ReplicatedLinear, RowParallelLinear) from vllm.model_executor.model_loader.tensorizer import ( TensorizerConfig, is_vllm_tensorized, load_with_tensorizer, serialize_vllm_model, tensorizer_weights_iterator) from vllm.model_executor.model_loader.utils import (get_model_architecture, set_default_torch_dtype) from vllm.model_executor.model_loader.weight_utils import ( download_safetensors_index_file_from_hf, download_weights_from_hf, filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, get_gguf_extra_tensor_names, gguf_quant_weights_iterator, initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator, safetensors_weights_iterator) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.utils import is_pin_memory_available @contextmanager def device_loading_context(module: torch.nn.Module, target_device: torch.device): if target_device.type == "cpu": # If target is CPU, no need to move anything yield module return original_device_states: Dict[str, torch.device] = {} # Store original device states and move parameters to GPU if they're on CPU for name, p in module.named_parameters(): if p.device.type == "cpu": original_device_states[name] = p.device p.data = p.data.to(target_device) # Parameters already on target device are not touched try: yield module finally: # Restore parameters to their original devices, ignoring new parameters pin_memory = is_pin_memory_available() for name, p in module.named_parameters(): if name in original_device_states: original_device: torch.device = original_device_states[name] if original_device.type == "cpu": # `torch.empty_like` does not support `pin_memory` argument cpu_data = torch.empty_strided(size=p.data.size(), stride=p.data.stride(), dtype=p.data.dtype, layout=p.data.layout, device="cpu", pin_memory=pin_memory) cpu_data.copy_(p.data) p.data = cpu_data else: p.data = p.data.to(original_device) # New parameters or parameters already on target device are untouched logger = init_logger(__name__) def _get_device_memory_info_loader(): """Get device memory info for debug logging. Returns dict or None.""" try: import torch.mlu allocated = torch.mlu.memory_allocated() / (1024 ** 3) reserved = torch.mlu.memory_reserved() / (1024 ** 3) free, total = torch.mlu.mem_get_info() return {"allocated": allocated, "reserved": reserved, "free": free / (1024 ** 3), "total": total / (1024 ** 3)} except Exception: pass try: if torch.cuda.is_available(): allocated = torch.cuda.memory_allocated() / (1024 ** 3) reserved = torch.cuda.memory_reserved() / (1024 ** 3) free, total = torch.cuda.mem_get_info() return {"allocated": allocated, "reserved": reserved, "free": free / (1024 ** 3), "total": total / (1024 ** 3)} except Exception: pass return None def _log_mem(tag: str): info = _get_device_memory_info_loader() if info: logger.info( "[DEBUG-MEM] %s: allocated=%.2f GiB, reserved=%.2f GiB, " "free=%.2f GiB, total=%.2f GiB", tag, info["allocated"], info["reserved"], info["free"], info["total"]) def _initialize_model(vllm_config: VllmConfig, prefix: str = "") -> nn.Module: """Initialize a model with the given configurations.""" model_config = vllm_config.model_config model_class, _ = get_model_architecture(model_config) logger.info("[DEBUG-MEM] Model class: %s, dtype: %s", model_class.__name__, model_config.dtype) _log_mem("Before _initialize_model") signatures = inspect.signature(model_class.__init__) all_params = [param.name for param in signatures.parameters.values()] if "vllm_config" in all_params and "prefix" in all_params: # new-style model class model = model_class(vllm_config=vllm_config, prefix=prefix) _log_mem("After _initialize_model (empty weights created)") # Print model parameter summary total_params = 0 total_bytes = 0 for name, param in model.named_parameters(): total_params += param.numel() total_bytes += param.numel() * param.element_size() logger.info( "[DEBUG-MEM] Model params: %d, " "estimated size: %.2f GiB", total_params, total_bytes / (1024 ** 3)) return model msg = ("vLLM model class should accept `vllm_config` and `prefix` as " "input arguments. Possibly you have an old-style model class" " registered from out of tree and it is used for new vLLM version. " "Check https://docs.vllm.ai/en/latest/design/class_hierarchy.html " "for the design and update the model class accordingly.") logger.warning(msg) logger.warning( "Trying to guess the arguments for old-style model class %s", model_class) # try to be compatible with old-style model class kwargs = {} if "prefix" in all_params: kwargs["prefix"] = prefix if "config" in all_params: kwargs["config"] = model_config.hf_config if "cache_config" in all_params: kwargs["cache_config"] = vllm_config.cache_config if "quant_config" in all_params: kwargs["quant_config"] = vllm_config.quant_config if "lora_config" in all_params: kwargs["lora_config"] = vllm_config.lora_config if "scheduler_config" in all_params: kwargs["scheduler_config"] = vllm_config.scheduler_config return model_class(**kwargs) class BaseModelLoader(ABC): """Base class for model loaders.""" def __init__(self, load_config: LoadConfig): self.load_config = load_config @abstractmethod def download_model(self, model_config: ModelConfig) -> None: """Download a model so that it can be immediately loaded.""" raise NotImplementedError @abstractmethod def load_model(self, *, vllm_config: VllmConfig) -> nn.Module: """Load a model with the given configurations.""" raise NotImplementedError class DefaultModelLoader(BaseModelLoader): """Model loader that can load different file types from disk.""" @dataclasses.dataclass class Source: """A source for weights.""" model_or_path: str """The model ID or path.""" revision: Optional[str] """The optional model revision.""" prefix: str = "" """A prefix to prepend to all weights.""" fall_back_to_pt: bool = True """Whether .pt weights can be used.""" def __init__(self, load_config: LoadConfig): super().__init__(load_config) if load_config.model_loader_extra_config: raise ValueError(f"Model loader extra config is not supported for " f"load format {load_config.load_format}") def _maybe_download_from_modelscope( self, model: str, revision: Optional[str]) -> Optional[str]: """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True. Returns the path to the downloaded model, or None if the model is not downloaded from ModelScope.""" if VLLM_USE_MODELSCOPE: # download model from ModelScope hub, # lazy import so that modelscope is not required for normal use. # pylint: disable=C. from modelscope.hub.snapshot_download import snapshot_download if not os.path.exists(model): model_path = snapshot_download( model_id=model, cache_dir=self.load_config.download_dir, local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, revision=revision, ignore_file_pattern=self.load_config.ignore_patterns, ) else: model_path = model return model_path return None def _prepare_weights(self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool) -> Tuple[str, List[str], bool]: """Prepare weights for the model. If the model is not local, it will be downloaded.""" model_name_or_path = self._maybe_download_from_modelscope( model_name_or_path, revision) or model_name_or_path is_local = os.path.isdir(model_name_or_path) load_format = self.load_config.load_format use_safetensors = False index_file = SAFE_WEIGHTS_INDEX_NAME # Some quantized models use .pt files for storing the weights. if load_format == LoadFormat.AUTO: allow_patterns = ["*.safetensors", "*.bin"] elif load_format == LoadFormat.SAFETENSORS: use_safetensors = True allow_patterns = ["*.safetensors"] elif load_format == LoadFormat.MISTRAL: use_safetensors = True allow_patterns = ["consolidated*.safetensors"] index_file = "consolidated.safetensors.index.json" elif load_format == LoadFormat.PT: allow_patterns = ["*.pt"] elif load_format == LoadFormat.NPCACHE: allow_patterns = ["*.bin"] else: raise ValueError(f"Unknown load_format: {load_format}") if fall_back_to_pt: allow_patterns += ["*.pt"] if not is_local: hf_folder = download_weights_from_hf( model_name_or_path, self.load_config.download_dir, allow_patterns, revision, ignore_patterns=self.load_config.ignore_patterns, ) else: hf_folder = model_name_or_path hf_weights_files: List[str] = [] for pattern in allow_patterns: hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) if len(hf_weights_files) > 0: if pattern == "*.safetensors": use_safetensors = True break if use_safetensors: # For models like Mistral-7B-Instruct-v0.3 # there are both sharded safetensors files and a consolidated # safetensors file. Using both breaks. # Here, we download the `model.safetensors.index.json` and filter # any files not found in the index. if not is_local: download_safetensors_index_file_from_hf( model_name_or_path, index_file, self.load_config.download_dir, revision) hf_weights_files = filter_duplicate_safetensors_files( hf_weights_files, hf_folder, index_file) else: hf_weights_files = filter_files_not_needed_for_inference( hf_weights_files) if len(hf_weights_files) == 0: raise RuntimeError( f"Cannot find any model weights with `{model_name_or_path}`") return hf_folder, hf_weights_files, use_safetensors def _get_weights_iterator( self, source: "Source" ) -> Generator[Tuple[str, torch.Tensor], None, None]: """Get an iterator for the model weights based on the load format.""" hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( source.model_or_path, source.revision, source.fall_back_to_pt) if self.load_config.load_format == LoadFormat.NPCACHE: # Currently np_cache only support *.bin checkpoints assert use_safetensors is False weights_iterator = np_cache_weights_iterator( source.model_or_path, self.load_config.download_dir, hf_folder, hf_weights_files) elif use_safetensors: weights_iterator = safetensors_weights_iterator(hf_weights_files) else: weights_iterator = pt_weights_iterator(hf_weights_files) if current_platform.is_tpu(): # In PyTorch XLA, we should call `xm.mark_step` frequently so that # not too many ops are accumulated in the XLA program. import torch_xla.core.xla_model as xm def _xla_weights_iterator(iterator: Generator): for weights in iterator: yield weights xm.mark_step() weights_iterator = _xla_weights_iterator(weights_iterator) # Apply the prefix. return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator) def _get_all_weights( self, model_config: ModelConfig, model: nn.Module, ) -> Generator[Tuple[str, torch.Tensor], None, None]: primary_weights = DefaultModelLoader.Source( model_config.model, model_config.revision, prefix="", fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True)) yield from self._get_weights_iterator(primary_weights) secondary_weights = cast(Iterable[DefaultModelLoader.Source], getattr(model, "secondary_weights", ())) for source in secondary_weights: yield from self._get_weights_iterator(source) def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model, model_config.revision, fall_back_to_pt=True) def load_model(self, vllm_config: VllmConfig) -> nn.Module: device_config = vllm_config.device_config model_config = vllm_config.model_config target_device = torch.device(device_config.device) _log_mem("load_model start, target_device=%s" % target_device) with set_default_torch_dtype(model_config.dtype): with target_device: model = _initialize_model(vllm_config=vllm_config) _log_mem("Before load_weights") model.load_weights(self._get_all_weights(model_config, model)) _log_mem("After load_weights") for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) if quant_method is not None: # When quant methods need to process weights after loading # (for repacking, quantizing, etc), they expect parameters # to be on the global target device. This scope is for the # case where cpu offloading is used, where we will move the # parameters onto device for processing and back off after. with device_loading_context(module, target_device): quant_method.process_weights_after_loading(module) return model.eval() class DummyModelLoader(BaseModelLoader): """Model loader that will set model weights to random values.""" def __init__(self, load_config: LoadConfig): super().__init__(load_config) if load_config.model_loader_extra_config: raise ValueError(f"Model loader extra config is not supported for " f"load format {load_config.load_format}") def download_model(self, model_config: ModelConfig) -> None: pass # Nothing to download def load_model(self, vllm_config: VllmConfig) -> nn.Module: device_config = vllm_config.device_config model_config = vllm_config.model_config with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model = _initialize_model(vllm_config=vllm_config) # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. initialize_dummy_weights(model) for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) if quant_method is not None: # When quant methods need to process weights after loading # (for repacking, quantizing, etc), they expect parameters # to be on the global target device. This scope is for the # case where cpu offloading is used, where we will move the # parameters onto device for processing and back off after. with device_loading_context( module, torch.device(device_config.device)): quant_method.process_weights_after_loading(module) return model.eval() class TensorizerLoader(BaseModelLoader): """Model loader using CoreWeave's tensorizer library.""" def __init__(self, load_config: LoadConfig): super().__init__(load_config) if isinstance(load_config.model_loader_extra_config, TensorizerConfig): self.tensorizer_config = load_config.model_loader_extra_config else: self.tensorizer_config = TensorizerConfig( **load_config.model_loader_extra_config) def _verify_config(self, model_config: ModelConfig, parallel_config: ParallelConfig): self.tensorizer_config.verify_with_model_config(model_config) self.tensorizer_config.verify_with_parallel_config(parallel_config) def _get_weights_iterator( self) -> Generator[Tuple[str, torch.Tensor], None, None]: tensorizer_args = self.tensorizer_config._construct_tensorizer_args() return tensorizer_weights_iterator(tensorizer_args) def _load_model_serialized_cpu( self, vllm_config: VllmConfig, ) -> nn.Module: """Load a serialized model with tensorizer to the CPU. This is only necessary when the model isn't vLLM-tensorized (see examples/tensorize_vllm_model.py) This should still be faster than default HuggingFace loading, but will be slower than loading a vLLM-tensorized model. """ device_config = vllm_config.device_config model_config = vllm_config.model_config with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model = _initialize_model(vllm_config=vllm_config) model.load_weights(self._get_weights_iterator()) return model.eval() def _load_model_serialized( self, vllm_config: VllmConfig, ) -> nn.Module: """Load a serialized model with tensorizer. Expects a vLLM-tensorized model. See the examples/tensorize_vllm_model.py example script for serializing vLLM models.""" device_config = vllm_config.device_config model_config = vllm_config.model_config with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model_class = get_model_architecture(model_config)[0] tensorizer_config = copy.copy(self.tensorizer_config) tensorizer_config.model_class = model_class tensorizer_config.hf_config = model_config.hf_config tensorizer_config.dtype = model_config.dtype model = load_with_tensorizer(tensorizer_config, vllm_config=vllm_config) return model.eval() def download_model(self, model_config: ModelConfig) -> None: self.tensorizer_config.verify_with_model_config(model_config) with self.tensorizer_config.open_stream(): pass def load_model(self, vllm_config: VllmConfig) -> nn.Module: model_config = vllm_config.model_config parallel_config = vllm_config.parallel_config self._verify_config(model_config, parallel_config) if parallel_config.tensor_parallel_size > 1: from vllm.distributed import get_tensor_model_parallel_rank self.tensorizer_config.tensorizer_uri = \ self.tensorizer_config.tensorizer_uri \ % get_tensor_model_parallel_rank() if is_vllm_tensorized(self.tensorizer_config): return self._load_model_serialized(vllm_config=vllm_config) return self._load_model_serialized_cpu(vllm_config=vllm_config) @staticmethod def save_model( model: torch.nn.Module, tensorizer_config: TensorizerConfig, ) -> None: serialize_vllm_model( model=model, tensorizer_config=tensorizer_config, ) class ShardedStateLoader(BaseModelLoader): """ Model loader that directly loads each worker's model state dict, which enables a fast load path for large tensor-parallel models where each worker only needs to read its own shard rather than the entire checkpoint. See `examples/save_sharded_state.py` for creating a sharded checkpoint. """ DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors" def __init__(self, load_config: LoadConfig): super().__init__(load_config) extra_config = ({} if load_config.model_loader_extra_config is None else load_config.model_loader_extra_config.copy()) self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN) if extra_config: raise ValueError(f"Unexpected extra config keys for load format " f"{load_config.load_format}: " f"{load_config.model_loader_extra_config.keys()}") @staticmethod def _filter_subtensors( tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: """ Filter out all tensors that share the same memory or a subset of the memory of another tensor. """ same_storage_groups: Dict[Any, List[Tuple[ str, torch.Tensor]]] = collections.defaultdict(list) for key, tensor in tensors.items(): if tensor.numel(): ptr = tensor.untyped_storage().data_ptr() same_storage_groups[tensor.device, ptr].append((key, tensor)) def get_end_ptr(tensor: torch.Tensor) -> int: return tensor.view(-1)[-1].data_ptr() + tensor.element_size() result: Dict[str, torch.Tensor] = {} for group in same_storage_groups.values(): for k, t in group: a, b = t.data_ptr(), get_end_ptr(t) for k2, t2 in group: if not t2.is_contiguous(): continue a2, b2 = t2.data_ptr(), get_end_ptr(t2) if a < a2 or b2 < b: continue if a2 < a or b < b2 or not t.is_contiguous(): break # t2 covers strictly more memory than t. if k2 < k: # Same tensors, keep the one with the smaller key. break else: result[k] = t return result def _prepare_weights(self, model_name_or_path: str, revision: Optional[str]): if os.path.isdir(model_name_or_path): return model_name_or_path else: allow_patterns = ["*.safetensors"] return download_weights_from_hf( model_name_or_path, self.load_config.download_dir, allow_patterns, revision, ignore_patterns=self.load_config.ignore_patterns, ) def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model, model_config.revision) def load_model(self, vllm_config: VllmConfig) -> nn.Module: device_config = vllm_config.device_config model_config = vllm_config.model_config from safetensors.torch import safe_open from vllm.distributed import get_tensor_model_parallel_rank local_model_path = self._prepare_weights(model_config.model, model_config.revision) with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model = _initialize_model(vllm_config=vllm_config) for _, module in model.named_modules(): quant_method = getattr(module, "quant_method", None) if quant_method is not None: quant_method.process_weights_after_loading(module) rank = get_tensor_model_parallel_rank() pattern = os.path.join( local_model_path, self.pattern.format(rank=rank, part="*"), ) filepaths = glob.glob(pattern) if not filepaths: # TODO: support un-sharded checkpoints too raise ValueError( f"Could not find checkpoint files '{pattern}', only " f"pre-sharded checkpoints are currently supported!") state_dict = self._filter_subtensors(model.state_dict()) for path in filepaths: with safe_open(path, framework="pt") as f: for key in f.keys(): # noqa: SIM118 tensor = f.get_tensor(key) # If loading with LoRA enabled, additional padding may # be added to certain parameters. We only load into a # narrowed view of the parameter data. param_data = state_dict[key].data param_shape = state_dict[key].shape for dim, size in enumerate(tensor.shape): if size < param_shape[dim]: param_data = param_data.narrow(dim, 0, size) if tensor.shape != param_shape: logger.warning( "loading tensor of shape %s into " "parameter '%s' of shape %s", tensor.shape, key, param_shape) param_data.copy_(tensor) state_dict.pop(key) if state_dict: raise ValueError( f"Missing keys {tuple(state_dict)} in loaded state!") return model.eval() @staticmethod def save_model( model: torch.nn.Module, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None, ) -> None: from safetensors.torch import save_file from vllm.distributed import get_tensor_model_parallel_rank if pattern is None: pattern = ShardedStateLoader.DEFAULT_PATTERN rank = get_tensor_model_parallel_rank() part_idx = 0 total_size = 0 state_dict = ShardedStateLoader._filter_subtensors(model.state_dict()) state_dict_part: Dict[str, torch.Tensor] = {} for key, tensor in state_dict.items(): param_size = tensor.nelement() * tensor.element_size() if max_size is not None and total_size + param_size > max_size: filename = pattern.format(rank=rank, part=part_idx) save_file( state_dict_part, os.path.join(path, filename), ) part_idx += 1 total_size = 0 state_dict_part = {} state_dict_part[key] = tensor total_size += param_size if len(state_dict_part) > 0: filename = pattern.format(rank=rank, part=part_idx) save_file( state_dict_part, os.path.join(path, filename), ) class BitsAndBytesModelLoader(BaseModelLoader): """Model loader to load model weights with BitAndBytes quantization.""" possible_config_file_names = ["adapter_config.json"] default_target_modules = [ ".gate_proj.", ".down_proj.", ".up_proj.", ".q_proj.", ".k_proj.", ".v_proj.", ".o_proj.", '.fc1.', '.fc2.', '.dense.', '.query_key_value.', '.qkv_proj.', '.dense_h_to_4h.', '.dense_4h_to_h.', '.out_proj.', ] def __init__(self, load_config: LoadConfig): super().__init__(load_config) # Save the module names without sharding. self.unsharded_weights_modules: List[str] = [] # Save the module names that are sharded by column. self.column_sharded_weights_modules: List[str] = [] # we don't need to quantize the whole model, only the target modules # that are specified in the adapter config file. If the adapter config # file is not provided, we will quantize the default modules. if (not load_config.model_loader_extra_config or "qlora_adapter_name_or_path" not in load_config.model_loader_extra_config): self.target_modules = [] return qlora_adapter = load_config.model_loader_extra_config[ "qlora_adapter_name_or_path"] config_file_path = self._get_config_file(qlora_adapter) with open(config_file_path) as f: config = json.load(f) self.target_modules = config["target_modules"] def _get_config_file(self, qlora_adapter: str) -> str: is_local = os.path.isdir(qlora_adapter) config_file_path = None if is_local: for file in self.possible_config_file_names: config_file_path = os.path.join(qlora_adapter, file) if os.path.exists(config_file_path): break else: hf_api = HfApi() repo_files = hf_api.list_repo_files(repo_id=qlora_adapter) for file in self.possible_config_file_names: if file in repo_files: config_file_path = hf_hub_download(repo_id=qlora_adapter, filename=file) break if not config_file_path: raise ValueError( f"Cannot find adapter config file in {qlora_adapter}") return config_file_path def _get_weight_files( self, model_name_or_path: str, allowed_patterns: List[str], revision: Optional[str] = None) -> Tuple[List[str], str]: """Retrieve weight files. Download the files if necessary. Return the weight files and the file pattern.""" is_local = os.path.isdir(model_name_or_path) if is_local: for pattern in allowed_patterns: weight_files = glob.glob( os.path.join(model_name_or_path, pattern)) if weight_files: return weight_files, pattern else: hf_api = HfApi() repo_files = hf_api.list_repo_files(repo_id=model_name_or_path) for pattern in allowed_patterns: matching_files = fnmatch.filter(repo_files, pattern) if matching_files: hf_folder = download_weights_from_hf( model_name_or_path, self.load_config.download_dir, [pattern], revision, ignore_patterns=self.load_config.ignore_patterns, ) return glob.glob(os.path.join(hf_folder, pattern)), pattern raise RuntimeError( f"No model weights found in: `{model_name_or_path}`") def _prepare_weights(self, model_name_or_path: str, revision: Optional[str]) -> Tuple[List[str], bool]: """Prepare weight files for the model.""" allowed_patterns = ["*.safetensors", "*.bin", "*.pt"] hf_weights_files, matched_pattern = self._get_weight_files( model_name_or_path, allowed_patterns, revision) if matched_pattern != "*.safetensors": hf_weights_files = filter_files_not_needed_for_inference( hf_weights_files) if len(hf_weights_files) == 0: raise RuntimeError( f"Cannot find any model weights with `{model_name_or_path}`") return hf_weights_files, matched_pattern == "*.safetensors" def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): if use_safetensors: return safetensors_weights_iterator(hf_weights_files) else: return pt_weights_iterator(hf_weights_files) def _get_quantized_weights_iterator( self, model_name_or_path: str, revision: Optional[str], pre_quant: bool, load_8bit: bool, ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str, Any]]: """Get an iterator to the model weights with bitsandbytes quantization, as well as the quantization state dictionary.""" # only load the bitsandbytes module when needed try: import bitsandbytes if bitsandbytes.__version__ < "0.44.0": raise ImportError("bitsandbytes version is wrong. Please " "install bitsandbytes>=0.44.0.") except ImportError as err: raise ImportError("Please install bitsandbytes>=0.44.0 via " "`pip install bitsandbytes>=0.44.0` to use " "bitsandbytes quantizer.") from err hf_weights_files, use_safetensors = self._prepare_weights( model_name_or_path, revision) quant_state_dict: Dict[str, Any] = {} if pre_quant: if load_8bit: return self._quantized_8bit_generator( hf_weights_files, use_safetensors, quant_state_dict), quant_state_dict else: return self._quantized_4bit_generator( hf_weights_files, use_safetensors, quant_state_dict), quant_state_dict return self._unquantized_generator(hf_weights_files, use_safetensors, quant_state_dict), quant_state_dict def _is_8bit_weight_name(self, weight_name: str): quantized_suffix = {".scb", ".weight_format"} return any(weight_name.lower().endswith(suffix) for suffix in quantized_suffix) def _is_4bit_weight_name(self, weight_name: str): quantized_suffix = { "absmax", "quant_map", "nested_absmax", "nested_quant_map", "bitsandbytes" } suffix = weight_name.split(".")[-1] return any(q_suffix in suffix for q_suffix in quantized_suffix) def _quantized_8bit_generator(self, hf_weights_files, use_safetensors, quant_state_dict) -> Generator: for weight_name, weight_tensor in self._hf_weight_iter( hf_weights_files, use_safetensors): if not weight_name.lower().endswith(".scb"): continue weight_key = weight_name.lower().replace(".scb", ".weight") quant_state_dict[weight_key] = weight_tensor for weight_name, weight_tensor in self._hf_weight_iter( hf_weights_files, use_safetensors): if self._is_8bit_weight_name(weight_name): continue if weight_name in quant_state_dict: set_weight_attrs(weight_tensor, {"load_in_8bit": True}) yield weight_name, weight_tensor else: yield weight_name, weight_tensor def _quantized_4bit_generator(self, hf_weights_files, use_safetensors, quant_state_dict) -> Generator: from bitsandbytes.functional import QuantState # First iterate over all quant state weights weight_iterator = self._hf_weight_iter(hf_weights_files, use_safetensors) temp_state_dict = {} for weight_name, weight_tensor in weight_iterator: if not self._is_4bit_weight_name(weight_name): continue # bitsandbytes library requires # weight.quant_state.bitsandbytes__* in CPU if "quant_state.bitsandbytes" in weight_name: temp_state_dict[weight_name] = weight_tensor.cpu().data else: temp_state_dict[weight_name] = weight_tensor # Closure to parse quant_state for each prequant weight def _parse_quant_state(param_name: str, temp_state_dict: Dict) -> QuantState: quant_state = {} for k in temp_state_dict: if param_name + "." in k: quant_state[k] = temp_state_dict[k] return QuantState.from_dict(quant_state, device="cuda") # Second iterate over all prequant and normal weights # pre quantized weights would have a quant_state for weight_name, weight_tensor in self._hf_weight_iter( hf_weights_files, use_safetensors): if self._is_4bit_weight_name(weight_name): continue if (f"{weight_name}.quant_state.bitsandbytes__nf4" \ in temp_state_dict) or \ (f"{weight_name}.quant_state.bitsandbytes__fp4" \ in temp_state_dict): quant_state = _parse_quant_state(weight_name, temp_state_dict) quant_state_dict[weight_name] = quant_state yield weight_name, weight_tensor else: yield weight_name, weight_tensor def _unquantized_generator(self, hf_weights_files, use_safetensors, quant_state_dict) -> Generator: from bitsandbytes.functional import quantize_4bit tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() for weight_name, weight_tensor in self._hf_weight_iter( hf_weights_files, use_safetensors): if any(target_module in weight_name for target_module in self.target_modules) and weight_name.endswith(".weight"): # Without sharding if any( weight_name.startswith(module) for module in self.unsharded_weights_modules): weight_sub_tensor = weight_tensor # Shard by column elif any( weight_name.startswith(module) for module in self.column_sharded_weights_modules): total_size = weight_tensor.size(-1) start_index = total_size // tp_size * tp_rank end_index = total_size // tp_size * (tp_rank + 1) weight_sub_tensor = weight_tensor[..., start_index:end_index] # Shard by row else: total_size = weight_tensor.size(0) start_index = total_size // tp_size * tp_rank end_index = total_size // tp_size * (tp_rank + 1) weight_sub_tensor = weight_tensor[start_index:end_index, ...] # bitsandbytes requires data in GPU if weight_sub_tensor.is_cuda: loaded_weight = weight_sub_tensor else: loaded_weight = weight_sub_tensor.cuda() # remove the following after the issue is fixed: # https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342 if loaded_weight.is_contiguous() is False: loaded_weight = loaded_weight.contiguous() with set_default_torch_dtype(torch.float32): processed_weight, quant_state = quantize_4bit( loaded_weight, compress_statistics=True, quant_type="nf4") quant_state_dict[weight_name] = quant_state else: processed_weight = weight_tensor yield weight_name, processed_weight def _load_weights(self, model_config: ModelConfig, model: nn.Module) -> None: if not hasattr(model, 'load_weights'): raise AttributeError( "The required method 'load_weights' is not defined in class" f" {type(model).__name__}.") if not hasattr(model, 'bitsandbytes_stacked_params_mapping'): raise AttributeError( f"Model {type(model).__name__} does not support BitsAndBytes " "quantization yet.") if len(self.target_modules) == 0: if hasattr(model, 'default_bitsandbytes_target_modules'): self.target_modules = model.default_bitsandbytes_target_modules else: self.target_modules = self.default_target_modules for name, module in model.named_modules(): # Some modules like `ReplicatedLinear` should not have their weights # sharded. The reason for implementing it this way is to avoid new # static variable in the model implementation. if isinstance(module, (ReplicatedLinear, )): self.unsharded_weights_modules.append(name) # In TP, these weights are partitioned along the column # dimension (dim=-1) elif isinstance(module, (RowParallelLinear, )): self.column_sharded_weights_modules.append(name) self.model_type = type(model).__name__ logger.info("Loading weights with BitsAndBytes quantization. " " May take a while ...") quant_config = getattr(model_config.hf_config, "quantization_config", None) pre_quant = False if quant_config is not None: quant_method = quant_config.get('quant_method') if quant_method == "bitsandbytes": pre_quant = True else: raise ValueError( f"BitsAndBytes loader does not support {quant_method} " "quantization") # The quant_states in pre_quantized models cannot work with a split # weight tensor. So TP does not work with pre_quantized bnb models. if pre_quant and get_tensor_model_parallel_world_size() > 1: raise ValueError( "Prequant BitsAndBytes models with TP is not supported." "Please try with PP.") load_8bit = False if pre_quant: load_8bit = quant_config.get('load_in_8bit', False) qweight_iterator, quant_state_dict = \ self._get_quantized_weights_iterator( model_config.model, model_config.revision, pre_quant, load_8bit) model.load_weights(qweight_iterator) torch.cuda.empty_cache() param_dict = dict(model.named_parameters()) stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {} # TODO: Change this lazy import to normal import # after the checks are updated to run on a new version from vllm.model_executor.models.utils import is_pp_missing_parameter for quant_param_name in quant_state_dict: if is_pp_missing_parameter(quant_param_name, model): continue non_stacked_param_name = quant_param_name shard_index = 0 for shard_name, ( weight_name, index ) in model.bitsandbytes_stacked_params_mapping.items(): shard_pos = quant_param_name.find(shard_name) # Some models, such as MiniCPM V2.5/2.6, contain both # module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj' # from being incorrectly identified as being present in # 'vpm.encoder.layers.0.self_attn.qkv_proj.weight if shard_pos > 0 and quant_param_name[shard_pos - 1] == ".": shard_index = index quant_param_name = quant_param_name.replace( shard_name, weight_name) break if quant_param_name not in param_dict: raise ValueError( f"Parameter {quant_param_name} not found in the model.") if quant_param_name not in stacked_quant_state_dict: stacked_quant_state_dict[quant_param_name] = {} stacked_quant_state_dict[quant_param_name][shard_index] = ( quant_state_dict[non_stacked_param_name]) # save quant_states and offsets as the attributes of the parameters for param_name, param in param_dict.items(): if param_name in stacked_quant_state_dict: quant_states = stacked_quant_state_dict[param_name] set_weight_attrs(param, {"bnb_quant_state": quant_states}) pack_ratio = getattr(param, "pack_factor", -1) if pack_ratio == -1: raise ValueError( f"pack_factor not set for parameter {param_name}.") num_elements = [0] * len(quant_states) for seq, quant_state in quant_states.items(): num_elements[seq] = math.prod( quant_state.shape) // pack_ratio offsets = np.concatenate(([0], np.cumsum(num_elements))) set_weight_attrs(param, {"bnb_shard_offsets": offsets}) if load_8bit: set_weight_attrs( param, {"matmul_state": [None] * len(quant_states)}) def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model, model_config.revision) def load_model(self, vllm_config: VllmConfig) -> nn.Module: device_config = vllm_config.device_config model_config = vllm_config.model_config with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model = _initialize_model(vllm_config=vllm_config) self._load_weights(model_config, model) return model.eval() class GGUFModelLoader(BaseModelLoader): """ Model loader that can load GGUF files. This is useful for loading models that are quantized with GGUF and saved in the GGUF format. This loader supports loading both full models and sharded models. """ def __init__(self, load_config: LoadConfig): super().__init__(load_config) if load_config.model_loader_extra_config: raise ValueError(f"Model loader extra config is not supported for " f"load format {load_config.load_format}") def _prepare_weights(self, model_name_or_path: str): if os.path.isfile(model_name_or_path): return model_name_or_path else: raise ValueError(f"{model_name_or_path} is not a file.") def _get_gguf_weights_map(self, model_config: ModelConfig): """ GGUF uses this naming convention for their tensors from HF checkpoint: `blk.N.BB.weight` and `blk.N.BB.bias` where N signifies the block number of a layer, and BB signifies the attention/mlp layer components. See "Standardized tensor names" in https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details. """ config = model_config.hf_config model_type = config.model_type # hack: ggufs have a different name than transformers if model_type == "cohere": model_type = "command-r" arch = None for key, value in gguf.MODEL_ARCH_NAMES.items(): if value == model_type: arch = key break if arch is None: raise RuntimeError(f"Unknown gguf model_type: {model_type}") num_layers = config.num_hidden_layers name_map = gguf.get_tensor_name_map(arch, num_layers) with torch.device("meta"): dummy_model = AutoModelForCausalLM.from_config(config) state_dict = dummy_model.state_dict() gguf_to_hf_name_map = {} for hf_name in state_dict: name, suffix = hf_name.rsplit(".", 1) gguf_name = name_map.get_name(name) gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name return gguf_to_hf_name_map def _get_weights_iterator( self, model_name_or_path: str, gguf_to_hf_name_map: Dict[str, str] ) -> Generator[Tuple[str, torch.Tensor], None, None]: return gguf_quant_weights_iterator(model_name_or_path, gguf_to_hf_name_map) def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model) def load_model(self, vllm_config: VllmConfig) -> nn.Module: device_config = vllm_config.device_config model_config = vllm_config.model_config local_model_path = self._prepare_weights(model_config.model) gguf_weights_map = self._get_gguf_weights_map(model_config) # we can only know if tie word embeddings after mapping weights if "lm_head.weight" in get_gguf_extra_tensor_names( local_model_path, gguf_weights_map): model_config.hf_config.update({"tie_word_embeddings": True}) with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model = _initialize_model(vllm_config=vllm_config) model.load_weights( self._get_weights_iterator(local_model_path, gguf_weights_map)) return model def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: """Get a model loader based on the load format.""" if isinstance(load_config.load_format, type): return load_config.load_format(load_config) if load_config.load_format == LoadFormat.DUMMY: return DummyModelLoader(load_config) if load_config.load_format == LoadFormat.TENSORIZER: return TensorizerLoader(load_config) if load_config.load_format == LoadFormat.SHARDED_STATE: return ShardedStateLoader(load_config) if load_config.load_format == LoadFormat.BITSANDBYTES: return BitsAndBytesModelLoader(load_config) if load_config.load_format == LoadFormat.GGUF: return GGUFModelLoader(load_config) return DefaultModelLoader(load_config)