Add a new event loop (#1677)

This commit is contained in:
Lianmin Zheng
2024-10-16 01:33:20 -07:00
committed by GitHub
parent a5114b6f91
commit 9116b2896f
9 changed files with 161 additions and 25 deletions

View File

@@ -736,6 +736,10 @@ class ScheduleBatch:
self.input_ids = self.output_ids
self.seq_lens.add_(1)
self.output_ids = None
if self.sampling_info.penalizer_orchestrator:
self.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
self.input_ids
)
# Alloc mem
bs = len(self.reqs)

View File

@@ -20,6 +20,7 @@ import logging
import os
import time
import warnings
from collections import deque
from types import SimpleNamespace
from typing import List, Optional, Union
@@ -192,9 +193,20 @@ class Scheduler:
self.tree_cache_metrics = {"total": 0, "hit": 0}
self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache)
if self.server_args.enable_overlap_schedule:
def cache_finished_req(req):
free_delta = int(self.running_batch and req in self.cur_batch.reqs)
self.tree_cache.cache_finished_req(req, free_delta=free_delta)
else:
cache_finished_req = self.tree_cache.cache_finished_req
self.cache_finished_req = cache_finished_req
# Init running status
self.waiting_queue: List[Req] = []
self.running_batch: Optional[ScheduleBatch] = None
self.cur_batch: Optional[ScheduleBatch] = None
self.decode_forward_ct = 0
self.stream_interval = server_args.stream_interval
self.num_generated_tokens = 0
@@ -279,6 +291,32 @@ class Scheduler:
self.last_batch = batch
@torch.inference_mode()
def event_loop_overlap(self):
result_queue = deque()
self.last_batch = None
self.running_batch = None
while True:
recv_reqs = self.recv_requests()
self.process_input_requests(recv_reqs)
batch = self.get_next_batch_to_run()
self.cur_batch = batch
if batch:
result = self.run_batch(batch)
result_queue.append((batch.copy(), result))
if self.last_batch:
tmp_batch, tmp_result = result_queue.popleft()
self.process_batch_result(tmp_batch, tmp_result)
elif batch is None:
self.check_memory()
self.new_token_ratio = global_config.init_new_token_ratio
self.last_batch = batch
def recv_requests(self):
if self.tp_rank == 0:
recv_reqs = []
@@ -705,11 +743,6 @@ class Scheduler:
def process_batch_result_prefill(self, batch: ScheduleBatch, result):
if self.is_generation:
logits_output, next_token_ids = result
if batch.sampling_info.penalizer_orchestrator:
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
)
if batch.return_logprob:
# Move logprobs to cpu
if logits_output.next_token_logprobs is not None:
@@ -742,7 +775,7 @@ class Scheduler:
req.check_finished()
if req.finished():
self.tree_cache.cache_finished_req(req)
self.cache_finished_req(req)
elif not batch.decoding_reqs or req not in batch.decoding_reqs:
self.tree_cache.cache_unfinished_req(req)
@@ -771,7 +804,7 @@ class Scheduler:
req.check_finished()
if req.finished():
self.tree_cache.cache_finished_req(req)
self.cache_finished_req(req)
else:
self.tree_cache.cache_unfinished_req(req)
@@ -779,10 +812,6 @@ class Scheduler:
def process_batch_result_decode(self, batch: ScheduleBatch, result):
logits_output, next_token_ids = result
if batch.sampling_info.penalizer_orchestrator:
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
)
self.num_generated_tokens += len(batch.reqs)
# Move logprobs to cpu
@@ -796,6 +825,9 @@ class Scheduler:
# Check finish condition
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
if self.server_args.enable_overlap_schedule and req.finished():
continue
req.completion_tokens_wo_jump_forward += 1
req.output_ids.append(next_token_id)
req.check_finished()
@@ -806,7 +838,7 @@ class Scheduler:
)
if req.finished():
self.tree_cache.cache_finished_req(req)
self.cache_finished_req(req)
if req.return_logprob:
req.output_token_logprobs.append(
@@ -1027,7 +1059,7 @@ class Scheduler:
for req in self.running_batch.reqs:
if req.rid == recv_req.rid and not req.finished():
req.finished_reason = FINISH_ABORT()
self.tree_cache.cache_finished_req(req)
self.cache_finished_req(req)
break
def update_weights(self, recv_req: UpdateWeightReqInput):
@@ -1072,7 +1104,10 @@ def run_scheduler_process(
try:
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank)
pipe_writer.send("ready")
scheduler.event_loop_normal()
if server_args.enable_overlap_schedule:
scheduler.event_loop_overlap()
else:
scheduler.event_loop_normal()
except Exception:
msg = get_exception_traceback()
logger.error(msg)