From 7ee6c259ff8c9cc29f92c4c68530810c8bfc2b30 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 12 Oct 2024 21:35:30 -0700 Subject: [PATCH] Simplify the event loop and expose `--num-continuous-decode-steps` as an argument (#1652) --- python/sglang/global_config.py | 1 - python/sglang/srt/managers/schedule_batch.py | 16 +++ python/sglang/srt/managers/scheduler.py | 119 +++++++++---------- python/sglang/srt/managers/tp_worker.py | 2 +- python/sglang/srt/server_args.py | 9 ++ 5 files changed, 85 insertions(+), 62 deletions(-) diff --git a/python/sglang/global_config.py b/python/sglang/global_config.py index 17aab788a..5e7290edc 100644 --- a/python/sglang/global_config.py +++ b/python/sglang/global_config.py @@ -19,7 +19,6 @@ class GlobalConfig: self.new_token_ratio_decay = 0.001 # Runtime constants: others - self.num_continue_decode_steps = 10 self.retract_decode_steps = 20 self.flashinfer_workspace_size = os.environ.get( "FLASHINFER_WORKSPACE_SIZE", 384 * 1024 * 1024 diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index afe356b38..869c529e3 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -831,6 +831,22 @@ class ScheduleBatch: sampling_info=self.sampling_info, ) + def copy(self): + return ScheduleBatch( + reqs=self.reqs, + req_to_token_pool=self.req_to_token_pool, + token_to_kv_pool=self.token_to_kv_pool, + tree_cache=self.tree_cache, + forward_mode=self.forward_mode, + output_token_ids=self.output_token_ids, + ) + + def __str__(self): + return ( + f"ScheduleBatch(forward_mode={self.forward_mode.name}, " + f"#req={(len(self.reqs))})" + ) + @dataclass class ModelWorkerBatch: diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 52a4840fc..62d1ff9ed 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -20,6 +20,7 @@ import logging import os import time import warnings +from types import SimpleNamespace from typing import List, Optional, Union import torch @@ -106,7 +107,8 @@ class Scheduler: self.send_to_detokenizer = context.socket(zmq.PUSH) self.send_to_detokenizer.connect(f"ipc://{port_args.detokenizer_ipc_name}") else: - self.recv_from_tokenizer = self.send_to_detokenizer = None + self.recv_from_tokenizer = None + self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None) # Init tokenizer self.model_config = ModelConfig( @@ -190,7 +192,6 @@ class Scheduler: # Init running status self.waiting_queue: List[Req] = [] self.running_batch: ScheduleBatch = None - self.out_pyobjs = [] self.decode_forward_ct = 0 self.stream_interval = server_args.stream_interval self.num_generated_tokens = 0 @@ -247,13 +248,30 @@ class Scheduler: @torch.inference_mode() def event_loop(self): + self.last_batch = None + while True: recv_reqs = self.recv_requests() self.process_input_requests(recv_reqs) - self.run_step() + batch = self.get_next_batch_to_run() - self.send_results() + if batch: + result = self.run_batch(batch) + self.process_batch_result(batch, result) + + # Decode multiple steps to reduce the overhead + if batch.forward_mode.is_decode(): + for _ in range(self.server_args.num_continuous_decode_steps - 1): + if not self.running_batch: + break + self.update_running_batch() + if not self.running_batch: + break + result = self.run_batch(batch) + self.process_batch_result(batch, result) + + self.last_batch = batch def recv_requests(self): if self.tp_rank == 0: @@ -286,7 +304,9 @@ class Scheduler: self.abort_request(recv_req) elif isinstance(recv_req, UpdateWeightReqInput): success, message = self.update_weights(recv_req) - self.out_pyobjs.append(UpdateWeightReqOutput(success, message)) + self.send_to_detokenizer.send_pyobj( + UpdateWeightReqOutput(success, message) + ) elif isinstance(recv_req, ProfileReq): if recv_req == ProfileReq.START_PROFILE: self.start_profile() @@ -384,12 +404,6 @@ class Scheduler: self.waiting_queue.append(req) - def send_results(self): - if self.tp_rank == 0: - for obj in self.out_pyobjs: - self.send_to_detokenizer.send_pyobj(obj) - self.out_pyobjs = [] - def print_decode_stats(self): num_used = self.max_total_num_tokens - ( self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size() @@ -427,44 +441,32 @@ class Scheduler: ) exit(1) if crash_on_warning else None - def run_step(self): + def get_next_batch_to_run(self): + # Merge prefill to the running batch + if ( + self.last_batch + and not self.last_batch.forward_mode.is_decode() + and not self.last_batch.is_empty() + ): + if self.running_batch is None: + self.running_batch = self.last_batch + else: + self.running_batch.merge_batch(self.last_batch) + + # Prefill first new_batch = self.get_new_batch_prefill() if new_batch is not None: - # Run a new prefill batch - # replace run_batch with the uncommented line to use pytorch profiler - # result = pytorch_profile( - # "profile_prefill_step", self.run_batch, new_batch, data_size=len(new_batch.reqs) - # ) - result = self.run_batch(new_batch) - self.process_batch_result(new_batch, result) + return new_batch + + # Run decode + if self.running_batch is not None: + self.update_running_batch() + if not self.running_batch: + return None + return self.running_batch else: - if self.running_batch is not None: - # Run a few decode batches continuously for reducing overhead - for _ in range(global_config.num_continue_decode_steps): - batch = self.get_new_batch_decode() - - if batch: - # replace run_batch with the uncommented line to use pytorch profiler - # result = pytorch_profile( - # "profile_decode_step", - # self.run_batch, - # batch, - # data_size=len(batch.reqs), - # ) - result = self.run_batch(batch) - self.process_batch_result(batch, result) - - if self.running_batch.is_empty(): - self.running_batch = None - - if self.running_batch is None: - break - - if self.out_pyobjs and self.running_batch.has_stream: - break - else: - self.check_memory() - self.new_token_ratio = global_config.init_new_token_ratio + self.check_memory() + self.new_token_ratio = global_config.init_new_token_ratio def get_new_batch_prefill(self) -> Optional[ScheduleBatch]: # Handle the cases where prefill is not allowed @@ -607,7 +609,7 @@ class Scheduler: return new_batch - def get_new_batch_decode(self) -> Optional[ScheduleBatch]: + def update_running_batch(self): batch = self.running_batch # Check if decode out of memory @@ -636,11 +638,11 @@ class Scheduler: if jump_forward_reqs: self.batch_is_full = False if batch.is_empty(): - return None + self.running_batch = None + return # Update batch tensors batch.prepare_for_decode() - return batch def run_batch(self, batch: ScheduleBatch): if self.is_generation: @@ -657,16 +659,19 @@ class Scheduler: ) else: next_token_ids = torch.full((batch.batch_size(),), 0) - return logits_output, next_token_ids + ret = logits_output, next_token_ids else: # embedding or reward model assert batch.extend_num_tokens != 0 model_worker_batch = batch.get_model_worker_batch() embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch) - return embeddings + ret = embeddings + return ret def process_batch_result(self, batch: ScheduleBatch, result): if batch.forward_mode.is_decode(): self.process_batch_result_decode(batch, result) + if batch.is_empty(): + self.running_batch = None else: self.process_batch_result_prefill(batch, result) @@ -728,7 +733,7 @@ class Scheduler: ) else: # embedding or reward model assert batch.extend_num_tokens != 0 - embeddings = result + embeddings = result.tolist() # Check finish conditions for i, req in enumerate(batch.reqs): @@ -750,12 +755,6 @@ class Scheduler: self.handle_finished_requests(batch) - if not batch.is_empty(): - if self.running_batch is None: - self.running_batch = batch - else: - self.running_batch.merge_batch(batch) - def process_batch_result_decode(self, batch: ScheduleBatch, result): logits_output, next_token_ids = result if batch.sampling_info.penalizer_orchestrator: @@ -951,7 +950,7 @@ class Scheduler: # Send to detokenizer if output_rids: if self.is_generation: - self.out_pyobjs.append( + self.send_to_detokenizer.send_pyobj( BatchTokenIDOut( output_rids, output_vids, @@ -965,7 +964,7 @@ class Scheduler: ) ) else: # embedding or reward model - self.out_pyobjs.append( + self.send_to_detokenizer.send_pyobj( BatchEmbeddingOut( output_rids, output_embeddings, diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 73c4abe08..390398017 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -118,7 +118,7 @@ class TpModelWorker: def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch): forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) logits_output = self.model_runner.forward(forward_batch) - embeddings = logits_output.embeddings.tolist() + embeddings = logits_output.embeddings return embeddings def update_weights(self, recv_req: UpdateWeightReqInput): diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 2966bed64..c7b2c19e7 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -111,6 +111,7 @@ class ServerArgs: torchao_config: str = "" enable_p2p_check: bool = False triton_attention_reduce_in_fp32: bool = False + num_continuous_decode_steps: int = 1 def __post_init__(self): # Set missing default values @@ -559,6 +560,14 @@ class ServerArgs: help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." "This only affects Triton attention kernels.", ) + parser.add_argument( + "--num-continuous-decode-steps", + type=int, + default=ServerArgs.num_continuous_decode_steps, + help="Run multiple continuous decoding steps to reduce scheduling overhead. " + "This can potentially increase throughput but may also increase time-to-first-token latency. " + "The default value is 1, meaning only run one decoding step at a time.", + ) @classmethod def from_cli_args(cls, args: argparse.Namespace):