Remove sampling info events and overlap thread file (#11300)

This commit is contained in:
Liangsheng Yin
2025-10-07 21:34:25 +08:00
committed by GitHub
parent 79d3495177
commit 501dfa6b42
9 changed files with 13 additions and 393 deletions

View File

@@ -1012,22 +1012,9 @@ class Scheduler(
result = self.run_batch(batch)
self.result_queue.append((batch.copy(), result))
if self.last_batch is None:
# Create a dummy first batch to start the pipeline for overlap schedule.
# It is now used for triggering the sampling_info_done event.
tmp_batch = ScheduleBatch(
reqs=None,
forward_mode=ForwardMode.DUMMY_FIRST,
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
)
self.process_batch_result(tmp_batch, None)
if self.last_batch:
# Process the results of the last batch
tmp_batch, tmp_result = self.result_queue.popleft()
tmp_batch.next_batch_sampling_info = (
self.tp_worker.cur_sampling_info if batch else None
)
self.process_batch_result(tmp_batch, tmp_result)
elif batch is None:
# When the server is idle, do self-check and re-init some states
@@ -2100,7 +2087,7 @@ class Scheduler(
self.record_batch_in_overlap(model_worker_batch)
# Sampling info will be modified during forward
model_worker_batch.sampling_info = self.tp_worker.cur_sampling_info = (
model_worker_batch.sampling_info = (
model_worker_batch.sampling_info.copy_for_forward()
)
@@ -2219,9 +2206,6 @@ class Scheduler(
if self.enable_overlap:
if result.copy_done is not None:
result.copy_done.synchronize()
self.set_next_batch_sampling_info_done(batch)
elif batch.forward_mode.is_dummy_first():
self.set_next_batch_sampling_info_done(batch)
self.maybe_send_health_check_signal()
@@ -2431,13 +2415,6 @@ class Scheduler(
self._add_request_to_queue(req)
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
if batch.next_batch_sampling_info:
if batch.next_batch_sampling_info.grammars is not None:
batch.next_batch_sampling_info.update_regex_vocab_mask()
self.default_stream.synchronize()
batch.next_batch_sampling_info.sampling_info_done.set()
def watchdog_thread(self):
"""A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
self.watchdog_last_forward_ct = 0