init src 0.9.2
This commit is contained in:
84
vllm/zero_overhead/tokenizer.py
Normal file
84
vllm/zero_overhead/tokenizer.py
Normal file
@@ -0,0 +1,84 @@
|
||||
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import VLLM_INVALID_TOKEN_ID
|
||||
from vllm.transformers_utils.detokenizer import Detokenizer
|
||||
from vllm.transformers_utils.detokenizer_utils import convert_prompt_ids_to_tokens, detokenize_incrementally
|
||||
from vllm.zero_overhead.sequence import ZeroOverheadSequence
|
||||
|
||||
|
||||
class ZeroOverheadDetokenizer(Detokenizer):
|
||||
def __init__(self, tokenizer_group):
|
||||
super().__init__(tokenizer_group)
|
||||
|
||||
def decode_sequence_inplace(self, seq: ZeroOverheadSequence,
|
||||
prms: SamplingParams) -> int:
|
||||
"""Decodes the new token for a sequence. In-place operation.
|
||||
|
||||
Args:
|
||||
seq: The sequence to decode.
|
||||
prms: The sampling parameters used to generate the sequence.
|
||||
|
||||
Returns:
|
||||
The number of characters added to the output text.
|
||||
"""
|
||||
eff_length = seq.get_prompt_len() + seq.effective_output_len
|
||||
all_input_ids = seq.get_token_ids()[ : eff_length]
|
||||
|
||||
token_id_generated_this_iteration = all_input_ids[-1]
|
||||
tokenizer = self.get_tokenizer_for_seq(seq)
|
||||
|
||||
# Convert prompt token IDs to tokens if necessary.
|
||||
# Do it here so that we don't have to repeat this
|
||||
# computation for each logprob.
|
||||
if seq.tokens is None:
|
||||
(seq.tokens, seq.prefix_offset,
|
||||
seq.read_offset) = convert_prompt_ids_to_tokens(
|
||||
tokenizer=tokenizer,
|
||||
prompt_ids=all_input_ids[:-1],
|
||||
skip_special_tokens=prms.skip_special_tokens,
|
||||
)
|
||||
|
||||
(new_tokens, new_decoded_token_text, prefix_offset,
|
||||
read_offset) = detokenize_incrementally(
|
||||
tokenizer=tokenizer,
|
||||
all_input_ids=all_input_ids,
|
||||
prev_tokens=seq.tokens,
|
||||
prefix_offset=seq.prefix_offset,
|
||||
read_offset=seq.read_offset,
|
||||
skip_special_tokens=prms.skip_special_tokens,
|
||||
spaces_between_special_tokens=prms.spaces_between_special_tokens,
|
||||
)
|
||||
|
||||
# Decode logprobs
|
||||
logprobs = seq.output_logprobs[-1]
|
||||
if logprobs:
|
||||
previous_tokens = all_input_ids[:-1]
|
||||
for token_id, sample_logprob in logprobs.items():
|
||||
# If the token was generated this iteration,
|
||||
# use the provided text.
|
||||
if token_id == token_id_generated_this_iteration:
|
||||
sample_logprob.decoded_token = new_decoded_token_text
|
||||
continue
|
||||
|
||||
if (sample_logprob.decoded_token is None
|
||||
and token_id != VLLM_INVALID_TOKEN_ID):
|
||||
all_input_ids_with_logprob = previous_tokens + [token_id]
|
||||
(_, new_text, _, _) = detokenize_incrementally(
|
||||
tokenizer=tokenizer,
|
||||
all_input_ids=all_input_ids_with_logprob,
|
||||
prev_tokens=seq.tokens,
|
||||
prefix_offset=seq.prefix_offset,
|
||||
read_offset=seq.read_offset,
|
||||
skip_special_tokens=prms.skip_special_tokens,
|
||||
spaces_between_special_tokens=prms.
|
||||
spaces_between_special_tokens,
|
||||
)
|
||||
sample_logprob.decoded_token = new_text
|
||||
|
||||
seq.tokens.extend(new_tokens)
|
||||
seq.prefix_offset = prefix_offset
|
||||
seq.read_offset = read_offset
|
||||
seq.output_text += new_decoded_token_text
|
||||
|
||||
return len(new_decoded_token_text)
|
||||
Reference in New Issue
Block a user