Sync from v0.13
This commit is contained in:
321
vllm/model_executor/model_loader/default_loader.py
Normal file
321
vllm/model_executor/model_loader/default_loader.py
Normal file
@@ -0,0 +1,321 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import dataclasses
|
||||
import glob
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Generator, Iterable
|
||||
from typing import cast
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.load import LoadConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization.torchao import torchao_version_at_least
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_safetensors_index_file_from_hf,
|
||||
download_weights_from_hf,
|
||||
fastsafetensors_weights_iterator,
|
||||
filter_duplicate_safetensors_files,
|
||||
filter_files_not_needed_for_inference,
|
||||
get_quant_config,
|
||||
maybe_download_from_modelscope,
|
||||
multi_thread_pt_weights_iterator,
|
||||
multi_thread_safetensors_weights_iterator,
|
||||
np_cache_weights_iterator,
|
||||
pt_weights_iterator,
|
||||
safetensors_weights_iterator,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.repo_utils import list_filtered_repo_files
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DefaultModelLoader(BaseModelLoader):
|
||||
"""Model loader that can load different file types from disk."""
|
||||
|
||||
# default number of thread when enable multithread weight loading
|
||||
DEFAULT_NUM_THREADS = 8
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Source:
|
||||
"""A source for weights."""
|
||||
|
||||
model_or_path: str
|
||||
"""The model ID or path."""
|
||||
|
||||
revision: str | None
|
||||
"""The optional model revision."""
|
||||
|
||||
prefix: str = ""
|
||||
"""A prefix to prepend to all weights."""
|
||||
|
||||
fall_back_to_pt: bool = True
|
||||
"""Whether .pt weights can be used."""
|
||||
|
||||
allow_patterns_overrides: list[str] | None = None
|
||||
"""If defined, weights will load exclusively using these patterns."""
|
||||
|
||||
counter_before_loading_weights: float = 0.0
|
||||
counter_after_loading_weights: float = 0.0
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
|
||||
extra_config = load_config.model_loader_extra_config
|
||||
allowed_keys = {"enable_multithread_load", "num_threads"}
|
||||
unexpected_keys = set(extra_config.keys()) - allowed_keys
|
||||
|
||||
if unexpected_keys:
|
||||
raise ValueError(
|
||||
f"Unexpected extra config keys for load format "
|
||||
f"{load_config.load_format}: "
|
||||
f"{unexpected_keys}"
|
||||
)
|
||||
|
||||
def _prepare_weights(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
revision: str | None,
|
||||
fall_back_to_pt: bool,
|
||||
allow_patterns_overrides: list[str] | None,
|
||||
) -> tuple[str, list[str], bool]:
|
||||
"""Prepare weights for the model.
|
||||
|
||||
If the model is not local, it will be downloaded."""
|
||||
model_name_or_path = (
|
||||
maybe_download_from_modelscope(model_name_or_path, revision)
|
||||
or model_name_or_path
|
||||
)
|
||||
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
load_format = self.load_config.load_format
|
||||
use_safetensors = False
|
||||
index_file = SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
# First check for 'auto' format that mistral files format are present.
|
||||
# This is to load mistral models with official format by default.
|
||||
if load_format == "auto":
|
||||
load_format = (
|
||||
"mistral"
|
||||
if len(
|
||||
list_filtered_repo_files(
|
||||
model_name_or_path=model_name_or_path,
|
||||
allow_patterns=["consolidated*.safetensors"],
|
||||
revision=revision,
|
||||
)
|
||||
)
|
||||
> 0
|
||||
else "hf"
|
||||
)
|
||||
|
||||
# Some quantized models use .pt files for storing the weights.
|
||||
if load_format == "hf":
|
||||
allow_patterns = ["*.safetensors", "*.bin"]
|
||||
elif load_format == "safetensors" or load_format == "fastsafetensors":
|
||||
use_safetensors = True
|
||||
allow_patterns = ["*.safetensors"]
|
||||
elif load_format == "mistral":
|
||||
use_safetensors = True
|
||||
allow_patterns = ["consolidated*.safetensors"]
|
||||
index_file = "consolidated.safetensors.index.json"
|
||||
elif load_format == "pt":
|
||||
allow_patterns = ["*.pt"]
|
||||
elif load_format == "npcache":
|
||||
allow_patterns = ["*.bin"]
|
||||
else:
|
||||
raise ValueError(f"Unknown load_format: {load_format}")
|
||||
|
||||
if fall_back_to_pt:
|
||||
allow_patterns += ["*.pt"]
|
||||
|
||||
if allow_patterns_overrides is not None:
|
||||
allow_patterns = allow_patterns_overrides
|
||||
|
||||
if not is_local:
|
||||
hf_folder = download_weights_from_hf(
|
||||
model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
allow_patterns,
|
||||
revision,
|
||||
ignore_patterns=self.load_config.ignore_patterns,
|
||||
)
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
|
||||
hf_weights_files: list[str] = []
|
||||
for pattern in allow_patterns:
|
||||
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
||||
if len(hf_weights_files) > 0:
|
||||
if pattern == "*.safetensors":
|
||||
use_safetensors = True
|
||||
break
|
||||
|
||||
if use_safetensors:
|
||||
# For models like Mistral-7B-Instruct-v0.3
|
||||
# there are both sharded safetensors files and a consolidated
|
||||
# safetensors file. Using both breaks.
|
||||
# Here, we download the `model.safetensors.index.json` and filter
|
||||
# any files not found in the index.
|
||||
if not is_local:
|
||||
download_safetensors_index_file_from_hf(
|
||||
model_name_or_path,
|
||||
index_file,
|
||||
self.load_config.download_dir,
|
||||
revision,
|
||||
)
|
||||
hf_weights_files = filter_duplicate_safetensors_files(
|
||||
hf_weights_files, hf_folder, index_file
|
||||
)
|
||||
else:
|
||||
hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files)
|
||||
|
||||
if len(hf_weights_files) == 0:
|
||||
raise RuntimeError(
|
||||
f"Cannot find any model weights with `{model_name_or_path}`"
|
||||
)
|
||||
|
||||
return hf_folder, hf_weights_files, use_safetensors
|
||||
|
||||
def _get_weights_iterator(
|
||||
self, source: "Source"
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Get an iterator for the model weights based on the load format."""
|
||||
extra_config = self.load_config.model_loader_extra_config
|
||||
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
|
||||
source.model_or_path,
|
||||
source.revision,
|
||||
source.fall_back_to_pt,
|
||||
source.allow_patterns_overrides,
|
||||
)
|
||||
if self.load_config.load_format == "npcache":
|
||||
# Currently np_cache only support *.bin checkpoints
|
||||
assert use_safetensors is False
|
||||
weights_iterator = np_cache_weights_iterator(
|
||||
source.model_or_path,
|
||||
self.load_config.download_dir,
|
||||
hf_folder,
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
)
|
||||
elif use_safetensors:
|
||||
if self.load_config.load_format == "fastsafetensors":
|
||||
weights_iterator = fastsafetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
)
|
||||
else:
|
||||
if extra_config.get("enable_multithread_load"):
|
||||
weights_iterator = multi_thread_safetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
max_workers=extra_config.get(
|
||||
"num_threads", self.DEFAULT_NUM_THREADS
|
||||
),
|
||||
)
|
||||
else:
|
||||
weights_iterator = safetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
self.load_config.safetensors_load_strategy,
|
||||
)
|
||||
else:
|
||||
if extra_config.get("enable_multithread_load"):
|
||||
weights_iterator = multi_thread_pt_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
self.load_config.pt_load_map_location,
|
||||
max_workers=extra_config.get(
|
||||
"num_threads", self.DEFAULT_NUM_THREADS
|
||||
),
|
||||
)
|
||||
else:
|
||||
weights_iterator = pt_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
self.load_config.pt_load_map_location,
|
||||
)
|
||||
|
||||
if current_platform.is_tpu():
|
||||
from vllm.platforms.tpu import USE_TPU_INFERENCE
|
||||
|
||||
if not USE_TPU_INFERENCE:
|
||||
# In PyTorch XLA, we should call `torch_xla.sync`
|
||||
# frequently so that not too many ops are accumulated
|
||||
# in the XLA program.
|
||||
import torch_xla
|
||||
|
||||
def _xla_weights_iterator(iterator: Generator):
|
||||
for weights in iterator:
|
||||
yield weights
|
||||
torch_xla.sync(wait=False)
|
||||
|
||||
weights_iterator = _xla_weights_iterator(weights_iterator)
|
||||
|
||||
if self.counter_before_loading_weights == 0.0:
|
||||
self.counter_before_loading_weights = time.perf_counter()
|
||||
# Apply the prefix.
|
||||
return ((source.prefix + name, tensor) for (name, tensor) in weights_iterator)
|
||||
|
||||
def get_all_weights(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
model: nn.Module,
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
primary_weights = DefaultModelLoader.Source(
|
||||
model_config.model,
|
||||
model_config.revision,
|
||||
prefix="",
|
||||
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load", True),
|
||||
allow_patterns_overrides=getattr(model, "allow_patterns_overrides", None),
|
||||
)
|
||||
yield from self._get_weights_iterator(primary_weights)
|
||||
|
||||
secondary_weights = cast(
|
||||
Iterable[DefaultModelLoader.Source],
|
||||
getattr(model, "secondary_weights", ()),
|
||||
)
|
||||
for source in secondary_weights:
|
||||
yield from self._get_weights_iterator(source)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(
|
||||
model_config.model,
|
||||
model_config.revision,
|
||||
fall_back_to_pt=True,
|
||||
allow_patterns_overrides=None,
|
||||
)
|
||||
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
if model_config.quantization == "torchao":
|
||||
quant_config = get_quant_config(model_config, self.load_config)
|
||||
if (
|
||||
hasattr(quant_config, "is_checkpoint_torchao_serialized")
|
||||
and quant_config.is_checkpoint_torchao_serialized
|
||||
and torchao_version_at_least("0.15.0")
|
||||
):
|
||||
self.load_config.safetensors_load_strategy = "torchao"
|
||||
|
||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||
loaded_weights = model.load_weights(self.get_all_weights(model_config, model))
|
||||
|
||||
self.counter_after_loading_weights = time.perf_counter()
|
||||
logger.info_once(
|
||||
"Loading weights took %.2f seconds",
|
||||
self.counter_after_loading_weights - self.counter_before_loading_weights,
|
||||
scope="local",
|
||||
)
|
||||
# We only enable strict check for non-quantized models
|
||||
# that have loaded weights tracking currently.
|
||||
if model_config.quantization is None and loaded_weights is not None:
|
||||
weights_not_loaded = weights_to_load - loaded_weights
|
||||
if weights_not_loaded:
|
||||
raise ValueError(
|
||||
"Following weights were not initialized from "
|
||||
f"checkpoint: {weights_not_loaded}"
|
||||
)
|
||||
Reference in New Issue
Block a user