170 lines
4.8 KiB
Python
170 lines
4.8 KiB
Python
|
|
# SPDX-License-Identifier: Apache-2.0
|
||
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||
|
|
import base64
|
||
|
|
import sys
|
||
|
|
from dataclasses import dataclass
|
||
|
|
from typing import Literal
|
||
|
|
|
||
|
|
import numpy as np
|
||
|
|
import torch
|
||
|
|
from typing_extensions import assert_never
|
||
|
|
|
||
|
|
from vllm import PoolingRequestOutput
|
||
|
|
|
||
|
|
sys_byteorder = sys.byteorder
|
||
|
|
|
||
|
|
|
||
|
|
EMBED_DTYPE_TO_TORCH_DTYPE = {
|
||
|
|
"float32": torch.float32,
|
||
|
|
"float16": torch.float16,
|
||
|
|
"bfloat16": torch.bfloat16,
|
||
|
|
# I'm not sure if other platforms' CPUs support the fp8 data format.
|
||
|
|
# EMBED_DTYPE only uses the fp8 data representation,
|
||
|
|
# does not use fp8 computation, and only occurs on the CPU.
|
||
|
|
# Apologize for any possible break.
|
||
|
|
"fp8_e4m3": torch.float8_e4m3fn,
|
||
|
|
"fp8_e5m2": torch.float8_e5m2,
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
EMBED_DTYPE_TO_TORCH_DTYPE_VIEW = {
|
||
|
|
"float32": torch.float32,
|
||
|
|
"float16": torch.float16,
|
||
|
|
# numpy does not support bfloat16 and fp8
|
||
|
|
"bfloat16": torch.float16,
|
||
|
|
"fp8_e4m3": torch.uint8,
|
||
|
|
"fp8_e5m2": torch.uint8,
|
||
|
|
}
|
||
|
|
|
||
|
|
EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW = {
|
||
|
|
"float32": np.float32,
|
||
|
|
"float16": np.float16,
|
||
|
|
# numpy does not support bfloat16 and fp8
|
||
|
|
"bfloat16": np.float16,
|
||
|
|
"fp8_e4m3": np.uint8,
|
||
|
|
"fp8_e5m2": np.uint8,
|
||
|
|
}
|
||
|
|
|
||
|
|
ENDIANNESS = ["native", "big", "little"]
|
||
|
|
|
||
|
|
EmbedDType = Literal["float32", "float16", "bfloat16", "fp8_e4m3", "fp8_e5m2"]
|
||
|
|
Endianness = Literal["native", "big", "little"]
|
||
|
|
EncodingFormat = Literal["float", "base64", "bytes"]
|
||
|
|
|
||
|
|
|
||
|
|
def tensor2binary(
|
||
|
|
tensor: torch.Tensor, embed_dtype: EmbedDType, endianness: Endianness
|
||
|
|
) -> bytes:
|
||
|
|
assert isinstance(tensor, torch.Tensor)
|
||
|
|
assert embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE
|
||
|
|
assert endianness in ENDIANNESS
|
||
|
|
|
||
|
|
torch_dtype = EMBED_DTYPE_TO_TORCH_DTYPE[embed_dtype]
|
||
|
|
torch_view_dtype = EMBED_DTYPE_TO_TORCH_DTYPE_VIEW[embed_dtype]
|
||
|
|
|
||
|
|
np_array = (
|
||
|
|
tensor.to(torch_dtype).flatten().contiguous().view(torch_view_dtype).numpy()
|
||
|
|
)
|
||
|
|
|
||
|
|
if endianness != "native" and endianness != sys_byteorder:
|
||
|
|
np_array = np_array.byteswap()
|
||
|
|
|
||
|
|
return np_array.tobytes()
|
||
|
|
|
||
|
|
|
||
|
|
def binary2tensor(
|
||
|
|
binary: bytes,
|
||
|
|
shape: tuple[int, ...],
|
||
|
|
embed_dtype: EmbedDType,
|
||
|
|
endianness: Endianness,
|
||
|
|
) -> torch.Tensor:
|
||
|
|
assert embed_dtype in EMBED_DTYPE_TO_TORCH_DTYPE
|
||
|
|
assert embed_dtype in EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW
|
||
|
|
assert endianness in ENDIANNESS
|
||
|
|
|
||
|
|
torch_dtype = EMBED_DTYPE_TO_TORCH_DTYPE[embed_dtype]
|
||
|
|
np_dtype = EMBED_DTYPE_TO_NUMPY_DTYPE_VIEW[embed_dtype]
|
||
|
|
|
||
|
|
np_array = np.frombuffer(binary, dtype=np_dtype).reshape(shape)
|
||
|
|
|
||
|
|
if endianness != "native" and endianness != sys_byteorder:
|
||
|
|
np_array = np_array.byteswap()
|
||
|
|
|
||
|
|
return torch.from_numpy(np_array).view(torch_dtype)
|
||
|
|
|
||
|
|
|
||
|
|
def encode_pooling_output(
|
||
|
|
output: PoolingRequestOutput,
|
||
|
|
encoding_format: EncodingFormat,
|
||
|
|
embed_dtype: EmbedDType,
|
||
|
|
endianness: Endianness,
|
||
|
|
) -> list[float] | str | bytes:
|
||
|
|
if encoding_format == "float":
|
||
|
|
return output.outputs.data.tolist()
|
||
|
|
elif encoding_format == "base64":
|
||
|
|
embedding_bytes = tensor2binary(output.outputs.data, embed_dtype, endianness)
|
||
|
|
return base64.b64encode(embedding_bytes).decode("utf-8")
|
||
|
|
elif encoding_format == "bytes":
|
||
|
|
return tensor2binary(output.outputs.data, embed_dtype, endianness)
|
||
|
|
assert_never(encoding_format)
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class MetadataItem:
|
||
|
|
index: int
|
||
|
|
embed_dtype: EmbedDType
|
||
|
|
endianness: Endianness
|
||
|
|
start: int
|
||
|
|
end: int
|
||
|
|
shape: tuple[int, ...]
|
||
|
|
|
||
|
|
|
||
|
|
def encode_pooling_bytes(
|
||
|
|
pooling_outputs: list[PoolingRequestOutput],
|
||
|
|
embed_dtype: EmbedDType,
|
||
|
|
endianness: Endianness,
|
||
|
|
):
|
||
|
|
num_prompt_tokens = 0
|
||
|
|
items: list[dict[str, MetadataItem]] = []
|
||
|
|
body = []
|
||
|
|
offset = 0
|
||
|
|
for idx, output in enumerate(pooling_outputs):
|
||
|
|
binary = tensor2binary(
|
||
|
|
tensor=output.outputs.data,
|
||
|
|
embed_dtype=embed_dtype,
|
||
|
|
endianness=endianness,
|
||
|
|
)
|
||
|
|
size = len(binary)
|
||
|
|
|
||
|
|
item = {
|
||
|
|
"index": idx,
|
||
|
|
"embed_dtype": embed_dtype,
|
||
|
|
"endianness": endianness,
|
||
|
|
"start": offset,
|
||
|
|
"end": offset + size,
|
||
|
|
"shape": output.outputs.data.shape,
|
||
|
|
}
|
||
|
|
|
||
|
|
body.append(binary)
|
||
|
|
items.append(item)
|
||
|
|
prompt_token_ids = output.prompt_token_ids
|
||
|
|
num_prompt_tokens += len(prompt_token_ids)
|
||
|
|
offset += size
|
||
|
|
|
||
|
|
usage = {
|
||
|
|
"prompt_tokens": num_prompt_tokens,
|
||
|
|
"total_tokens": num_prompt_tokens,
|
||
|
|
}
|
||
|
|
return body, items, usage
|
||
|
|
|
||
|
|
|
||
|
|
def decode_pooling_output(items: list[MetadataItem], body: bytes) -> list[torch.Tensor]:
|
||
|
|
items.sort(key=lambda x: x.index)
|
||
|
|
|
||
|
|
tensor_list: list[torch.Tensor] = []
|
||
|
|
for item in items:
|
||
|
|
binary = body[item.start : item.end]
|
||
|
|
tensor = binary2tensor(binary, item.shape, item.embed_dtype, item.endianness)
|
||
|
|
tensor_list.append(tensor)
|
||
|
|
return tensor_list
|