diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 70ce381be..f66c30eef 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1703,18 +1703,12 @@ class Scheduler( def save_remote_model(self, params): url = params["url"] - if isinstance(self.tp_worker, TpModelWorkerClient): - worker = self.tp_worker.worker - else: - worker = self.tp_worker + worker = self.tp_worker.worker worker.model_runner.save_remote_model(url) def save_sharded_model(self, params): - if isinstance(self.tp_worker, TpModelWorkerClient): - worker = self.tp_worker.worker - else: - worker = self.tp_worker + worker = self.tp_worker.worker worker.model_runner.save_sharded_model( path=params["path"], diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 21cf7837d..712da3961 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -53,6 +53,8 @@ class TpModelWorker: req_to_token_pool: Optional[ReqToTokenPool] = None, token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None, ): + self.worker = self + # Parse args self.tp_rank = tp_rank