[gpt-oss] Add gpt-oss bf16 support

This commit is contained in:
2025-08-13 21:25:57 +08:00
parent 5d2e7edf78
commit 17ea2ec6aa
1232 changed files with 777 additions and 36 deletions

View File

@@ -0,0 +1,76 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from torch import nn
from vllm.config import LoadConfig, LoadFormat, ModelConfig, VllmConfig
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.bitsandbytes_loader import (
BitsAndBytesModelLoader)
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
from vllm.model_executor.model_loader.dummy_loader import DummyModelLoader
from vllm.model_executor.model_loader.gguf_loader import GGUFModelLoader
from vllm.model_executor.model_loader.runai_streamer_loader import (
RunaiModelStreamerLoader)
from vllm.model_executor.model_loader.sharded_state_loader import (
ShardedStateLoader)
from vllm.model_executor.model_loader.tensorizer_loader import TensorizerLoader
from vllm.model_executor.model_loader.utils import (
get_architecture_class_name, get_model_architecture)
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
"""Get a model loader based on the load format."""
if isinstance(load_config.load_format, type):
return load_config.load_format(load_config)
if load_config.load_format == LoadFormat.DUMMY:
return DummyModelLoader(load_config)
if load_config.load_format == LoadFormat.TENSORIZER:
return TensorizerLoader(load_config)
if load_config.load_format == LoadFormat.SHARDED_STATE:
return ShardedStateLoader(load_config)
if load_config.load_format == LoadFormat.BITSANDBYTES:
return BitsAndBytesModelLoader(load_config)
if load_config.load_format == LoadFormat.GGUF:
return GGUFModelLoader(load_config)
if load_config.load_format == LoadFormat.RUNAI_STREAMER:
return RunaiModelStreamerLoader(load_config)
if load_config.load_format == LoadFormat.RUNAI_STREAMER_SHARDED:
return ShardedStateLoader(load_config, runai_model_streamer=True)
return DefaultModelLoader(load_config)
def get_model(*,
vllm_config: VllmConfig,
model_config: Optional[ModelConfig] = None) -> nn.Module:
loader = get_model_loader(vllm_config.load_config)
if model_config is None:
model_config = vllm_config.model_config
return loader.load_model(vllm_config=vllm_config,
model_config=model_config)
__all__ = [
"get_model",
"get_model_loader",
"get_architecture_class_name",
"get_model_architecture",
"BaseModelLoader",
"BitsAndBytesModelLoader",
"GGUFModelLoader",
"DefaultModelLoader",
"DummyModelLoader",
"RunaiModelStreamerLoader",
"ShardedStateLoader",
"TensorizerLoader",
]

View File

@@ -0,0 +1,43 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from abc import ABC, abstractmethod
import torch
import torch.nn as nn
from vllm.config import LoadConfig, ModelConfig, VllmConfig
from vllm.model_executor.model_loader.utils import (
initialize_model, process_weights_after_loading, set_default_torch_dtype)
class BaseModelLoader(ABC):
"""Base class for model loaders."""
def __init__(self, load_config: LoadConfig):
self.load_config = load_config
@abstractmethod
def download_model(self, model_config: ModelConfig) -> None:
"""Download a model so that it can be immediately loaded."""
raise NotImplementedError
@abstractmethod
def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> None:
"""Load weights into a model. This standalone API allows
inplace weights loading for an already-initialized model"""
raise NotImplementedError
def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
"""Load a model with the given configurations."""
device_config = vllm_config.device_config
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = initialize_model(vllm_config=vllm_config,
model_config=model_config)
# Quantization does not happen in `load_weights` but after it
self.load_weights(model, model_config)
process_weights_after_loading(model, model_config, target_device)
return model.eval()

View File

