Improve overlap scheduling (#5788)
This commit is contained in:
@@ -645,6 +645,7 @@ class Scheduler(
|
||||
self.cur_batch = batch
|
||||
|
||||
if batch:
|
||||
batch.launch_done = threading.Event()
|
||||
result = self.run_batch(batch)
|
||||
self.result_queue.append((batch.copy(), result))
|
||||
|
||||
@@ -656,7 +657,7 @@ class Scheduler(
|
||||
forward_mode=ForwardMode.DUMMY_FIRST,
|
||||
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
||||
)
|
||||
self.process_batch_result(tmp_batch, None)
|
||||
self.process_batch_result(tmp_batch, None, batch.launch_done)
|
||||
|
||||
if self.last_batch:
|
||||
# Process the results of the last batch
|
||||
@@ -664,7 +665,10 @@ class Scheduler(
|
||||
tmp_batch.next_batch_sampling_info = (
|
||||
self.tp_worker.cur_sampling_info if batch else None
|
||||
)
|
||||
self.process_batch_result(tmp_batch, tmp_result)
|
||||
# NOTE: we should use current launched batch's launch_done event Instead of the last batch's
|
||||
self.process_batch_result(
|
||||
tmp_batch, tmp_result, batch.launch_done if batch else None
|
||||
)
|
||||
elif batch is None:
|
||||
# When the server is idle, do self-check and re-init some states
|
||||
self.check_memory()
|
||||
@@ -1417,14 +1421,15 @@ class Scheduler(
|
||||
self,
|
||||
batch: ScheduleBatch,
|
||||
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
||||
launch_done: Optional[threading.Event] = None,
|
||||
):
|
||||
if batch.forward_mode.is_decode():
|
||||
self.process_batch_result_decode(batch, result)
|
||||
self.process_batch_result_decode(batch, result, launch_done)
|
||||
elif batch.forward_mode.is_extend():
|
||||
self.process_batch_result_prefill(batch, result)
|
||||
self.process_batch_result_prefill(batch, result, launch_done)
|
||||
elif batch.forward_mode.is_idle():
|
||||
if self.enable_overlap:
|
||||
self.tp_worker.resolve_batch_result(result.bid)
|
||||
self.tp_worker.resolve_last_batch_result(launch_done)
|
||||
if batch.next_batch_sampling_info:
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
self.current_stream.synchronize()
|
||||
|
||||
Reference in New Issue
Block a user