Sync from v0.13
This commit is contained in:
514
vllm/v1/serial_utils.py
Normal file
514
vllm/v1/serial_utils.py
Normal file
@@ -0,0 +1,514 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import dataclasses
|
||||
import importlib
|
||||
import pickle
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
from inspect import isclass
|
||||
from types import FunctionType
|
||||
from typing import Any, TypeAlias, get_type_hints
|
||||
|
||||
import cloudpickle
|
||||
import msgspec
|
||||
import numpy as np
|
||||
import torch
|
||||
import zmq
|
||||
from msgspec import msgpack
|
||||
from pydantic import GetCoreSchemaHandler
|
||||
from pydantic_core import core_schema
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal.inputs import (
|
||||
BaseMultiModalField,
|
||||
MultiModalBatchedField,
|
||||
MultiModalFieldConfig,
|
||||
MultiModalFieldElem,
|
||||
MultiModalFlatField,
|
||||
MultiModalKwargsItem,
|
||||
MultiModalKwargsItems,
|
||||
MultiModalSharedField,
|
||||
NestedTensors,
|
||||
)
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.v1.utils import tensor_data
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
CUSTOM_TYPE_PICKLE = 1
|
||||
CUSTOM_TYPE_CLOUDPICKLE = 2
|
||||
CUSTOM_TYPE_RAW_VIEW = 3
|
||||
|
||||
# MultiModalField class serialization type map.
|
||||
# These need to list all possible field types and match them
|
||||
# to factory methods in `MultiModalFieldConfig`.
|
||||
MMF_CLASS_TO_FACTORY: dict[type[BaseMultiModalField], str] = {
|
||||
MultiModalFlatField: "flat",
|
||||
MultiModalSharedField: "shared",
|
||||
MultiModalBatchedField: "batched",
|
||||
}
|
||||
|
||||
bytestr: TypeAlias = bytes | bytearray | memoryview | zmq.Frame
|
||||
|
||||
|
||||
def _log_insecure_serialization_warning():
|
||||
logger.warning_once(
|
||||
"Allowing insecure serialization using pickle due to "
|
||||
"VLLM_ALLOW_INSECURE_SERIALIZATION=1"
|
||||
)
|
||||
|
||||
|
||||
def _typestr(val: Any) -> tuple[str, str] | None:
|
||||
if val is None:
|
||||
return None
|
||||
t = type(val)
|
||||
return t.__module__, t.__qualname__
|
||||
|
||||
|
||||
def _encode_type_info_recursive(obj: Any) -> Any:
|
||||
"""Recursively encode type information for nested structures of
|
||||
lists/dicts."""
|
||||
if obj is None:
|
||||
return None
|
||||
if type(obj) is list:
|
||||
return [_encode_type_info_recursive(item) for item in obj]
|
||||
if type(obj) is dict:
|
||||
return {k: _encode_type_info_recursive(v) for k, v in obj.items()}
|
||||
return _typestr(obj)
|
||||
|
||||
|
||||
def _decode_type_info_recursive(
|
||||
type_info: Any, data: Any, convert_fn: Callable[[Sequence[str], Any], Any]
|
||||
) -> Any:
|
||||
"""Recursively decode type information for nested structures of
|
||||
lists/dicts."""
|
||||
if type_info is None:
|
||||
return data
|
||||
if isinstance(type_info, dict):
|
||||
assert isinstance(data, dict)
|
||||
return {
|
||||
k: _decode_type_info_recursive(type_info[k], data[k], convert_fn)
|
||||
for k in type_info
|
||||
}
|
||||
if isinstance(type_info, list) and (
|
||||
# Exclude serialized tensors/numpy arrays.
|
||||
len(type_info) != 2 or not isinstance(type_info[0], str)
|
||||
):
|
||||
assert isinstance(data, list)
|
||||
return [
|
||||
_decode_type_info_recursive(ti, d, convert_fn)
|
||||
for ti, d in zip(type_info, data)
|
||||
]
|
||||
return convert_fn(type_info, data)
|
||||
|
||||
|
||||
class UtilityResult:
|
||||
"""Wrapper for special handling when serializing/deserializing."""
|
||||
|
||||
def __init__(self, r: Any = None):
|
||||
self.result = r
|
||||
|
||||
|
||||
class MsgpackEncoder:
|
||||
"""Encoder with custom torch tensor and numpy array serialization.
|
||||
|
||||
Note that unlike vanilla `msgspec` Encoders, this interface is generally
|
||||
not thread-safe when encoding tensors / numpy arrays.
|
||||
|
||||
By default, arrays below 256B are serialized inline Larger will get sent
|
||||
via dedicated messages. Note that this is a per-tensor limit.
|
||||
"""
|
||||
|
||||
def __init__(self, size_threshold: int | None = None):
|
||||
if size_threshold is None:
|
||||
size_threshold = envs.VLLM_MSGPACK_ZERO_COPY_THRESHOLD
|
||||
self.encoder = msgpack.Encoder(enc_hook=self.enc_hook)
|
||||
# This is used as a local stash of buffers that we can then access from
|
||||
# our custom `msgspec` hook, `enc_hook`. We don't have a way to
|
||||
# pass custom data to the hook otherwise.
|
||||
self.aux_buffers: list[bytestr] | None = None
|
||||
self.size_threshold = size_threshold
|
||||
if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
||||
_log_insecure_serialization_warning()
|
||||
|
||||
def encode(self, obj: Any) -> Sequence[bytestr]:
|
||||
try:
|
||||
self.aux_buffers = bufs = [b""]
|
||||
bufs[0] = self.encoder.encode(obj)
|
||||
# This `bufs` list allows us to collect direct pointers to backing
|
||||
# buffers of tensors and np arrays, and return them along with the
|
||||
# top-level encoded buffer instead of copying their data into the
|
||||
# new buffer.
|
||||
return bufs
|
||||
finally:
|
||||
self.aux_buffers = None
|
||||
|
||||
def encode_into(self, obj: Any, buf: bytearray) -> Sequence[bytestr]:
|
||||
try:
|
||||
self.aux_buffers = [buf]
|
||||
bufs = self.aux_buffers
|
||||
self.encoder.encode_into(obj, buf)
|
||||
return bufs
|
||||
finally:
|
||||
self.aux_buffers = None
|
||||
|
||||
def enc_hook(self, obj: Any) -> Any:
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return self._encode_tensor(obj)
|
||||
|
||||
# Fall back to pickle for object or void kind ndarrays.
|
||||
if isinstance(obj, np.ndarray) and obj.dtype.kind not in ("O", "V"):
|
||||
return self._encode_ndarray(obj)
|
||||
|
||||
if isinstance(obj, slice):
|
||||
# We are assuming only int-based values will be used here.
|
||||
return tuple(
|
||||
int(v) if v is not None else None
|
||||
for v in (obj.start, obj.stop, obj.step)
|
||||
)
|
||||
|
||||
if isinstance(obj, MultiModalKwargsItem):
|
||||
return self._encode_mm_item(obj)
|
||||
|
||||
if isinstance(obj, MultiModalKwargsItems):
|
||||
return self._encode_mm_items(obj)
|
||||
|
||||
if isinstance(obj, UtilityResult):
|
||||
result = obj.result
|
||||
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
||||
return None, result
|
||||
# Since utility results are not strongly typed, we recursively
|
||||
# encode type information for nested structures of lists/dicts
|
||||
# to help with correct msgspec deserialization.
|
||||
return _encode_type_info_recursive(result), result
|
||||
|
||||
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
||||
raise TypeError(
|
||||
f"Object of type {type(obj)} is not serializable"
|
||||
"Set VLLM_ALLOW_INSECURE_SERIALIZATION=1 to allow "
|
||||
"fallback to pickle-based serialization."
|
||||
)
|
||||
|
||||
if isinstance(obj, FunctionType):
|
||||
# `pickle` is generally faster than cloudpickle, but can have
|
||||
# problems serializing methods.
|
||||
return msgpack.Ext(CUSTOM_TYPE_CLOUDPICKLE, cloudpickle.dumps(obj))
|
||||
|
||||
return msgpack.Ext(
|
||||
CUSTOM_TYPE_PICKLE, pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
)
|
||||
|
||||
def _encode_ndarray(
|
||||
self, obj: np.ndarray
|
||||
) -> tuple[str, tuple[int, ...], int | memoryview]:
|
||||
assert self.aux_buffers is not None
|
||||
# If the array is non-contiguous, we need to copy it first
|
||||
arr_data = obj.data if obj.flags.c_contiguous else obj.tobytes()
|
||||
if not obj.shape or obj.nbytes < self.size_threshold:
|
||||
# Encode small arrays and scalars inline. Using this extension type
|
||||
# ensures we can avoid copying when decoding.
|
||||
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
|
||||
else:
|
||||
# Otherwise encode index of backing buffer to avoid copy.
|
||||
data = len(self.aux_buffers)
|
||||
self.aux_buffers.append(arr_data)
|
||||
|
||||
# We serialize the ndarray as a tuple of native types.
|
||||
# The data is either inlined if small, or an index into a list of
|
||||
# backing buffers that we've stashed in `aux_buffers`.
|
||||
return obj.dtype.str, obj.shape, data
|
||||
|
||||
def _encode_tensor(
|
||||
self, obj: torch.Tensor
|
||||
) -> tuple[str, tuple[int, ...], int | memoryview]:
|
||||
assert self.aux_buffers is not None
|
||||
# view the tensor as a contiguous 1D array of bytes
|
||||
arr_data = tensor_data(obj)
|
||||
if obj.nbytes < self.size_threshold:
|
||||
# Smaller tensors are encoded inline, just like ndarrays.
|
||||
data = msgpack.Ext(CUSTOM_TYPE_RAW_VIEW, arr_data)
|
||||
else:
|
||||
# Otherwise encode index of backing buffer to avoid copy.
|
||||
data = len(self.aux_buffers)
|
||||
self.aux_buffers.append(arr_data)
|
||||
dtype = str(obj.dtype).removeprefix("torch.")
|
||||
return dtype, obj.shape, data
|
||||
|
||||
def _encode_mm_items(self, items: MultiModalKwargsItems) -> dict[str, Any]:
|
||||
return {
|
||||
modality: [self._encode_mm_item(item) for item in itemlist]
|
||||
for modality, itemlist in items.items()
|
||||
}
|
||||
|
||||
def _encode_mm_item(self, item: MultiModalKwargsItem) -> list[dict[str, Any]]:
|
||||
return [self._encode_mm_field_elem(elem) for elem in item.values()]
|
||||
|
||||
def _encode_mm_field_elem(self, elem: MultiModalFieldElem) -> dict[str, Any]:
|
||||
return {
|
||||
"modality": elem.modality,
|
||||
"key": elem.key,
|
||||
"data": (
|
||||
None if elem.data is None else self._encode_nested_tensors(elem.data)
|
||||
),
|
||||
"field": self._encode_mm_field(elem.field),
|
||||
}
|
||||
|
||||
def _encode_nested_tensors(self, nt: NestedTensors) -> Any:
|
||||
if isinstance(nt, torch.Tensor):
|
||||
return self._encode_tensor(nt)
|
||||
if isinstance(nt, (int, float)):
|
||||
# Although it violates NestedTensors type, MultiModalKwargs
|
||||
# values are sometimes floats.
|
||||
return nt
|
||||
return [self._encode_nested_tensors(x) for x in nt]
|
||||
|
||||
def _encode_mm_field(self, field: BaseMultiModalField):
|
||||
# Figure out the factory name for the field type.
|
||||
name = MMF_CLASS_TO_FACTORY.get(field.__class__)
|
||||
if not name:
|
||||
raise TypeError(f"Unsupported field type: {field.__class__}")
|
||||
|
||||
# We just need to copy all of the field values in order
|
||||
# which will be then used to reconstruct the field.
|
||||
factory_kw = {f.name: getattr(field, f.name) for f in dataclasses.fields(field)}
|
||||
return name, factory_kw
|
||||
|
||||
|
||||
class MsgpackDecoder:
|
||||
"""Decoder with custom torch tensor and numpy array serialization.
|
||||
|
||||
Note that unlike vanilla `msgspec` Decoders, this interface is generally
|
||||
not thread-safe when encoding tensors / numpy arrays.
|
||||
"""
|
||||
|
||||
def __init__(self, t: Any | None = None, share_mem: bool = True):
|
||||
self.share_mem = share_mem
|
||||
self.pin_tensors = is_pin_memory_available()
|
||||
args = () if t is None else (t,)
|
||||
self.decoder = msgpack.Decoder(
|
||||
*args, ext_hook=self.ext_hook, dec_hook=self.dec_hook
|
||||
)
|
||||
self.aux_buffers: Sequence[bytestr] = ()
|
||||
if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
||||
_log_insecure_serialization_warning()
|
||||
|
||||
def decode(self, bufs: bytestr | Sequence[bytestr]) -> Any:
|
||||
if isinstance(bufs, bytestr): # type: ignore
|
||||
return self.decoder.decode(bufs)
|
||||
|
||||
self.aux_buffers = bufs
|
||||
try:
|
||||
return self.decoder.decode(bufs[0])
|
||||
finally:
|
||||
self.aux_buffers = ()
|
||||
|
||||
def dec_hook(self, t: type, obj: Any) -> Any:
|
||||
# Given native types in `obj`, convert to type `t`.
|
||||
if isclass(t):
|
||||
if issubclass(t, np.ndarray):
|
||||
return self._decode_ndarray(obj)
|
||||
if issubclass(t, torch.Tensor):
|
||||
return self._decode_tensor(obj)
|
||||
if t is slice:
|
||||
return slice(*obj)
|
||||
if issubclass(t, MultiModalKwargsItem):
|
||||
return self._decode_mm_item(obj)
|
||||
if issubclass(t, MultiModalKwargsItems):
|
||||
return self._decode_mm_items(obj)
|
||||
if t is UtilityResult:
|
||||
return self._decode_utility_result(obj)
|
||||
return obj
|
||||
|
||||
def _decode_utility_result(self, obj: Any) -> UtilityResult:
|
||||
result_type, result = obj
|
||||
if result_type is not None:
|
||||
if not envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
||||
raise TypeError(
|
||||
"VLLM_ALLOW_INSECURE_SERIALIZATION must "
|
||||
"be set to use custom utility result types"
|
||||
)
|
||||
# Use recursive decoding to handle nested structures
|
||||
result = _decode_type_info_recursive(
|
||||
result_type, result, self._convert_result
|
||||
)
|
||||
return UtilityResult(result)
|
||||
|
||||
def _convert_result(self, result_type: Sequence[str], result: Any) -> Any:
|
||||
if result_type is None:
|
||||
return result
|
||||
mod_name, name = result_type
|
||||
mod = importlib.import_module(mod_name)
|
||||
result_type = getattr(mod, name)
|
||||
return msgspec.convert(result, result_type, dec_hook=self.dec_hook)
|
||||
|
||||
def _decode_ndarray(self, arr: Any) -> np.ndarray:
|
||||
dtype, shape, data = arr
|
||||
# zero-copy decode. We assume the ndarray will not be kept around,
|
||||
# as it now locks the whole received message buffer in memory.
|
||||
buffer = self.aux_buffers[data] if isinstance(data, int) else data
|
||||
arr = np.frombuffer(buffer, dtype=dtype)
|
||||
if not self.share_mem:
|
||||
arr = arr.copy()
|
||||
return arr.reshape(shape)
|
||||
|
||||
def _decode_tensor(self, arr: Any) -> torch.Tensor:
|
||||
dtype, shape, data = arr
|
||||
is_aux = isinstance(data, int)
|
||||
buffer = self.aux_buffers[data] if is_aux else data
|
||||
buffer = buffer if isinstance(buffer, memoryview) else memoryview(buffer)
|
||||
torch_dtype = getattr(torch, dtype)
|
||||
assert isinstance(torch_dtype, torch.dtype)
|
||||
if not buffer.nbytes: # torch.frombuffer doesn't like empty buffers
|
||||
assert 0 in shape
|
||||
return torch.empty(shape, dtype=torch_dtype)
|
||||
# Create uint8 array
|
||||
arr = torch.frombuffer(buffer, dtype=torch.uint8)
|
||||
# Clone ensures tensor is backed by pytorch-owned memory for safe
|
||||
# future async CPU->GPU transfer.
|
||||
# Pin larger tensors for more efficient CPU->GPU transfer.
|
||||
if not is_aux:
|
||||
arr = arr.clone()
|
||||
elif not self.share_mem:
|
||||
arr = arr.pin_memory() if self.pin_tensors else arr.clone()
|
||||
# Convert back to proper shape & type
|
||||
return arr.view(torch_dtype).view(shape)
|
||||
|
||||
def _decode_mm_items(self, obj: dict[str, Any]) -> MultiModalKwargsItems:
|
||||
return MultiModalKwargsItems(
|
||||
{
|
||||
modality: [self._decode_mm_item(item) for item in itemlist]
|
||||
for modality, itemlist in obj.items()
|
||||
}
|
||||
)
|
||||
|
||||
def _decode_mm_item(self, obj: list[Any]) -> MultiModalKwargsItem:
|
||||
return MultiModalKwargsItem.from_elems(
|
||||
[self._decode_mm_field_elem(v) for v in obj]
|
||||
)
|
||||
|
||||
def _decode_mm_field_elem(self, obj: dict[str, Any]) -> MultiModalFieldElem:
|
||||
if obj["data"] is not None:
|
||||
obj["data"] = self._decode_nested_tensors(obj["data"])
|
||||
|
||||
# Reconstruct the field processor using MultiModalFieldConfig
|
||||
factory_meth_name, factory_kw = obj["field"]
|
||||
factory_meth = getattr(MultiModalFieldConfig, factory_meth_name)
|
||||
|
||||
# Special case: decode the union "slices" field of
|
||||
# MultiModalFlatField
|
||||
if factory_meth_name == "flat":
|
||||
factory_kw["slices"] = self._decode_nested_slices(factory_kw["slices"])
|
||||
|
||||
obj["field"] = factory_meth("", **factory_kw).field
|
||||
return MultiModalFieldElem(**obj)
|
||||
|
||||
def _decode_nested_tensors(self, obj: Any) -> NestedTensors:
|
||||
if isinstance(obj, (int, float)):
|
||||
# Although it violates NestedTensors type, MultiModalKwargs
|
||||
# values are sometimes floats.
|
||||
return obj
|
||||
if not isinstance(obj, list):
|
||||
raise TypeError(f"Unexpected NestedTensors contents: {type(obj)}")
|
||||
if obj and isinstance(obj[0], str):
|
||||
return self._decode_tensor(obj)
|
||||
return [self._decode_nested_tensors(x) for x in obj]
|
||||
|
||||
def _decode_nested_slices(self, obj: Any) -> Any:
|
||||
assert isinstance(obj, (list, tuple))
|
||||
if obj and not isinstance(obj[0], (list, tuple)):
|
||||
return slice(*obj)
|
||||
return [self._decode_nested_slices(x) for x in obj]
|
||||
|
||||
def ext_hook(self, code: int, data: memoryview) -> Any:
|
||||
if code == CUSTOM_TYPE_RAW_VIEW:
|
||||
return data
|
||||
|
||||
if envs.VLLM_ALLOW_INSECURE_SERIALIZATION:
|
||||
if code == CUSTOM_TYPE_PICKLE:
|
||||
return pickle.loads(data)
|
||||
if code == CUSTOM_TYPE_CLOUDPICKLE:
|
||||
return cloudpickle.loads(data)
|
||||
|
||||
raise NotImplementedError(f"Extension type code {code} is not supported")
|
||||
|
||||
|
||||
def run_method(
|
||||
obj: Any,
|
||||
method: str | bytes | Callable,
|
||||
args: tuple[Any, ...],
|
||||
kwargs: dict[str, Any],
|
||||
) -> Any:
|
||||
"""
|
||||
Run a method of an object with the given arguments and keyword arguments.
|
||||
If the method is string, it will be converted to a method using getattr.
|
||||
If the method is serialized bytes and will be deserialized using
|
||||
cloudpickle.
|
||||
If the method is a callable, it will be called directly.
|
||||
"""
|
||||
if isinstance(method, bytes):
|
||||
func = partial(cloudpickle.loads(method), obj)
|
||||
elif isinstance(method, str):
|
||||
try:
|
||||
func = getattr(obj, method)
|
||||
except AttributeError:
|
||||
raise NotImplementedError(
|
||||
f"Method {method!r} is not implemented."
|
||||
) from None
|
||||
else:
|
||||
func = partial(method, obj) # type: ignore
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
class PydanticMsgspecMixin:
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, source_type: Any, handler: GetCoreSchemaHandler
|
||||
) -> core_schema.CoreSchema:
|
||||
"""
|
||||
Make msgspec.Struct compatible with Pydantic, respecting defaults.
|
||||
Handle JSON=>msgspec.Struct. Used when exposing msgspec.Struct to the
|
||||
API as input or in `/docs`. Note this is cached by Pydantic and not
|
||||
called on every validation.
|
||||
"""
|
||||
msgspec_fields = {f.name: f for f in msgspec.structs.fields(source_type)}
|
||||
type_hints = get_type_hints(source_type)
|
||||
|
||||
# Build the Pydantic typed_dict_field for each msgspec field
|
||||
fields = {}
|
||||
for name, hint in type_hints.items():
|
||||
msgspec_field = msgspec_fields[name]
|
||||
|
||||
# typed_dict_field using the handler to get the schema
|
||||
field_schema = handler(hint)
|
||||
|
||||
# Add default value to the schema.
|
||||
if msgspec_field.default_factory is not msgspec.NODEFAULT:
|
||||
wrapped_schema = core_schema.with_default_schema(
|
||||
schema=field_schema,
|
||||
default_factory=msgspec_field.default_factory,
|
||||
)
|
||||
fields[name] = core_schema.typed_dict_field(wrapped_schema)
|
||||
elif msgspec_field.default is not msgspec.NODEFAULT:
|
||||
wrapped_schema = core_schema.with_default_schema(
|
||||
schema=field_schema,
|
||||
default=msgspec_field.default,
|
||||
)
|
||||
fields[name] = core_schema.typed_dict_field(wrapped_schema)
|
||||
else:
|
||||
# No default, so Pydantic will treat it as required
|
||||
fields[name] = core_schema.typed_dict_field(field_schema)
|
||||
return core_schema.no_info_after_validator_function(
|
||||
cls._validate_msgspec,
|
||||
core_schema.typed_dict_schema(fields),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _validate_msgspec(cls, value: Any) -> Any:
|
||||
"""Validate and convert input to msgspec.Struct instance."""
|
||||
if isinstance(value, cls):
|
||||
return value
|
||||
if isinstance(value, dict):
|
||||
return cls(**value)
|
||||
return msgspec.convert(value, type=cls)
|
||||
Reference in New Issue
Block a user