diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 1f42543ab..e19f83f24 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -468,9 +468,7 @@ class Scheduler( # Hybrid memory pool self.is_hybrid = self.tp_worker.is_hybrid - self.is_hybrid_gdn = ( - self.tp_worker.worker.model_runner.hybrid_gdn_config is not None - ) + self.is_hybrid_gdn = self.tp_worker.model_runner.hybrid_gdn_config is not None if self.is_hybrid: self.sliding_window_size = self.tp_worker.sliding_window_size @@ -1882,7 +1880,7 @@ class Scheduler( chunked_req_to_exclude.add(self.chunked_req) self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True) # chunked request keeps its rid but will get a new req_pool_idx - if self.tp_worker.worker.model_runner.mambaish_config is not None: + if self.tp_worker.model_runner.mambaish_config is not None: self.req_to_token_pool.free( self.chunked_req.req_pool_idx, free_mamba_cache=False ) @@ -2686,9 +2684,7 @@ class Scheduler( ret = vars(get_global_server_args()) ret["last_gen_throughput"] = self.last_gen_throughput ret["memory_usage"] = { - "weight": round( - self.tp_worker.worker.model_runner.weight_load_mem_usage, 2 - ), + "weight": round(self.tp_worker.model_runner.weight_load_mem_usage, 2), "kvcache": round( self.token_to_kv_pool_allocator.get_kvcache().mem_usage, 2 ), @@ -2696,7 +2692,7 @@ class Scheduler( } ret["memory_usage"]["graph"] = round( - self.tp_worker.worker.model_runner.graph_mem_usage, 2 + self.tp_worker.model_runner.graph_mem_usage, 2 ) if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0: diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py index fdb7acd64..7552bcce0 100644 --- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py +++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import logging -from typing import Tuple +from typing import TYPE_CHECKING, Tuple import torch @@ -23,6 +25,9 @@ from sglang.srt.managers.io_struct import ( UpdateWeightsFromTensorReqOutput, ) +if TYPE_CHECKING: + from sglang.srt.managers.scheduler import Scheduler + logger = logging.getLogger(__name__) @@ -79,7 +84,9 @@ class SchedulerUpdateWeightsMixin: parameter = self.tp_worker.get_weights_by_name(recv_req) return GetWeightsByNameReqOutput(parameter) - def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput): + def release_memory_occupation( + self: Scheduler, recv_req: ReleaseMemoryOccupationReqInput + ): tags = recv_req.tags if tags is None or len(tags) == 0: @@ -94,14 +101,16 @@ class SchedulerUpdateWeightsMixin: if GPU_MEMORY_TYPE_WEIGHTS in tags: self.stashed_model_static_state = _export_static_state( - self.tp_worker.worker.model_runner.model + self.tp_worker.model_runner.model ) torch.distributed.barrier(self.tp_cpu_group) self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS) return ReleaseMemoryOccupationReqOutput() - def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput): + def resume_memory_occupation( + self: Scheduler, recv_req: ResumeMemoryOccupationReqInput + ): tags = recv_req.tags if tags is None or len(tags) == 0: @@ -114,7 +123,7 @@ class SchedulerUpdateWeightsMixin: self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS) torch.distributed.barrier(self.tp_cpu_group) _import_static_state( - self.tp_worker.worker.model_runner.model, + self.tp_worker.model_runner.model, self.stashed_model_static_state, ) del self.stashed_model_static_state @@ -124,24 +133,20 @@ class SchedulerUpdateWeightsMixin: return ResumeMemoryOccupationReqOutput() - def save_remote_model(self, params): + def save_remote_model(self: Scheduler, params): url = params["url"] - worker = self.tp_worker.worker - worker.model_runner.save_remote_model(url) + self.tp_worker.model_runner.save_remote_model(url) if self.draft_worker is not None: draft_url = params.get("draft_url", None) assert ( draft_url is not None ), "draft_url must be provided when draft model is enabled" - draft_worker = self.draft_worker.worker - draft_worker.model_runner.save_remote_model(draft_url) + self.draft_worker.model_runner.save_remote_model(draft_url) - def save_sharded_model(self, params): - worker = self.tp_worker.worker - - worker.model_runner.save_sharded_model( + def save_sharded_model(self: Scheduler, params): + self.tp_worker.model_runner.save_sharded_model( path=params["path"], pattern=params["pattern"], max_size=params["max_size"], diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 267810c06..3485a0357 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -168,9 +168,6 @@ class TpModelWorker: )[0] set_random_seed(self.random_seed) - # A reference make this class has the same member as TpModelWorkerClient - self.worker = self - self.hicache_layer_transfer_counter = None def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):