Files

791 lines
29 KiB
Python
Raw Permalink Normal View History

2026-01-19 10:38:50 +08:00
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
2026-01-09 13:34:11 +08:00
import argparse
2026-01-19 10:38:50 +08:00
import contextlib
import contextvars
2026-01-09 13:34:11 +08:00
import dataclasses
2026-01-19 10:38:50 +08:00
import json
2026-01-09 13:34:11 +08:00
import os
2026-01-19 10:38:50 +08:00
import tempfile
import threading
2026-01-09 13:34:11 +08:00
import time
2026-01-19 10:38:50 +08:00
from collections.abc import Generator, MutableMapping
from dataclasses import asdict, dataclass, field, fields
from typing import TYPE_CHECKING, Any, ClassVar, Optional
2026-01-09 13:34:11 +08:00
2026-01-19 10:38:50 +08:00
import regex as re
2026-01-09 13:34:11 +08:00
import torch
2026-01-19 10:38:50 +08:00
from huggingface_hub import snapshot_download
2026-01-09 13:34:11 +08:00
from torch import nn
2026-01-19 10:38:50 +08:00
from torch.utils._python_dispatch import TorchDispatchMode
2026-01-09 13:34:11 +08:00
from transformers import PretrainedConfig
import vllm.envs as envs
2026-01-19 10:38:50 +08:00
from vllm.config import ModelConfig, ParallelConfig, VllmConfig, set_current_vllm_config
2026-01-09 13:34:11 +08:00
from vllm.logger import init_logger
2026-01-19 10:38:50 +08:00
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.platforms import current_platform
from vllm.utils.argparse_utils import FlexibleArgumentParser
from vllm.utils.import_utils import PlaceholderModule
2026-01-09 13:34:11 +08:00
2026-01-19 10:38:50 +08:00
if TYPE_CHECKING:
from vllm.engine.arg_utils import EngineArgs
2026-01-09 13:34:11 +08:00
try:
2026-01-19 10:38:50 +08:00
from tensorizer import (
DecryptionParams,
EncryptionParams,
TensorDeserializer,
TensorSerializer,
)
2026-01-09 13:34:11 +08:00
from tensorizer.stream_io import open_stream
2026-01-19 10:38:50 +08:00
from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor
except ImportError:
tensorizer = PlaceholderModule("tensorizer")
DecryptionParams = tensorizer.placeholder_attr("DecryptionParams")
EncryptionParams = tensorizer.placeholder_attr("EncryptionParams")
TensorDeserializer = tensorizer.placeholder_attr("TensorDeserializer")
TensorSerializer = tensorizer.placeholder_attr("TensorSerializer")
open_stream = tensorizer.placeholder_attr("stream_io.open_stream")
convert_bytes = tensorizer.placeholder_attr("utils.convert_bytes")
get_mem_usage = tensorizer.placeholder_attr("utils.get_mem_usage")
no_init_or_tensor = tensorizer.placeholder_attr("utils.no_init_or_tensor")
2026-01-09 13:34:11 +08:00
__all__ = [
2026-01-19 10:38:50 +08:00
"EncryptionParams",
"DecryptionParams",
"TensorDeserializer",
"TensorSerializer",
"open_stream",
"convert_bytes",
"get_mem_usage",
"no_init_or_tensor",
"TensorizerConfig",
2026-01-09 13:34:11 +08:00
]
logger = init_logger(__name__)
2026-01-19 10:38:50 +08:00
def is_valid_deserialization_uri(uri: str | None) -> bool:
if uri:
scheme = uri.lower().split("://")[0]
return scheme in {"s3", "http", "https"} or os.path.exists(uri)
return False
def tensorizer_kwargs_arg(value):
loaded = json.loads(value)
if not isinstance(loaded, dict):
raise argparse.ArgumentTypeError(
f"Not deserializable to dict: {value}. serialization_kwargs and "
f"deserialization_kwargs must be "
f"deserializable from a JSON string to a dictionary. "
)
return loaded
class MetaTensorMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = kwargs or {}
if func._schema.name == "aten::empty" and "device" not in kwargs:
kwargs["device"] = "meta"
return func(*args, **kwargs)
def meta_tensor_mode(
loading_code=None,
):
if loading_code is None:
return _NoInitOrTensorImpl.context_manager()
elif callable(loading_code):
with _NoInitOrTensorImpl.context_manager():
return loading_code()
else:
raise TypeError(
"expected a callable to evaluate,"
" or None if being used as a context manager;"
f' got an object of type "{type(loading_code).__name__}" instead.'
)
class _NoInitOrTensorImpl:
_MODULES = (torch.nn.Linear, torch.nn.Embedding, torch.nn.LayerNorm)
_MODULE_ORIGINALS = tuple((m, m.reset_parameters) for m in _MODULES)
is_active = contextvars.ContextVar("_NoInitOrTensorImpl.is_active", default=False)
_count_active: int = 0
_count_active_lock = threading.Lock()
@classmethod
@contextlib.contextmanager
def context_manager(cls):
if cls.is_active.get():
yield
return
with cls._count_active_lock:
cls._count_active += 1
if cls._count_active == 1:
for mod in cls._MODULES:
mod.reset_parameters = cls._disable(mod.reset_parameters)
reset_token = cls.is_active.set(True)
try:
with MetaTensorMode():
yield
finally:
cls.is_active.reset(reset_token)
with cls._count_active_lock:
cls._count_active -= 1
if cls._count_active == 0:
for mod, original in cls._MODULE_ORIGINALS:
mod.reset_parameters = original
@staticmethod
def _disable(func):
def wrapper(*args, **kwargs):
if not _NoInitOrTensorImpl.is_active.get():
return func(*args, **kwargs)
return wrapper
2026-01-09 13:34:11 +08:00
@dataclass
2026-01-19 10:38:50 +08:00
class TensorizerConfig(MutableMapping):
tensorizer_uri: str | None = None
tensorizer_dir: str | None = None
vllm_tensorized: bool | None = None
verify_hash: bool | None = None
num_readers: int | None = None
encryption_keyfile: str | None = None
s3_access_key_id: str | None = None
s3_secret_access_key: str | None = None
s3_endpoint: str | None = None
lora_dir: str | None = None
stream_kwargs: dict[str, Any] | None = None
serialization_kwargs: dict[str, Any] | None = None
deserialization_kwargs: dict[str, Any] | None = None
_extra_serialization_attrs: dict[str, Any] | None = field(init=False, default=None)
model_class: type[torch.nn.Module] | None = field(init=False, default=None)
hf_config: PretrainedConfig | None = field(init=False, default=None)
dtype: str | torch.dtype | None = field(init=False, default=None)
_is_sharded: bool = field(init=False, default=False)
_fields: ClassVar[tuple[str, ...]]
_keys: ClassVar[frozenset[str]]
"""Configuration class for Tensorizer settings.
These settings configure the behavior of model serialization and
deserialization using Tensorizer.
Attributes:
tensorizer_uri: Path to serialized model tensors. Can be a local file
path or a S3 URI. This is a required field unless lora_dir is
provided and the config is meant to be used for the
`tensorize_lora_adapter` function. Unless a `tensorizer_dir` or
`lora_dir` is passed to this object's initializer, this is
a required argument.
tensorizer_dir: Path to a directory containing serialized model tensors,
and all other potential model artifacts to load the model, such as
configs and tokenizer files. Can be passed instead of
`tensorizer_uri` where the `model.tensors` file will be assumed
to be in this directory.
vllm_tensorized: If True, indicates that the serialized model is a
vLLM model. This is used to determine the behavior of the
TensorDeserializer when loading tensors from a serialized model.
It is far faster to deserialize a vLLM model as it utilizes
tensorizer's optimized GPU loading. Note that this is now
deprecated, as serialized vLLM models are now automatically
inferred as vLLM models.
verify_hash: If True, the hashes of each tensor will be verified
against the hashes stored in the metadata. A `HashMismatchError`
will be raised if any of the hashes do not match.
num_readers: Controls how many threads are allowed to read concurrently
from the source file. Default is `None`, which will dynamically set
the number of readers based on the number of available
resources and model size. This greatly increases performance.
encryption_keyfile: File path to a binary file containing a
binary key to use for decryption. `None` (the default) means
no decryption. See the example script in
examples/others/tensorize_vllm_model.py.
s3_access_key_id: The access key for the S3 bucket. Can also be set via
the S3_ACCESS_KEY_ID environment variable.
s3_secret_access_key: The secret access key for the S3 bucket. Can also
be set via the S3_SECRET_ACCESS_KEY environment variable.
s3_endpoint: The endpoint for the S3 bucket. Can also be set via the
S3_ENDPOINT_URL environment variable.
lora_dir: Path to a directory containing LoRA adapter artifacts for
serialization or deserialization. When serializing LoRA adapters
this is the only necessary parameter to pass to this object's
initializer.
"""
def __post_init__(self):
# check if the configuration is for a sharded vLLM model
self._is_sharded = (
isinstance(self.tensorizer_uri, str)
and re.search(r"%0\dd", self.tensorizer_uri) is not None
)
if self.tensorizer_dir and self.lora_dir:
raise ValueError(
"Only one of tensorizer_dir or lora_dir may be specified. "
"Use lora_dir exclusively when serializing LoRA adapters, "
"and tensorizer_dir or tensorizer_uri otherwise."
)
if self.tensorizer_dir and self.tensorizer_uri:
logger.warning_once(
"Provided both tensorizer_dir and tensorizer_uri. "
"Inferring tensorizer_dir from tensorizer_uri as the "
"latter takes precedence."
)
self.tensorizer_dir = os.path.dirname(self.tensorizer_uri)
if not self.tensorizer_uri:
if self.lora_dir:
self.tensorizer_uri = f"{self.lora_dir}/adapter_model.tensors"
elif self.tensorizer_dir:
self.tensorizer_uri = f"{self.tensorizer_dir}/model.tensors"
else:
raise ValueError(
"Unable to resolve tensorizer_uri. "
"A valid tensorizer_uri or tensorizer_dir "
"must be provided for deserialization, and a "
"valid tensorizer_uri, tensorizer_uri, or "
"lora_dir for serialization."
)
else:
self.tensorizer_dir = os.path.dirname(self.tensorizer_uri)
if not self.serialization_kwargs:
self.serialization_kwargs = {}
if not self.deserialization_kwargs:
self.deserialization_kwargs = {}
def to_serializable(self) -> dict[str, Any]:
# Due to TensorizerConfig needing to be msgpack-serializable, it needs
# support for morphing back and forth between itself and its dict
# representation
# TensorizerConfig's representation as a dictionary is meant to be
# linked to TensorizerConfig in such a way that the following is
# technically initializable:
# TensorizerConfig(**my_tensorizer_cfg.to_serializable())
# This means the dict must not retain non-initializable parameters
# and post-init attribute states
# Also don't want to retain private and unset parameters, so only retain
# not None values and public attributes
raw_tc_dict = asdict(self)
blacklisted = []
if "tensorizer_uri" in raw_tc_dict and "tensorizer_dir" in raw_tc_dict:
blacklisted.append("tensorizer_dir")
if "tensorizer_dir" in raw_tc_dict and "lora_dir" in raw_tc_dict:
blacklisted.append("tensorizer_dir")
tc_dict = {}
for k, v in raw_tc_dict.items():
if (
k not in blacklisted
and k not in tc_dict
and not k.startswith("_")
and v is not None
):
tc_dict[k] = v
return tc_dict
2026-01-09 13:34:11 +08:00
def _construct_tensorizer_args(self) -> "TensorizerArgs":
2026-01-19 10:38:50 +08:00
return TensorizerArgs(self) # type: ignore
2026-01-09 13:34:11 +08:00
def verify_with_parallel_config(
self,
parallel_config: "ParallelConfig",
) -> None:
2026-01-19 10:38:50 +08:00
if parallel_config.tensor_parallel_size > 1 and not self._is_sharded:
2026-01-09 13:34:11 +08:00
raise ValueError(
2026-01-19 10:38:50 +08:00
"For a sharded model, tensorizer_uri should include a"
" string format template like '%04d' to be formatted"
" with the rank of the shard"
)
2026-01-09 13:34:11 +08:00
def verify_with_model_config(self, model_config: "ModelConfig") -> None:
2026-01-19 10:38:50 +08:00
if model_config.quantization is not None and self.tensorizer_uri is not None:
2026-01-09 13:34:11 +08:00
logger.warning(
"Loading a model using Tensorizer with quantization on vLLM"
2026-01-19 10:38:50 +08:00
" is unstable and may lead to errors."
)
def open_stream(self, tensorizer_args: Optional["TensorizerArgs"] = None):
if tensorizer_args is None:
tensorizer_args = self._construct_tensorizer_args()
return open_stream(self.tensorizer_uri, **tensorizer_args.stream_kwargs)
def keys(self):
return self._keys
def __len__(self):
return len(fields(self))
def __iter__(self):
return iter(self._fields)
def __getitem__(self, item: str) -> Any:
if item not in self.keys():
raise KeyError(item)
return getattr(self, item)
2026-01-09 13:34:11 +08:00
2026-01-19 10:38:50 +08:00
def __setitem__(self, key: str, value: Any) -> None:
if key not in self.keys():
# Disallow modifying invalid keys
raise KeyError(key)
setattr(self, key, value)
2026-01-09 13:34:11 +08:00
2026-01-19 10:38:50 +08:00
def __delitem__(self, key, /):
if key not in self.keys():
raise KeyError(key)
delattr(self, key)
2026-01-09 13:34:11 +08:00
2026-01-19 10:38:50 +08:00
TensorizerConfig._fields = tuple(f.name for f in fields(TensorizerConfig))
TensorizerConfig._keys = frozenset(TensorizerConfig._fields)
2026-01-09 13:34:11 +08:00
@dataclass
class TensorizerArgs:
2026-01-19 10:38:50 +08:00
tensorizer_uri: str | None = None
tensorizer_dir: str | None = None
encryption_keyfile: str | None = None
def __init__(self, tensorizer_config: TensorizerConfig):
for k, v in tensorizer_config.items():
setattr(self, k, v)
self.file_obj = tensorizer_config.tensorizer_uri
self.s3_access_key_id = (
tensorizer_config.s3_access_key_id or envs.S3_ACCESS_KEY_ID
)
self.s3_secret_access_key = (
tensorizer_config.s3_secret_access_key or envs.S3_SECRET_ACCESS_KEY
)
self.s3_endpoint = tensorizer_config.s3_endpoint or envs.S3_ENDPOINT_URL
2026-01-09 13:34:11 +08:00
2026-01-19 10:38:50 +08:00
self.stream_kwargs = {
"s3_access_key_id": tensorizer_config.s3_access_key_id,
"s3_secret_access_key": tensorizer_config.s3_secret_access_key,
"s3_endpoint": tensorizer_config.s3_endpoint,
**(tensorizer_config.stream_kwargs or {}),
2026-01-09 13:34:11 +08:00
}
2026-01-19 10:38:50 +08:00
self.deserialization_kwargs = {
"verify_hash": tensorizer_config.verify_hash,
"encryption": tensorizer_config.encryption_keyfile,
"num_readers": tensorizer_config.num_readers,
**(tensorizer_config.deserialization_kwargs or {}),
2026-01-09 13:34:11 +08:00
}
2026-01-19 10:38:50 +08:00
2026-01-09 13:34:11 +08:00
if self.encryption_keyfile:
with open_stream(
2026-01-19 10:38:50 +08:00
tensorizer_config.encryption_keyfile,
**self.stream_kwargs,
2026-01-09 13:34:11 +08:00
) as stream:
key = stream.read()
decryption_params = DecryptionParams.from_key(key)
2026-01-19 10:38:50 +08:00
self.deserialization_kwargs["encryption"] = decryption_params
2026-01-09 13:34:11 +08:00
@staticmethod
2026-01-19 10:38:50 +08:00
def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
2026-01-09 13:34:11 +08:00
"""Tensorizer CLI arguments"""
# Tensorizer options arg group
group = parser.add_argument_group(
2026-01-19 10:38:50 +08:00
"tensorizer options",
description=(
"Options for configuring the behavior of the"
" tensorizer deserializer when "
"load_format=tensorizer is specified when "
"initializing an LLMEngine, either via the CLI "
"when running the vLLM OpenAI inference server "
"with a JSON string passed to "
"--model-loader-extra-config or as arguments given "
"to TensorizerConfig when passed to "
"model_loader_extra_config in the constructor "
"for LLMEngine."
),
)
2026-01-09 13:34:11 +08:00
group.add_argument(
"--tensorizer-uri",
2026-01-19 10:38:50 +08:00
type=str,
2026-01-09 13:34:11 +08:00
help="Path to serialized model tensors. Can be a local file path,"
" or an HTTP(S) or S3 URI.",
)
group.add_argument(
"--verify-hash",
action="store_true",
help="If enabled, the hashes of each tensor will be verified"
" against the hashes stored in the file metadata. An exception"
" will be raised if any of the hashes do not match.",
)
group.add_argument(
"--encryption-keyfile",
2026-01-19 10:38:50 +08:00
type=str,
2026-01-09 13:34:11 +08:00
default=None,
help="The file path to a binary file containing a binary key to "
2026-01-19 10:38:50 +08:00
"use for decryption. Can be a file path or S3 network URI.",
)
2026-01-09 13:34:11 +08:00
group.add_argument(
"--num-readers",
default=None,
type=int,
help="Controls how many threads are allowed to read concurrently "
"from the source file. Default is `None`, which will dynamically "
"set the number of readers based on the available resources "
2026-01-19 10:38:50 +08:00
"and model size. This greatly increases performance.",
)
2026-01-09 13:34:11 +08:00
group.add_argument(
"--s3-access-key-id",
2026-01-19 10:38:50 +08:00
type=str,
2026-01-09 13:34:11 +08:00
default=None,
help="The access key for the S3 bucket. Can also be set via the "
"S3_ACCESS_KEY_ID environment variable.",
)
group.add_argument(
"--s3-secret-access-key",
2026-01-19 10:38:50 +08:00
type=str,
2026-01-09 13:34:11 +08:00
default=None,
help="The secret access key for the S3 bucket. Can also be set via "
"the S3_SECRET_ACCESS_KEY environment variable.",
)
group.add_argument(
"--s3-endpoint",
2026-01-19 10:38:50 +08:00
type=str,
2026-01-09 13:34:11 +08:00
default=None,
help="The endpoint for the S3 bucket. Can also be set via the "
"S3_ENDPOINT_URL environment variable.",
)
return parser
@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> "TensorizerArgs":
attrs = [attr.name for attr in dataclasses.fields(cls)]
2026-01-19 10:38:50 +08:00
tensorizer_args = cls(
**{attr: getattr(args, attr) for attr in attrs if hasattr(args, attr)}
)
2026-01-09 13:34:11 +08:00
return tensorizer_args
2026-01-19 10:38:50 +08:00
def _check_tensors_on_meta_device(model: nn.Module) -> None:
for tensor in model.state_dict().values():
if tensor.device.type == "meta":
raise ValueError(
"The serialized model contains tensors on the meta device,"
" indicating that some tensors were not loaded properly."
" Please check that the parameters of the model being"
" specified match that of the serialized model, such as"
" its quantization."
)
def _resize_lora_embeddings(model: nn.Module):
"""Modify LoRA embedding layers to use bigger tensors
to allow for adapter added tokens."""
for child in model.modules():
if (
isinstance(child, VocabParallelEmbedding)
and child.weight.shape[0] < child.num_embeddings_per_partition
):
new_weight = torch.empty(
child.num_embeddings_per_partition,
child.embedding_dim,
dtype=child.weight.dtype,
device=child.weight.device,
)
new_weight[: child.weight.shape[0]].copy_(child.weight.data)
new_weight[child.weight.shape[0] :].fill_(0)
child.weight.data = new_weight
def init_tensorizer_model(
tensorizer_config: TensorizerConfig, vllm_config: VllmConfig
) -> nn.Module:
assert tensorizer_config.hf_config is not None
model_args = tensorizer_config.hf_config
model_args.dtype = tensorizer_config.dtype
assert tensorizer_config.model_class is not None
# TODO: Do we need to consider old-style model class?
with meta_tensor_mode(), set_current_vllm_config(vllm_config, check_compile=True):
return tensorizer_config.model_class(vllm_config=vllm_config)
def deserialize_tensorizer_model(
model: nn.Module, tensorizer_config: TensorizerConfig
) -> None:
tensorizer_args = tensorizer_config._construct_tensorizer_args()
if not is_valid_deserialization_uri(tensorizer_config.tensorizer_uri):
raise ValueError(
f"{tensorizer_config.tensorizer_uri} is not a valid "
f"tensorizer URI. Please check that the URI is correct. "
f"It must either point to a local existing file, or have a "
f"S3, HTTP or HTTPS scheme."
)
before_mem = get_mem_usage()
start = time.perf_counter()
with (
open_stream(
tensorizer_config.tensorizer_uri, mode="rb", **tensorizer_args.stream_kwargs
) as stream,
TensorDeserializer(
stream,
dtype=tensorizer_config.dtype,
device=f"xpu:{torch.xpu.current_device()}"
if current_platform.is_xpu()
else f"cuda:{torch.cuda.current_device()}",
**tensorizer_args.deserialization_kwargs,
) as deserializer,
):
deserializer.load_into_module(model)
end = time.perf_counter()
total_bytes_str = convert_bytes(deserializer.total_tensor_bytes)
duration = end - start
per_second = convert_bytes(deserializer.total_tensor_bytes / duration)
after_mem = get_mem_usage()
deserializer.close()
logger.info(
"Deserialized %s in %0.2fs, %s/s", total_bytes_str, end - start, per_second
)
logger.info("Memory usage before: %s", before_mem)
logger.info("Memory usage after: %s", after_mem)
_check_tensors_on_meta_device(model)
_resize_lora_embeddings(model)
del model.vllm_tensorized_marker
2026-01-09 13:34:11 +08:00
def tensorizer_weights_iterator(
2026-01-19 10:38:50 +08:00
tensorizer_args: "TensorizerArgs",
) -> Generator[tuple[str, torch.Tensor], None, None]:
2026-01-09 13:34:11 +08:00
logger.warning(
"Deserializing HuggingFace models is not optimized for "
"loading on vLLM, as tensorizer is forced to load to CPU. "
"Consider deserializing a vLLM model instead for faster "
2026-01-19 10:38:50 +08:00
"load times. See the "
"examples/others/tensorize_vllm_model.py example script "
"for serializing vLLM models."
)
deserializer_args = tensorizer_args.deserialization_kwargs
stream_kwargs = tensorizer_args.stream_kwargs
stream = open_stream(tensorizer_args.tensorizer_uri, **stream_kwargs)
with TensorDeserializer(stream, **deserializer_args, device="cpu") as state:
yield from state.items()
2026-01-09 13:34:11 +08:00
del state
2026-01-19 10:38:50 +08:00
def is_vllm_tensorized(tensorizer_config: "TensorizerConfig") -> bool:
"""
Infer if the model is a vLLM model by checking the weights for
a vLLM tensorized marker.
Args:
tensorizer_config: The TensorizerConfig object containing the
tensorizer_uri to the serialized model.
Returns:
bool: True if the model is a vLLM model, False otherwise.
"""
tensorizer_args = tensorizer_config._construct_tensorizer_args()
deserializer = TensorDeserializer(
open_stream(tensorizer_args.tensorizer_uri, **tensorizer_args.stream_kwargs),
**tensorizer_args.deserialization_kwargs,
lazy_load=True,
)
if tensorizer_config.vllm_tensorized:
logger.warning(
"Please note that newly serialized vLLM models are automatically "
"inferred as vLLM models, so setting vllm_tensorized=True is "
"only necessary for models serialized prior to this change."
)
return True
return ".vllm_tensorized_marker" in deserializer
def serialize_extra_artifacts(
tensorizer_args: TensorizerArgs, served_model_name: str | list[str] | None
) -> None:
if not isinstance(served_model_name, str):
raise ValueError(
f"served_model_name must be a str for serialize_extra_artifacts, "
f"not {type(served_model_name)}."
)
with tempfile.TemporaryDirectory() as tmpdir:
snapshot_download(
served_model_name,
local_dir=tmpdir,
ignore_patterns=[
"*.pt",
"*.safetensors",
"*.bin",
"*.cache",
"*.gitattributes",
"*.md",
],
)
for artifact in os.scandir(tmpdir):
if not artifact.is_file():
continue
with (
open(artifact.path, "rb") as f,
open_stream(
f"{tensorizer_args.tensorizer_dir}/{artifact.name}",
mode="wb+",
**tensorizer_args.stream_kwargs,
) as stream,
):
logger.info("Writing artifact %s", artifact.name)
stream.write(f.read())
def serialize_vllm_model(
model: nn.Module,
tensorizer_config: TensorizerConfig,
model_config: "ModelConfig",
) -> nn.Module:
model.register_parameter(
"vllm_tensorized_marker",
nn.Parameter(torch.tensor((1,), device="meta"), requires_grad=False),
)
tensorizer_args = tensorizer_config._construct_tensorizer_args()
encryption_params = None
if (keyfile := tensorizer_config.encryption_keyfile) is not None:
with open(keyfile, "rb") as f:
key = f.read()
encryption_params = EncryptionParams(key=key)
output_file = tensorizer_args.tensorizer_uri
if tensorizer_config._is_sharded:
from vllm.distributed import get_tensor_model_parallel_rank
output_file = output_file % get_tensor_model_parallel_rank()
with open_stream(
output_file, mode="wb+", **tensorizer_args.stream_kwargs
) as stream:
serializer = TensorSerializer(
stream,
encryption=encryption_params,
**tensorizer_config.serialization_kwargs,
)
serializer.write_module(model)
serializer.close()
serialize_extra_artifacts(tensorizer_args, model_config.served_model_name)
logger.info("Successfully serialized model to %s", str(output_file))
return model
def tensorize_vllm_model(
engine_args: "EngineArgs",
tensorizer_config: TensorizerConfig,
generate_keyfile: bool = True,
):
"""Utility to load a model and then serialize it with Tensorizer
Intended to be used separately from running a vLLM server since it
creates its own Engine instance.
"""
engine_config = engine_args.create_engine_config()
tensorizer_config.verify_with_model_config(engine_config.model_config)
tensorizer_config.verify_with_parallel_config(engine_config.parallel_config)
# generate the encryption key before creating the engine to support sharding
if (
generate_keyfile
and (keyfile := tensorizer_config.encryption_keyfile) is not None
):
encryption_params = EncryptionParams.random()
with open_stream(
keyfile,
mode="wb+",
s3_access_key_id=tensorizer_config.s3_access_key_id,
s3_secret_access_key=tensorizer_config.s3_secret_access_key,
s3_endpoint=tensorizer_config.s3_endpoint,
) as stream:
stream.write(encryption_params.key)
from vllm.v1.engine.llm_engine import LLMEngine
engine = LLMEngine.from_vllm_config(engine_config)
engine.collective_rpc(
"save_tensorized_model",
kwargs={"tensorizer_config": tensorizer_config.to_serializable()},
)
def tensorize_lora_adapter(lora_path: str, tensorizer_config: TensorizerConfig):
"""
Uses tensorizer to serialize a LoRA adapter. Assumes that the files
needed to load a LoRA adapter are a safetensors-format file called
adapter_model.safetensors and a json config file called adapter_config.json.
Serializes the files in the tensorizer_config.tensorizer_dir
"""
import safetensors
from vllm.lora.utils import get_adapter_absolute_path
lora_dir = get_adapter_absolute_path(lora_path)
tensor_path = config_path = ""
for file in os.listdir(lora_dir):
if file.startswith("adapter_model"):
tensor_path = lora_dir + "/" + file
if file.startswith("adapter_config"):
config_path = lora_dir + "/" + file
if tensor_path and config_path:
break
if tensor_path.endswith(".safetensors"):
tensors = safetensors.torch.load_file(tensor_path)
elif tensor_path.endswith(".bin"):
tensors = torch.load(tensor_path)
else:
raise ValueError("Unsupported file: %s", tensor_path)
with open(config_path) as f:
config = json.load(f)
tensorizer_args = tensorizer_config._construct_tensorizer_args()
with open_stream(
f"{tensorizer_config.tensorizer_dir}/adapter_config.json",
mode="wb+",
**tensorizer_args.stream_kwargs,
) as f:
f.write(json.dumps(config).encode("utf-8"))
lora_uri = f"{tensorizer_config.tensorizer_dir}/adapter_model.tensors"
with open_stream(lora_uri, mode="wb+", **tensorizer_args.stream_kwargs) as f:
serializer = TensorSerializer(f)
serializer.write_state_dict(tensors)
serializer.close()
logger.info(
"Successfully serialized LoRA files to %s",
str(tensorizer_config.tensorizer_dir),
)