diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 10a76d53d..e9a8dc3d3 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -91,6 +91,7 @@ class Scheduler: port_args: PortArgs, gpu_id: int, tp_rank: int, + dp_rank: Optional[int], ): # Parse args self.server_args = server_args @@ -144,13 +145,24 @@ class Scheduler: # Launch a tensor parallel worker self.tp_worker = TpModelWorker( + server_args=server_args, gpu_id=gpu_id, tp_rank=tp_rank, - server_args=server_args, + dp_rank=dp_rank, nccl_port=port_args.nccl_port, ) - self.tp_cpu_group = self.tp_worker.model_runner.tp_group.cpu_group - self.device = self.tp_worker.device + + # Init states for overlap schedule + if self.server_args.enable_overlap_schedule: + self.forward_batch_generation = ( + self.tp_worker.forward_batch_generation_non_blocking + ) + self.resolve_next_token_ids = ( + lambda bid, x: self.tp_worker.resolve_future_token_ids(bid) + ) + else: + self.forward_batch_generation = self.tp_worker.forward_batch_generation + self.resolve_next_token_ids = lambda bid, x: x.tolist() # Get token and memory info from the model worker ( @@ -159,11 +171,11 @@ class Scheduler: self.max_running_requests, self.max_req_input_len, self.random_seed, + self.device, ) = self.tp_worker.get_token_and_memory_info() + self.tp_cpu_group = self.tp_worker.get_tp_cpu_group() + self.pad_input_ids_func = self.tp_worker.get_pad_input_ids_func() set_random_seed(self.random_seed) - self.pad_input_ids_func = getattr( - self.tp_worker.model_runner.model, "pad_input_ids", None - ) # Print debug info logger.info( @@ -173,9 +185,8 @@ class Scheduler: f"context_len={self.model_config.context_len}" ) - # Init cache - self.req_to_token_pool = self.tp_worker.model_runner.req_to_token_pool - self.token_to_kv_pool = self.tp_worker.model_runner.token_to_kv_pool + # Init memory pool and cache + self.req_to_token_pool, self.token_to_kv_pool = self.tp_worker.get_memory_pool() if ( server_args.chunked_prefill_size is not None @@ -253,20 +264,6 @@ class Scheduler: with_stack=True, ) - # Init states for overlap schedule - if self.server_args.enable_overlap_schedule: - self.forward_batch_generation = ( - self.tp_worker.forward_batch_generation_non_blocking - ) - self.resolve_next_token_ids = ( - lambda bid, x: self.tp_worker.resolve_future_token_ids(bid) - ) - self.cache_finished_req = self.tree_cache.cache_finished_req - else: - self.forward_batch_generation = self.tp_worker.forward_batch_generation - self.resolve_next_token_ids = lambda bid, x: x.tolist() - self.cache_finished_req = self.tree_cache.cache_finished_req - @torch.inference_mode() def event_loop_normal(self): self.last_batch = None @@ -779,7 +776,7 @@ class Scheduler: req.check_finished() if req.finished(): - self.cache_finished_req(req) + self.tree_cache.cache_finished_req(req) elif not batch.decoding_reqs or req not in batch.decoding_reqs: self.tree_cache.cache_unfinished_req(req) @@ -808,7 +805,7 @@ class Scheduler: req.check_finished() if req.finished(): - self.cache_finished_req(req) + self.tree_cache.cache_finished_req(req) else: self.tree_cache.cache_unfinished_req(req) @@ -845,7 +842,7 @@ class Scheduler: ) if req.finished(): - self.cache_finished_req(req) + self.tree_cache.cache_finished_req(req) if req.return_logprob: req.output_token_logprobs.append( @@ -1069,7 +1066,7 @@ class Scheduler: for req in self.running_batch.reqs: if req.rid == recv_req.rid and not req.finished(): req.finished_reason = FINISH_ABORT() - self.cache_finished_req(req) + self.tree_cache.cache_finished_req(req) break def update_weights(self, recv_req: UpdateWeightReqInput): @@ -1112,7 +1109,7 @@ def run_scheduler_process( suppress_other_loggers() try: - scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank) + scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank) pipe_writer.send("ready") if server_args.enable_overlap_schedule: scheduler.event_loop_overlap() diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 814b2d2cc..bf9155f1e 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -20,6 +20,7 @@ import logging import threading import time from queue import Queue +from typing import Optional import torch @@ -40,9 +41,10 @@ class TpModelWorker: def __init__( self, + server_args: ServerArgs, gpu_id: int, tp_rank: int, - server_args: ServerArgs, + dp_rank: Optional[int], nccl_port: int, ): # Parse args @@ -116,6 +118,19 @@ class TpModelWorker: self.max_running_requests, self.max_req_input_len, self.random_seed, + self.device, + ) + + def get_pad_input_ids_func(self): + return getattr(self.model_runner.model, "pad_input_ids", None) + + def get_tp_cpu_group(self): + return self.model_runner.tp_group.cpu_group + + def get_memory_pool(self): + return ( + self.model_runner.req_to_token_pool, + self.model_runner.token_to_kv_pool, ) def init_overlap_status(self):