[Fix] Fix major performance bug in certain cases (#1563)

Co-authored-by: hnyls2002 <hnyls2002@gmail.com>
This commit is contained in:
Ying Sheng
2024-10-04 01:51:11 -07:00
committed by GitHub
parent 2432ad40c6
commit 04b262cd91
5 changed files with 50 additions and 18 deletions

View File

@@ -222,7 +222,7 @@ class Scheduler:
)
self.new_token_ratio = self.min_new_token_ratio
self.new_token_ratio_decay = global_config.new_token_ratio_decay
self.do_not_get_new_batch = False
self.batch_is_full = False
def event_loop(self):
while True:
@@ -261,12 +261,10 @@ class Scheduler:
for recv_req in recv_reqs:
if isinstance(recv_req, TokenizedGenerateReqInput):
self.handle_generate_request(recv_req)
self.do_not_get_new_batch = False
elif isinstance(
recv_req, (TokenizedEmbeddingReqInput, TokenizedRewardReqInput)
):
self.handle_embedding_request(recv_req)
self.do_not_get_new_batch = False
elif isinstance(recv_req, FlushCacheReq):
self.flush_cache()
elif isinstance(recv_req, AbortReq):
@@ -279,11 +277,12 @@ class Scheduler:
@torch.inference_mode()
def forward_step(self):
if self.do_not_get_new_batch and self.current_inflight_req is None:
if (
self.batch_is_full or len(self.waiting_queue) == 0
) and self.current_inflight_req is None:
new_batch = None
else:
new_batch = self.get_new_prefill_batch()
self.do_not_get_new_batch = False
if new_batch is not None:
# Run a new prefill batch
@@ -447,6 +446,7 @@ class Scheduler:
len(self.running_batch.reqs) if self.running_batch is not None else 0
)
if running_bs >= self.max_running_requests:
self.batch_is_full = True
return None
# Get priority queue
@@ -490,9 +490,11 @@ class Scheduler:
)
> self.max_loras_per_batch
):
self.batch_is_full = True
break
if adder.no_remaining_tokens():
self.batch_is_full = True
break
req.init_next_round_input(None if prefix_computed else self.tree_cache)
res = adder.add_one_req(req)
@@ -500,6 +502,7 @@ class Scheduler:
not res
or running_bs + len(adder.can_run_list) >= self.max_running_requests
):
self.batch_is_full = True
break
can_run_list = adder.can_run_list
@@ -810,9 +813,6 @@ class Scheduler:
if req.top_logprobs_num > 0:
req.output_top_logprobs.append(logits_output.output_top_logprobs[i])
if not has_finished:
self.do_not_get_new_batch = True
self.handle_finished_requests(batch)
def handle_finished_requests(self, batch: ScheduleBatch):
@@ -833,6 +833,8 @@ class Scheduler:
for i, req in enumerate(batch.reqs):
if not req.finished() and req is not self.current_inflight_req:
unfinished_indices.append(i)
else:
self.batch_is_full = False
if req.finished() or (
req.stream