From 24f3e1511cc289b1b7e3e94e4ee19ab559a5e7f9 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 14 Oct 2024 05:25:00 -0700 Subject: [PATCH] [Minor] Improve style (#1666) --- python/sglang/srt/managers/schedule_batch.py | 1 + python/sglang/srt/managers/schedule_policy.py | 2 +- python/sglang/srt/managers/scheduler.py | 52 ++++++++++--------- 3 files changed, 29 insertions(+), 26 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index bb70366b0..7f0c243b7 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -203,6 +203,7 @@ class Req: self.prefix_indices = [] self.extend_input_len = 0 self.last_node = None + self.is_inflight_req = 0 # Logprobs (arguments) self.return_logprob = False diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 1068250c1..74a3a621c 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -45,7 +45,7 @@ class SchedulePolicy: def calc_priority(self, waiting_queue: List[Req]): # Compute matched prefix length prefix_computed = False - if self.policy in ["lpm", "dfs-weight"]: + if self.policy == "lpm" or self.policy == "dfs-weight": for r in waiting_queue: # NOTE: the prefix_indices must always be aligned with last_node r.prefix_indices, r.last_node = self.tree_cache.match_prefix( diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 0c5049844..796bca849 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -194,7 +194,7 @@ class Scheduler: # Init running status self.waiting_queue: List[Req] = [] - self.running_batch: ScheduleBatch = None + self.running_batch: Optional[ScheduleBatch] = None self.decode_forward_ct = 0 self.stream_interval = server_args.stream_interval self.num_generated_tokens = 0 @@ -273,6 +273,9 @@ class Scheduler: break result = self.run_batch(batch) self.process_batch_result(batch, result) + else: + self.check_memory() + self.new_token_ratio = global_config.init_new_token_ratio self.last_batch = batch @@ -468,8 +471,6 @@ class Scheduler: # Check memory if self.running_batch is None: - self.check_memory() - self.new_token_ratio = global_config.init_new_token_ratio return # Run decode @@ -489,9 +490,7 @@ class Scheduler: ) and self.current_inflight_req is None: return None - running_bs = ( - len(self.running_batch.reqs) if self.running_batch is not None else 0 - ) + running_bs = len(self.running_batch.reqs) if self.running_batch else 0 if running_bs >= self.max_running_requests: self.batch_is_full = True return None @@ -512,7 +511,7 @@ class Scheduler: ) has_inflight = self.current_inflight_req is not None - if self.current_inflight_req is not None: + if has_inflight: self.current_inflight_req.init_next_round_input( None if prefix_computed else self.tree_cache ) @@ -520,7 +519,7 @@ class Scheduler: self.current_inflight_req ) - if self.lora_paths is not None: + if self.lora_paths: lora_set = ( set([req.lora_path for req in self.running_batch.reqs]) if self.running_batch is not None @@ -529,7 +528,7 @@ class Scheduler: for req in self.waiting_queue: if ( - self.lora_paths is not None + self.lora_paths and len( lora_set | set([req.lora_path for req in adder.can_run_list]) @@ -551,16 +550,20 @@ class Scheduler: self.batch_is_full = True break + # Update waiting queue can_run_list = adder.can_run_list + if len(can_run_list) == 0: + return None + self.waiting_queue = [ + x for x in self.waiting_queue if x not in set(can_run_list) + ] if adder.new_inflight_req is not None: assert self.current_inflight_req is None self.current_inflight_req = adder.new_inflight_req - if len(can_run_list) == 0: - return None - - self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list] + if self.current_inflight_req: + self.current_inflight_req.is_inflight_req += 1 # Print stats if self.tp_rank == 0: @@ -613,13 +616,13 @@ class Scheduler: new_batch.prepare_for_extend(self.model_config.vocab_size) # Mixed-style chunked prefill - decoding_reqs = [] if self.is_mixed_chunk and self.running_batch is not None: self.running_batch.prepare_for_decode() new_batch.mix_with_running(self.running_batch) - decoding_reqs = self.running_batch.reqs + new_batch.decoding_reqs = self.running_batch.reqs self.running_batch = None - new_batch.decoding_reqs = decoding_reqs + else: + new_batch.decoding_reqs = None return new_batch @@ -738,12 +741,12 @@ class Scheduler: if req.finished(): self.tree_cache.cache_finished_req(req) - elif req not in batch.decoding_reqs: - # To reduce overhead, only cache prefill reqs + elif not batch.decoding_reqs or req not in batch.decoding_reqs: self.tree_cache.cache_unfinished_req(req) - if req is self.current_inflight_req: + if req.is_inflight_req > 0: # Inflight request would get a new req idx + req.is_inflight_req -= 1 self.req_to_token_pool.free(req.req_pool_idx) if req.return_logprob: @@ -768,8 +771,9 @@ class Scheduler: else: self.tree_cache.cache_unfinished_req(req) - if req is self.current_inflight_req: + if req.is_inflight_req > 0: # Inflight request would get a new req idx + req.is_inflight_req -= 1 self.req_to_token_pool.free(req.req_pool_idx) self.stream_output(batch) @@ -906,13 +910,11 @@ class Scheduler: else: # embedding or reward model output_embeddings = [] + is_stream_iter = self.decode_forward_ct % self.stream_interval == 0 + for req in batch.reqs: if req.finished() or ( - req.stream - and ( - self.decode_forward_ct % self.stream_interval == 0 - or len(req.output_ids) == 1 - ) + req.stream and (is_stream_iter or len(req.output_ids) == 1) ): output_rids.append(req.rid) output_finished_reason.append(req.finished_reason)