65 lines
2.8 KiB
Python
65 lines
2.8 KiB
Python
|
|
from typing import Union
|
|
from vllm.sequence import Sequence
|
|
from typing import Sequence as GenericSequence
|
|
|
|
|
|
class ZeroOverheadSequence(Sequence):
|
|
def __init__(self, seq_id, inputs, block_size, eos_token_id = None, lora_request = None, prompt_adapter_request = None):
|
|
super().__init__(seq_id, inputs, block_size, eos_token_id, lora_request, prompt_adapter_request)
|
|
self.effective_output_len : int = 0
|
|
|
|
def fix_last_token_id(self, token_id: int) -> None:
|
|
effect_offset = self.effective_output_len - len(self.data.output_token_ids)
|
|
if effect_offset < 0:
|
|
self.data._output_token_ids[effect_offset] = token_id
|
|
if len(self.data._new_appended_tokens) >= effect_offset * -1:
|
|
self.data._new_appended_tokens[effect_offset] = token_id
|
|
self.data._cached_all_token_ids[effect_offset] = token_id
|
|
self.effective_output_len += 1
|
|
|
|
def remove_last_place_holder(self, count):
|
|
self.data._output_token_ids = self.data._output_token_ids[:-1 * count]
|
|
self.data._new_appended_tokens = self.data._new_appended_tokens[:-1 * count]
|
|
self.data._cached_all_token_ids = self.data._cached_all_token_ids[:-1 * count]
|
|
self.data._num_computed_tokens -= count
|
|
|
|
def zero_overhead_get_output_token_ids(self) -> tuple[int, ...]:
|
|
return self.data.output_token_ids[:self.effective_output_len]
|
|
|
|
def zero_overhead_get_output_len(self) -> int:
|
|
return self.effective_output_len
|
|
|
|
def zero_overhead_get_last_token_id(self) -> int:
|
|
if self.effective_output_len == 0:
|
|
return self.data._prompt_token_ids[-1]
|
|
return self.data._output_token_ids[self.effective_output_len - 1]
|
|
|
|
def zero_overhead_get_len(self) -> int:
|
|
return self.effective_output_len + len(self.data._prompt_token_ids)
|
|
|
|
def get_output_token_ids_to_return(
|
|
self, delta: bool) -> Union[GenericSequence[int], int]:
|
|
"""If delta is True, only new tokens since the last call to
|
|
this method are returned"""
|
|
if not delta:
|
|
return self.zero_overhead_get_output_token_ids()
|
|
|
|
output_len = self.zero_overhead_get_output_len()
|
|
|
|
# Get the number of new tokens
|
|
num_new_tokens = output_len - self._last_output_token_ids_offset
|
|
self._last_output_token_ids_offset = output_len
|
|
|
|
# Return new tokens
|
|
if num_new_tokens == 1:
|
|
# Optimization for single decode token case
|
|
# (which is what we have most of the time)
|
|
return self.data._cached_all_token_ids[self.effective_output_len - 1]
|
|
|
|
if num_new_tokens == 0:
|
|
return []
|
|
|
|
effect_offset = self.effective_output_len - len(self.data.output_token_ids)
|
|
return self.data._cached_all_token_ids[-num_new_tokens : effect_offset]
|