125 lines
3.1 KiB
Python
125 lines
3.1 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import math
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
|
|
import pybase64
|
|
import torch
|
|
|
|
from vllm.outputs import PoolingRequestOutput
|
|
from vllm.utils.serial_utils import (
|
|
EMBED_DTYPES,
|
|
EmbedDType,
|
|
Endianness,
|
|
binary2tensor,
|
|
tensor2binary,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class MetadataItem:
|
|
index: int
|
|
embed_dtype: EmbedDType
|
|
endianness: Endianness
|
|
start: int
|
|
end: int
|
|
shape: tuple[int, ...]
|
|
|
|
|
|
def build_metadata_items(
|
|
embed_dtype: EmbedDType,
|
|
endianness: Endianness,
|
|
shape: tuple[int, ...],
|
|
n_request: int,
|
|
) -> list[MetadataItem]:
|
|
n_bytes = EMBED_DTYPES[embed_dtype].nbytes
|
|
size = math.prod(shape)
|
|
|
|
return [
|
|
MetadataItem(
|
|
index=i,
|
|
embed_dtype=embed_dtype,
|
|
endianness=endianness,
|
|
start=i * size * n_bytes,
|
|
end=(i + 1) * size * n_bytes,
|
|
shape=shape,
|
|
)
|
|
for i in range(n_request)
|
|
]
|
|
|
|
|
|
def encode_pooling_output_float(output: PoolingRequestOutput) -> list[float]:
|
|
return output.outputs.data.tolist()
|
|
|
|
|
|
def encode_pooling_output_binary(
|
|
output: PoolingRequestOutput,
|
|
embed_dtype: EmbedDType,
|
|
endianness: Endianness,
|
|
) -> bytes:
|
|
return tensor2binary(output.outputs.data, embed_dtype, endianness)
|
|
|
|
|
|
def encode_pooling_output_base64(
|
|
output: PoolingRequestOutput,
|
|
embed_dtype: EmbedDType,
|
|
endianness: Endianness,
|
|
) -> str:
|
|
embedding_bytes = tensor2binary(output.outputs.data, embed_dtype, endianness)
|
|
return pybase64.b64encode(embedding_bytes).decode("utf-8")
|
|
|
|
|
|
def encode_pooling_bytes(
|
|
pooling_outputs: list[PoolingRequestOutput],
|
|
embed_dtype: EmbedDType,
|
|
endianness: Endianness,
|
|
) -> tuple[list[bytes], list[dict[str, Any]], dict[str, Any]]:
|
|
num_prompt_tokens = 0
|
|
items: list[dict[str, Any]] = []
|
|
body: list[bytes] = []
|
|
offset = 0
|
|
for idx, output in enumerate(pooling_outputs):
|
|
binary = tensor2binary(
|
|
tensor=output.outputs.data,
|
|
embed_dtype=embed_dtype,
|
|
endianness=endianness,
|
|
)
|
|
size = len(binary)
|
|
|
|
# Dictionary form of MetadataItem
|
|
item = dict(
|
|
index=idx,
|
|
embed_dtype=embed_dtype,
|
|
endianness=endianness,
|
|
start=offset,
|
|
end=offset + size,
|
|
shape=output.outputs.data.shape,
|
|
)
|
|
|
|
body.append(binary)
|
|
items.append(item)
|
|
prompt_token_ids = output.prompt_token_ids
|
|
num_prompt_tokens += len(prompt_token_ids)
|
|
offset += size
|
|
|
|
# Dictionary form of UsageInfo
|
|
usage = dict(
|
|
prompt_tokens=num_prompt_tokens,
|
|
total_tokens=num_prompt_tokens,
|
|
)
|
|
|
|
return body, items, usage
|
|
|
|
|
|
def decode_pooling_output(items: list[MetadataItem], body: bytes) -> list[torch.Tensor]:
|
|
return [
|
|
binary2tensor(
|
|
body[item.start : item.end],
|
|
item.shape,
|
|
item.embed_dtype,
|
|
item.endianness,
|
|
)
|
|
for item in sorted(items, key=lambda x: x.index)
|
|
]
|