Improve overlap scheduling (#5788)

This commit is contained in:
Liangsheng Yin
2025-04-28 11:19:16 +08:00
committed by GitHub
parent f0365820e8
commit 40d9b8acce
6 changed files with 61 additions and 23 deletions

View File

@@ -645,6 +645,7 @@ class Scheduler(
self.cur_batch = batch
if batch:
batch.launch_done = threading.Event()
result = self.run_batch(batch)
self.result_queue.append((batch.copy(), result))
@@ -656,7 +657,7 @@ class Scheduler(
forward_mode=ForwardMode.DUMMY_FIRST,
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
)
self.process_batch_result(tmp_batch, None)
self.process_batch_result(tmp_batch, None, batch.launch_done)
if self.last_batch:
# Process the results of the last batch
@@ -664,7 +665,10 @@ class Scheduler(
tmp_batch.next_batch_sampling_info = (
self.tp_worker.cur_sampling_info if batch else None
)
self.process_batch_result(tmp_batch, tmp_result)
# NOTE: we should use current launched batch's launch_done event Instead of the last batch's
self.process_batch_result(
tmp_batch, tmp_result, batch.launch_done if batch else None
)
elif batch is None:
# When the server is idle, do self-check and re-init some states
self.check_memory()
@@ -1417,14 +1421,15 @@ class Scheduler(
self,
batch: ScheduleBatch,
result: Union[GenerationBatchResult, EmbeddingBatchResult],
launch_done: Optional[threading.Event] = None,
):
if batch.forward_mode.is_decode():
self.process_batch_result_decode(batch, result)
self.process_batch_result_decode(batch, result, launch_done)
elif batch.forward_mode.is_extend():
self.process_batch_result_prefill(batch, result)
self.process_batch_result_prefill(batch, result, launch_done)
elif batch.forward_mode.is_idle():
if self.enable_overlap:
self.tp_worker.resolve_batch_result(result.bid)
self.tp_worker.resolve_last_batch_result(launch_done)
if batch.next_batch_sampling_info:
batch.next_batch_sampling_info.update_regex_vocab_mask()
self.current_stream.synchronize()