84 lines
3.5 KiB
Python
84 lines
3.5 KiB
Python
|
|
|
|
from vllm.sampling_params import SamplingParams
|
|
from vllm.sequence import VLLM_INVALID_TOKEN_ID
|
|
from vllm.transformers_utils.detokenizer import Detokenizer
|
|
from vllm.transformers_utils.detokenizer_utils import convert_prompt_ids_to_tokens, detokenize_incrementally
|
|
from vllm.zero_overhead.sequence import ZeroOverheadSequence
|
|
|
|
|
|
class ZeroOverheadDetokenizer(Detokenizer):
|
|
def __init__(self, tokenizer_group):
|
|
super().__init__(tokenizer_group)
|
|
|
|
def decode_sequence_inplace(self, seq: ZeroOverheadSequence,
|
|
prms: SamplingParams) -> int:
|
|
"""Decodes the new token for a sequence. In-place operation.
|
|
|
|
Args:
|
|
seq: The sequence to decode.
|
|
prms: The sampling parameters used to generate the sequence.
|
|
|
|
Returns:
|
|
The number of characters added to the output text.
|
|
"""
|
|
eff_length = seq.get_prompt_len() + seq.effective_output_len
|
|
all_input_ids = seq.get_token_ids()[ : eff_length]
|
|
|
|
token_id_generated_this_iteration = all_input_ids[-1]
|
|
tokenizer = self.get_tokenizer_for_seq(seq)
|
|
|
|
# Convert prompt token IDs to tokens if necessary.
|
|
# Do it here so that we don't have to repeat this
|
|
# computation for each logprob.
|
|
if seq.tokens is None:
|
|
(seq.tokens, seq.prefix_offset,
|
|
seq.read_offset) = convert_prompt_ids_to_tokens(
|
|
tokenizer=tokenizer,
|
|
prompt_ids=all_input_ids[:-1],
|
|
skip_special_tokens=prms.skip_special_tokens,
|
|
)
|
|
|
|
(new_tokens, new_decoded_token_text, prefix_offset,
|
|
read_offset) = detokenize_incrementally(
|
|
tokenizer=tokenizer,
|
|
all_input_ids=all_input_ids,
|
|
prev_tokens=seq.tokens,
|
|
prefix_offset=seq.prefix_offset,
|
|
read_offset=seq.read_offset,
|
|
skip_special_tokens=prms.skip_special_tokens,
|
|
spaces_between_special_tokens=prms.spaces_between_special_tokens,
|
|
)
|
|
|
|
# Decode logprobs
|
|
logprobs = seq.output_logprobs[-1]
|
|
if logprobs:
|
|
previous_tokens = all_input_ids[:-1]
|
|
for token_id, sample_logprob in logprobs.items():
|
|
# If the token was generated this iteration,
|
|
# use the provided text.
|
|
if token_id == token_id_generated_this_iteration:
|
|
sample_logprob.decoded_token = new_decoded_token_text
|
|
continue
|
|
|
|
if (sample_logprob.decoded_token is None
|
|
and token_id != VLLM_INVALID_TOKEN_ID):
|
|
all_input_ids_with_logprob = previous_tokens + [token_id]
|
|
(_, new_text, _, _) = detokenize_incrementally(
|
|
tokenizer=tokenizer,
|
|
all_input_ids=all_input_ids_with_logprob,
|
|
prev_tokens=seq.tokens,
|
|
prefix_offset=seq.prefix_offset,
|
|
read_offset=seq.read_offset,
|
|
skip_special_tokens=prms.skip_special_tokens,
|
|
spaces_between_special_tokens=prms.
|
|
spaces_between_special_tokens,
|
|
)
|
|
sample_logprob.decoded_token = new_text
|
|
|
|
seq.tokens.extend(new_tokens)
|
|
seq.prefix_offset = prefix_offset
|
|
seq.read_offset = read_offset
|
|
seq.output_text += new_decoded_token_text
|
|
|
|
return len(new_decoded_token_text) |