Misc clean up; Remove the support of jump forward (#4032)

This commit is contained in:
Lianmin Zheng
2025-03-03 07:02:14 -08:00
committed by GitHub
parent 110e006673
commit 935cda944b
41 changed files with 396 additions and 426 deletions

View File

@@ -57,7 +57,6 @@ DETOKENIZER_MAX_STATES = int(os.environ.get("SGLANG_DETOKENIZER_MAX_STATES", 1 <
class DecodeStatus:
"""Store the status of incremental decoding."""
vid: int
decoded_text: str
decode_ids: List[int]
surr_offset: int
@@ -143,10 +142,8 @@ class DetokenizerManager:
read_ids, surr_ids = [], []
for i in range(bs):
rid = recv_obj.rids[i]
vid = recv_obj.vids[i]
if rid not in self.decode_status or self.decode_status[rid].vid != vid:
if rid not in self.decode_status:
s = DecodeStatus(
vid=vid,
decoded_text=recv_obj.decoded_texts[i],
decode_ids=recv_obj.decode_ids[i],
surr_offset=0,

View File

@@ -376,8 +376,6 @@ class BatchTokenIDOut:
# The finish reason
finished_reasons: List[BaseFinishReason]
# For incremental decoding
# The version id to sync decode status with in detokenizer_manager
vids: List[int]
decoded_texts: List[str]
decode_ids: List[int]
read_offsets: List[int]

View File

@@ -296,7 +296,6 @@ class Req:
# 1: surr_offset
# 2: read_offset
# 3: last token
self.vid = 0 # version id to sync decode status with in detokenizer_manager
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
self.read_offset = None
self.decoded_text = ""
@@ -357,11 +356,6 @@ class Req:
) = None
self.hidden_states = []
# Logprobs (internal values)
# The tokens is prefilled but need to be considered as decode tokens
# and should be updated for the decode logprobs
self.last_update_decode_tokens = 0
# Embedding (return values)
self.embedding = None
@@ -500,68 +494,6 @@ class Req:
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
return
def jump_forward_and_retokenize(self, jump_forward_str, next_state):
if self.origin_input_text is None:
# Recovering text can only use unpadded ids
self.origin_input_text = self.tokenizer.decode(
self.origin_input_ids_unpadded
)
all_text = self.origin_input_text + self.decoded_text + jump_forward_str
all_ids = self.tokenizer.encode(all_text)
if not all_ids:
logger.warning("Encoded all_text resulted in empty all_ids")
return False
prompt_tokens = len(self.origin_input_ids_unpadded)
if prompt_tokens > len(all_ids):
logger.warning("prompt_tokens is larger than encoded all_ids")
return False
if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
# TODO(lsyin): fix token fusion
logger.warning(
"Token fusion between input and output, try to avoid this by removing the space at the end of the input."
)
return False
old_output_ids = self.output_ids
self.output_ids = all_ids[prompt_tokens:]
self.decoded_text = self.decoded_text + jump_forward_str
self.surr_offset = prompt_tokens
self.read_offset = len(all_ids)
# NOTE: A trick to reduce the surrouding tokens decoding overhead
for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
surr_text_ = self.tokenizer.decode(
all_ids[self.read_offset - i : self.read_offset]
)
if not surr_text_.endswith("<EFBFBD>"):
self.surr_offset = self.read_offset - i
break
# update the inner state of the grammar
self.grammar.jump_and_retokenize(old_output_ids, self.output_ids, next_state)
if self.return_logprob:
# For fast-forward part's logprobs
k = 0
for i, old_id in enumerate(old_output_ids):
if old_id == self.output_ids[i]:
k = k + 1
else:
break
self.output_token_logprobs_val = self.output_token_logprobs_val[:k]
self.output_token_logprobs_idx = self.output_token_logprobs_idx[:k]
self.output_top_logprobs_val = self.output_top_logprobs_val[:k]
self.output_top_logprobs_idx = self.output_top_logprobs_idx[:k]
self.output_token_ids_logprobs_val = self.output_token_ids_logprobs_val[:k]
self.output_token_ids_logprobs_idx = self.output_token_ids_logprobs_idx[:k]
self.logprob_start_len = prompt_tokens + k
self.last_update_decode_tokens = len(self.output_ids) - k
return True
def reset_for_retract(self):
self.prefix_indices = []
self.last_node = None
@@ -574,8 +506,6 @@ class Req:
self.is_chunked = 0
self.req_pool_idx = None
self.last_update_decode_tokens = 0
def __repr__(self):
return (
f"Req(rid={self.rid}, "
@@ -672,7 +602,6 @@ class ScheduleBatch:
enable_overlap: bool,
spec_algorithm: SpeculativeAlgorithm,
enable_custom_logit_processor: bool,
return_hidden_states: bool = False,
):
return cls(
reqs=reqs,
@@ -687,7 +616,7 @@ class ScheduleBatch:
device=req_to_token_pool.device,
spec_algorithm=spec_algorithm,
enable_custom_logit_processor=enable_custom_logit_processor,
return_hidden_states=return_hidden_states,
return_hidden_states=any(req.return_hidden_states for req in reqs),
)
def batch_size(self):
@@ -1091,59 +1020,6 @@ class ScheduleBatch:
return retracted_reqs, new_estimate_ratio
def check_for_jump_forward(self, pad_input_ids_func):
jump_forward_reqs = []
keep_indices = set(i for i in range(len(self.reqs)))
for i, req in enumerate(self.reqs):
if req.grammar is not None:
jump_helper = req.grammar.try_jump_forward(req.tokenizer)
if jump_helper:
suffix_ids, _ = jump_helper
# Current ids, for cache and revert
cur_all_ids = tuple(req.origin_input_ids + req.output_ids)[:-1]
cur_output_ids = req.output_ids
req.output_ids.extend(suffix_ids)
decode_res, new_text = req.get_next_inc_detokenization()
if not decode_res:
req.output_ids = cur_output_ids
continue
(
jump_forward_str,
next_state,
) = req.grammar.jump_forward_str_state(jump_helper)
# Make the incrementally decoded text part of jump_forward_str
# so that the UTF-8 will not corrupt
jump_forward_str = new_text + jump_forward_str
if not req.jump_forward_and_retokenize(
jump_forward_str, next_state
):
req.output_ids = cur_output_ids
continue
# The decode status has diverged from detokenizer_manager
req.vid += 1
# insert the old request into tree_cache
self.tree_cache.cache_finished_req(req, cur_all_ids)
# re-applying image padding
if req.image_inputs is not None:
req.origin_input_ids = pad_input_ids_func(
req.origin_input_ids_unpadded, req.image_inputs
)
jump_forward_reqs.append(req)
keep_indices.remove(i)
self.filter_batch(keep_indices=list(keep_indices))
return jump_forward_reqs
def prepare_encoder_info_decode(self):
# Reset the encoder cached status
self.encoder_cached = [True] * len(self.reqs)

View File

@@ -150,7 +150,6 @@ class Scheduler:
self.tp_rank = tp_rank
self.tp_size = server_args.tp_size
self.schedule_policy = server_args.schedule_policy
self.disable_jump_forward = server_args.disable_jump_forward
self.lora_paths = server_args.lora_paths
self.max_loras_per_batch = server_args.max_loras_per_batch
self.enable_overlap = not server_args.disable_overlap_schedule
@@ -251,9 +250,6 @@ class Scheduler:
self.enable_overlap = False
logger.info("Overlap scheduler is disabled for multimodal models.")
if self.enable_overlap:
self.disable_jump_forward = True
# Launch a tensor parallel worker
if self.enable_overlap:
TpWorkerClass = TpModelWorkerClient
@@ -1024,11 +1020,8 @@ class Scheduler:
if self.running_batch is not None
else set([])
)
return_hidden_states = False
# Get requests from the waiting queue to a new prefill batch
for req in self.waiting_queue:
if req.return_hidden_states:
return_hidden_states = True
if (
self.lora_paths
and len(
@@ -1114,7 +1107,6 @@ class Scheduler:
self.enable_overlap,
self.spec_algorithm,
self.server_args.enable_custom_logit_processor,
return_hidden_states,
)
new_batch.prepare_for_extend()
@@ -1168,14 +1160,6 @@ class Scheduler:
self.min_new_token_ratio,
)
# Check for jump-forward
if not self.disable_jump_forward and batch.has_grammar:
jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func)
self._extend_requests_to_queue(jump_forward_reqs)
if batch.is_empty():
self.batch_is_full = False
return None
if batch.batch_size() < initial_bs:
self.batch_is_full = False
@@ -1530,8 +1514,6 @@ class Scheduler:
prefill (e.g., computing input token logprobs).
"""
assert output.input_token_logprobs is not None
# It is for jump decoding that will be deprecated.
assert req.last_update_decode_tokens == 0
if req.input_token_logprobs is None:
req.input_token_logprobs = []
if req.temp_input_top_logprobs_val is None:
@@ -1658,50 +1640,12 @@ class Scheduler:
self.add_input_logprob_return_values(
i, req, output, pt, num_input_logprobs, last_prefill_chunk=True
)
if req.last_update_decode_tokens != 0:
# Some decode tokens are re-computed in an extend batch
req.output_token_logprobs_val.extend(
output.input_token_logprobs[
pt
+ num_input_logprobs
- 1
- req.last_update_decode_tokens : pt
+ num_input_logprobs
- 1
],
)
req.output_token_logprobs_idx.extend(
req.fill_ids[
len(req.fill_ids)
- req.last_update_decode_tokens : len(req.fill_ids)
]
)
if req.top_logprobs_num > 0:
if req.last_update_decode_tokens != 0:
req.output_top_logprobs_val.extend(
output.input_top_logprobs_val[i][-req.last_update_decode_tokens :]
)
req.output_top_logprobs_idx.extend(
output.input_top_logprobs_idx[i][-req.last_update_decode_tokens :]
)
req.output_top_logprobs_val.append(output.next_token_top_logprobs_val[i])
req.output_top_logprobs_idx.append(output.next_token_top_logprobs_idx[i])
if req.token_ids_logprob is not None:
if req.last_update_decode_tokens != 0:
req.output_token_ids_logprobs_val.extend(
output.input_token_ids_logprobs_val[i][
-req.last_update_decode_tokens :
]
)
req.output_token_ids_logprobs_idx.extend(
output.input_token_ids_logprobs_idx[i][
-req.last_update_decode_tokens :
]
)
req.output_token_ids_logprobs_val.append(
output.next_token_token_ids_logprobs_val[i]
)
@@ -1719,7 +1663,6 @@ class Scheduler:
finished_reasons: List[BaseFinishReason] = []
if self.is_generation:
vids = []
decoded_texts = []
decode_ids_list = []
read_offsets = []
@@ -1786,7 +1729,6 @@ class Scheduler:
finished_reasons.append(
req.finished_reason.to_json() if req.finished_reason else None
)
vids.append(req.vid)
decoded_texts.append(req.decoded_text)
decode_ids, read_offset = req.init_incremental_detokenize()
decode_ids_list.append(decode_ids)
@@ -1842,7 +1784,6 @@ class Scheduler:
BatchTokenIDOut(
rids,
finished_reasons,
vids,
decoded_texts,
decode_ids_list,
read_offsets,