init src 0.9.2
This commit is contained in:
181
vllm/transformers_utils/detokenizer.py
Normal file
181
vllm/transformers_utils/detokenizer.py
Normal file
@@ -0,0 +1,181 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from vllm.sequence import (VLLM_INVALID_TOKEN_ID, Logprob, SamplingParams,
|
||||
Sequence, SequenceGroup)
|
||||
|
||||
from .detokenizer_utils import (convert_prompt_ids_to_tokens,
|
||||
detokenize_incrementally)
|
||||
from .tokenizer import AnyTokenizer
|
||||
from .tokenizer_group import TokenizerGroup
|
||||
|
||||
|
||||
class Detokenizer:
|
||||
"""Provides methods to decode the output of a model into text."""
|
||||
|
||||
def __init__(self, tokenizer_group: TokenizerGroup, mode="auto"):
|
||||
self.mode = mode
|
||||
if self.mode != "cpm":
|
||||
self.tokenizer_group = tokenizer_group
|
||||
else:
|
||||
self.tokenizer = tokenizer_group
|
||||
|
||||
def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:
|
||||
"""Returns the HF tokenizer to use for a given sequence."""
|
||||
return self.tokenizer_group.get_lora_tokenizer(sequence.lora_request)
|
||||
|
||||
def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup,
|
||||
prompt_logprobs: list[Optional[dict[
|
||||
int, Logprob]]],
|
||||
position_offset: int) -> None:
|
||||
"""Decodes the logprobs for the prompt of a sequence group.
|
||||
|
||||
Args:
|
||||
seq_group: The sequence group to decode.
|
||||
prompt_logprobs: The logprobs to decode.
|
||||
position_offset: Offset of the first index of the logprobs
|
||||
relative to the start of the sequence (for chunked prefill).
|
||||
|
||||
Returns:
|
||||
The prompt logprobs with the decoded tokens.
|
||||
"""
|
||||
prms = seq_group.sampling_params
|
||||
assert prms is not None
|
||||
|
||||
# We can pick any sequence for the prompt.
|
||||
seq = seq_group.get_seqs()[0]
|
||||
# Only prompt, without the generated token.
|
||||
all_token_ids = seq.get_token_ids()
|
||||
prompt_token_ids = all_token_ids[:-1]
|
||||
if self.mode != "cpm":
|
||||
tokenizer = self.get_tokenizer_for_seq(seq)
|
||||
else:
|
||||
tokenizer = self.tokenizer
|
||||
prefix_offset = 0
|
||||
read_offset = 0
|
||||
next_iter_prefix_offset = 0
|
||||
next_iter_read_offset = 0
|
||||
next_iter_tokens: list[str] = []
|
||||
prev_tokens = None
|
||||
|
||||
for token_position_in_logprob, prompt_logprobs_for_token in enumerate(
|
||||
prompt_logprobs):
|
||||
|
||||
# Absolute token position equals the index in the logprobs
|
||||
# list plus the offset of the entire logprobs list relative
|
||||
# to the start of the sequence.
|
||||
token_position = token_position_in_logprob + position_offset
|
||||
if not prompt_logprobs_for_token:
|
||||
continue
|
||||
for token_id, sample_logprob in prompt_logprobs_for_token.items():
|
||||
if (sample_logprob.decoded_token is None
|
||||
and token_id != VLLM_INVALID_TOKEN_ID):
|
||||
prompt_token_ids_with_token = (
|
||||
prompt_token_ids[:token_position] + [token_id])
|
||||
(new_tokens, new_text, new_prefix_offset,
|
||||
new_read_offset) = detokenize_incrementally(
|
||||
tokenizer=tokenizer,
|
||||
all_input_ids=prompt_token_ids_with_token,
|
||||
prev_tokens=prev_tokens,
|
||||
prefix_offset=prefix_offset,
|
||||
read_offset=read_offset,
|
||||
skip_special_tokens=prms.skip_special_tokens,
|
||||
spaces_between_special_tokens=prms.
|
||||
spaces_between_special_tokens,
|
||||
mode=self.mode,
|
||||
)
|
||||
|
||||
sample_logprob.decoded_token = new_text
|
||||
|
||||
# Use the offsets & prev tokens corresponding to
|
||||
# real tokens to ensure detokenization is consistent
|
||||
# actual with prompt.
|
||||
if token_id == all_token_ids[token_position]:
|
||||
next_iter_prefix_offset = new_prefix_offset
|
||||
next_iter_read_offset = new_read_offset
|
||||
next_iter_tokens = new_tokens
|
||||
|
||||
# Advance to the next token position.
|
||||
prefix_offset = next_iter_prefix_offset
|
||||
read_offset = next_iter_read_offset
|
||||
if prev_tokens is None:
|
||||
prev_tokens = next_iter_tokens.copy()
|
||||
else:
|
||||
prev_tokens.extend(next_iter_tokens)
|
||||
|
||||
def decode_sequence_inplace(self, seq: Sequence,
|
||||
prms: SamplingParams) -> int:
|
||||
"""Decodes the new token for a sequence. In-place operation.
|
||||
|
||||
Args:
|
||||
seq: The sequence to decode.
|
||||
prms: The sampling parameters used to generate the sequence.
|
||||
|
||||
Returns:
|
||||
The number of characters added to the output text.
|
||||
"""
|
||||
all_input_ids = seq.get_token_ids()
|
||||
token_id_generated_this_iteration = all_input_ids[-1]
|
||||
if self.mode != "cpm":
|
||||
tokenizer = self.get_tokenizer_for_seq(seq)
|
||||
else:
|
||||
tokenizer = self.tokenizer
|
||||
|
||||
# Convert prompt token IDs to tokens if necessary.
|
||||
# Do it here so that we don't have to repeat this
|
||||
# computation for each logprob.
|
||||
if seq.tokens is None:
|
||||
(seq.tokens, seq.prefix_offset,
|
||||
seq.read_offset) = convert_prompt_ids_to_tokens(
|
||||
tokenizer=tokenizer,
|
||||
prompt_ids=all_input_ids[:-1],
|
||||
skip_special_tokens=prms.skip_special_tokens,
|
||||
)
|
||||
|
||||
(new_tokens, new_decoded_token_text, prefix_offset,
|
||||
read_offset) = detokenize_incrementally(
|
||||
tokenizer=tokenizer,
|
||||
all_input_ids=all_input_ids,
|
||||
prev_tokens=seq.tokens,
|
||||
prefix_offset=seq.prefix_offset,
|
||||
read_offset=seq.read_offset,
|
||||
skip_special_tokens=prms.skip_special_tokens,
|
||||
spaces_between_special_tokens=prms.spaces_between_special_tokens,
|
||||
mode=self.mode,
|
||||
)
|
||||
|
||||
# Decode logprobs
|
||||
logprobs = seq.output_logprobs[-1]
|
||||
if logprobs:
|
||||
previous_tokens = all_input_ids[:-1]
|
||||
for token_id, sample_logprob in logprobs.items():
|
||||
# If the token was generated this iteration,
|
||||
# use the provided text.
|
||||
if token_id == token_id_generated_this_iteration:
|
||||
sample_logprob.decoded_token = new_decoded_token_text
|
||||
continue
|
||||
|
||||
if (sample_logprob.decoded_token is None
|
||||
and token_id != VLLM_INVALID_TOKEN_ID):
|
||||
all_input_ids_with_logprob = previous_tokens + [token_id]
|
||||
(_, new_text, _, _) = detokenize_incrementally(
|
||||
tokenizer=tokenizer,
|
||||
all_input_ids=all_input_ids_with_logprob,
|
||||
prev_tokens=seq.tokens,
|
||||
prefix_offset=seq.prefix_offset,
|
||||
read_offset=seq.read_offset,
|
||||
skip_special_tokens=prms.skip_special_tokens,
|
||||
spaces_between_special_tokens=prms.
|
||||
spaces_between_special_tokens,
|
||||
mode=self.mode,
|
||||
)
|
||||
sample_logprob.decoded_token = new_text
|
||||
|
||||
seq.tokens.extend(new_tokens)
|
||||
seq.prefix_offset = prefix_offset
|
||||
seq.read_offset = read_offset
|
||||
seq.output_text += new_decoded_token_text
|
||||
|
||||
return len(new_decoded_token_text)
|
||||
Reference in New Issue
Block a user