diff --git a/docs/references/environment_variables.md b/docs/references/environment_variables.md index 3f6d19d21..aa72aedd3 100644 --- a/docs/references/environment_variables.md +++ b/docs/references/environment_variables.md @@ -100,4 +100,5 @@ SGLang supports various environment variables that can be used to configure its | Environment Variable | Description | Default Value | | --- | --- | --- | +| `SGLANG_WAIT_WEIGHTS_READY_TIMEOUT` | Timeout period for waiting on weights | `120` | | `SGLANG_DISABLE_OUTLINES_DISK_CACHE` | Disable Outlines disk cache | `true` | diff --git a/examples/checkpoint_engine/update.py b/examples/checkpoint_engine/update.py new file mode 100644 index 000000000..86b588cce --- /dev/null +++ b/examples/checkpoint_engine/update.py @@ -0,0 +1,241 @@ +""" +Usage: +1) Launch the server with wait-for-initial-weights option in one terminal: + python -m sglang.launch_server --model-path /workspace/Qwen/Qwen3-4B/ --tensor-parallel-size 2 --port 19730 --load-format dummy --checkpoint-engine-wait-weights-before-ready --mem-fraction-static 0.7 + +2) Torchrun this script in another terminal: + torchrun --nproc-per-node 2 update.py --update-method broadcast --checkpoint-path /workspace/Qwen/Qwen3-4B/ --inference-parallel-size 2 +""" + +import argparse +import json +import os +import pickle +import time +from collections import defaultdict +from collections.abc import Callable +from contextlib import contextmanager +from typing import Literal + +import httpx +import torch +import torch.distributed as dist +from checkpoint_engine.ps import ParameterServer +from loguru import logger +from safetensors import safe_open + + +@contextmanager +def timer(msg: str): + start = time.perf_counter() + yield + end = time.perf_counter() + logger.info(f"{msg} duration: {end - start:.2f} seconds") + + +def check_sglang_ready( + endpoint: str, inference_parallel_size: int, uds: str | None = None +): + if rank != rank // inference_parallel_size * inference_parallel_size: + return + retry_num = 0 + transport = None + if uds is not None: + transport = httpx.HTTPTransport(uds=uds) + with httpx.Client(transport=transport) as client: + while True: + try: + response = client.get(f"{endpoint}/ping", timeout=10) + response.raise_for_status() + break + except (httpx.ConnectError, httpx.HTTPStatusError) as e: + if retry_num % 10 == 0: + logger.warning( + f"fail to check sglang ready, retry {retry_num} times, error: {e}" + ) + retry_num += 1 + time.sleep(0.1) + + +def split_checkpoint_files( + checkpoint_path: str, rank: int, world_size: int +) -> list[str]: + checkpoint_files = [ + os.path.join(checkpoint_path, f) + for f in filter( + lambda x: x.endswith(".safetensors"), os.listdir(checkpoint_path) + ) + ] + files_per_rank = (len(checkpoint_files) + world_size - 1) // world_size + return checkpoint_files[rank * files_per_rank : (rank + 1) * files_per_rank] + + +def split_tensors( + checkpoint_path: str, rank: int, world_size: int +) -> dict[str, torch.Tensor]: + index_fn = os.path.join(checkpoint_path, "model.safetensors.index.json") + with open(index_fn) as f: + weight_map: dict[str, str] = json.load(f)["weight_map"] + weights_per_rank = (len(weight_map) + world_size - 1) // world_size + fn_tensors: dict[str, list[str]] = defaultdict(list) + weight_keys = list(weight_map.items()) + for name, file in weight_keys[ + rank * weights_per_rank : (rank + 1) * weights_per_rank + ]: + fn_tensors[file].append(name) + named_tensors = {} + for file, names in fn_tensors.items(): + with safe_open(os.path.join(checkpoint_path, file), framework="pt") as f: + for name in names: + named_tensors[name] = f.get_tensor(name) + return named_tensors + + +def req_inference( + endpoint: str, + inference_parallel_size: int, + timeout: float = 300.0, + uds: str | None = None, + weight_version: str | None = None, +) -> Callable[[list[tuple[str, str]]], None]: + rank = int(os.getenv("RANK", 0)) + src = rank // inference_parallel_size * inference_parallel_size + + def req_func(socket_paths: list[tuple[str, str]]): + if rank == src: + with httpx.Client(transport=httpx.HTTPTransport(uds=uds)) as client: + resp = client.post( + f"{endpoint}/update_weights_from_ipc", + json={ + "zmq_handles": dict( + socket_paths[src : src + inference_parallel_size] + ), + "flush_cache": True, + "weight_version": weight_version, + }, + timeout=timeout, + ) + resp.raise_for_status() + + return req_func + + +def update_weights( + ps: ParameterServer, + checkpoint_name: str, + checkpoint_files: list[str], + named_tensors: dict[str, torch.Tensor], + req_func: Callable[[list[tuple[str, str]]], None], + inference_parallel_size: int, + endpoint: str, + save_metas_file: str | None = None, + update_method: Literal["broadcast", "p2p", "all"] = "broadcast", + uds: str | None = None, +): + ps.register_checkpoint( + checkpoint_name, files=checkpoint_files, named_tensors=named_tensors + ) + ps.init_process_group() + check_sglang_ready(endpoint, inference_parallel_size, uds) + dist.barrier() + with timer("Gather metas"): + ps.gather_metas(checkpoint_name) + if save_metas_file and int(os.getenv("RANK")) == 0: + with open(save_metas_file, "wb") as f: + pickle.dump(ps.get_metas(), f) + + if update_method == "broadcast" or update_method == "all": + with timer("Update weights without setting ranks"): + ps.update(checkpoint_name, req_func) + + if update_method == "p2p" or update_method == "all": + if update_method: + # sleep 2s to wait destroy process group + time.sleep(2) + with timer("Update weights with setting ranks"): + ps.update( + checkpoint_name, req_func, ranks=list(range(inference_parallel_size)) + ) + + +def join( + ps: ParameterServer, + checkpoint_name: str, + load_metas_file: str, + req_func: Callable[[list[tuple[str, str]]], None], + inference_parallel_size: int, + endpoint: str, + uds: str | None = None, +): + assert load_metas_file, "load_metas_file is required" + with open(load_metas_file, "rb") as f: + metas = pickle.load(f) + ps.init_process_group() + check_sglang_ready(endpoint, inference_parallel_size, uds) + dist.barrier() + with timer("Gather metas before join"): + ps.gather_metas(checkpoint_name) + ps.load_metas(metas) + with timer( + f"Update weights with setting ranks as range(0, {inference_parallel_size}) by using p2p" + ): + ps.update(checkpoint_name, req_func, ranks=list(range(inference_parallel_size))) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Update weights example") + parser.add_argument("--checkpoint-path", type=str, default=None) + parser.add_argument("--save-metas-file", type=str, default=None) + parser.add_argument("--load-metas-file", type=str, default=None) + parser.add_argument("--sleep-time", type=int, default=0) + parser.add_argument("--endpoint", type=str, default="http://localhost:19730") + parser.add_argument("--inference-parallel-size", type=int, default=8) + parser.add_argument("--checkpoint-name", type=str, default="my-checkpoint-iter-0") + parser.add_argument("--update-method", type=str, default="broadcast") + parser.add_argument("--uds", type=str, default=None) + parser.add_argument("--weight-version", type=str, default=None) + args = parser.parse_args() + rank = int(os.getenv("RANK")) + world_size = int(os.getenv("WORLD_SIZE")) + req_func = req_inference( + args.endpoint, + args.inference_parallel_size, + uds=args.uds, + weight_version=args.weight_version, + ) + ps = ParameterServer(auto_pg=True) + ps._p2p_store = None + if args.load_metas_file: + join( + ps, + args.checkpoint_name, + args.load_metas_file, + req_func, + args.inference_parallel_size, + args.endpoint, + args.uds, + ) + else: + if os.path.exists( + os.path.join(args.checkpoint_path, "model.safetensors.index.json") + ): + named_tensors = split_tensors(args.checkpoint_path, rank, world_size) + checkpoint_files = [] + else: + checkpoint_files = split_checkpoint_files( + args.checkpoint_path, rank, world_size + ) + named_tensors = {} + update_weights( + ps, + args.checkpoint_name, + checkpoint_files, + named_tensors, + req_func, + args.inference_parallel_size, + args.endpoint, + args.save_metas_file, + args.update_method, + args.uds, + ) + time.sleep(args.sleep_time) diff --git a/python/pyproject.toml b/python/pyproject.toml index fa909c4b9..a68d9ae89 100755 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -89,6 +89,7 @@ test = [ "sentence_transformers", "tabulate", ] +checkpoint-engine = ["checkpoint-engine==0.1.2"] all = [] dev = ["sglang[test]"] diff --git a/python/sglang/srt/checkpoint_engine/checkpoint_engine_worker.py b/python/sglang/srt/checkpoint_engine/checkpoint_engine_worker.py new file mode 100644 index 000000000..dd8805e65 --- /dev/null +++ b/python/sglang/srt/checkpoint_engine/checkpoint_engine_worker.py @@ -0,0 +1,142 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" +Checkpoint-engine integration for SGLang. +This module provides weight update functionality via IPC for checkpoint-engine compatibility. +""" +import logging +from typing import Callable, Dict, Optional + +import torch +import zmq + +try: + from checkpoint_engine.worker import update_weights_from_ipc +except ImportError: + raise ImportError( + "checkpoint-engine is not installed. " + "Please install it with: pip install sglang[checkpoint-engine]" + ) + +logger = logging.getLogger(__name__) + + +class SGLangCheckpointEngineWorkerExtension: + """ + Worker extension for SGLang to support checkpoint-engine IPC weight updates. + This class provides the interface needed for checkpoint-engine integration. + """ + + def __init__(self): + self._zmq_ctx: Optional[zmq.Context] = None + + def get_device_uuid(self) -> str: + """Get the UUID of current device.""" + # We need to implement this to get the device UUID + # This will be overridden when integrated into SGLang's worker + raise NotImplementedError( + "This method should be overridden by SGLang integration" + ) + + def get_device_id(self) -> int: + """Get the device ID.""" + raise NotImplementedError( + "This method should be overridden by SGLang integration" + ) + + def get_model_loader(self) -> Callable: + """Get the model weight loader function.""" + raise NotImplementedError( + "This method should be overridden by SGLang integration" + ) + + def get_post_hook(self) -> Optional[Callable]: + """Get the post-processing hook after weight loading.""" + return None + + def update_weights_from_ipc(self, zmq_handles: Dict[str, str]): + """ + Update weights from IPC communication. + Args: + zmq_handles: Dict mapping device UUID to ZMQ socket path + """ + if self._zmq_ctx is None: + self._zmq_ctx = zmq.Context() + device_uuid = self.get_device_uuid() + device_id = self.get_device_id() + if device_uuid not in zmq_handles: + raise ValueError( + f"Device UUID {device_uuid} not found in zmq_handles: {list(zmq_handles.keys())}" + ) + update_weights_from_ipc( + self._zmq_ctx, + zmq_handles[device_uuid], + device_id=device_id, + run=self.get_model_loader(), + post_hook=self.get_post_hook(), + ) + + +class SGLangCheckpointEngineWorkerExtensionImpl(SGLangCheckpointEngineWorkerExtension): + """ + Implementation of SGLangCheckpointEngineWorkerExtension that integrates with SGLang's model runner. + This class provides the concrete implementation for checkpoint-engine IPC weight updates. + """ + + def __init__(self, model_runner): + super().__init__() + self.model_runner = model_runner + + def get_device_uuid(self) -> str: + """Get the UUID of current device.""" + # Get device UUID for current device + device_id = torch.cuda.current_device() + try: + return f"GPU-{torch.cuda.get_device_properties(device_id).uuid!s}" + except AssertionError as e: + raise ValueError(f"Failed to get GPU UUID for device {device_id}") from e + + def get_device_id(self) -> int: + """Get the device ID.""" + return torch.cuda.current_device() + + def get_model_loader(self) -> Callable: + """Get the model weight loader function.""" + return self.model_runner.model.load_weights + + def get_post_hook(self) -> Optional[Callable]: + """Get the post-processing hook after weight loading.""" + + def post_hook(): + # Perform post-processing after weight loading similar to DefaultModelLoader + try: + from sglang.srt.model_loader.loader import device_loading_context + + # Process quantization methods after loading weights + for _, module in self.model_runner.model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + # Move parameters to device if needed for quantization processing + target_device = torch.device( + "cuda", torch.cuda.current_device() + ) + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) + # Call model-specific post-loading hook if available + if hasattr(self.model_runner.model, "post_load_weights"): + self.model_runner.model.post_load_weights() + except Exception as e: + logger.warning(f"Post-hook processing failed: {e}") + + return post_hook diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 2865c0f1e..f5d6bb848 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -59,6 +59,7 @@ from sglang.srt.managers.io_struct import ( UnloadLoRAAdapterReqInput, UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, + UpdateWeightsFromIPCReqInput, UpdateWeightsFromTensorReqInput, ) from sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerRouter @@ -649,6 +650,21 @@ class Engine(EngineBase): request=None, ) + def update_weights_from_ipc( + self, + zmq_handles: Dict[str, str], + flush_cache: bool = True, + ): + """Update weights from IPC for checkpoint-engine integration.""" + obj = UpdateWeightsFromIPCReqInput( + zmq_handles=zmq_handles, + flush_cache=flush_cache, + ) + loop = asyncio.get_event_loop() + return loop.run_until_complete( + self.tokenizer_manager.update_weights_from_ipc(obj, None) + ) + def _set_envs_and_config(server_args: ServerArgs): # Set global environments diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 129793252..6a0bb0aae 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -96,6 +96,7 @@ from sglang.srt.managers.io_struct import ( UnloadLoRAAdapterReqInput, UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, + UpdateWeightsFromIPCReqInput, UpdateWeightsFromTensorReqInput, UpdateWeightVersionReqInput, VertexGenerateReqInput, @@ -129,6 +130,7 @@ logger = logging.getLogger(__name__) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20)) +WAIT_WEIGHTS_READY_TIMEOUT = int(os.getenv("SGLANG_WAIT_WEIGHTS_READY_TIMEOUT", 120)) # Store global states @@ -838,6 +840,27 @@ async def update_weights_from_distributed( return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) +@app.post("/update_weights_from_ipc") +async def update_weights_from_ipc(obj: UpdateWeightsFromIPCReqInput, request: Request): + """Update the weights from IPC (Inter-Process Communication) for checkpoint-engine integration.""" + success, message = await _global_state.tokenizer_manager.update_weights_from_ipc( + obj, request + ) + + # Update weight version if provided and weights update was successful + if success and obj.weight_version is not None: + _update_weight_version_if_provided(obj.weight_version) + message += f" Weight version updated to {obj.weight_version}." + + content = {"success": success, "message": message} + if success: + if _global_state.tokenizer_manager.initial_weights_loaded is False: + _global_state.tokenizer_manager.initial_weights_loaded = True + return ORJSONResponse(content) + else: + return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST) + + @app.post("/update_weight_version") async def update_weight_version(obj: UpdateWeightVersionReqInput, request: Request): """Update the weight version. This operation requires no active requests.""" @@ -1530,6 +1553,8 @@ def _wait_and_warmup( pipe_finish_writer: Optional[multiprocessing.connection.Connection], launch_callback: Optional[Callable[[], None]] = None, ): + if server_args.checkpoint_engine_wait_weights_before_ready: + _wait_weights_ready() if not server_args.skip_server_warmup: if not _execute_server_warmup( server_args, @@ -1552,3 +1577,24 @@ def _wait_and_warmup( if launch_callback is not None: launch_callback() + + +def _wait_weights_ready(): + """Wait for weights to be ready within the specified timeout.""" + timeout = WAIT_WEIGHTS_READY_TIMEOUT + start_time = time.time() + + for _ in range(timeout): + if _global_state.tokenizer_manager.initial_weights_loaded: + logger.info( + f"Weights are ready after {time.time() - start_time:.2f} seconds" + ) + return + time.sleep(1) + + # Timeout reached without weights being ready + logger.error( + f"Weights are not ready after waiting {timeout} seconds. " + f"Consider increasing SGLANG_WAIT_WEIGHTS_READY_TIMEOUT environment variable. " + f"Current status: initial_weights_loaded={_global_state.tokenizer_manager.initial_weights_loaded}" + ) diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 849204aad..72516334e 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -1080,6 +1080,24 @@ class InitWeightsSendGroupForRemoteInstanceReqInput(BaseReq): backend: str = "nccl" +# Now UpdateWeightsFromIPCReqInput and UpdateWeightsFromIPCReqOutput +# are only used by Checkpoint Engine (https://github.com/MoonshotAI/checkpoint-engine) +@dataclass +class UpdateWeightsFromIPCReqInput(BaseReq): + # ZMQ socket paths for each device UUID + zmq_handles: Dict[str, str] + # Whether to flush cache after weight update + flush_cache: bool = True + # Optional: Update weight version along with weights + weight_version: Optional[str] = None + + +@dataclass +class UpdateWeightsFromIPCReqOutput(BaseReq): + success: bool + message: str + + @dataclass class InitWeightsSendGroupForRemoteInstanceReqOutput(BaseReq): success: bool diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index ed1fe91e9..293851111 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -109,6 +109,7 @@ from sglang.srt.managers.io_struct import ( UnloadLoRAAdapterReqOutput, UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, + UpdateWeightsFromIPCReqInput, UpdateWeightsFromTensorReqInput, ) from sglang.srt.managers.mm_utils import init_embedding_cache @@ -530,6 +531,7 @@ class Scheduler( self.update_weights_from_distributed, ), (UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor), + (UpdateWeightsFromIPCReqInput, self.update_weights_from_ipc), (GetWeightsByNameReqInput, self.get_weights_by_name), (ReleaseMemoryOccupationReqInput, self.release_memory_occupation), (ResumeMemoryOccupationReqInput, self.resume_memory_occupation), diff --git a/python/sglang/srt/managers/scheduler_update_weights_mixin.py b/python/sglang/srt/managers/scheduler_update_weights_mixin.py index 7552bcce0..647648fe1 100644 --- a/python/sglang/srt/managers/scheduler_update_weights_mixin.py +++ b/python/sglang/srt/managers/scheduler_update_weights_mixin.py @@ -21,6 +21,8 @@ from sglang.srt.managers.io_struct import ( UpdateWeightFromDiskReqOutput, UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqOutput, + UpdateWeightsFromIPCReqInput, + UpdateWeightsFromIPCReqOutput, UpdateWeightsFromTensorReqInput, UpdateWeightsFromTensorReqOutput, ) @@ -80,6 +82,18 @@ class SchedulerUpdateWeightsMixin: torch.distributed.barrier(group=self.tp_cpu_group) return UpdateWeightsFromTensorReqOutput(success, message) + def update_weights_from_ipc(self, recv_req: UpdateWeightsFromIPCReqInput): + """Update the online model parameter from IPC for checkpoint-engine integration.""" + success, message = self.tp_worker.update_weights_from_ipc(recv_req) + if success: + if recv_req.flush_cache: + flush_cache_success = self.flush_cache() + assert flush_cache_success, "Cache flush failed after updating weights" + else: + logger.error(message) + torch.distributed.barrier(group=self.tp_cpu_group) + return UpdateWeightsFromIPCReqOutput(success, message) + def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): parameter = self.tp_worker.get_weights_by_name(recv_req) return GetWeightsByNameReqOutput(parameter) diff --git a/python/sglang/srt/managers/tokenizer_communicator_mixin.py b/python/sglang/srt/managers/tokenizer_communicator_mixin.py index c0283d05d..81f8cb97c 100644 --- a/python/sglang/srt/managers/tokenizer_communicator_mixin.py +++ b/python/sglang/srt/managers/tokenizer_communicator_mixin.py @@ -63,6 +63,8 @@ from sglang.srt.managers.io_struct import ( UnloadLoRAAdapterReqOutput, UpdateWeightsFromDistributedReqInput, UpdateWeightsFromDistributedReqOutput, + UpdateWeightsFromIPCReqInput, + UpdateWeightsFromIPCReqOutput, UpdateWeightsFromTensorReqInput, UpdateWeightsFromTensorReqOutput, ) @@ -169,6 +171,9 @@ class TokenizerCommunicatorMixin: self.update_weights_from_tensor_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) + self.update_weights_from_ipc_communicator = _Communicator( + self.send_to_scheduler, server_args.dp_size + ) self.get_weights_by_name_communicator = _Communicator( self.send_to_scheduler, server_args.dp_size ) @@ -235,6 +240,10 @@ class TokenizerCommunicatorMixin: UpdateWeightsFromTensorReqOutput, self.update_weights_from_tensor_communicator.handle_recv, ), + ( + UpdateWeightsFromIPCReqOutput, + self.update_weights_from_ipc_communicator.handle_recv, + ), ( GetWeightsByNameReqOutput, self.get_weights_by_name_communicator.handle_recv, @@ -442,6 +451,28 @@ class TokenizerCommunicatorMixin: result = (await self.update_weights_from_tensor_communicator(obj))[0] return result.success, result.message + async def update_weights_from_ipc( + self, + obj: UpdateWeightsFromIPCReqInput, + request: Optional[fastapi.Request] = None, + ) -> Tuple[bool, str]: + """Update weights via IPC for checkpoint-engine integration.""" + self.auto_create_handle_loop() + try: + # For now, we only support single data parallel instance + assert ( + self.server_args.dp_size == 1 or self.server_args.enable_dp_attention + ), "dp_size must be 1 or dp attention must be enabled for update weights from IPC" + logger.info("Starting IPC weight update") + # This means that weight sync cannot run while requests are in progress. + async with self.model_update_lock.writer_lock: + result = (await self.update_weights_from_ipc_communicator(obj))[0] + return result.success, result.message + except Exception as e: + error_msg = f"IPC weight update failed: {str(e)}" + logger.error(error_msg) + return False, error_msg + async def load_lora_adapter( self: TokenizerManager, obj: LoadLoRAAdapterReqInput, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 63eaaa268..f9005cad8 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -284,6 +284,11 @@ class TokenizerManager(TokenizerCommunicatorMixin): self.gracefully_exit = False self.last_receive_tstamp = 0 + # Initial weights status + self.initial_weights_loaded = True + if server_args.checkpoint_engine_wait_weights_before_ready: + self.initial_weights_loaded = False + # Dumping self.dump_requests_folder = "" # By default do not dump self.dump_requests_threshold = 1000 diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 0a623d4a2..f4daf3679 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -32,6 +32,7 @@ from sglang.srt.managers.io_struct import ( UnloadLoRAAdapterReqInput, UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, + UpdateWeightsFromIPCReqInput, UpdateWeightsFromTensorReqInput, ) from sglang.srt.managers.schedule_batch import ModelWorkerBatch @@ -164,6 +165,11 @@ class BaseTpWorker(ABC): ) return success, message + def update_weights_from_ipc(self, recv_req: UpdateWeightsFromIPCReqInput): + """Update weights from IPC for checkpoint-engine integration.""" + success, message = self.model_runner.update_weights_from_ipc(recv_req) + return success, message + def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput): parameter = self.model_runner.get_weights_by_name( recv_req.name, recv_req.truncate_size diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index ca67da75e..35fc8afbd 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -2387,6 +2387,23 @@ class ModelRunner: ) ShardedStateLoader.save_model(self.model, path, pattern, max_size) + def update_weights_from_ipc(self, recv_req): + """Update weights from IPC for checkpoint-engine integration.""" + try: + from sglang.srt.checkpoint_engine.checkpoint_engine_worker import ( + SGLangCheckpointEngineWorkerExtensionImpl, + ) + + # Create a worker extension that integrates with SGLang's model + worker = SGLangCheckpointEngineWorkerExtensionImpl(self) + worker.update_weights_from_ipc(recv_req.zmq_handles) + return True, "IPC weight update completed successfully" + except ImportError as e: + return False, f"IPC weight update failed: ImportError {e}" + except Exception as e: + logger.error(f"IPC weight update failed: {e}") + return False, str(e) + def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]): params_dict = dict(model.named_parameters()) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 2f992cecd..7935de6f6 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -208,6 +208,7 @@ class ServerArgs: skip_server_warmup: bool = False warmups: Optional[str] = None nccl_port: Optional[int] = None + checkpoint_engine_wait_weights_before_ready: bool = False # Quantization and data type dtype: str = "auto" @@ -1704,6 +1705,12 @@ class ServerArgs: default=ServerArgs.nccl_port, help="The port for NCCL distributed environment setup. Defaults to a random port.", ) + parser.add_argument( + "--checkpoint-engine-wait-weights-before-ready", + action="store_true", + help="If set, the server will wait for initial weights to be loaded via checkpoint-engine or other update methods " + "before serving inference requests.", + ) # Quantization and data type parser.add_argument( diff --git a/python/sglang/srt/utils/common.py b/python/sglang/srt/utils/common.py index 7c2f573e4..cac2046cc 100644 --- a/python/sglang/srt/utils/common.py +++ b/python/sglang/srt/utils/common.py @@ -2275,6 +2275,11 @@ def launch_dummy_health_check_server(host, port, enable_metrics): app = FastAPI() + @app.get("/ping") + async def ping(): + """Could be used by the checkpoint-engine update script to confirm the server is up.""" + return Response(status_code=200) + @app.get("/health") async def health(): """Check the health of the http server."""