From 58d1082e392cabbf26c404cb7ec18e4cb51b99e9 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sun, 6 Oct 2024 03:24:04 -0700 Subject: [PATCH] Clean up event loop (#1586) --- python/sglang/srt/managers/scheduler.py | 425 ++++++++++++------------ 1 file changed, 220 insertions(+), 205 deletions(-) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 7f764260c..c667020fa 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -228,20 +228,14 @@ class Scheduler: self.new_token_ratio_decay = global_config.new_token_ratio_decay self.batch_is_full = False + @torch.inference_mode() def event_loop(self): while True: - # Receive requests - if self.tp_rank == 0: - recv_reqs = self.recv_requests_from_zmq() - else: - recv_reqs = None + recv_reqs = self.recv_requests() + self.process_input_requests(recv_reqs) - # Process requests - recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group) - self.process_requests(recv_reqs) - - # Forward - self.forward_step() + # Run one step + self.run_step() # Send results if self.tp_rank == 0: @@ -249,19 +243,23 @@ class Scheduler: self.send_to_detokenizer.send_pyobj(obj) self.out_pyobjs = [] - def recv_requests_from_zmq(self): - recv_reqs = [] + def recv_requests(self): + if self.tp_rank == 0: + recv_reqs = [] - while True: - try: - recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK) - except zmq.ZMQError: - break - recv_reqs.append(recv_req) + while True: + try: + recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK) + except zmq.ZMQError: + break + recv_reqs.append(recv_req) + else: + recv_reqs = None + recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group) return recv_reqs - def process_requests(self, recv_reqs: List): + def process_input_requests(self, recv_reqs: List): for recv_req in recv_reqs: if isinstance(recv_req, TokenizedGenerateReqInput): self.handle_generate_request(recv_req) @@ -279,83 +277,6 @@ class Scheduler: else: raise ValueError(f"Invalid request: {recv_req}") - @torch.inference_mode() - def forward_step(self): - if ( - self.batch_is_full or len(self.waiting_queue) == 0 - ) and self.current_inflight_req is None: - new_batch = None - else: - new_batch = self.get_new_prefill_batch() - - if new_batch is not None: - # Run a new prefill batch - self.forward_prefill_batch(new_batch) - - if not new_batch.is_empty(): - if self.running_batch is None: - self.running_batch = new_batch - else: - self.running_batch.merge_batch(new_batch) - else: - # Run a decode batch - if self.running_batch is not None: - # Run a few decode batches continuously for reducing overhead - for _ in range(global_config.num_continue_decode_steps): - self.num_generated_tokens += len(self.running_batch.reqs) - self.forward_decode_batch(self.running_batch) - - # Print stats - if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0: - self.print_decode_stats() - - if self.running_batch.is_empty(): - self.running_batch = None - break - - if self.out_pyobjs and self.running_batch.has_stream: - break - else: - self.check_memory() - self.new_token_ratio = global_config.init_new_token_ratio - - def print_decode_stats(self): - num_used = self.max_total_num_tokens - ( - self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() - ) - throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic) - self.num_generated_tokens = 0 - self.last_stats_tic = time.time() - logger.info( - f"Decode batch. " - f"#running-req: {len(self.running_batch.reqs)}, " - f"#token: {num_used}, " - f"token usage: {num_used / self.max_total_num_tokens:.2f}, " - f"gen throughput (token/s): {throughput:.2f}, " - f"#queue-req: {len(self.waiting_queue)}" - ) - - def check_memory(self): - available_size = ( - self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() - ) - if available_size != self.max_total_num_tokens: - warnings.warn( - "Warning: " - f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n" - "KV cache pool leak detected!" - ) - exit(1) if crash_on_warning else None - - if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size: - warnings.warn( - "Warning: " - f"available req slots={len(self.req_to_token_pool.free_slots)}, " - f"total slots={self.req_to_token_pool.size}\n" - "Memory pool leak detected!" - ) - exit(1) if crash_on_warning else None - def handle_generate_request( self, recv_req: TokenizedGenerateReqInput, @@ -445,7 +366,88 @@ class Scheduler: self.waiting_queue.append(req) - def get_new_prefill_batch(self) -> Optional[ScheduleBatch]: + def run_step(self): + new_batch = self.get_new_batch_prefill() + + if new_batch is not None: + # Run a new prefill batch + result = self.run_batch(new_batch) + self.process_batch_result(new_batch, result) + + if not new_batch.is_empty(): + if self.running_batch is None: + self.running_batch = new_batch + else: + self.running_batch.merge_batch(new_batch) + else: + # Run a decode batch + if self.running_batch is not None: + # Run a few decode batches continuously for reducing overhead + for _ in range(global_config.num_continue_decode_steps): + batch = self.get_new_batch_decode() + + if batch: + result = self.run_batch(batch) + self.process_batch_result(batch, result) + + # Print stats + if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0: + self.print_decode_stats() + + if self.running_batch.is_empty(): + self.running_batch = None + break + + if self.out_pyobjs and self.running_batch.has_stream: + break + else: + self.check_memory() + self.new_token_ratio = global_config.init_new_token_ratio + + def print_decode_stats(self): + num_used = self.max_total_num_tokens - ( + self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() + ) + throughput = self.num_generated_tokens / (time.time() - self.last_stats_tic) + self.num_generated_tokens = 0 + self.last_stats_tic = time.time() + logger.info( + f"Decode batch. " + f"#running-req: {len(self.running_batch.reqs)}, " + f"#token: {num_used}, " + f"token usage: {num_used / self.max_total_num_tokens:.2f}, " + f"gen throughput (token/s): {throughput:.2f}, " + f"#queue-req: {len(self.waiting_queue)}" + ) + + def check_memory(self): + available_size = ( + self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() + ) + if available_size != self.max_total_num_tokens: + warnings.warn( + "Warning: " + f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n" + "KV cache pool leak detected!" + ) + exit(1) if crash_on_warning else None + + if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size: + warnings.warn( + "Warning: " + f"available req slots={len(self.req_to_token_pool.free_slots)}, " + f"total slots={self.req_to_token_pool.size}\n" + "Memory pool leak detected!" + ) + exit(1) if crash_on_warning else None + + def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: + # Handle the cases where prefill is not allowed + if ( + self.batch_is_full or len(self.waiting_queue) == 0 + ) and self.current_inflight_req is None: + return None + running_bs = ( len(self.running_batch.reqs) if self.running_batch is not None else 0 ) @@ -456,8 +458,8 @@ class Scheduler: # Get priority queue prefix_computed = self.policy.calc_priority(self.waiting_queue) + # Prefill policy num_mixed_running = running_bs if self.is_mixed_chunk else 0 - adder = PrefillAdder( self.tree_cache, self.running_batch, @@ -517,6 +519,8 @@ class Scheduler: if len(can_run_list) == 0: return None + self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list] + # Print stats if self.tp_rank == 0: if isinstance(self.tree_cache, RadixCache): @@ -544,7 +548,7 @@ class Scheduler: f"#cached-token: {adder.log_hit_tokens}, " f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, " - f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}" + f"#queue-req: {len(self.waiting_queue) + has_inflight}" ) else: logger.info( @@ -555,41 +559,97 @@ class Scheduler: f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, " f"token usage: {num_used / self.max_total_num_tokens:.2f}, " f"#running-req: {running_bs}, " - f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}" + f"#queue-req: {len(self.waiting_queue) + has_inflight}" ) - # Return the new batch + # Create a new batch new_batch = ScheduleBatch.init_new( can_run_list, self.req_to_token_pool, self.token_to_kv_pool, self.tree_cache, ) - self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list] - return new_batch - - def forward_prefill_batch(self, batch: ScheduleBatch): - # Build batch tensors - batch.prepare_for_extend(self.model_config.vocab_size) + new_batch.prepare_for_extend(self.model_config.vocab_size) + # Mixed-style chunked prefill decoding_reqs = [] if self.is_mixed_chunk and self.running_batch is not None: self.running_batch.prepare_for_decode() - batch.mix_with_running(self.running_batch) + new_batch.mix_with_running(self.running_batch) decoding_reqs = self.running_batch.reqs self.running_batch = None + new_batch.decoding_reqs = decoding_reqs + return new_batch + + def get_new_batch_decode(self) -> Optional[ScheduleBatch]: + batch = self.running_batch + + # Check if decode out of memory + if not batch.check_decode_mem(): + old_ratio = self.new_token_ratio + + retracted_reqs, new_token_ratio = batch.retract_decode() + self.new_token_ratio = new_token_ratio + + logger.info( + "Decode out of memory happened. " + f"#retracted_reqs: {len(retracted_reqs)}, " + f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}" + ) + self.waiting_queue.extend(retracted_reqs) + else: + self.new_token_ratio = max( + self.new_token_ratio - self.new_token_ratio_decay, + self.min_new_token_ratio, + ) + + # Check for jump-forward + if not self.disable_regex_jump_forward: + jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func) + self.waiting_queue.extend(jump_forward_reqs) + if batch.is_empty(): + return None + + # Update batch tensors + self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30) + batch.prepare_for_decode() + return batch + + def run_batch(self, batch: ScheduleBatch): if self.is_generation: - # Forward and sample the next tokens - if batch.extend_num_tokens != 0: + if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0: model_worker_batch = batch.get_model_worker_batch() logits_output, next_token_ids = self.tp_worker.forward_batch_generation( model_worker_batch ) - batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( - next_token_ids - ) + else: + logits_output = None + if self.tokenizer is not None: + next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs) + else: + next_token_ids = [0] * len(batch.reqs) + return logits_output, next_token_ids + else: # embedding or reward model + assert batch.extend_num_tokens != 0 + model_worker_batch = batch.get_model_worker_batch() + embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch) + return embeddings + def process_batch_result(self, batch: ScheduleBatch, result): + if batch.forward_mode.is_decode(): + self.process_batch_result_decode(batch, result) + else: + self.process_batch_result_prefill(batch, result) + + def process_batch_result_prefill(self, batch: ScheduleBatch, result): + if self.is_generation: + logits_output, next_token_ids = result + batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( + next_token_ids + ) + + if logits_output: # Move logprobs to cpu if logits_output.next_token_logprobs is not None: logits_output.next_token_logprobs = ( @@ -607,16 +667,7 @@ class Scheduler: logits_output.normalized_prompt_logprobs.tolist() ) - next_token_ids = next_token_ids.tolist() - else: - if self.tokenizer is None: - next_token_ids = [] - for req in batch.reqs: - next_token_ids.append( - next(iter(req.sampling_params.stop_token_ids)) - ) - else: - next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs) + next_token_ids = next_token_ids.tolist() # Check finish conditions logprob_pt = 0 @@ -634,7 +685,7 @@ class Scheduler: if req.finished(): self.tree_cache.cache_finished_req(req) - elif req not in decoding_reqs: + elif req not in batch.decoding_reqs: # To reduce overhead, only cache prefill reqs self.tree_cache.cache_unfinished_req(req) @@ -646,10 +697,9 @@ class Scheduler: logprob_pt += self.add_logprob_return_values( i, req, logprob_pt, next_token_ids, logits_output ) - else: + else: # embedding or reward model assert batch.extend_num_tokens != 0 - model_worker_batch = batch.get_model_worker_batch() - embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch) + embeddings = result # Check finish conditions for i, req in enumerate(batch.reqs): @@ -671,6 +721,45 @@ class Scheduler: self.handle_finished_requests(batch) + def process_batch_result_decode(self, batch: ScheduleBatch, result): + logits_output, next_token_ids = result + batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( + next_token_ids + ) + self.num_generated_tokens += len(batch.reqs) + + # Move logprobs to cpu + if logits_output.next_token_logprobs is not None: + next_token_logprobs = logits_output.next_token_logprobs[ + torch.arange(len(next_token_ids), device=next_token_ids.device), + next_token_ids, + ].tolist() + + next_token_ids = next_token_ids.tolist() + + # Check finish condition + for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): + req.completion_tokens_wo_jump_forward += 1 + req.output_ids.append(next_token_id) + req.check_finished() + + if req.regex_fsm is not None: + req.regex_fsm_state = req.regex_fsm.get_next_state( + req.regex_fsm_state, next_token_id + ) + + if req.finished(): + self.tree_cache.cache_finished_req(req) + + if req.return_logprob: + req.output_token_logprobs.append( + (next_token_logprobs[i], next_token_id) + ) + if req.top_logprobs_num > 0: + req.output_top_logprobs.append(logits_output.output_top_logprobs[i]) + + self.handle_finished_requests(batch) + def add_logprob_return_values( self, i: int, @@ -744,80 +833,6 @@ class Scheduler: return num_input_logprobs - def forward_decode_batch(self, batch: ScheduleBatch): - # Check if decode out of memory - if not batch.check_decode_mem(): - old_ratio = self.new_token_ratio - - retracted_reqs, new_token_ratio = batch.retract_decode() - self.new_token_ratio = new_token_ratio - - logger.info( - "Decode out of memory happened. " - f"#retracted_reqs: {len(retracted_reqs)}, " - f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}" - ) - self.waiting_queue.extend(retracted_reqs) - else: - self.new_token_ratio = max( - self.new_token_ratio - self.new_token_ratio_decay, - self.min_new_token_ratio, - ) - - # Check for jump-forward - if not self.disable_regex_jump_forward: - jump_forward_reqs = batch.check_for_jump_forward(self.pad_input_ids_func) - self.waiting_queue.extend(jump_forward_reqs) - if batch.is_empty(): - return - - # Update batch tensors - self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30) - batch.prepare_for_decode() - - # Forward and sample the next tokens - model_worker_batch = batch.get_model_worker_batch() - logits_output, next_token_ids = self.tp_worker.forward_batch_generation( - model_worker_batch - ) - batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( - next_token_ids - ) - - # Move logprobs to cpu - if logits_output.next_token_logprobs is not None: - next_token_logprobs = logits_output.next_token_logprobs[ - torch.arange(len(next_token_ids), device=next_token_ids.device), - next_token_ids, - ].tolist() - - next_token_ids = next_token_ids.tolist() - - # Check finish condition - has_finished = False - for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): - req.completion_tokens_wo_jump_forward += 1 - req.output_ids.append(next_token_id) - req.check_finished() - - if req.regex_fsm is not None: - req.regex_fsm_state = req.regex_fsm.get_next_state( - req.regex_fsm_state, next_token_id - ) - - if req.finished(): - self.tree_cache.cache_finished_req(req) - has_finished = True - - if req.return_logprob: - req.output_token_logprobs.append( - (next_token_logprobs[i], next_token_id) - ) - if req.top_logprobs_num > 0: - req.output_top_logprobs.append(logits_output.output_top_logprobs[i]) - - self.handle_finished_requests(batch) - def handle_finished_requests(self, batch: ScheduleBatch): output_rids = [] output_meta_info = [] @@ -829,7 +844,7 @@ class Scheduler: output_read_offsets = [] output_skip_special_tokens = [] output_spaces_between_special_tokens = [] - else: # for embedding model + else: # embedding or reward model output_embeddings = [] unfinished_indices = [] @@ -886,7 +901,7 @@ class Scheduler: req.normalized_prompt_logprob, ) output_meta_info.append(meta_info) - else: # for embedding model + else: # embedding or reward model output_embeddings.append(req.embedding) meta_info = { "prompt_tokens": len(req.origin_input_ids), @@ -909,7 +924,7 @@ class Scheduler: output_finished_reason, ) ) - else: # for embedding model + else: # embedding or reward model self.out_pyobjs.append( BatchEmbeddingOut( output_rids,