[Feature] Support loading weights from ckpt engine worker (#11755)
Signed-off-by: Yang Kaiyong <yangkaiyong.yky@antgroup.com> Signed-off-by: Cruz Zhao <CruzZhao@linux.alibaba.com> Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com> Co-authored-by: Yang Kaiyong <yangkaiyong.yky@antgroup.com> Co-authored-by: Cruz Zhao <CruzZhao@linux.alibaba.com> Co-authored-by: Xuchun Shang <xuchun.shang@gmail.com> Co-authored-by: Shangming Cai <csmthu@gmail.com>
This commit is contained in:
@@ -89,6 +89,7 @@ test = [
|
||||
"sentence_transformers",
|
||||
"tabulate",
|
||||
]
|
||||
checkpoint-engine = ["checkpoint-engine==0.1.2"]
|
||||
all = []
|
||||
dev = ["sglang[test]"]
|
||||
|
||||
|
||||
142
python/sglang/srt/checkpoint_engine/checkpoint_engine_worker.py
Normal file
142
python/sglang/srt/checkpoint_engine/checkpoint_engine_worker.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Checkpoint-engine integration for SGLang.
|
||||
This module provides weight update functionality via IPC for checkpoint-engine compatibility.
|
||||
"""
|
||||
import logging
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
try:
|
||||
from checkpoint_engine.worker import update_weights_from_ipc
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"checkpoint-engine is not installed. "
|
||||
"Please install it with: pip install sglang[checkpoint-engine]"
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SGLangCheckpointEngineWorkerExtension:
|
||||
"""
|
||||
Worker extension for SGLang to support checkpoint-engine IPC weight updates.
|
||||
This class provides the interface needed for checkpoint-engine integration.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._zmq_ctx: Optional[zmq.Context] = None
|
||||
|
||||
def get_device_uuid(self) -> str:
|
||||
"""Get the UUID of current device."""
|
||||
# We need to implement this to get the device UUID
|
||||
# This will be overridden when integrated into SGLang's worker
|
||||
raise NotImplementedError(
|
||||
"This method should be overridden by SGLang integration"
|
||||
)
|
||||
|
||||
def get_device_id(self) -> int:
|
||||
"""Get the device ID."""
|
||||
raise NotImplementedError(
|
||||
"This method should be overridden by SGLang integration"
|
||||
)
|
||||
|
||||
def get_model_loader(self) -> Callable:
|
||||
"""Get the model weight loader function."""
|
||||
raise NotImplementedError(
|
||||
"This method should be overridden by SGLang integration"
|
||||
)
|
||||
|
||||
def get_post_hook(self) -> Optional[Callable]:
|
||||
"""Get the post-processing hook after weight loading."""
|
||||
return None
|
||||
|
||||
def update_weights_from_ipc(self, zmq_handles: Dict[str, str]):
|
||||
"""
|
||||
Update weights from IPC communication.
|
||||
Args:
|
||||
zmq_handles: Dict mapping device UUID to ZMQ socket path
|
||||
"""
|
||||
if self._zmq_ctx is None:
|
||||
self._zmq_ctx = zmq.Context()
|
||||
device_uuid = self.get_device_uuid()
|
||||
device_id = self.get_device_id()
|
||||
if device_uuid not in zmq_handles:
|
||||
raise ValueError(
|
||||
f"Device UUID {device_uuid} not found in zmq_handles: {list(zmq_handles.keys())}"
|
||||
)
|
||||
update_weights_from_ipc(
|
||||
self._zmq_ctx,
|
||||
zmq_handles[device_uuid],
|
||||
device_id=device_id,
|
||||
run=self.get_model_loader(),
|
||||
post_hook=self.get_post_hook(),
|
||||
)
|
||||
|
||||
|
||||
class SGLangCheckpointEngineWorkerExtensionImpl(SGLangCheckpointEngineWorkerExtension):
|
||||
"""
|
||||
Implementation of SGLangCheckpointEngineWorkerExtension that integrates with SGLang's model runner.
|
||||
This class provides the concrete implementation for checkpoint-engine IPC weight updates.
|
||||
"""
|
||||
|
||||
def __init__(self, model_runner):
|
||||
super().__init__()
|
||||
self.model_runner = model_runner
|
||||
|
||||
def get_device_uuid(self) -> str:
|
||||
"""Get the UUID of current device."""
|
||||
# Get device UUID for current device
|
||||
device_id = torch.cuda.current_device()
|
||||
try:
|
||||
return f"GPU-{torch.cuda.get_device_properties(device_id).uuid!s}"
|
||||
except AssertionError as e:
|
||||
raise ValueError(f"Failed to get GPU UUID for device {device_id}") from e
|
||||
|
||||
def get_device_id(self) -> int:
|
||||
"""Get the device ID."""
|
||||
return torch.cuda.current_device()
|
||||
|
||||
def get_model_loader(self) -> Callable:
|
||||
"""Get the model weight loader function."""
|
||||
return self.model_runner.model.load_weights
|
||||
|
||||
def get_post_hook(self) -> Optional[Callable]:
|
||||
"""Get the post-processing hook after weight loading."""
|
||||
|
||||
def post_hook():
|
||||
# Perform post-processing after weight loading similar to DefaultModelLoader
|
||||
try:
|
||||
from sglang.srt.model_loader.loader import device_loading_context
|
||||
|
||||
# Process quantization methods after loading weights
|
||||
for _, module in self.model_runner.model.named_modules():
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if quant_method is not None:
|
||||
# Move parameters to device if needed for quantization processing
|
||||
target_device = torch.device(
|
||||
"cuda", torch.cuda.current_device()
|
||||
)
|
||||
with device_loading_context(module, target_device):
|
||||
quant_method.process_weights_after_loading(module)
|
||||
# Call model-specific post-loading hook if available
|
||||
if hasattr(self.model_runner.model, "post_load_weights"):
|
||||
self.model_runner.model.post_load_weights()
|
||||
except Exception as e:
|
||||
logger.warning(f"Post-hook processing failed: {e}")
|
||||
|
||||
return post_hook
|
||||
@@ -59,6 +59,7 @@ from sglang.srt.managers.io_struct import (
|
||||
UnloadLoRAAdapterReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromIPCReqInput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
)
|
||||
from sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerRouter
|
||||
@@ -649,6 +650,21 @@ class Engine(EngineBase):
|
||||
request=None,
|
||||
)
|
||||
|
||||
def update_weights_from_ipc(
|
||||
self,
|
||||
zmq_handles: Dict[str, str],
|
||||
flush_cache: bool = True,
|
||||
):
|
||||
"""Update weights from IPC for checkpoint-engine integration."""
|
||||
obj = UpdateWeightsFromIPCReqInput(
|
||||
zmq_handles=zmq_handles,
|
||||
flush_cache=flush_cache,
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(
|
||||
self.tokenizer_manager.update_weights_from_ipc(obj, None)
|
||||
)
|
||||
|
||||
|
||||
def _set_envs_and_config(server_args: ServerArgs):
|
||||
# Set global environments
|
||||
|
||||
@@ -96,6 +96,7 @@ from sglang.srt.managers.io_struct import (
|
||||
UnloadLoRAAdapterReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromIPCReqInput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
UpdateWeightVersionReqInput,
|
||||
VertexGenerateReqInput,
|
||||
@@ -129,6 +130,7 @@ logger = logging.getLogger(__name__)
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
|
||||
WAIT_WEIGHTS_READY_TIMEOUT = int(os.getenv("SGLANG_WAIT_WEIGHTS_READY_TIMEOUT", 120))
|
||||
|
||||
|
||||
# Store global states
|
||||
@@ -838,6 +840,27 @@ async def update_weights_from_distributed(
|
||||
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
|
||||
@app.post("/update_weights_from_ipc")
|
||||
async def update_weights_from_ipc(obj: UpdateWeightsFromIPCReqInput, request: Request):
|
||||
"""Update the weights from IPC (Inter-Process Communication) for checkpoint-engine integration."""
|
||||
success, message = await _global_state.tokenizer_manager.update_weights_from_ipc(
|
||||
obj, request
|
||||
)
|
||||
|
||||
# Update weight version if provided and weights update was successful
|
||||
if success and obj.weight_version is not None:
|
||||
_update_weight_version_if_provided(obj.weight_version)
|
||||
message += f" Weight version updated to {obj.weight_version}."
|
||||
|
||||
content = {"success": success, "message": message}
|
||||
if success:
|
||||
if _global_state.tokenizer_manager.initial_weights_loaded is False:
|
||||
_global_state.tokenizer_manager.initial_weights_loaded = True
|
||||
return ORJSONResponse(content)
|
||||
else:
|
||||
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
|
||||
@app.post("/update_weight_version")
|
||||
async def update_weight_version(obj: UpdateWeightVersionReqInput, request: Request):
|
||||
"""Update the weight version. This operation requires no active requests."""
|
||||
@@ -1530,6 +1553,8 @@ def _wait_and_warmup(
|
||||
pipe_finish_writer: Optional[multiprocessing.connection.Connection],
|
||||
launch_callback: Optional[Callable[[], None]] = None,
|
||||
):
|
||||
if server_args.checkpoint_engine_wait_weights_before_ready:
|
||||
_wait_weights_ready()
|
||||
if not server_args.skip_server_warmup:
|
||||
if not _execute_server_warmup(
|
||||
server_args,
|
||||
@@ -1552,3 +1577,24 @@ def _wait_and_warmup(
|
||||
|
||||
if launch_callback is not None:
|
||||
launch_callback()
|
||||
|
||||
|
||||
def _wait_weights_ready():
|
||||
"""Wait for weights to be ready within the specified timeout."""
|
||||
timeout = WAIT_WEIGHTS_READY_TIMEOUT
|
||||
start_time = time.time()
|
||||
|
||||
for _ in range(timeout):
|
||||
if _global_state.tokenizer_manager.initial_weights_loaded:
|
||||
logger.info(
|
||||
f"Weights are ready after {time.time() - start_time:.2f} seconds"
|
||||
)
|
||||
return
|
||||
time.sleep(1)
|
||||
|
||||
# Timeout reached without weights being ready
|
||||
logger.error(
|
||||
f"Weights are not ready after waiting {timeout} seconds. "
|
||||
f"Consider increasing SGLANG_WAIT_WEIGHTS_READY_TIMEOUT environment variable. "
|
||||
f"Current status: initial_weights_loaded={_global_state.tokenizer_manager.initial_weights_loaded}"
|
||||
)
|
||||
|
||||
@@ -1080,6 +1080,24 @@ class InitWeightsSendGroupForRemoteInstanceReqInput(BaseReq):
|
||||
backend: str = "nccl"
|
||||
|
||||
|
||||
# Now UpdateWeightsFromIPCReqInput and UpdateWeightsFromIPCReqOutput
|
||||
# are only used by Checkpoint Engine (https://github.com/MoonshotAI/checkpoint-engine)
|
||||
@dataclass
|
||||
class UpdateWeightsFromIPCReqInput(BaseReq):
|
||||
# ZMQ socket paths for each device UUID
|
||||
zmq_handles: Dict[str, str]
|
||||
# Whether to flush cache after weight update
|
||||
flush_cache: bool = True
|
||||
# Optional: Update weight version along with weights
|
||||
weight_version: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateWeightsFromIPCReqOutput(BaseReq):
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class InitWeightsSendGroupForRemoteInstanceReqOutput(BaseReq):
|
||||
success: bool
|
||||
|
||||
@@ -109,6 +109,7 @@ from sglang.srt.managers.io_struct import (
|
||||
UnloadLoRAAdapterReqOutput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromIPCReqInput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
)
|
||||
from sglang.srt.managers.mm_utils import init_embedding_cache
|
||||
@@ -530,6 +531,7 @@ class Scheduler(
|
||||
self.update_weights_from_distributed,
|
||||
),
|
||||
(UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
|
||||
(UpdateWeightsFromIPCReqInput, self.update_weights_from_ipc),
|
||||
(GetWeightsByNameReqInput, self.get_weights_by_name),
|
||||
(ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
|
||||
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
|
||||
|
||||
@@ -21,6 +21,8 @@ from sglang.srt.managers.io_struct import (
|
||||
UpdateWeightFromDiskReqOutput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromDistributedReqOutput,
|
||||
UpdateWeightsFromIPCReqInput,
|
||||
UpdateWeightsFromIPCReqOutput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
UpdateWeightsFromTensorReqOutput,
|
||||
)
|
||||
@@ -80,6 +82,18 @@ class SchedulerUpdateWeightsMixin:
|
||||
torch.distributed.barrier(group=self.tp_cpu_group)
|
||||
return UpdateWeightsFromTensorReqOutput(success, message)
|
||||
|
||||
def update_weights_from_ipc(self, recv_req: UpdateWeightsFromIPCReqInput):
|
||||
"""Update the online model parameter from IPC for checkpoint-engine integration."""
|
||||
success, message = self.tp_worker.update_weights_from_ipc(recv_req)
|
||||
if success:
|
||||
if recv_req.flush_cache:
|
||||
flush_cache_success = self.flush_cache()
|
||||
assert flush_cache_success, "Cache flush failed after updating weights"
|
||||
else:
|
||||
logger.error(message)
|
||||
torch.distributed.barrier(group=self.tp_cpu_group)
|
||||
return UpdateWeightsFromIPCReqOutput(success, message)
|
||||
|
||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
||||
return GetWeightsByNameReqOutput(parameter)
|
||||
|
||||
@@ -63,6 +63,8 @@ from sglang.srt.managers.io_struct import (
|
||||
UnloadLoRAAdapterReqOutput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromDistributedReqOutput,
|
||||
UpdateWeightsFromIPCReqInput,
|
||||
UpdateWeightsFromIPCReqOutput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
UpdateWeightsFromTensorReqOutput,
|
||||
)
|
||||
@@ -169,6 +171,9 @@ class TokenizerCommunicatorMixin:
|
||||
self.update_weights_from_tensor_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.update_weights_from_ipc_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.get_weights_by_name_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
@@ -235,6 +240,10 @@ class TokenizerCommunicatorMixin:
|
||||
UpdateWeightsFromTensorReqOutput,
|
||||
self.update_weights_from_tensor_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
UpdateWeightsFromIPCReqOutput,
|
||||
self.update_weights_from_ipc_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
GetWeightsByNameReqOutput,
|
||||
self.get_weights_by_name_communicator.handle_recv,
|
||||
@@ -442,6 +451,28 @@ class TokenizerCommunicatorMixin:
|
||||
result = (await self.update_weights_from_tensor_communicator(obj))[0]
|
||||
return result.success, result.message
|
||||
|
||||
async def update_weights_from_ipc(
|
||||
self,
|
||||
obj: UpdateWeightsFromIPCReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
"""Update weights via IPC for checkpoint-engine integration."""
|
||||
self.auto_create_handle_loop()
|
||||
try:
|
||||
# For now, we only support single data parallel instance
|
||||
assert (
|
||||
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
||||
), "dp_size must be 1 or dp attention must be enabled for update weights from IPC"
|
||||
logger.info("Starting IPC weight update")
|
||||
# This means that weight sync cannot run while requests are in progress.
|
||||
async with self.model_update_lock.writer_lock:
|
||||
result = (await self.update_weights_from_ipc_communicator(obj))[0]
|
||||
return result.success, result.message
|
||||
except Exception as e:
|
||||
error_msg = f"IPC weight update failed: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
async def load_lora_adapter(
|
||||
self: TokenizerManager,
|
||||
obj: LoadLoRAAdapterReqInput,
|
||||
|
||||
@@ -284,6 +284,11 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
self.gracefully_exit = False
|
||||
self.last_receive_tstamp = 0
|
||||
|
||||
# Initial weights status
|
||||
self.initial_weights_loaded = True
|
||||
if server_args.checkpoint_engine_wait_weights_before_ready:
|
||||
self.initial_weights_loaded = False
|
||||
|
||||
# Dumping
|
||||
self.dump_requests_folder = "" # By default do not dump
|
||||
self.dump_requests_threshold = 1000
|
||||
|
||||
@@ -32,6 +32,7 @@ from sglang.srt.managers.io_struct import (
|
||||
UnloadLoRAAdapterReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromIPCReqInput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||
@@ -164,6 +165,11 @@ class BaseTpWorker(ABC):
|
||||
)
|
||||
return success, message
|
||||
|
||||
def update_weights_from_ipc(self, recv_req: UpdateWeightsFromIPCReqInput):
|
||||
"""Update weights from IPC for checkpoint-engine integration."""
|
||||
success, message = self.model_runner.update_weights_from_ipc(recv_req)
|
||||
return success, message
|
||||
|
||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||
parameter = self.model_runner.get_weights_by_name(
|
||||
recv_req.name, recv_req.truncate_size
|
||||
|
||||
@@ -2387,6 +2387,23 @@ class ModelRunner:
|
||||
)
|
||||
ShardedStateLoader.save_model(self.model, path, pattern, max_size)
|
||||
|
||||
def update_weights_from_ipc(self, recv_req):
|
||||
"""Update weights from IPC for checkpoint-engine integration."""
|
||||
try:
|
||||
from sglang.srt.checkpoint_engine.checkpoint_engine_worker import (
|
||||
SGLangCheckpointEngineWorkerExtensionImpl,
|
||||
)
|
||||
|
||||
# Create a worker extension that integrates with SGLang's model
|
||||
worker = SGLangCheckpointEngineWorkerExtensionImpl(self)
|
||||
worker.update_weights_from_ipc(recv_req.zmq_handles)
|
||||
return True, "IPC weight update completed successfully"
|
||||
except ImportError as e:
|
||||
return False, f"IPC weight update failed: ImportError {e}"
|
||||
except Exception as e:
|
||||
logger.error(f"IPC weight update failed: {e}")
|
||||
return False, str(e)
|
||||
|
||||
|
||||
def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
|
||||
params_dict = dict(model.named_parameters())
|
||||
|
||||
@@ -208,6 +208,7 @@ class ServerArgs:
|
||||
skip_server_warmup: bool = False
|
||||
warmups: Optional[str] = None
|
||||
nccl_port: Optional[int] = None
|
||||
checkpoint_engine_wait_weights_before_ready: bool = False
|
||||
|
||||
# Quantization and data type
|
||||
dtype: str = "auto"
|
||||
@@ -1704,6 +1705,12 @@ class ServerArgs:
|
||||
default=ServerArgs.nccl_port,
|
||||
help="The port for NCCL distributed environment setup. Defaults to a random port.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint-engine-wait-weights-before-ready",
|
||||
action="store_true",
|
||||
help="If set, the server will wait for initial weights to be loaded via checkpoint-engine or other update methods "
|
||||
"before serving inference requests.",
|
||||
)
|
||||
|
||||
# Quantization and data type
|
||||
parser.add_argument(
|
||||
|
||||
@@ -2275,6 +2275,11 @@ def launch_dummy_health_check_server(host, port, enable_metrics):
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/ping")
|
||||
async def ping():
|
||||
"""Could be used by the checkpoint-engine update script to confirm the server is up."""
|
||||
return Response(status_code=200)
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
"""Check the health of the http server."""
|
||||
|
||||
Reference in New Issue
Block a user