Remove tp_worker.worker (#11548)
This commit is contained in:
@@ -468,9 +468,7 @@ class Scheduler(
|
||||
|
||||
# Hybrid memory pool
|
||||
self.is_hybrid = self.tp_worker.is_hybrid
|
||||
self.is_hybrid_gdn = (
|
||||
self.tp_worker.worker.model_runner.hybrid_gdn_config is not None
|
||||
)
|
||||
self.is_hybrid_gdn = self.tp_worker.model_runner.hybrid_gdn_config is not None
|
||||
|
||||
if self.is_hybrid:
|
||||
self.sliding_window_size = self.tp_worker.sliding_window_size
|
||||
@@ -1882,7 +1880,7 @@ class Scheduler(
|
||||
chunked_req_to_exclude.add(self.chunked_req)
|
||||
self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
|
||||
# chunked request keeps its rid but will get a new req_pool_idx
|
||||
if self.tp_worker.worker.model_runner.mambaish_config is not None:
|
||||
if self.tp_worker.model_runner.mambaish_config is not None:
|
||||
self.req_to_token_pool.free(
|
||||
self.chunked_req.req_pool_idx, free_mamba_cache=False
|
||||
)
|
||||
@@ -2686,9 +2684,7 @@ class Scheduler(
|
||||
ret = vars(get_global_server_args())
|
||||
ret["last_gen_throughput"] = self.last_gen_throughput
|
||||
ret["memory_usage"] = {
|
||||
"weight": round(
|
||||
self.tp_worker.worker.model_runner.weight_load_mem_usage, 2
|
||||
),
|
||||
"weight": round(self.tp_worker.model_runner.weight_load_mem_usage, 2),
|
||||
"kvcache": round(
|
||||
self.token_to_kv_pool_allocator.get_kvcache().mem_usage, 2
|
||||
),
|
||||
@@ -2696,7 +2692,7 @@ class Scheduler(
|
||||
}
|
||||
|
||||
ret["memory_usage"]["graph"] = round(
|
||||
self.tp_worker.worker.model_runner.graph_mem_usage, 2
|
||||
self.tp_worker.model_runner.graph_mem_usage, 2
|
||||
)
|
||||
|
||||
if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Tuple
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -23,6 +25,9 @@ from sglang.srt.managers.io_struct import (
|
||||
UpdateWeightsFromTensorReqOutput,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.scheduler import Scheduler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -79,7 +84,9 @@ class SchedulerUpdateWeightsMixin:
|
||||
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
||||
return GetWeightsByNameReqOutput(parameter)
|
||||
|
||||
def release_memory_occupation(self, recv_req: ReleaseMemoryOccupationReqInput):
|
||||
def release_memory_occupation(
|
||||
self: Scheduler, recv_req: ReleaseMemoryOccupationReqInput
|
||||
):
|
||||
tags = recv_req.tags
|
||||
|
||||
if tags is None or len(tags) == 0:
|
||||
@@ -94,14 +101,16 @@ class SchedulerUpdateWeightsMixin:
|
||||
|
||||
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
||||
self.stashed_model_static_state = _export_static_state(
|
||||
self.tp_worker.worker.model_runner.model
|
||||
self.tp_worker.model_runner.model
|
||||
)
|
||||
torch.distributed.barrier(self.tp_cpu_group)
|
||||
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
|
||||
|
||||
return ReleaseMemoryOccupationReqOutput()
|
||||
|
||||
def resume_memory_occupation(self, recv_req: ResumeMemoryOccupationReqInput):
|
||||
def resume_memory_occupation(
|
||||
self: Scheduler, recv_req: ResumeMemoryOccupationReqInput
|
||||
):
|
||||
tags = recv_req.tags
|
||||
|
||||
if tags is None or len(tags) == 0:
|
||||
@@ -114,7 +123,7 @@ class SchedulerUpdateWeightsMixin:
|
||||
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
|
||||
torch.distributed.barrier(self.tp_cpu_group)
|
||||
_import_static_state(
|
||||
self.tp_worker.worker.model_runner.model,
|
||||
self.tp_worker.model_runner.model,
|
||||
self.stashed_model_static_state,
|
||||
)
|
||||
del self.stashed_model_static_state
|
||||
@@ -124,24 +133,20 @@ class SchedulerUpdateWeightsMixin:
|
||||
|
||||
return ResumeMemoryOccupationReqOutput()
|
||||
|
||||
def save_remote_model(self, params):
|
||||
def save_remote_model(self: Scheduler, params):
|
||||
url = params["url"]
|
||||
|
||||
worker = self.tp_worker.worker
|
||||
worker.model_runner.save_remote_model(url)
|
||||
self.tp_worker.model_runner.save_remote_model(url)
|
||||
|
||||
if self.draft_worker is not None:
|
||||
draft_url = params.get("draft_url", None)
|
||||
assert (
|
||||
draft_url is not None
|
||||
), "draft_url must be provided when draft model is enabled"
|
||||
draft_worker = self.draft_worker.worker
|
||||
draft_worker.model_runner.save_remote_model(draft_url)
|
||||
self.draft_worker.model_runner.save_remote_model(draft_url)
|
||||
|
||||
def save_sharded_model(self, params):
|
||||
worker = self.tp_worker.worker
|
||||
|
||||
worker.model_runner.save_sharded_model(
|
||||
def save_sharded_model(self: Scheduler, params):
|
||||
self.tp_worker.model_runner.save_sharded_model(
|
||||
path=params["path"],
|
||||
pattern=params["pattern"],
|
||||
max_size=params["max_size"],
|
||||
|
||||
@@ -168,9 +168,6 @@ class TpModelWorker:
|
||||
)[0]
|
||||
set_random_seed(self.random_seed)
|
||||
|
||||
# A reference make this class has the same member as TpModelWorkerClient
|
||||
self.worker = self
|
||||
|
||||
self.hicache_layer_transfer_counter = None
|
||||
|
||||
def register_hicache_layer_transfer_counter(self, counter: LayerDoneCounter):
|
||||
|
||||
Reference in New Issue
Block a user