168 lines
6.1 KiB
Python
168 lines
6.1 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
from typing import TYPE_CHECKING, Tuple
|
|
|
|
import torch
|
|
|
|
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS
|
|
from sglang.srt.managers.io_struct import (
|
|
DestroyWeightsUpdateGroupReqInput,
|
|
DestroyWeightsUpdateGroupReqOutput,
|
|
GetWeightsByNameReqInput,
|
|
GetWeightsByNameReqOutput,
|
|
InitWeightsUpdateGroupReqInput,
|
|
InitWeightsUpdateGroupReqOutput,
|
|
ReleaseMemoryOccupationReqInput,
|
|
ReleaseMemoryOccupationReqOutput,
|
|
ResumeMemoryOccupationReqInput,
|
|
ResumeMemoryOccupationReqOutput,
|
|
UpdateWeightFromDiskReqInput,
|
|
UpdateWeightFromDiskReqOutput,
|
|
UpdateWeightsFromDistributedReqInput,
|
|
UpdateWeightsFromDistributedReqOutput,
|
|
UpdateWeightsFromTensorReqInput,
|
|
UpdateWeightsFromTensorReqOutput,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from sglang.srt.managers.scheduler import Scheduler
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SchedulerUpdateWeightsMixin:
|
|
|
|
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
|
"""In-place update of the weights from disk."""
|
|
success, message = self.tp_worker.update_weights_from_disk(recv_req)
|
|
if success:
|
|
flush_cache_success = self.flush_cache()
|
|
assert flush_cache_success, "Cache flush failed after updating weights"
|
|
else:
|
|
logger.error(message)
|
|
return UpdateWeightFromDiskReqOutput(success, message, 0)
|
|
|
|
def init_weights_update_group(self, recv_req: InitWeightsUpdateGroupReqInput):
|
|
"""Initialize the online model parameter update group."""
|
|
success, message = self.tp_worker.init_weights_update_group(recv_req)
|
|
return InitWeightsUpdateGroupReqOutput(success, message)
|
|
|
|
def destroy_weights_update_group(self, recv_req: DestroyWeightsUpdateGroupReqInput):
|
|
"""Destroy the online model parameter update group."""
|
|
success, message = self.tp_worker.destroy_weights_update_group(recv_req)
|
|
return DestroyWeightsUpdateGroupReqOutput(success, message)
|
|
|
|
def update_weights_from_distributed(
|
|
self,
|
|
recv_req: UpdateWeightsFromDistributedReqInput,
|
|
) -> Tuple[bool, str]:
|
|
"""Update the online model parameter."""
|
|
success, message = self.tp_worker.update_weights_from_distributed(recv_req)
|
|
if success:
|
|
if recv_req.flush_cache:
|
|
flush_cache_success = self.flush_cache()
|
|
assert flush_cache_success, "Cache flush failed after updating weights"
|
|
else:
|
|
logger.error(message)
|
|
return UpdateWeightsFromDistributedReqOutput(success, message)
|
|
|
|
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
|
"""Update the online model parameter from tensors."""
|
|
success, message = self.tp_worker.update_weights_from_tensor(recv_req)
|
|
# TODO extract common code b/t update_weights_from_distributed and update_weights_from_tensor later
|
|
if success:
|
|
if recv_req.flush_cache:
|
|
flush_cache_success = self.flush_cache()
|
|
assert flush_cache_success, "Cache flush failed after updating weights"
|
|
else:
|
|
logger.error(message)
|
|
torch.distributed.barrier(group=self.tp_cpu_group)
|
|
return UpdateWeightsFromTensorReqOutput(success, message)
|
|
|
|
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
|
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
|
return GetWeightsByNameReqOutput(parameter)
|
|
|
|
def release_memory_occupation(
|
|
self: Scheduler, recv_req: ReleaseMemoryOccupationReqInput
|
|
):
|
|
tags = recv_req.tags
|
|
|
|
if tags is None or len(tags) == 0:
|
|
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
|
|
|
for tag in tags:
|
|
self.offload_tags.add(tag)
|
|
|
|
if GPU_MEMORY_TYPE_KV_CACHE in tags:
|
|
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
|
|
self.flush_cache()
|
|
|
|
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
|
self.stashed_model_static_state = _export_static_state(
|
|
self.tp_worker.model_runner.model
|
|
)
|
|
torch.distributed.barrier(self.tp_cpu_group)
|
|
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_WEIGHTS)
|
|
|
|
return ReleaseMemoryOccupationReqOutput()
|
|
|
|
def resume_memory_occupation(
|
|
self: Scheduler, recv_req: ResumeMemoryOccupationReqInput
|
|
):
|
|
tags = recv_req.tags
|
|
|
|
if tags is None or len(tags) == 0:
|
|
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
|
|
|
for tag in tags:
|
|
self.offload_tags.remove(tag)
|
|
|
|
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
|
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
|
|
torch.distributed.barrier(self.tp_cpu_group)
|
|
_import_static_state(
|
|
self.tp_worker.model_runner.model,
|
|
self.stashed_model_static_state,
|
|
)
|
|
del self.stashed_model_static_state
|
|
|
|
if GPU_MEMORY_TYPE_KV_CACHE in tags:
|
|
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_KV_CACHE)
|
|
|
|
return ResumeMemoryOccupationReqOutput()
|
|
|
|
def save_remote_model(self: Scheduler, params):
|
|
url = params["url"]
|
|
|
|
self.tp_worker.model_runner.save_remote_model(url)
|
|
|
|
if self.draft_worker is not None:
|
|
draft_url = params.get("draft_url", None)
|
|
assert (
|
|
draft_url is not None
|
|
), "draft_url must be provided when draft model is enabled"
|
|
self.draft_worker.model_runner.save_remote_model(draft_url)
|
|
|
|
def save_sharded_model(self: Scheduler, params):
|
|
self.tp_worker.model_runner.save_sharded_model(
|
|
path=params["path"],
|
|
pattern=params["pattern"],
|
|
max_size=params["max_size"],
|
|
)
|
|
|
|
|
|
def _export_static_state(model):
|
|
return dict(
|
|
buffers=[
|
|
(name, buffer.detach().clone()) for name, buffer in model.named_buffers()
|
|
]
|
|
)
|
|
|
|
|
|
def _import_static_state(model, static_params):
|
|
self_named_buffers = dict(model.named_buffers())
|
|
for name, tensor in static_params["buffers"]:
|
|
self_named_buffers[name][...] = tensor
|