update
This commit is contained in:
98
vllm/utils/serial_utils.py
Normal file
98
vllm/utils/serial_utils.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import io
|
||||
import sys
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from typing import Literal, get_args
|
||||
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import pybase64
|
||||
import torch
|
||||
|
||||
sys_byteorder = sys.byteorder
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DTypeInfo:
|
||||
torch_dtype: torch.dtype
|
||||
|
||||
torch_view_dtype: torch.dtype
|
||||
numpy_view_dtype: npt.DTypeLike
|
||||
|
||||
@property
|
||||
def nbytes(self) -> int:
|
||||
return self.torch_dtype.itemsize
|
||||
|
||||
|
||||
EmbedDType = Literal["float32", "float16", "bfloat16", "fp8_e4m3", "fp8_e5m2"]
|
||||
Endianness = Literal["native", "big", "little"]
|
||||
EncodingFormat = Literal["float", "base64", "bytes", "bytes_only"]
|
||||
|
||||
# I'm not sure if other platforms' CPUs support the fp8 data format.
|
||||
# EMBED_DTYPE only uses the fp8 data representation,
|
||||
# does not use fp8 computation, and only occurs on the CPU.
|
||||
# Apologize for any possible break.
|
||||
# NOTE: numpy does not support bfloat16 and fp8
|
||||
EMBED_DTYPES: Mapping[EmbedDType, DTypeInfo] = {
|
||||
"float32": DTypeInfo(torch.float32, torch.float32, np.float32),
|
||||
"float16": DTypeInfo(torch.float16, torch.float16, np.float16),
|
||||
"bfloat16": DTypeInfo(torch.bfloat16, torch.float16, np.float16),
|
||||
"fp8_e4m3": DTypeInfo(torch.float8_e4m3fn, torch.uint8, np.uint8),
|
||||
"fp8_e5m2": DTypeInfo(torch.float8_e5m2, torch.uint8, np.uint8),
|
||||
}
|
||||
ENDIANNESS: tuple[Endianness, ...] = get_args(Endianness)
|
||||
|
||||
|
||||
def tensor2base64(x: torch.Tensor) -> str:
|
||||
with io.BytesIO() as buf:
|
||||
torch.save(x, buf)
|
||||
buf.seek(0)
|
||||
binary_data = buf.read()
|
||||
|
||||
return pybase64.b64encode(binary_data).decode("utf-8")
|
||||
|
||||
|
||||
def tensor2binary(
|
||||
tensor: torch.Tensor,
|
||||
embed_dtype: EmbedDType,
|
||||
endianness: Endianness,
|
||||
) -> bytes:
|
||||
assert isinstance(tensor, torch.Tensor)
|
||||
assert embed_dtype in EMBED_DTYPES
|
||||
assert endianness in ENDIANNESS
|
||||
|
||||
dtype_info = EMBED_DTYPES[embed_dtype]
|
||||
|
||||
np_array = (
|
||||
tensor.to(dtype_info.torch_dtype)
|
||||
.flatten()
|
||||
.contiguous()
|
||||
.view(dtype_info.torch_view_dtype)
|
||||
.numpy()
|
||||
)
|
||||
|
||||
if endianness != "native" and endianness != sys_byteorder:
|
||||
np_array = np_array.byteswap()
|
||||
|
||||
return np_array.tobytes()
|
||||
|
||||
|
||||
def binary2tensor(
|
||||
binary: bytes,
|
||||
shape: tuple[int, ...],
|
||||
embed_dtype: EmbedDType,
|
||||
endianness: Endianness,
|
||||
) -> torch.Tensor:
|
||||
assert embed_dtype in EMBED_DTYPES
|
||||
assert endianness in ENDIANNESS
|
||||
|
||||
dtype_info = EMBED_DTYPES[embed_dtype]
|
||||
|
||||
np_array = np.frombuffer(binary, dtype=dtype_info.numpy_view_dtype).reshape(shape)
|
||||
|
||||
if endianness != "native" and endianness != sys_byteorder:
|
||||
np_array = np_array.byteswap()
|
||||
|
||||
return torch.from_numpy(np_array).view(dtype_info.torch_dtype)
|
||||
Reference in New Issue
Block a user