Files
enginex-bi_150-vllm/utils/serial_utils.py

170 lines
4.8 KiB
Python
Raw Normal View History

2026-03-05 18:06:10 +08:00
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import base64
import sys
from dataclasses import dataclass
from typing import Literal
import numpy as np
import torch
from typing_extensions import assert_never
from vllm import PoolingRequestOutput
sys_byteorder = sys.byteorder
EMBED_DTYPE_TO_TORCH_DTYPE = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
# I'm not sure if other platforms' CPUs support the fp8 data format.
# EMBED_DTYPE only uses the fp8 data representation,
# does not use fp8 computation, and only occurs on the CPU.
# Apologize for any possible break.
"fp8_e4m3": torch.float8_e4m3fn,
"fp8_e5m2": torch.float8_e5m2,
}
EMBED_DTYPE_TO_TORCH_DTYPE_VIEW = {
"float32": torch.float32,
"float16": torch.float16,
# numpy does not support bfloat16 and fp8
"bfloat16": torch.float16,
"fp8_e4m3": torch.uint8,
"fp8_e5m2": torch.uint8,
}
EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW = {
"float32": np.float32,
"float16": np.float16,
# numpy does not support bfloat16 and fp8
"bfloat16": np.float16,
"fp8_e4m3": np.uint8,
"fp8_e5m2": np.uint8,
}
ENDIANNESS = ["native", "big", "little"]
EmbedDType = Literal["float32", "float16", "bfloat16", "fp8_e4m3", "fp8_e5m2"]
Endianness = Literal["native", "big", "little"]
EncodingFormat = Literal["float", "base64", "bytes"]
def tensor2binary(
tensor: torch.Tensor, embed_dtype: EmbedDType, endianness: Endianness
) -> bytes:
assert isinstance(tensor, torch.Tensor)
assert embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE
assert endianness in ENDIANNESS
torch_dtype = EMBED_DTYPE_TO_TORCH_DTYPE[embed_dtype]
torch_view_dtype = EMBED_DTYPE_TO_TORCH_DTYPE_VIEW[embed_dtype]
np_array = (
tensor.to(torch_dtype).flatten().contiguous().view(torch_view_dtype).numpy()
)
if endianness != "native" and endianness != sys_byteorder:
np_array = np_array.byteswap()
return np_array.tobytes()
def binary2tensor(
binary: bytes,
shape: tuple[int, ...],
embed_dtype: EmbedDType,
endianness: Endianness,
) -> torch.Tensor:
assert embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE
assert embed_dtype in EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW
assert endianness in ENDIANNESS
torch_dtype = EMBED_DTYPE_TO_TORCH_DTYPE[embed_dtype]
np_dtype = EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW[embed_dtype]
np_array = np.frombuffer(binary, dtype=np_dtype).reshape(shape)
if endianness != "native" and endianness != sys_byteorder:
np_array = np_array.byteswap()
return torch.from_numpy(np_array).view(torch_dtype)
def encode_pooling_output(
output: PoolingRequestOutput,
encoding_format: EncodingFormat,
embed_dtype: EmbedDType,
endianness: Endianness,
) -> list[float] | str | bytes:
if encoding_format == "float":
return output.outputs.data.tolist()
elif encoding_format == "base64":
embedding_bytes = tensor2binary(output.outputs.data, embed_dtype, endianness)
return base64.b64encode(embedding_bytes).decode("utf-8")
elif encoding_format == "bytes":
return tensor2binary(output.outputs.data, embed_dtype, endianness)
assert_never(encoding_format)
@dataclass
class MetadataItem:
index: int
embed_dtype: EmbedDType
endianness: Endianness
start: int
end: int
shape: tuple[int, ...]
def encode_pooling_bytes(
pooling_outputs: list[PoolingRequestOutput],
embed_dtype: EmbedDType,
endianness: Endianness,
):
num_prompt_tokens = 0
items: list[dict[str, MetadataItem]] = []
body = []
offset = 0
for idx, output in enumerate(pooling_outputs):
binary = tensor2binary(
tensor=output.outputs.data,
embed_dtype=embed_dtype,
endianness=endianness,
)
size = len(binary)
item = {
"index": idx,
"embed_dtype": embed_dtype,
"endianness": endianness,
"start": offset,
"end": offset + size,
"shape": output.outputs.data.shape,
}
body.append(binary)
items.append(item)
prompt_token_ids = output.prompt_token_ids
num_prompt_tokens += len(prompt_token_ids)
offset += size
usage = {
"prompt_tokens": num_prompt_tokens,
"total_tokens": num_prompt_tokens,
}
return body, items, usage
def decode_pooling_output(items: list[MetadataItem], body: bytes) -> list[torch.Tensor]:
items.sort(key=lambda x: x.index)
tensor_list: list[torch.Tensor] = []
for item in items:
binary = body[item.start : item.end]
tensor = binary2tensor(binary, item.shape, item.embed_dtype, item.endianness)
tensor_list.append(tensor)
return tensor_list