diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 91c3c01b8..655b39a4e 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -273,6 +273,48 @@ class Scheduler( ): """A scheduler that manages a tensor parallel GPU worker.""" + def launch_draft_worker( + self, gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank + ): + if self.spec_algorithm.is_eagle(): + from sglang.srt.speculative.eagle_worker import EAGLEWorker + + self.draft_worker = EAGLEWorker( + gpu_id=gpu_id, + tp_rank=tp_rank, + moe_ep_rank=moe_ep_rank, + server_args=server_args, + nccl_port=port_args.nccl_port, + target_worker=self.tp_worker, + dp_rank=dp_rank, + ) + elif self.spec_algorithm.is_standalone(): + from sglang.srt.speculative.standalone_worker import StandaloneWorker + + self.draft_worker = StandaloneWorker( + gpu_id=gpu_id, + tp_rank=tp_rank, + moe_ep_rank=moe_ep_rank, + server_args=server_args, + nccl_port=port_args.nccl_port, + target_worker=self.tp_worker, + dp_rank=dp_rank, + ) + elif self.spec_algorithm.is_ngram(): + from sglang.srt.speculative.ngram_worker import NGRAMWorker + + self.draft_worker = NGRAMWorker( + gpu_id=gpu_id, + tp_rank=tp_rank, + moe_ep_rank=moe_ep_rank, + server_args=server_args, + nccl_port=port_args.nccl_port, + target_worker=self.tp_worker, + dp_rank=dp_rank, + ) + else: + self.draft_worker = None + def __init__( self, server_args: ServerArgs, @@ -412,44 +454,9 @@ class Scheduler( ) # Launch a draft worker for speculative decoding - if self.spec_algorithm.is_eagle(): - from sglang.srt.speculative.eagle_worker import EAGLEWorker - - self.draft_worker = EAGLEWorker( - gpu_id=gpu_id, - tp_rank=tp_rank, - moe_ep_rank=moe_ep_rank, - server_args=server_args, - nccl_port=port_args.nccl_port, - target_worker=self.tp_worker, - dp_rank=dp_rank, - ) - elif self.spec_algorithm.is_standalone(): - from sglang.srt.speculative.standalone_worker import StandaloneWorker - - self.draft_worker = StandaloneWorker( - gpu_id=gpu_id, - tp_rank=tp_rank, - moe_ep_rank=moe_ep_rank, - server_args=server_args, - nccl_port=port_args.nccl_port, - target_worker=self.tp_worker, - dp_rank=dp_rank, - ) - elif self.spec_algorithm.is_ngram(): - from sglang.srt.speculative.ngram_worker import NGRAMWorker - - self.draft_worker = NGRAMWorker( - gpu_id=gpu_id, - tp_rank=tp_rank, - moe_ep_rank=moe_ep_rank, - server_args=server_args, - nccl_port=port_args.nccl_port, - target_worker=self.tp_worker, - dp_rank=dp_rank, - ) - else: - self.draft_worker = None + self.launch_draft_worker( + gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank + ) # Dispatch the model worker if self.spec_algorithm.is_none():