# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import dataclasses import glob import os import time from collections.abc import Generator, Iterable from typing import Optional, cast import torch from torch import nn from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from vllm.config import ModelConfig from vllm.config.load import LoadConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.weight_utils import ( download_safetensors_index_file_from_hf, download_weights_from_hf, fastsafetensors_weights_iterator, filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, maybe_download_from_modelscope, multi_thread_pt_weights_iterator, multi_thread_safetensors_weights_iterator, np_cache_weights_iterator, pt_weights_iterator, safetensors_weights_iterator) from vllm.platforms import current_platform logger = init_logger(__name__) class DefaultModelLoader(BaseModelLoader): """Model loader that can load different file types from disk.""" # default number of thread when enable multithread weight loading DEFAULT_NUM_THREADS = 8 @dataclasses.dataclass class Source: """A source for weights.""" model_or_path: str """The model ID or path.""" revision: Optional[str] """The optional model revision.""" prefix: str = "" """A prefix to prepend to all weights.""" fall_back_to_pt: bool = True """Whether .pt weights can be used.""" allow_patterns_overrides: Optional[list[str]] = None """If defined, weights will load exclusively using these patterns.""" counter_before_loading_weights: float = 0.0 counter_after_loading_weights: float = 0.0 def __init__(self, load_config: LoadConfig): super().__init__(load_config) extra_config = load_config.model_loader_extra_config allowed_keys = {"enable_multithread_load", "num_threads"} unexpected_keys = set(extra_config.keys()) - allowed_keys if unexpected_keys: raise ValueError(f"Unexpected extra config keys for load format " f"{load_config.load_format}: " f"{unexpected_keys}") def _prepare_weights( self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool, allow_patterns_overrides: Optional[list[str]], ) -> tuple[str, list[str], bool]: """Prepare weights for the model. If the model is not local, it will be downloaded.""" model_name_or_path = (maybe_download_from_modelscope( model_name_or_path, revision) or model_name_or_path) is_local = os.path.isdir(model_name_or_path) load_format = self.load_config.load_format use_safetensors = False index_file = SAFE_WEIGHTS_INDEX_NAME # Some quantized models use .pt files for storing the weights. if load_format == "auto": allow_patterns = ["*.safetensors", "*.bin"] elif (load_format == "safetensors" or load_format == "fastsafetensors"): use_safetensors = True allow_patterns = ["*.safetensors"] elif load_format == "mistral": use_safetensors = True allow_patterns = ["consolidated*.safetensors"] index_file = "consolidated.safetensors.index.json" elif load_format == "pt": allow_patterns = ["*.pt"] elif load_format == "npcache": allow_patterns = ["*.bin"] else: raise ValueError(f"Unknown load_format: {load_format}") if fall_back_to_pt: allow_patterns += ["*.pt"] if allow_patterns_overrides is not None: allow_patterns = allow_patterns_overrides if not is_local: hf_folder = download_weights_from_hf( model_name_or_path, self.load_config.download_dir, allow_patterns, revision, ignore_patterns=self.load_config.ignore_patterns, ) else: hf_folder = model_name_or_path hf_weights_files: list[str] = [] for pattern in allow_patterns: hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) if len(hf_weights_files) > 0: if pattern == "*.safetensors": use_safetensors = True break if use_safetensors: # For models like Mistral-7B-Instruct-v0.3 # there are both sharded safetensors files and a consolidated # safetensors file. Using both breaks. # Here, we download the `model.safetensors.index.json` and filter # any files not found in the index. if not is_local: download_safetensors_index_file_from_hf( model_name_or_path, index_file, self.load_config.download_dir, revision, ) hf_weights_files = filter_duplicate_safetensors_files( hf_weights_files, hf_folder, index_file) else: hf_weights_files = filter_files_not_needed_for_inference( hf_weights_files) if len(hf_weights_files) == 0: raise RuntimeError( f"Cannot find any model weights with `{model_name_or_path}`") return hf_folder, hf_weights_files, use_safetensors def _get_weights_iterator( self, source: "Source" ) -> Generator[tuple[str, torch.Tensor], None, None]: """Get an iterator for the model weights based on the load format.""" extra_config = self.load_config.model_loader_extra_config hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( source.model_or_path, source.revision, source.fall_back_to_pt, source.allow_patterns_overrides) if self.load_config.load_format == "npcache": # Currently np_cache only support *.bin checkpoints assert use_safetensors is False weights_iterator = np_cache_weights_iterator( source.model_or_path, self.load_config.download_dir, hf_folder, hf_weights_files, self.load_config.use_tqdm_on_load, ) elif use_safetensors: if self.load_config.load_format == "fastsafetensors": weights_iterator = fastsafetensors_weights_iterator( hf_weights_files, self.load_config.use_tqdm_on_load, ) else: if extra_config.get("enable_multithread_load"): weights_iterator = ( multi_thread_safetensors_weights_iterator( hf_weights_files, self.load_config.use_tqdm_on_load, max_workers=extra_config.get( "num_threads", self.DEFAULT_NUM_THREADS), )) else: weights_iterator = safetensors_weights_iterator( hf_weights_files, self.load_config.use_tqdm_on_load, self.load_config.safetensors_load_strategy, ) else: if extra_config.get("enable_multithread_load"): weights_iterator = multi_thread_pt_weights_iterator( hf_weights_files, self.load_config.use_tqdm_on_load, self.load_config.pt_load_map_location, max_workers=extra_config.get("num_threads", self.DEFAULT_NUM_THREADS), ) else: weights_iterator = pt_weights_iterator( hf_weights_files, self.load_config.use_tqdm_on_load, self.load_config.pt_load_map_location, ) if current_platform.is_tpu(): from vllm.platforms.tpu import USE_TPU_COMMONS if not USE_TPU_COMMONS: # In PyTorch XLA, we should call `torch_xla.sync` # frequently so that not too many ops are accumulated # in the XLA program. import torch_xla def _xla_weights_iterator(iterator: Generator): for weights in iterator: yield weights torch_xla.sync(wait=False) weights_iterator = _xla_weights_iterator(weights_iterator) if self.counter_before_loading_weights == 0.0: self.counter_before_loading_weights = time.perf_counter() # Apply the prefix. return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator) def get_all_weights( self, model_config: ModelConfig, model: nn.Module, ) -> Generator[tuple[str, torch.Tensor], None, None]: primary_weights = DefaultModelLoader.Source( model_config.model, model_config.revision, prefix="", fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True), allow_patterns_overrides=getattr(model, "allow_patterns_overrides", None), ) yield from self._get_weights_iterator(primary_weights) secondary_weights = cast( Iterable[DefaultModelLoader.Source], getattr(model, "secondary_weights", ()), ) for source in secondary_weights: yield from self._get_weights_iterator(source) def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config.model, model_config.revision, fall_back_to_pt=True, allow_patterns_overrides=None) def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None: weights_to_load = {name for name, _ in model.named_parameters()} loaded_weights = model.load_weights( self.get_all_weights(model_config, model)) self.counter_after_loading_weights = time.perf_counter() logger.info( "Loading weights took %.2f seconds", self.counter_after_loading_weights - self.counter_before_loading_weights) # We only enable strict check for non-quantized models # that have loaded weights tracking currently. if model_config.quantization is None and loaded_weights is not None: weights_not_loaded = weights_to_load - loaded_weights if weights_not_loaded: raise ValueError("Following weights were not initialized from " f"checkpoint: {weights_not_loaded}")