@@ -0,0 +1,570 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: SIM117
import fnmatch
import glob
import itertools
import math
import os
from collections.abc import Generator
from typing import Any, Callable, Optional
import numpy as np
import torch
from huggingface_hub import HfApi
from torch import nn
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import LoadConfig, ModelConfig
from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
# yapf: enable
from vllm.logger import init_logger
# yapf conflicts with isort for this block
# yapf: disable
from vllm.model_executor.layers.linear import (LinearBase,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.utils import (ParamMapping,
set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf, download_weights_from_hf,
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
pt_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.models import is_pooling_model
from vllm.model_executor.utils import (get_packed_modules_mapping,
set_weight_attrs)
from vllm.platforms import current_platform
logger = init_logger(__name__)
class BitsAndBytesModelLoader(BaseModelLoader):
"""Model loader to load model weights with BitAndBytes quantization."""
possible_config_file_names = ["adapter_config.json"]
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
# Save the module names without sharding.
self.unsharded_weights_modules: list[str] = []
# Save the module names that are sharded by column.
self.column_sharded_weights_modules: list[str] = []
# Store all module names (from transformers) that support
# BNB quantization.
self.target_modules: list[str] = []
# mapping weight names from transformers to vllm.
self.weight_mapper: Callable = lambda name: name
def _get_weight_files(
self,
model_name_or_path: str,
allowed_patterns: list[str],
revision: Optional[str] = None,
) -> tuple[str, list[str], str]:
"""Retrieve weight files. Download the files if necessary.
Return the weight files and the file pattern."""
is_local = os.path.isdir(model_name_or_path)
if is_local:
for pattern in allowed_patterns:
weight_files = glob.glob(
os.path.join(model_name_or_path, pattern))
if weight_files:
return model_name_or_path, weight_files, pattern
else:
hf_api = HfApi()
repo_files = hf_api.list_repo_files(repo_id=model_name_or_path)
for pattern in allowed_patterns:
matching_files = fnmatch.filter(repo_files, pattern)
if matching_files:
hf_folder = download_weights_from_hf(
model_name_or_path,
self.load_config.download_dir,
[pattern],
revision,
ignore_patterns=self.load_config.ignore_patterns,
)
return hf_folder, glob.glob(
os.path.join(hf_folder, pattern)), pattern
raise RuntimeError(
f"No model weights found in: `{model_name_or_path}`")
def _prepare_weights(self, model_name_or_path: str,
revision: Optional[str]) -> tuple[list[str], bool]:
"""Prepare weight files for the model."""
allowed_patterns = ["*.safetensors", "*.bin", "*.pt"]
hf_folder, hf_weights_files, matched_pattern = self._get_weight_files(
model_name_or_path, allowed_patterns, revision)
use_safetensors = matched_pattern == "*.safetensors"
is_local = os.path.isdir(model_name_or_path)
index_file = SAFE_WEIGHTS_INDEX_NAME
if use_safetensors:
# For models like Mistral-7B-Instruct-v0.3
# there are both sharded safetensors files and a consolidated
# safetensors file. Using both breaks.
# Here, we download the `model.safetensors.index.json` and filter
# any files not found in the index.
if not is_local:
download_safetensors_index_file_from_hf(
model_name_or_path,
index_file,
self.load_config.download_dir,
revision,
)
hf_weights_files = filter_duplicate_safetensors_files(
hf_weights_files, hf_folder, index_file)
else:
hf_weights_files = filter_files_not_needed_for_inference(
hf_weights_files)
if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`")
return hf_weights_files, use_safetensors
def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
def _maybe_pool_model(module_name:str):
# For pool model, we need to add the prefix `model.`
# for the weight name if possible.
if self.is_pool_model and self.target_modules[0]. \
startswith("model.") and not module_name.startswith(
"model."):
return "model."+module_name
return module_name
if use_safetensors:
iterator = safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
else:
iterator = pt_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
self.load_config.pt_load_map_location,
)
for org_name, param in iterator:
# mapping weight names from transformers to vllm while preserving
# original names.
mapped_name = self.weight_mapper(org_name)
mapped_name=_maybe_pool_model(mapped_name)
yield org_name, mapped_name, param
def _get_quantized_weights_iterator(
self,
model_name_or_path: str,
revision: Optional[str],
pre_quant: bool,
load_8bit: bool,
) -> tuple[Generator[tuple[str, torch.Tensor], None, None], dict[str,
Any]]:
"""Get an iterator to the model weights with bitsandbytes quantization,
as well as the quantization state dictionary."""
# only load the bitsandbytes module when needed
try:
import bitsandbytes
if bitsandbytes.__version__ < "0.45.3":
raise ImportError("bitsandbytes version is wrong. Please "
"install bitsandbytes>=0.45.3.")
except ImportError as err:
raise ImportError("Please install bitsandbytes>=0.45.3 via "
"`pip install bitsandbytes>=0.45.3` to use "
"bitsandbytes quantizer.") from err
hf_weights_files, use_safetensors = self._prepare_weights(
model_name_or_path, revision)
quant_state_dict: dict[str, Any] = {}
if pre_quant:
if load_8bit:
return self._quantized_8bit_generator(
hf_weights_files, use_safetensors,
quant_state_dict), quant_state_dict
else:
return self._quantized_4bit_generator(
hf_weights_files, use_safetensors,
quant_state_dict), quant_state_dict
return self._unquantized_generator(hf_weights_files, use_safetensors,
quant_state_dict), quant_state_dict
def _is_8bit_weight_name(self, weight_name: str):
quantized_suffix = {".scb", ".weight_format"}
return any(weight_name.lower().endswith(suffix)
for suffix in quantized_suffix)
def _is_4bit_weight_name(self, weight_name: str):
quantized_suffix = {
"absmax",
"quant_map",
"nested_absmax",
"nested_quant_map",
"bitsandbytes",
}
suffix = weight_name.split(".")[-1]
return any(q_suffix in suffix for q_suffix in quantized_suffix)
def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
quant_state_dict) -> Generator:
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
if not mapped_weight_name.lower().endswith(".scb"):
continue
weight_key = mapped_weight_name.lower().replace(".scb", ".weight")
quant_state_dict[weight_key] = weight_tensor
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
if self._is_8bit_weight_name(mapped_weight_name):
continue
if mapped_weight_name in quant_state_dict:
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
yield org_weight_name, weight_tensor
else:
yield org_weight_name, weight_tensor
def _quantized_4bit_generator(self, hf_weights_files, use_safetensors,
quant_state_dict) -> Generator:
from bitsandbytes.functional import QuantState
# First iterate over all quant state weights
weight_iterator = self._hf_weight_iter(hf_weights_files,
use_safetensors)
temp_state_dict = {}
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in weight_iterator:
if not self._is_4bit_weight_name(mapped_weight_name):
continue
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__* in CPU
if "quant_state.bitsandbytes" in mapped_weight_name:
temp_state_dict[mapped_weight_name] = weight_tensor.cpu().data
else:
temp_state_dict[mapped_weight_name] = weight_tensor
# Closure to parse quant_state for each prequant weight
def _parse_quant_state(param_name: str,
temp_state_dict: dict) -> QuantState:
quant_state = {}
for k in temp_state_dict:
if param_name + "." in k:
quant_state[k] = temp_state_dict[k]
return QuantState.from_dict(quant_state,
device=current_platform.device_type)
# Second iterate over all prequant and normal weights
# pre quantized weights would have a quant_state
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
if self._is_4bit_weight_name(mapped_weight_name):
continue
if (f"{mapped_weight_name}.quant_state.bitsandbytes__nf4"
in temp_state_dict) or (
f"{mapped_weight_name}.quant_state.bitsandbytes__fp4"
in temp_state_dict):
quant_state = _parse_quant_state(mapped_weight_name,
temp_state_dict)
quant_state_dict[mapped_weight_name] = quant_state
yield org_weight_name, weight_tensor
else:
yield org_weight_name, weight_tensor
def _unquantized_generator(self, hf_weights_files, use_safetensors,
quant_state_dict) -> Generator:
from bitsandbytes.functional import quantize_4bit
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
for (
org_weight_name,
mapped_weight_name,
weight_tensor,
) in self._hf_weight_iter(hf_weights_files, use_safetensors):
if any(target_module in mapped_weight_name
for target_module in self.target_modules
) and mapped_weight_name.endswith(".weight"):
# Without sharding
if any(
mapped_weight_name.startswith(module)
for module in self.unsharded_weights_modules):
weight_sub_tensor = weight_tensor
# Shard by column
elif any(
mapped_weight_name.startswith(module)
for module in self.column_sharded_weights_modules):
total_size = weight_tensor.size(-1)
start_index = total_size // tp_size * tp_rank
end_index = total_size // tp_size * (tp_rank + 1)
weight_sub_tensor = weight_tensor[...,
start_index:end_index]
# Weights have fused on disk. In this case, we assume that the
# weight and module use same name.
elif any(
mapped_weight_name.startswith(module)
for module in self.maybe_fused_weights_modules):
# special case for fused weights
# get the size of each shard weight tensor
total_shard_sizes = next(
(sizes for module, sizes in
self.maybe_fused_weights_modules.items()
if mapped_weight_name.startswith(module)))
total_size = weight_tensor.size(0)
assert total_size == sum(total_shard_sizes)
# get the start/end index of each shard weight tensor
total_start_index = list(
itertools.accumulate([0] + total_shard_sizes))[:-1]
shard_weights_index = [(
idx + size // tp_size * tp_rank,
idx + size // tp_size * (tp_rank + 1),
) for idx, size in zip(total_start_index,
total_shard_sizes)]
# slice and reorder the weight tensor
weight_tensor = [
weight_tensor[start_index:end_index, ...]
for start_index, end_index in shard_weights_index
]
weight_sub_tensor = torch.cat(weight_tensor, dim=0)
# Shard by row
else:
total_size = weight_tensor.size(0)
start_index = total_size // tp_size * tp_rank
end_index = total_size // tp_size * (tp_rank + 1)
weight_sub_tensor = weight_tensor[start_index:end_index,
...]
# bitsandbytes requires data in GPU
if weight_sub_tensor.is_cuda:
loaded_weight = weight_sub_tensor
else:
loaded_weight = weight_sub_tensor.cuda()
# remove the following after the issue is fixed:
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342
if loaded_weight.is_contiguous() is False:
loaded_weight = loaded_weight.contiguous()
with set_default_torch_dtype(torch.float32):
processed_weight, quant_state = quantize_4bit(
loaded_weight,
compress_statistics=True,
quant_type="nf4",
)
quant_state_dict[mapped_weight_name] = quant_state
else:
processed_weight = weight_tensor
yield org_weight_name, processed_weight
def _get_bnb_target_modules(self, model: nn.Module) -> None:
for name, module in model.named_modules():
if (isinstance(module, LinearBase) and
hasattr(module.quant_method, "quant_config")):
if modules_info := self.modules_mapping.get_sub_modules(name):
# Map vllm's names to transformers's names.
rep_name, sub_modules = modules_info
for sub_name in sub_modules:
self.target_modules.append(
name.replace(rep_name, sub_name))
# Add original module name even if the module has stacked map,
# in case model has a mixture of disk-merged and disk-splitted
# weights with same last name.
self.target_modules.append(name)
assert (self.target_modules
), "vllm currently does not support BNB quantization for"
f" {type(model).__name__}"
def load_weights(self, model: nn.Module, model_config: ModelConfig) -> None:
if not hasattr(model, "load_weights"):
raise AttributeError(
"The required method 'load_weights' is not defined in class"
f" {type(model).__name__}.")
if not hasattr(model, "packed_modules_mapping"):
raise AttributeError(
f"Model {type(model).__name__} does not support BitsAndBytes "
"quantization yet. No 'packed_modules_mapping' found.")
self.is_pool_model=is_pooling_model(model)
self.modules_mapping = ParamMapping(get_packed_modules_mapping(model))
# For some models like Molmo, we need to use hf_to_vllm_mapper
# to ensure correct loading of weights.
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
self.weight_mapper = lambda name: hf_to_vllm_mapper._map_name(name)
# Modules whose weights might have fused on disk
# we need their output_sizes to make shard in flight correctly with TP
self.maybe_fused_weights_modules: dict[str, list[int]] = {}
self._get_bnb_target_modules(model)
for name, module in model.named_modules():
# Some modules like `ReplicatedLinear` should not have their weights
# sharded. The reason for implementing it this way is to avoid new
# static variable in the model implementation.
if isinstance(module, (ReplicatedLinear, )):
self.unsharded_weights_modules.append(name)
# `QKVParallelLinear` and `MergedColumnParallelLinear` might have
# fused weights on disk. We need to use the output sizes of these
# modules to shard the weights correctly.
elif isinstance(module,
(QKVParallelLinear, MergedColumnParallelLinear)):
self.maybe_fused_weights_modules[name] = module.output_sizes
# In TP, these weights are partitioned along the column
# dimension (dim=-1)
elif isinstance(module, (RowParallelLinear, )):
self.column_sharded_weights_modules.append(name)
self.model_type = type(model).__name__
logger.info("Loading weights with BitsAndBytes quantization. "
"May take a while ...")
quant_config = getattr(model_config.hf_config, "quantization_config",
None)
pre_quant = False
if quant_config is not None:
quant_method = quant_config.get("quant_method")
if quant_method == "bitsandbytes":
pre_quant = True
else:
raise ValueError(
f"BitsAndBytes loader does not support {quant_method} "
"quantization")
# The quant_states in pre_quantized models cannot work with a split
# weight tensor. So TP does not work with pre_quantized bnb models.
if pre_quant and get_tensor_model_parallel_world_size() > 1:
raise ValueError(
"Prequant BitsAndBytes models with tensor parallelism is not "
"supported. Please try with pipeline parallelism.")
load_8bit = False
if pre_quant:
load_8bit = quant_config.get("load_in_8bit", False)
qweight_iterator, quant_state_dict = (
self._get_quantized_weights_iterator(model_config.model,
model_config.revision,
pre_quant, load_8bit))
weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(qweight_iterator)
# Some models may have weights loading tracker unimplemented.
if loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded:
raise ValueError("Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}")
torch.cuda.empty_cache()
param_dict = dict(model.named_parameters())
stacked_quant_state_dict: dict[str, dict[int, Any]] = {}
# TODO: Change this lazy import to normal import
# after the checks are updated to run on a new version
from vllm.model_executor.models.utils import is_pp_missing_parameter
for quant_param_name in quant_state_dict:
if is_pp_missing_parameter(quant_param_name, model):
continue
non_stacked_param_name = quant_param_name
shard_index = 0
for shard_name, (
weight_name,
index,
) in self.modules_mapping.inverse_packed_mapping.items():
# Some models, such as MiniCPM V2.5/2.6, contain both
# module names 'kv_proj' and 'qkv_proj'. To prevent 'kv_proj'
# from being incorrectly identified as being present in
# 'vpm.encoder.layers.0.self_attn.qkv_proj.weight
shard_pos = quant_param_name.find(shard_name)
can_correct_rename = (shard_pos
> 0) and (quant_param_name[shard_pos - 1]
== ".")
# If the quant_param_name is packed, it won't occur in the
# param_dict before renaming.
new_quant_param_name = quant_param_name.replace(
shard_name, weight_name)
need_rename = (quant_param_name not in param_dict) \
and (new_quant_param_name in param_dict)
if can_correct_rename and need_rename:
shard_index = index
quant_param_name = new_quant_param_name
break
# Models like Clip/Siglip may skip some layers in initialization,
# causing unused quant_param_name in state_dict.
if quant_param_name not in param_dict:
continue
if quant_param_name not in stacked_quant_state_dict:
stacked_quant_state_dict[quant_param_name] = {}
stacked_quant_state_dict[quant_param_name][shard_index] = (
quant_state_dict[non_stacked_param_name])
# save quant_states and offsets as the attributes of the parameters
for param_name, param in param_dict.items():
if param_name in stacked_quant_state_dict:
quant_states = stacked_quant_state_dict[param_name]
set_weight_attrs(param, {"bnb_quant_state": quant_states})
pack_ratio = getattr(param, "pack_factor", -1)
if pack_ratio == -1:
raise ValueError(
f"pack_factor not set for parameter {param_name}.")
num_elements = [0] * len(quant_states)
for seq, quant_state in quant_states.items():
num_elements[seq] = (math.prod(quant_state.shape) //
pack_ratio)
offsets = np.concatenate(([0], np.cumsum(num_elements)))
# Make torch infer_schema happy
offsets = torch.tensor(offsets).cpu()
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
if load_8bit:
set_weight_attrs(
param, {"matmul_state": [None] * len(quant_states)})
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)

View File

@@ -0,0 +1,282 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import dataclasses
import glob
import os
import time
from collections.abc import Generator, Iterable
from typing import Optional, cast
import huggingface_hub
import torch
from torch import nn
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm import envs
from vllm.config import LoadConfig, LoadFormat, ModelConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf, download_weights_from_hf,
fastsafetensors_weights_iterator, filter_duplicate_safetensors_files,
filter_files_not_needed_for_inference, get_lock, np_cache_weights_iterator,
pt_weights_iterator, safetensors_weights_iterator)
from vllm.platforms import current_platform
logger = init_logger(__name__)
class DefaultModelLoader(BaseModelLoader):
"""Model loader that can load different file types from disk."""
@dataclasses.dataclass
class Source:
"""A source for weights."""
model_or_path: str
"""The model ID or path."""
revision: Optional[str]
"""The optional model revision."""
prefix: str = ""
"""A prefix to prepend to all weights."""
fall_back_to_pt: bool = True
"""Whether .pt weights can be used."""
allow_patterns_overrides: Optional[list[str]] = None
"""If defined, weights will load exclusively using these patterns."""
counter_before_loading_weights: float = 0.0
counter_after_loading_weights: float = 0.0
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(f"Model loader extra config is not supported for "
f"load format {load_config.load_format}")
def _maybe_download_from_modelscope(
self, model: str, revision: Optional[str]) -> Optional[str]:
"""Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
Returns the path to the downloaded model, or None if the model is not
downloaded from ModelScope."""
if envs.VLLM_USE_MODELSCOPE:
# download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use.
# pylint: disable=C.
from modelscope.hub.snapshot_download import snapshot_download
if not os.path.exists(model):
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model, self.load_config.download_dir):
model_path = snapshot_download(
model_id=model,
cache_dir=self.load_config.download_dir,
local_files_only=huggingface_hub.constants.
HF_HUB_OFFLINE,
revision=revision,
ignore_file_pattern=self.load_config.ignore_patterns,
)
else:
model_path = model
return model_path
return None
def _prepare_weights(
self,
model_name_or_path: str,
revision: Optional[str],
fall_back_to_pt: bool,
allow_patterns_overrides: Optional[list[str]],
) -> tuple[str, list[str], bool]:
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
model_name_or_path = (self._maybe_download_from_modelscope(
model_name_or_path, revision) or model_name_or_path)
is_local = os.path.isdir(model_name_or_path)
load_format = self.load_config.load_format
use_safetensors = False
index_file = SAFE_WEIGHTS_INDEX_NAME
# Some quantized models use .pt files for storing the weights.
if load_format == LoadFormat.AUTO:
allow_patterns = ["*.safetensors", "*.bin"]
elif (load_format == LoadFormat.SAFETENSORS
or load_format == LoadFormat.FASTSAFETENSORS):
use_safetensors = True
allow_patterns = ["*.safetensors"]
elif load_format == LoadFormat.MISTRAL:
use_safetensors = True
allow_patterns = ["consolidated*.safetensors"]
index_file = "consolidated.safetensors.index.json"
elif load_format == LoadFormat.PT:
allow_patterns = ["*.pt"]
elif load_format == LoadFormat.NPCACHE:
allow_patterns = ["*.bin"]
else:
raise ValueError(f"Unknown load_format: {load_format}")
if fall_back_to_pt:
allow_patterns += ["*.pt"]
if allow_patterns_overrides is not None:
allow_patterns = allow_patterns_overrides
if not is_local:
hf_folder = download_weights_from_hf(
model_name_or_path,
self.load_config.download_dir,
allow_patterns,
revision,
ignore_patterns=self.load_config.ignore_patterns,
)
else:
hf_folder = model_name_or_path
hf_weights_files: list[str] = []
for pattern in allow_patterns:
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
if len(hf_weights_files) > 0:
if pattern == "*.safetensors":
use_safetensors = True
break
if use_safetensors:
# For models like Mistral-7B-Instruct-v0.3
# there are both sharded safetensors files and a consolidated
# safetensors file. Using both breaks.
# Here, we download the `model.safetensors.index.json` and filter
# any files not found in the index.
if not is_local:
download_safetensors_index_file_from_hf(
model_name_or_path,
index_file,
self.load_config.download_dir,
revision,
)
hf_weights_files = filter_duplicate_safetensors_files(
hf_weights_files, hf_folder, index_file)
else:
hf_weights_files = filter_files_not_needed_for_inference(
hf_weights_files)
if len(hf_weights_files) == 0:
raise RuntimeError(
f"Cannot find any model weights with `{model_name_or_path}`")
return hf_folder, hf_weights_files, use_safetensors
def _get_weights_iterator(
self, source: "Source"
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
source.model_or_path, source.revision, source.fall_back_to_pt,
source.allow_patterns_overrides)
if self.load_config.load_format == LoadFormat.NPCACHE:
# Currently np_cache only support *.bin checkpoints
assert use_safetensors is False
weights_iterator = np_cache_weights_iterator(
source.model_or_path,
self.load_config.download_dir,
hf_folder,
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
elif use_safetensors:
if self.load_config.load_format == LoadFormat.FASTSAFETENSORS:
weights_iterator = fastsafetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
else:
weights_iterator = safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
else:
weights_iterator = pt_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
self.load_config.pt_load_map_location,
)
if current_platform.is_tpu():
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
# not too many ops are accumulated in the XLA program.
import torch_xla.core.xla_model as xm
def _xla_weights_iterator(iterator: Generator):
for weights in iterator:
yield weights
xm.mark_step()
weights_iterator = _xla_weights_iterator(weights_iterator)
elif current_platform.is_hpu():
import habana_frameworks.torch.core as htcore
def _hpu_weights_iterator(iterator: Generator):
for weights in iterator:
yield weights
htcore.mark_step()
weights_iterator = _hpu_weights_iterator(weights_iterator)
if self.counter_before_loading_weights == 0.0:
self.counter_before_loading_weights = time.perf_counter()
# Apply the prefix.
return ((source.prefix + name, tensor)
for (name, tensor) in weights_iterator)
def get_all_weights(
self,
model_config: ModelConfig,
model: nn.Module,
) -> Generator[tuple[str, torch.Tensor], None, None]:
primary_weights = DefaultModelLoader.Source(
model_config.model,
model_config.revision,
prefix="",
fall_back_to_pt=getattr(model, "fall_back_to_pt_during_load",
True),
allow_patterns_overrides=getattr(model, "allow_patterns_overrides",
None),
)
yield from self._get_weights_iterator(primary_weights)
secondary_weights = cast(
Iterable[DefaultModelLoader.Source],
getattr(model, "secondary_weights", ()),
)
for source in secondary_weights:
yield from self._get_weights_iterator(source)
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model,
model_config.revision,
fall_back_to_pt=True,
allow_patterns_overrides=None)
def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> None:
weights_to_load = {name for name, _ in model.named_parameters()}
loaded_weights = model.load_weights(
self.get_all_weights(model_config, model))
self.counter_after_loading_weights = time.perf_counter()
logger.info(
"Loading weights took %.2f seconds",
self.counter_after_loading_weights -
self.counter_before_loading_weights)
# We only enable strict check for non-quantized models
# that have loaded weights tracking currently.
if model_config.quantization is None and loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded:
raise ValueError("Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}")

View File

@@ -0,0 +1,27 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch.nn as nn
from vllm.config import LoadConfig, ModelConfig
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.weight_utils import (
initialize_dummy_weights)
class DummyModelLoader(BaseModelLoader):
"""Model loader that will set model weights to random values."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(f"Model loader extra config is not supported for "
f"load format {load_config.load_format}")
def download_model(self, model_config: ModelConfig) -> None:
pass # Nothing to download
def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> None:
# NOTE(woosuk): For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)

View File

@@ -0,0 +1,120 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
from collections.abc import Generator
import gguf
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM
from vllm.config import LoadConfig, ModelConfig, VllmConfig
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.utils import (
initialize_model, process_weights_after_loading, set_default_torch_dtype)
from vllm.model_executor.model_loader.weight_utils import (
get_gguf_extra_tensor_names, gguf_quant_weights_iterator)
class GGUFModelLoader(BaseModelLoader):
"""
Model loader that can load GGUF files. This is useful for loading models
that are quantized with GGUF and saved in the GGUF format. This loader
supports loading both full models and sharded models.
"""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(f"Model loader extra config is not supported for "
f"load format {load_config.load_format}")
def _prepare_weights(self, model_name_or_path: str):
if os.path.isfile(model_name_or_path):
return model_name_or_path
else:
raise ValueError(f"{model_name_or_path} is not a file.")
def _get_gguf_weights_map(self, model_config: ModelConfig):
"""
GGUF uses this naming convention for their tensors from HF checkpoint:
`blk.N.BB.weight` and `blk.N.BB.bias`
where N signifies the block number of a layer, and BB signifies the
attention/mlp layer components.
See "Standardized tensor names" in
https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
"""
config = model_config.hf_config
model_type = config.model_type
gguf_to_hf_name_map = {}
# hack: ggufs have a different name than transformers
if model_type == "cohere":
model_type = "command-r"
if model_type in ("deepseek_v3", "deepseek_v2"):
model_type = "deepseek2"
# GGUF layer map assumes that we will have a merged expert weights
# so we need to map them manually
for idx in range(config.num_hidden_layers):
gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = \
f"model.layers.{idx}.mlp.gate.e_score_correction_bias"
gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = \
f"model.layers.{idx}.mlp.experts.0.down_proj.weight"
gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = \
f"model.layers.{idx}.mlp.experts.0.gate_proj.weight"
gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = \
f"model.layers.{idx}.mlp.experts.0.up_proj.weight"
arch = None
for key, value in gguf.MODEL_ARCH_NAMES.items():
if value == model_type:
arch = key
break
if arch is None:
raise RuntimeError(f"Unknown gguf model_type: {model_type}")
num_layers = config.num_hidden_layers
name_map = gguf.get_tensor_name_map(arch, num_layers)
with torch.device("meta"):
dummy_model = AutoModelForCausalLM.from_config(
config, trust_remote_code=model_config.trust_remote_code)
state_dict = dummy_model.state_dict()
for hf_name in state_dict:
name, suffix = hf_name.rsplit(".", 1)
gguf_name = name_map.get_name(name)
gguf_to_hf_name_map[f"{gguf_name}.{suffix}"] = hf_name
return gguf_to_hf_name_map
def _get_weights_iterator(
self, model_name_or_path: str, gguf_to_hf_name_map: dict[str, str]
) -> Generator[tuple[str, torch.Tensor], None, None]:
return gguf_quant_weights_iterator(model_name_or_path,
gguf_to_hf_name_map)
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model)
def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> None:
local_model_path = self._prepare_weights(model_config.model)
gguf_weights_map = self._get_gguf_weights_map(model_config)
model.load_weights(
self._get_weights_iterator(local_model_path, gguf_weights_map))
def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
device_config = vllm_config.device_config
local_model_path = self._prepare_weights(model_config.model)
gguf_weights_map = self._get_gguf_weights_map(model_config)
# we can only know if tie word embeddings after mapping weights
if "lm_head.weight" in get_gguf_extra_tensor_names(
local_model_path, gguf_weights_map):
model_config.hf_config.update({"tie_word_embeddings": True})
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = initialize_model(vllm_config=vllm_config)
self.load_weights(model, model_config)
process_weights_after_loading(model, model_config, target_device)
return model

View File

@@ -0,0 +1,476 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utilities for selecting and loading Neuron models in transformers-neuronx
framework."""
import ast
import copy
import importlib
import os
from typing import Optional
import torch
import torch.nn as nn
from transformers import PretrainedConfig
from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import get_quantization_config
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SequenceOutput)
TORCH_DTYPE_TO_NEURON_AMP = {
"auto": "f32",
"half": "f16",
"float16": "f16",
"bfloat16": "bf16",
"float": "f32",
"float32": "f32",
torch.float16: "f16",
torch.bfloat16: "bf16",
torch.float32: "f32",
}
# Models supported by Neuron.
_NEURON_SUPPORTED_MODELS: dict[str, tuple[str, str, str]] = {
"LlamaForCausalLM": ("transformers_neuronx.llama.model",
"LlamaForSampling", "LlamaForCausalLM"),
"MistralForCausalLM": ("transformers_neuronx.mistral.model",
"MistralForSampling", "MistralForCausalLM")
}
class NeuronCausalLM(nn.Module):
def __init__(self,
config: PretrainedConfig,
on_device_sampling_disabled: bool = False) -> None:
super().__init__()
self.config = config
self.logits_processor = LogitsProcessor(config.vocab_size,
logits_as_input=True)
self.on_device_sampling_disabled = on_device_sampling_disabled
if self.on_device_sampling_disabled:
# Use default sampler
self.sampler = Sampler()
# Lazy initialized
self.model: nn.Module
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_block_ids: torch.Tensor,
) -> torch.Tensor:
logits = self.model(input_ids,
cache_ids=positions,
start_ids=input_block_ids)
return logits
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(None, hidden_states, sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
if self.on_device_sampling_disabled:
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens
# On-device sampling outputs the token ids directly.
sampled_token_ids = logits.flatten()
next_tokens = []
sample_idx = 0
for seq_group in sampling_metadata.seq_groups:
samples = []
for seq_id in seq_group.seq_ids:
token_id = sampled_token_ids[sample_idx].item()
samples.append(
SequenceOutput(parent_seq_id=seq_id,
output_token=token_id,
logprobs={token_id: Logprob(token_id)}))
sample_idx += 1
next_tokens.append(
CompletionSequenceGroupOutput(samples=samples,
prompt_logprobs=None))
return SamplerOutput(outputs=next_tokens)
def load_weights(self, model_name_or_path: str, **kwargs):
arch = _get_model_architecture(self.config)
neuronx_module_path, neuronx_model_cls_name, hf_model_cls_name = (
_NEURON_SUPPORTED_MODELS[arch])
neuronx_module = importlib.import_module(neuronx_module_path)
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
self.model = neuronx_model_cls.from_pretrained(model_name_or_path,
**kwargs)
self.model.to_neuron()
class NeuronSpeculationCausalLM(nn.Module):
"""A Neuron-optimized causal language model with speculative decoding."""
SPECULATION_TERMINATION_ID = -1
def __init__(self, speculation_model) -> None:
super().__init__()
self.model = speculation_model
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_block_ids: torch.Tensor,
) -> torch.Tensor:
tokens, counts = self.model.speculative_iteration(
input_ids, positions, input_block_ids)
# Mark the end of accepted speculative tokens for each sequence with the
# speculation termination id.
batch_size, steps = tokens.shape
mask = torch.arange(steps).expand(batch_size, -1) >= counts
tokens[mask] = self.SPECULATION_TERMINATION_ID
return tokens
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[list[SamplerOutput]]:
batch_size, num_steps = logits.shape
seq_ids = [
seq_id for sg in sampling_metadata.seq_groups
for seq_id in sg.seq_ids
]
# Organize input tensors by step instead of by sequence.
accepted_token_ids_by_step = logits.transpose(0, 1)
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
sampler_output_list = []
for step_index in range(num_steps):
if all(token_id == self.SPECULATION_TERMINATION_ID
for token_id in accepted_token_ids_by_step[step_index]):
break
step_output_token_ids = []
for sequence_index in range(batch_size):
token_id = accepted_token_ids_by_step[step_index][
sequence_index]
step_output_token_ids.append(
CompletionSequenceGroupOutput(samples=[
SequenceOutput(parent_seq_id=seq_ids[sequence_index],
output_token=token_id,
logprobs={token_id: Logprob(token_id)})
],
prompt_logprobs=None))
sampler_output_list.append(
SamplerOutput(outputs=step_output_token_ids))
return sampler_output_list
def _get_model_architecture(config: PretrainedConfig) -> str:
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in _NEURON_SUPPORTED_MODELS:
return arch
raise ValueError(
f"Model architectures {architectures} are not supported on Neuron "
f"for now. Supported architectures: "
f"{list(_NEURON_SUPPORTED_MODELS.keys())}")
def _get_buckets(env: str, default_value: list[int]) -> list[int]:
env_value = os.getenv(env)
if env_value is None:
return default_value
buckets_remove_empty = filter(
lambda x: x is not None and len(x.strip()) > 0, env_value.split(","))
buckets_int = map(int, buckets_remove_empty)
buckets_list = list(buckets_int)
return buckets_list
def _get_default_neuron_config(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig):
"""Generate a neuron config based on vllm config args."""
from transformers_neuronx.config import ContinuousBatchingConfig
from transformers_neuronx.constants import LAYOUT_BSH
continuous_batching_config = ContinuousBatchingConfig(
batch_size_for_shared_caches=scheduler_config.max_num_seqs)
quant_config = dict(
dequant_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
quantize_method="vector_dynamic")
neuron_quantization_config_builder = lambda quant: get_quantization_config(
quant).from_config(quant_config).get_quant_method(None, "")
# TODO: Add Paged attention config to the default neuron arguments.
default_neuron_args = dict(
collectives_layout=LAYOUT_BSH,
attention_layout=LAYOUT_BSH,
fuse_qkv=True,
quant=neuron_quantization_config_builder(model_config.quantization)
if model_config.quantization else None,
continuous_batching=continuous_batching_config,
weight_tiling=bool(model_config.quantization),
on_device_generation=_get_neuron_on_device_generation_config(
model_config))
return default_neuron_args
def _get_default_neuron_config_for_speculation(
model_config: ModelConfig, parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig):
"""Generate a neuron config for speculative decoding based on
vllm config args."""
from transformers_neuronx.config import ContinuousBatchingConfig
from transformers_neuronx.constants import LAYOUT_BSH
continuous_batching_config = ContinuousBatchingConfig(
batch_size_for_shared_caches=scheduler_config.max_num_seqs)
default_neuron_args = dict(collectives_layout=LAYOUT_BSH,
attention_layout=LAYOUT_BSH,
fuse_qkv=True,
on_device_embedding=True,
continuous_batching=continuous_batching_config,
on_device_generation=copy.deepcopy(
model_config.neuron_sampling_params))
return default_neuron_args
def _get_neuron_on_device_generation_config(model_config: ModelConfig):
if not _is_neuron_on_device_sampling_disabled(model_config):
return copy.deepcopy(model_config.neuron_sampling_params)
return None
def _is_neuron_on_device_sampling_disabled(model_config: ModelConfig) -> bool:
return not getattr(model_config, "neuron_sampling_params", None)
def _get_neuron_config_after_override(default_neuron_config,
overridden_neuron_config):
from transformers_neuronx.config import (ContinuousBatchingConfig,
GenerationConfig,
KVCacheQuantizationConfig,
NeuronConfig, QuantizationConfig,
SparseAttnConfig)
sparse_attn = overridden_neuron_config.pop("sparse_attn", {})
if sparse_attn:
overridden_neuron_config["sparse_attn"] = SparseAttnConfig(
**sparse_attn)
kv_cache_quant = overridden_neuron_config.pop("kv_cache_quant", {})
if kv_cache_quant:
overridden_neuron_config["kv_cache_quant"] = KVCacheQuantizationConfig(
**kv_cache_quant)
continuous_batching = overridden_neuron_config.pop("continuous_batching",
{})
if continuous_batching:
overridden_neuron_config[
"continuous_batching"] = ContinuousBatchingConfig(
**continuous_batching)
quant = overridden_neuron_config.pop("quant", {})
if quant:
overridden_neuron_config["quant"] = QuantizationConfig(**quant)
on_device_generation = overridden_neuron_config.pop(
"on_device_generation", {})
if on_device_generation:
overridden_neuron_config["on_device_generation"] = GenerationConfig(
**on_device_generation)
default_neuron_config.update(overridden_neuron_config)
return NeuronConfig(**default_neuron_config)
def get_neuron_model(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig) -> nn.Module:
"""Initializes a neuron-optimized model for inference."""
# Create a model instance.
model = NeuronCausalLM(
model_config.hf_config,
_is_neuron_on_device_sampling_disabled(model_config))
default_neuron_config_args = _get_default_neuron_config(
model_config, parallel_config, scheduler_config)
neuron_config = _get_neuron_config_after_override(
default_neuron_config_args, model_config.override_neuron_config)
context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS",
[scheduler_config.max_model_len])
n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS",
[scheduler_config.max_model_len])
model.load_weights(model_config.model,
tp_degree=parallel_config.tensor_parallel_size,
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
neuron_config=neuron_config,
context_length_estimate=context_length_estimates,
n_positions=n_positions,
batch_size=scheduler_config.max_num_seqs)
return model.eval()
def get_neuron_speculation_model(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
speculation_config: SpeculativeConfig):
"""Initializes a neuron-optimized speculation model for inference.
This method is only applicable for speculation with a standalone draft model
"""
from transformers_neuronx.fused_speculation import FusedSpeculativeDecoder
# For Eagle SD, we need to pass in additional parameters in neuron config.
is_eagle = getattr(speculation_config.draft_model_config.hf_config,
"is_eagle", False)
# Create target model instance.
target_model = NeuronCausalLM(model_config.hf_config)
default_neuron_config_args = _get_default_neuron_config_for_speculation(
model_config, parallel_config, scheduler_config)
if is_eagle:
default_neuron_config_args['is_eagle_target'] = True
neuron_config = _get_neuron_config_after_override(
default_neuron_config_args, model_config.override_neuron_config)
context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS",
[scheduler_config.max_model_len])
n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS",
[scheduler_config.max_model_len])
target_model.load_weights(
model_config.model,
tp_degree=parallel_config.tensor_parallel_size,
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
neuron_config=neuron_config,
context_length_estimate=context_length_estimates,
n_positions=n_positions,
batch_size=scheduler_config.max_num_seqs)
target_model.eval()
# Create draft model instance.
draft_model = NeuronCausalLM(
speculation_config.draft_model_config.hf_config)
default_draft_neuron_config_args = (
_get_default_neuron_config_for_speculation(
speculation_config.draft_model_config, parallel_config,
scheduler_config))
if is_eagle:
default_draft_neuron_config_args['is_eagle_draft'] = True
default_draft_neuron_config_args['has_pre_attention_norm'] = False
draft_neuron_config = _get_neuron_config_after_override(
default_draft_neuron_config_args,
speculation_config.draft_model_config.override_neuron_config)
draft_model.load_weights(speculation_config.draft_model_config.model,
tp_degree=speculation_config.
draft_parallel_config.tensor_parallel_size,
amp=TORCH_DTYPE_TO_NEURON_AMP[
speculation_config.draft_model_config.dtype],
neuron_config=draft_neuron_config,
context_length_estimate=context_length_estimates,
n_positions=n_positions,
batch_size=scheduler_config.max_num_seqs)
draft_model.eval()
num_speculative_tokens = speculation_config.num_speculative_tokens
# Create speculation model instance.
speculation_model = FusedSpeculativeDecoder(draft_model.model,
target_model.model,
num_speculative_tokens)
speculation_model.to_neuron()
return NeuronSpeculationCausalLM(speculation_model)
def get_neuron_eagle_speculation_model(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
speculation_config: SpeculativeConfig):
"""Initializes a neuron-optimized EAGLE speculation model for inference."""
from transformers_neuronx.eagle_speculation import EagleSpeculativeDecoder
# Create target model instance.
target_model = NeuronCausalLM(model_config.hf_config)
default_neuron_config_args = _get_default_neuron_config_for_speculation(
model_config, parallel_config, scheduler_config)
default_neuron_config_args['is_eagle_target'] = True
neuron_config = _get_neuron_config_after_override(
default_neuron_config_args, model_config.override_neuron_config)
context_length_estimates = _get_buckets("NEURON_CONTEXT_LENGTH_BUCKETS",
[scheduler_config.max_model_len])
n_positions = _get_buckets("NEURON_TOKEN_GEN_BUCKETS",
[scheduler_config.max_model_len])
target_model.load_weights(
model_config.model,
tp_degree=parallel_config.tensor_parallel_size,
amp=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
neuron_config=neuron_config,
context_length_estimate=context_length_estimates,
n_positions=n_positions,
batch_size=scheduler_config.max_num_seqs)
target_model.eval()
# Create draft model instance.
draft_model = NeuronCausalLM(
speculation_config.draft_model_config.hf_config)
default_draft_neuron_config_args = (
_get_default_neuron_config_for_speculation(
speculation_config.draft_model_config, parallel_config,
scheduler_config))
default_draft_neuron_config_args['is_eagle_draft'] = True
default_draft_neuron_config_args['has_pre_attention_norm'] = False
draft_neuron_config = _get_neuron_config_after_override(
default_draft_neuron_config_args,
speculation_config.draft_model_config.override_neuron_config)
draft_model.load_weights(speculation_config.draft_model_config.model,
tp_degree=speculation_config.
draft_parallel_config.tensor_parallel_size,
amp=TORCH_DTYPE_TO_NEURON_AMP[
speculation_config.draft_model_config.dtype],
neuron_config=draft_neuron_config,
context_length_estimate=context_length_estimates,
n_positions=n_positions,
batch_size=scheduler_config.max_num_seqs)
draft_model.eval()
token_tree: dict[int, list[int]] = ast.literal_eval(
speculation_config.speculative_token_tree)
speculation_model = EagleSpeculativeDecoder(draft_model.model,
target_model.model,
token_tree=token_tree)
speculation_model.to_neuron()
return NeuronSpeculationCausalLM(speculation_model)

View File

@@ -0,0 +1,685 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utilities for selecting and loading Neuron models in
neuronx-distributed-inference framework."""
# Disabling yapf because yapf and isort have conflicts for the below imports
# yapf: disable
import copy
import hashlib
import importlib
import multiprocessing
import os
import shutil
from typing import Optional
import torch
import torch.nn as nn
from neuronx_distributed_inference.models.config import (
FusedSpecNeuronConfig, OnDeviceSamplingConfig)
from neuronx_distributed_inference.models.mllama.utils import (
create_vision_mask)
from neuronx_distributed_inference.modules.lora_serving import (
LoraServingConfig)
from neuronx_distributed_inference.utils.hf_adapter import (
load_pretrained_config)
from transformers import AutoModelForCausalLM, AutoTokenizer, PretrainedConfig
from vllm.config import (ModelConfig, ParallelConfig, SchedulerConfig,
SpeculativeConfig)
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.sampler import Sampler, SamplerOutput
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
SequenceOutput)
# yapf: enable
logger = init_logger(__name__)
TORCH_DTYPE_TO_NEURON_AMP = {
"auto": "float32",
"half": "float16",
"float16": "float16",
"bfloat16": "bfloat16",
"float": "float32",
"float32": "float32",
torch.float16: "float16",
torch.bfloat16: "bfloat16",
torch.float32: "float32",
}
# Models supported by Neuronx distributed for inference.
_NEURON_SUPPORTED_MODELS: dict[str, tuple[str, str]] = {
"LlamaForCausalLM":
("neuronx_distributed_inference.models.llama.modeling_llama",
"NeuronLlamaForCausalLM"),
"MistralForCausalLM":
("neuronx_distributed_inference.models.llama.modeling_llama",
"NeuronLlamaForCausalLM"),
"DbrxForCausalLM":
("neuronx_distributed_inference.models.dbrx.modeling_dbrx",
"NeuronDbrxForCausalLM"),
"MixtralForCausalLM":
("neuronx_distributed_inference.models.mixtral.modeling_mixtral",
"NeuronMixtralForCausalLM"),
"MllamaForConditionalGeneration":
("neuronx_distributed_inference.models.mllama.modeling_mllama",
"NeuronMllamaForCausalLM"),
}
class NeuronCausalLM(nn.Module):
def __init__(
self,
config: PretrainedConfig,
) -> None:
super().__init__()
self.config = config
self.logits_processor = LogitsProcessor(config.vocab_size,
logits_as_input=True)
self.sampler = Sampler()
# Lazy initialized
self.model: nn.Module
def forward(self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_block_ids: torch.Tensor,
sampling_params: torch.Tensor,
prev_hidden: Optional[torch.Tensor] = None,
adapter_ids: Optional[torch.Tensor] = None) -> torch.Tensor:
# sort block ids sequentially for perf/neuron support reasons
sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids)
input_ids = torch.index_select(input_ids, 0, sorted_indices)
positions = torch.index_select(positions, 0, sorted_indices)
sampling_params = torch.index_select(sampling_params, 0,
sorted_indices)
output = self.model(input_ids,
attention_mask=None,
position_ids=positions,
seq_ids=sorted_input_block_ids,
sampling_params=sampling_params,
prev_hidden=prev_hidden,
adapter_ids=adapter_ids)
# on-device sampling
if self.config.neuron_config.on_device_sampling_config:
output = output.hidden_states
else:
output = output.logits[:, -1, :]
restored_indices = torch.argsort(sorted_indices)
if input_block_ids.shape[0] != 1:
output = torch.index_select(output, 0, restored_indices)
return output
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(None, hidden_states, sampling_metadata)
return logits
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
# on-device sampling
if self.config.neuron_config.on_device_sampling_config:
batch_size = logits.shape
seq_ids = [
seq_id for sg in sampling_metadata.seq_groups
for seq_id in sg.seq_ids
]
assert len(seq_ids) == list(batch_size)[0], "batch size mismatch"
# Organize input tensors by step instead of by sequence.
accepted_token_ids_by_step = logits.flatten()
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
step_output_token_ids = []
for i, seq_id in enumerate(seq_ids):
token_id = accepted_token_ids_by_step[i]
step_output_token_ids.append(
CompletionSequenceGroupOutput(samples=[
SequenceOutput(parent_seq_id=seq_id,
output_token=token_id,
logprobs={token_id: Logprob(token_id)})
],
prompt_logprobs=None))
return SamplerOutput(outputs=step_output_token_ids)
else:
return self.sampler(logits, sampling_metadata)
def load_weights(self, model_name_or_path: str, **kwargs):
arch = _get_model_architecture(self.config)
neuronx_module_path, neuronx_model_cls_name = (
_NEURON_SUPPORTED_MODELS[arch])
neuronx_module = importlib.import_module(neuronx_module_path)
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
neuron_config = neuronx_model_cls.get_neuron_config_cls()(
**kwargs['neuron_config'])
self.config.neuron_config = neuron_config
config = neuronx_model_cls.get_config_cls()(
neuron_config,
load_config=load_pretrained_config(model_name_or_path))
hashed_config = hashlib.md5(config.to_json_string().encode('utf-8'),
usedforsecurity=False).hexdigest()
if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None:
compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS")
elif os.path.exists(model_name_or_path):
compiled_model_path = os.path.join(model_name_or_path,
"neuron-compiled-artifacts",
hashed_config)
shutil.rmtree(compiled_model_path, ignore_errors=True)
else:
compiled_model_path = os.path.join("local-models",
model_name_or_path,
"neuron-compiled-artifacts",
hashed_config)
shutil.rmtree(compiled_model_path, ignore_errors=True)
try:
self.model = neuronx_model_cls(compiled_model_path)
override_neuron_config = kwargs["override_neuron_config"]
for k, v in override_neuron_config.items():
setattr(self.model.config.neuron_config, k, v)
self.model.load(compiled_model_path)
return
except (FileNotFoundError, ValueError) as e:
logger.warning("Exception: %s", e)
logger.warning("Failed to load the model from %s, Recompiling...",
compiled_model_path)
if not os.path.exists(model_name_or_path):
hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
saved_path = os.path.join("local-models", model_name_or_path)
hf_model.save_pretrained(saved_path)
model_name_or_path = saved_path
self.model = neuronx_model_cls(model_name_or_path, config)
self.model.compile(compiled_model_path)
self.model.load(compiled_model_path)
class NeuronMllamaForCausalLM(nn.Module):
def __init__(self,
config: PretrainedConfig,
on_device_sampling_disabled: bool = False) -> None:
super().__init__()
# has_image is the only multimodal input that is used in
# token-generation
# This is a cache (on CPU) that saves has_image data per sequence id
# The number of entries in this cache is <= Batch-Size
self.has_image_cache: dict[int, torch.Tensor] = {}
self.config = config
self.logits_processor = LogitsProcessor(
config.get_text_config().vocab_size, logits_as_input=True)
self.on_device_sampling_disabled = on_device_sampling_disabled
if self.on_device_sampling_disabled:
# Use default sampler
self.sampler = Sampler()
# Lazy initialized
self.model: nn.Module
self.is_reorder_needed: bool = True
def read_from_has_image_cache(self, seq_ids: torch.Tensor):
has_image_list = []
for index in range(len(seq_ids)):
seq_id = seq_ids[index].item()
if seq_id in self.has_image_cache:
has_image_list.append(self.has_image_cache[seq_id])
else:
has_image_list.append(torch.tensor([0]))
return torch.tensor(has_image_list)
def write_to_has_image_cache(self, seq_ids: torch.Tensor,
has_image: torch.Tensor):
for index in range(len(seq_ids)):
seq_id = seq_ids[index].item()
if index < len(has_image):
self.has_image_cache[seq_id] = has_image[index]
else:
self.has_image_cache[seq_id] = torch.zeros(1)
def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
seq_ids: torch.Tensor, pixel_values: torch.Tensor,
aspect_ratios: torch.Tensor, num_chunks: torch.Tensor,
has_image: torch.Tensor, sampling_params) -> torch.Tensor:
# We update the has_image cache during prefill
# and read the has_image cache during decode
if input_ids.shape[-1] > 1: # prefill
self.write_to_has_image_cache(seq_ids, has_image)
else:
has_image = self.read_from_has_image_cache(seq_ids)
bs = input_ids.shape[0]
num_chunks = torch.zeros((bs, 1))
aspect_ratios = torch.zeros((bs, 1, 2))
input_block_ids = seq_ids
origin_input_block_ids = seq_ids
if self.is_reorder_needed:
# sort block ids sequentially for perf/neuron support reasons
input_block_ids, sorted_indices = torch.sort(input_block_ids)
input_ids = torch.index_select(input_ids, 0, sorted_indices)
positions = torch.index_select(positions, 0, sorted_indices)
sampling_params = torch.index_select(sampling_params, 0,
sorted_indices)
pixel_values = torch.index_select(pixel_values, 0, sorted_indices)
aspect_ratios = torch.index_select(aspect_ratios, 0,
sorted_indices)
num_chunks = torch.index_select(num_chunks, 0, sorted_indices)
has_image = torch.index_select(has_image, 0, sorted_indices)
self.vision_mask = create_vision_mask(input_ids, self.vision_token_id)
output = self.model(
input_ids.to(torch.int32),
attention_mask=None,
position_ids=positions.to(torch.int32),
seq_ids=seq_ids.flatten().to(torch.int32),
pixel_values=pixel_values.to(
self.config.vision_config.torch_dtype),
aspect_ratios=aspect_ratios.to(torch.int32),
vision_mask=self.vision_mask.to(torch.int32),
sampling_params=sampling_params,
num_chunks=num_chunks.to(torch.int32),
has_image=has_image.to(torch.int32),
)
if self.config.neuron_config.on_device_sampling_config:
output = output.hidden_states
else:
output = output.logits[:, -1, :]
if self.is_reorder_needed and origin_input_block_ids.shape[0] != 1:
restored_indices = torch.argsort(sorted_indices)
output = torch.index_select(output, 0, restored_indices)
return output
def compute_logits(self, hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata) -> torch.Tensor:
logits = self.logits_processor(None, hidden_states, sampling_metadata)
return logits
def sample(self, hidden_states, sampling_metadata):
if not self.on_device_sampling_disabled:
with torch.profiler.record_function("sample"):
hidden_states = hidden_states.flatten()
res = []
sample_idx = 0
for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids
samples = []
for seq_id in seq_ids:
token_id = hidden_states[sample_idx].item()
samples.append(
SequenceOutput(
parent_seq_id=seq_id,
output_token=token_id,
logprobs={token_id: Logprob(token_id)}))
sample_idx += 1
res.append(
CompletionSequenceGroupOutput(samples=samples,
prompt_logprobs=None))
next_tokens = SamplerOutput(outputs=res)
else:
next_tokens = self.sampler(None, hidden_states, sampling_metadata)
return next_tokens
def load_weights(self, model_name_or_path: str, **kwargs):
arch = _get_model_architecture(self.config)
neuronx_module_path, neuronx_model_cls_name = (
_NEURON_SUPPORTED_MODELS[arch])
neuronx_module = importlib.import_module(neuronx_module_path)
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
neuron_config = neuronx_model_cls.get_neuron_config_cls()(
**kwargs['neuron_config'])
self.config.neuron_config = neuron_config
logger.info("neuron_config buckets: %s",
self.config.neuron_config.buckets)
config = neuronx_model_cls.get_config_cls()(
neuron_config,
load_config=load_pretrained_config(model_name_or_path))
hashed_config = hashlib.md5(config.to_json_string().encode('utf-8'),
usedforsecurity=False).hexdigest()
if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None:
compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS")
elif os.path.exists(model_name_or_path):
compiled_model_path = os.path.join(model_name_or_path,
"neuron-compiled-artifacts",
hashed_config)
else:
compiled_model_path = os.path.join("local-models",
model_name_or_path,
"neuron-compiled-artifacts",
hashed_config)
try:
self.model = neuronx_model_cls(compiled_model_path)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.vision_token_id = tokenizer(
"<|image|>", add_special_tokens=False).input_ids[0]
self.model.load(compiled_model_path)
return
except (FileNotFoundError, ValueError):
logger.warning("Failed to load the model from %s, Recompiling...",
compiled_model_path)
if not os.path.exists(model_name_or_path):
hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
saved_path = os.path.join("local-models", model_name_or_path)
hf_model.save_pretrained(saved_path)
model_name_or_path = saved_path
self.model = neuronx_model_cls(model_name_or_path, config)
logger.info("\nCompiling and saving model to %s", model_name_or_path)
p = multiprocessing.Process(target=compile_model,
args=(self, compiled_model_path))
p.start()
p.join()
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
tokenizer.save_pretrained(compiled_model_path)
logger.info("Successfully compiled and saved the model in %s",
compiled_model_path)
# Read "<|image|>" token_id from the tokenizer
self.vision_token_id = tokenizer("<|image|>",
add_special_tokens=False).input_ids[0]
logger.info("\nLoading model from compiled checkpoint...")
self.model.load(compiled_model_path)
def compile_model(neuron_model, traced_model_path):
neuron_model.model.compile(traced_model_path)
class NeuronSpeculationCausalLM(nn.Module):
"""A Neuron-optimized causal language model with speculative decoding."""
def __init__(
self,
config: PretrainedConfig,
) -> None:
super().__init__()
self.config = config
self.logits_processor = LogitsProcessor(config.vocab_size,
logits_as_input=True)
# Lazy initialized
self.model: nn.Module
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_block_ids: torch.Tensor,
sampling_params: torch.Tensor,
) -> torch.Tensor:
# sort block ids sequentially for perf/neuron support reasons
sorted_input_block_ids, sorted_indices = torch.sort(input_block_ids)
input_ids = torch.index_select(input_ids, 0, sorted_indices)
positions = torch.index_select(positions, 0, sorted_indices)
sampling_params = torch.index_select(sampling_params, 0,
sorted_indices)
output = self.model(input_ids,
attention_mask=None,
position_ids=positions,
seq_ids=sorted_input_block_ids,
sampling_params=sampling_params)
restored_indices = torch.argsort(sorted_indices)
# CTX encoding
if (positions[:, 0]).sum().item() == 0:
output = output.fused_outputs[0][:, 0:1]
if input_block_ids.shape[0] != 1:
output = torch.index_select(output, 0, restored_indices)
return output
# Fused Spec (Generation)
accepted_tokens_with_padding = output.fused_outputs[0]
next_pos_ids = output.fused_outputs[-1]
generated_token_counts = next_pos_ids - positions
assert torch.any(generated_token_counts == 0).item() is False, \
"NxDI model generated no output for one or more sequences."
batch_size, steps = accepted_tokens_with_padding.shape
mask = torch.arange(steps).expand(batch_size,
-1) >= generated_token_counts
accepted_tokens_with_padding[mask] = -1
if input_block_ids.shape[0] != 1:
accepted_tokens_with_padding = torch.index_select(
accepted_tokens_with_padding, 0, restored_indices)
return accepted_tokens_with_padding
def sample(
self,
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[list[SamplerOutput]]:
batch_size, num_steps = logits.shape
seq_ids = [
seq_id for sg in sampling_metadata.seq_groups
for seq_id in sg.seq_ids
]
# Organize input tensors by step instead of by sequence.
accepted_token_ids_by_step = logits.transpose(0, 1)
accepted_token_ids_by_step = accepted_token_ids_by_step.tolist()
sampler_output_list = []
for step_index in range(num_steps):
if all(token_id == -1
for token_id in accepted_token_ids_by_step[step_index]):
break
step_output_token_ids = []
for sequence_index in range(batch_size):
token_id = accepted_token_ids_by_step[step_index][
sequence_index]
step_output_token_ids.append(
CompletionSequenceGroupOutput(samples=[
SequenceOutput(parent_seq_id=seq_ids[sequence_index],
output_token=token_id,
logprobs={token_id: Logprob(token_id)})
],
prompt_logprobs=None))
sampler_output_list.append(
SamplerOutput(outputs=step_output_token_ids))
return sampler_output_list
def load_weights(self, model_name_or_path: str,
draft_model_name_or_path: str, **kwargs):
arch = _get_model_architecture(self.config)
neuronx_module_path, neuronx_model_cls_name = (
_NEURON_SUPPORTED_MODELS[arch])
neuronx_module = importlib.import_module(neuronx_module_path)
neuronx_model_cls = getattr(neuronx_module, neuronx_model_cls_name)
neuron_config = neuronx_model_cls.get_neuron_config_cls()(
**kwargs['neuron_config'])
config = neuronx_model_cls.get_config_cls()(
neuron_config,
load_config=load_pretrained_config(model_name_or_path))
draft_neuron_config = copy.deepcopy(config.neuron_config)
if not config.neuron_config.enable_eagle_speculation:
draft_neuron_config.speculation_length = 0
draft_neuron_config.trace_tokengen_model = True
draft_neuron_config.enable_fused_speculation = False
if getattr(config.neuron_config, "draft_model_modules_to_not_convert",
None):
draft_neuron_config.modules_to_not_convert = (
draft_neuron_config.draft_model_modules_to_not_convert)
if config.neuron_config.enable_eagle_speculation:
draft_neuron_config.is_eagle_draft = True
draft_neuron_config.sequence_parallel_enabled = False
draft_config = neuronx_model_cls.get_config_cls()(
draft_neuron_config,
load_config=load_pretrained_config(draft_model_name_or_path))
fused_spec_config = (FusedSpecNeuronConfig(
neuronx_model_cls._model_cls,
draft_config=draft_config,
draft_model_path=draft_model_name_or_path))
config.fused_spec_config = fused_spec_config
self.config.neuron_config = neuron_config
hashed_config = hashlib.md5(config.to_json_string().encode('utf-8'),
usedforsecurity=False).hexdigest()
if os.getenv("NEURON_COMPILED_ARTIFACTS") is not None:
compiled_model_path = os.getenv("NEURON_COMPILED_ARTIFACTS")
elif os.path.exists(model_name_or_path):
compiled_model_path = os.path.join(model_name_or_path,
"neuron-compiled-artifacts",
hashed_config)
shutil.rmtree(compiled_model_path, ignore_errors=True)
else:
compiled_model_path = os.path.join("local-models",
model_name_or_path,
"neuron-compiled-artifacts",
hashed_config)
shutil.rmtree(compiled_model_path, ignore_errors=True)
try:
self.model = neuronx_model_cls(compiled_model_path)
override_neuron_config = kwargs["override_neuron_config"]
for k, v in override_neuron_config.items():
setattr(self.model.config.neuron_config, k, v)
self.model.load(compiled_model_path)
return
except (FileNotFoundError, ValueError) as e:
logger.warning("Exception: %s", e)
logger.warning("Failed to load the model from %s Recompiling...",
compiled_model_path)
if not os.path.exists(model_name_or_path):
hf_model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
saved_path = os.path.join("local-models", model_name_or_path)
hf_model.save_pretrained(saved_path)
model_name_or_path = saved_path
if not os.path.exists(draft_model_name_or_path):
if draft_model_name_or_path != model_name_or_path:
hf_model = AutoModelForCausalLM.from_pretrained(
draft_model_name_or_path)
saved_path = os.path.join("local-models",
draft_model_name_or_path)
hf_model.save_pretrained(saved_path)
draft_model_name_or_path = saved_path
else:
draft_model_name_or_path = model_name_or_path
config.fused_spec_config.draft_model_path = draft_model_name_or_path
self.model = neuronx_model_cls(model_name_or_path, config)
self.model.compile(compiled_model_path)
self.model.load(compiled_model_path)
def _get_model_architecture(config: PretrainedConfig) -> str:
architectures = getattr(config, "architectures", [])
for arch in architectures:
if arch in _NEURON_SUPPORTED_MODELS:
return arch
raise ValueError(
f"Model architectures {architectures} are not supported on Neuron "
f"for now. Supported architectures: "
f"{list(_NEURON_SUPPORTED_MODELS.keys())}")
def _get_default_neuron_config(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
lora_serving_config: LoraServingConfig):
"""Generate a neuron config based on vllm config args."""
on_device_sampling_config = OnDeviceSamplingConfig(dynamic=True,
deterministic=False)
batch_size = scheduler_config.max_num_seqs
neuron_config = dict(
tp_degree=parallel_config.tensor_parallel_size,
ctx_batch_size=1,
batch_size=batch_size,
max_context_length=scheduler_config.max_model_len,
seq_len=scheduler_config.max_model_len,
enable_bucketing=True,
is_continuous_batching=True,
quantized=False,
torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
padding_side="right",
on_device_sampling_config=on_device_sampling_config,
sequence_parallel_enabled=True,
lora_serving_config=lora_serving_config)
return neuron_config
def _get_default_speculation_config(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
speculation_config: SpeculativeConfig):
"""Generate a neuron config for speculative decoding based on vllm config
args."""
neuron_config = dict(
tp_degree=parallel_config.tensor_parallel_size,
ctx_batch_size=1,
batch_size=scheduler_config.max_num_seqs,
max_context_length=scheduler_config.max_model_len,
seq_len=scheduler_config.max_model_len,
speculation_length=speculation_config.num_speculative_tokens,
trace_tokengen_model=False,
enable_fused_speculation=True,
enable_bucketing=True,
is_continuous_batching=True,
quantized=False,
torch_dtype=TORCH_DTYPE_TO_NEURON_AMP[model_config.dtype],
on_device_sampling_config=dict(
top_k=1,
do_sample=False,
))
return neuron_config
def _get_neuron_config_after_override(default_neuron_config,
overridden_neuron_config):
"""Update default neuron config values with override args"""
overridden_neuron_config = overridden_neuron_config or {}
default_neuron_config.update(overridden_neuron_config)
return default_neuron_config
def get_neuron_model(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
lora_serving_config: LoraServingConfig) -> nn.Module:
"""Initializes a neuron-optimized model for inference."""
model_arch = _get_model_architecture(model_config.hf_config)
if model_arch == "MllamaForConditionalGeneration":
model = NeuronMllamaForCausalLM(model_config.hf_config)
else:
model = NeuronCausalLM(model_config.hf_config)
default_neuron_config_args = _get_default_neuron_config(
model_config, parallel_config, scheduler_config, lora_serving_config)
neuron_config = _get_neuron_config_after_override(
default_neuron_config_args, model_config.override_neuron_config)
override_neuron_config = model_config.override_neuron_config
model.load_weights(model_config.model,
neuron_config=neuron_config,
override_neuron_config=override_neuron_config)
return model.eval()
def get_neuron_speculation_model(model_config: ModelConfig,
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
speculation_config: SpeculativeConfig):
"""Initializes a neuron-optimized speculation model for inference.
This model handles speculation using both a draft model and an EAGLE draft.
"""
model = NeuronSpeculationCausalLM(model_config.hf_config)
default_neuron_config_args = _get_default_speculation_config(
model_config, parallel_config, scheduler_config, speculation_config)
neuron_config = _get_neuron_config_after_override(
default_neuron_config_args, model_config.override_neuron_config)
override_neuron_config = model_config.override_neuron_config
model.load_weights(model_config.model,
speculation_config.draft_model_config.model,
neuron_config=neuron_config,
override_neuron_config=override_neuron_config)
return model.eval()

View File

@@ -0,0 +1,109 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: SIM117
import glob
import os
from collections.abc import Generator
from typing import Optional
import torch
from torch import nn
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import LoadConfig, ModelConfig
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf, download_weights_from_hf,
runai_safetensors_weights_iterator)
from vllm.transformers_utils.s3_utils import glob as s3_glob
from vllm.transformers_utils.utils import is_s3
class RunaiModelStreamerLoader(BaseModelLoader):
"""
Model loader that can load safetensors
files from local FS or S3 bucket.
"""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
extra_config = load_config.model_loader_extra_config
if ("concurrency" in extra_config
and isinstance(extra_config.get("concurrency"), int)):
os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(
extra_config.get("concurrency"))
if ("memory_limit" in extra_config
and isinstance(extra_config.get("memory_limit"), int)):
os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(
extra_config.get("memory_limit"))
runai_streamer_s3_endpoint = os.getenv(
'RUNAI_STREAMER_S3_ENDPOINT')
aws_endpoint_url = os.getenv('AWS_ENDPOINT_URL')
if (runai_streamer_s3_endpoint is None
and aws_endpoint_url is not None):
os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url
def _prepare_weights(self, model_name_or_path: str,
revision: Optional[str]) -> list[str]:
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
is_s3_path = is_s3(model_name_or_path)
is_local = os.path.isdir(model_name_or_path)
safetensors_pattern = "*.safetensors"
index_file = SAFE_WEIGHTS_INDEX_NAME
hf_folder = (model_name_or_path if
(is_local or is_s3_path) else download_weights_from_hf(
model_name_or_path,
self.load_config.download_dir,
[safetensors_pattern],
revision,
ignore_patterns=self.load_config.ignore_patterns,
))
if is_s3_path:
hf_weights_files = s3_glob(path=hf_folder,
allow_pattern=[safetensors_pattern])
else:
hf_weights_files = glob.glob(
os.path.join(hf_folder, safetensors_pattern))
if not is_local and not is_s3_path:
download_safetensors_index_file_from_hf(
model_name_or_path, index_file, self.load_config.download_dir,
revision)
if not hf_weights_files:
raise RuntimeError(
f"Cannot find any safetensors model weights with "
f"`{model_name_or_path}`")
return hf_weights_files
def _get_weights_iterator(
self, model_or_path: str,
revision: str) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
hf_weights_files = self._prepare_weights(model_or_path, revision)
return runai_safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
def download_model(self, model_config: ModelConfig) -> None:
"""Download model if necessary"""
self._prepare_weights(model_config.model, model_config.revision)
def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> None:
"""Load weights into a model."""
model_weights = model_config.model
if hasattr(model_config, "model_weights"):
model_weights = model_config.model_weights
model.load_weights(
self._get_weights_iterator(model_weights, model_config.revision))

View File

@@ -0,0 +1,201 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import collections
import glob
import os
from collections.abc import Generator
from typing import Any, Optional
import torch
from torch import nn
from vllm.config import LoadConfig, ModelConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.weight_utils import (
download_weights_from_hf, runai_safetensors_weights_iterator)
from vllm.transformers_utils.s3_utils import glob as s3_glob
from vllm.transformers_utils.utils import is_s3
logger = init_logger(__name__)
class ShardedStateLoader(BaseModelLoader):
"""
Model loader that directly loads each worker's model state dict, which
enables a fast load path for large tensor-parallel models where each worker
only needs to read its own shard rather than the entire checkpoint. See
`examples/offline_inference/save_sharded_state.py` for creating a sharded
checkpoint.
"""
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
def __init__(self,
load_config: LoadConfig,
runai_model_streamer: bool = False):
super().__init__(load_config)
self.runai_model_streamer = runai_model_streamer
extra_config = ({} if load_config.model_loader_extra_config is None
else load_config.model_loader_extra_config.copy())
self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
if extra_config:
raise ValueError(f"Unexpected extra config keys for load format "
f"{load_config.load_format}: "
f"{load_config.model_loader_extra_config.keys()}")
@staticmethod
def _filter_subtensors(
tensors: dict[str, torch.Tensor], ) -> dict[str, torch.Tensor]:
"""
Filter out all tensors that share the same memory or a subset of the
memory of another tensor.
"""
same_storage_groups: dict[Any, list[tuple[str, torch.Tensor]]] = (
collections.defaultdict(list))
for key, tensor in tensors.items():
if tensor.numel():
ptr = tensor.untyped_storage().data_ptr()
same_storage_groups[tensor.device, ptr].append((key, tensor))
def get_end_ptr(tensor: torch.Tensor) -> int:
return tensor.view(-1)[-1].data_ptr() + tensor.element_size()
result: dict[str, torch.Tensor] = {}
for group in same_storage_groups.values():
for k, t in group:
a, b = t.data_ptr(), get_end_ptr(t)
for k2, t2 in group:
if not t2.is_contiguous():
continue
a2, b2 = t2.data_ptr(), get_end_ptr(t2)
if a < a2 or b2 < b:
continue
if a2 < a or b < b2 or not t.is_contiguous():
break # t2 covers strictly more memory than t.
if k2 < k:
# Same tensors, keep the one with the smaller key.
break
else:
result[k] = t
return result
def _prepare_weights(self, model_name_or_path: str,
revision: Optional[str]):
if is_s3(model_name_or_path) or os.path.isdir(model_name_or_path):
return model_name_or_path
else:
allow_patterns = ["*.safetensors"]
return download_weights_from_hf(
model_name_or_path,
self.load_config.download_dir,
allow_patterns,
revision,
ignore_patterns=self.load_config.ignore_patterns,
)
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)
def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> None:
from vllm.distributed import get_tensor_model_parallel_rank
model_weights = model_config.model
if hasattr(model_config, "model_weights"):
model_weights = model_config.model_weights
local_model_path = model_weights
rank = get_tensor_model_parallel_rank()
pattern = os.path.join(
local_model_path,
self.pattern.format(rank=rank, part="*"),
)
filepaths = []
if is_s3(local_model_path):
file_pattern = f"*{self.pattern.format(rank=rank, part=' * ')}"
filepaths = s3_glob(path=local_model_path,
allow_pattern=[file_pattern])
else:
filepaths = glob.glob(pattern)
if not filepaths:
# TODO: support un-sharded checkpoints too
raise ValueError(
f"Could not find checkpoint files '{pattern}', only "
f"pre-sharded checkpoints are currently supported!")
state_dict = self._filter_subtensors(model.state_dict())
for key, tensor in self.iterate_over_files(filepaths):
# If loading with LoRA enabled, additional padding may
# be added to certain parameters. We only load into a
# narrowed view of the parameter data.
param_data = state_dict[key].data
param_shape = state_dict[key].shape
for dim, size in enumerate(tensor.shape):
if size < param_shape[dim]:
param_data = param_data.narrow(dim, 0, size)
if tensor.shape != param_shape:
logger.warning(
"loading tensor of shape %s into "
"parameter '%s' of shape %s",
tensor.shape,
key,
param_shape,
)
param_data.copy_(tensor)
state_dict.pop(key)
if state_dict:
raise ValueError(
f"Missing keys {tuple(state_dict)} in loaded state!")
def iterate_over_files(
self, paths) -> Generator[tuple[str, torch.Tensor], None, None]:
if self.runai_model_streamer:
yield from runai_safetensors_weights_iterator(paths, True)
else:
from safetensors.torch import safe_open
for path in paths:
with safe_open(path, framework="pt") as f:
for key in f.keys(): # noqa: SIM118
tensor = f.get_tensor(key)
yield key, tensor
@staticmethod
def save_model(
model: torch.nn.Module,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
from safetensors.torch import save_file
from vllm.distributed import get_tensor_model_parallel_rank
if pattern is None:
pattern = ShardedStateLoader.DEFAULT_PATTERN
rank = get_tensor_model_parallel_rank()
part_idx = 0
total_size = 0
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
state_dict_part: dict[str, torch.Tensor] = {}
for key, tensor in state_dict.items():
param_size = tensor.nelement() * tensor.element_size()
if max_size is not None and total_size + param_size > max_size:
filename = pattern.format(rank=rank, part=part_idx)
save_file(
state_dict_part,
os.path.join(path, filename),
)
part_idx += 1
total_size = 0
state_dict_part = {}
state_dict_part[key] = tensor
total_size += param_size
if len(state_dict_part) > 0:
filename = pattern.format(rank=rank, part=part_idx)
save_file(
state_dict_part,
os.path.join(path, filename),
)

View File

@@ -0,0 +1,600 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
import contextlib
import contextvars
import dataclasses
import io
import json
import os
import threading
import time
from collections.abc import Generator
from dataclasses import dataclass
from functools import partial
from typing import Any, BinaryIO, Optional, Union
import regex as re
import torch
from torch import nn
from torch.utils._python_dispatch import TorchDispatchMode
from transformers import PretrainedConfig
import vllm.envs as envs
from vllm.config import (ModelConfig, ParallelConfig, VllmConfig,
set_current_vllm_config)
from vllm.engine.arg_utils import EngineArgs
from vllm.logger import init_logger
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.utils import FlexibleArgumentParser, PlaceholderModule
try:
from tensorizer import (DecryptionParams, EncryptionParams,
TensorDeserializer, TensorSerializer)
from tensorizer.stream_io import open_stream
from tensorizer.utils import (convert_bytes, get_mem_usage,
no_init_or_tensor)
_read_stream, _write_stream = (partial(
open_stream,
mode=mode,
) for mode in ("rb", "wb+"))
except ImportError:
tensorizer = PlaceholderModule("tensorizer")
DecryptionParams = tensorizer.placeholder_attr("DecryptionParams")
EncryptionParams = tensorizer.placeholder_attr("EncryptionParams")
TensorDeserializer = tensorizer.placeholder_attr("TensorDeserializer")
TensorSerializer = tensorizer.placeholder_attr("TensorSerializer")
open_stream = tensorizer.placeholder_attr("stream_io.open_stream")
convert_bytes = tensorizer.placeholder_attr("utils.convert_bytes")
get_mem_usage = tensorizer.placeholder_attr("utils.get_mem_usage")
no_init_or_tensor = tensorizer.placeholder_attr("utils.no_init_or_tensor")
_read_stream = tensorizer.placeholder_attr("_read_stream")
_write_stream = tensorizer.placeholder_attr("_write_stream")
__all__ = [
'EncryptionParams', 'DecryptionParams', 'TensorDeserializer',
'TensorSerializer', 'open_stream', 'convert_bytes', 'get_mem_usage',
'no_init_or_tensor', 'TensorizerConfig'
]
logger = init_logger(__name__)
class MetaTensorMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
if func._schema.name == "aten::empty" and "device" not in kwargs:
kwargs["device"] = "meta"
return func(*args, **kwargs)
def meta_tensor_mode(loading_code=None, ):
if loading_code is None:
return _NoInitOrTensorImpl.context_manager()
elif callable(loading_code):
with _NoInitOrTensorImpl.context_manager():
return loading_code()
else:
raise TypeError(
"expected a callable to evaluate,"
" or None if being used as a context manager;"
f' got an object of type "{type(loading_code).__name__}" instead.')
class _NoInitOrTensorImpl:
_MODULES = (torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm)
_MODULE_ORIGINALS = tuple((m, m.reset_parameters) for m in _MODULES)
is_active = contextvars.ContextVar("_NoInitOrTensorImpl.is_active",
default=False)
_count_active: int = 0
_count_active_lock = threading.Lock()
@classmethod
@contextlib.contextmanager
def context_manager(cls):
if cls.is_active.get():
yield
return
with cls._count_active_lock:
cls._count_active += 1
if cls._count_active == 1:
for mod in cls._MODULES:
mod.reset_parameters = cls._disable(mod.reset_parameters)
reset_token = cls.is_active.set(True)
try:
with MetaTensorMode():
yield
finally:
cls.is_active.reset(reset_token)
with cls._count_active_lock:
cls._count_active -= 1
if cls._count_active == 0:
for mod, original in cls._MODULE_ORIGINALS:
mod.reset_parameters = original
@staticmethod
def _disable(func):
def wrapper(*args, **kwargs):
if not _NoInitOrTensorImpl.is_active.get():
return func(*args, **kwargs)
return wrapper
@dataclass
class TensorizerConfig:
tensorizer_uri: Union[str, None] = None
vllm_tensorized: Optional[bool] = False
verify_hash: Optional[bool] = False
num_readers: Optional[int] = None
encryption_keyfile: Optional[str] = None
s3_access_key_id: Optional[str] = None
s3_secret_access_key: Optional[str] = None
s3_endpoint: Optional[str] = None
model_class: Optional[type[torch.nn.Module]] = None
hf_config: Optional[PretrainedConfig] = None
dtype: Optional[Union[str, torch.dtype]] = None
lora_dir: Optional[str] = None
_is_sharded: bool = False
def __post_init__(self):
# check if the configuration is for a sharded vLLM model
self._is_sharded = isinstance(self.tensorizer_uri, str) \
and re.search(r'%0\dd', self.tensorizer_uri) is not None
if not self.tensorizer_uri and not self.lora_dir:
raise ValueError("tensorizer_uri must be provided.")
if not self.tensorizer_uri and self.lora_dir:
self.tensorizer_uri = f"{self.lora_dir}/adapter_model.tensors"
assert self.tensorizer_uri is not None, ("tensorizer_uri must be "
"provided.")
self.tensorizer_dir = os.path.dirname(self.tensorizer_uri)
self.lora_dir = self.tensorizer_dir
@classmethod
def as_dict(cls, *args, **kwargs) -> dict[str, Any]:
cfg = TensorizerConfig(*args, **kwargs)
return dataclasses.asdict(cfg)
def to_dict(self) -> dict[str, Any]:
return dataclasses.asdict(self)
def _construct_tensorizer_args(self) -> "TensorizerArgs":
tensorizer_args = {
"tensorizer_uri": self.tensorizer_uri,
"vllm_tensorized": self.vllm_tensorized,
"verify_hash": self.verify_hash,
"num_readers": self.num_readers,
"encryption_keyfile": self.encryption_keyfile,
"s3_access_key_id": self.s3_access_key_id,
"s3_secret_access_key": self.s3_secret_access_key,
"s3_endpoint": self.s3_endpoint,
}
return TensorizerArgs(**tensorizer_args) # type: ignore
def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",
) -> None:
if parallel_config.tensor_parallel_size > 1 \
and not self._is_sharded:
raise ValueError(
"For a sharded model, tensorizer_uri should include a"
" string format template like '%04d' to be formatted"
" with the rank of the shard")
def verify_with_model_config(self, model_config: "ModelConfig") -> None:
if (model_config.quantization is not None
and self.tensorizer_uri is not None):
logger.warning(
"Loading a model using Tensorizer with quantization on vLLM"
" is unstable and may lead to errors.")
def open_stream(self, tensorizer_args: Optional["TensorizerArgs"] = None):
if tensorizer_args is None:
tensorizer_args = self._construct_tensorizer_args()
return open_stream(self.tensorizer_uri,
**tensorizer_args.stream_params)
@dataclass
class TensorizerArgs:
tensorizer_uri: Union[io.BufferedIOBase, io.RawIOBase, BinaryIO, str,
bytes, os.PathLike, int]
vllm_tensorized: Optional[bool] = False
verify_hash: Optional[bool] = False
num_readers: Optional[int] = None
encryption_keyfile: Optional[str] = None
s3_access_key_id: Optional[str] = None
s3_secret_access_key: Optional[str] = None
s3_endpoint: Optional[str] = None
"""
Args for the TensorizerAgent class. These are used to configure the behavior
of the TensorDeserializer when loading tensors from a serialized model.
Args:
tensorizer_uri: Path to serialized model tensors. Can be a local file
path or a S3 URI. This is a required field unless lora_dir is
provided and the config is meant to be used for the
`tensorize_lora_adapter` function.
vllm_tensorized: If True, indicates that the serialized model is a
vLLM model. This is used to determine the behavior of the
TensorDeserializer when loading tensors from a serialized model.
It is far faster to deserialize a vLLM model as it utilizes
tensorizer's optimized GPU loading. Note that this is now
deprecated, as serialized vLLM models are now automatically
inferred as vLLM models.
verify_hash: If True, the hashes of each tensor will be verified against
the hashes stored in the metadata. A `HashMismatchError` will be
raised if any of the hashes do not match.
num_readers: Controls how many threads are allowed to read concurrently
from the source file. Default is `None`, which will dynamically set
the number of readers based on the number of available
resources and model size. This greatly increases performance.
encryption_keyfile: File path to a binary file containing a
binary key to use for decryption. `None` (the default) means
no decryption. See the example script in
examples/others/tensorize_vllm_model.py.
s3_access_key_id: The access key for the S3 bucket. Can also be set via
the S3_ACCESS_KEY_ID environment variable.
s3_secret_access_key: The secret access key for the S3 bucket. Can also
be set via the S3_SECRET_ACCESS_KEY environment variable.
s3_endpoint: The endpoint for the S3 bucket. Can also be set via the
S3_ENDPOINT_URL environment variable.
"""
def __post_init__(self):
self.file_obj = self.tensorizer_uri
self.s3_access_key_id = self.s3_access_key_id or envs.S3_ACCESS_KEY_ID
self.s3_secret_access_key = (self.s3_secret_access_key
or envs.S3_SECRET_ACCESS_KEY)
self.s3_endpoint = self.s3_endpoint or envs.S3_ENDPOINT_URL
self.stream_params = {
"s3_access_key_id": self.s3_access_key_id,
"s3_secret_access_key": self.s3_secret_access_key,
"s3_endpoint": self.s3_endpoint,
}
self.deserializer_params = {
"verify_hash": self.verify_hash,
"encryption": self.encryption_keyfile,
"num_readers": self.num_readers
}
if self.encryption_keyfile:
with open_stream(
self.encryption_keyfile,
**self.stream_params,
) as stream:
key = stream.read()
decryption_params = DecryptionParams.from_key(key)
self.deserializer_params['encryption'] = decryption_params
@staticmethod
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"""Tensorizer CLI arguments"""
# Tensorizer options arg group
group = parser.add_argument_group(
'tensorizer options',
description=('Options for configuring the behavior of the'
' tensorizer deserializer when '
'load_format=tensorizer is specified when '
'initializing an LLMEngine, either via the CLI '
'when running the vLLM OpenAI inference server '
'with a JSON string passed to '
'--model-loader-extra-config or as arguments given '
'to TensorizerConfig when passed to '
'model_loader_extra_config in the constructor '
'for LLMEngine.'))
group.add_argument(
"--tensorizer-uri",
type=str,
help="Path to serialized model tensors. Can be a local file path,"
" or an HTTP(S) or S3 URI.",
)
group.add_argument(
"--verify-hash",
action="store_true",
help="If enabled, the hashes of each tensor will be verified"
" against the hashes stored in the file metadata. An exception"
" will be raised if any of the hashes do not match.",
)
group.add_argument(
"--encryption-keyfile",
type=str,
default=None,
help="The file path to a binary file containing a binary key to "
"use for decryption. Can be a file path or S3 network URI.")
group.add_argument(
"--num-readers",
default=None,
type=int,
help="Controls how many threads are allowed to read concurrently "
"from the source file. Default is `None`, which will dynamically "
"set the number of readers based on the available resources "
"and model size. This greatly increases performance.")
group.add_argument(
"--s3-access-key-id",
type=str,
default=None,
help="The access key for the S3 bucket. Can also be set via the "
"S3_ACCESS_KEY_ID environment variable.",
)
group.add_argument(
"--s3-secret-access-key",
type=str,
default=None,
help="The secret access key for the S3 bucket. Can also be set via "
"the S3_SECRET_ACCESS_KEY environment variable.",
)
group.add_argument(
"--s3-endpoint",
type=str,
default=None,
help="The endpoint for the S3 bucket. Can also be set via the "
"S3_ENDPOINT_URL environment variable.",
)
return parser
@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> "TensorizerArgs":
attrs = [attr.name for attr in dataclasses.fields(cls)]
tensorizer_args = cls(**{
attr: getattr(args, attr)
for attr in attrs if hasattr(args, attr)
})
return tensorizer_args
def _check_tensors_on_meta_device(model: nn.Module) -> None:
for tensor in model.state_dict().values():
if tensor.device.type == 'meta':
raise ValueError(
"The serialized model contains tensors on the meta device,"
" indicating that some tensors were not loaded properly."
" Please check that the parameters of the model being"
" specified match that of the serialized model, such as"
" its quantization.")
def _resize_lora_embeddings(model: nn.Module):
"""Modify LoRA embedding layers to use bigger tensors
to allow for adapter added tokens."""
for child in model.modules():
if (isinstance(child, VocabParallelEmbedding) and child.weight.shape[0]
< child.num_embeddings_per_partition):
new_weight = torch.empty(child.num_embeddings_per_partition,
child.embedding_dim,
dtype=child.weight.dtype,
device=child.weight.device)
new_weight[:child.weight.shape[0]].copy_(child.weight.data)
new_weight[child.weight.shape[0]:].fill_(0)
child.weight.data = new_weight
def init_tensorizer_model(tensorizer_config: TensorizerConfig,
vllm_config: VllmConfig) -> nn.Module:
assert tensorizer_config.hf_config is not None
model_args = tensorizer_config.hf_config
model_args.torch_dtype = tensorizer_config.dtype
assert tensorizer_config.model_class is not None
# TODO: Do we need to consider old-style model class?
with meta_tensor_mode(), set_current_vllm_config(vllm_config,
check_compile=True):
return tensorizer_config.model_class(vllm_config=vllm_config)
def deserialize_tensorizer_model(model: nn.Module,
tensorizer_config: TensorizerConfig) -> None:
tensorizer_args = tensorizer_config._construct_tensorizer_args()
before_mem = get_mem_usage()
start = time.perf_counter()
with _read_stream(
tensorizer_config.tensorizer_uri,
**tensorizer_args.stream_params) as stream, TensorDeserializer(
stream,
dtype=tensorizer_config.dtype,
device=f'cuda:{torch.cuda.current_device()}',
**tensorizer_args.deserializer_params) as deserializer:
deserializer.load_into_module(model)
end = time.perf_counter()
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
duration = end - start
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
after_mem = get_mem_usage()
deserializer.close()
logger.info("Deserialized %s in %0.2fs, %s/s", total_bytes_str,
end - start, per_second)
logger.info("Memory usage before: %s", before_mem)
logger.info("Memory usage after: %s", after_mem)
_check_tensors_on_meta_device(model)
_resize_lora_embeddings(model)
del model.vllm_tensorized_marker
def tensorizer_weights_iterator(
tensorizer_args: "TensorizerArgs"
) -> Generator[tuple[str, torch.Tensor], None, None]:
logger.warning("Deserializing HuggingFace models is not optimized for "
"loading on vLLM, as tensorizer is forced to load to CPU. "
"Consider deserializing a vLLM model instead for faster "
"load times. See the "
"examples/others/tensorize_vllm_model.py example script "
"for serializing vLLM models.")
deserializer_args = tensorizer_args.deserializer_params
stream_params = tensorizer_args.stream_params
stream = open_stream(tensorizer_args.tensorizer_uri, **stream_params)
with TensorDeserializer(stream, **deserializer_args,
device="cpu") as state:
yield from state.items()
del state
def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool:
"""
Infer if the model is a vLLM model by checking the weights for
a vLLM tensorized marker.
Args:
tensorizer_config: The TensorizerConfig object containing the
tensorizer_uri to the serialized model.
Returns:
bool: True if the model is a vLLM model, False otherwise.
"""
tensorizer_args = tensorizer_config._construct_tensorizer_args()
deserializer = TensorDeserializer(open_stream(
tensorizer_args.tensorizer_uri, **tensorizer_args.stream_params),
**tensorizer_args.deserializer_params,
lazy_load=True)
if tensorizer_config.vllm_tensorized:
logger.warning(
"Please note that newly serialized vLLM models are automatically "
"inferred as vLLM models, so setting vllm_tensorized=True is "
"only necessary for models serialized prior to this change.")
return True
return ".vllm_tensorized_marker" in deserializer
def serialize_vllm_model(
model: nn.Module,
tensorizer_config: TensorizerConfig,
) -> nn.Module:
model.register_parameter(
"vllm_tensorized_marker",
nn.Parameter(torch.tensor((1, ), device="meta"), requires_grad=False))
tensorizer_args = tensorizer_config._construct_tensorizer_args()
encryption_params = None
if (keyfile := tensorizer_config.encryption_keyfile) is not None:
with open(keyfile, "rb") as f:
key = f.read()
encryption_params = EncryptionParams(key=key)
output_file = tensorizer_args.tensorizer_uri
if tensorizer_config._is_sharded:
from vllm.distributed import get_tensor_model_parallel_rank
output_file = output_file % get_tensor_model_parallel_rank()
with _write_stream(output_file, **tensorizer_args.stream_params) as stream:
serializer = TensorSerializer(stream, encryption=encryption_params)
serializer.write_module(model)
serializer.close()
logger.info("Successfully serialized model to %s", str(output_file))
return model
def tensorize_vllm_model(engine_args: EngineArgs,
tensorizer_config: TensorizerConfig,
generate_keyfile: bool = True):
"""Utility to load a model and then serialize it with Tensorizer
Intended to be used separately from running a vLLM server since it
creates its own Engine instance.
"""
engine_config = engine_args.create_engine_config()
tensorizer_config.verify_with_model_config(engine_config.model_config)
tensorizer_config.verify_with_parallel_config(
engine_config.parallel_config)
# generate the encryption key before creating the engine to support sharding
if generate_keyfile and (keyfile :=
tensorizer_config.encryption_keyfile) is not None:
encryption_params = EncryptionParams.random()
with _write_stream(
keyfile,
s3_access_key_id=tensorizer_config.s3_access_key_id,
s3_secret_access_key=tensorizer_config.s3_secret_access_key,
s3_endpoint=tensorizer_config.s3_endpoint,
) as stream:
stream.write(encryption_params.key)
from vllm import LLMEngine
from vllm.v1.engine.llm_engine import LLMEngine as V1LLMEngine
if not envs.VLLM_USE_V1:
engine = LLMEngine.from_engine_args(engine_args)
engine.model_executor.collective_rpc(
"save_tensorized_model",
kwargs=dict(tensorizer_config=tensorizer_config),
)
else:
engine = V1LLMEngine.from_vllm_config(engine_config)
engine.collective_rpc(
"save_tensorized_model",
kwargs=dict(tensorizer_config=tensorizer_config),
)
def tensorize_lora_adapter(lora_path: str,
tensorizer_config: TensorizerConfig):
"""
Uses tensorizer to serialize a LoRA adapter. Assumes that the files
needed to load a LoRA adapter are a safetensors-format file called
adapter_model.safetensors and a json config file called adapter_config.json.
Serializes the files in the tensorizer_config.lora_dir
"""
import safetensors
from vllm.lora.utils import get_adapter_absolute_path
lora_dir = get_adapter_absolute_path(lora_path)
tensor_path = config_path = ""
for file in os.listdir(lora_dir):
if file.startswith("adapter_model"):
tensor_path = lora_dir + "/" + file
if file.startswith("adapter_config"):
config_path = lora_dir + "/" + file
if tensor_path and config_path:
break
if tensor_path.endswith(".safetensors"):
tensors = safetensors.torch.load_file(tensor_path)
elif tensor_path.endswith(".bin"):
tensors = torch.load(tensor_path)
else:
raise ValueError("Unsupported file: %s", tensor_path)
with open(config_path) as f:
config = json.load(f)
tensorizer_args = tensorizer_config._construct_tensorizer_args()
with open_stream(f"{tensorizer_config.lora_dir}/adapter_config.json",
mode="wb+",
**tensorizer_args.stream_params) as f:
f.write(json.dumps(config).encode("utf-8"))
lora_uri = (f"{tensorizer_config.lora_dir}"
f"/adapter_model.tensors")
with open_stream(lora_uri, mode="wb+",
**tensorizer_args.stream_params) as f:
serializer = TensorSerializer(f)
serializer.write_state_dict(tensors)
serializer.close()
logger.info("Successfully serialized LoRA files to %s",
str(tensorizer_config.lora_dir))

View File

@@ -0,0 +1,123 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa: SIM117
import copy
from collections.abc import Generator
from typing import Union
import torch
from torch import nn
from vllm.config import LoadConfig, ModelConfig, ParallelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.tensorizer import (
TensorizerConfig, deserialize_tensorizer_model, init_tensorizer_model,
is_vllm_tensorized, serialize_vllm_model, tensorizer_weights_iterator)
from vllm.model_executor.model_loader.utils import (get_model_architecture,
initialize_model,
set_default_torch_dtype)
logger = init_logger(__name__)
class TensorizerLoader(BaseModelLoader):
"""Model loader using CoreWeave's tensorizer library."""
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if isinstance(load_config.model_loader_extra_config, TensorizerConfig):
self.tensorizer_config = load_config.model_loader_extra_config
else:
self.tensorizer_config = TensorizerConfig(
**load_config.model_loader_extra_config)
def _verify_config(self, model_config: ModelConfig,
parallel_config: ParallelConfig):
self.tensorizer_config.verify_with_model_config(model_config)
self.tensorizer_config.verify_with_parallel_config(parallel_config)
def _get_weights_iterator(
self, ) -> Generator[tuple[str, torch.Tensor], None, None]:
tensorizer_args = self.tensorizer_config._construct_tensorizer_args()
return tensorizer_weights_iterator(tensorizer_args)
def _load_model_serialized_cpu(
self,
vllm_config: VllmConfig,
) -> nn.Module:
"""Load a serialized model with tensorizer to the CPU.
This is only necessary when the model isn't vLLM-tensorized (see
examples/others/tensorize_vllm_model.py) This should still
be faster than default HuggingFace loading, but will be slower than
loading a vLLM-tensorized model.
"""
device_config = vllm_config.device_config
model_config = vllm_config.model_config
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = initialize_model(vllm_config=vllm_config)
model.load_weights(self._get_weights_iterator())
return model.eval()
def download_model(self, model_config: ModelConfig) -> None:
self.tensorizer_config.verify_with_model_config(model_config)
with self.tensorizer_config.open_stream():
pass
def _patch_tensorizer_config(
self, model_config: ModelConfig) -> TensorizerConfig:
model_class = get_model_architecture(model_config)[0]
tensorizer_config = copy.copy(self.tensorizer_config)
tensorizer_config.model_class = model_class
tensorizer_config.hf_config = model_config.hf_config
tensorizer_config.dtype = model_config.dtype
return tensorizer_config
def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> None:
"""Load serialized model weights with tensorizer.
Expects a vLLM-tensorized model. See the
examples/others/tensorize_vllm_model.py example script
for serializing vLLM models."""
if is_vllm_tensorized(self.tensorizer_config):
tensorizer_config = self._patch_tensorizer_config(model_config)
deserialize_tensorizer_model(model, tensorizer_config)
else:
model.load_weights(self._get_weights_iterator())
def load_model(self, vllm_config: VllmConfig,
model_config: ModelConfig) -> nn.Module:
parallel_config = vllm_config.parallel_config
self._verify_config(model_config, parallel_config)
if parallel_config.tensor_parallel_size > 1:
from vllm.distributed import get_tensor_model_parallel_rank
self.tensorizer_config.tensorizer_uri = (
self.tensorizer_config.tensorizer_uri %
get_tensor_model_parallel_rank())
if is_vllm_tensorized(self.tensorizer_config):
tensorizer_config = self._patch_tensorizer_config(model_config)
model = init_tensorizer_model(tensorizer_config=tensorizer_config,
vllm_config=vllm_config)
self.load_weights(model, model_config)
return model
return self._load_model_serialized_cpu(vllm_config=vllm_config)
@staticmethod
def save_model(
model: torch.nn.Module,
tensorizer_config: Union[TensorizerConfig, dict],
) -> None:
if isinstance(tensorizer_config, dict):
tensorizer_config = TensorizerConfig(**tensorizer_config)
serialize_vllm_model(
model=model,
tensorizer_config=tensorizer_config,
)

View File

@@ -0,0 +1,112 @@
# SPDX-License-Identifier: Apache-2.0
import time
from typing import Optional
import torch
import torch.nn as nn
import torch_xla.core.xla_model as xm
import torch_xla.distributed.spmd as xs
from vllm.config import ModelConfig, VllmConfig
from vllm.distributed.tpu_distributed_utils import get_fqn, shard_model
from vllm.logger import init_logger
from vllm.model_executor.model_loader.default_loader import DefaultModelLoader
from vllm.model_executor.model_loader.utils import (
initialize_model, process_weights_after_loading, set_default_torch_dtype)
logger = init_logger(__name__)
class TPUModelLoader(DefaultModelLoader):
"""
A TPU model loader for model loading under SPMD mode.
"""
def load_model(
self,
vllm_config: VllmConfig,
model_config: ModelConfig,
mesh: Optional[xs.Mesh] = None,
) -> nn.Module:
# Initialize model and load weights on CPU. Then, during SPMD partition,
# weights are sharded and transferred to TPUs.
self.counter_before_loading_weights = time.perf_counter()
model_config = vllm_config.model_config
assert model_config.quantization is None, "Quantization not supported"
target_device = torch.device('cpu')
with set_default_torch_dtype(model_config.dtype):
with target_device:
model = initialize_model(vllm_config=vllm_config)
load_format = vllm_config.load_config.load_format
if load_format != "dummy":
weights_to_load = {
name
for name, _ in model.named_parameters()
}
all_weights = self.get_all_weights(model_config, model)
loaded_weights = model.load_weights(all_weights)
self.counter_after_loading_weights = time.perf_counter()
logger.info(
"Loading weights took %.2f seconds",
self.counter_after_loading_weights -
self.counter_before_loading_weights)
# We only enable strict check for non-quantized models
# that have loaded weights tracking currently.
if model_config.quantization is None and \
loaded_weights is not None:
weights_not_loaded = weights_to_load - loaded_weights
if weights_not_loaded:
raise ValueError(
"Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}")
else:
logger.info("Use dummy weight during weight loading.")
process_weights_after_loading(model, model_config, target_device)
counter_before_partition = time.perf_counter()
model = model.eval()
model = model.to('xla')
shard_model(model, mesh)
counter_after_partition = time.perf_counter()
logger.info("Partition model took %.2f seconds",
counter_after_partition - counter_before_partition)
# Ensure the model is properly loaded.
self._check_model_is_loaded(mesh, model)
# Need to torch compile after model sharding are done. Because the
# compiler hints ('xs.mark_sharding') are torch ops.
if not model_config.is_multimodal_model:
model.model = torch.compile(model.model, backend="openxla")
else:
model.language_model.model = \
torch.compile(model.language_model.model, backend="openxla")
return model
def _check_model_is_loaded(self, mesh: Optional[xs.Mesh],
model: nn.Module) -> None:
"""
Ensure the model is properly loaded.
1. All model parameters and buffers are on XLA device.
2. Non-SPMD friendly layers are replaced as expected.
"""
device = xm.xla_device()
device_type = str(device.type)
# Check parameters
for name, param in model.named_parameters():
assert param.device.type == device_type, f"Parameter {name} is on \
{param.device.type} instead of {device_type}"
# Check buffers
for name, buffer in model.named_buffers():
assert buffer.device.type == device_type, \
f"Buffer {name} is on {buffer.device.type} instead of \
{device_type}"
for module in model.modules():
if (mesh is not None) and (get_fqn(module) == 'QKVParallelLinear'):
raise AssertionError("QKVParallelLinear should be replaced by \
XlaQKVParallelLinear under SPMD mode.")

View File

@@ -0,0 +1,302 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utilities for selecting and loading models."""
import contextlib
import inspect
import warnings
from contextlib import contextmanager
from dataclasses import dataclass, field
from typing import Optional
import torch
import transformers
from torch import nn
from transformers.dynamic_module_utils import get_class_from_dynamic_module
from vllm.attention import Attention
from vllm.config import (ModelConfig, ModelImpl, VllmConfig,
set_current_vllm_config)
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import QKVCrossParallelLinear
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.adapters import (as_classification_model,
as_embedding_model,
as_reward_model)
from vllm.utils import is_pin_memory_available
logger = init_logger(__name__)
@contextlib.contextmanager
def set_default_torch_dtype(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(old_dtype)
def initialize_model(
vllm_config: VllmConfig,
*,
prefix: str = "",
model_class: Optional[type[nn.Module]] = None,
model_config: Optional[ModelConfig] = None,
) -> nn.Module:
"""Initialize a model with the given configurations."""
if model_config is None:
model_config = vllm_config.model_config
if model_class is None:
model_class, _ = get_model_architecture(model_config)
if vllm_config.quant_config is not None:
configure_quant_config(vllm_config.quant_config, model_class)
signatures = inspect.signature(model_class.__init__)
all_params = [param.name for param in signatures.parameters.values()]
if "vllm_config" in all_params and "prefix" in all_params:
# new-style model class
with set_current_vllm_config(vllm_config, check_compile=True):
return model_class(vllm_config=vllm_config, prefix=prefix)
msg = ("vLLM model class should accept `vllm_config` and `prefix` as "
"input arguments. Possibly you have an old-style model class"
" registered from out of tree and it is used for new vLLM version. "
"Check https://docs.vllm.ai/en/latest/design/arch_overview.html "
"for the design and update the model class accordingly.")
warnings.warn(msg, DeprecationWarning, stacklevel=2)
logger.warning(
"Trying to guess the arguments for old-style model class %s",
model_class,
)
# try to be compatible with old-style model class
kwargs = {}
if "prefix" in all_params:
kwargs["prefix"] = prefix
if "config" in all_params:
kwargs["config"] = model_config.hf_config
if "cache_config" in all_params:
kwargs["cache_config"] = vllm_config.cache_config
if "quant_config" in all_params:
kwargs["quant_config"] = vllm_config.quant_config
if "lora_config" in all_params:
kwargs["lora_config"] = vllm_config.lora_config
if "scheduler_config" in all_params:
kwargs["scheduler_config"] = vllm_config.scheduler_config
with set_current_vllm_config(vllm_config, check_compile=True):
return model_class(**kwargs)
def process_weights_after_loading(model: nn.Module, model_config: ModelConfig,
target_device: torch.device) -> None:
for _, module in model.named_modules():
if isinstance(module, QKVCrossParallelLinear):
# NOTE(Isotr0py): special case for cross QKV layer because
# q and kv proj aren't registered as submodules intentionally
module.process_weights_after_loading()
continue
quant_method = getattr(module, "quant_method", None)
if isinstance(quant_method, QuantizeMethodBase):
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
# Currently only used by MLA.
# NOTE: This intentionally happens after other modules so we can easily
# decompress the weights for MLA.
for _, module in model.named_modules():
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# TODO(lucas): see if there is a way to unify the signatures
# of process_weights_after_loading
module.process_weights_after_loading(model_config.dtype)
@contextmanager
def device_loading_context(module: torch.nn.Module,
target_device: torch.device):
if target_device.type == "cpu":
# If target is CPU, no need to move anything
yield module
return
original_device_states: dict[str, torch.device] = {}
# Store original device states and move parameters to GPU if they're on CPU
for name, p in module.named_parameters():
if p.device.type == "cpu":
original_device_states[name] = p.device
p.data = p.data.to(target_device)
# Parameters already on target device are not touched
try:
yield module
finally:
# Restore parameters to their original devices, ignoring new parameters
pin_memory = is_pin_memory_available()
for name, p in module.named_parameters():
if name in original_device_states:
original_device: torch.device = original_device_states[name]
if original_device.type == "cpu":
# `torch.empty_like` does not support `pin_memory` argument
cpu_data = torch.empty_strided(
size=p.data.size(),
stride=p.data.stride(),
dtype=p.data.dtype,
layout=p.data.layout,
device="cpu",
pin_memory=pin_memory,
)
cpu_data.copy_(p.data)
p.data = cpu_data
else:
p.data = p.data.to(original_device)
# New parameters or parameters already on target device are untouched
def resolve_transformers_arch(model_config: ModelConfig,
architectures: list[str]):
for i, arch in enumerate(architectures):
if arch == "TransformersForCausalLM":
continue
auto_map: dict[str, str] = getattr(model_config.hf_config, "auto_map",
None) or dict()
# Make sure that config class is always initialized before model class,
# otherwise the model class won't be able to access the config class,
# the expected auto_map should have correct order like:
# "auto_map": {
# "AutoConfig": "<your-repo-name>--<config-name>",
# "AutoModel": "<your-repo-name>--<config-name>",
# "AutoModelFor<Task>": "<your-repo-name>--<config-name>",
# },
auto_modules = {
name:
get_class_from_dynamic_module(module,
model_config.model,
revision=model_config.revision)
for name, module in sorted(auto_map.items(), key=lambda x: x[0])
}
model_module = getattr(transformers, arch, None)
if model_module is None:
if "AutoModel" not in auto_map:
raise ValueError(
f"Cannot find model module. '{arch}' is not a registered "
"model in the Transformers library (only relevant if the "
"model is meant to be in Transformers) and 'AutoModel' is "
"not present in the model config's 'auto_map' (relevant "
"if the model is custom).")
model_module = auto_modules["AutoModel"]
# TODO(Isotr0py): Further clean up these raises.
# perhaps handled them in _ModelRegistry._raise_for_unsupported?
if model_config.model_impl == ModelImpl.TRANSFORMERS:
if not model_module.is_backend_compatible():
raise ValueError(
f"The Transformers implementation of {arch} is not "
"compatible with vLLM.")
architectures[i] = "TransformersForCausalLM"
if model_config.model_impl == ModelImpl.AUTO:
if not model_module.is_backend_compatible():
raise ValueError(
f"{arch} has no vLLM implementation and the Transformers "
"implementation is not compatible with vLLM. Try setting "
"VLLM_USE_V1=0.")
logger.warning(
"%s has no vLLM implementation, falling back to Transformers "
"implementation. Some features may not be supported and "
"performance may not be optimal.", arch)
architectures[i] = "TransformersForCausalLM"
return architectures
def get_model_architecture(
model_config: ModelConfig) -> tuple[type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", [])
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
mixtral_supported = [
"fp8", "compressed-tensors", "gptq_marlin", "awq_marlin", "quark"
]
vllm_supported_archs = ModelRegistry.get_supported_archs()
vllm_not_supported = not any(arch in vllm_supported_archs
for arch in architectures)
if (model_config.model_impl == ModelImpl.TRANSFORMERS or
model_config.model_impl != ModelImpl.VLLM and vllm_not_supported):
architectures = resolve_transformers_arch(model_config, architectures)
elif (model_config.quantization is not None
and model_config.quantization not in mixtral_supported
and "MixtralForCausalLM" in architectures):
architectures = ["QuantMixtralForCausalLM"]
model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
if model_config.task == "embed":
model_cls = as_embedding_model(model_cls)
elif model_config.task == "classify":
model_cls = as_classification_model(model_cls)
elif model_config.task == "reward":
model_cls = as_reward_model(model_cls)
return model_cls, arch
def get_architecture_class_name(model_config: ModelConfig) -> str:
return get_model_architecture(model_config)[1]
@dataclass
class ParamMapping:
"""
A class to handle parameter mapping for model weight loading.
It creates a bidirectional mapping between packed parameters and their
constituent parts.
"""
packed_mapping: dict[str, list[str]]
inverse_packed_mapping: dict[str, tuple[str,
int]] = field(default_factory=dict)
def __post_init__(self):
for packed_name, sub_params in self.packed_mapping.items():
# Skip self-contained cases (e.g., {"W_pack": ["W_pack"]})
if len(sub_params) == 1 and sub_params[0] == packed_name:
continue
for index, param_name in enumerate(sub_params):
self.inverse_packed_mapping[param_name] = (
packed_name,
index,
)
def get_sub_modules(self,
module_name: str) -> Optional[tuple[str, list[str]]]:
for key, value in self.packed_mapping.items():
if module_name.endswith(key):
return key, value
return None
def configure_quant_config(quant_config: QuantizationConfig,
model_class: type[nn.Module]):
"""
Pass packed_modules_mapping by reference to quant_config so that
quant_config can properly match fused modules
Note that model attributes are passed by reference to quant_config,
enabling them to be updated by model_class.__new__ (ex. chatglm, qwen)
"""
packed_mapping = getattr(model_class, "packed_modules_mapping", None)
if packed_mapping is not None:
# pass packed_modules_mapping by reference to quant_config
quant_config.packed_modules_mapping = packed_mapping
else:
logger.warning(
"The model class %s has not defined `packed_modules_mapping`, "
"this may lead to incorrect mapping of quantized or ignored "
"modules", model_class.__name__)

View File

@@ -0,0 +1,782 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utilities for downloading and initializing model weights."""
import fnmatch
import glob
import hashlib
import json
import os
import tempfile
import time
from collections import defaultdict
from collections.abc import Generator
from pathlib import Path
from typing import Any, Callable, Optional, Union
import filelock
import gguf
import huggingface_hub.constants
import numpy as np
import torch
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
from safetensors.torch import load_file, safe_open, save_file
from tqdm.auto import tqdm
from vllm.config import LoadConfig, ModelConfig
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import (QuantizationConfig,
get_quantization_config)
from vllm.platforms import current_platform
from vllm.utils import PlaceholderModule
try:
from runai_model_streamer import SafetensorsStreamer
except (ImportError, OSError):
# see https://github.com/run-ai/runai-model-streamer/issues/26
# OSError will be raised on arm64 platform
runai_model_streamer = PlaceholderModule(
"runai_model_streamer") # type: ignore[assignment]
SafetensorsStreamer = runai_model_streamer.placeholder_attr(
"SafetensorsStreamer")
try:
from fastsafetensors import SafeTensorsFileLoader, SingleGroup
except ImportError:
fastsafetensors = PlaceholderModule("fastsafetensors")
SafeTensorsFileLoader = fastsafetensors.placeholder_attr(
"SafeTensorsFileLoader")
SingleGroup = fastsafetensors.placeholder_attr("SingleGroup")
logger = init_logger(__name__)
# use system-level temp directory for file locks, so that multiple users
# can share the same lock without error.
# lock files in the temp directory will be automatically deleted when the
# system reboots, so users will not complain about annoying lock files
temp_dir = tempfile.gettempdir()
def enable_hf_transfer():
"""automatically activates hf_transfer
"""
if "HF_HUB_ENABLE_HF_TRANSFER" not in os.environ:
try:
# enable hf hub transfer if available
import hf_transfer # type: ignore # noqa
huggingface_hub.constants.HF_HUB_ENABLE_HF_TRANSFER = True
except ImportError:
pass
enable_hf_transfer()
class DisabledTqdm(tqdm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, disable=True)
def get_lock(model_name_or_path: Union[str, Path],
cache_dir: Optional[str] = None):
lock_dir = cache_dir or temp_dir
model_name_or_path = str(model_name_or_path)
os.makedirs(os.path.dirname(lock_dir), exist_ok=True)
model_name = model_name_or_path.replace("/", "-")
hash_name = hashlib.sha256(model_name.encode()).hexdigest()
# add hash to avoid conflict with old users' lock files
lock_file_name = hash_name + model_name + ".lock"
# mode 0o666 is required for the filelock to be shared across users
lock = filelock.FileLock(os.path.join(lock_dir, lock_file_name),
mode=0o666)
return lock
def _shared_pointers(tensors):
ptrs = defaultdict(list)
for k, v in tensors.items():
ptrs[v.data_ptr()].append(k)
failing = []
for _, names in ptrs.items():
if len(names) > 1:
failing.append(names)
return failing
def convert_bin_to_safetensor_file(
pt_filename: str,
sf_filename: str,
) -> None:
loaded = torch.load(pt_filename, map_location="cpu", weights_only=True)
if "state_dict" in loaded:
loaded = loaded["state_dict"]
shared = _shared_pointers(loaded)
for shared_weights in shared:
for name in shared_weights[1:]:
loaded.pop(name)
# For tensors to be contiguous
loaded = {k: v.contiguous() for k, v in loaded.items()}
dirname = os.path.dirname(sf_filename)
os.makedirs(dirname, exist_ok=True)
save_file(loaded, sf_filename, metadata={"format": "pt"})
# check file size
sf_size = os.stat(sf_filename).st_size
pt_size = os.stat(pt_filename).st_size
if (sf_size - pt_size) / pt_size > 0.01:
raise RuntimeError(f"""The file size different is more than 1%:
- {sf_filename}: {sf_size}
- {pt_filename}: {pt_size}
""")
# check if the tensors are the same
reloaded = load_file(sf_filename)
for k in loaded:
pt_tensor = loaded[k]
sf_tensor = reloaded[k]
if not torch.equal(pt_tensor, sf_tensor):
raise RuntimeError(f"The output tensors do not match for key {k}")
# TODO(woosuk): Move this to other place.
def get_quant_config(model_config: ModelConfig,
load_config: LoadConfig) -> QuantizationConfig:
quant_cls = get_quantization_config(model_config.quantization)
# GGUF doesn't have config file
if model_config.quantization == "gguf":
return quant_cls.from_config({})
# Read the quantization config from the HF model config, if available.
hf_quant_config = getattr(model_config.hf_config, "quantization_config",
None)
# some vision model may keep quantization_config in their text_config
hf_text_config = getattr(model_config.hf_config, "text_config", None)
if hf_quant_config is None and hf_text_config is not None:
hf_quant_config = getattr(hf_text_config, "quantization_config", None)
if hf_quant_config is None:
# compressed-tensors uses a compressions_config
hf_quant_config = getattr(model_config.hf_config, "compression_config",
None)
if hf_quant_config is not None:
return quant_cls.from_config(hf_quant_config)
# Inflight BNB quantization
if model_config.quantization == "bitsandbytes":
return quant_cls.from_config({})
is_local = os.path.isdir(model_config.model)
if not is_local:
# Download the config files.
with get_lock(model_config.model, load_config.download_dir):
hf_folder = snapshot_download(
model_config.model,
revision=model_config.revision,
allow_patterns="*.json",
cache_dir=load_config.download_dir,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
tqdm_class=DisabledTqdm,
)
else:
hf_folder = model_config.model
possible_config_filenames = quant_cls.get_config_filenames()
# If the quantization config is not found, use the default config.
if not possible_config_filenames:
return quant_cls()
config_files = glob.glob(os.path.join(hf_folder, "*.json"))
quant_config_files = [
f for f in config_files if any(
f.endswith(x) for x in possible_config_filenames)
]
if len(quant_config_files) == 0:
raise ValueError(
f"Cannot find the config file for {model_config.quantization}")
if len(quant_config_files) > 1:
raise ValueError(
f"Found multiple config files for {model_config.quantization}: "
f"{quant_config_files}")
quant_config_file = quant_config_files[0]
with open(quant_config_file) as f:
config = json.load(f)
if model_config.quantization == "bitsandbytes":
config["adapter_name_or_path"] = model_config.model
elif model_config.quantization == "modelopt":
if config["producer"]["name"] == "modelopt":
return quant_cls.from_config(config)
else:
raise ValueError(
f"Unsupported quantization config"
f" found for {model_config.quantization} in {f}.")
return quant_cls.from_config(config)
def get_sparse_attention_config(
model_config: ModelConfig,
load_config: LoadConfig,
sparse_attention_config_filename: str = "sparse_attention_config.json",
) -> dict[str, Any]:
model_name_or_path = model_config.model
is_local = os.path.isdir(model_name_or_path)
if not is_local:
# Download the config files.
with get_lock(model_name_or_path, load_config.download_dir):
hf_folder = snapshot_download(
model_name_or_path,
revision=model_config.revision,
allow_patterns="*.json",
cache_dir=load_config.download_dir,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
tqdm_class=DisabledTqdm,
)
else:
hf_folder = model_name_or_path
config_file = os.path.join(hf_folder, sparse_attention_config_filename)
if not os.path.exists(config_file):
return {}
# Load the sparse attention config.
with open(config_file) as f:
config = json.load(f)
logger.info("Loaded sparse attention config from %s", config_file)
return config
def download_weights_from_hf(
model_name_or_path: str,
cache_dir: Optional[str],
allow_patterns: list[str],
revision: Optional[str] = None,
ignore_patterns: Optional[Union[str, list[str]]] = None,
) -> str:
"""Download model weights from Hugging Face Hub.
Args:
model_name_or_path (str): The model name or path.
cache_dir (Optional[str]): The cache directory to store the model
weights. If None, will use HF defaults.
allow_patterns (list[str]): The allowed patterns for the
weight files. Files matched by any of the patterns will be
downloaded.
revision (Optional[str]): The revision of the model.
ignore_patterns (Optional[Union[str, list[str]]]): The patterns to
filter out the weight files. Files matched by any of the patterns
will be ignored.
Returns:
str: The path to the downloaded model weights.
"""
local_only = huggingface_hub.constants.HF_HUB_OFFLINE
if not local_only:
# Before we download we look at that is available:
fs = HfFileSystem()
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
# depending on what is available we download different things
for pattern in allow_patterns:
matching = fnmatch.filter(file_list, pattern)
if len(matching) > 0:
allow_patterns = [pattern]
break
logger.info("Using model weights format %s", allow_patterns)
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):
start_time = time.perf_counter()
hf_folder = snapshot_download(
model_name_or_path,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
cache_dir=cache_dir,
tqdm_class=DisabledTqdm,
revision=revision,
local_files_only=local_only,
)
time_taken = time.perf_counter() - start_time
if time_taken > 0.5:
logger.info("Time spent downloading weights for %s: %.6f seconds",
model_name_or_path, time_taken)
return hf_folder
def download_safetensors_index_file_from_hf(
model_name_or_path: str,
index_file: str,
cache_dir: Optional[str],
revision: Optional[str] = None,
) -> None:
"""Download hf safetensors index file from Hugging Face Hub.
Args:
model_name_or_path (str): The model name or path.
index_file (str): The safetensors index file name
cache_dir (Optional[str]): The cache directory to store the model
weights. If None, will use HF defaults.
revision (Optional[str]): The revision of the model.
"""
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):
try:
# Download the safetensors index file.
hf_hub_download(
repo_id=model_name_or_path,
filename=index_file,
cache_dir=cache_dir,
revision=revision,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
)
# If file not found on remote or locally, we should not fail since
# only some models will have index_file.
except huggingface_hub.utils.LocalEntryNotFoundError:
logger.info("No %s found in local cache.", index_file)
except huggingface_hub.utils.EntryNotFoundError:
logger.info("No %s found in remote.", index_file)
# For models like Mistral-7B-v0.3, there are both sharded
# safetensors files and a consolidated safetensors file.
# Passing both of these to the weight loader functionality breaks.
# So, we use the index_file to
# look up which safetensors files should be used.
def filter_duplicate_safetensors_files(hf_weights_files: list[str],
hf_folder: str,
index_file: str) -> list[str]:
# model.safetensors.index.json is a mapping from keys in the
# torch state_dict to safetensors file holding that weight.
index_file_name = os.path.join(hf_folder, index_file)
if not os.path.isfile(index_file_name):
return hf_weights_files
# Iterate through the weight_map (weight_name: safetensors files)
# to identify weights that we should use.
with open(index_file_name) as f:
weight_map = json.load(f)["weight_map"]
weight_files_in_index = set()
for weight_name in weight_map:
weight_files_in_index.add(
os.path.join(hf_folder, weight_map[weight_name]))
# Filter out any fields that are not found in the index file.
hf_weights_files = [
f for f in hf_weights_files if f in weight_files_in_index
]
return hf_weights_files
def filter_files_not_needed_for_inference(
hf_weights_files: list[str]) -> list[str]:
"""
Exclude files that are not needed for inference.
See https://github.com/huggingface/transformers/blob/v4.34.0/src/transformers/trainer.py#L227-L233
"""
blacklist = [
"training_args.bin",
"optimizer.bin",
"optimizer.pt",
"scheduler.pt",
"scaler.pt",
]
hf_weights_files = [
f for f in hf_weights_files
if not any(f.endswith(x) for x in blacklist)
]
return hf_weights_files
# explicitly use pure text format, with a newline at the end
# this makes it impossible to see the animation in the progress bar
# but will avoid messing up with ray or multiprocessing, which wraps
# each line of output with some prefix.
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
def enable_tqdm(use_tqdm_on_load: bool):
return use_tqdm_on_load and (not torch.distributed.is_initialized()
or torch.distributed.get_rank() == 0)
def np_cache_weights_iterator(
model_name_or_path: str,
cache_dir: Optional[str],
hf_folder: str,
hf_weights_files: list[str],
use_tqdm_on_load: bool,
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model np files.
Will dump the model weights to numpy files if they are not already dumped.
"""
# Convert the model weights from torch tensors to numpy arrays for
# faster loading.
np_folder = os.path.join(hf_folder, "np")
os.makedirs(np_folder, exist_ok=True)
weight_names_file = os.path.join(np_folder, "weight_names.json")
# Use file lock to prevent multiple processes from
# dumping the same model weights to numpy at the same time.
with get_lock(model_name_or_path, cache_dir):
if not os.path.exists(weight_names_file):
weight_names: list[str] = []
for bin_file in tqdm(
hf_weights_files,
desc="Loading np_cache checkpoint shards",
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
):
state = torch.load(bin_file,
map_location="cpu",
weights_only=True)
for name, param in state.items():
param_path = os.path.join(np_folder, name)
with open(param_path, "wb") as f:
np.save(f, param.cpu().detach().numpy())
weight_names.append(name)
with open(weight_names_file, "w") as f:
json.dump(weight_names, f)
with open(weight_names_file) as f:
weight_names = json.load(f)
for name in weight_names:
param_path = os.path.join(np_folder, name)
with open(param_path, "rb") as f:
param = np.load(f)
yield name, torch.from_numpy(param)
def safetensors_weights_iterator(
hf_weights_files: list[str],
use_tqdm_on_load: bool,
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
for st_file in tqdm(
hf_weights_files,
desc="Loading safetensors checkpoint shards",
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
):
with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118
param = f.get_tensor(name)
yield name, param
def runai_safetensors_weights_iterator(
hf_weights_files: list[str],
use_tqdm_on_load: bool,
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
with SafetensorsStreamer() as streamer:
for st_file in tqdm(
hf_weights_files,
desc="Loading safetensors using Runai Model Streamer",
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
):
streamer.stream_file(st_file)
yield from streamer.get_tensors()
def fastsafetensors_weights_iterator(
hf_weights_files: list[str],
use_tqdm_on_load: bool,
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files
using fastsafetensor library."""
if torch.distributed.is_initialized():
pg = torch.distributed.group.WORLD
else:
pg = SingleGroup()
device = torch.device(f'cuda:{pg.rank()}')
weight_files_sub_lists = [
hf_weights_files[i:i + pg.size()]
for i in range(0, len(hf_weights_files), pg.size())
]
for f_list in tqdm(
weight_files_sub_lists,
desc="Loading safetensors using Fastsafetensor loader",
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
):
loader = SafeTensorsFileLoader(pg, device)
rank_file_map = {i: [f] for i, f in enumerate(f_list)}
loader.add_filenames(rank_file_map)
try:
fb = loader.copy_files_to_device()
try:
keys = list(fb.key_to_rank_lidx.keys())
for k in keys:
t = fb.get_tensor(k)
yield k, t
finally:
fb.close()
finally:
loader.close()
def pt_weights_iterator(
hf_weights_files: list[str],
use_tqdm_on_load: bool,
pt_load_map_location: Union[str, dict[str, str]] = "cpu",
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model bin/pt files."""
for bin_file in tqdm(
hf_weights_files,
desc="Loading pt checkpoint shards",
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
):
state = torch.load(bin_file,
map_location=pt_load_map_location,
weights_only=True)
yield from state.items()
del state
def get_gguf_extra_tensor_names(
gguf_file: str, gguf_to_hf_name_map: dict[str, str]) -> list[str]:
reader = gguf.GGUFReader(gguf_file)
expected_gguf_keys = set(gguf_to_hf_name_map.keys())
exact_gguf_keys = set([tensor.name for tensor in reader.tensors])
extra_keys = expected_gguf_keys - exact_gguf_keys
return [gguf_to_hf_name_map[key] for key in extra_keys]
def gguf_quant_weights_iterator(
gguf_file: str, gguf_to_hf_name_map: dict[str, str]
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""
Iterate over the quant weights in the model gguf files and convert
them to torch tensors
"""
reader = gguf.GGUFReader(gguf_file)
for tensor in reader.tensors:
if tensor.name in gguf_to_hf_name_map:
weight_type = tensor.tensor_type
name = gguf_to_hf_name_map[tensor.name]
if weight_type.name != "F32":
weight_type_name = name.replace("weight", "qweight_type")
weight_type = torch.tensor(weight_type)
yield weight_type_name, weight_type
for tensor in reader.tensors:
if tensor.name in gguf_to_hf_name_map:
weight = tensor.data
weight_type = tensor.tensor_type
name = gguf_to_hf_name_map[tensor.name]
if weight_type.name != "F32":
name = name.replace("weight", "qweight")
param = torch.tensor(weight)
yield name, param
def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
"""convert PySafeSlice object from safetensors to torch.Tensor
PySafeSlice object supports indexing, which is done before loading the
actual tensor and can reduce the amount of memory being read into the
memory. However, it does not support more advanced functionalities
like `.view()` or `.t()`. Therefore, if we need to modify the loaded
tensor with these more complicated operators, we need to convert to
tensor first.
"""
if not isinstance(x, torch.Tensor):
x = x[:]
return x
def default_weight_loader(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
"""Default weight loader."""
try:
if param.numel() == 1 and loaded_weight.numel() == 1:
# Sometimes scalar values aren't considered tensors with shapes
# so if both param and loaded_weight are a scalar,
# "broadcast" instead of copy
param.data.fill_(loaded_weight.item())
else:
assert param.size() == loaded_weight.size(), (
f"Attempted to load weight ({loaded_weight.size()}) "
f"into parameter ({param.size()})")
param.data.copy_(loaded_weight)
except Exception:
# NOTE: This exception is added for the purpose of setting breakpoint to
# debug weight loading issues.
raise
def row_parallel_weight_loader(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
"""Load weights that are row-parallelized."""
tp_rank = get_tensor_model_parallel_rank()
shard_dim = 0 if param.dim() != 1 else None
if shard_dim is not None:
shard_size = param.data.shape[shard_dim]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(shard_dim, start_idx, shard_size)
return default_weight_loader(param, loaded_weight)
LoaderFunction = Callable[[torch.Tensor, torch.Tensor], None]
def sharded_weight_loader(shard_axis: int) -> LoaderFunction:
"""Create a weight loader that shards the weights along the given axis"""
def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
tp_rank = get_tensor_model_parallel_rank()
shard_size = param.data.shape[shard_axis]
start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(shard_axis, start_idx, shard_size)
return default_weight_loader(param, loaded_weight)
return loader
def composed_weight_loader(
loader: LoaderFunction, fn: Callable[[torch.Tensor],
torch.Tensor]) -> LoaderFunction:
"""Create a weight loader that post-processes the weights after loading"""
def composed_loader(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
loader(param, loaded_weight)
param.data.copy_(fn(param))
return
return composed_loader
def initialize_dummy_weights(
model: torch.nn.Module,
low: float = -1e-3,
high: float = 1e-3,
seed: int = 1234,
) -> None:
"""Initialize model weights with random values.
The model weights must be randomly initialized for accurate performance
measurements. Additionally, the model weights should not cause NaNs in the
forward pass. We empirically found that initializing the weights with
values between -1e-3 and 1e-3 works well for most models.
We use per-parameter random seed, so that dummy weights are consistent,
even if the model is partitioned across multiple devices. When the seed
is fixed, the random values generated by this function only depends on
the parameter's number of elements and its data type.
"""
for param in model.state_dict().values():
if torch.is_floating_point(param):
if current_platform.is_tpu():
generator = torch.Generator(device="cpu")
generator.manual_seed(seed)
# Note: The param.uniform_ function cannot be used in this
# context because it demands more TPU HBM than directly copying
# from a CPU tensor.
# Note: We avoid using torch.rank_like as it doesn't currently
# support the generator argument.
param.copy_((high - low) *
torch.rand(param.shape,
generator=generator,
dtype=param.dtype,
layout=param.layout,
requires_grad=param.requires_grad,
device="cpu") + low)
torch._sync(param)
continue
generator = torch.Generator(device=param.data.device)
generator.manual_seed(seed)
if torch.finfo(param.data.dtype).bits < 16:
# uniform_ doesn't support < 16-bit datatypes (FP8)
dtype = param.data.dtype
tmp_param = param.data.to(torch.float16)
tmp_param = tmp_param.uniform_(low, high,
generator=generator).to(dtype)
param.data.copy_(tmp_param)
else:
param.uniform_(low, high, generator=generator)
def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
"""Remap the name of FP8 k/v_scale parameters.
This function handles the remapping of FP8 k/v_scale parameter names.
It detects if the given name ends with a suffix and attempts to remap
it to the expected name format in the model. If the remapped name is not
found in the params_dict, a warning is printed and None is returned.
Args:
name (str): The original loaded checkpoint parameter name.
params_dict (dict): Dictionary containing the model's named parameters.
Returns:
str: The remapped parameter name if successful, or the original name
if no remapping is needed.
None: If the remapped name is not found in params_dict.
"""
if name.endswith(".kv_scale"):
logger.warning_once(
"DEPRECATED. Found kv_scale in the checkpoint. "
"This format is deprecated in favor of separate k_scale and "
"v_scale tensors and will be removed in a future release. "
"Functionally, we will remap kv_scale to k_scale and duplicate "
"k_scale to v_scale")
# NOTE: we remap the deprecated kv_scale to k_scale
remapped_name = name.replace(".kv_scale", ".attn.k_scale")
if remapped_name not in params_dict:
logger.warning_once(
"Found kv_scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv_scale is not loaded.", # noqa: E501
name,
remapped_name,
)
return None
return remapped_name
possible_scale_names = [".k_scale", ".v_scale"]
modelopt_scale_names = [
".self_attn.k_proj.k_scale", ".self_attn.v_proj.v_scale"
]
for scale_name in possible_scale_names:
if name.endswith(scale_name):
if any(mo_scale_name in name
for mo_scale_name in modelopt_scale_names):
remapped_name = name.replace(
f".self_attn.{scale_name[1]}_proj{scale_name}",
f".self_attn.attn{scale_name}")
else:
remapped_name = name.replace(scale_name, f".attn{scale_name}")
if remapped_name not in params_dict:
logger.warning_once(
"Found %s in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). %s is not loaded.", # noqa: E501
scale_name,
name,
remapped_name,
scale_name,
)
return None
return remapped_name
# If there were no matches, return the untouched param name
return name