Files
bi_150-vllm/vllm/entrypoints/pooling/utils.py

125 lines
3.1 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from dataclasses import dataclass
from typing import Any
import pybase64
import torch
from vllm.outputs import PoolingRequestOutput
from vllm.utils.serial_utils import (
EMBED_DTYPES,
EmbedDType,
Endianness,
binary2tensor,
tensor2binary,
)
@dataclass
class MetadataItem:
index: int
embed_dtype: EmbedDType
endianness: Endianness
start: int
end: int
shape: tuple[int, ...]
def build_metadata_items(
embed_dtype: EmbedDType,
endianness: Endianness,
shape: tuple[int, ...],
n_request: int,
) -> list[MetadataItem]:
n_bytes = EMBED_DTYPES[embed_dtype].nbytes
size = math.prod(shape)
return [
MetadataItem(
index=i,
embed_dtype=embed_dtype,
endianness=endianness,
start=i * size * n_bytes,
end=(i + 1) * size * n_bytes,
shape=shape,
)
for i in range(n_request)
]
def encode_pooling_output_float(output: PoolingRequestOutput) -> list[float]:
return output.outputs.data.tolist()
def encode_pooling_output_binary(
output: PoolingRequestOutput,
embed_dtype: EmbedDType,
endianness: Endianness,
) -> bytes:
return tensor2binary(output.outputs.data, embed_dtype, endianness)
def encode_pooling_output_base64(
output: PoolingRequestOutput,
embed_dtype: EmbedDType,
endianness: Endianness,
) -> str:
embedding_bytes = tensor2binary(output.outputs.data, embed_dtype, endianness)
return pybase64.b64encode(embedding_bytes).decode("utf-8")
def encode_pooling_bytes(
pooling_outputs: list[PoolingRequestOutput],
embed_dtype: EmbedDType,
endianness: Endianness,
) -> tuple[list[bytes], list[dict[str, Any]], dict[str, Any]]:
num_prompt_tokens = 0
items: list[dict[str, Any]] = []
body: list[bytes] = []
offset = 0
for idx, output in enumerate(pooling_outputs):
binary = tensor2binary(
tensor=output.outputs.data,
embed_dtype=embed_dtype,
endianness=endianness,
)
size = len(binary)
# Dictionary form of MetadataItem
item = dict(
index=idx,
embed_dtype=embed_dtype,
endianness=endianness,
start=offset,
end=offset + size,
shape=output.outputs.data.shape,
)
body.append(binary)
items.append(item)
prompt_token_ids = output.prompt_token_ids
num_prompt_tokens += len(prompt_token_ids)
offset += size
# Dictionary form of UsageInfo
usage = dict(
prompt_tokens=num_prompt_tokens,
total_tokens=num_prompt_tokens,
)
return body, items, usage
def decode_pooling_output(items: list[MetadataItem], body: bytes) -> list[torch.Tensor]:
return [
binary2tensor(
body[item.start : item.end],
item.shape,
item.embed_dtype,
item.endianness,
)
for item in sorted(items, key=lambda x: x.index)
]