Sync from v0.13
This commit is contained in:
214
vllm/utils/serial_utils.py
Normal file
214
vllm/utils/serial_utils.py
Normal file
@@ -0,0 +1,214 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import base64
|
||||
import io
|
||||
import math
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing_extensions import assert_never
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm import PoolingRequestOutput
|
||||
else:
|
||||
PoolingRequestOutput = Any
|
||||
|
||||
sys_byteorder = sys.byteorder
|
||||
|
||||
|
||||
EMBED_DTYPE_TO_TORCH_DTYPE = {
|
||||
"float32": torch.float32,
|
||||
"float16": torch.float16,
|
||||
"bfloat16": torch.bfloat16,
|
||||
# I'm not sure if other platforms' CPUs support the fp8 data format.
|
||||
# EMBED_DTYPE only uses the fp8 data representation,
|
||||
# does not use fp8 computation, and only occurs on the CPU.
|
||||
# Apologize for any possible break.
|
||||
"fp8_e4m3": torch.float8_e4m3fn,
|
||||
"fp8_e5m2": torch.float8_e5m2,
|
||||
}
|
||||
|
||||
EMBED_DTYPE_TO_N_BYTES = {
|
||||
"float32": 4,
|
||||
"float16": 2,
|
||||
"bfloat16": 2,
|
||||
"fp8_e4m3": 1,
|
||||
"fp8_e5m2": 1,
|
||||
}
|
||||
|
||||
|
||||
EMBED_DTYPE_TO_TORCH_DTYPE_VIEW = {
|
||||
"float32": torch.float32,
|
||||
"float16": torch.float16,
|
||||
# numpy does not support bfloat16 and fp8
|
||||
"bfloat16": torch.float16,
|
||||
"fp8_e4m3": torch.uint8,
|
||||
"fp8_e5m2": torch.uint8,
|
||||
}
|
||||
|
||||
EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW = {
|
||||
"float32": np.float32,
|
||||
"float16": np.float16,
|
||||
# numpy does not support bfloat16 and fp8
|
||||
"bfloat16": np.float16,
|
||||
"fp8_e4m3": np.uint8,
|
||||
"fp8_e5m2": np.uint8,
|
||||
}
|
||||
|
||||
ENDIANNESS = ["native", "big", "little"]
|
||||
|
||||
EmbedDType = Literal["float32", "float16", "bfloat16", "fp8_e4m3", "fp8_e5m2"]
|
||||
Endianness = Literal["native", "big", "little"]
|
||||
EncodingFormat = Literal["float", "base64", "bytes", "bytes_only"]
|
||||
|
||||
|
||||
def tensor2base64(x: torch.Tensor) -> str:
|
||||
with io.BytesIO() as buf:
|
||||
torch.save(x, buf)
|
||||
buf.seek(0)
|
||||
binary_data = buf.read()
|
||||
|
||||
return base64.b64encode(binary_data).decode("utf-8")
|
||||
|
||||
|
||||
def tensor2binary(
|
||||
tensor: torch.Tensor, embed_dtype: EmbedDType, endianness: Endianness
|
||||
) -> bytes:
|
||||
assert isinstance(tensor, torch.Tensor)
|
||||
assert embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE
|
||||
assert endianness in ENDIANNESS
|
||||
|
||||
torch_dtype = EMBED_DTYPE_TO_TORCH_DTYPE[embed_dtype]
|
||||
torch_view_dtype = EMBED_DTYPE_TO_TORCH_DTYPE_VIEW[embed_dtype]
|
||||
|
||||
np_array = (
|
||||
tensor.to(torch_dtype).flatten().contiguous().view(torch_view_dtype).numpy()
|
||||
)
|
||||
|
||||
if endianness != "native" and endianness != sys_byteorder:
|
||||
np_array = np_array.byteswap()
|
||||
|
||||
return np_array.tobytes()
|
||||
|
||||
|
||||
def binary2tensor(
|
||||
binary: bytes,
|
||||
shape: tuple[int, ...],
|
||||
embed_dtype: EmbedDType,
|
||||
endianness: Endianness,
|
||||
) -> torch.Tensor:
|
||||
assert embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE
|
||||
assert embed_dtype in EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW
|
||||
assert endianness in ENDIANNESS
|
||||
|
||||
torch_dtype = EMBED_DTYPE_TO_TORCH_DTYPE[embed_dtype]
|
||||
np_dtype = EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW[embed_dtype]
|
||||
|
||||
np_array = np.frombuffer(binary, dtype=np_dtype).reshape(shape)
|
||||
|
||||
if endianness != "native" and endianness != sys_byteorder:
|
||||
np_array = np_array.byteswap()
|
||||
|
||||
return torch.from_numpy(np_array).view(torch_dtype)
|
||||
|
||||
|
||||
def encode_pooling_output(
|
||||
output: PoolingRequestOutput,
|
||||
encoding_format: EncodingFormat,
|
||||
embed_dtype: EmbedDType,
|
||||
endianness: Endianness,
|
||||
) -> list[float] | str | bytes:
|
||||
if encoding_format == "float":
|
||||
return output.outputs.data.tolist()
|
||||
elif encoding_format == "base64":
|
||||
embedding_bytes = tensor2binary(output.outputs.data, embed_dtype, endianness)
|
||||
return base64.b64encode(embedding_bytes).decode("utf-8")
|
||||
elif encoding_format == "bytes" or encoding_format == "bytes_only":
|
||||
return tensor2binary(output.outputs.data, embed_dtype, endianness)
|
||||
assert_never(encoding_format)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetadataItem:
|
||||
index: int
|
||||
embed_dtype: EmbedDType
|
||||
endianness: Endianness
|
||||
start: int
|
||||
end: int
|
||||
shape: tuple[int, ...]
|
||||
|
||||
|
||||
def build_metadata_items(
|
||||
embed_dtype: EmbedDType,
|
||||
endianness: Endianness,
|
||||
shape: tuple[int, ...],
|
||||
n_request: int,
|
||||
):
|
||||
n_bytes = EMBED_DTYPE_TO_N_BYTES[embed_dtype]
|
||||
size = math.prod(shape)
|
||||
items = [
|
||||
MetadataItem(
|
||||
index=i,
|
||||
embed_dtype=embed_dtype,
|
||||
endianness=endianness,
|
||||
start=i * size * n_bytes,
|
||||
end=(i + 1) * size * n_bytes,
|
||||
shape=shape,
|
||||
)
|
||||
for i in range(n_request)
|
||||
]
|
||||
|
||||
return items
|
||||
|
||||
|
||||
def encode_pooling_bytes(
|
||||
pooling_outputs: list[PoolingRequestOutput],
|
||||
embed_dtype: EmbedDType,
|
||||
endianness: Endianness,
|
||||
):
|
||||
num_prompt_tokens = 0
|
||||
items: list[dict[str, MetadataItem]] = []
|
||||
body = []
|
||||
offset = 0
|
||||
for idx, output in enumerate(pooling_outputs):
|
||||
binary = tensor2binary(
|
||||
tensor=output.outputs.data,
|
||||
embed_dtype=embed_dtype,
|
||||
endianness=endianness,
|
||||
)
|
||||
size = len(binary)
|
||||
|
||||
item = {
|
||||
"index": idx,
|
||||
"embed_dtype": embed_dtype,
|
||||
"endianness": endianness,
|
||||
"start": offset,
|
||||
"end": offset + size,
|
||||
"shape": output.outputs.data.shape,
|
||||
}
|
||||
|
||||
body.append(binary)
|
||||
items.append(item)
|
||||
prompt_token_ids = output.prompt_token_ids
|
||||
num_prompt_tokens += len(prompt_token_ids)
|
||||
offset += size
|
||||
|
||||
usage = {
|
||||
"prompt_tokens": num_prompt_tokens,
|
||||
"total_tokens": num_prompt_tokens,
|
||||
}
|
||||
return body, items, usage
|
||||
|
||||
|
||||
def decode_pooling_output(items: list[MetadataItem], body: bytes) -> list[torch.Tensor]:
|
||||
items.sort(key=lambda x: x.index)
|
||||
|
||||
tensor_list: list[torch.Tensor] = []
|
||||
for item in items:
|
||||
binary = body[item.start : item.end]
|
||||
tensor = binary2tensor(binary, item.shape, item.embed_dtype, item.endianness)
|
||||
tensor_list.append(tensor)
|
||||
return tensor_list
|
||||
Reference in New Issue
Block a user