v1.0
This commit is contained in:
208
logprobs.py
Normal file
208
logprobs.py
Normal file
@@ -0,0 +1,208 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import itertools
|
||||
from collections.abc import Iterable, Iterator, MutableSequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import overload
|
||||
|
||||
import vllm.envs as envs
|
||||
|
||||
|
||||
# We use dataclass for now because it is used for
|
||||
# openai server output, and msgspec is not serializable.
|
||||
# TODO(sang): Fix it.
|
||||
@dataclass
|
||||
class Logprob:
|
||||
"""Infos for supporting OpenAI compatible logprobs and token ranks.
|
||||
|
||||
Attributes:
|
||||
logprob: The logprob of chosen token
|
||||
rank: The vocab rank of chosen token (>=1)
|
||||
decoded_token: The decoded chosen token index
|
||||
"""
|
||||
|
||||
logprob: float
|
||||
rank: int | None = None
|
||||
decoded_token: str | None = None
|
||||
|
||||
|
||||
LogprobsOnePosition = dict[int, Logprob]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FlatLogprobs(MutableSequence[LogprobsOnePosition]):
|
||||
"""
|
||||
Flat logprobs of a request into multiple primitive type lists.
|
||||
|
||||
Compared to list[dict[int, Logprob]], this data structure reduced GC
|
||||
overhead significantly. As it flattened logprob information for
|
||||
all positions and ranks in to multiple primitive type lists (i.e.
|
||||
logprobs, token_ids, ranks per token_ids, decoded_tokens).
|
||||
So regardless of the sequence length and top_logprobs setup,
|
||||
FlatLogprobs would only introduce a constant amount of objects.
|
||||
|
||||
As each position might contains different amount of ranks,
|
||||
start_indices_per_position would be used to access the logprob ranges
|
||||
for different positions.
|
||||
|
||||
NOTE: To reduce the migration overhead and improve backward compatibility,
|
||||
we support the key Sequence APIs of list, so it could act as
|
||||
list[LogprobsOnePosition]
|
||||
"""
|
||||
|
||||
# Start / end indices to indicate the range of logprobs for each position.
|
||||
start_indices: list[int] = field(default_factory=list)
|
||||
end_indices: list[int] = field(default_factory=list)
|
||||
|
||||
# Flatten Logprob information for (each position, rank).
|
||||
# For position <i>, the logprobs are ranged
|
||||
# from self.start_indices[i] to self.end_indices[i] (exclusive).
|
||||
token_ids: list[int] = field(default_factory=list)
|
||||
logprobs: list[float] = field(default_factory=list)
|
||||
ranks: list[int | None] = field(default_factory=list)
|
||||
decoded_tokens: list[str | None] = field(default_factory=list)
|
||||
|
||||
def append(self, logprobs_one_position: LogprobsOnePosition | None) -> None:
|
||||
"""Appends the container with logprobs for the next position"""
|
||||
self.start_indices.append(len(self.logprobs))
|
||||
if logprobs_one_position:
|
||||
for token_id, logprob in logprobs_one_position.items():
|
||||
self.token_ids.append(token_id)
|
||||
self.logprobs.append(logprob.logprob)
|
||||
self.ranks.append(logprob.rank)
|
||||
self.decoded_tokens.append(logprob.decoded_token)
|
||||
self.end_indices.append(len(self.logprobs))
|
||||
|
||||
def append_fast(
|
||||
self,
|
||||
token_ids: list[int],
|
||||
logprobs: list[float],
|
||||
ranks: itertools.chain[int],
|
||||
decoded_tokens: Iterable[str | None],
|
||||
) -> None:
|
||||
"""
|
||||
Appends logprobs for the next position without creating
|
||||
the intermediate logprob dictionary.
|
||||
"""
|
||||
self.start_indices.append(len(self.logprobs))
|
||||
for token_id, logprob, rank, decoded_token in zip(
|
||||
token_ids, logprobs, ranks, decoded_tokens
|
||||
):
|
||||
self.token_ids.append(token_id)
|
||||
self.logprobs.append(logprob)
|
||||
self.ranks.append(rank)
|
||||
self.decoded_tokens.append(decoded_token)
|
||||
self.end_indices.append(len(self.logprobs))
|
||||
|
||||
def extend(self, logprobs_multi_positions) -> None:
|
||||
"""Extends the container with logprobs for the next multiple positions"""
|
||||
for logprobs_one_position in logprobs_multi_positions:
|
||||
self.append(logprobs_one_position)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Gets number of positions stored in the container"""
|
||||
return len(self.start_indices)
|
||||
|
||||
@overload
|
||||
def __getitem__(self, position: int) -> LogprobsOnePosition: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, s: slice, /) -> "FlatLogprobs": ...
|
||||
|
||||
def __getitem__(self, index: int | slice):
|
||||
"""Extracts logprobs of a given position or slice"""
|
||||
if isinstance(index, int):
|
||||
return {
|
||||
self.token_ids[i]: Logprob(
|
||||
logprob=self.logprobs[i],
|
||||
rank=self.ranks[i],
|
||||
decoded_token=self.decoded_tokens[i],
|
||||
)
|
||||
for i in range(self.start_indices[index], self.end_indices[index])
|
||||
}
|
||||
elif isinstance(index, slice):
|
||||
min_index = self.start_indices[index][0]
|
||||
max_index = self.end_indices[index][-1]
|
||||
return FlatLogprobs(
|
||||
# Shift updated start_indices and end_indices to
|
||||
# be 0-indexed
|
||||
start_indices=[i - min_index for i in self.start_indices[index]],
|
||||
end_indices=[i - min_index for i in self.end_indices[index]],
|
||||
token_ids=self.token_ids[min_index:max_index],
|
||||
logprobs=self.logprobs[min_index:max_index],
|
||||
ranks=self.ranks[min_index:max_index],
|
||||
decoded_tokens=self.decoded_tokens[min_index:max_index],
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"Invalid index type: {type(index)}")
|
||||
|
||||
def __setitem__(self, item, value) -> None:
|
||||
raise TypeError("Cannot set logprobs in FlatLogprobs")
|
||||
|
||||
def __delitem__(self, item) -> None:
|
||||
raise TypeError("Cannot delete logprobs from FlatLogprobs")
|
||||
|
||||
def insert(self, item) -> None:
|
||||
raise TypeError("Cannot insert logprobs to FlatLogprobs")
|
||||
|
||||
def __iter__(self) -> Iterator[LogprobsOnePosition]:
|
||||
"""
|
||||
Iterates the container and yields LogprobsOnePosition for
|
||||
each position.
|
||||
"""
|
||||
for i in range(0, len(self.start_indices)):
|
||||
yield self.__getitem__(i)
|
||||
|
||||
|
||||
# {token_id -> logprob} per each sequence group. None if the corresponding
|
||||
# sequence group doesn't require prompt logprob.
|
||||
PromptLogprobs = FlatLogprobs | list[LogprobsOnePosition | None]
|
||||
# {token_id -> logprob} for each sequence group.
|
||||
SampleLogprobs = FlatLogprobs | list[LogprobsOnePosition]
|
||||
|
||||
|
||||
def create_prompt_logprobs() -> PromptLogprobs:
|
||||
"""Creates a container to store prompt logprobs for a request"""
|
||||
logprobs = FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else []
|
||||
# NOTE: logprob of first prompt token is None.
|
||||
logprobs.append(None)
|
||||
return logprobs
|
||||
|
||||
|
||||
def create_sample_logprobs() -> SampleLogprobs:
|
||||
"""Creates a container to store decode logprobs for a request"""
|
||||
return FlatLogprobs() if envs.VLLM_FLAT_LOGPROBS else []
|
||||
|
||||
|
||||
def append_logprobs_for_next_position(
|
||||
request_logprobs: PromptLogprobs | SampleLogprobs,
|
||||
token_ids: list[int],
|
||||
logprobs: list[float],
|
||||
decoded_tokens: Iterable[str | None],
|
||||
rank: int,
|
||||
num_logprobs: int,
|
||||
) -> None:
|
||||
"""Appends logprobs for the next position"""
|
||||
if num_logprobs == -1:
|
||||
num_logprobs = len(logprobs)
|
||||
# We do not need a special case for the sampled token
|
||||
# being in the topk, since inserting duplicated data
|
||||
# into a dictionary twice is the same as doing it once.
|
||||
topk_ranks = range(1, num_logprobs + 1)
|
||||
ranks = itertools.chain((rank,), topk_ranks)
|
||||
|
||||
if isinstance(request_logprobs, FlatLogprobs):
|
||||
request_logprobs.append_fast(token_ids, logprobs, ranks, decoded_tokens)
|
||||
else:
|
||||
request_logprobs.append(
|
||||
{
|
||||
token_id: Logprob(
|
||||
logprob=logprob,
|
||||
rank=rank,
|
||||
decoded_token=token,
|
||||
)
|
||||
for token_id, logprob, rank, token in zip(
|
||||
token_ids, logprobs, ranks, decoded_tokens
|
||||
)
|
||||
}
|
||||
)
|
||||
Reference in New Issue
Block a user