From 05bea6883c4b3f2fb7f01287cd8dccefeacd545f Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Sat, 7 Sep 2024 20:46:27 -0700 Subject: [PATCH] Fix some online scheduling delay (#1345) --- .../sglang/srt/managers/policy_scheduler.py | 77 +++++++++++-------- python/sglang/srt/managers/tp_worker.py | 11 ++- 2 files changed, 52 insertions(+), 36 deletions(-) diff --git a/python/sglang/srt/managers/policy_scheduler.py b/python/sglang/srt/managers/policy_scheduler.py index 3a70bfe54..b58c0e7b3 100644 --- a/python/sglang/srt/managers/policy_scheduler.py +++ b/python/sglang/srt/managers/policy_scheduler.py @@ -119,6 +119,7 @@ class PrefillAdder: self.running_batch = running_batch self.new_token_ratio = new_token_ratio self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens + self.rem_total_tokens_ = self.rem_total_tokens self.total_tokens = rem_total_tokens self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens self.rem_chunk_tokens = rem_chunk_tokens @@ -153,11 +154,18 @@ class PrefillAdder: for r in running_batch.reqs ] ) + self.rem_total_tokens_ -= sum( + [ + r.sampling_params.max_new_tokens - len(r.output_ids) + for r in running_batch.reqs + ] + ) def _prefill_one_req( self, prefix_len: int, extend_input_len: int, max_new_tokens: int ): self.rem_total_tokens -= extend_input_len + max_new_tokens + self.rem_total_tokens_ -= extend_input_len + max_new_tokens self.rem_input_tokens -= extend_input_len if self.rem_chunk_tokens is not None: self.rem_chunk_tokens -= extend_input_len @@ -231,43 +239,52 @@ class PrefillAdder: return None - if self.req_states is None: - self.req_states = [] - if self.running_batch is not None: - for r in self.running_batch.reqs: + # Quick Check + can_run = False + if ( + req.extend_input_len + req.sampling_params.max_new_tokens + <= self.rem_total_tokens + ): + can_run = True + + if not can_run: + if self.req_states is None: + self.req_states = [] + if self.running_batch is not None: + for r in self.running_batch.reqs: + state = get_req_state(r) + if state is not None: + self.req_states.append(state) + for r in self.can_run_list: state = get_req_state(r) if state is not None: self.req_states.append(state) - for r in self.can_run_list: - state = get_req_state(r) + state = get_req_state(req) if state is not None: self.req_states.append(state) - state = get_req_state(req) - if state is not None: - self.req_states.append(state) - self.req_states.sort(key=lambda x: x[0]) - else: - state = get_req_state(req) - if state is not None: - for i, (tokens_left, tokens_occupied) in enumerate(self.req_states): - if tokens_left >= state[0]: - self.req_states.insert(i, state) - break - else: - self.req_states.append(state) + self.req_states.sort(key=lambda x: x[0]) + else: + state = get_req_state(req) + if state is not None: + for i, (tokens_left, tokens_occupied) in enumerate(self.req_states): + if tokens_left >= state[0]: + self.req_states.insert(i, state) + break + else: + self.req_states.append(state) - tokens_freed = 0 - for i, (tokens_left, tokens_occupied) in enumerate(self.req_states): - decode_steps = ( - self.req_states[i + 1][0] - if i + 1 < len(self.req_states) - else tokens_left - ) - bs = len(self.req_states) - i - if self.total_tokens + tokens_freed - decode_steps * bs <= 0: - return False - tokens_freed += tokens_occupied + tokens_freed = 0 + for i, (tokens_left, tokens_occupied) in enumerate(self.req_states): + decode_steps = ( + self.req_states[i + 1][0] + if i + 1 < len(self.req_states) + else tokens_left + ) + bs = len(self.req_states) - i + if self.total_tokens + tokens_freed - decode_steps * bs <= 0: + return False + tokens_freed += tokens_occupied if req.extend_input_len <= self.rem_chunk_tokens: self.can_run_list.append(req) diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index d914a71c2..c2c0e6c2d 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -231,6 +231,7 @@ class ModelTpServer: recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput) ): self.handle_generate_request(recv_req) + self.do_not_get_new_batch = False elif isinstance(recv_req, FlushCacheReq): self.flush_cache() elif isinstance(recv_req, AbortReq): @@ -254,12 +255,10 @@ class ModelTpServer: @torch.inference_mode() def forward_step(self): - if self.current_inflight_req is not None: - self.do_not_get_new_batch = False - - new_batch = ( - self.get_new_prefill_batch() if not self.do_not_get_new_batch else None - ) + if self.do_not_get_new_batch and self.current_inflight_req is None: + new_batch = None + else: + new_batch = self.get_new_prefill_batch() self.do_not_get_new_batch = False if new_batch is not None: