[Feature] Support loading weights from ckpt engine worker (#11755)

Signed-off-by: Yang Kaiyong <yangkaiyong.yky@antgroup.com>
Signed-off-by: Cruz Zhao <CruzZhao@linux.alibaba.com>
Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com>
Co-authored-by: Yang Kaiyong <yangkaiyong.yky@antgroup.com>
Co-authored-by: Cruz Zhao <CruzZhao@linux.alibaba.com>
Co-authored-by: Xuchun Shang <xuchun.shang@gmail.com>
Co-authored-by: Shangming Cai <csmthu@gmail.com>
This commit is contained in:
Teng Ma
2025-10-24 00:23:30 +08:00
committed by GitHub
parent b0b4f71679
commit 96a5e4dd79
15 changed files with 552 additions and 0 deletions

View File

@@ -89,6 +89,7 @@ test = [
"sentence_transformers",
"tabulate",
]
checkpoint-engine = ["checkpoint-engine==0.1.2"]
all = []
dev = ["sglang[test]"]

View File

@@ -0,0 +1,142 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Checkpoint-engine integration for SGLang.
This module provides weight update functionality via IPC for checkpoint-engine compatibility.
"""
import logging
from typing import Callable, Dict, Optional
import torch
import zmq
try:
from checkpoint_engine.worker import update_weights_from_ipc
except ImportError:
raise ImportError(
"checkpoint-engine is not installed. "
"Please install it with: pip install sglang[checkpoint-engine]"
)
logger = logging.getLogger(__name__)
class SGLangCheckpointEngineWorkerExtension:
"""
Worker extension for SGLang to support checkpoint-engine IPC weight updates.
This class provides the interface needed for checkpoint-engine integration.
"""
def __init__(self):
self._zmq_ctx: Optional[zmq.Context] = None
def get_device_uuid(self) -> str:
"""Get the UUID of current device."""
# We need to implement this to get the device UUID
# This will be overridden when integrated into SGLang's worker
raise NotImplementedError(
"This method should be overridden by SGLang integration"
)
def get_device_id(self) -> int:
"""Get the device ID."""
raise NotImplementedError(
"This method should be overridden by SGLang integration"
)
def get_model_loader(self) -> Callable:
"""Get the model weight loader function."""
raise NotImplementedError(
"This method should be overridden by SGLang integration"
)
def get_post_hook(self) -> Optional[Callable]:
"""Get the post-processing hook after weight loading."""
return None
def update_weights_from_ipc(self, zmq_handles: Dict[str, str]):
"""
Update weights from IPC communication.
Args:
zmq_handles: Dict mapping device UUID to ZMQ socket path
"""
if self._zmq_ctx is None:
self._zmq_ctx = zmq.Context()
device_uuid = self.get_device_uuid()
device_id = self.get_device_id()
if device_uuid not in zmq_handles:
raise ValueError(
f"Device UUID {device_uuid} not found in zmq_handles: {list(zmq_handles.keys())}"
)
update_weights_from_ipc(
self._zmq_ctx,
zmq_handles[device_uuid],
device_id=device_id,
run=self.get_model_loader(),
post_hook=self.get_post_hook(),
)
class SGLangCheckpointEngineWorkerExtensionImpl(SGLangCheckpointEngineWorkerExtension):
"""
Implementation of SGLangCheckpointEngineWorkerExtension that integrates with SGLang's model runner.
This class provides the concrete implementation for checkpoint-engine IPC weight updates.
"""
def __init__(self, model_runner):
super().__init__()
self.model_runner = model_runner
def get_device_uuid(self) -> str:
"""Get the UUID of current device."""
# Get device UUID for current device
device_id = torch.cuda.current_device()
try:
return f"GPU-{torch.cuda.get_device_properties(device_id).uuid!s}"
except AssertionError as e:
raise ValueError(f"Failed to get GPU UUID for device {device_id}") from e
def get_device_id(self) -> int:
"""Get the device ID."""
return torch.cuda.current_device()
def get_model_loader(self) -> Callable:
"""Get the model weight loader function."""
return self.model_runner.model.load_weights
def get_post_hook(self) -> Optional[Callable]:
"""Get the post-processing hook after weight loading."""
def post_hook():
# Perform post-processing after weight loading similar to DefaultModelLoader
try:
from sglang.srt.model_loader.loader import device_loading_context
# Process quantization methods after loading weights
for _, module in self.model_runner.model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
# Move parameters to device if needed for quantization processing
target_device = torch.device(
"cuda", torch.cuda.current_device()
)
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
# Call model-specific post-loading hook if available
if hasattr(self.model_runner.model, "post_load_weights"):
self.model_runner.model.post_load_weights()
except Exception as e:
logger.warning(f"Post-hook processing failed: {e}")
return post_hook

View File

@@ -59,6 +59,7 @@ from sglang.srt.managers.io_struct import (
UnloadLoRAAdapterReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromIPCReqInput,
UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerRouter
@@ -649,6 +650,21 @@ class Engine(EngineBase):
request=None,
)
def update_weights_from_ipc(
self,
zmq_handles: Dict[str, str],
flush_cache: bool = True,
):
"""Update weights from IPC for checkpoint-engine integration."""
obj = UpdateWeightsFromIPCReqInput(
zmq_handles=zmq_handles,
flush_cache=flush_cache,
)
loop = asyncio.get_event_loop()
return loop.run_until_complete(
self.tokenizer_manager.update_weights_from_ipc(obj, None)
)
def _set_envs_and_config(server_args: ServerArgs):
# Set global environments

View File

@@ -96,6 +96,7 @@ from sglang.srt.managers.io_struct import (
UnloadLoRAAdapterReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromIPCReqInput,
UpdateWeightsFromTensorReqInput,
UpdateWeightVersionReqInput,
VertexGenerateReqInput,
@@ -129,6 +130,7 @@ logger = logging.getLogger(__name__)
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
WAIT_WEIGHTS_READY_TIMEOUT = int(os.getenv("SGLANG_WAIT_WEIGHTS_READY_TIMEOUT", 120))
# Store global states
@@ -838,6 +840,27 @@ async def update_weights_from_distributed(
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
@app.post("/update_weights_from_ipc")
async def update_weights_from_ipc(obj: UpdateWeightsFromIPCReqInput, request: Request):
"""Update the weights from IPC (Inter-Process Communication) for checkpoint-engine integration."""
success, message = await _global_state.tokenizer_manager.update_weights_from_ipc(
obj, request
)
# Update weight version if provided and weights update was successful
if success and obj.weight_version is not None:
_update_weight_version_if_provided(obj.weight_version)
message += f" Weight version updated to {obj.weight_version}."
content = {"success": success, "message": message}
if success:
if _global_state.tokenizer_manager.initial_weights_loaded is False:
_global_state.tokenizer_manager.initial_weights_loaded = True
return ORJSONResponse(content)
else:
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
@app.post("/update_weight_version")
async def update_weight_version(obj: UpdateWeightVersionReqInput, request: Request):
"""Update the weight version. This operation requires no active requests."""
@@ -1530,6 +1553,8 @@ def _wait_and_warmup(
pipe_finish_writer: Optional[multiprocessing.connection.Connection],
launch_callback: Optional[Callable[[], None]] = None,
):
if server_args.checkpoint_engine_wait_weights_before_ready:
_wait_weights_ready()
if not server_args.skip_server_warmup:
if not _execute_server_warmup(
server_args,
@@ -1552,3 +1577,24 @@ def _wait_and_warmup(
if launch_callback is not None:
launch_callback()
def _wait_weights_ready():
"""Wait for weights to be ready within the specified timeout."""
timeout = WAIT_WEIGHTS_READY_TIMEOUT
start_time = time.time()
for _ in range(timeout):
if _global_state.tokenizer_manager.initial_weights_loaded:
logger.info(
f"Weights are ready after {time.time() - start_time:.2f} seconds"
)
return
time.sleep(1)
# Timeout reached without weights being ready
logger.error(
f"Weights are not ready after waiting {timeout} seconds. "
f"Consider increasing SGLANG_WAIT_WEIGHTS_READY_TIMEOUT environment variable. "
f"Current status: initial_weights_loaded={_global_state.tokenizer_manager.initial_weights_loaded}"
)

View File

@@ -1080,6 +1080,24 @@ class InitWeightsSendGroupForRemoteInstanceReqInput(BaseReq):
backend: str = "nccl"
# Now UpdateWeightsFromIPCReqInput and UpdateWeightsFromIPCReqOutput
# are only used by Checkpoint Engine (https://github.com/MoonshotAI/checkpoint-engine)
@dataclass
class UpdateWeightsFromIPCReqInput(BaseReq):
# ZMQ socket paths for each device UUID
zmq_handles: Dict[str, str]
# Whether to flush cache after weight update
flush_cache: bool = True
# Optional: Update weight version along with weights
weight_version: Optional[str] = None
@dataclass
class UpdateWeightsFromIPCReqOutput(BaseReq):
success: bool
message: str
@dataclass
class InitWeightsSendGroupForRemoteInstanceReqOutput(BaseReq):
success: bool

View File

@@ -109,6 +109,7 @@ from sglang.srt.managers.io_struct import (
UnloadLoRAAdapterReqOutput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromIPCReqInput,
UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.mm_utils import init_embedding_cache
@@ -530,6 +531,7 @@ class Scheduler(
self.update_weights_from_distributed,
),
(UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
(UpdateWeightsFromIPCReqInput, self.update_weights_from_ipc),
(GetWeightsByNameReqInput, self.get_weights_by_name),
(ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),

View File

@@ -21,6 +21,8 @@ from sglang.srt.managers.io_struct import (
UpdateWeightFromDiskReqOutput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromDistributedReqOutput,
UpdateWeightsFromIPCReqInput,
UpdateWeightsFromIPCReqOutput,
UpdateWeightsFromTensorReqInput,
UpdateWeightsFromTensorReqOutput,
)
@@ -80,6 +82,18 @@ class SchedulerUpdateWeightsMixin:
torch.distributed.barrier(group=self.tp_cpu_group)
return UpdateWeightsFromTensorReqOutput(success, message)
def update_weights_from_ipc(self, recv_req: UpdateWeightsFromIPCReqInput):
"""Update the online model parameter from IPC for checkpoint-engine integration."""
success, message = self.tp_worker.update_weights_from_ipc(recv_req)
if success:
if recv_req.flush_cache:
flush_cache_success = self.flush_cache()
assert flush_cache_success, "Cache flush failed after updating weights"
else:
logger.error(message)
torch.distributed.barrier(group=self.tp_cpu_group)
return UpdateWeightsFromIPCReqOutput(success, message)
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.tp_worker.get_weights_by_name(recv_req)
return GetWeightsByNameReqOutput(parameter)

View File

@@ -63,6 +63,8 @@ from sglang.srt.managers.io_struct import (
UnloadLoRAAdapterReqOutput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromDistributedReqOutput,
UpdateWeightsFromIPCReqInput,
UpdateWeightsFromIPCReqOutput,
UpdateWeightsFromTensorReqInput,
UpdateWeightsFromTensorReqOutput,
)
@@ -169,6 +171,9 @@ class TokenizerCommunicatorMixin:
self.update_weights_from_tensor_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.update_weights_from_ipc_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
self.get_weights_by_name_communicator = _Communicator(
self.send_to_scheduler, server_args.dp_size
)
@@ -235,6 +240,10 @@ class TokenizerCommunicatorMixin:
UpdateWeightsFromTensorReqOutput,
self.update_weights_from_tensor_communicator.handle_recv,
),
(
UpdateWeightsFromIPCReqOutput,
self.update_weights_from_ipc_communicator.handle_recv,
),
(
GetWeightsByNameReqOutput,
self.get_weights_by_name_communicator.handle_recv,
@@ -442,6 +451,28 @@ class TokenizerCommunicatorMixin:
result = (await self.update_weights_from_tensor_communicator(obj))[0]
return result.success, result.message
async def update_weights_from_ipc(
self,
obj: UpdateWeightsFromIPCReqInput,
request: Optional[fastapi.Request] = None,
) -> Tuple[bool, str]:
"""Update weights via IPC for checkpoint-engine integration."""
self.auto_create_handle_loop()
try:
# For now, we only support single data parallel instance
assert (
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
), "dp_size must be 1 or dp attention must be enabled for update weights from IPC"
logger.info("Starting IPC weight update")
# This means that weight sync cannot run while requests are in progress.
async with self.model_update_lock.writer_lock:
result = (await self.update_weights_from_ipc_communicator(obj))[0]
return result.success, result.message
except Exception as e:
error_msg = f"IPC weight update failed: {str(e)}"
logger.error(error_msg)
return False, error_msg
async def load_lora_adapter(
self: TokenizerManager,
obj: LoadLoRAAdapterReqInput,

View File

@@ -284,6 +284,11 @@ class TokenizerManager(TokenizerCommunicatorMixin):
self.gracefully_exit = False
self.last_receive_tstamp = 0
# Initial weights status
self.initial_weights_loaded = True
if server_args.checkpoint_engine_wait_weights_before_ready:
self.initial_weights_loaded = False
# Dumping
self.dump_requests_folder = "" # By default do not dump
self.dump_requests_threshold = 1000

View File

@@ -32,6 +32,7 @@ from sglang.srt.managers.io_struct import (
UnloadLoRAAdapterReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromIPCReqInput,
UpdateWeightsFromTensorReqInput,
)
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
@@ -164,6 +165,11 @@ class BaseTpWorker(ABC):
)
return success, message
def update_weights_from_ipc(self, recv_req: UpdateWeightsFromIPCReqInput):
"""Update weights from IPC for checkpoint-engine integration."""
success, message = self.model_runner.update_weights_from_ipc(recv_req)
return success, message
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
parameter = self.model_runner.get_weights_by_name(
recv_req.name, recv_req.truncate_size

View File

@@ -2387,6 +2387,23 @@ class ModelRunner:
)
ShardedStateLoader.save_model(self.model, path, pattern, max_size)
def update_weights_from_ipc(self, recv_req):
"""Update weights from IPC for checkpoint-engine integration."""
try:
from sglang.srt.checkpoint_engine.checkpoint_engine_worker import (
SGLangCheckpointEngineWorkerExtensionImpl,
)
# Create a worker extension that integrates with SGLang's model
worker = SGLangCheckpointEngineWorkerExtensionImpl(self)
worker.update_weights_from_ipc(recv_req.zmq_handles)
return True, "IPC weight update completed successfully"
except ImportError as e:
return False, f"IPC weight update failed: ImportError {e}"
except Exception as e:
logger.error(f"IPC weight update failed: {e}")
return False, str(e)
def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
params_dict = dict(model.named_parameters())

View File

@@ -208,6 +208,7 @@ class ServerArgs:
skip_server_warmup: bool = False
warmups: Optional[str] = None
nccl_port: Optional[int] = None
checkpoint_engine_wait_weights_before_ready: bool = False
# Quantization and data type
dtype: str = "auto"
@@ -1704,6 +1705,12 @@ class ServerArgs:
default=ServerArgs.nccl_port,
help="The port for NCCL distributed environment setup. Defaults to a random port.",
)
parser.add_argument(
"--checkpoint-engine-wait-weights-before-ready",
action="store_true",
help="If set, the server will wait for initial weights to be loaded via checkpoint-engine or other update methods "
"before serving inference requests.",
)
# Quantization and data type
parser.add_argument(

View File

@@ -2275,6 +2275,11 @@ def launch_dummy_health_check_server(host, port, enable_metrics):
app = FastAPI()
@app.get("/ping")
async def ping():
"""Could be used by the checkpoint-engine update script to confirm the server is up."""
return Response(status_code=200)
@app.get("/health")
async def health():
"""Check the health of the http server."""