Support incremental streaming of logprob/token_ids between scheduler and detokenizer (#6225)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>
This commit is contained in:
Lianmin Zheng
2025-05-12 14:33:38 -07:00
committed by GitHub
parent f1c896007a
commit d18c6b3358
9 changed files with 257 additions and 86 deletions

View File

@@ -530,10 +530,6 @@ class Scheduler(
)
def init_metrics(self):
# The largest prefill length of a single request
self._largest_prefill_len: int = 0
# The largest context length (prefill + generation) of a single request
self._largest_prefill_decode_len: int = 0
self.last_gen_throughput: float = 0.0
self.last_input_throughput: float = 0.0
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
@@ -1122,9 +1118,6 @@ class Scheduler(
self.token_to_kv_pool_allocator.available_size()
+ self.tree_cache.evictable_size()
)
self._largest_prefill_len = max(
self._largest_prefill_len, adder.log_input_tokens
)
num_new_seq = len(can_run_list)
f = (
@@ -1601,14 +1594,9 @@ class Scheduler(
elif batch.forward_mode.is_idle():
if self.enable_overlap:
self.tp_worker.resolve_last_batch_result(launch_done)
if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
self.set_next_batch_sampling_info_done(batch)
elif batch.forward_mode.is_dummy_first():
batch.next_batch_sampling_info.update_regex_vocab_mask()
self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
self.set_next_batch_sampling_info_done(batch)
if self.return_health_check_ct:
# Return some signal for the health check.
@@ -1776,6 +1764,13 @@ class Scheduler(
self._extend_requests_to_queue(self.grammar_queue[:num_ready_reqs])
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
if batch.next_batch_sampling_info:
if batch.next_batch_sampling_info.grammars is not None:
batch.next_batch_sampling_info.update_regex_vocab_mask()
self.current_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
def watchdog_thread(self):
"""A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
self.watchdog_last_forward_ct = 0