352 lines
13 KiB
Python
352 lines
13 KiB
Python
|
|
# SPDX-License-Identifier: Apache-2.0
|
|||
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|||
|
|
from abc import ABC, abstractmethod
|
|||
|
|
|
|||
|
|
import tokenizers
|
|||
|
|
from packaging import version
|
|||
|
|
from tokenizers import Tokenizer
|
|||
|
|
from tokenizers.decoders import DecodeStream
|
|||
|
|
from transformers import PreTrainedTokenizerFast
|
|||
|
|
|
|||
|
|
from vllm.logger import init_logger
|
|||
|
|
from vllm.transformers_utils.detokenizer_utils import (
|
|||
|
|
AnyTokenizer,
|
|||
|
|
convert_prompt_ids_to_tokens,
|
|||
|
|
detokenize_incrementally,
|
|||
|
|
)
|
|||
|
|
from vllm.utils import length_from_prompt_token_ids_or_embeds
|
|||
|
|
from vllm.v1.engine import EngineCoreRequest
|
|||
|
|
|
|||
|
|
logger = init_logger(__name__)
|
|||
|
|
|
|||
|
|
# Only tokenizers >= 0.21.1 supports DecodeStream used for
|
|||
|
|
# FastIncrementalDetokenizer.
|
|||
|
|
USE_FAST_DETOKENIZER = version.parse(tokenizers.__version__) >= version.parse("0.21.1")
|
|||
|
|
|
|||
|
|
# Error string from https://github.com/huggingface/tokenizers/blob/909fdde2a4ffedd9295206f705eb612be2a91b12/tokenizers/src/tokenizer/mod.rs#L1042
|
|||
|
|
INVALID_PREFIX_ERR_MSG = "Invalid prefix encountered"
|
|||
|
|
|
|||
|
|
|
|||
|
|
class IncrementalDetokenizer:
|
|||
|
|
def __init__(self):
|
|||
|
|
self.token_ids: list[int] = []
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def output_token_ids(self) -> list[int]:
|
|||
|
|
return self.token_ids
|
|||
|
|
|
|||
|
|
def update(self, new_token_ids: list[int], stop_terminated: bool) -> str | None:
|
|||
|
|
self.token_ids.extend(new_token_ids)
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
def get_next_output_text(self, finished: bool, delta: bool) -> str:
|
|||
|
|
return ""
|
|||
|
|
|
|||
|
|
@classmethod
|
|||
|
|
def from_new_request(
|
|||
|
|
cls,
|
|||
|
|
tokenizer: AnyTokenizer | None,
|
|||
|
|
request: EngineCoreRequest,
|
|||
|
|
) -> "IncrementalDetokenizer":
|
|||
|
|
assert request.sampling_params is not None
|
|||
|
|
|
|||
|
|
if tokenizer is None:
|
|||
|
|
# No tokenizer => skipping detokenization.
|
|||
|
|
return IncrementalDetokenizer()
|
|||
|
|
|
|||
|
|
if USE_FAST_DETOKENIZER and isinstance(tokenizer, PreTrainedTokenizerFast):
|
|||
|
|
# Fast tokenizer => use tokenizers library DecodeStream.
|
|||
|
|
return FastIncrementalDetokenizer(tokenizer, request)
|
|||
|
|
|
|||
|
|
# Fall back to slow python-based incremental detokenization.
|
|||
|
|
return SlowIncrementalDetokenizer(tokenizer, request)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class BaseIncrementalDetokenizer(IncrementalDetokenizer, ABC):
|
|||
|
|
def __init__(self, request: EngineCoreRequest):
|
|||
|
|
super().__init__()
|
|||
|
|
|
|||
|
|
# Stop strings
|
|||
|
|
params = request.sampling_params
|
|||
|
|
assert params is not None
|
|||
|
|
stop_list: list[str]
|
|||
|
|
if params.stop is None:
|
|||
|
|
stop_list = []
|
|||
|
|
elif isinstance(params.stop, str):
|
|||
|
|
stop_list = [params.stop]
|
|||
|
|
else:
|
|||
|
|
stop_list = params.stop
|
|||
|
|
self.stop = stop_list
|
|||
|
|
self.min_tokens = params.min_tokens
|
|||
|
|
self.include_stop_str_in_output = params.include_stop_str_in_output
|
|||
|
|
|
|||
|
|
# Number of chars to hold back when stop strings are to be excluded
|
|||
|
|
# from streamed output.
|
|||
|
|
if self.stop and not self.include_stop_str_in_output:
|
|||
|
|
self.stop_buffer_length = max(len(s) for s in self.stop) - 1
|
|||
|
|
else:
|
|||
|
|
self.stop_buffer_length = 0
|
|||
|
|
self._last_output_text_offset: int = 0
|
|||
|
|
|
|||
|
|
# Generation data
|
|||
|
|
self.output_text = ""
|
|||
|
|
|
|||
|
|
def update(self, new_token_ids: list[int], stop_terminated: bool) -> str | None:
|
|||
|
|
"""
|
|||
|
|
Update RequestState for the request_id by:
|
|||
|
|
1) Detokenize the new token ids incrementally.
|
|||
|
|
2) Evaluate stop criteria.
|
|||
|
|
|
|||
|
|
Return matched stop string or None.
|
|||
|
|
"""
|
|||
|
|
if not new_token_ids:
|
|||
|
|
# Skip detokenization if no new token ids.
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
if stop_terminated and not self.include_stop_str_in_output:
|
|||
|
|
# If stop-terminated, exclude last token from detokenization
|
|||
|
|
# based on include_stop_str_in_output parameter.
|
|||
|
|
skipped_stop_token_id = new_token_ids[-1]
|
|||
|
|
new_token_ids = new_token_ids[:-1]
|
|||
|
|
else:
|
|||
|
|
skipped_stop_token_id = None
|
|||
|
|
|
|||
|
|
# 1) Detokenize the new token ids incrementally.
|
|||
|
|
# TODO(woosuk): This method becomes very inefficient when the number of
|
|||
|
|
# new_token_ids is more than 1. We need to optimize this.
|
|||
|
|
stop_check_offset = len(self.output_text)
|
|||
|
|
for new_token_id in new_token_ids:
|
|||
|
|
self.token_ids.append(new_token_id)
|
|||
|
|
self.output_text += self.decode_next(new_token_id)
|
|||
|
|
# Support min_tokens, see https://github.com/vllm-project/vllm/pull/22014
|
|||
|
|
if self.min_tokens and len(self.output_token_ids) <= self.min_tokens:
|
|||
|
|
stop_check_offset = len(self.output_text)
|
|||
|
|
|
|||
|
|
if skipped_stop_token_id is not None:
|
|||
|
|
# Cleanup after skipping detokenization.
|
|||
|
|
self.token_ids.append(skipped_stop_token_id)
|
|||
|
|
|
|||
|
|
# 2) Evaluate stop strings.
|
|||
|
|
stop_string = None
|
|||
|
|
if self.stop and len(self.output_token_ids) > self.min_tokens:
|
|||
|
|
stop = check_stop_strings(
|
|||
|
|
output_text=self.output_text,
|
|||
|
|
new_char_count=len(self.output_text) - stop_check_offset,
|
|||
|
|
stop=self.stop,
|
|||
|
|
include_in_output=self.include_stop_str_in_output,
|
|||
|
|
)
|
|||
|
|
if stop is not None:
|
|||
|
|
stop_string, truncate_to = stop
|
|||
|
|
if truncate_to != -1:
|
|||
|
|
self.output_text = self.output_text[:truncate_to]
|
|||
|
|
|
|||
|
|
return stop_string
|
|||
|
|
|
|||
|
|
@abstractmethod
|
|||
|
|
def decode_next(self, next_token_id: int) -> str:
|
|||
|
|
raise NotImplementedError
|
|||
|
|
|
|||
|
|
def get_next_output_text(self, finished: bool, delta: bool) -> str:
|
|||
|
|
"""If delta is True, only new text since the last call to
|
|||
|
|
this method is returned"""
|
|||
|
|
|
|||
|
|
# We return the full output text if the sequence is finished.
|
|||
|
|
buffer_length = 0 if finished else self.stop_buffer_length
|
|||
|
|
if not delta:
|
|||
|
|
return (
|
|||
|
|
self.output_text[:-buffer_length]
|
|||
|
|
if buffer_length
|
|||
|
|
else (self.output_text)
|
|||
|
|
)
|
|||
|
|
length = len(self.output_text) - buffer_length
|
|||
|
|
last_offset = self._last_output_text_offset
|
|||
|
|
if last_offset < length:
|
|||
|
|
self._last_output_text_offset = length
|
|||
|
|
return self.output_text[last_offset:length]
|
|||
|
|
return ""
|
|||
|
|
|
|||
|
|
|
|||
|
|
class FastIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
|||
|
|
def __init__(self, tokenizer: PreTrainedTokenizerFast, request: EngineCoreRequest):
|
|||
|
|
super().__init__(request)
|
|||
|
|
|
|||
|
|
sampling_params = request.sampling_params
|
|||
|
|
assert sampling_params is not None
|
|||
|
|
|
|||
|
|
self.request_id = request.request_id
|
|||
|
|
self.skip_special_tokens = sampling_params.skip_special_tokens
|
|||
|
|
self.stream = DecodeStream(skip_special_tokens=self.skip_special_tokens)
|
|||
|
|
|
|||
|
|
self.tokenizer: Tokenizer = tokenizer._tokenizer
|
|||
|
|
|
|||
|
|
# Find a safe place to start.
|
|||
|
|
prompt_token_ids = request.prompt_token_ids or []
|
|||
|
|
prompt_suffix = prompt_token_ids
|
|||
|
|
prompt_len = len(prompt_suffix)
|
|||
|
|
if prompt_len > 4:
|
|||
|
|
for i in range(4, min(prompt_len + 1, 24)):
|
|||
|
|
suffix = prompt_token_ids[-i:]
|
|||
|
|
if "<EFBFBD>" not in self.tokenizer.decode(suffix):
|
|||
|
|
prompt_suffix = suffix
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
# Prime the stream.
|
|||
|
|
for tid in prompt_suffix:
|
|||
|
|
self._protected_step(tid)
|
|||
|
|
|
|||
|
|
self.spaces_between_special_tokens = (
|
|||
|
|
sampling_params.skip_special_tokens
|
|||
|
|
or sampling_params.spaces_between_special_tokens
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if not self.spaces_between_special_tokens:
|
|||
|
|
# Store dict of added token ids so that we can suppress
|
|||
|
|
# the spaces between them.
|
|||
|
|
if (
|
|||
|
|
added_token_ids := getattr(self.tokenizer, "added_token_ids", None)
|
|||
|
|
) is None:
|
|||
|
|
self.tokenizer.added_token_ids = added_token_ids = {
|
|||
|
|
tid: tok.content
|
|||
|
|
for tid, tok in self.tokenizer.get_added_tokens_decoder().items()
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
if added_token_ids:
|
|||
|
|
self.last_special = False
|
|||
|
|
self.added_token_ids = added_token_ids
|
|||
|
|
else:
|
|||
|
|
# No added tokens.
|
|||
|
|
self.spaces_between_special_tokens = True
|
|||
|
|
|
|||
|
|
def decode_next(self, next_token_id: int) -> str:
|
|||
|
|
token = self._protected_step(next_token_id)
|
|||
|
|
|
|||
|
|
if not self.spaces_between_special_tokens:
|
|||
|
|
special_token = self.added_token_ids.get(next_token_id)
|
|||
|
|
is_special = special_token is not None
|
|||
|
|
if is_special and self.last_special:
|
|||
|
|
# Return raw token string without any prefixed spaces.
|
|||
|
|
token = special_token
|
|||
|
|
self.last_special = is_special
|
|||
|
|
|
|||
|
|
return token or ""
|
|||
|
|
|
|||
|
|
def _protected_step(self, next_token_id: int) -> str | None:
|
|||
|
|
try:
|
|||
|
|
token = self.stream.step(self.tokenizer, next_token_id)
|
|||
|
|
except (OverflowError, TypeError):
|
|||
|
|
# Handle rare observed overflow, still to be diagnosed.
|
|||
|
|
# See https://github.com/vllm-project/vllm/issues/21951.
|
|||
|
|
logger.exception("Encountered invalid token id: %r", next_token_id)
|
|||
|
|
token = None
|
|||
|
|
except Exception as e:
|
|||
|
|
if not str(e).startswith(INVALID_PREFIX_ERR_MSG):
|
|||
|
|
raise e
|
|||
|
|
# Recover from edge case where tokenizer can produce non-monotonic,
|
|||
|
|
# invalid UTF-8 output, which breaks the internal state of
|
|||
|
|
# tokenizers' DecodeStream.
|
|||
|
|
# See https://github.com/vllm-project/vllm/issues/17448.
|
|||
|
|
logger.warning(
|
|||
|
|
"Encountered invalid prefix detokenization error"
|
|||
|
|
" for request %s, resetting decode stream.",
|
|||
|
|
self.request_id,
|
|||
|
|
)
|
|||
|
|
self.stream = DecodeStream(skip_special_tokens=self.skip_special_tokens)
|
|||
|
|
token = self.stream.step(self.tokenizer, next_token_id)
|
|||
|
|
return token
|
|||
|
|
|
|||
|
|
|
|||
|
|
class SlowIncrementalDetokenizer(BaseIncrementalDetokenizer):
|
|||
|
|
def __init__(self, tokenizer: AnyTokenizer, request: EngineCoreRequest):
|
|||
|
|
super().__init__(request)
|
|||
|
|
|
|||
|
|
self.tokenizer = tokenizer
|
|||
|
|
params = request.sampling_params
|
|||
|
|
assert params is not None
|
|||
|
|
|
|||
|
|
self.prompt_len = length_from_prompt_token_ids_or_embeds(
|
|||
|
|
request.prompt_token_ids, request.prompt_embeds
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Metadata for incremental detokenization.
|
|||
|
|
if request.prompt_token_ids is not None:
|
|||
|
|
self.tokens, self.prefix_offset, self.read_offset = (
|
|||
|
|
convert_prompt_ids_to_tokens(
|
|||
|
|
tokenizer=tokenizer,
|
|||
|
|
prompt_ids=request.prompt_token_ids,
|
|||
|
|
skip_special_tokens=params.skip_special_tokens,
|
|||
|
|
)
|
|||
|
|
)
|
|||
|
|
else:
|
|||
|
|
# Prompt embedding requests cannot be detokenized, in general.
|
|||
|
|
self.tokens = [""] * self.prompt_len
|
|||
|
|
self.prefix_offset = 0
|
|||
|
|
self.read_offest = 0
|
|||
|
|
|
|||
|
|
self.token_ids.extend(request.prompt_token_ids or [0] * self.prompt_len)
|
|||
|
|
|
|||
|
|
self.skip_special_tokens = params.skip_special_tokens
|
|||
|
|
self.spaces_between_special_tokens = params.spaces_between_special_tokens
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def output_token_ids(self) -> list[int]:
|
|||
|
|
return (
|
|||
|
|
self.token_ids
|
|||
|
|
if not self.prompt_len
|
|||
|
|
else (self.token_ids[self.prompt_len :])
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def decode_next(self, next_token_id: int) -> str:
|
|||
|
|
new_tokens, decoded_text, prefix_offset, read_offset = detokenize_incrementally(
|
|||
|
|
tokenizer=self.tokenizer,
|
|||
|
|
all_input_ids=self.token_ids,
|
|||
|
|
prev_tokens=self.tokens,
|
|||
|
|
prefix_offset=self.prefix_offset,
|
|||
|
|
read_offset=self.read_offset,
|
|||
|
|
skip_special_tokens=self.skip_special_tokens,
|
|||
|
|
spaces_between_special_tokens=self.spaces_between_special_tokens,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
self.tokens.extend(new_tokens)
|
|||
|
|
self.prefix_offset = prefix_offset
|
|||
|
|
self.read_offset = read_offset
|
|||
|
|
|
|||
|
|
return decoded_text
|
|||
|
|
|
|||
|
|
|
|||
|
|
def check_stop_strings(
|
|||
|
|
output_text: str,
|
|||
|
|
new_char_count: int,
|
|||
|
|
stop: list[str],
|
|||
|
|
include_in_output: bool,
|
|||
|
|
) -> tuple[str, int] | None:
|
|||
|
|
"""Check if any stop strings are matched and truncate sequence
|
|||
|
|
output text accordingly.
|
|||
|
|
|
|||
|
|
Returns tuple (stop_string, offset) if matched or else None.
|
|||
|
|
|
|||
|
|
Where stop_string is the matched stop string and offset is the
|
|||
|
|
length to which output_text should be truncated, or -1 for no
|
|||
|
|
truncation.
|
|||
|
|
"""
|
|||
|
|
if not new_char_count or not stop:
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
for stop_str in stop:
|
|||
|
|
stop_string_len = len(stop_str)
|
|||
|
|
# Avoid searching already-searched text.
|
|||
|
|
stop_index = output_text.find(stop_str, 1 - new_char_count - stop_string_len)
|
|||
|
|
if stop_index == -1:
|
|||
|
|
continue
|
|||
|
|
|
|||
|
|
if include_in_output:
|
|||
|
|
# Truncate to end of stop string.
|
|||
|
|
stop_index += stop_string_len
|
|||
|
|
if stop_index >= len(output_text):
|
|||
|
|
# No truncation required.
|
|||
|
|
return stop_str, -1
|
|||
|
|
|
|||
|
|
# Truncate the output text to either the beginning
|
|||
|
|
# or end of the stop string.
|
|||
|
|
return stop_str, stop_index
|
|||
|
|
return None
|