283 lines
11 KiB
Python
283 lines
11 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import dataclasses
|
|
import glob
|
|
import os
|
|
import time
|
|
from collections.abc import Generator, Iterable
|
|
from typing import Optional, cast
|
|
|
|
import huggingface_hub
|
|
import torch
|
|
from torch import nn
|
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
|
|
|
from vllm import envs
|
|
from vllm.config import LoadConfig, LoadFormat, ModelConfig
|
|
from vllm.logger import init_logger
|
|
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
|
from vllm.model_executor.model_loader.weight_utils import (
|
|
download_safetensors_index_file_from_hf, download_weights_from_hf,
|
|
fastsafetensors_weights_iterator, filter_duplicate_safetensors_files,
|
|
filter_files_not_needed_for_inference, get_lock, np_cache_weights_iterator,
|
|
pt_weights_iterator, safetensors_weights_iterator)
|
|
from vllm.platforms import current_platform
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
|
|
class DefaultModelLoader(BaseModelLoader):
|
|
"""Model loader that can load different file types from disk."""
|
|
|
|
@dataclasses.dataclass
|
|
class Source:
|
|
"""A source for weights."""
|
|
|
|
model_or_path: str
|
|
"""The model ID or path."""
|
|
|
|
revision: Optional[str]
|
|
"""The optional model revision."""
|
|
|
|
prefix: str = ""
|
|
"""A prefix to prepend to all weights."""
|
|
|
|
fall_back_to_pt: bool = True
|
|
"""Whether .pt weights can be used."""
|
|
|
|
allow_patterns_overrides: Optional[list[str]] = None
|
|
"""If defined, weights will load exclusively using these patterns."""
|
|
|
|
counter_before_loading_weights: float = 0.0
|
|
counter_after_loading_weights: float = 0.0
|
|
|
|
def __init__(self, load_config: LoadConfig):
|
|
super().__init__(load_config)
|
|
if load_config.model_loader_extra_config:
|
|
raise ValueError(f"Model loader extra config is not supported for "
|
|
f"load format {load_config.load_format}")
|
|
|
|
def _maybe_download_from_modelscope(
|
|
self, model: str, revision: Optional[str]) -> Optional[str]:
|
|
"""Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
|
|
|
|
Returns the path to the downloaded model, or None if the model is not
|
|
downloaded from ModelScope."""
|
|
if envs.VLLM_USE_MODELSCOPE:
|
|
# download model from ModelScope hub,
|
|
# lazy import so that modelscope is not required for normal use.
|
|
# pylint: disable=C.
|
|
from modelscope.hub.snapshot_download import snapshot_download
|
|
|
|
if not os.path.exists(model):
|
|
# Use file lock to prevent multiple processes from
|
|
# downloading the same model weights at the same time.
|
|
with get_lock(model, self.load_config.download_dir):
|
|
model_path = snapshot_download(
|
|
model_id=model,
|
|
cache_dir=self.load_config.download_dir,
|
|
local_files_only=huggingface_hub.constants.
|
|
HF_HUB_OFFLINE,
|
|
revision=revision,
|
|
ignore_file_pattern=self.load_config.ignore_patterns,
|
|
)
|
|
else:
|
|
model_path = model
|
|
return model_path
|
|
return None
|
|
|
|
def _prepare_weights(
|
|
self,
|
|
model_name_or_path: str,
|
|
revision: Optional[str],
|
|
fall_back_to_pt: bool,
|
|
allow_patterns_overrides: Optional[list[str]],
|
|
) -> tuple[str, list[str], bool]:
|
|
"""Prepare weights for the model.
|
|
|
|
If the model is not local, it will be downloaded."""
|
|
model_name_or_path = (self._maybe_download_from_modelscope(
|
|
model_name_or_path, revision) or model_name_or_path)
|
|
|
|
is_local = os.path.isdir(model_name_or_path)
|
|
load_format = self.load_config.load_format
|
|
use_safetensors = False
|
|
index_file = SAFE_WEIGHTS_INDEX_NAME
|
|
# Some quantized models use .pt files for storing the weights.
|
|
if load_format == LoadFormat.AUTO:
|
|
allow_patterns = ["*.safetensors", "*.bin"]
|
|
elif (load_format == LoadFormat.SAFETENSORS
|
|
or load_format == LoadFormat.FASTSAFETENSORS):
|
|
use_safetensors = True
|
|
allow_patterns = ["*.safetensors"]
|
|
elif load_format == LoadFormat.MISTRAL:
|
|
use_safetensors = True
|
|
allow_patterns = ["consolidated*.safetensors"]
|
|
index_file = "consolidated.safetensors.index.json"
|
|
elif load_format == LoadFormat.PT:
|
|
allow_patterns = ["*.pt"]
|
|
elif load_format == LoadFormat.NPCACHE:
|
|
allow_patterns = ["*.bin"]
|
|
else:
|
|
raise ValueError(f"Unknown load_format: {load_format}")
|
|
|
|
if fall_back_to_pt:
|
|
allow_patterns += ["*.pt"]
|
|
|
|
if allow_patterns_overrides is not None:
|
|
allow_patterns = allow_patterns_overrides
|
|
|
|
if not is_local:
|
|
hf_folder = download_weights_from_hf(
|
|
model_name_or_path,
|
|
self.load_config.download_dir,
|
|
allow_patterns,
|
|
revision,
|
|
ignore_patterns=self.load_config.ignore_patterns,
|
|
)
|
|
else:
|
|
hf_folder = model_name_or_path
|
|
|
|
hf_weights_files: list[str] = []
|
|
for pattern in allow_patterns:
|
|
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
|
if len(hf_weights_files) > 0:
|
|
if pattern == "*.safetensors":
|
|
use_safetensors = True
|
|
break
|
|
|
|
if use_safetensors:
|
|
# For models like Mistral-7B-Instruct-v0.3
|
|
# there are both sharded safetensors files and a consolidated
|
|
# safetensors file. Using both breaks.
|
|
# Here, we download the `model.safetensors.index.json` and filter
|
|
# any files not found in the index.
|
|
if not is_local:
|
|
download_safetensors_index_file_from_hf(
|
|
model_name_or_path,
|
|
index_file,
|
|
self.load_config.download_dir,
|
|
revision,
|
|
)
|
|
hf_weights_files = filter_duplicate_safetensors_files(
|
|
hf_weights_files, hf_folder, index_file)
|
|
else:
|
|
hf_weights_files = filter_files_not_needed_for_inference(
|
|
hf_weights_files)
|
|
|
|
if len(hf_weights_files) == 0:
|
|
raise RuntimeError(
|
|
f"Cannot find any model weights with `{model_name_or_path}`")
|
|
|
|
return hf_folder, hf_weights_files, use_safetensors
|
|
|
|
def _get_weights_iterator(
|
|
self, source: "Source"
|
|
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
|
"""Get an iterator for the model weights based on the load format."""
|
|
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
|
|
source.model_or_path, source.revision, source.fall_back_to_pt,
|
|
source.allow_patterns_overrides)
|
|
if self.load_config.load_format == LoadFormat.NPCACHE:
|
|
# Currently np_cache only support *.bin checkpoints
|
|
assert use_safetensors is False
|
|
weights_iterator = np_cache_weights_iterator(
|
|
source.model_or_path,
|
|
self.load_config.download_dir,
|
|
hf_folder,
|
|
hf_weights_files,
|
|
self.load_config.use_tqdm_on_load,
|
|
)
|
|
elif use_safetensors:
|
|
if self.load_config.load_format == LoadFormat.FASTSAFETENSORS:
|
|
weights_iterator = fastsafetensors_weights_iterator(
|
|
hf_weights_files,
|
|
self.load_config.use_tqdm_on_load,
|
|
)
|
|
else:
|
|
weights_iterator = safetensors_weights_iterator(
|
|
hf_weights_files,
|
|
self.load_config.use_tqdm_on_load,
|
|
)
|
|
else:
|
|
weights_iterator = pt_weights_iterator(
|
|
hf_weights_files,
|
|
self.load_config.use_tqdm_on_load,
|
|
self.load_config.pt_load_map_location,
|
|
)
|
|
|
|
if current_platform.is_tpu():
|
|
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
|
|
# not too many ops are accumulated in the XLA program.
|
|
import torch_xla.core.xla_model as xm
|
|
|
|
def _xla_weights_iterator(iterator: Generator):
|
|
for weights in iterator:
|
|
yield weights
|
|
xm.mark_step()
|
|
|
|
weights_iterator = _xla_weights_iterator(weights_iterator)
|
|
|
|
elif current_platform.is_hpu():
|
|
import habana_frameworks.torch.core as htcore
|
|
|
|
def _hpu_weights_iterator(iterator: Generator):
|
|
for weights in iterator:
|
|
yield weights
|
|
htcore.mark_step()
|
|
|
|
weights_iterator = _hpu_weights_iterator(weights_iterator)
|
|
|
|
if self.counter_before_loading_weights == 0.0:
|
|
self.counter_before_loading_weights = time.perf_counter()
|
|
# Apply the prefix.
|
|
return ((source.prefix + name, tensor)
|
|
for (name, tensor) in weights_iterator)
|
|
|
|
def get_all_weights(
|
|
self,
|
|
model_config: ModelConfig,
|
|
model: nn.Module,
|
|
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
|
primary_weights = DefaultModelLoader.Source(
|
|
model_config.model,
|
|
model_config.revision,
|
|
prefix="",
|
|
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load",
|
|
True),
|
|
allow_patterns_overrides=getattr(model, "allow_patterns_overrides",
|
|
None),
|
|
)
|
|
yield from self._get_weights_iterator(primary_weights)
|
|
|
|
secondary_weights = cast(
|
|
Iterable[DefaultModelLoader.Source],
|
|
getattr(model, "secondary_weights", ()),
|
|
)
|
|
for source in secondary_weights:
|
|
yield from self._get_weights_iterator(source)
|
|
|
|
def download_model(self, model_config: ModelConfig) -> None:
|
|
self._prepare_weights(model_config.model,
|
|
model_config.revision,
|
|
fall_back_to_pt=True,
|
|
allow_patterns_overrides=None)
|
|
|
|
def load_weights(self, model: nn.Module,
|
|
model_config: ModelConfig) -> None:
|
|
weights_to_load = {name for name, _ in model.named_parameters()}
|
|
loaded_weights = model.load_weights(
|
|
self.get_all_weights(model_config, model))
|
|
self.counter_after_loading_weights = time.perf_counter()
|
|
logger.info(
|
|
"Loading weights took %.2f seconds",
|
|
self.counter_after_loading_weights -
|
|
self.counter_before_loading_weights)
|
|
# We only enable strict check for non-quantized models
|
|
# that have loaded weights tracking currently.
|
|
if model_config.quantization is None and loaded_weights is not None:
|
|
weights_not_loaded = weights_to_load - loaded_weights
|
|
if weights_not_loaded:
|
|
raise ValueError("Following weights were not initialized from "
|
|
f"checkpoint: {weights_not_loaded}")
|