Remove sampling info events and overlap thread file (#11300)
This commit is contained in:
@@ -1012,22 +1012,9 @@ class Scheduler(
|
||||
result = self.run_batch(batch)
|
||||
self.result_queue.append((batch.copy(), result))
|
||||
|
||||
if self.last_batch is None:
|
||||
# Create a dummy first batch to start the pipeline for overlap schedule.
|
||||
# It is now used for triggering the sampling_info_done event.
|
||||
tmp_batch = ScheduleBatch(
|
||||
reqs=None,
|
||||
forward_mode=ForwardMode.DUMMY_FIRST,
|
||||
next_batch_sampling_info=self.tp_worker.cur_sampling_info,
|
||||
)
|
||||
self.process_batch_result(tmp_batch, None)
|
||||
|
||||
if self.last_batch:
|
||||
# Process the results of the last batch
|
||||
tmp_batch, tmp_result = self.result_queue.popleft()
|
||||
tmp_batch.next_batch_sampling_info = (
|
||||
self.tp_worker.cur_sampling_info if batch else None
|
||||
)
|
||||
self.process_batch_result(tmp_batch, tmp_result)
|
||||
elif batch is None:
|
||||
# When the server is idle, do self-check and re-init some states
|
||||
@@ -2100,7 +2087,7 @@ class Scheduler(
|
||||
self.record_batch_in_overlap(model_worker_batch)
|
||||
|
||||
# Sampling info will be modified during forward
|
||||
model_worker_batch.sampling_info = self.tp_worker.cur_sampling_info = (
|
||||
model_worker_batch.sampling_info = (
|
||||
model_worker_batch.sampling_info.copy_for_forward()
|
||||
)
|
||||
|
||||
@@ -2219,9 +2206,6 @@ class Scheduler(
|
||||
if self.enable_overlap:
|
||||
if result.copy_done is not None:
|
||||
result.copy_done.synchronize()
|
||||
self.set_next_batch_sampling_info_done(batch)
|
||||
elif batch.forward_mode.is_dummy_first():
|
||||
self.set_next_batch_sampling_info_done(batch)
|
||||
|
||||
self.maybe_send_health_check_signal()
|
||||
|
||||
@@ -2431,13 +2415,6 @@ class Scheduler(
|
||||
self._add_request_to_queue(req)
|
||||
self.grammar_queue = self.grammar_queue[num_ready_reqs:]
|
||||
|
||||
def set_next_batch_sampling_info_done(self, batch: ScheduleBatch):
|
||||
if batch.next_batch_sampling_info:
|
||||
if batch.next_batch_sampling_info.grammars is not None:
|
||||
batch.next_batch_sampling_info.update_regex_vocab_mask()
|
||||
self.default_stream.synchronize()
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
|
||||
def watchdog_thread(self):
|
||||
"""A watch dog thread that will try to kill the server itself if one forward batch takes too long."""
|
||||
self.watchdog_last_forward_ct = 0
|
||||
|
||||
Reference in New Issue
Block a user