Misc clean up; Remove the support of jump forward (#4032)
This commit is contained in:
@@ -296,7 +296,6 @@ class Req:
|
||||
# 1: surr_offset
|
||||
# 2: read_offset
|
||||
# 3: last token
|
||||
self.vid = 0 # version id to sync decode status with in detokenizer_manager
|
||||
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
|
||||
self.read_offset = None
|
||||
self.decoded_text = ""
|
||||
@@ -357,11 +356,6 @@ class Req:
|
||||
) = None
|
||||
self.hidden_states = []
|
||||
|
||||
# Logprobs (internal values)
|
||||
# The tokens is prefilled but need to be considered as decode tokens
|
||||
# and should be updated for the decode logprobs
|
||||
self.last_update_decode_tokens = 0
|
||||
|
||||
# Embedding (return values)
|
||||
self.embedding = None
|
||||
|
||||
@@ -500,68 +494,6 @@ class Req:
|
||||
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
|
||||
return
|
||||
|
||||
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
|
||||
if self.origin_input_text is None:
|
||||
# Recovering text can only use unpadded ids
|
||||
self.origin_input_text = self.tokenizer.decode(
|
||||
self.origin_input_ids_unpadded
|
||||
)
|
||||
|
||||
all_text = self.origin_input_text + self.decoded_text + jump_forward_str
|
||||
all_ids = self.tokenizer.encode(all_text)
|
||||
if not all_ids:
|
||||
logger.warning("Encoded all_text resulted in empty all_ids")
|
||||
return False
|
||||
|
||||
prompt_tokens = len(self.origin_input_ids_unpadded)
|
||||
if prompt_tokens > len(all_ids):
|
||||
logger.warning("prompt_tokens is larger than encoded all_ids")
|
||||
return False
|
||||
|
||||
if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
|
||||
# TODO(lsyin): fix token fusion
|
||||
logger.warning(
|
||||
"Token fusion between input and output, try to avoid this by removing the space at the end of the input."
|
||||
)
|
||||
return False
|
||||
|
||||
old_output_ids = self.output_ids
|
||||
self.output_ids = all_ids[prompt_tokens:]
|
||||
self.decoded_text = self.decoded_text + jump_forward_str
|
||||
self.surr_offset = prompt_tokens
|
||||
self.read_offset = len(all_ids)
|
||||
|
||||
# NOTE: A trick to reduce the surrouding tokens decoding overhead
|
||||
for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
|
||||
surr_text_ = self.tokenizer.decode(
|
||||
all_ids[self.read_offset - i : self.read_offset]
|
||||
)
|
||||
if not surr_text_.endswith("<EFBFBD>"):
|
||||
self.surr_offset = self.read_offset - i
|
||||
break
|
||||
|
||||
# update the inner state of the grammar
|
||||
self.grammar.jump_and_retokenize(old_output_ids, self.output_ids, next_state)
|
||||
|
||||
if self.return_logprob:
|
||||
# For fast-forward part's logprobs
|
||||
k = 0
|
||||
for i, old_id in enumerate(old_output_ids):
|
||||
if old_id == self.output_ids[i]:
|
||||
k = k + 1
|
||||
else:
|
||||
break
|
||||
self.output_token_logprobs_val = self.output_token_logprobs_val[:k]
|
||||
self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k]
|
||||
self.output_top_logprobs_val = self.output_top_logprobs_val[:k]
|
||||
self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k]
|
||||
self.output_token_ids_logprobs_val = self.output_token_ids_logprobs_val[:k]
|
||||
self.output_token_ids_logprobs_idx = self.output_token_ids_logprobs_idx[:k]
|
||||
self.logprob_start_len = prompt_tokens + k
|
||||
self.last_update_decode_tokens = len(self.output_ids) - k
|
||||
|
||||
return True
|
||||
|
||||
def reset_for_retract(self):
|
||||
self.prefix_indices = []
|
||||
self.last_node = None
|
||||
@@ -574,8 +506,6 @@ class Req:
|
||||
self.is_chunked = 0
|
||||
self.req_pool_idx = None
|
||||
|
||||
self.last_update_decode_tokens = 0
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"Req(rid={self.rid}, "
|
||||
@@ -672,7 +602,6 @@ class ScheduleBatch:
|
||||
enable_overlap: bool,
|
||||
spec_algorithm: SpeculativeAlgorithm,
|
||||
enable_custom_logit_processor: bool,
|
||||
return_hidden_states: bool = False,
|
||||
):
|
||||
return cls(
|
||||
reqs=reqs,
|
||||
@@ -687,7 +616,7 @@ class ScheduleBatch:
|
||||
device=req_to_token_pool.device,
|
||||
spec_algorithm=spec_algorithm,
|
||||
enable_custom_logit_processor=enable_custom_logit_processor,
|
||||
return_hidden_states=return_hidden_states,
|
||||
return_hidden_states=any(req.return_hidden_states for req in reqs),
|
||||
)
|
||||
|
||||
def batch_size(self):
|
||||
@@ -1091,59 +1020,6 @@ class ScheduleBatch:
|
||||
|
||||
return retracted_reqs, new_estimate_ratio
|
||||
|
||||
def check_for_jump_forward(self, pad_input_ids_func):
|
||||
jump_forward_reqs = []
|
||||
keep_indices = set(i for i in range(len(self.reqs)))
|
||||
|
||||
for i, req in enumerate(self.reqs):
|
||||
if req.grammar is not None:
|
||||
jump_helper = req.grammar.try_jump_forward(req.tokenizer)
|
||||
if jump_helper:
|
||||
suffix_ids, _ = jump_helper
|
||||
|
||||
# Current ids, for cache and revert
|
||||
cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
|
||||
cur_output_ids = req.output_ids
|
||||
|
||||
req.output_ids.extend(suffix_ids)
|
||||
decode_res, new_text = req.get_next_inc_detokenization()
|
||||
if not decode_res:
|
||||
req.output_ids = cur_output_ids
|
||||
continue
|
||||
|
||||
(
|
||||
jump_forward_str,
|
||||
next_state,
|
||||
) = req.grammar.jump_forward_str_state(jump_helper)
|
||||
|
||||
# Make the incrementally decoded text part of jump_forward_str
|
||||
# so that the UTF-8 will not corrupt
|
||||
jump_forward_str = new_text + jump_forward_str
|
||||
if not req.jump_forward_and_retokenize(
|
||||
jump_forward_str, next_state
|
||||
):
|
||||
req.output_ids = cur_output_ids
|
||||
continue
|
||||
|
||||
# The decode status has diverged from detokenizer_manager
|
||||
req.vid += 1
|
||||
|
||||
# insert the old request into tree_cache
|
||||
self.tree_cache.cache_finished_req(req, cur_all_ids)
|
||||
|
||||
# re-applying image padding
|
||||
if req.image_inputs is not None:
|
||||
req.origin_input_ids = pad_input_ids_func(
|
||||
req.origin_input_ids_unpadded, req.image_inputs
|
||||
)
|
||||
|
||||
jump_forward_reqs.append(req)
|
||||
keep_indices.remove(i)
|
||||
|
||||
self.filter_batch(keep_indices=list(keep_indices))
|
||||
|
||||
return jump_forward_reqs
|
||||
|
||||
def prepare_encoder_info_decode(self):
|
||||
# Reset the encoder cached status
|
||||
self.encoder_cached = [True] * len(self.reqs)
|
||||
|
||||
Reference in New Issue
Block a user