[gpt-oss] Add gpt-oss bf16 support
This commit is contained in:
76
vllm/model_executor/model_loader/__init__.py
Normal file
76
vllm/model_executor/model_loader/__init__.py
Normal file
@@ -0,0 +1,76 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import LoadConfig, LoadFormat, ModelConfig, VllmConfig
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.bitsandbytes_loader import (
|
||||
BitsAndBytesModelLoader)
|
||||
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
|
||||
from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader
|
||||
from vllm.model_executor.model_loader.gguf_loader import GGUFModelLoader
|
||||
from vllm.model_executor.model_loader.runai_streamer_loader import (
|
||||
RunaiModelStreamerLoader)
|
||||
from vllm.model_executor.model_loader.sharded_state_loader import (
|
||||
ShardedStateLoader)
|
||||
from vllm.model_executor.model_loader.tensorizer_loader import TensorizerLoader
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
get_architecture_class_name, get_model_architecture)
|
||||
|
||||
|
||||
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
||||
"""Get a model loader based on the load format."""
|
||||
if isinstance(load_config.load_format, type):
|
||||
return load_config.load_format(load_config)
|
||||
|
||||
if load_config.load_format == LoadFormat.DUMMY:
|
||||
return DummyModelLoader(load_config)
|
||||
|
||||
if load_config.load_format == LoadFormat.TENSORIZER:
|
||||
return TensorizerLoader(load_config)
|
||||
|
||||
if load_config.load_format == LoadFormat.SHARDED_STATE:
|
||||
return ShardedStateLoader(load_config)
|
||||
|
||||
if load_config.load_format == LoadFormat.BITSANDBYTES:
|
||||
return BitsAndBytesModelLoader(load_config)
|
||||
|
||||
if load_config.load_format == LoadFormat.GGUF:
|
||||
return GGUFModelLoader(load_config)
|
||||
|
||||
if load_config.load_format == LoadFormat.RUNAI_STREAMER:
|
||||
return RunaiModelStreamerLoader(load_config)
|
||||
|
||||
if load_config.load_format == LoadFormat.RUNAI_STREAMER_SHARDED:
|
||||
return ShardedStateLoader(load_config, runai_model_streamer=True)
|
||||
|
||||
return DefaultModelLoader(load_config)
|
||||
|
||||
|
||||
def get_model(*,
|
||||
vllm_config: VllmConfig,
|
||||
model_config: Optional[ModelConfig] = None) -> nn.Module:
|
||||
loader = get_model_loader(vllm_config.load_config)
|
||||
if model_config is None:
|
||||
model_config = vllm_config.model_config
|
||||
return loader.load_model(vllm_config=vllm_config,
|
||||
model_config=model_config)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_model",
|
||||
"get_model_loader",
|
||||
"get_architecture_class_name",
|
||||
"get_model_architecture",
|
||||
"BaseModelLoader",
|
||||
"BitsAndBytesModelLoader",
|
||||
"GGUFModelLoader",
|
||||
"DefaultModelLoader",
|
||||
"DummyModelLoader",
|
||||
"RunaiModelStreamerLoader",
|
||||
"ShardedStateLoader",
|
||||
"TensorizerLoader",
|
||||
]
|
||||
43
vllm/model_executor/model_loader/base_loader.py
Normal file
43
vllm/model_executor/model_loader/base_loader.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import LoadConfig, ModelConfig, VllmConfig
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
initialize_model, process_weights_after_loading, set_default_torch_dtype)
|
||||
|
||||
|
||||
class BaseModelLoader(ABC):
|
||||
"""Base class for model loaders."""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
self.load_config = load_config
|
||||
|
||||
@abstractmethod
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
"""Download a model so that it can be immediately loaded."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def load_weights(self, model: nn.Module,
|
||||
model_config: ModelConfig) -> None:
|
||||
"""Load weights into a model. This standalone API allows
|
||||
inplace weights loading for an already-initialized model"""
|
||||
raise NotImplementedError
|
||||
|
||||
def load_model(self, vllm_config: VllmConfig,
|
||||
model_config: ModelConfig) -> nn.Module:
|
||||
"""Load a model with the given configurations."""
|
||||
device_config = vllm_config.device_config
|
||||
target_device = torch.device(device_config.device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
model = initialize_model(vllm_config=vllm_config,
|
||||
model_config=model_config)
|
||||
# Quantization does not happen in `load_weights` but after it
|
||||
self.load_weights(model, model_config)
|
||||
process_weights_after_loading(model, model_config, target_device)
|
||||
return model.eval()
|
||||
570
vllm/model_executor/model_loader/bitsandbytes_loader.py
Normal file
570
vllm/model_executor/model_loader/bitsandbytes_loader.py
Normal file
@@ -0,0 +1,570 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: SIM117
|
||||
import fnmatch
|
||||
import glob
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import HfApi
|
||||
from torch import nn
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from vllm.config import LoadConfig, ModelConfig
|
||||
from vllm.distributed import (get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
# yapf: enable
|
||||
from vllm.logger import init_logger
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.utils import (ParamMapping,
|
||||
set_default_torch_dtype)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_safetensors_index_file_from_hf, download_weights_from_hf,
|
||||
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
|
||||
pt_weights_iterator, safetensors_weights_iterator)
|
||||
from vllm.model_executor.models import is_pooling_model
|
||||
from vllm.model_executor.utils import (get_packed_modules_mapping,
|
||||
set_weight_attrs)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
"""Model loader to load model weights with BitAndBytes quantization."""
|
||||
|
||||
possible_config_file_names = ["adapter_config.json"]
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
|
||||
# Save the module names without sharding.
|
||||
self.unsharded_weights_modules: list[str] = []
|
||||
# Save the module names that are sharded by column.
|
||||
self.column_sharded_weights_modules: list[str] = []
|
||||
# Store all module names (from transformers) that support
|
||||
# BNB quantization.
|
||||
self.target_modules: list[str] = []
|
||||
# mapping weight names from transformers to vllm.
|
||||
self.weight_mapper: Callable = lambda name: name
|
||||
|
||||
def _get_weight_files(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
allowed_patterns: list[str],
|
||||
revision: Optional[str] = None,
|
||||
) -> tuple[str, list[str], str]:
|
||||
"""Retrieve weight files. Download the files if necessary.
|
||||
|
||||
Return the weight files and the file pattern."""
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
|
||||
if is_local:
|
||||
for pattern in allowed_patterns:
|
||||
weight_files = glob.glob(
|
||||
os.path.join(model_name_or_path, pattern))
|
||||
if weight_files:
|
||||
return model_name_or_path, weight_files, pattern
|
||||
else:
|
||||
hf_api = HfApi()
|
||||
repo_files = hf_api.list_repo_files(repo_id=model_name_or_path)
|
||||
for pattern in allowed_patterns:
|
||||
matching_files = fnmatch.filter(repo_files, pattern)
|
||||
if matching_files:
|
||||
hf_folder = download_weights_from_hf(
|
||||
model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
[pattern],
|
||||
revision,
|
||||
ignore_patterns=self.load_config.ignore_patterns,
|
||||
)
|
||||
return hf_folder, glob.glob(
|
||||
os.path.join(hf_folder, pattern)), pattern
|
||||
|
||||
raise RuntimeError(
|
||||
f"No model weights found in: `{model_name_or_path}`")
|
||||
|
||||
def _prepare_weights(self, model_name_or_path: str,
|
||||
revision: Optional[str]) -> tuple[list[str], bool]:
|
||||
"""Prepare weight files for the model."""
|
||||
|
||||
allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]
|
||||
|
||||
hf_folder, hf_weights_files, matched_pattern = self._get_weight_files(
|
||||
model_name_or_path, allowed_patterns, revision)
|
||||
|
||||
use_safetensors = matched_pattern == "*.safetensors"
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
index_file = SAFE_WEIGHTS_INDEX_NAME
|
||||
if use_safetensors:
|
||||
# For models like Mistral-7B-Instruct-v0.3
|
||||
# there are both sharded safetensors files and a consolidated
|
||||
# safetensors file. Using both breaks.
|
||||
# Here, we download the `model.safetensors.index.json` and filter
|
||||
# any files not found in the index.
|
||||
if not is_local:
|
||||
download_safetensors_index_file_from_hf(
|
||||
model_name_or_path,
|
||||
index_file,
|
||||
self.load_config.download_dir,
|
||||
revision,
|
||||
)
|
||||
hf_weights_files = filter_duplicate_safetensors_files(
|
||||
hf_weights_files, hf_folder, index_file)
|
||||
else:
|
||||
hf_weights_files = filter_files_not_needed_for_inference(
|
||||
hf_weights_files)
|
||||
|
||||
if len(hf_weights_files) == 0:
|
||||
raise RuntimeError(
|
||||
f"Cannot find any model weights with `{model_name_or_path}`")
|
||||
|
||||
return hf_weights_files, use_safetensors
|
||||
|
||||
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
|
||||
def _maybe_pool_model(module_name:str):
|
||||
# For pool model, we need to add the prefix `model.`
|
||||
# for the weight name if possible.
|
||||
if self.is_pool_model and self.target_modules[0]. \
|
||||
startswith("model.") and not module_name.startswith(
|
||||
"model."):
|
||||
return "model."+module_name
|
||||
|
||||
return module_name
|
||||
|
||||
if use_safetensors:
|
||||
iterator = safetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
)
|
||||
else:
|
||||
iterator = pt_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
self.load_config.pt_load_map_location,
|
||||
)
|
||||
for org_name, param in iterator:
|
||||
# mapping weight names from transformers to vllm while preserving
|
||||
# original names.
|
||||
mapped_name = self.weight_mapper(org_name)
|
||||
mapped_name=_maybe_pool_model(mapped_name)
|
||||
|
||||
|
||||
yield org_name, mapped_name, param
|
||||
|
||||
def _get_quantized_weights_iterator(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
revision: Optional[str],
|
||||
pre_quant: bool,
|
||||
load_8bit: bool,
|
||||
) -> tuple[Generator[tuple[str, torch.Tensor], None, None], dict[str,
|
||||
Any]]:
|
||||
"""Get an iterator to the model weights with bitsandbytes quantization,
|
||||
as well as the quantization state dictionary."""
|
||||
|
||||
# only load the bitsandbytes module when needed
|
||||
try:
|
||||
import bitsandbytes
|
||||
|
||||
if bitsandbytes.__version__ < "0.45.3":
|
||||
raise ImportError("bitsandbytes version is wrong. Please "
|
||||
"install bitsandbytes>=0.45.3.")
|
||||
except ImportError as err:
|
||||
raise ImportError("Please install bitsandbytes>=0.45.3 via "
|
||||
"`pip install bitsandbytes>=0.45.3` to use "
|
||||
"bitsandbytes quantizer.") from err
|
||||
|
||||
hf_weights_files, use_safetensors = self._prepare_weights(
|
||||
model_name_or_path, revision)
|
||||
|
||||
quant_state_dict: dict[str, Any] = {}
|
||||
|
||||
if pre_quant:
|
||||
if load_8bit:
|
||||
return self._quantized_8bit_generator(
|
||||
hf_weights_files, use_safetensors,
|
||||
quant_state_dict), quant_state_dict
|
||||
else:
|
||||
return self._quantized_4bit_generator(
|
||||
hf_weights_files, use_safetensors,
|
||||
quant_state_dict), quant_state_dict
|
||||
|
||||
return self._unquantized_generator(hf_weights_files, use_safetensors,
|
||||
quant_state_dict), quant_state_dict
|
||||
|
||||
def _is_8bit_weight_name(self, weight_name: str):
|
||||
quantized_suffix = {".scb", ".weight_format"}
|
||||
return any(weight_name.lower().endswith(suffix)
|
||||
for suffix in quantized_suffix)
|
||||
|
||||
def _is_4bit_weight_name(self, weight_name: str):
|
||||
quantized_suffix = {
|
||||
"absmax",
|
||||
"quant_map",
|
||||
"nested_absmax",
|
||||
"nested_quant_map",
|
||||
"bitsandbytes",
|
||||
}
|
||||
suffix = weight_name.split(".")[-1]
|
||||
return any(q_suffix in suffix for q_suffix in quantized_suffix)
|
||||
|
||||
def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
|
||||
quant_state_dict) -> Generator:
|
||||
for (
|
||||
org_weight_name,
|
||||
mapped_weight_name,
|
||||
weight_tensor,
|
||||
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
|
||||
if not mapped_weight_name.lower().endswith(".scb"):
|
||||
continue
|
||||
|
||||
weight_key = mapped_weight_name.lower().replace(".scb", ".weight")
|
||||
quant_state_dict[weight_key] = weight_tensor
|
||||
|
||||
for (
|
||||
org_weight_name,
|
||||
mapped_weight_name,
|
||||
weight_tensor,
|
||||
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
|
||||
if self._is_8bit_weight_name(mapped_weight_name):
|
||||
continue
|
||||
|
||||
if mapped_weight_name in quant_state_dict:
|
||||
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
|
||||
yield org_weight_name, weight_tensor
|
||||
else:
|
||||
yield org_weight_name, weight_tensor
|
||||
|
||||
def _quantized_4bit_generator(self, hf_weights_files, use_safetensors,
|
||||
quant_state_dict) -> Generator:
|
||||
from bitsandbytes.functional import QuantState
|
||||
|
||||
# First iterate over all quant state weights
|
||||
weight_iterator = self._hf_weight_iter(hf_weights_files,
|
||||
use_safetensors)
|
||||
temp_state_dict = {}
|
||||
for (
|
||||
org_weight_name,
|
||||
mapped_weight_name,
|
||||
weight_tensor,
|
||||
) in weight_iterator:
|
||||
if not self._is_4bit_weight_name(mapped_weight_name):
|
||||
continue
|
||||
# bitsandbytes library requires
|
||||
# weight.quant_state.bitsandbytes__* in CPU
|
||||
if "quant_state.bitsandbytes" in mapped_weight_name:
|
||||
temp_state_dict[mapped_weight_name] = weight_tensor.cpu().data
|
||||
else:
|
||||
temp_state_dict[mapped_weight_name] = weight_tensor
|
||||
|
||||
# Closure to parse quant_state for each prequant weight
|
||||
def _parse_quant_state(param_name: str,
|
||||
temp_state_dict: dict) -> QuantState:
|
||||
quant_state = {}
|
||||
for k in temp_state_dict:
|
||||
if param_name + "." in k:
|
||||
quant_state[k] = temp_state_dict[k]
|
||||
|
||||
return QuantState.from_dict(quant_state,
|
||||
device=current_platform.device_type)
|
||||
|
||||
# Second iterate over all prequant and normal weights
|
||||
# pre quantized weights would have a quant_state
|
||||
for (
|
||||
org_weight_name,
|
||||
mapped_weight_name,
|
||||
weight_tensor,
|
||||
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
|
||||
if self._is_4bit_weight_name(mapped_weight_name):
|
||||
continue
|
||||
|
||||
if (f"{mapped_weight_name}.quant_state.bitsandbytes__nf4"
|
||||
in temp_state_dict) or (
|
||||
f"{mapped_weight_name}.quant_state.bitsandbytes__fp4"
|
||||
in temp_state_dict):
|
||||
quant_state = _parse_quant_state(mapped_weight_name,
|
||||
temp_state_dict)
|
||||
quant_state_dict[mapped_weight_name] = quant_state
|
||||
yield org_weight_name, weight_tensor
|
||||
else:
|
||||
yield org_weight_name, weight_tensor
|
||||
|
||||
def _unquantized_generator(self, hf_weights_files, use_safetensors,
|
||||
quant_state_dict) -> Generator:
|
||||
from bitsandbytes.functional import quantize_4bit
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
for (
|
||||
org_weight_name,
|
||||
mapped_weight_name,
|
||||
weight_tensor,
|
||||
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
|
||||
if any(target_module in mapped_weight_name
|
||||
for target_module in self.target_modules
|
||||
) and mapped_weight_name.endswith(".weight"):
|
||||
# Without sharding
|
||||
if any(
|
||||
mapped_weight_name.startswith(module)
|
||||
for module in self.unsharded_weights_modules):
|
||||
weight_sub_tensor = weight_tensor
|
||||
# Shard by column
|
||||
elif any(
|
||||
mapped_weight_name.startswith(module)
|
||||
for module in self.column_sharded_weights_modules):
|
||||
total_size = weight_tensor.size(-1)
|
||||
start_index = total_size // tp_size * tp_rank
|
||||
end_index = total_size // tp_size * (tp_rank + 1)
|
||||
weight_sub_tensor = weight_tensor[...,
|
||||
start_index:end_index]
|
||||
# Weights have fused on disk. In this case, we assume that the
|
||||
# weight and module use same name.
|
||||
elif any(
|
||||
mapped_weight_name.startswith(module)
|
||||
for module in self.maybe_fused_weights_modules):
|
||||
# special case for fused weights
|
||||
# get the size of each shard weight tensor
|
||||
total_shard_sizes = next(
|
||||
(sizes for module, sizes in
|
||||
self.maybe_fused_weights_modules.items()
|
||||
if mapped_weight_name.startswith(module)))
|
||||
total_size = weight_tensor.size(0)
|
||||
assert total_size == sum(total_shard_sizes)
|
||||
# get the start/end index of each shard weight tensor
|
||||
total_start_index = list(
|
||||
itertools.accumulate([0] + total_shard_sizes))[:-1]
|
||||
shard_weights_index = [(
|
||||
idx + size // tp_size * tp_rank,
|
||||
idx + size // tp_size * (tp_rank + 1),
|
||||
) for idx, size in zip(total_start_index,
|
||||
total_shard_sizes)]
|
||||
# slice and reorder the weight tensor
|
||||
weight_tensor = [
|
||||
weight_tensor[start_index:end_index, ...]
|
||||
for start_index, end_index in shard_weights_index
|
||||
]
|
||||
weight_sub_tensor = torch.cat(weight_tensor, dim=0)
|
||||
# Shard by row
|
||||
else:
|
||||
total_size = weight_tensor.size(0)
|
||||
start_index = total_size // tp_size * tp_rank
|
||||
end_index = total_size // tp_size * (tp_rank + 1)
|
||||
weight_sub_tensor = weight_tensor[start_index:end_index,
|
||||
...]
|
||||
|
||||
# bitsandbytes requires data in GPU
|
||||
if weight_sub_tensor.is_cuda:
|
||||
loaded_weight = weight_sub_tensor
|
||||
else:
|
||||
loaded_weight = weight_sub_tensor.cuda()
|
||||
|
||||
# remove the following after the issue is fixed:
|
||||
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
|
||||
if loaded_weight.is_contiguous() is False:
|
||||
loaded_weight = loaded_weight.contiguous()
|
||||
|
||||
with set_default_torch_dtype(torch.float32):
|
||||
processed_weight, quant_state = quantize_4bit(
|
||||
loaded_weight,
|
||||
compress_statistics=True,
|
||||
quant_type="nf4",
|
||||
)
|
||||
|
||||
quant_state_dict[mapped_weight_name] = quant_state
|
||||
else:
|
||||
processed_weight = weight_tensor
|
||||
yield org_weight_name, processed_weight
|
||||
|
||||
def _get_bnb_target_modules(self, model: nn.Module) -> None:
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if (isinstance(module, LinearBase) and
|
||||
hasattr(module.quant_method, "quant_config")):
|
||||
if modules_info := self.modules_mapping.get_sub_modules(name):
|
||||
# Map vllm's names to transformers's names.
|
||||
rep_name, sub_modules = modules_info
|
||||
for sub_name in sub_modules:
|
||||
self.target_modules.append(
|
||||
name.replace(rep_name, sub_name))
|
||||
# Add original module name even if the module has stacked map,
|
||||
# in case model has a mixture of disk-merged and disk-splitted
|
||||
# weights with same last name.
|
||||
self.target_modules.append(name)
|
||||
|
||||
assert (self.target_modules
|
||||
), "vllm currently does not support BNB quantization for"
|
||||
f" {type(model).__name__}"
|
||||
|
||||
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
|
||||
if not hasattr(model, "load_weights"):
|
||||
raise AttributeError(
|
||||
"The required method 'load_weights' is not defined in class"
|
||||
f" {type(model).__name__}.")
|
||||
|
||||
if not hasattr(model, "packed_modules_mapping"):
|
||||
raise AttributeError(
|
||||
f"Model {type(model).__name__} does not support BitsAndBytes "
|
||||
"quantization yet. No 'packed_modules_mapping' found.")
|
||||
self.is_pool_model=is_pooling_model(model)
|
||||
|
||||
self.modules_mapping = ParamMapping(get_packed_modules_mapping(model))
|
||||
|
||||
# For some models like Molmo, we need to use hf_to_vllm_mapper
|
||||
# to ensure correct loading of weights.
|
||||
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
|
||||
self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name)
|
||||
|
||||
# Modules whose weights might have fused on disk
|
||||
# we need their output_sizes to make shard in flight correctly with TP
|
||||
self.maybe_fused_weights_modules: dict[str, list[int]] = {}
|
||||
self._get_bnb_target_modules(model)
|
||||
for name, module in model.named_modules():
|
||||
# Some modules like `ReplicatedLinear` should not have their weights
|
||||
# sharded. The reason for implementing it this way is to avoid new
|
||||
# static variable in the model implementation.
|
||||
if isinstance(module, (ReplicatedLinear, )):
|
||||
self.unsharded_weights_modules.append(name)
|
||||
# `QKVParallelLinear` and `MergedColumnParallelLinear` might have
|
||||
# fused weights on disk. We need to use the output sizes of these
|
||||
# modules to shard the weights correctly.
|
||||
elif isinstance(module,
|
||||
(QKVParallelLinear, MergedColumnParallelLinear)):
|
||||
self.maybe_fused_weights_modules[name] = module.output_sizes
|
||||
# In TP, these weights are partitioned along the column
|
||||
# dimension (dim=-1)
|
||||
elif isinstance(module, (RowParallelLinear, )):
|
||||
self.column_sharded_weights_modules.append(name)
|
||||
|
||||
self.model_type = type(model).__name__
|
||||
|
||||
logger.info("Loading weights with BitsAndBytes quantization. "
|
||||
"May take a while ...")
|
||||
|
||||
quant_config = getattr(model_config.hf_config, "quantization_config",
|
||||
None)
|
||||
|
||||
pre_quant = False
|
||||
if quant_config is not None:
|
||||
quant_method = quant_config.get("quant_method")
|
||||
if quant_method == "bitsandbytes":
|
||||
pre_quant = True
|
||||
else:
|
||||
raise ValueError(
|
||||
f"BitsAndBytes loader does not support {quant_method} "
|
||||
"quantization")
|
||||
|
||||
# The quant_states in pre_quantized models cannot work with a split
|
||||
# weight tensor. So TP does not work with pre_quantized bnb models.
|
||||
if pre_quant and get_tensor_model_parallel_world_size() > 1:
|
||||
raise ValueError(
|
||||
"Prequant BitsAndBytes models with tensor parallelism is not "
|
||||
"supported. Please try with pipeline parallelism.")
|
||||
|
||||
load_8bit = False
|
||||
if pre_quant:
|
||||
load_8bit = quant_config.get("load_in_8bit", False)
|
||||
|
||||
qweight_iterator, quant_state_dict = (
|
||||
self._get_quantized_weights_iterator(model_config.model,
|
||||
model_config.revision,
|
||||
pre_quant, load_8bit))
|
||||
|
||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||
loaded_weights = model.load_weights(qweight_iterator)
|
||||
# Some models may have weights loading tracker unimplemented.
|
||||
if loaded_weights is not None:
|
||||
weights_not_loaded = weights_to_load - loaded_weights
|
||||
if weights_not_loaded:
|
||||
raise ValueError("Following weights were not initialized from "
|
||||
f"checkpoint: {weights_not_loaded}")
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
param_dict = dict(model.named_parameters())
|
||||
stacked_quant_state_dict: dict[str, dict[int, Any]] = {}
|
||||
# TODO: Change this lazy import to normal import
|
||||
# after the checks are updated to run on a new version
|
||||
from vllm.model_executor.models.utils import is_pp_missing_parameter
|
||||
|
||||
for quant_param_name in quant_state_dict:
|
||||
if is_pp_missing_parameter(quant_param_name, model):
|
||||
continue
|
||||
|
||||
non_stacked_param_name = quant_param_name
|
||||
|
||||
shard_index = 0
|
||||
for shard_name, (
|
||||
weight_name,
|
||||
index,
|
||||
) in self.modules_mapping.inverse_packed_mapping.items():
|
||||
# Some models, such as MiniCPM V2.5/2.6, contain both
|
||||
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
|
||||
# from being incorrectly identified as being present in
|
||||
# 'vpm.encoder.layers.0.self_attn.qkv_proj.weight
|
||||
shard_pos = quant_param_name.find(shard_name)
|
||||
can_correct_rename = (shard_pos
|
||||
> 0) and (quant_param_name[shard_pos - 1]
|
||||
== ".")
|
||||
# If the quant_param_name is packed, it won't occur in the
|
||||
# param_dict before renaming.
|
||||
new_quant_param_name = quant_param_name.replace(
|
||||
shard_name, weight_name)
|
||||
need_rename = (quant_param_name not in param_dict) \
|
||||
and (new_quant_param_name in param_dict)
|
||||
if can_correct_rename and need_rename:
|
||||
shard_index = index
|
||||
quant_param_name = new_quant_param_name
|
||||
break
|
||||
|
||||
# Models like Clip/Siglip may skip some layers in initialization,
|
||||
# causing unused quant_param_name in state_dict.
|
||||
if quant_param_name not in param_dict:
|
||||
continue
|
||||
|
||||
if quant_param_name not in stacked_quant_state_dict:
|
||||
stacked_quant_state_dict[quant_param_name] = {}
|
||||
|
||||
stacked_quant_state_dict[quant_param_name][shard_index] = (
|
||||
quant_state_dict[non_stacked_param_name])
|
||||
|
||||
# save quant_states and offsets as the attributes of the parameters
|
||||
for param_name, param in param_dict.items():
|
||||
if param_name in stacked_quant_state_dict:
|
||||
quant_states = stacked_quant_state_dict[param_name]
|
||||
set_weight_attrs(param, {"bnb_quant_state": quant_states})
|
||||
|
||||
pack_ratio = getattr(param, "pack_factor", -1)
|
||||
if pack_ratio == -1:
|
||||
raise ValueError(
|
||||
f"pack_factor not set for parameter {param_name}.")
|
||||
|
||||
num_elements = [0] * len(quant_states)
|
||||
for seq, quant_state in quant_states.items():
|
||||
num_elements[seq] = (math.prod(quant_state.shape) //
|
||||
pack_ratio)
|
||||
|
||||
offsets = np.concatenate(([0], np.cumsum(num_elements)))
|
||||
# Make torch infer_schema happy
|
||||
offsets = torch.tensor(offsets).cpu()
|
||||
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
|
||||
|
||||
if load_8bit:
|
||||
set_weight_attrs(
|
||||
param, {"matmul_state": [None] * len(quant_states)})
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(model_config.model, model_config.revision)
|
||||
282
vllm/model_executor/model_loader/default_loader.py
Normal file
282
vllm/model_executor/model_loader/default_loader.py
Normal file
@@ -0,0 +1,282 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import dataclasses
|
||||
import glob
|
||||
import os
|
||||
import time
|
||||
from collections.abc import Generator, Iterable
|
||||
from typing import Optional, cast
|
||||
|
||||
import huggingface_hub
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import LoadConfig, LoadFormat, ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_safetensors_index_file_from_hf, download_weights_from_hf,
|
||||
fastsafetensors_weights_iterator, filter_duplicate_safetensors_files,
|
||||
filter_files_not_needed_for_inference, get_lock, np_cache_weights_iterator,
|
||||
pt_weights_iterator, safetensors_weights_iterator)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DefaultModelLoader(BaseModelLoader):
|
||||
"""Model loader that can load different file types from disk."""
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Source:
|
||||
"""A source for weights."""
|
||||
|
||||
model_or_path: str
|
||||
"""The model ID or path."""
|
||||
|
||||
revision: Optional[str]
|
||||
"""The optional model revision."""
|
||||
|
||||
prefix: str = ""
|
||||
"""A prefix to prepend to all weights."""
|
||||
|
||||
fall_back_to_pt: bool = True
|
||||
"""Whether .pt weights can be used."""
|
||||
|
||||
allow_patterns_overrides: Optional[list[str]] = None
|
||||
"""If defined, weights will load exclusively using these patterns."""
|
||||
|
||||
counter_before_loading_weights: float = 0.0
|
||||
counter_after_loading_weights: float = 0.0
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
if load_config.model_loader_extra_config:
|
||||
raise ValueError(f"Model loader extra config is not supported for "
|
||||
f"load format {load_config.load_format}")
|
||||
|
||||
def _maybe_download_from_modelscope(
|
||||
self, model: str, revision: Optional[str]) -> Optional[str]:
|
||||
"""Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
|
||||
|
||||
Returns the path to the downloaded model, or None if the model is not
|
||||
downloaded from ModelScope."""
|
||||
if envs.VLLM_USE_MODELSCOPE:
|
||||
# download model from ModelScope hub,
|
||||
# lazy import so that modelscope is not required for normal use.
|
||||
# pylint: disable=C.
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
|
||||
if not os.path.exists(model):
|
||||
# Use file lock to prevent multiple processes from
|
||||
# downloading the same model weights at the same time.
|
||||
with get_lock(model, self.load_config.download_dir):
|
||||
model_path = snapshot_download(
|
||||
model_id=model,
|
||||
cache_dir=self.load_config.download_dir,
|
||||
local_files_only=huggingface_hub.constants.
|
||||
HF_HUB_OFFLINE,
|
||||
revision=revision,
|
||||
ignore_file_pattern=self.load_config.ignore_patterns,
|
||||
)
|
||||
else:
|
||||
model_path = model
|
||||
return model_path
|
||||
return None
|
||||
|
||||
def _prepare_weights(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
revision: Optional[str],
|
||||
fall_back_to_pt: bool,
|
||||
allow_patterns_overrides: Optional[list[str]],
|
||||
) -> tuple[str, list[str], bool]:
|
||||
"""Prepare weights for the model.
|
||||
|
||||
If the model is not local, it will be downloaded."""
|
||||
model_name_or_path = (self._maybe_download_from_modelscope(
|
||||
model_name_or_path, revision) or model_name_or_path)
|
||||
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
load_format = self.load_config.load_format
|
||||
use_safetensors = False
|
||||
index_file = SAFE_WEIGHTS_INDEX_NAME
|
||||
# Some quantized models use .pt files for storing the weights.
|
||||
if load_format == LoadFormat.AUTO:
|
||||
allow_patterns = ["*.safetensors", "*.bin"]
|
||||
elif (load_format == LoadFormat.SAFETENSORS
|
||||
or load_format == LoadFormat.FASTSAFETENSORS):
|
||||
use_safetensors = True
|
||||
allow_patterns = ["*.safetensors"]
|
||||
elif load_format == LoadFormat.MISTRAL:
|
||||
use_safetensors = True
|
||||
allow_patterns = ["consolidated*.safetensors"]
|
||||
index_file = "consolidated.safetensors.index.json"
|
||||
elif load_format == LoadFormat.PT:
|
||||
allow_patterns = ["*.pt"]
|
||||
elif load_format == LoadFormat.NPCACHE:
|
||||
allow_patterns = ["*.bin"]
|
||||
else:
|
||||
raise ValueError(f"Unknown load_format: {load_format}")
|
||||
|
||||
if fall_back_to_pt:
|
||||
allow_patterns += ["*.pt"]
|
||||
|
||||
if allow_patterns_overrides is not None:
|
||||
allow_patterns = allow_patterns_overrides
|
||||
|
||||
if not is_local:
|
||||
hf_folder = download_weights_from_hf(
|
||||
model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
allow_patterns,
|
||||
revision,
|
||||
ignore_patterns=self.load_config.ignore_patterns,
|
||||
)
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
|
||||
hf_weights_files: list[str] = []
|
||||
for pattern in allow_patterns:
|
||||
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
||||
if len(hf_weights_files) > 0:
|
||||
if pattern == "*.safetensors":
|
||||
use_safetensors = True
|
||||
break
|
||||
|
||||
if use_safetensors:
|
||||
# For models like Mistral-7B-Instruct-v0.3
|
||||
# there are both sharded safetensors files and a consolidated
|
||||
# safetensors file. Using both breaks.
|
||||
# Here, we download the `model.safetensors.index.json` and filter
|
||||
# any files not found in the index.
|
||||
if not is_local:
|
||||
download_safetensors_index_file_from_hf(
|
||||
model_name_or_path,
|
||||
index_file,
|
||||
self.load_config.download_dir,
|
||||
revision,
|
||||
)
|
||||
hf_weights_files = filter_duplicate_safetensors_files(
|
||||
hf_weights_files, hf_folder, index_file)
|
||||
else:
|
||||
hf_weights_files = filter_files_not_needed_for_inference(
|
||||
hf_weights_files)
|
||||
|
||||
if len(hf_weights_files) == 0:
|
||||
raise RuntimeError(
|
||||
f"Cannot find any model weights with `{model_name_or_path}`")
|
||||
|
||||
return hf_folder, hf_weights_files, use_safetensors
|
||||
|
||||
def _get_weights_iterator(
|
||||
self, source: "Source"
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Get an iterator for the model weights based on the load format."""
|
||||
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
|
||||
source.model_or_path, source.revision, source.fall_back_to_pt,
|
||||
source.allow_patterns_overrides)
|
||||
if self.load_config.load_format == LoadFormat.NPCACHE:
|
||||
# Currently np_cache only support *.bin checkpoints
|
||||
assert use_safetensors is False
|
||||
weights_iterator = np_cache_weights_iterator(
|
||||
source.model_or_path,
|
||||
self.load_config.download_dir,
|
||||
hf_folder,
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
)
|
||||
elif use_safetensors:
|
||||
if self.load_config.load_format == LoadFormat.FASTSAFETENSORS:
|
||||
weights_iterator = fastsafetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
)
|
||||
else:
|
||||
weights_iterator = safetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
)
|
||||
else:
|
||||
weights_iterator = pt_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
self.load_config.pt_load_map_location,
|
||||
)
|
||||
|
||||
if current_platform.is_tpu():
|
||||
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
|
||||
# not too many ops are accumulated in the XLA program.
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
def _xla_weights_iterator(iterator: Generator):
|
||||
for weights in iterator:
|
||||
yield weights
|
||||
xm.mark_step()
|
||||
|
||||
weights_iterator = _xla_weights_iterator(weights_iterator)
|
||||
|
||||
elif current_platform.is_hpu():
|
||||
import habana_frameworks.torch.core as htcore
|
||||
|
||||
def _hpu_weights_iterator(iterator: Generator):
|
||||
for weights in iterator:
|
||||
yield weights
|
||||
htcore.mark_step()
|
||||
|
||||
weights_iterator = _hpu_weights_iterator(weights_iterator)
|
||||
|
||||
if self.counter_before_loading_weights == 0.0:
|
||||
self.counter_before_loading_weights = time.perf_counter()
|
||||
# Apply the prefix.
|
||||
return ((source.prefix + name, tensor)
|
||||
for (name, tensor) in weights_iterator)
|
||||
|
||||
def get_all_weights(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
model: nn.Module,
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
primary_weights = DefaultModelLoader.Source(
|
||||
model_config.model,
|
||||
model_config.revision,
|
||||
prefix="",
|
||||
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load",
|
||||
True),
|
||||
allow_patterns_overrides=getattr(model, "allow_patterns_overrides",
|
||||
None),
|
||||
)
|
||||
yield from self._get_weights_iterator(primary_weights)
|
||||
|
||||
secondary_weights = cast(
|
||||
Iterable[DefaultModelLoader.Source],
|
||||
getattr(model, "secondary_weights", ()),
|
||||
)
|
||||
for source in secondary_weights:
|
||||
yield from self._get_weights_iterator(source)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(model_config.model,
|
||||
model_config.revision,
|
||||
fall_back_to_pt=True,
|
||||
allow_patterns_overrides=None)
|
||||
|
||||
def load_weights(self, model: nn.Module,
|
||||
model_config: ModelConfig) -> None:
|
||||
weights_to_load = {name for name, _ in model.named_parameters()}
|
||||
loaded_weights = model.load_weights(
|
||||
self.get_all_weights(model_config, model))
|
||||
self.counter_after_loading_weights = time.perf_counter()
|
||||
logger.info(
|
||||
"Loading weights took %.2f seconds",
|
||||
self.counter_after_loading_weights -
|
||||
self.counter_before_loading_weights)
|
||||
# We only enable strict check for non-quantized models
|
||||
# that have loaded weights tracking currently.
|
||||
if model_config.quantization is None and loaded_weights is not None:
|
||||
weights_not_loaded = weights_to_load - loaded_weights
|
||||
if weights_not_loaded:
|
||||
raise ValueError("Following weights were not initialized from "
|
||||
f"checkpoint: {weights_not_loaded}")
|
||||
27
vllm/model_executor/model_loader/dummy_loader.py
Normal file
27
vllm/model_executor/model_loader/dummy_loader.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.config import LoadConfig, ModelConfig
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
initialize_dummy_weights)
|
||||
|
||||
|
||||
class DummyModelLoader(BaseModelLoader):
|
||||
"""Model loader that will set model weights to random values."""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
if load_config.model_loader_extra_config:
|
||||
raise ValueError(f"Model loader extra config is not supported for "
|
||||
f"load format {load_config.load_format}")
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
pass # Nothing to download
|
||||
|
||||
def load_weights(self, model: nn.Module,
|
||||
model_config: ModelConfig) -> None:
|
||||
# NOTE(woosuk): For accurate performance evaluation, we assign
|
||||
# random values to the weights.
|
||||
initialize_dummy_weights(model)
|
||||
120
vllm/model_executor/model_loader/gguf_loader.py
Normal file
120
vllm/model_executor/model_loader/gguf_loader.py
Normal file
@@ -0,0 +1,120 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
|
||||
import gguf
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from vllm.config import LoadConfig, ModelConfig, VllmConfig
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
initialize_model, process_weights_after_loading, set_default_torch_dtype)
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
get_gguf_extra_tensor_names, gguf_quant_weights_iterator)
|
||||
|
||||
|
||||
class GGUFModelLoader(BaseModelLoader):
|
||||
"""
|
||||
Model loader that can load GGUF files. This is useful for loading models
|
||||
that are quantized with GGUF and saved in the GGUF format. This loader
|
||||
supports loading both full models and sharded models.
|
||||
"""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
if load_config.model_loader_extra_config:
|
||||
raise ValueError(f"Model loader extra config is not supported for "
|
||||
f"load format {load_config.load_format}")
|
||||
|
||||
def _prepare_weights(self, model_name_or_path: str):
|
||||
if os.path.isfile(model_name_or_path):
|
||||
return model_name_or_path
|
||||
else:
|
||||
raise ValueError(f"{model_name_or_path} is not a file.")
|
||||
|
||||
def _get_gguf_weights_map(self, model_config: ModelConfig):
|
||||
"""
|
||||
GGUF uses this naming convention for their tensors from HF checkpoint:
|
||||
`blk.N.BB.weight` and `blk.N.BB.bias`
|
||||
where N signifies the block number of a layer, and BB signifies the
|
||||
attention/mlp layer components.
|
||||
See "Standardized tensor names" in
|
||||
https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
|
||||
"""
|
||||
config = model_config.hf_config
|
||||
model_type = config.model_type
|
||||
gguf_to_hf_name_map = {}
|
||||
# hack: ggufs have a different name than transformers
|
||||
if model_type == "cohere":
|
||||
model_type = "command-r"
|
||||
if model_type in ("deepseek_v3", "deepseek_v2"):
|
||||
model_type = "deepseek2"
|
||||
# GGUF layer map assumes that we will have a merged expert weights
|
||||
# so we need to map them manually
|
||||
for idx in range(config.num_hidden_layers):
|
||||
gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = \
|
||||
f"model.layers.{idx}.mlp.gate.e_score_correction_bias"
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = \
|
||||
f"model.layers.{idx}.mlp.experts.0.down_proj.weight"
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = \
|
||||
f"model.layers.{idx}.mlp.experts.0.gate_proj.weight"
|
||||
gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = \
|
||||
f"model.layers.{idx}.mlp.experts.0.up_proj.weight"
|
||||
|
||||
arch = None
|
||||
for key, value in gguf.MODEL_ARCH_NAMES.items():
|
||||
if value == model_type:
|
||||
arch = key
|
||||
break
|
||||
if arch is None:
|
||||
raise RuntimeError(f"Unknown gguf model_type: {model_type}")
|
||||
num_layers = config.num_hidden_layers
|
||||
name_map = gguf.get_tensor_name_map(arch, num_layers)
|
||||
with torch.device("meta"):
|
||||
dummy_model = AutoModelForCausalLM.from_config(
|
||||
config, trust_remote_code=model_config.trust_remote_code)
|
||||
state_dict = dummy_model.state_dict()
|
||||
|
||||
for hf_name in state_dict:
|
||||
name, suffix = hf_name.rsplit(".", 1)
|
||||
gguf_name = name_map.get_name(name)
|
||||
gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name
|
||||
return gguf_to_hf_name_map
|
||||
|
||||
def _get_weights_iterator(
|
||||
self, model_name_or_path: str, gguf_to_hf_name_map: dict[str, str]
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
return gguf_quant_weights_iterator(model_name_or_path,
|
||||
gguf_to_hf_name_map)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(model_config.model)
|
||||
|
||||
def load_weights(self, model: nn.Module,
|
||||
model_config: ModelConfig) -> None:
|
||||
local_model_path = self._prepare_weights(model_config.model)
|
||||
gguf_weights_map = self._get_gguf_weights_map(model_config)
|
||||
model.load_weights(
|
||||
self._get_weights_iterator(local_model_path, gguf_weights_map))
|
||||
|
||||
def load_model(self, vllm_config: VllmConfig,
|
||||
model_config: ModelConfig) -> nn.Module:
|
||||
device_config = vllm_config.device_config
|
||||
local_model_path = self._prepare_weights(model_config.model)
|
||||
gguf_weights_map = self._get_gguf_weights_map(model_config)
|
||||
# we can only know if tie word embeddings after mapping weights
|
||||
if "lm_head.weight" in get_gguf_extra_tensor_names(
|
||||
local_model_path, gguf_weights_map):
|
||||
model_config.hf_config.update({"tie_word_embeddings": True})
|
||||
|
||||
target_device = torch.device(device_config.device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
model = initialize_model(vllm_config=vllm_config)
|
||||
self.load_weights(model, model_config)
|
||||
|
||||
process_weights_after_loading(model, model_config, target_device)
|
||||
return model
|
||||
476
vllm/model_executor/model_loader/neuron.py
Normal file
476
vllm/model_executor/model_loader/neuron.py
Normal file
@@ -0,0 +1,476 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utilities for selecting and loading Neuron models in transformers-neuronx
|
||||
framework."""
|
||||
import ast
|
||||
import copy
|
||||
import importlib
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
SpeculativeConfig)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import get_quantization_config
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||
SequenceOutput)
|
||||
|
||||
TORCH_DTYPE_TO_NEURON_AMP = {
|
||||
"auto": "f32",
|
||||
"half": "f16",
|
||||
"float16": "f16",
|
||||
"bfloat16": "bf16",
|
||||
"float": "f32",
|
||||
"float32": "f32",
|
||||
torch.float16: "f16",
|
||||
torch.bfloat16: "bf16",
|
||||
torch.float32: "f32",
|
||||
}
|
||||
|
||||
# Models supported by Neuron.
|
||||
_NEURON_SUPPORTED_MODELS: dict[str, tuple[str, str, str]] = {
|
||||
"LlamaForCausalLM": ("transformers_neuronx.llama.model",
|
||||
"LlamaForSampling", "LlamaForCausalLM"),
|
||||
"MistralForCausalLM": ("transformers_neuronx.mistral.model",
|
||||
"MistralForSampling", "MistralForCausalLM")
|
||||
}
|
||||
|
||||
|
||||
class NeuronCausalLM(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
on_device_sampling_disabled: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size,
|
||||
logits_as_input=True)
|
||||
|
||||
self.on_device_sampling_disabled = on_device_sampling_disabled
|
||||
if self.on_device_sampling_disabled:
|
||||
# Use default sampler
|
||||
self.sampler = Sampler()
|
||||
|
||||
# Lazy initialized
|
||||
self.model: nn.Module
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_block_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
logits = self.model(input_ids,
|
||||
cache_ids=positions,
|
||||
start_ids=input_block_ids)
|
||||
return logits
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(None, hidden_states, sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
|
||||
if self.on_device_sampling_disabled:
|
||||
next_tokens = self.sampler(logits, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
# On-device sampling outputs the token ids directly.
|
||||
sampled_token_ids = logits.flatten()
|
||||
next_tokens = []
|
||||
sample_idx = 0
|
||||
for seq_group in sampling_metadata.seq_groups:
|
||||
samples = []
|
||||
for seq_id in seq_group.seq_ids:
|
||||
token_id = sampled_token_ids[sample_idx].item()
|
||||
samples.append(
|
||||
SequenceOutput(parent_seq_id=seq_id,
|
||||
output_token=token_id,
|
||||
logprobs={token_id: Logprob(token_id)}))
|
||||
sample_idx += 1
|
||||
next_tokens.append(
|
||||
CompletionSequenceGroupOutput(samples=samples,
|
||||
prompt_logprobs=None))
|
||||
|
||||
return SamplerOutput(outputs=next_tokens)
|
||||
|
||||
def load_weights(self, model_name_or_path: str, **kwargs):
|
||||
arch = _get_model_architecture(self.config)
|
||||
neuronx_module_path, neuronx_model_cls_name, hf_model_cls_name = (
|
||||
_NEURON_SUPPORTED_MODELS[arch])
|
||||
neuronx_module = importlib.import_module(neuronx_module_path)
|
||||
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
|
||||
|
||||
self.model = neuronx_model_cls.from_pretrained(model_name_or_path,
|
||||
**kwargs)
|
||||
self.model.to_neuron()
|
||||
|
||||
|
||||
class NeuronSpeculationCausalLM(nn.Module):
|
||||
"""A Neuron-optimized causal language model with speculative decoding."""
|
||||
|
||||
SPECULATION_TERMINATION_ID = -1
|
||||
|
||||
def __init__(self, speculation_model) -> None:
|
||||
super().__init__()
|
||||
self.model = speculation_model
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_block_ids: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
tokens, counts = self.model.speculative_iteration(
|
||||
input_ids, positions, input_block_ids)
|
||||
|
||||
# Mark the end of accepted speculative tokens for each sequence with the
|
||||
# speculation termination id.
|
||||
batch_size, steps = tokens.shape
|
||||
mask = torch.arange(steps).expand(batch_size, -1) >= counts
|
||||
tokens[mask] = self.SPECULATION_TERMINATION_ID
|
||||
|
||||
return tokens
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[list[SamplerOutput]]:
|
||||
batch_size, num_steps = logits.shape
|
||||
seq_ids = [
|
||||
seq_id for sg in sampling_metadata.seq_groups
|
||||
for seq_id in sg.seq_ids
|
||||
]
|
||||
# Organize input tensors by step instead of by sequence.
|
||||
accepted_token_ids_by_step = logits.transpose(0, 1)
|
||||
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
|
||||
|
||||
sampler_output_list = []
|
||||
for step_index in range(num_steps):
|
||||
if all(token_id == self.SPECULATION_TERMINATION_ID
|
||||
for token_id in accepted_token_ids_by_step[step_index]):
|
||||
break
|
||||
step_output_token_ids = []
|
||||
for sequence_index in range(batch_size):
|
||||
token_id = accepted_token_ids_by_step[step_index][
|
||||
sequence_index]
|
||||
step_output_token_ids.append(
|
||||
CompletionSequenceGroupOutput(samples=[
|
||||
SequenceOutput(parent_seq_id=seq_ids[sequence_index],
|
||||
output_token=token_id,
|
||||
logprobs={token_id: Logprob(token_id)})
|
||||
],
|
||||
prompt_logprobs=None))
|
||||
sampler_output_list.append(
|
||||
SamplerOutput(outputs=step_output_token_ids))
|
||||
return sampler_output_list
|
||||
|
||||
|
||||
def _get_model_architecture(config: PretrainedConfig) -> str:
|
||||
architectures = getattr(config, "architectures", [])
|
||||
for arch in architectures:
|
||||
if arch in _NEURON_SUPPORTED_MODELS:
|
||||
return arch
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} are not supported on Neuron "
|
||||
f"for now. Supported architectures: "
|
||||
f"{list(_NEURON_SUPPORTED_MODELS.keys())}")
|
||||
|
||||
|
||||
def _get_buckets(env: str, default_value: list[int]) -> list[int]:
|
||||
env_value = os.getenv(env)
|
||||
if env_value is None:
|
||||
return default_value
|
||||
buckets_remove_empty = filter(
|
||||
lambda x: x is not None and len(x.strip()) > 0, env_value.split(","))
|
||||
buckets_int = map(int, buckets_remove_empty)
|
||||
buckets_list = list(buckets_int)
|
||||
return buckets_list
|
||||
|
||||
|
||||
def _get_default_neuron_config(model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig):
|
||||
"""Generate a neuron config based on vllm config args."""
|
||||
from transformers_neuronx.config import ContinuousBatchingConfig
|
||||
from transformers_neuronx.constants import LAYOUT_BSH
|
||||
|
||||
continuous_batching_config = ContinuousBatchingConfig(
|
||||
batch_size_for_shared_caches=scheduler_config.max_num_seqs)
|
||||
quant_config = dict(
|
||||
dequant_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
|
||||
quantize_method="vector_dynamic")
|
||||
neuron_quantization_config_builder = lambda quant: get_quantization_config(
|
||||
quant).from_config(quant_config).get_quant_method(None, "")
|
||||
# TODO: Add Paged attention config to the default neuron arguments.
|
||||
default_neuron_args = dict(
|
||||
collectives_layout=LAYOUT_BSH,
|
||||
attention_layout=LAYOUT_BSH,
|
||||
fuse_qkv=True,
|
||||
quant=neuron_quantization_config_builder(model_config.quantization)
|
||||
if model_config.quantization else None,
|
||||
continuous_batching=continuous_batching_config,
|
||||
weight_tiling=bool(model_config.quantization),
|
||||
on_device_generation=_get_neuron_on_device_generation_config(
|
||||
model_config))
|
||||
return default_neuron_args
|
||||
|
||||
|
||||
def _get_default_neuron_config_for_speculation(
|
||||
model_config: ModelConfig, parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig):
|
||||
"""Generate a neuron config for speculative decoding based on
|
||||
vllm config args."""
|
||||
from transformers_neuronx.config import ContinuousBatchingConfig
|
||||
from transformers_neuronx.constants import LAYOUT_BSH
|
||||
|
||||
continuous_batching_config = ContinuousBatchingConfig(
|
||||
batch_size_for_shared_caches=scheduler_config.max_num_seqs)
|
||||
|
||||
default_neuron_args = dict(collectives_layout=LAYOUT_BSH,
|
||||
attention_layout=LAYOUT_BSH,
|
||||
fuse_qkv=True,
|
||||
on_device_embedding=True,
|
||||
continuous_batching=continuous_batching_config,
|
||||
on_device_generation=copy.deepcopy(
|
||||
model_config.neuron_sampling_params))
|
||||
return default_neuron_args
|
||||
|
||||
|
||||
def _get_neuron_on_device_generation_config(model_config: ModelConfig):
|
||||
if not _is_neuron_on_device_sampling_disabled(model_config):
|
||||
return copy.deepcopy(model_config.neuron_sampling_params)
|
||||
return None
|
||||
|
||||
|
||||
def _is_neuron_on_device_sampling_disabled(model_config: ModelConfig) -> bool:
|
||||
return not getattr(model_config, "neuron_sampling_params", None)
|
||||
|
||||
|
||||
def _get_neuron_config_after_override(default_neuron_config,
|
||||
overridden_neuron_config):
|
||||
from transformers_neuronx.config import (ContinuousBatchingConfig,
|
||||
GenerationConfig,
|
||||
KVCacheQuantizationConfig,
|
||||
NeuronConfig, QuantizationConfig,
|
||||
SparseAttnConfig)
|
||||
|
||||
sparse_attn = overridden_neuron_config.pop("sparse_attn", {})
|
||||
if sparse_attn:
|
||||
overridden_neuron_config["sparse_attn"] = SparseAttnConfig(
|
||||
**sparse_attn)
|
||||
|
||||
kv_cache_quant = overridden_neuron_config.pop("kv_cache_quant", {})
|
||||
if kv_cache_quant:
|
||||
overridden_neuron_config["kv_cache_quant"] = KVCacheQuantizationConfig(
|
||||
**kv_cache_quant)
|
||||
|
||||
continuous_batching = overridden_neuron_config.pop("continuous_batching",
|
||||
{})
|
||||
if continuous_batching:
|
||||
overridden_neuron_config[
|
||||
"continuous_batching"] = ContinuousBatchingConfig(
|
||||
**continuous_batching)
|
||||
|
||||
quant = overridden_neuron_config.pop("quant", {})
|
||||
if quant:
|
||||
overridden_neuron_config["quant"] = QuantizationConfig(**quant)
|
||||
|
||||
on_device_generation = overridden_neuron_config.pop(
|
||||
"on_device_generation", {})
|
||||
if on_device_generation:
|
||||
overridden_neuron_config["on_device_generation"] = GenerationConfig(
|
||||
**on_device_generation)
|
||||
default_neuron_config.update(overridden_neuron_config)
|
||||
return NeuronConfig(**default_neuron_config)
|
||||
|
||||
|
||||
def get_neuron_model(model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig) -> nn.Module:
|
||||
"""Initializes a neuron-optimized model for inference."""
|
||||
# Create a model instance.
|
||||
model = NeuronCausalLM(
|
||||
model_config.hf_config,
|
||||
_is_neuron_on_device_sampling_disabled(model_config))
|
||||
|
||||
default_neuron_config_args = _get_default_neuron_config(
|
||||
model_config, parallel_config, scheduler_config)
|
||||
|
||||
neuron_config = _get_neuron_config_after_override(
|
||||
default_neuron_config_args, model_config.override_neuron_config)
|
||||
|
||||
context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS",
|
||||
[scheduler_config.max_model_len])
|
||||
n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS",
|
||||
[scheduler_config.max_model_len])
|
||||
|
||||
model.load_weights(model_config.model,
|
||||
tp_degree=parallel_config.tensor_parallel_size,
|
||||
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
|
||||
neuron_config=neuron_config,
|
||||
context_length_estimate=context_length_estimates,
|
||||
n_positions=n_positions,
|
||||
batch_size=scheduler_config.max_num_seqs)
|
||||
|
||||
return model.eval()
|
||||
|
||||
|
||||
def get_neuron_speculation_model(model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
speculation_config: SpeculativeConfig):
|
||||
"""Initializes a neuron-optimized speculation model for inference.
|
||||
|
||||
This method is only applicable for speculation with a standalone draft model
|
||||
"""
|
||||
from transformers_neuronx.fused_speculation import FusedSpeculativeDecoder
|
||||
|
||||
# For Eagle SD, we need to pass in additional parameters in neuron config.
|
||||
is_eagle = getattr(speculation_config.draft_model_config.hf_config,
|
||||
"is_eagle", False)
|
||||
|
||||
# Create target model instance.
|
||||
target_model = NeuronCausalLM(model_config.hf_config)
|
||||
|
||||
default_neuron_config_args = _get_default_neuron_config_for_speculation(
|
||||
model_config, parallel_config, scheduler_config)
|
||||
if is_eagle:
|
||||
default_neuron_config_args['is_eagle_target'] = True
|
||||
|
||||
neuron_config = _get_neuron_config_after_override(
|
||||
default_neuron_config_args, model_config.override_neuron_config)
|
||||
|
||||
context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS",
|
||||
[scheduler_config.max_model_len])
|
||||
n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS",
|
||||
[scheduler_config.max_model_len])
|
||||
|
||||
target_model.load_weights(
|
||||
model_config.model,
|
||||
tp_degree=parallel_config.tensor_parallel_size,
|
||||
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
|
||||
neuron_config=neuron_config,
|
||||
context_length_estimate=context_length_estimates,
|
||||
n_positions=n_positions,
|
||||
batch_size=scheduler_config.max_num_seqs)
|
||||
|
||||
target_model.eval()
|
||||
|
||||
# Create draft model instance.
|
||||
draft_model = NeuronCausalLM(
|
||||
speculation_config.draft_model_config.hf_config)
|
||||
|
||||
default_draft_neuron_config_args = (
|
||||
_get_default_neuron_config_for_speculation(
|
||||
speculation_config.draft_model_config, parallel_config,
|
||||
scheduler_config))
|
||||
if is_eagle:
|
||||
default_draft_neuron_config_args['is_eagle_draft'] = True
|
||||
default_draft_neuron_config_args['has_pre_attention_norm'] = False
|
||||
|
||||
draft_neuron_config = _get_neuron_config_after_override(
|
||||
default_draft_neuron_config_args,
|
||||
speculation_config.draft_model_config.override_neuron_config)
|
||||
|
||||
draft_model.load_weights(speculation_config.draft_model_config.model,
|
||||
tp_degree=speculation_config.
|
||||
draft_parallel_config.tensor_parallel_size,
|
||||
amp=TORCH_DTYPE_TO_NEURON_AMP[
|
||||
speculation_config.draft_model_config.dtype],
|
||||
neuron_config=draft_neuron_config,
|
||||
context_length_estimate=context_length_estimates,
|
||||
n_positions=n_positions,
|
||||
batch_size=scheduler_config.max_num_seqs)
|
||||
|
||||
draft_model.eval()
|
||||
|
||||
num_speculative_tokens = speculation_config.num_speculative_tokens
|
||||
# Create speculation model instance.
|
||||
speculation_model = FusedSpeculativeDecoder(draft_model.model,
|
||||
target_model.model,
|
||||
num_speculative_tokens)
|
||||
speculation_model.to_neuron()
|
||||
|
||||
return NeuronSpeculationCausalLM(speculation_model)
|
||||
|
||||
|
||||
def get_neuron_eagle_speculation_model(model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
speculation_config: SpeculativeConfig):
|
||||
"""Initializes a neuron-optimized EAGLE speculation model for inference."""
|
||||
from transformers_neuronx.eagle_speculation import EagleSpeculativeDecoder
|
||||
|
||||
# Create target model instance.
|
||||
target_model = NeuronCausalLM(model_config.hf_config)
|
||||
|
||||
default_neuron_config_args = _get_default_neuron_config_for_speculation(
|
||||
model_config, parallel_config, scheduler_config)
|
||||
default_neuron_config_args['is_eagle_target'] = True
|
||||
neuron_config = _get_neuron_config_after_override(
|
||||
default_neuron_config_args, model_config.override_neuron_config)
|
||||
|
||||
context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS",
|
||||
[scheduler_config.max_model_len])
|
||||
n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS",
|
||||
[scheduler_config.max_model_len])
|
||||
|
||||
target_model.load_weights(
|
||||
model_config.model,
|
||||
tp_degree=parallel_config.tensor_parallel_size,
|
||||
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
|
||||
neuron_config=neuron_config,
|
||||
context_length_estimate=context_length_estimates,
|
||||
n_positions=n_positions,
|
||||
batch_size=scheduler_config.max_num_seqs)
|
||||
|
||||
target_model.eval()
|
||||
|
||||
# Create draft model instance.
|
||||
draft_model = NeuronCausalLM(
|
||||
speculation_config.draft_model_config.hf_config)
|
||||
|
||||
default_draft_neuron_config_args = (
|
||||
_get_default_neuron_config_for_speculation(
|
||||
speculation_config.draft_model_config, parallel_config,
|
||||
scheduler_config))
|
||||
default_draft_neuron_config_args['is_eagle_draft'] = True
|
||||
default_draft_neuron_config_args['has_pre_attention_norm'] = False
|
||||
draft_neuron_config = _get_neuron_config_after_override(
|
||||
default_draft_neuron_config_args,
|
||||
speculation_config.draft_model_config.override_neuron_config)
|
||||
|
||||
draft_model.load_weights(speculation_config.draft_model_config.model,
|
||||
tp_degree=speculation_config.
|
||||
draft_parallel_config.tensor_parallel_size,
|
||||
amp=TORCH_DTYPE_TO_NEURON_AMP[
|
||||
speculation_config.draft_model_config.dtype],
|
||||
neuron_config=draft_neuron_config,
|
||||
context_length_estimate=context_length_estimates,
|
||||
n_positions=n_positions,
|
||||
batch_size=scheduler_config.max_num_seqs)
|
||||
|
||||
draft_model.eval()
|
||||
|
||||
token_tree: dict[int, list[int]] = ast.literal_eval(
|
||||
speculation_config.speculative_token_tree)
|
||||
|
||||
speculation_model = EagleSpeculativeDecoder(draft_model.model,
|
||||
target_model.model,
|
||||
token_tree=token_tree)
|
||||
speculation_model.to_neuron()
|
||||
|
||||
return NeuronSpeculationCausalLM(speculation_model)
|
||||
685
vllm/model_executor/model_loader/neuronx_distributed.py
Normal file
685
vllm/model_executor/model_loader/neuronx_distributed.py
Normal file
@@ -0,0 +1,685 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utilities for selecting and loading Neuron models in
|
||||
neuronx-distributed-inference framework."""
|
||||
# Disabling yapf because yapf and isort have conflicts for the below imports
|
||||
# yapf: disable
|
||||
import copy
|
||||
import hashlib
|
||||
import importlib
|
||||
import multiprocessing
|
||||
import os
|
||||
import shutil
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from neuronx_distributed_inference.models.config import (
|
||||
FusedSpecNeuronConfig, OnDeviceSamplingConfig)
|
||||
from neuronx_distributed_inference.models.mllama.utils import (
|
||||
create_vision_mask)
|
||||
from neuronx_distributed_inference.modules.lora_serving import (
|
||||
LoraServingConfig)
|
||||
from neuronx_distributed_inference.utils.hf_adapter import (
|
||||
load_pretrained_config)
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig
|
||||
|
||||
from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
SpeculativeConfig)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
|
||||
SequenceOutput)
|
||||
|
||||
# yapf: enable
|
||||
logger = init_logger(__name__)
|
||||
|
||||
TORCH_DTYPE_TO_NEURON_AMP = {
|
||||
"auto": "float32",
|
||||
"half": "float16",
|
||||
"float16": "float16",
|
||||
"bfloat16": "bfloat16",
|
||||
"float": "float32",
|
||||
"float32": "float32",
|
||||
torch.float16: "float16",
|
||||
torch.bfloat16: "bfloat16",
|
||||
torch.float32: "float32",
|
||||
}
|
||||
|
||||
# Models supported by Neuronx distributed for inference.
|
||||
_NEURON_SUPPORTED_MODELS: dict[str, tuple[str, str]] = {
|
||||
"LlamaForCausalLM":
|
||||
("neuronx_distributed_inference.models.llama.modeling_llama",
|
||||
"NeuronLlamaForCausalLM"),
|
||||
"MistralForCausalLM":
|
||||
("neuronx_distributed_inference.models.llama.modeling_llama",
|
||||
"NeuronLlamaForCausalLM"),
|
||||
"DbrxForCausalLM":
|
||||
("neuronx_distributed_inference.models.dbrx.modeling_dbrx",
|
||||
"NeuronDbrxForCausalLM"),
|
||||
"MixtralForCausalLM":
|
||||
("neuronx_distributed_inference.models.mixtral.modeling_mixtral",
|
||||
"NeuronMixtralForCausalLM"),
|
||||
"MllamaForConditionalGeneration":
|
||||
("neuronx_distributed_inference.models.mllama.modeling_mllama",
|
||||
"NeuronMllamaForCausalLM"),
|
||||
}
|
||||
|
||||
|
||||
class NeuronCausalLM(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size,
|
||||
logits_as_input=True)
|
||||
self.sampler = Sampler()
|
||||
|
||||
# Lazy initialized
|
||||
self.model: nn.Module
|
||||
|
||||
def forward(self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_block_ids: torch.Tensor,
|
||||
sampling_params: torch.Tensor,
|
||||
prev_hidden: Optional[torch.Tensor] = None,
|
||||
adapter_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
# sort block ids sequentially for perf/neuron support reasons
|
||||
sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids)
|
||||
input_ids = torch.index_select(input_ids, 0, sorted_indices)
|
||||
positions = torch.index_select(positions, 0, sorted_indices)
|
||||
sampling_params = torch.index_select(sampling_params, 0,
|
||||
sorted_indices)
|
||||
output = self.model(input_ids,
|
||||
attention_mask=None,
|
||||
position_ids=positions,
|
||||
seq_ids=sorted_input_block_ids,
|
||||
sampling_params=sampling_params,
|
||||
prev_hidden=prev_hidden,
|
||||
adapter_ids=adapter_ids)
|
||||
# on-device sampling
|
||||
if self.config.neuron_config.on_device_sampling_config:
|
||||
output = output.hidden_states
|
||||
else:
|
||||
output = output.logits[:, -1, :]
|
||||
|
||||
restored_indices = torch.argsort(sorted_indices)
|
||||
if input_block_ids.shape[0] != 1:
|
||||
output = torch.index_select(output, 0, restored_indices)
|
||||
|
||||
return output
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(None, hidden_states, sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
# on-device sampling
|
||||
if self.config.neuron_config.on_device_sampling_config:
|
||||
batch_size = logits.shape
|
||||
seq_ids = [
|
||||
seq_id for sg in sampling_metadata.seq_groups
|
||||
for seq_id in sg.seq_ids
|
||||
]
|
||||
assert len(seq_ids) == list(batch_size)[0], "batch size mismatch"
|
||||
# Organize input tensors by step instead of by sequence.
|
||||
accepted_token_ids_by_step = logits.flatten()
|
||||
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
|
||||
|
||||
step_output_token_ids = []
|
||||
for i, seq_id in enumerate(seq_ids):
|
||||
token_id = accepted_token_ids_by_step[i]
|
||||
step_output_token_ids.append(
|
||||
CompletionSequenceGroupOutput(samples=[
|
||||
SequenceOutput(parent_seq_id=seq_id,
|
||||
output_token=token_id,
|
||||
logprobs={token_id: Logprob(token_id)})
|
||||
],
|
||||
prompt_logprobs=None))
|
||||
return SamplerOutput(outputs=step_output_token_ids)
|
||||
else:
|
||||
return self.sampler(logits, sampling_metadata)
|
||||
|
||||
def load_weights(self, model_name_or_path: str, **kwargs):
|
||||
arch = _get_model_architecture(self.config)
|
||||
neuronx_module_path, neuronx_model_cls_name = (
|
||||
_NEURON_SUPPORTED_MODELS[arch])
|
||||
neuronx_module = importlib.import_module(neuronx_module_path)
|
||||
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
|
||||
neuron_config = neuronx_model_cls.get_neuron_config_cls()(
|
||||
**kwargs['neuron_config'])
|
||||
self.config.neuron_config = neuron_config
|
||||
config = neuronx_model_cls.get_config_cls()(
|
||||
neuron_config,
|
||||
load_config=load_pretrained_config(model_name_or_path))
|
||||
hashed_config = hashlib.md5(config.to_json_string().encode('utf-8'),
|
||||
usedforsecurity=False).hexdigest()
|
||||
if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None:
|
||||
compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS")
|
||||
elif os.path.exists(model_name_or_path):
|
||||
compiled_model_path = os.path.join(model_name_or_path,
|
||||
"neuron-compiled-artifacts",
|
||||
hashed_config)
|
||||
shutil.rmtree(compiled_model_path, ignore_errors=True)
|
||||
else:
|
||||
compiled_model_path = os.path.join("local-models",
|
||||
model_name_or_path,
|
||||
"neuron-compiled-artifacts",
|
||||
hashed_config)
|
||||
shutil.rmtree(compiled_model_path, ignore_errors=True)
|
||||
try:
|
||||
self.model = neuronx_model_cls(compiled_model_path)
|
||||
override_neuron_config = kwargs["override_neuron_config"]
|
||||
for k, v in override_neuron_config.items():
|
||||
setattr(self.model.config.neuron_config, k, v)
|
||||
self.model.load(compiled_model_path)
|
||||
return
|
||||
except (FileNotFoundError, ValueError) as e:
|
||||
logger.warning("Exception: %s", e)
|
||||
logger.warning("Failed to load the model from %s, Recompiling...",
|
||||
compiled_model_path)
|
||||
if not os.path.exists(model_name_or_path):
|
||||
hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
|
||||
saved_path = os.path.join("local-models", model_name_or_path)
|
||||
hf_model.save_pretrained(saved_path)
|
||||
model_name_or_path = saved_path
|
||||
self.model = neuronx_model_cls(model_name_or_path, config)
|
||||
self.model.compile(compiled_model_path)
|
||||
self.model.load(compiled_model_path)
|
||||
|
||||
|
||||
class NeuronMllamaForCausalLM(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
on_device_sampling_disabled: bool = False) -> None:
|
||||
super().__init__()
|
||||
# has_image is the only multimodal input that is used in
|
||||
# token-generation
|
||||
# This is a cache (on CPU) that saves has_image data per sequence id
|
||||
# The number of entries in this cache is <= Batch-Size
|
||||
self.has_image_cache: dict[int, torch.Tensor] = {}
|
||||
self.config = config
|
||||
self.logits_processor = LogitsProcessor(
|
||||
config.get_text_config().vocab_size, logits_as_input=True)
|
||||
|
||||
self.on_device_sampling_disabled = on_device_sampling_disabled
|
||||
if self.on_device_sampling_disabled:
|
||||
# Use default sampler
|
||||
self.sampler = Sampler()
|
||||
|
||||
# Lazy initialized
|
||||
self.model: nn.Module
|
||||
self.is_reorder_needed: bool = True
|
||||
|
||||
def read_from_has_image_cache(self, seq_ids: torch.Tensor):
|
||||
has_image_list = []
|
||||
for index in range(len(seq_ids)):
|
||||
seq_id = seq_ids[index].item()
|
||||
if seq_id in self.has_image_cache:
|
||||
has_image_list.append(self.has_image_cache[seq_id])
|
||||
else:
|
||||
has_image_list.append(torch.tensor([0]))
|
||||
return torch.tensor(has_image_list)
|
||||
|
||||
def write_to_has_image_cache(self, seq_ids: torch.Tensor,
|
||||
has_image: torch.Tensor):
|
||||
for index in range(len(seq_ids)):
|
||||
seq_id = seq_ids[index].item()
|
||||
if index < len(has_image):
|
||||
self.has_image_cache[seq_id] = has_image[index]
|
||||
else:
|
||||
self.has_image_cache[seq_id] = torch.zeros(1)
|
||||
|
||||
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
|
||||
seq_ids: torch.Tensor, pixel_values: torch.Tensor,
|
||||
aspect_ratios: torch.Tensor, num_chunks: torch.Tensor,
|
||||
has_image: torch.Tensor, sampling_params) -> torch.Tensor:
|
||||
|
||||
# We update the has_image cache during prefill
|
||||
# and read the has_image cache during decode
|
||||
if input_ids.shape[-1] > 1: # prefill
|
||||
self.write_to_has_image_cache(seq_ids, has_image)
|
||||
else:
|
||||
has_image = self.read_from_has_image_cache(seq_ids)
|
||||
bs = input_ids.shape[0]
|
||||
num_chunks = torch.zeros((bs, 1))
|
||||
aspect_ratios = torch.zeros((bs, 1, 2))
|
||||
|
||||
input_block_ids = seq_ids
|
||||
origin_input_block_ids = seq_ids
|
||||
if self.is_reorder_needed:
|
||||
# sort block ids sequentially for perf/neuron support reasons
|
||||
input_block_ids, sorted_indices = torch.sort(input_block_ids)
|
||||
input_ids = torch.index_select(input_ids, 0, sorted_indices)
|
||||
positions = torch.index_select(positions, 0, sorted_indices)
|
||||
sampling_params = torch.index_select(sampling_params, 0,
|
||||
sorted_indices)
|
||||
pixel_values = torch.index_select(pixel_values, 0, sorted_indices)
|
||||
aspect_ratios = torch.index_select(aspect_ratios, 0,
|
||||
sorted_indices)
|
||||
num_chunks = torch.index_select(num_chunks, 0, sorted_indices)
|
||||
has_image = torch.index_select(has_image, 0, sorted_indices)
|
||||
|
||||
self.vision_mask = create_vision_mask(input_ids, self.vision_token_id)
|
||||
output = self.model(
|
||||
input_ids.to(torch.int32),
|
||||
attention_mask=None,
|
||||
position_ids=positions.to(torch.int32),
|
||||
seq_ids=seq_ids.flatten().to(torch.int32),
|
||||
pixel_values=pixel_values.to(
|
||||
self.config.vision_config.torch_dtype),
|
||||
aspect_ratios=aspect_ratios.to(torch.int32),
|
||||
vision_mask=self.vision_mask.to(torch.int32),
|
||||
sampling_params=sampling_params,
|
||||
num_chunks=num_chunks.to(torch.int32),
|
||||
has_image=has_image.to(torch.int32),
|
||||
)
|
||||
if self.config.neuron_config.on_device_sampling_config:
|
||||
output = output.hidden_states
|
||||
else:
|
||||
output = output.logits[:, -1, :]
|
||||
|
||||
if self.is_reorder_needed and origin_input_block_ids.shape[0] != 1:
|
||||
restored_indices = torch.argsort(sorted_indices)
|
||||
output = torch.index_select(output, 0, restored_indices)
|
||||
return output
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
logits = self.logits_processor(None, hidden_states, sampling_metadata)
|
||||
return logits
|
||||
|
||||
def sample(self, hidden_states, sampling_metadata):
|
||||
if not self.on_device_sampling_disabled:
|
||||
with torch.profiler.record_function("sample"):
|
||||
hidden_states = hidden_states.flatten()
|
||||
res = []
|
||||
sample_idx = 0
|
||||
for seq_group in sampling_metadata.seq_groups:
|
||||
seq_ids = seq_group.seq_ids
|
||||
samples = []
|
||||
for seq_id in seq_ids:
|
||||
token_id = hidden_states[sample_idx].item()
|
||||
samples.append(
|
||||
SequenceOutput(
|
||||
parent_seq_id=seq_id,
|
||||
output_token=token_id,
|
||||
logprobs={token_id: Logprob(token_id)}))
|
||||
sample_idx += 1
|
||||
res.append(
|
||||
CompletionSequenceGroupOutput(samples=samples,
|
||||
prompt_logprobs=None))
|
||||
next_tokens = SamplerOutput(outputs=res)
|
||||
else:
|
||||
next_tokens = self.sampler(None, hidden_states, sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self, model_name_or_path: str, **kwargs):
|
||||
arch = _get_model_architecture(self.config)
|
||||
neuronx_module_path, neuronx_model_cls_name = (
|
||||
_NEURON_SUPPORTED_MODELS[arch])
|
||||
neuronx_module = importlib.import_module(neuronx_module_path)
|
||||
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
|
||||
neuron_config = neuronx_model_cls.get_neuron_config_cls()(
|
||||
**kwargs['neuron_config'])
|
||||
self.config.neuron_config = neuron_config
|
||||
logger.info("neuron_config buckets: %s",
|
||||
self.config.neuron_config.buckets)
|
||||
config = neuronx_model_cls.get_config_cls()(
|
||||
neuron_config,
|
||||
load_config=load_pretrained_config(model_name_or_path))
|
||||
hashed_config = hashlib.md5(config.to_json_string().encode('utf-8'),
|
||||
usedforsecurity=False).hexdigest()
|
||||
if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None:
|
||||
compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS")
|
||||
elif os.path.exists(model_name_or_path):
|
||||
compiled_model_path = os.path.join(model_name_or_path,
|
||||
"neuron-compiled-artifacts",
|
||||
hashed_config)
|
||||
else:
|
||||
compiled_model_path = os.path.join("local-models",
|
||||
model_name_or_path,
|
||||
"neuron-compiled-artifacts",
|
||||
hashed_config)
|
||||
try:
|
||||
self.model = neuronx_model_cls(compiled_model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
||||
self.vision_token_id = tokenizer(
|
||||
"<|image|>", add_special_tokens=False).input_ids[0]
|
||||
self.model.load(compiled_model_path)
|
||||
return
|
||||
except (FileNotFoundError, ValueError):
|
||||
logger.warning("Failed to load the model from %s, Recompiling...",
|
||||
compiled_model_path)
|
||||
if not os.path.exists(model_name_or_path):
|
||||
hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
|
||||
saved_path = os.path.join("local-models", model_name_or_path)
|
||||
hf_model.save_pretrained(saved_path)
|
||||
model_name_or_path = saved_path
|
||||
self.model = neuronx_model_cls(model_name_or_path, config)
|
||||
|
||||
logger.info("\nCompiling and saving model to %s", model_name_or_path)
|
||||
|
||||
p = multiprocessing.Process(target=compile_model,
|
||||
args=(self, compiled_model_path))
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
|
||||
tokenizer.save_pretrained(compiled_model_path)
|
||||
logger.info("Successfully compiled and saved the model in %s",
|
||||
compiled_model_path)
|
||||
|
||||
# Read "<|image|>" token_id from the tokenizer
|
||||
self.vision_token_id = tokenizer("<|image|>",
|
||||
add_special_tokens=False).input_ids[0]
|
||||
logger.info("\nLoading model from compiled checkpoint...")
|
||||
self.model.load(compiled_model_path)
|
||||
|
||||
|
||||
def compile_model(neuron_model, traced_model_path):
|
||||
neuron_model.model.compile(traced_model_path)
|
||||
|
||||
|
||||
class NeuronSpeculationCausalLM(nn.Module):
|
||||
"""A Neuron-optimized causal language model with speculative decoding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size,
|
||||
logits_as_input=True)
|
||||
# Lazy initialized
|
||||
self.model: nn.Module
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
input_block_ids: torch.Tensor,
|
||||
sampling_params: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
# sort block ids sequentially for perf/neuron support reasons
|
||||
sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids)
|
||||
input_ids = torch.index_select(input_ids, 0, sorted_indices)
|
||||
positions = torch.index_select(positions, 0, sorted_indices)
|
||||
sampling_params = torch.index_select(sampling_params, 0,
|
||||
sorted_indices)
|
||||
|
||||
output = self.model(input_ids,
|
||||
attention_mask=None,
|
||||
position_ids=positions,
|
||||
seq_ids=sorted_input_block_ids,
|
||||
sampling_params=sampling_params)
|
||||
restored_indices = torch.argsort(sorted_indices)
|
||||
|
||||
# CTX encoding
|
||||
if (positions[:, 0]).sum().item() == 0:
|
||||
output = output.fused_outputs[0][:, 0:1]
|
||||
if input_block_ids.shape[0] != 1:
|
||||
output = torch.index_select(output, 0, restored_indices)
|
||||
return output
|
||||
|
||||
# Fused Spec (Generation)
|
||||
accepted_tokens_with_padding = output.fused_outputs[0]
|
||||
next_pos_ids = output.fused_outputs[-1]
|
||||
generated_token_counts = next_pos_ids - positions
|
||||
|
||||
assert torch.any(generated_token_counts == 0).item() is False, \
|
||||
"NxDI model generated no output for one or more sequences."
|
||||
|
||||
batch_size, steps = accepted_tokens_with_padding.shape
|
||||
mask = torch.arange(steps).expand(batch_size,
|
||||
-1) >= generated_token_counts
|
||||
accepted_tokens_with_padding[mask] = -1
|
||||
|
||||
if input_block_ids.shape[0] != 1:
|
||||
accepted_tokens_with_padding = torch.index_select(
|
||||
accepted_tokens_with_padding, 0, restored_indices)
|
||||
|
||||
return accepted_tokens_with_padding
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[list[SamplerOutput]]:
|
||||
batch_size, num_steps = logits.shape
|
||||
seq_ids = [
|
||||
seq_id for sg in sampling_metadata.seq_groups
|
||||
for seq_id in sg.seq_ids
|
||||
]
|
||||
# Organize input tensors by step instead of by sequence.
|
||||
accepted_token_ids_by_step = logits.transpose(0, 1)
|
||||
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
|
||||
|
||||
sampler_output_list = []
|
||||
for step_index in range(num_steps):
|
||||
if all(token_id == -1
|
||||
for token_id in accepted_token_ids_by_step[step_index]):
|
||||
break
|
||||
step_output_token_ids = []
|
||||
for sequence_index in range(batch_size):
|
||||
token_id = accepted_token_ids_by_step[step_index][
|
||||
sequence_index]
|
||||
step_output_token_ids.append(
|
||||
CompletionSequenceGroupOutput(samples=[
|
||||
SequenceOutput(parent_seq_id=seq_ids[sequence_index],
|
||||
output_token=token_id,
|
||||
logprobs={token_id: Logprob(token_id)})
|
||||
],
|
||||
prompt_logprobs=None))
|
||||
sampler_output_list.append(
|
||||
SamplerOutput(outputs=step_output_token_ids))
|
||||
return sampler_output_list
|
||||
|
||||
def load_weights(self, model_name_or_path: str,
|
||||
draft_model_name_or_path: str, **kwargs):
|
||||
arch = _get_model_architecture(self.config)
|
||||
neuronx_module_path, neuronx_model_cls_name = (
|
||||
_NEURON_SUPPORTED_MODELS[arch])
|
||||
neuronx_module = importlib.import_module(neuronx_module_path)
|
||||
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
|
||||
neuron_config = neuronx_model_cls.get_neuron_config_cls()(
|
||||
**kwargs['neuron_config'])
|
||||
config = neuronx_model_cls.get_config_cls()(
|
||||
neuron_config,
|
||||
load_config=load_pretrained_config(model_name_or_path))
|
||||
|
||||
draft_neuron_config = copy.deepcopy(config.neuron_config)
|
||||
if not config.neuron_config.enable_eagle_speculation:
|
||||
draft_neuron_config.speculation_length = 0
|
||||
draft_neuron_config.trace_tokengen_model = True
|
||||
draft_neuron_config.enable_fused_speculation = False
|
||||
if getattr(config.neuron_config, "draft_model_modules_to_not_convert",
|
||||
None):
|
||||
draft_neuron_config.modules_to_not_convert = (
|
||||
draft_neuron_config.draft_model_modules_to_not_convert)
|
||||
if config.neuron_config.enable_eagle_speculation:
|
||||
draft_neuron_config.is_eagle_draft = True
|
||||
draft_neuron_config.sequence_parallel_enabled = False
|
||||
draft_config = neuronx_model_cls.get_config_cls()(
|
||||
draft_neuron_config,
|
||||
load_config=load_pretrained_config(draft_model_name_or_path))
|
||||
fused_spec_config = (FusedSpecNeuronConfig(
|
||||
neuronx_model_cls._model_cls,
|
||||
draft_config=draft_config,
|
||||
draft_model_path=draft_model_name_or_path))
|
||||
config.fused_spec_config = fused_spec_config
|
||||
self.config.neuron_config = neuron_config
|
||||
|
||||
hashed_config = hashlib.md5(config.to_json_string().encode('utf-8'),
|
||||
usedforsecurity=False).hexdigest()
|
||||
if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None:
|
||||
compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS")
|
||||
elif os.path.exists(model_name_or_path):
|
||||
compiled_model_path = os.path.join(model_name_or_path,
|
||||
"neuron-compiled-artifacts",
|
||||
hashed_config)
|
||||
shutil.rmtree(compiled_model_path, ignore_errors=True)
|
||||
else:
|
||||
compiled_model_path = os.path.join("local-models",
|
||||
model_name_or_path,
|
||||
"neuron-compiled-artifacts",
|
||||
hashed_config)
|
||||
shutil.rmtree(compiled_model_path, ignore_errors=True)
|
||||
try:
|
||||
self.model = neuronx_model_cls(compiled_model_path)
|
||||
override_neuron_config = kwargs["override_neuron_config"]
|
||||
for k, v in override_neuron_config.items():
|
||||
setattr(self.model.config.neuron_config, k, v)
|
||||
self.model.load(compiled_model_path)
|
||||
return
|
||||
except (FileNotFoundError, ValueError) as e:
|
||||
logger.warning("Exception: %s", e)
|
||||
logger.warning("Failed to load the model from %s Recompiling...",
|
||||
compiled_model_path)
|
||||
if not os.path.exists(model_name_or_path):
|
||||
hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
|
||||
saved_path = os.path.join("local-models", model_name_or_path)
|
||||
hf_model.save_pretrained(saved_path)
|
||||
model_name_or_path = saved_path
|
||||
if not os.path.exists(draft_model_name_or_path):
|
||||
if draft_model_name_or_path != model_name_or_path:
|
||||
hf_model = AutoModelForCausalLM.from_pretrained(
|
||||
draft_model_name_or_path)
|
||||
saved_path = os.path.join("local-models",
|
||||
draft_model_name_or_path)
|
||||
hf_model.save_pretrained(saved_path)
|
||||
draft_model_name_or_path = saved_path
|
||||
else:
|
||||
draft_model_name_or_path = model_name_or_path
|
||||
config.fused_spec_config.draft_model_path = draft_model_name_or_path
|
||||
self.model = neuronx_model_cls(model_name_or_path, config)
|
||||
self.model.compile(compiled_model_path)
|
||||
self.model.load(compiled_model_path)
|
||||
|
||||
|
||||
def _get_model_architecture(config: PretrainedConfig) -> str:
|
||||
architectures = getattr(config, "architectures", [])
|
||||
for arch in architectures:
|
||||
if arch in _NEURON_SUPPORTED_MODELS:
|
||||
return arch
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} are not supported on Neuron "
|
||||
f"for now. Supported architectures: "
|
||||
f"{list(_NEURON_SUPPORTED_MODELS.keys())}")
|
||||
|
||||
|
||||
def _get_default_neuron_config(model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
lora_serving_config: LoraServingConfig):
|
||||
"""Generate a neuron config based on vllm config args."""
|
||||
on_device_sampling_config = OnDeviceSamplingConfig(dynamic=True,
|
||||
deterministic=False)
|
||||
batch_size = scheduler_config.max_num_seqs
|
||||
|
||||
neuron_config = dict(
|
||||
tp_degree=parallel_config.tensor_parallel_size,
|
||||
ctx_batch_size=1,
|
||||
batch_size=batch_size,
|
||||
max_context_length=scheduler_config.max_model_len,
|
||||
seq_len=scheduler_config.max_model_len,
|
||||
enable_bucketing=True,
|
||||
is_continuous_batching=True,
|
||||
quantized=False,
|
||||
torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
|
||||
padding_side="right",
|
||||
on_device_sampling_config=on_device_sampling_config,
|
||||
sequence_parallel_enabled=True,
|
||||
lora_serving_config=lora_serving_config)
|
||||
return neuron_config
|
||||
|
||||
|
||||
def _get_default_speculation_config(model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
speculation_config: SpeculativeConfig):
|
||||
"""Generate a neuron config for speculative decoding based on vllm config
|
||||
args."""
|
||||
neuron_config = dict(
|
||||
tp_degree=parallel_config.tensor_parallel_size,
|
||||
ctx_batch_size=1,
|
||||
batch_size=scheduler_config.max_num_seqs,
|
||||
max_context_length=scheduler_config.max_model_len,
|
||||
seq_len=scheduler_config.max_model_len,
|
||||
speculation_length=speculation_config.num_speculative_tokens,
|
||||
trace_tokengen_model=False,
|
||||
enable_fused_speculation=True,
|
||||
enable_bucketing=True,
|
||||
is_continuous_batching=True,
|
||||
quantized=False,
|
||||
torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
|
||||
on_device_sampling_config=dict(
|
||||
top_k=1,
|
||||
do_sample=False,
|
||||
))
|
||||
return neuron_config
|
||||
|
||||
|
||||
def _get_neuron_config_after_override(default_neuron_config,
|
||||
overridden_neuron_config):
|
||||
"""Update default neuron config values with override args"""
|
||||
overridden_neuron_config = overridden_neuron_config or {}
|
||||
default_neuron_config.update(overridden_neuron_config)
|
||||
return default_neuron_config
|
||||
|
||||
|
||||
def get_neuron_model(model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
lora_serving_config: LoraServingConfig) -> nn.Module:
|
||||
"""Initializes a neuron-optimized model for inference."""
|
||||
model_arch = _get_model_architecture(model_config.hf_config)
|
||||
if model_arch == "MllamaForConditionalGeneration":
|
||||
model = NeuronMllamaForCausalLM(model_config.hf_config)
|
||||
else:
|
||||
model = NeuronCausalLM(model_config.hf_config)
|
||||
default_neuron_config_args = _get_default_neuron_config(
|
||||
model_config, parallel_config, scheduler_config, lora_serving_config)
|
||||
neuron_config = _get_neuron_config_after_override(
|
||||
default_neuron_config_args, model_config.override_neuron_config)
|
||||
|
||||
override_neuron_config = model_config.override_neuron_config
|
||||
model.load_weights(model_config.model,
|
||||
neuron_config=neuron_config,
|
||||
override_neuron_config=override_neuron_config)
|
||||
return model.eval()
|
||||
|
||||
|
||||
def get_neuron_speculation_model(model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
speculation_config: SpeculativeConfig):
|
||||
"""Initializes a neuron-optimized speculation model for inference.
|
||||
|
||||
This model handles speculation using both a draft model and an EAGLE draft.
|
||||
"""
|
||||
model = NeuronSpeculationCausalLM(model_config.hf_config)
|
||||
default_neuron_config_args = _get_default_speculation_config(
|
||||
model_config, parallel_config, scheduler_config, speculation_config)
|
||||
neuron_config = _get_neuron_config_after_override(
|
||||
default_neuron_config_args, model_config.override_neuron_config)
|
||||
|
||||
override_neuron_config = model_config.override_neuron_config
|
||||
model.load_weights(model_config.model,
|
||||
speculation_config.draft_model_config.model,
|
||||
neuron_config=neuron_config,
|
||||
override_neuron_config=override_neuron_config)
|
||||
return model.eval()
|
||||
109
vllm/model_executor/model_loader/runai_streamer_loader.py
Normal file
109
vllm/model_executor/model_loader/runai_streamer_loader.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: SIM117
|
||||
import glob
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from vllm.config import LoadConfig, ModelConfig
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_safetensors_index_file_from_hf, download_weights_from_hf,
|
||||
runai_safetensors_weights_iterator)
|
||||
from vllm.transformers_utils.s3_utils import glob as s3_glob
|
||||
from vllm.transformers_utils.utils import is_s3
|
||||
|
||||
|
||||
class RunaiModelStreamerLoader(BaseModelLoader):
|
||||
"""
|
||||
Model loader that can load safetensors
|
||||
files from local FS or S3 bucket.
|
||||
"""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
if load_config.model_loader_extra_config:
|
||||
extra_config = load_config.model_loader_extra_config
|
||||
|
||||
if ("concurrency" in extra_config
|
||||
and isinstance(extra_config.get("concurrency"), int)):
|
||||
os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(
|
||||
extra_config.get("concurrency"))
|
||||
|
||||
if ("memory_limit" in extra_config
|
||||
and isinstance(extra_config.get("memory_limit"), int)):
|
||||
os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(
|
||||
extra_config.get("memory_limit"))
|
||||
|
||||
runai_streamer_s3_endpoint = os.getenv(
|
||||
'RUNAI_STREAMER_S3_ENDPOINT')
|
||||
aws_endpoint_url = os.getenv('AWS_ENDPOINT_URL')
|
||||
if (runai_streamer_s3_endpoint is None
|
||||
and aws_endpoint_url is not None):
|
||||
os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url
|
||||
|
||||
def _prepare_weights(self, model_name_or_path: str,
|
||||
revision: Optional[str]) -> list[str]:
|
||||
"""Prepare weights for the model.
|
||||
|
||||
If the model is not local, it will be downloaded."""
|
||||
|
||||
is_s3_path = is_s3(model_name_or_path)
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
safetensors_pattern = "*.safetensors"
|
||||
index_file = SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
hf_folder = (model_name_or_path if
|
||||
(is_local or is_s3_path) else download_weights_from_hf(
|
||||
model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
[safetensors_pattern],
|
||||
revision,
|
||||
ignore_patterns=self.load_config.ignore_patterns,
|
||||
))
|
||||
if is_s3_path:
|
||||
hf_weights_files = s3_glob(path=hf_folder,
|
||||
allow_pattern=[safetensors_pattern])
|
||||
else:
|
||||
hf_weights_files = glob.glob(
|
||||
os.path.join(hf_folder, safetensors_pattern))
|
||||
|
||||
if not is_local and not is_s3_path:
|
||||
download_safetensors_index_file_from_hf(
|
||||
model_name_or_path, index_file, self.load_config.download_dir,
|
||||
revision)
|
||||
|
||||
if not hf_weights_files:
|
||||
raise RuntimeError(
|
||||
f"Cannot find any safetensors model weights with "
|
||||
f"`{model_name_or_path}`")
|
||||
|
||||
return hf_weights_files
|
||||
|
||||
def _get_weights_iterator(
|
||||
self, model_or_path: str,
|
||||
revision: str) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Get an iterator for the model weights based on the load format."""
|
||||
hf_weights_files = self._prepare_weights(model_or_path, revision)
|
||||
return runai_safetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
"""Download model if necessary"""
|
||||
self._prepare_weights(model_config.model, model_config.revision)
|
||||
|
||||
def load_weights(self, model: nn.Module,
|
||||
model_config: ModelConfig) -> None:
|
||||
"""Load weights into a model."""
|
||||
model_weights = model_config.model
|
||||
if hasattr(model_config, "model_weights"):
|
||||
model_weights = model_config.model_weights
|
||||
model.load_weights(
|
||||
self._get_weights_iterator(model_weights, model_config.revision))
|
||||
201
vllm/model_executor/model_loader/sharded_state_loader.py
Normal file
201
vllm/model_executor/model_loader/sharded_state_loader.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import collections
|
||||
import glob
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import LoadConfig, ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_weights_from_hf, runai_safetensors_weights_iterator)
|
||||
from vllm.transformers_utils.s3_utils import glob as s3_glob
|
||||
from vllm.transformers_utils.utils import is_s3
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class ShardedStateLoader(BaseModelLoader):
|
||||
"""
|
||||
Model loader that directly loads each worker's model state dict, which
|
||||
enables a fast load path for large tensor-parallel models where each worker
|
||||
only needs to read its own shard rather than the entire checkpoint. See
|
||||
`examples/offline_inference/save_sharded_state.py` for creating a sharded
|
||||
checkpoint.
|
||||
"""
|
||||
|
||||
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
|
||||
|
||||
def __init__(self,
|
||||
load_config: LoadConfig,
|
||||
runai_model_streamer: bool = False):
|
||||
super().__init__(load_config)
|
||||
|
||||
self.runai_model_streamer = runai_model_streamer
|
||||
extra_config = ({} if load_config.model_loader_extra_config is None
|
||||
else load_config.model_loader_extra_config.copy())
|
||||
self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
|
||||
if extra_config:
|
||||
raise ValueError(f"Unexpected extra config keys for load format "
|
||||
f"{load_config.load_format}: "
|
||||
f"{load_config.model_loader_extra_config.keys()}")
|
||||
|
||||
@staticmethod
|
||||
def _filter_subtensors(
|
||||
tensors: dict[str, torch.Tensor], ) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Filter out all tensors that share the same memory or a subset of the
|
||||
memory of another tensor.
|
||||
"""
|
||||
same_storage_groups: dict[Any, list[tuple[str, torch.Tensor]]] = (
|
||||
collections.defaultdict(list))
|
||||
for key, tensor in tensors.items():
|
||||
if tensor.numel():
|
||||
ptr = tensor.untyped_storage().data_ptr()
|
||||
same_storage_groups[tensor.device, ptr].append((key, tensor))
|
||||
|
||||
def get_end_ptr(tensor: torch.Tensor) -> int:
|
||||
return tensor.view(-1)[-1].data_ptr() + tensor.element_size()
|
||||
|
||||
result: dict[str, torch.Tensor] = {}
|
||||
for group in same_storage_groups.values():
|
||||
for k, t in group:
|
||||
a, b = t.data_ptr(), get_end_ptr(t)
|
||||
for k2, t2 in group:
|
||||
if not t2.is_contiguous():
|
||||
continue
|
||||
a2, b2 = t2.data_ptr(), get_end_ptr(t2)
|
||||
if a < a2 or b2 < b:
|
||||
continue
|
||||
if a2 < a or b < b2 or not t.is_contiguous():
|
||||
break # t2 covers strictly more memory than t.
|
||||
if k2 < k:
|
||||
# Same tensors, keep the one with the smaller key.
|
||||
break
|
||||
else:
|
||||
result[k] = t
|
||||
return result
|
||||
|
||||
def _prepare_weights(self, model_name_or_path: str,
|
||||
revision: Optional[str]):
|
||||
if is_s3(model_name_or_path) or os.path.isdir(model_name_or_path):
|
||||
return model_name_or_path
|
||||
else:
|
||||
allow_patterns = ["*.safetensors"]
|
||||
return download_weights_from_hf(
|
||||
model_name_or_path,
|
||||
self.load_config.download_dir,
|
||||
allow_patterns,
|
||||
revision,
|
||||
ignore_patterns=self.load_config.ignore_patterns,
|
||||
)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self._prepare_weights(model_config.model, model_config.revision)
|
||||
|
||||
def load_weights(self, model: nn.Module,
|
||||
model_config: ModelConfig) -> None:
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
model_weights = model_config.model
|
||||
if hasattr(model_config, "model_weights"):
|
||||
model_weights = model_config.model_weights
|
||||
local_model_path = model_weights
|
||||
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
pattern = os.path.join(
|
||||
local_model_path,
|
||||
self.pattern.format(rank=rank, part="*"),
|
||||
)
|
||||
|
||||
filepaths = []
|
||||
if is_s3(local_model_path):
|
||||
file_pattern = f"*{self.pattern.format(rank=rank, part=' * ')}"
|
||||
filepaths = s3_glob(path=local_model_path,
|
||||
allow_pattern=[file_pattern])
|
||||
else:
|
||||
filepaths = glob.glob(pattern)
|
||||
if not filepaths:
|
||||
# TODO: support un-sharded checkpoints too
|
||||
raise ValueError(
|
||||
f"Could not find checkpoint files '{pattern}', only "
|
||||
f"pre-sharded checkpoints are currently supported!")
|
||||
state_dict = self._filter_subtensors(model.state_dict())
|
||||
for key, tensor in self.iterate_over_files(filepaths):
|
||||
# If loading with LoRA enabled, additional padding may
|
||||
# be added to certain parameters. We only load into a
|
||||
# narrowed view of the parameter data.
|
||||
param_data = state_dict[key].data
|
||||
param_shape = state_dict[key].shape
|
||||
for dim, size in enumerate(tensor.shape):
|
||||
if size < param_shape[dim]:
|
||||
param_data = param_data.narrow(dim, 0, size)
|
||||
if tensor.shape != param_shape:
|
||||
logger.warning(
|
||||
"loading tensor of shape %s into "
|
||||
"parameter '%s' of shape %s",
|
||||
tensor.shape,
|
||||
key,
|
||||
param_shape,
|
||||
)
|
||||
param_data.copy_(tensor)
|
||||
state_dict.pop(key)
|
||||
if state_dict:
|
||||
raise ValueError(
|
||||
f"Missing keys {tuple(state_dict)} in loaded state!")
|
||||
|
||||
def iterate_over_files(
|
||||
self, paths) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
if self.runai_model_streamer:
|
||||
yield from runai_safetensors_weights_iterator(paths, True)
|
||||
else:
|
||||
from safetensors.torch import safe_open
|
||||
for path in paths:
|
||||
with safe_open(path, framework="pt") as f:
|
||||
for key in f.keys(): # noqa: SIM118
|
||||
tensor = f.get_tensor(key)
|
||||
yield key, tensor
|
||||
|
||||
@staticmethod
|
||||
def save_model(
|
||||
model: torch.nn.Module,
|
||||
path: str,
|
||||
pattern: Optional[str] = None,
|
||||
max_size: Optional[int] = None,
|
||||
) -> None:
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
if pattern is None:
|
||||
pattern = ShardedStateLoader.DEFAULT_PATTERN
|
||||
rank = get_tensor_model_parallel_rank()
|
||||
part_idx = 0
|
||||
total_size = 0
|
||||
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
|
||||
state_dict_part: dict[str, torch.Tensor] = {}
|
||||
for key, tensor in state_dict.items():
|
||||
param_size = tensor.nelement() * tensor.element_size()
|
||||
if max_size is not None and total_size + param_size > max_size:
|
||||
filename = pattern.format(rank=rank, part=part_idx)
|
||||
save_file(
|
||||
state_dict_part,
|
||||
os.path.join(path, filename),
|
||||
)
|
||||
part_idx += 1
|
||||
total_size = 0
|
||||
state_dict_part = {}
|
||||
state_dict_part[key] = tensor
|
||||
total_size += param_size
|
||||
if len(state_dict_part) > 0:
|
||||
filename = pattern.format(rank=rank, part=part_idx)
|
||||
save_file(
|
||||
state_dict_part,
|
||||
os.path.join(path, filename),
|
||||
)
|
||||
600
vllm/model_executor/model_loader/tensorizer.py
Normal file
600
vllm/model_executor/model_loader/tensorizer.py
Normal file
@@ -0,0 +1,600 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import contextvars
|
||||
import dataclasses
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, BinaryIO, Optional, Union
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (ModelConfig, ParallelConfig, VllmConfig,
|
||||
set_current_vllm_config)
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding)
|
||||
from vllm.utils import FlexibleArgumentParser, PlaceholderModule
|
||||
|
||||
try:
|
||||
from tensorizer import (DecryptionParams, EncryptionParams,
|
||||
TensorDeserializer, TensorSerializer)
|
||||
from tensorizer.stream_io import open_stream
|
||||
from tensorizer.utils import (convert_bytes, get_mem_usage,
|
||||
no_init_or_tensor)
|
||||
|
||||
_read_stream, _write_stream = (partial(
|
||||
open_stream,
|
||||
mode=mode,
|
||||
) for mode in ("rb", "wb+"))
|
||||
except ImportError:
|
||||
tensorizer = PlaceholderModule("tensorizer")
|
||||
DecryptionParams = tensorizer.placeholder_attr("DecryptionParams")
|
||||
EncryptionParams = tensorizer.placeholder_attr("EncryptionParams")
|
||||
TensorDeserializer = tensorizer.placeholder_attr("TensorDeserializer")
|
||||
TensorSerializer = tensorizer.placeholder_attr("TensorSerializer")
|
||||
open_stream = tensorizer.placeholder_attr("stream_io.open_stream")
|
||||
convert_bytes = tensorizer.placeholder_attr("utils.convert_bytes")
|
||||
get_mem_usage = tensorizer.placeholder_attr("utils.get_mem_usage")
|
||||
no_init_or_tensor = tensorizer.placeholder_attr("utils.no_init_or_tensor")
|
||||
|
||||
_read_stream = tensorizer.placeholder_attr("_read_stream")
|
||||
_write_stream = tensorizer.placeholder_attr("_write_stream")
|
||||
|
||||
__all__ = [
|
||||
'EncryptionParams', 'DecryptionParams', 'TensorDeserializer',
|
||||
'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage',
|
||||
'no_init_or_tensor', 'TensorizerConfig'
|
||||
]
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class MetaTensorMode(TorchDispatchMode):
|
||||
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
kwargs = kwargs or {}
|
||||
|
||||
if func._schema.name == "aten::empty" and "device" not in kwargs:
|
||||
kwargs["device"] = "meta"
|
||||
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
def meta_tensor_mode(loading_code=None, ):
|
||||
|
||||
if loading_code is None:
|
||||
return _NoInitOrTensorImpl.context_manager()
|
||||
elif callable(loading_code):
|
||||
with _NoInitOrTensorImpl.context_manager():
|
||||
return loading_code()
|
||||
else:
|
||||
raise TypeError(
|
||||
"expected a callable to evaluate,"
|
||||
" or None if being used as a context manager;"
|
||||
f' got an object of type "{type(loading_code).__name__}" instead.')
|
||||
|
||||
|
||||
class _NoInitOrTensorImpl:
|
||||
_MODULES = (torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm)
|
||||
_MODULE_ORIGINALS = tuple((m, m.reset_parameters) for m in _MODULES)
|
||||
|
||||
is_active = contextvars.ContextVar("_NoInitOrTensorImpl.is_active",
|
||||
default=False)
|
||||
_count_active: int = 0
|
||||
_count_active_lock = threading.Lock()
|
||||
|
||||
@classmethod
|
||||
@contextlib.contextmanager
|
||||
def context_manager(cls):
|
||||
if cls.is_active.get():
|
||||
yield
|
||||
return
|
||||
|
||||
with cls._count_active_lock:
|
||||
cls._count_active += 1
|
||||
if cls._count_active == 1:
|
||||
for mod in cls._MODULES:
|
||||
mod.reset_parameters = cls._disable(mod.reset_parameters)
|
||||
|
||||
reset_token = cls.is_active.set(True)
|
||||
|
||||
try:
|
||||
with MetaTensorMode():
|
||||
yield
|
||||
finally:
|
||||
cls.is_active.reset(reset_token)
|
||||
with cls._count_active_lock:
|
||||
cls._count_active -= 1
|
||||
if cls._count_active == 0:
|
||||
for mod, original in cls._MODULE_ORIGINALS:
|
||||
mod.reset_parameters = original
|
||||
|
||||
@staticmethod
|
||||
def _disable(func):
|
||||
|
||||
def wrapper(*args, **kwargs):
|
||||
if not _NoInitOrTensorImpl.is_active.get():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorizerConfig:
|
||||
tensorizer_uri: Union[str, None] = None
|
||||
vllm_tensorized: Optional[bool] = False
|
||||
verify_hash: Optional[bool] = False
|
||||
num_readers: Optional[int] = None
|
||||
encryption_keyfile: Optional[str] = None
|
||||
s3_access_key_id: Optional[str] = None
|
||||
s3_secret_access_key: Optional[str] = None
|
||||
s3_endpoint: Optional[str] = None
|
||||
model_class: Optional[type[torch.nn.Module]] = None
|
||||
hf_config: Optional[PretrainedConfig] = None
|
||||
dtype: Optional[Union[str, torch.dtype]] = None
|
||||
lora_dir: Optional[str] = None
|
||||
_is_sharded: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# check if the configuration is for a sharded vLLM model
|
||||
self._is_sharded = isinstance(self.tensorizer_uri, str) \
|
||||
and re.search(r'%0\dd', self.tensorizer_uri) is not None
|
||||
if not self.tensorizer_uri and not self.lora_dir:
|
||||
raise ValueError("tensorizer_uri must be provided.")
|
||||
if not self.tensorizer_uri and self.lora_dir:
|
||||
self.tensorizer_uri = f"{self.lora_dir}/adapter_model.tensors"
|
||||
assert self.tensorizer_uri is not None, ("tensorizer_uri must be "
|
||||
"provided.")
|
||||
self.tensorizer_dir = os.path.dirname(self.tensorizer_uri)
|
||||
self.lora_dir = self.tensorizer_dir
|
||||
|
||||
@classmethod
|
||||
def as_dict(cls, *args, **kwargs) -> dict[str, Any]:
|
||||
cfg = TensorizerConfig(*args, **kwargs)
|
||||
return dataclasses.asdict(cfg)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return dataclasses.asdict(self)
|
||||
|
||||
def _construct_tensorizer_args(self) -> "TensorizerArgs":
|
||||
tensorizer_args = {
|
||||
"tensorizer_uri": self.tensorizer_uri,
|
||||
"vllm_tensorized": self.vllm_tensorized,
|
||||
"verify_hash": self.verify_hash,
|
||||
"num_readers": self.num_readers,
|
||||
"encryption_keyfile": self.encryption_keyfile,
|
||||
"s3_access_key_id": self.s3_access_key_id,
|
||||
"s3_secret_access_key": self.s3_secret_access_key,
|
||||
"s3_endpoint": self.s3_endpoint,
|
||||
}
|
||||
return TensorizerArgs(**tensorizer_args) # type: ignore
|
||||
|
||||
def verify_with_parallel_config(
|
||||
self,
|
||||
parallel_config: "ParallelConfig",
|
||||
) -> None:
|
||||
if parallel_config.tensor_parallel_size > 1 \
|
||||
and not self._is_sharded:
|
||||
raise ValueError(
|
||||
"For a sharded model, tensorizer_uri should include a"
|
||||
" string format template like '%04d' to be formatted"
|
||||
" with the rank of the shard")
|
||||
|
||||
def verify_with_model_config(self, model_config: "ModelConfig") -> None:
|
||||
if (model_config.quantization is not None
|
||||
and self.tensorizer_uri is not None):
|
||||
logger.warning(
|
||||
"Loading a model using Tensorizer with quantization on vLLM"
|
||||
" is unstable and may lead to errors.")
|
||||
|
||||
def open_stream(self, tensorizer_args: Optional["TensorizerArgs"] = None):
|
||||
if tensorizer_args is None:
|
||||
tensorizer_args = self._construct_tensorizer_args()
|
||||
|
||||
return open_stream(self.tensorizer_uri,
|
||||
**tensorizer_args.stream_params)
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorizerArgs:
|
||||
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str,
|
||||
bytes, os.PathLike, int]
|
||||
vllm_tensorized: Optional[bool] = False
|
||||
verify_hash: Optional[bool] = False
|
||||
num_readers: Optional[int] = None
|
||||
encryption_keyfile: Optional[str] = None
|
||||
s3_access_key_id: Optional[str] = None
|
||||
s3_secret_access_key: Optional[str] = None
|
||||
s3_endpoint: Optional[str] = None
|
||||
"""
|
||||
Args for the TensorizerAgent class. These are used to configure the behavior
|
||||
of the TensorDeserializer when loading tensors from a serialized model.
|
||||
|
||||
Args:
|
||||
tensorizer_uri: Path to serialized model tensors. Can be a local file
|
||||
path or a S3 URI. This is a required field unless lora_dir is
|
||||
provided and the config is meant to be used for the
|
||||
`tensorize_lora_adapter` function.
|
||||
vllm_tensorized: If True, indicates that the serialized model is a
|
||||
vLLM model. This is used to determine the behavior of the
|
||||
TensorDeserializer when loading tensors from a serialized model.
|
||||
It is far faster to deserialize a vLLM model as it utilizes
|
||||
tensorizer's optimized GPU loading. Note that this is now
|
||||
deprecated, as serialized vLLM models are now automatically
|
||||
inferred as vLLM models.
|
||||
verify_hash: If True, the hashes of each tensor will be verified against
|
||||
the hashes stored in the metadata. A `HashMismatchError` will be
|
||||
raised if any of the hashes do not match.
|
||||
num_readers: Controls how many threads are allowed to read concurrently
|
||||
from the source file. Default is `None`, which will dynamically set
|
||||
the number of readers based on the number of available
|
||||
resources and model size. This greatly increases performance.
|
||||
encryption_keyfile: File path to a binary file containing a
|
||||
binary key to use for decryption. `None` (the default) means
|
||||
no decryption. See the example script in
|
||||
examples/others/tensorize_vllm_model.py.
|
||||
s3_access_key_id: The access key for the S3 bucket. Can also be set via
|
||||
the S3_ACCESS_KEY_ID environment variable.
|
||||
s3_secret_access_key: The secret access key for the S3 bucket. Can also
|
||||
be set via the S3_SECRET_ACCESS_KEY environment variable.
|
||||
s3_endpoint: The endpoint for the S3 bucket. Can also be set via the
|
||||
S3_ENDPOINT_URL environment variable.
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
self.file_obj = self.tensorizer_uri
|
||||
self.s3_access_key_id = self.s3_access_key_id or envs.S3_ACCESS_KEY_ID
|
||||
self.s3_secret_access_key = (self.s3_secret_access_key
|
||||
or envs.S3_SECRET_ACCESS_KEY)
|
||||
self.s3_endpoint = self.s3_endpoint or envs.S3_ENDPOINT_URL
|
||||
self.stream_params = {
|
||||
"s3_access_key_id": self.s3_access_key_id,
|
||||
"s3_secret_access_key": self.s3_secret_access_key,
|
||||
"s3_endpoint": self.s3_endpoint,
|
||||
}
|
||||
|
||||
self.deserializer_params = {
|
||||
"verify_hash": self.verify_hash,
|
||||
"encryption": self.encryption_keyfile,
|
||||
"num_readers": self.num_readers
|
||||
}
|
||||
|
||||
if self.encryption_keyfile:
|
||||
with open_stream(
|
||||
self.encryption_keyfile,
|
||||
**self.stream_params,
|
||||
) as stream:
|
||||
key = stream.read()
|
||||
decryption_params = DecryptionParams.from_key(key)
|
||||
self.deserializer_params['encryption'] = decryption_params
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
|
||||
"""Tensorizer CLI arguments"""
|
||||
|
||||
# Tensorizer options arg group
|
||||
group = parser.add_argument_group(
|
||||
'tensorizer options',
|
||||
description=('Options for configuring the behavior of the'
|
||||
' tensorizer deserializer when '
|
||||
'load_format=tensorizer is specified when '
|
||||
'initializing an LLMEngine, either via the CLI '
|
||||
'when running the vLLM OpenAI inference server '
|
||||
'with a JSON string passed to '
|
||||
'--model-loader-extra-config or as arguments given '
|
||||
'to TensorizerConfig when passed to '
|
||||
'model_loader_extra_config in the constructor '
|
||||
'for LLMEngine.'))
|
||||
|
||||
group.add_argument(
|
||||
"--tensorizer-uri",
|
||||
type=str,
|
||||
help="Path to serialized model tensors. Can be a local file path,"
|
||||
" or an HTTP(S) or S3 URI.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--verify-hash",
|
||||
action="store_true",
|
||||
help="If enabled, the hashes of each tensor will be verified"
|
||||
" against the hashes stored in the file metadata. An exception"
|
||||
" will be raised if any of the hashes do not match.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--encryption-keyfile",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The file path to a binary file containing a binary key to "
|
||||
"use for decryption. Can be a file path or S3 network URI.")
|
||||
group.add_argument(
|
||||
"--num-readers",
|
||||
default=None,
|
||||
type=int,
|
||||
help="Controls how many threads are allowed to read concurrently "
|
||||
"from the source file. Default is `None`, which will dynamically "
|
||||
"set the number of readers based on the available resources "
|
||||
"and model size. This greatly increases performance.")
|
||||
group.add_argument(
|
||||
"--s3-access-key-id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The access key for the S3 bucket. Can also be set via the "
|
||||
"S3_ACCESS_KEY_ID environment variable.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--s3-secret-access-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The secret access key for the S3 bucket. Can also be set via "
|
||||
"the S3_SECRET_ACCESS_KEY environment variable.",
|
||||
)
|
||||
group.add_argument(
|
||||
"--s3-endpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The endpoint for the S3 bucket. Can also be set via the "
|
||||
"S3_ENDPOINT_URL environment variable.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace) -> "TensorizerArgs":
|
||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||
tensorizer_args = cls(**{
|
||||
attr: getattr(args, attr)
|
||||
for attr in attrs if hasattr(args, attr)
|
||||
})
|
||||
return tensorizer_args
|
||||
|
||||
|
||||
def _check_tensors_on_meta_device(model: nn.Module) -> None:
|
||||
for tensor in model.state_dict().values():
|
||||
if tensor.device.type == 'meta':
|
||||
raise ValueError(
|
||||
"The serialized model contains tensors on the meta device,"
|
||||
" indicating that some tensors were not loaded properly."
|
||||
" Please check that the parameters of the model being"
|
||||
" specified match that of the serialized model, such as"
|
||||
" its quantization.")
|
||||
|
||||
|
||||
def _resize_lora_embeddings(model: nn.Module):
|
||||
"""Modify LoRA embedding layers to use bigger tensors
|
||||
to allow for adapter added tokens."""
|
||||
for child in model.modules():
|
||||
if (isinstance(child, VocabParallelEmbedding) and child.weight.shape[0]
|
||||
< child.num_embeddings_per_partition):
|
||||
new_weight = torch.empty(child.num_embeddings_per_partition,
|
||||
child.embedding_dim,
|
||||
dtype=child.weight.dtype,
|
||||
device=child.weight.device)
|
||||
new_weight[:child.weight.shape[0]].copy_(child.weight.data)
|
||||
new_weight[child.weight.shape[0]:].fill_(0)
|
||||
child.weight.data = new_weight
|
||||
|
||||
|
||||
def init_tensorizer_model(tensorizer_config: TensorizerConfig,
|
||||
vllm_config: VllmConfig) -> nn.Module:
|
||||
assert tensorizer_config.hf_config is not None
|
||||
model_args = tensorizer_config.hf_config
|
||||
model_args.torch_dtype = tensorizer_config.dtype
|
||||
assert tensorizer_config.model_class is not None
|
||||
# TODO: Do we need to consider old-style model class?
|
||||
with meta_tensor_mode(), set_current_vllm_config(vllm_config,
|
||||
check_compile=True):
|
||||
return tensorizer_config.model_class(vllm_config=vllm_config)
|
||||
|
||||
|
||||
def deserialize_tensorizer_model(model: nn.Module,
|
||||
tensorizer_config: TensorizerConfig) -> None:
|
||||
tensorizer_args = tensorizer_config._construct_tensorizer_args()
|
||||
before_mem = get_mem_usage()
|
||||
start = time.perf_counter()
|
||||
with _read_stream(
|
||||
tensorizer_config.tensorizer_uri,
|
||||
**tensorizer_args.stream_params) as stream, TensorDeserializer(
|
||||
stream,
|
||||
dtype=tensorizer_config.dtype,
|
||||
device=f'cuda:{torch.cuda.current_device()}',
|
||||
**tensorizer_args.deserializer_params) as deserializer:
|
||||
deserializer.load_into_module(model)
|
||||
end = time.perf_counter()
|
||||
|
||||
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
|
||||
duration = end - start
|
||||
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
|
||||
after_mem = get_mem_usage()
|
||||
deserializer.close()
|
||||
logger.info("Deserialized %s in %0.2fs, %s/s", total_bytes_str,
|
||||
end - start, per_second)
|
||||
logger.info("Memory usage before: %s", before_mem)
|
||||
logger.info("Memory usage after: %s", after_mem)
|
||||
|
||||
_check_tensors_on_meta_device(model)
|
||||
_resize_lora_embeddings(model)
|
||||
del model.vllm_tensorized_marker
|
||||
|
||||
|
||||
def tensorizer_weights_iterator(
|
||||
tensorizer_args: "TensorizerArgs"
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
logger.warning("Deserializing HuggingFace models is not optimized for "
|
||||
"loading on vLLM, as tensorizer is forced to load to CPU. "
|
||||
"Consider deserializing a vLLM model instead for faster "
|
||||
"load times. See the "
|
||||
"examples/others/tensorize_vllm_model.py example script "
|
||||
"for serializing vLLM models.")
|
||||
|
||||
deserializer_args = tensorizer_args.deserializer_params
|
||||
stream_params = tensorizer_args.stream_params
|
||||
stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params)
|
||||
with TensorDeserializer(stream, **deserializer_args,
|
||||
device="cpu") as state:
|
||||
yield from state.items()
|
||||
del state
|
||||
|
||||
|
||||
def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool:
|
||||
"""
|
||||
Infer if the model is a vLLM model by checking the weights for
|
||||
a vLLM tensorized marker.
|
||||
|
||||
Args:
|
||||
tensorizer_config: The TensorizerConfig object containing the
|
||||
tensorizer_uri to the serialized model.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is a vLLM model, False otherwise.
|
||||
"""
|
||||
tensorizer_args = tensorizer_config._construct_tensorizer_args()
|
||||
deserializer = TensorDeserializer(open_stream(
|
||||
tensorizer_args.tensorizer_uri, **tensorizer_args.stream_params),
|
||||
**tensorizer_args.deserializer_params,
|
||||
lazy_load=True)
|
||||
if tensorizer_config.vllm_tensorized:
|
||||
logger.warning(
|
||||
"Please note that newly serialized vLLM models are automatically "
|
||||
"inferred as vLLM models, so setting vllm_tensorized=True is "
|
||||
"only necessary for models serialized prior to this change.")
|
||||
return True
|
||||
return ".vllm_tensorized_marker" in deserializer
|
||||
|
||||
|
||||
def serialize_vllm_model(
|
||||
model: nn.Module,
|
||||
tensorizer_config: TensorizerConfig,
|
||||
) -> nn.Module:
|
||||
model.register_parameter(
|
||||
"vllm_tensorized_marker",
|
||||
nn.Parameter(torch.tensor((1, ), device="meta"), requires_grad=False))
|
||||
tensorizer_args = tensorizer_config._construct_tensorizer_args()
|
||||
|
||||
encryption_params = None
|
||||
if (keyfile := tensorizer_config.encryption_keyfile) is not None:
|
||||
with open(keyfile, "rb") as f:
|
||||
key = f.read()
|
||||
encryption_params = EncryptionParams(key=key)
|
||||
|
||||
output_file = tensorizer_args.tensorizer_uri
|
||||
if tensorizer_config._is_sharded:
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
output_file = output_file % get_tensor_model_parallel_rank()
|
||||
|
||||
with _write_stream(output_file, **tensorizer_args.stream_params) as stream:
|
||||
serializer = TensorSerializer(stream, encryption=encryption_params)
|
||||
serializer.write_module(model)
|
||||
serializer.close()
|
||||
logger.info("Successfully serialized model to %s", str(output_file))
|
||||
return model
|
||||
|
||||
|
||||
def tensorize_vllm_model(engine_args: EngineArgs,
|
||||
tensorizer_config: TensorizerConfig,
|
||||
generate_keyfile: bool = True):
|
||||
"""Utility to load a model and then serialize it with Tensorizer
|
||||
|
||||
Intended to be used separately from running a vLLM server since it
|
||||
creates its own Engine instance.
|
||||
"""
|
||||
engine_config = engine_args.create_engine_config()
|
||||
tensorizer_config.verify_with_model_config(engine_config.model_config)
|
||||
tensorizer_config.verify_with_parallel_config(
|
||||
engine_config.parallel_config)
|
||||
|
||||
# generate the encryption key before creating the engine to support sharding
|
||||
if generate_keyfile and (keyfile :=
|
||||
tensorizer_config.encryption_keyfile) is not None:
|
||||
encryption_params = EncryptionParams.random()
|
||||
with _write_stream(
|
||||
keyfile,
|
||||
s3_access_key_id=tensorizer_config.s3_access_key_id,
|
||||
s3_secret_access_key=tensorizer_config.s3_secret_access_key,
|
||||
s3_endpoint=tensorizer_config.s3_endpoint,
|
||||
) as stream:
|
||||
stream.write(encryption_params.key)
|
||||
|
||||
from vllm import LLMEngine
|
||||
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
|
||||
|
||||
if not envs.VLLM_USE_V1:
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
engine.model_executor.collective_rpc(
|
||||
"save_tensorized_model",
|
||||
kwargs=dict(tensorizer_config=tensorizer_config),
|
||||
)
|
||||
else:
|
||||
engine = V1LLMEngine.from_vllm_config(engine_config)
|
||||
engine.collective_rpc(
|
||||
"save_tensorized_model",
|
||||
kwargs=dict(tensorizer_config=tensorizer_config),
|
||||
)
|
||||
|
||||
|
||||
def tensorize_lora_adapter(lora_path: str,
|
||||
tensorizer_config: TensorizerConfig):
|
||||
"""
|
||||
Uses tensorizer to serialize a LoRA adapter. Assumes that the files
|
||||
needed to load a LoRA adapter are a safetensors-format file called
|
||||
adapter_model.safetensors and a json config file called adapter_config.json.
|
||||
|
||||
Serializes the files in the tensorizer_config.lora_dir
|
||||
"""
|
||||
import safetensors
|
||||
|
||||
from vllm.lora.utils import get_adapter_absolute_path
|
||||
|
||||
lora_dir = get_adapter_absolute_path(lora_path)
|
||||
|
||||
tensor_path = config_path = ""
|
||||
|
||||
for file in os.listdir(lora_dir):
|
||||
if file.startswith("adapter_model"):
|
||||
tensor_path = lora_dir + "/" + file
|
||||
if file.startswith("adapter_config"):
|
||||
config_path = lora_dir + "/" + file
|
||||
if tensor_path and config_path:
|
||||
break
|
||||
|
||||
if tensor_path.endswith(".safetensors"):
|
||||
tensors = safetensors.torch.load_file(tensor_path)
|
||||
elif tensor_path.endswith(".bin"):
|
||||
tensors = torch.load(tensor_path)
|
||||
else:
|
||||
raise ValueError("Unsupported file: %s", tensor_path)
|
||||
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
|
||||
tensorizer_args = tensorizer_config._construct_tensorizer_args()
|
||||
|
||||
with open_stream(f"{tensorizer_config.lora_dir}/adapter_config.json",
|
||||
mode="wb+",
|
||||
**tensorizer_args.stream_params) as f:
|
||||
|
||||
f.write(json.dumps(config).encode("utf-8"))
|
||||
|
||||
lora_uri = (f"{tensorizer_config.lora_dir}"
|
||||
f"/adapter_model.tensors")
|
||||
with open_stream(lora_uri, mode="wb+",
|
||||
**tensorizer_args.stream_params) as f:
|
||||
serializer = TensorSerializer(f)
|
||||
serializer.write_state_dict(tensors)
|
||||
serializer.close()
|
||||
|
||||
logger.info("Successfully serialized LoRA files to %s",
|
||||
str(tensorizer_config.lora_dir))
|
||||
123
vllm/model_executor/model_loader/tensorizer_loader.py
Normal file
123
vllm/model_executor/model_loader/tensorizer_loader.py
Normal file
@@ -0,0 +1,123 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
# ruff: noqa: SIM117
|
||||
import copy
|
||||
from collections.abc import Generator
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import LoadConfig, ModelConfig, ParallelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.tensorizer import (
|
||||
TensorizerConfig, deserialize_tensorizer_model, init_tensorizer_model,
|
||||
is_vllm_tensorized, serialize_vllm_model, tensorizer_weights_iterator)
|
||||
from vllm.model_executor.model_loader.utils import (get_model_architecture,
|
||||
initialize_model,
|
||||
set_default_torch_dtype)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TensorizerLoader(BaseModelLoader):
|
||||
"""Model loader using CoreWeave's tensorizer library."""
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
|
||||
self.tensorizer_config = load_config.model_loader_extra_config
|
||||
else:
|
||||
self.tensorizer_config = TensorizerConfig(
|
||||
**load_config.model_loader_extra_config)
|
||||
|
||||
def _verify_config(self, model_config: ModelConfig,
|
||||
parallel_config: ParallelConfig):
|
||||
self.tensorizer_config.verify_with_model_config(model_config)
|
||||
self.tensorizer_config.verify_with_parallel_config(parallel_config)
|
||||
|
||||
def _get_weights_iterator(
|
||||
self, ) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
|
||||
return tensorizer_weights_iterator(tensorizer_args)
|
||||
|
||||
def _load_model_serialized_cpu(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
) -> nn.Module:
|
||||
"""Load a serialized model with tensorizer to the CPU.
|
||||
|
||||
This is only necessary when the model isn't vLLM-tensorized (see
|
||||
examples/others/tensorize_vllm_model.py) This should still
|
||||
be faster than default HuggingFace loading, but will be slower than
|
||||
loading a vLLM-tensorized model.
|
||||
"""
|
||||
device_config = vllm_config.device_config
|
||||
model_config = vllm_config.model_config
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
model = initialize_model(vllm_config=vllm_config)
|
||||
|
||||
model.load_weights(self._get_weights_iterator())
|
||||
return model.eval()
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
self.tensorizer_config.verify_with_model_config(model_config)
|
||||
|
||||
with self.tensorizer_config.open_stream():
|
||||
pass
|
||||
|
||||
def _patch_tensorizer_config(
|
||||
self, model_config: ModelConfig) -> TensorizerConfig:
|
||||
model_class = get_model_architecture(model_config)[0]
|
||||
tensorizer_config = copy.copy(self.tensorizer_config)
|
||||
tensorizer_config.model_class = model_class
|
||||
tensorizer_config.hf_config = model_config.hf_config
|
||||
tensorizer_config.dtype = model_config.dtype
|
||||
return tensorizer_config
|
||||
|
||||
def load_weights(self, model: nn.Module,
|
||||
model_config: ModelConfig) -> None:
|
||||
"""Load serialized model weights with tensorizer.
|
||||
|
||||
Expects a vLLM-tensorized model. See the
|
||||
examples/others/tensorize_vllm_model.py example script
|
||||
for serializing vLLM models."""
|
||||
if is_vllm_tensorized(self.tensorizer_config):
|
||||
tensorizer_config = self._patch_tensorizer_config(model_config)
|
||||
deserialize_tensorizer_model(model, tensorizer_config)
|
||||
else:
|
||||
model.load_weights(self._get_weights_iterator())
|
||||
|
||||
def load_model(self, vllm_config: VllmConfig,
|
||||
model_config: ModelConfig) -> nn.Module:
|
||||
parallel_config = vllm_config.parallel_config
|
||||
self._verify_config(model_config, parallel_config)
|
||||
|
||||
if parallel_config.tensor_parallel_size > 1:
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
|
||||
self.tensorizer_config.tensorizer_uri = (
|
||||
self.tensorizer_config.tensorizer_uri %
|
||||
get_tensor_model_parallel_rank())
|
||||
|
||||
if is_vllm_tensorized(self.tensorizer_config):
|
||||
tensorizer_config = self._patch_tensorizer_config(model_config)
|
||||
model = init_tensorizer_model(tensorizer_config=tensorizer_config,
|
||||
vllm_config=vllm_config)
|
||||
self.load_weights(model, model_config)
|
||||
return model
|
||||
return self._load_model_serialized_cpu(vllm_config=vllm_config)
|
||||
|
||||
@staticmethod
|
||||
def save_model(
|
||||
model: torch.nn.Module,
|
||||
tensorizer_config: Union[TensorizerConfig, dict],
|
||||
) -> None:
|
||||
if isinstance(tensorizer_config, dict):
|
||||
tensorizer_config = TensorizerConfig(**tensorizer_config)
|
||||
serialize_vllm_model(
|
||||
model=model,
|
||||
tensorizer_config=tensorizer_config,
|
||||
)
|
||||
112
vllm/model_executor/model_loader/tpu.py
Normal file
112
vllm/model_executor/model_loader/tpu.py
Normal file
@@ -0,0 +1,112 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.distributed.spmd as xs
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.distributed.tpu_distributed_utils import get_fqn, shard_model
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
initialize_model, process_weights_after_loading, set_default_torch_dtype)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TPUModelLoader(DefaultModelLoader):
|
||||
"""
|
||||
A TPU model loader for model loading under SPMD mode.
|
||||
"""
|
||||
|
||||
def load_model(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
model_config: ModelConfig,
|
||||
mesh: Optional[xs.Mesh] = None,
|
||||
) -> nn.Module:
|
||||
# Initialize model and load weights on CPU. Then, during SPMD partition,
|
||||
# weights are sharded and transferred to TPUs.
|
||||
self.counter_before_loading_weights = time.perf_counter()
|
||||
model_config = vllm_config.model_config
|
||||
assert model_config.quantization is None, "Quantization not supported"
|
||||
target_device = torch.device('cpu')
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with target_device:
|
||||
model = initialize_model(vllm_config=vllm_config)
|
||||
|
||||
load_format = vllm_config.load_config.load_format
|
||||
if load_format != "dummy":
|
||||
weights_to_load = {
|
||||
name
|
||||
for name, _ in model.named_parameters()
|
||||
}
|
||||
all_weights = self.get_all_weights(model_config, model)
|
||||
loaded_weights = model.load_weights(all_weights)
|
||||
self.counter_after_loading_weights = time.perf_counter()
|
||||
logger.info(
|
||||
"Loading weights took %.2f seconds",
|
||||
self.counter_after_loading_weights -
|
||||
self.counter_before_loading_weights)
|
||||
# We only enable strict check for non-quantized models
|
||||
# that have loaded weights tracking currently.
|
||||
if model_config.quantization is None and \
|
||||
loaded_weights is not None:
|
||||
weights_not_loaded = weights_to_load - loaded_weights
|
||||
if weights_not_loaded:
|
||||
raise ValueError(
|
||||
"Following weights were not initialized from "
|
||||
f"checkpoint: {weights_not_loaded}")
|
||||
else:
|
||||
logger.info("Use dummy weight during weight loading.")
|
||||
|
||||
process_weights_after_loading(model, model_config, target_device)
|
||||
|
||||
counter_before_partition = time.perf_counter()
|
||||
model = model.eval()
|
||||
model = model.to('xla')
|
||||
shard_model(model, mesh)
|
||||
counter_after_partition = time.perf_counter()
|
||||
logger.info("Partition model took %.2f seconds",
|
||||
counter_after_partition - counter_before_partition)
|
||||
|
||||
# Ensure the model is properly loaded.
|
||||
self._check_model_is_loaded(mesh, model)
|
||||
|
||||
# Need to torch compile after model sharding are done. Because the
|
||||
# compiler hints ('xs.mark_sharding') are torch ops.
|
||||
if not model_config.is_multimodal_model:
|
||||
model.model = torch.compile(model.model, backend="openxla")
|
||||
else:
|
||||
model.language_model.model = \
|
||||
torch.compile(model.language_model.model, backend="openxla")
|
||||
return model
|
||||
|
||||
def _check_model_is_loaded(self, mesh: Optional[xs.Mesh],
|
||||
model: nn.Module) -> None:
|
||||
"""
|
||||
Ensure the model is properly loaded.
|
||||
1. All model parameters and buffers are on XLA device.
|
||||
2. Non-SPMD friendly layers are replaced as expected.
|
||||
"""
|
||||
device = xm.xla_device()
|
||||
device_type = str(device.type)
|
||||
|
||||
# Check parameters
|
||||
for name, param in model.named_parameters():
|
||||
assert param.device.type == device_type, f"Parameter {name} is on \
|
||||
{param.device.type} instead of {device_type}"
|
||||
|
||||
# Check buffers
|
||||
for name, buffer in model.named_buffers():
|
||||
assert buffer.device.type == device_type, \
|
||||
f"Buffer {name} is on {buffer.device.type} instead of \
|
||||
{device_type}"
|
||||
|
||||
for module in model.modules():
|
||||
if (mesh is not None) and (get_fqn(module) == 'QKVParallelLinear'):
|
||||
raise AssertionError("QKVParallelLinear should be replaced by \
|
||||
XlaQKVParallelLinear under SPMD mode.")
|
||||
302
vllm/model_executor/model_loader/utils.py
Normal file
302
vllm/model_executor/model_loader/utils.py
Normal file
@@ -0,0 +1,302 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utilities for selecting and loading models."""
|
||||
import contextlib
|
||||
import inspect
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from torch import nn
|
||||
from transformers.dynamic_module_utils import get_class_from_dynamic_module
|
||||
|
||||
from vllm.attention import Attention
|
||||
from vllm.config import (ModelConfig, ModelImpl, VllmConfig,
|
||||
set_current_vllm_config)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import QKVCrossParallelLinear
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.model_executor.models.adapters import (as_classification_model,
|
||||
as_embedding_model,
|
||||
as_reward_model)
|
||||
from vllm.utils import is_pin_memory_available
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_default_torch_dtype(dtype: torch.dtype):
|
||||
"""Sets the default torch dtype to the given dtype."""
|
||||
old_dtype = torch.get_default_dtype()
|
||||
torch.set_default_dtype(dtype)
|
||||
yield
|
||||
torch.set_default_dtype(old_dtype)
|
||||
|
||||
|
||||
def initialize_model(
|
||||
vllm_config: VllmConfig,
|
||||
*,
|
||||
prefix: str = "",
|
||||
model_class: Optional[type[nn.Module]] = None,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
) -> nn.Module:
|
||||
"""Initialize a model with the given configurations."""
|
||||
if model_config is None:
|
||||
model_config = vllm_config.model_config
|
||||
if model_class is None:
|
||||
model_class, _ = get_model_architecture(model_config)
|
||||
|
||||
if vllm_config.quant_config is not None:
|
||||
configure_quant_config(vllm_config.quant_config, model_class)
|
||||
|
||||
signatures = inspect.signature(model_class.__init__)
|
||||
all_params = [param.name for param in signatures.parameters.values()]
|
||||
if "vllm_config" in all_params and "prefix" in all_params:
|
||||
# new-style model class
|
||||
with set_current_vllm_config(vllm_config, check_compile=True):
|
||||
return model_class(vllm_config=vllm_config, prefix=prefix)
|
||||
|
||||
msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
|
||||
"input arguments. Possibly you have an old-style model class"
|
||||
" registered from out of tree and it is used for new vLLM version. "
|
||||
"Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
|
||||
"for the design and update the model class accordingly.")
|
||||
warnings.warn(msg, DeprecationWarning, stacklevel=2)
|
||||
|
||||
logger.warning(
|
||||
"Trying to guess the arguments for old-style model class %s",
|
||||
model_class,
|
||||
)
|
||||
# try to be compatible with old-style model class
|
||||
kwargs = {}
|
||||
if "prefix" in all_params:
|
||||
kwargs["prefix"] = prefix
|
||||
if "config" in all_params:
|
||||
kwargs["config"] = model_config.hf_config
|
||||
if "cache_config" in all_params:
|
||||
kwargs["cache_config"] = vllm_config.cache_config
|
||||
if "quant_config" in all_params:
|
||||
kwargs["quant_config"] = vllm_config.quant_config
|
||||
if "lora_config" in all_params:
|
||||
kwargs["lora_config"] = vllm_config.lora_config
|
||||
if "scheduler_config" in all_params:
|
||||
kwargs["scheduler_config"] = vllm_config.scheduler_config
|
||||
with set_current_vllm_config(vllm_config, check_compile=True):
|
||||
return model_class(**kwargs)
|
||||
|
||||
|
||||
def process_weights_after_loading(model: nn.Module, model_config: ModelConfig,
|
||||
target_device: torch.device) -> None:
|
||||
for _, module in model.named_modules():
|
||||
if isinstance(module, QKVCrossParallelLinear):
|
||||
# NOTE(Isotr0py): special case for cross QKV layer because
|
||||
# q and kv proj aren't registered as submodules intentionally
|
||||
module.process_weights_after_loading()
|
||||
continue
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if isinstance(quant_method, QuantizeMethodBase):
|
||||
# When quant methods need to process weights after loading
|
||||
# (for repacking, quantizing, etc), they expect parameters
|
||||
# to be on the global target device. This scope is for the
|
||||
# case where cpu offloading is used, where we will move the
|
||||
# parameters onto device for processing and back off after.
|
||||
with device_loading_context(module, target_device):
|
||||
quant_method.process_weights_after_loading(module)
|
||||
|
||||
# Currently only used by MLA.
|
||||
# NOTE: This intentionally happens after other modules so we can easily
|
||||
# decompress the weights for MLA.
|
||||
for _, module in model.named_modules():
|
||||
if isinstance(module, Attention) and \
|
||||
hasattr(module, "process_weights_after_loading"):
|
||||
# TODO(lucas): see if there is a way to unify the signatures
|
||||
# of process_weights_after_loading
|
||||
module.process_weights_after_loading(model_config.dtype)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def device_loading_context(module: torch.nn.Module,
|
||||
target_device: torch.device):
|
||||
if target_device.type == "cpu":
|
||||
# If target is CPU, no need to move anything
|
||||
yield module
|
||||
return
|
||||
|
||||
original_device_states: dict[str, torch.device] = {}
|
||||
|
||||
# Store original device states and move parameters to GPU if they're on CPU
|
||||
for name, p in module.named_parameters():
|
||||
if p.device.type == "cpu":
|
||||
original_device_states[name] = p.device
|
||||
p.data = p.data.to(target_device)
|
||||
# Parameters already on target device are not touched
|
||||
|
||||
try:
|
||||
yield module
|
||||
|
||||
finally:
|
||||
# Restore parameters to their original devices, ignoring new parameters
|
||||
pin_memory = is_pin_memory_available()
|
||||
for name, p in module.named_parameters():
|
||||
if name in original_device_states:
|
||||
original_device: torch.device = original_device_states[name]
|
||||
if original_device.type == "cpu":
|
||||
# `torch.empty_like` does not support `pin_memory` argument
|
||||
cpu_data = torch.empty_strided(
|
||||
size=p.data.size(),
|
||||
stride=p.data.stride(),
|
||||
dtype=p.data.dtype,
|
||||
layout=p.data.layout,
|
||||
device="cpu",
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
cpu_data.copy_(p.data)
|
||||
p.data = cpu_data
|
||||
else:
|
||||
p.data = p.data.to(original_device)
|
||||
# New parameters or parameters already on target device are untouched
|
||||
|
||||
|
||||
def resolve_transformers_arch(model_config: ModelConfig,
|
||||
architectures: list[str]):
|
||||
for i, arch in enumerate(architectures):
|
||||
if arch == "TransformersForCausalLM":
|
||||
continue
|
||||
auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map",
|
||||
None) or dict()
|
||||
# Make sure that config class is always initialized before model class,
|
||||
# otherwise the model class won't be able to access the config class,
|
||||
# the expected auto_map should have correct order like:
|
||||
# "auto_map": {
|
||||
# "AutoConfig": "<your-repo-name>--<config-name>",
|
||||
# "AutoModel": "<your-repo-name>--<config-name>",
|
||||
# "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
|
||||
# },
|
||||
auto_modules = {
|
||||
name:
|
||||
get_class_from_dynamic_module(module,
|
||||
model_config.model,
|
||||
revision=model_config.revision)
|
||||
for name, module in sorted(auto_map.items(), key=lambda x: x[0])
|
||||
}
|
||||
model_module = getattr(transformers, arch, None)
|
||||
if model_module is None:
|
||||
if "AutoModel" not in auto_map:
|
||||
raise ValueError(
|
||||
f"Cannot find model module. '{arch}' is not a registered "
|
||||
"model in the Transformers library (only relevant if the "
|
||||
"model is meant to be in Transformers) and 'AutoModel' is "
|
||||
"not present in the model config's 'auto_map' (relevant "
|
||||
"if the model is custom).")
|
||||
model_module = auto_modules["AutoModel"]
|
||||
# TODO(Isotr0py): Further clean up these raises.
|
||||
# perhaps handled them in _ModelRegistry._raise_for_unsupported?
|
||||
if model_config.model_impl == ModelImpl.TRANSFORMERS:
|
||||
if not model_module.is_backend_compatible():
|
||||
raise ValueError(
|
||||
f"The Transformers implementation of {arch} is not "
|
||||
"compatible with vLLM.")
|
||||
architectures[i] = "TransformersForCausalLM"
|
||||
if model_config.model_impl == ModelImpl.AUTO:
|
||||
if not model_module.is_backend_compatible():
|
||||
raise ValueError(
|
||||
f"{arch} has no vLLM implementation and the Transformers "
|
||||
"implementation is not compatible with vLLM. Try setting "
|
||||
"VLLM_USE_V1=0.")
|
||||
logger.warning(
|
||||
"%s has no vLLM implementation, falling back to Transformers "
|
||||
"implementation. Some features may not be supported and "
|
||||
"performance may not be optimal.", arch)
|
||||
architectures[i] = "TransformersForCausalLM"
|
||||
return architectures
|
||||
|
||||
|
||||
def get_model_architecture(
|
||||
model_config: ModelConfig) -> tuple[type[nn.Module], str]:
|
||||
architectures = getattr(model_config.hf_config, "architectures", [])
|
||||
|
||||
# Special handling for quantized Mixtral.
|
||||
# FIXME(woosuk): This is a temporary hack.
|
||||
mixtral_supported = [
|
||||
"fp8", "compressed-tensors", "gptq_marlin", "awq_marlin", "quark"
|
||||
]
|
||||
|
||||
vllm_supported_archs = ModelRegistry.get_supported_archs()
|
||||
vllm_not_supported = not any(arch in vllm_supported_archs
|
||||
for arch in architectures)
|
||||
if (model_config.model_impl == ModelImpl.TRANSFORMERS or
|
||||
model_config.model_impl != ModelImpl.VLLM and vllm_not_supported):
|
||||
architectures = resolve_transformers_arch(model_config, architectures)
|
||||
elif (model_config.quantization is not None
|
||||
and model_config.quantization not in mixtral_supported
|
||||
and "MixtralForCausalLM" in architectures):
|
||||
architectures = ["QuantMixtralForCausalLM"]
|
||||
|
||||
model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
|
||||
if model_config.task == "embed":
|
||||
model_cls = as_embedding_model(model_cls)
|
||||
elif model_config.task == "classify":
|
||||
model_cls = as_classification_model(model_cls)
|
||||
elif model_config.task == "reward":
|
||||
model_cls = as_reward_model(model_cls)
|
||||
|
||||
return model_cls, arch
|
||||
|
||||
|
||||
def get_architecture_class_name(model_config: ModelConfig) -> str:
|
||||
return get_model_architecture(model_config)[1]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParamMapping:
|
||||
"""
|
||||
A class to handle parameter mapping for model weight loading.
|
||||
It creates a bidirectional mapping between packed parameters and their
|
||||
constituent parts.
|
||||
"""
|
||||
packed_mapping: dict[str, list[str]]
|
||||
inverse_packed_mapping: dict[str, tuple[str,
|
||||
int]] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
for packed_name, sub_params in self.packed_mapping.items():
|
||||
# Skip self-contained cases (e.g., {"W_pack": ["W_pack"]})
|
||||
if len(sub_params) == 1 and sub_params[0] == packed_name:
|
||||
continue
|
||||
for index, param_name in enumerate(sub_params):
|
||||
self.inverse_packed_mapping[param_name] = (
|
||||
packed_name,
|
||||
index,
|
||||
)
|
||||
|
||||
def get_sub_modules(self,
|
||||
module_name: str) -> Optional[tuple[str, list[str]]]:
|
||||
for key, value in self.packed_mapping.items():
|
||||
if module_name.endswith(key):
|
||||
return key, value
|
||||
return None
|
||||
|
||||
|
||||
def configure_quant_config(quant_config: QuantizationConfig,
|
||||
model_class: type[nn.Module]):
|
||||
"""
|
||||
Pass packed_modules_mapping by reference to quant_config so that
|
||||
quant_config can properly match fused modules
|
||||
|
||||
Note that model attributes are passed by reference to quant_config,
|
||||
enabling them to be updated by model_class.__new__ (ex. chatglm, qwen)
|
||||
"""
|
||||
packed_mapping = getattr(model_class, "packed_modules_mapping", None)
|
||||
if packed_mapping is not None:
|
||||
# pass packed_modules_mapping by reference to quant_config
|
||||
quant_config.packed_modules_mapping = packed_mapping
|
||||
else:
|
||||
logger.warning(
|
||||
"The model class %s has not defined `packed_modules_mapping`, "
|
||||
"this may lead to incorrect mapping of quantized or ignored "
|
||||
"modules", model_class.__name__)
|
||||
782
vllm/model_executor/model_loader/weight_utils.py
Normal file
782
vllm/model_executor/model_loader/weight_utils.py
Normal file
@@ -0,0 +1,782 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utilities for downloading and initializing model weights."""
|
||||
import fnmatch
|
||||
import glob
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import filelock
|
||||
import gguf
|
||||
import huggingface_hub.constants
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
|
||||
from safetensors.torch import load_file, safe_open, save_file
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from vllm.config import LoadConfig, ModelConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import (QuantizationConfig,
|
||||
get_quantization_config)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import PlaceholderModule
|
||||
|
||||
try:
|
||||
from runai_model_streamer import SafetensorsStreamer
|
||||
except (ImportError, OSError):
|
||||
# see https://github.com/run-ai/runai-model-streamer/issues/26
|
||||
# OSError will be raised on arm64 platform
|
||||
runai_model_streamer = PlaceholderModule(
|
||||
"runai_model_streamer") # type: ignore[assignment]
|
||||
SafetensorsStreamer = runai_model_streamer.placeholder_attr(
|
||||
"SafetensorsStreamer")
|
||||
|
||||
try:
|
||||
from fastsafetensors import SafeTensorsFileLoader, SingleGroup
|
||||
except ImportError:
|
||||
fastsafetensors = PlaceholderModule("fastsafetensors")
|
||||
SafeTensorsFileLoader = fastsafetensors.placeholder_attr(
|
||||
"SafeTensorsFileLoader")
|
||||
SingleGroup = fastsafetensors.placeholder_attr("SingleGroup")
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# use system-level temp directory for file locks, so that multiple users
|
||||
# can share the same lock without error.
|
||||
# lock files in the temp directory will be automatically deleted when the
|
||||
# system reboots, so users will not complain about annoying lock files
|
||||
temp_dir = tempfile.gettempdir()
|
||||
|
||||
|
||||
def enable_hf_transfer():
|
||||
"""automatically activates hf_transfer
|
||||
"""
|
||||
if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
|
||||
try:
|
||||
# enable hf hub transfer if available
|
||||
import hf_transfer # type: ignore # noqa
|
||||
huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
enable_hf_transfer()
|
||||
|
||||
|
||||
class DisabledTqdm(tqdm):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs, disable=True)
|
||||
|
||||
|
||||
def get_lock(model_name_or_path: Union[str, Path],
|
||||
cache_dir: Optional[str] = None):
|
||||
lock_dir = cache_dir or temp_dir
|
||||
model_name_or_path = str(model_name_or_path)
|
||||
os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
|
||||
model_name = model_name_or_path.replace("/", "-")
|
||||
hash_name = hashlib.sha256(model_name.encode()).hexdigest()
|
||||
# add hash to avoid conflict with old users' lock files
|
||||
lock_file_name = hash_name + model_name + ".lock"
|
||||
# mode 0o666 is required for the filelock to be shared across users
|
||||
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name),
|
||||
mode=0o666)
|
||||
return lock
|
||||
|
||||
|
||||
def _shared_pointers(tensors):
|
||||
ptrs = defaultdict(list)
|
||||
for k, v in tensors.items():
|
||||
ptrs[v.data_ptr()].append(k)
|
||||
failing = []
|
||||
for _, names in ptrs.items():
|
||||
if len(names) > 1:
|
||||
failing.append(names)
|
||||
return failing
|
||||
|
||||
|
||||
def convert_bin_to_safetensor_file(
|
||||
pt_filename: str,
|
||||
sf_filename: str,
|
||||
) -> None:
|
||||
loaded = torch.load(pt_filename, map_location="cpu", weights_only=True)
|
||||
if "state_dict" in loaded:
|
||||
loaded = loaded["state_dict"]
|
||||
shared = _shared_pointers(loaded)
|
||||
for shared_weights in shared:
|
||||
for name in shared_weights[1:]:
|
||||
loaded.pop(name)
|
||||
|
||||
# For tensors to be contiguous
|
||||
loaded = {k: v.contiguous() for k, v in loaded.items()}
|
||||
|
||||
dirname = os.path.dirname(sf_filename)
|
||||
os.makedirs(dirname, exist_ok=True)
|
||||
save_file(loaded, sf_filename, metadata={"format": "pt"})
|
||||
|
||||
# check file size
|
||||
sf_size = os.stat(sf_filename).st_size
|
||||
pt_size = os.stat(pt_filename).st_size
|
||||
if (sf_size - pt_size) / pt_size > 0.01:
|
||||
raise RuntimeError(f"""The file size different is more than 1%:
|
||||
- {sf_filename}: {sf_size}
|
||||
- {pt_filename}: {pt_size}
|
||||
""")
|
||||
|
||||
# check if the tensors are the same
|
||||
reloaded = load_file(sf_filename)
|
||||
for k in loaded:
|
||||
pt_tensor = loaded[k]
|
||||
sf_tensor = reloaded[k]
|
||||
if not torch.equal(pt_tensor, sf_tensor):
|
||||
raise RuntimeError(f"The output tensors do not match for key {k}")
|
||||
|
||||
|
||||
# TODO(woosuk): Move this to other place.
|
||||
def get_quant_config(model_config: ModelConfig,
|
||||
load_config: LoadConfig) -> QuantizationConfig:
|
||||
|
||||
quant_cls = get_quantization_config(model_config.quantization)
|
||||
|
||||
# GGUF doesn't have config file
|
||||
if model_config.quantization == "gguf":
|
||||
return quant_cls.from_config({})
|
||||
|
||||
# Read the quantization config from the HF model config, if available.
|
||||
hf_quant_config = getattr(model_config.hf_config, "quantization_config",
|
||||
None)
|
||||
# some vision model may keep quantization_config in their text_config
|
||||
hf_text_config = getattr(model_config.hf_config, "text_config", None)
|
||||
if hf_quant_config is None and hf_text_config is not None:
|
||||
hf_quant_config = getattr(hf_text_config, "quantization_config", None)
|
||||
if hf_quant_config is None:
|
||||
# compressed-tensors uses a compressions_config
|
||||
hf_quant_config = getattr(model_config.hf_config, "compression_config",
|
||||
None)
|
||||
if hf_quant_config is not None:
|
||||
return quant_cls.from_config(hf_quant_config)
|
||||
# Inflight BNB quantization
|
||||
if model_config.quantization == "bitsandbytes":
|
||||
return quant_cls.from_config({})
|
||||
is_local = os.path.isdir(model_config.model)
|
||||
if not is_local:
|
||||
# Download the config files.
|
||||
with get_lock(model_config.model, load_config.download_dir):
|
||||
hf_folder = snapshot_download(
|
||||
model_config.model,
|
||||
revision=model_config.revision,
|
||||
allow_patterns="*.json",
|
||||
cache_dir=load_config.download_dir,
|
||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||
tqdm_class=DisabledTqdm,
|
||||
)
|
||||
else:
|
||||
hf_folder = model_config.model
|
||||
|
||||
possible_config_filenames = quant_cls.get_config_filenames()
|
||||
|
||||
# If the quantization config is not found, use the default config.
|
||||
if not possible_config_filenames:
|
||||
return quant_cls()
|
||||
|
||||
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
|
||||
|
||||
quant_config_files = [
|
||||
f for f in config_files if any(
|
||||
f.endswith(x) for x in possible_config_filenames)
|
||||
]
|
||||
if len(quant_config_files) == 0:
|
||||
raise ValueError(
|
||||
f"Cannot find the config file for {model_config.quantization}")
|
||||
if len(quant_config_files) > 1:
|
||||
raise ValueError(
|
||||
f"Found multiple config files for {model_config.quantization}: "
|
||||
f"{quant_config_files}")
|
||||
|
||||
quant_config_file = quant_config_files[0]
|
||||
with open(quant_config_file) as f:
|
||||
config = json.load(f)
|
||||
|
||||
if model_config.quantization == "bitsandbytes":
|
||||
config["adapter_name_or_path"] = model_config.model
|
||||
elif model_config.quantization == "modelopt":
|
||||
if config["producer"]["name"] == "modelopt":
|
||||
return quant_cls.from_config(config)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported quantization config"
|
||||
f" found for {model_config.quantization} in {f}.")
|
||||
|
||||
return quant_cls.from_config(config)
|
||||
|
||||
|
||||
def get_sparse_attention_config(
|
||||
model_config: ModelConfig,
|
||||
load_config: LoadConfig,
|
||||
sparse_attention_config_filename: str = "sparse_attention_config.json",
|
||||
) -> dict[str, Any]:
|
||||
model_name_or_path = model_config.model
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
if not is_local:
|
||||
# Download the config files.
|
||||
with get_lock(model_name_or_path, load_config.download_dir):
|
||||
hf_folder = snapshot_download(
|
||||
model_name_or_path,
|
||||
revision=model_config.revision,
|
||||
allow_patterns="*.json",
|
||||
cache_dir=load_config.download_dir,
|
||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||
tqdm_class=DisabledTqdm,
|
||||
)
|
||||
else:
|
||||
hf_folder = model_name_or_path
|
||||
|
||||
config_file = os.path.join(hf_folder, sparse_attention_config_filename)
|
||||
if not os.path.exists(config_file):
|
||||
return {}
|
||||
|
||||
# Load the sparse attention config.
|
||||
with open(config_file) as f:
|
||||
config = json.load(f)
|
||||
logger.info("Loaded sparse attention config from %s", config_file)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def download_weights_from_hf(
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str],
|
||||
allow_patterns: list[str],
|
||||
revision: Optional[str] = None,
|
||||
ignore_patterns: Optional[Union[str, list[str]]] = None,
|
||||
) -> str:
|
||||
"""Download model weights from Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
model_name_or_path (str): The model name or path.
|
||||
cache_dir (Optional[str]): The cache directory to store the model
|
||||
weights. If None, will use HF defaults.
|
||||
allow_patterns (list[str]): The allowed patterns for the
|
||||
weight files. Files matched by any of the patterns will be
|
||||
downloaded.
|
||||
revision (Optional[str]): The revision of the model.
|
||||
ignore_patterns (Optional[Union[str, list[str]]]): The patterns to
|
||||
filter out the weight files. Files matched by any of the patterns
|
||||
will be ignored.
|
||||
|
||||
Returns:
|
||||
str: The path to the downloaded model weights.
|
||||
"""
|
||||
local_only = huggingface_hub.constants.HF_HUB_OFFLINE
|
||||
if not local_only:
|
||||
# Before we download we look at that is available:
|
||||
fs = HfFileSystem()
|
||||
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
|
||||
|
||||
# depending on what is available we download different things
|
||||
for pattern in allow_patterns:
|
||||
matching = fnmatch.filter(file_list, pattern)
|
||||
if len(matching) > 0:
|
||||
allow_patterns = [pattern]
|
||||
break
|
||||
|
||||
logger.info("Using model weights format %s", allow_patterns)
|
||||
# Use file lock to prevent multiple processes from
|
||||
# downloading the same model weights at the same time.
|
||||
with get_lock(model_name_or_path, cache_dir):
|
||||
start_time = time.perf_counter()
|
||||
hf_folder = snapshot_download(
|
||||
model_name_or_path,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
cache_dir=cache_dir,
|
||||
tqdm_class=DisabledTqdm,
|
||||
revision=revision,
|
||||
local_files_only=local_only,
|
||||
)
|
||||
time_taken = time.perf_counter() - start_time
|
||||
if time_taken > 0.5:
|
||||
logger.info("Time spent downloading weights for %s: %.6f seconds",
|
||||
model_name_or_path, time_taken)
|
||||
return hf_folder
|
||||
|
||||
|
||||
def download_safetensors_index_file_from_hf(
|
||||
model_name_or_path: str,
|
||||
index_file: str,
|
||||
cache_dir: Optional[str],
|
||||
revision: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Download hf safetensors index file from Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
model_name_or_path (str): The model name or path.
|
||||
index_file (str): The safetensors index file name
|
||||
cache_dir (Optional[str]): The cache directory to store the model
|
||||
weights. If None, will use HF defaults.
|
||||
revision (Optional[str]): The revision of the model.
|
||||
"""
|
||||
# Use file lock to prevent multiple processes from
|
||||
# downloading the same model weights at the same time.
|
||||
with get_lock(model_name_or_path, cache_dir):
|
||||
try:
|
||||
# Download the safetensors index file.
|
||||
hf_hub_download(
|
||||
repo_id=model_name_or_path,
|
||||
filename=index_file,
|
||||
cache_dir=cache_dir,
|
||||
revision=revision,
|
||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||
)
|
||||
# If file not found on remote or locally, we should not fail since
|
||||
# only some models will have index_file.
|
||||
except huggingface_hub.utils.LocalEntryNotFoundError:
|
||||
logger.info("No %s found in local cache.", index_file)
|
||||
except huggingface_hub.utils.EntryNotFoundError:
|
||||
logger.info("No %s found in remote.", index_file)
|
||||
|
||||
|
||||
# For models like Mistral-7B-v0.3, there are both sharded
|
||||
# safetensors files and a consolidated safetensors file.
|
||||
# Passing both of these to the weight loader functionality breaks.
|
||||
# So, we use the index_file to
|
||||
# look up which safetensors files should be used.
|
||||
def filter_duplicate_safetensors_files(hf_weights_files: list[str],
|
||||
hf_folder: str,
|
||||
index_file: str) -> list[str]:
|
||||
# model.safetensors.index.json is a mapping from keys in the
|
||||
# torch state_dict to safetensors file holding that weight.
|
||||
index_file_name = os.path.join(hf_folder, index_file)
|
||||
if not os.path.isfile(index_file_name):
|
||||
return hf_weights_files
|
||||
|
||||
# Iterate through the weight_map (weight_name: safetensors files)
|
||||
# to identify weights that we should use.
|
||||
with open(index_file_name) as f:
|
||||
weight_map = json.load(f)["weight_map"]
|
||||
weight_files_in_index = set()
|
||||
for weight_name in weight_map:
|
||||
weight_files_in_index.add(
|
||||
os.path.join(hf_folder, weight_map[weight_name]))
|
||||
# Filter out any fields that are not found in the index file.
|
||||
hf_weights_files = [
|
||||
f for f in hf_weights_files if f in weight_files_in_index
|
||||
]
|
||||
return hf_weights_files
|
||||
|
||||
|
||||
def filter_files_not_needed_for_inference(
|
||||
hf_weights_files: list[str]) -> list[str]:
|
||||
"""
|
||||
Exclude files that are not needed for inference.
|
||||
|
||||
See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
|
||||
"""
|
||||
blacklist = [
|
||||
"training_args.bin",
|
||||
"optimizer.bin",
|
||||
"optimizer.pt",
|
||||
"scheduler.pt",
|
||||
"scaler.pt",
|
||||
]
|
||||
hf_weights_files = [
|
||||
f for f in hf_weights_files
|
||||
if not any(f.endswith(x) for x in blacklist)
|
||||
]
|
||||
return hf_weights_files
|
||||
|
||||
|
||||
# explicitly use pure text format, with a newline at the end
|
||||
# this makes it impossible to see the animation in the progress bar
|
||||
# but will avoid messing up with ray or multiprocessing, which wraps
|
||||
# each line of output with some prefix.
|
||||
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
|
||||
|
||||
|
||||
def enable_tqdm(use_tqdm_on_load: bool):
|
||||
return use_tqdm_on_load and (not torch.distributed.is_initialized()
|
||||
or torch.distributed.get_rank() == 0)
|
||||
|
||||
|
||||
def np_cache_weights_iterator(
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str],
|
||||
hf_folder: str,
|
||||
hf_weights_files: list[str],
|
||||
use_tqdm_on_load: bool,
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model np files.
|
||||
|
||||
Will dump the model weights to numpy files if they are not already dumped.
|
||||
"""
|
||||
# Convert the model weights from torch tensors to numpy arrays for
|
||||
# faster loading.
|
||||
np_folder = os.path.join(hf_folder, "np")
|
||||
os.makedirs(np_folder, exist_ok=True)
|
||||
weight_names_file = os.path.join(np_folder, "weight_names.json")
|
||||
# Use file lock to prevent multiple processes from
|
||||
# dumping the same model weights to numpy at the same time.
|
||||
with get_lock(model_name_or_path, cache_dir):
|
||||
if not os.path.exists(weight_names_file):
|
||||
weight_names: list[str] = []
|
||||
for bin_file in tqdm(
|
||||
hf_weights_files,
|
||||
desc="Loading np_cache checkpoint shards",
|
||||
disable=not enable_tqdm(use_tqdm_on_load),
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
state = torch.load(bin_file,
|
||||
map_location="cpu",
|
||||
weights_only=True)
|
||||
for name, param in state.items():
|
||||
param_path = os.path.join(np_folder, name)
|
||||
with open(param_path, "wb") as f:
|
||||
np.save(f, param.cpu().detach().numpy())
|
||||
weight_names.append(name)
|
||||
with open(weight_names_file, "w") as f:
|
||||
json.dump(weight_names, f)
|
||||
|
||||
with open(weight_names_file) as f:
|
||||
weight_names = json.load(f)
|
||||
|
||||
for name in weight_names:
|
||||
param_path = os.path.join(np_folder, name)
|
||||
with open(param_path, "rb") as f:
|
||||
param = np.load(f)
|
||||
yield name, torch.from_numpy(param)
|
||||
|
||||
|
||||
def safetensors_weights_iterator(
|
||||
hf_weights_files: list[str],
|
||||
use_tqdm_on_load: bool,
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model safetensor files."""
|
||||
for st_file in tqdm(
|
||||
hf_weights_files,
|
||||
desc="Loading safetensors checkpoint shards",
|
||||
disable=not enable_tqdm(use_tqdm_on_load),
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
with safe_open(st_file, framework="pt") as f:
|
||||
for name in f.keys(): # noqa: SIM118
|
||||
param = f.get_tensor(name)
|
||||
yield name, param
|
||||
|
||||
|
||||
def runai_safetensors_weights_iterator(
|
||||
hf_weights_files: list[str],
|
||||
use_tqdm_on_load: bool,
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model safetensor files."""
|
||||
with SafetensorsStreamer() as streamer:
|
||||
for st_file in tqdm(
|
||||
hf_weights_files,
|
||||
desc="Loading safetensors using Runai Model Streamer",
|
||||
disable=not enable_tqdm(use_tqdm_on_load),
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
streamer.stream_file(st_file)
|
||||
yield from streamer.get_tensors()
|
||||
|
||||
|
||||
def fastsafetensors_weights_iterator(
|
||||
hf_weights_files: list[str],
|
||||
use_tqdm_on_load: bool,
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model safetensor files
|
||||
using fastsafetensor library."""
|
||||
if torch.distributed.is_initialized():
|
||||
pg = torch.distributed.group.WORLD
|
||||
else:
|
||||
pg = SingleGroup()
|
||||
|
||||
device = torch.device(f'cuda:{pg.rank()}')
|
||||
weight_files_sub_lists = [
|
||||
hf_weights_files[i:i + pg.size()]
|
||||
for i in range(0, len(hf_weights_files), pg.size())
|
||||
]
|
||||
|
||||
for f_list in tqdm(
|
||||
weight_files_sub_lists,
|
||||
desc="Loading safetensors using Fastsafetensor loader",
|
||||
disable=not enable_tqdm(use_tqdm_on_load),
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
loader = SafeTensorsFileLoader(pg, device)
|
||||
rank_file_map = {i: [f] for i, f in enumerate(f_list)}
|
||||
loader.add_filenames(rank_file_map)
|
||||
try:
|
||||
fb = loader.copy_files_to_device()
|
||||
try:
|
||||
keys = list(fb.key_to_rank_lidx.keys())
|
||||
for k in keys:
|
||||
t = fb.get_tensor(k)
|
||||
yield k, t
|
||||
finally:
|
||||
fb.close()
|
||||
finally:
|
||||
loader.close()
|
||||
|
||||
|
||||
def pt_weights_iterator(
|
||||
hf_weights_files: list[str],
|
||||
use_tqdm_on_load: bool,
|
||||
pt_load_map_location: Union[str, dict[str, str]] = "cpu",
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model bin/pt files."""
|
||||
for bin_file in tqdm(
|
||||
hf_weights_files,
|
||||
desc="Loading pt checkpoint shards",
|
||||
disable=not enable_tqdm(use_tqdm_on_load),
|
||||
bar_format=_BAR_FORMAT,
|
||||
):
|
||||
state = torch.load(bin_file,
|
||||
map_location=pt_load_map_location,
|
||||
weights_only=True)
|
||||
yield from state.items()
|
||||
del state
|
||||
|
||||
|
||||
def get_gguf_extra_tensor_names(
|
||||
gguf_file: str, gguf_to_hf_name_map: dict[str, str]) -> list[str]:
|
||||
reader = gguf.GGUFReader(gguf_file)
|
||||
expected_gguf_keys = set(gguf_to_hf_name_map.keys())
|
||||
exact_gguf_keys = set([tensor.name for tensor in reader.tensors])
|
||||
extra_keys = expected_gguf_keys - exact_gguf_keys
|
||||
return [gguf_to_hf_name_map[key] for key in extra_keys]
|
||||
|
||||
|
||||
def gguf_quant_weights_iterator(
|
||||
gguf_file: str, gguf_to_hf_name_map: dict[str, str]
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""
|
||||
Iterate over the quant weights in the model gguf files and convert
|
||||
them to torch tensors
|
||||
"""
|
||||
|
||||
reader = gguf.GGUFReader(gguf_file)
|
||||
|
||||
for tensor in reader.tensors:
|
||||
if tensor.name in gguf_to_hf_name_map:
|
||||
weight_type = tensor.tensor_type
|
||||
name = gguf_to_hf_name_map[tensor.name]
|
||||
|
||||
if weight_type.name != "F32":
|
||||
weight_type_name = name.replace("weight", "qweight_type")
|
||||
weight_type = torch.tensor(weight_type)
|
||||
yield weight_type_name, weight_type
|
||||
|
||||
for tensor in reader.tensors:
|
||||
if tensor.name in gguf_to_hf_name_map:
|
||||
weight = tensor.data
|
||||
weight_type = tensor.tensor_type
|
||||
name = gguf_to_hf_name_map[tensor.name]
|
||||
if weight_type.name != "F32":
|
||||
name = name.replace("weight", "qweight")
|
||||
param = torch.tensor(weight)
|
||||
yield name, param
|
||||
|
||||
|
||||
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
|
||||
"""convert PySafeSlice object from safetensors to torch.Tensor
|
||||
|
||||
PySafeSlice object supports indexing, which is done before loading the
|
||||
actual tensor and can reduce the amount of memory being read into the
|
||||
memory. However, it does not support more advanced functionalities
|
||||
like `.view()` or `.t()`. Therefore, if we need to modify the loaded
|
||||
tensor with these more complicated operators, we need to convert to
|
||||
tensor first.
|
||||
"""
|
||||
if not isinstance(x, torch.Tensor):
|
||||
x = x[:]
|
||||
return x
|
||||
|
||||
|
||||
def default_weight_loader(param: torch.Tensor,
|
||||
loaded_weight: torch.Tensor) -> None:
|
||||
"""Default weight loader."""
|
||||
try:
|
||||
if param.numel() == 1 and loaded_weight.numel() == 1:
|
||||
# Sometimes scalar values aren't considered tensors with shapes
|
||||
# so if both param and loaded_weight are a scalar,
|
||||
# "broadcast" instead of copy
|
||||
param.data.fill_(loaded_weight.item())
|
||||
else:
|
||||
assert param.size() == loaded_weight.size(), (
|
||||
f"Attempted to load weight ({loaded_weight.size()}) "
|
||||
f"into parameter ({param.size()})")
|
||||
|
||||
param.data.copy_(loaded_weight)
|
||||
except Exception:
|
||||
# NOTE: This exception is added for the purpose of setting breakpoint to
|
||||
# debug weight loading issues.
|
||||
raise
|
||||
|
||||
|
||||
def row_parallel_weight_loader(param: torch.Tensor,
|
||||
loaded_weight: torch.Tensor) -> None:
|
||||
"""Load weights that are row-parallelized."""
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
shard_dim = 0 if param.dim() != 1 else None
|
||||
|
||||
if shard_dim is not None:
|
||||
shard_size = param.data.shape[shard_dim]
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(shard_dim, start_idx, shard_size)
|
||||
|
||||
return default_weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
LoaderFunction = Callable[[torch.Tensor, torch.Tensor], None]
|
||||
|
||||
|
||||
def sharded_weight_loader(shard_axis: int) -> LoaderFunction:
|
||||
"""Create a weight loader that shards the weights along the given axis"""
|
||||
|
||||
def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
shard_size = param.data.shape[shard_axis]
|
||||
start_idx = tp_rank * shard_size
|
||||
loaded_weight = loaded_weight.narrow(shard_axis, start_idx, shard_size)
|
||||
|
||||
return default_weight_loader(param, loaded_weight)
|
||||
|
||||
return loader
|
||||
|
||||
|
||||
def composed_weight_loader(
|
||||
loader: LoaderFunction, fn: Callable[[torch.Tensor],
|
||||
torch.Tensor]) -> LoaderFunction:
|
||||
"""Create a weight loader that post-processes the weights after loading"""
|
||||
|
||||
def composed_loader(param: torch.Tensor,
|
||||
loaded_weight: torch.Tensor) -> None:
|
||||
loader(param, loaded_weight)
|
||||
param.data.copy_(fn(param))
|
||||
return
|
||||
|
||||
return composed_loader
|
||||
|
||||
|
||||
def initialize_dummy_weights(
|
||||
model: torch.nn.Module,
|
||||
low: float = -1e-3,
|
||||
high: float = 1e-3,
|
||||
seed: int = 1234,
|
||||
) -> None:
|
||||
"""Initialize model weights with random values.
|
||||
|
||||
The model weights must be randomly initialized for accurate performance
|
||||
measurements. Additionally, the model weights should not cause NaNs in the
|
||||
forward pass. We empirically found that initializing the weights with
|
||||
values between -1e-3 and 1e-3 works well for most models.
|
||||
|
||||
We use per-parameter random seed, so that dummy weights are consistent,
|
||||
even if the model is partitioned across multiple devices. When the seed
|
||||
is fixed, the random values generated by this function only depends on
|
||||
the parameter's number of elements and its data type.
|
||||
"""
|
||||
for param in model.state_dict().values():
|
||||
if torch.is_floating_point(param):
|
||||
if current_platform.is_tpu():
|
||||
generator = torch.Generator(device="cpu")
|
||||
generator.manual_seed(seed)
|
||||
# Note: The param.uniform_ function cannot be used in this
|
||||
# context because it demands more TPU HBM than directly copying
|
||||
# from a CPU tensor.
|
||||
# Note: We avoid using torch.rank_like as it doesn't currently
|
||||
# support the generator argument.
|
||||
param.copy_((high - low) *
|
||||
torch.rand(param.shape,
|
||||
generator=generator,
|
||||
dtype=param.dtype,
|
||||
layout=param.layout,
|
||||
requires_grad=param.requires_grad,
|
||||
device="cpu") + low)
|
||||
torch._sync(param)
|
||||
continue
|
||||
|
||||
generator = torch.Generator(device=param.data.device)
|
||||
generator.manual_seed(seed)
|
||||
if torch.finfo(param.data.dtype).bits < 16:
|
||||
# uniform_ doesn't support < 16-bit datatypes (FP8)
|
||||
dtype = param.data.dtype
|
||||
tmp_param = param.data.to(torch.float16)
|
||||
tmp_param = tmp_param.uniform_(low, high,
|
||||
generator=generator).to(dtype)
|
||||
param.data.copy_(tmp_param)
|
||||
else:
|
||||
param.uniform_(low, high, generator=generator)
|
||||
|
||||
|
||||
def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
|
||||
"""Remap the name of FP8 k/v_scale parameters.
|
||||
|
||||
This function handles the remapping of FP8 k/v_scale parameter names.
|
||||
It detects if the given name ends with a suffix and attempts to remap
|
||||
it to the expected name format in the model. If the remapped name is not
|
||||
found in the params_dict, a warning is printed and None is returned.
|
||||
|
||||
Args:
|
||||
name (str): The original loaded checkpoint parameter name.
|
||||
params_dict (dict): Dictionary containing the model's named parameters.
|
||||
|
||||
Returns:
|
||||
str: The remapped parameter name if successful, or the original name
|
||||
if no remapping is needed.
|
||||
None: If the remapped name is not found in params_dict.
|
||||
"""
|
||||
if name.endswith(".kv_scale"):
|
||||
logger.warning_once(
|
||||
"DEPRECATED. Found kv_scale in the checkpoint. "
|
||||
"This format is deprecated in favor of separate k_scale and "
|
||||
"v_scale tensors and will be removed in a future release. "
|
||||
"Functionally, we will remap kv_scale to k_scale and duplicate "
|
||||
"k_scale to v_scale")
|
||||
# NOTE: we remap the deprecated kv_scale to k_scale
|
||||
remapped_name = name.replace(".kv_scale", ".attn.k_scale")
|
||||
if remapped_name not in params_dict:
|
||||
logger.warning_once(
|
||||
"Found kv_scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv_scale is not loaded.", # noqa: E501
|
||||
name,
|
||||
remapped_name,
|
||||
)
|
||||
return None
|
||||
return remapped_name
|
||||
|
||||
possible_scale_names = [".k_scale", ".v_scale"]
|
||||
modelopt_scale_names = [
|
||||
".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale"
|
||||
]
|
||||
for scale_name in possible_scale_names:
|
||||
if name.endswith(scale_name):
|
||||
if any(mo_scale_name in name
|
||||
for mo_scale_name in modelopt_scale_names):
|
||||
remapped_name = name.replace(
|
||||
f".self_attn.{scale_name[1]}_proj{scale_name}",
|
||||
f".self_attn.attn{scale_name}")
|
||||
else:
|
||||
remapped_name = name.replace(scale_name, f".attn{scale_name}")
|
||||
if remapped_name not in params_dict:
|
||||
logger.warning_once(
|
||||
"Found %s in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). %s is not loaded.", # noqa: E501
|
||||
scale_name,
|
||||
name,
|
||||
remapped_name,
|
||||
scale_name,
|
||||
)
|
||||
return None
|
||||
return remapped_name
|
||||
|
||||
# If there were no matches, return the untouched param name
|
||||
return name
|
||||
Reference in New Issue
Block a user