diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 6bc93ea40..18ca20409 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -193,16 +193,6 @@ class Scheduler: self.tree_cache_metrics = {"total": 0, "hit": 0} self.policy = SchedulePolicy(self.schedule_policy, self.tree_cache) - if self.server_args.enable_overlap_schedule: - - def cache_finished_req(req): - free_delta = int(self.running_batch and req in self.cur_batch.reqs) - self.tree_cache.cache_finished_req(req, free_delta=free_delta) - - else: - cache_finished_req = self.tree_cache.cache_finished_req - self.cache_finished_req = cache_finished_req - # Init running status self.waiting_queue: List[Req] = [] self.running_batch: Optional[ScheduleBatch] = None @@ -245,6 +235,7 @@ class Scheduler: self.new_token_ratio_decay = global_config.new_token_ratio_decay self.batch_is_full = False + # Init profiler if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "": self.profiler = None else: @@ -261,6 +252,25 @@ class Scheduler: with_stack=True, ) + # Init states for overlap schedule + if self.server_args.enable_overlap_schedule: + self.forward_batch_generation = ( + self.tp_worker.forward_batch_generation_non_blocking + ) + self.resolve_next_token_ids = ( + lambda bid, x: self.tp_worker.resolve_future_token_ids(bid) + ) + + def cache_finished_req(req): + free_delta = int(self.running_batch and req in self.cur_batch.reqs) + self.tree_cache.cache_finished_req(req, free_delta=free_delta) + + self.cache_finished_req = cache_finished_req + else: + self.forward_batch_generation = self.tp_worker.forward_batch_generation + self.resolve_next_token_ids = lambda bid, x: x.tolist() + self.cache_finished_req = self.tree_cache.cache_finished_req + @torch.inference_mode() def event_loop_normal(self): self.last_batch = None @@ -712,7 +722,7 @@ class Scheduler: if self.is_generation: if batch.forward_mode.is_decode() or batch.extend_num_tokens != 0: model_worker_batch = batch.get_model_worker_batch() - logits_output, next_token_ids = self.tp_worker.forward_batch_generation( + logits_output, next_token_ids = self.forward_batch_generation( model_worker_batch ) else: @@ -724,12 +734,12 @@ class Scheduler: else: next_token_ids = torch.full((batch.batch_size(),), 0) batch.output_ids = next_token_ids - ret = logits_output, next_token_ids + ret = logits_output, next_token_ids, model_worker_batch.bid else: # embedding or reward model assert batch.extend_num_tokens != 0 model_worker_batch = batch.get_model_worker_batch() embeddings = self.tp_worker.forward_batch_embedding(model_worker_batch) - ret = embeddings + ret = embeddings, model_worker_batch.bid return ret def process_batch_result(self, batch: ScheduleBatch, result): @@ -742,7 +752,7 @@ class Scheduler: def process_batch_result_prefill(self, batch: ScheduleBatch, result): if self.is_generation: - logits_output, next_token_ids = result + logits_output, next_token_ids, bid = result if batch.return_logprob: # Move logprobs to cpu if logits_output.next_token_logprobs is not None: @@ -761,7 +771,7 @@ class Scheduler: logits_output.normalized_prompt_logprobs.tolist() ) - next_token_ids = next_token_ids.tolist() + next_token_ids = self.resolve_next_token_ids(bid, next_token_ids) # Check finish conditions logprob_pt = 0 @@ -790,7 +800,8 @@ class Scheduler: ) else: # embedding or reward model assert batch.extend_num_tokens != 0 - embeddings = result.tolist() + embeddings, bid = result + embeddings = embeddings.tolist() # Check finish conditions for i, req in enumerate(batch.reqs): @@ -811,7 +822,7 @@ class Scheduler: self.stream_output(batch.reqs) def process_batch_result_decode(self, batch: ScheduleBatch, result): - logits_output, next_token_ids = result + logits_output, next_token_ids, bid = result self.num_generated_tokens += len(batch.reqs) # Move logprobs to cpu @@ -821,7 +832,7 @@ class Scheduler: next_token_ids, ].tolist() - next_token_ids = next_token_ids.tolist() + next_token_ids = self.resolve_next_token_ids(bid, next_token_ids) # Check finish condition for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)): diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 390398017..d416aa64a 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -17,6 +17,11 @@ limitations under the License. import json import logging +import threading +import time +from queue import Queue + +import torch from sglang.srt.configs.model_config import ModelConfig from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer @@ -75,6 +80,7 @@ class TpModelWorker: tokenizer_mode=server_args.tokenizer_mode, trust_remote_code=server_args.trust_remote_code, ) + self.device = self.model_runner.device # Profile number of tokens self.max_total_num_tokens = self.model_runner.max_total_num_tokens @@ -100,6 +106,9 @@ class TpModelWorker: )[0] set_random_seed(self.random_seed) + if server_args.enable_overlap_schedule: + self.init_overlap_status() + def get_token_and_memory_info(self): return ( self.max_total_num_tokens, @@ -109,6 +118,83 @@ class TpModelWorker: self.random_seed, ) + def init_overlap_status(self): + self.future_logits_output_dict = dict() + self.future_logits_output_ct = 0 + self.future_token_ids_ct = 0 + self.future_token_ids_map = torch.empty( + (self.max_running_requests * 5,), dtype=torch.int32, device=self.device + ) + self.future_token_ids_limit = self.max_running_requests * 3 + self.future_token_ids_output = dict() + + self.future_event_map = dict() + self.forward_queue = Queue() + self.forward_stream = torch.cuda.Stream() + self.forward_thread = threading.Thread( + target=self.forward_thread_func, + ) + self.forward_thread.start() + + def forward_thread_func(self): + with torch.cuda.stream(self.forward_stream): + self.forward_thread_func_() + + @torch.inference_mode() + def forward_thread_func_(self): + while True: + tic1 = time.time() + model_worker_batch, future_logits_output, future_next_token_ids = ( + self.forward_queue.get() + ) + + # Resolve future tokens in the input + # logger.info(f"raw input {model_worker_batch.input_ids=}") + tic2 = time.time() + resolved_input_ids = model_worker_batch.input_ids + future_mask = resolved_input_ids < 0 + resolved_input_ids[future_mask] = self.future_token_ids_map[ + -resolved_input_ids[future_mask] + ] + # logger.info(f"resolved input {model_worker_batch.input_ids=}") + + # Run forward + logits_output, next_token_ids = self.forward_batch_generation( + model_worker_batch + ) + + # Set future values + if model_worker_batch.return_logprob: + self.future_logits_output_dict[future_logits_output] = logits_output + + # logger.info(f"set output {future_next_token_ids=}, {next_token_ids=}") + self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to( + torch.int32 + ) + # logger.info("Set event") + self.future_token_ids_output[model_worker_batch.bid] = ( + next_token_ids.tolist() + ) + self.future_event_map[model_worker_batch.bid].set() + + if False: + tic3 = time.time() + self.acc_time_with_waiting += tic3 - tic1 + self.acc_time_without_waiting += tic3 - tic2 + if self.forward_queue.qsize() == 0: + logger.info( + f"{self.acc_time_with_waiting=:.3f}, {self.acc_time_without_waiting=:.3f}, {self.forward_queue.qsize()=}" + ) + + def resolve_future_token_ids(self, bid: int): + self.future_event_map[bid].wait() + ret = self.future_token_ids_output[bid] + del self.future_event_map[bid] + return ret + + def resolve_future_logits_output(self, future_obj): + return self.future_logits_output_dict.pop(future_obj) + def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch): forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner) logits_output = self.model_runner.forward(forward_batch) @@ -121,6 +207,31 @@ class TpModelWorker: embeddings = logits_output.embeddings return embeddings + def forward_batch_generation_non_blocking( + self, model_worker_batch: ModelWorkerBatch + ): + # Allocate output future objects + future_logits_output = self.future_logits_output_ct + self.future_logits_output_ct += 1 + + bs = len(model_worker_batch.seq_lens) + future_next_token_ids = -torch.arange( + self.future_token_ids_ct + 1, + self.future_token_ids_ct + 1 + bs, + dtype=torch.int32, + device=self.device, + ) + self.future_token_ids_ct = ( + self.future_token_ids_ct + bs + ) % self.future_token_ids_limit + ret = future_logits_output, future_next_token_ids + + self.future_event_map[model_worker_batch.bid] = threading.Event() + self.forward_queue.put( + (model_worker_batch.copy(), future_logits_output, future_next_token_ids) + ) + return ret + def update_weights(self, recv_req: UpdateWeightReqInput): success, message = self.model_runner.update_weights( recv_req.model_path, recv_req.load_format diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index e398ab4b0..f0b7abbe3 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -447,7 +447,7 @@ def _set_envs_and_config(server_args: ServerArgs): os.environ["NCCL_CUMEM_ENABLE"] = "0" os.environ["NCCL_NVLS_ENABLE"] = "0" os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" - os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" # Set ulimit set_ulimit() @@ -528,7 +528,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid): kill_child_process(pid, including_parent=False) return - # print(f"{res.json()=}") + print(f"{res.json()=}") logger.info("The server is fired up and ready to roll!") if pipe_finish_writer is not None: