[Feature] Support loading weights from ckpt engine worker (#11755)
Signed-off-by: Yang Kaiyong <yangkaiyong.yky@antgroup.com> Signed-off-by: Cruz Zhao <CruzZhao@linux.alibaba.com> Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com> Co-authored-by: Yang Kaiyong <yangkaiyong.yky@antgroup.com> Co-authored-by: Cruz Zhao <CruzZhao@linux.alibaba.com> Co-authored-by: Xuchun Shang <xuchun.shang@gmail.com> Co-authored-by: Shangming Cai <csmthu@gmail.com>
This commit is contained in:
@@ -100,4 +100,5 @@ SGLang supports various environment variables that can be used to configure its
|
||||
|
||||
| Environment Variable | Description | Default Value |
|
||||
| --- | --- | --- |
|
||||
| `SGLANG_WAIT_WEIGHTS_READY_TIMEOUT` | Timeout period for waiting on weights | `120` |
|
||||
| `SGLANG_DISABLE_OUTLINES_DISK_CACHE` | Disable Outlines disk cache | `true` |
|
||||
|
||||
241
examples/checkpoint_engine/update.py
Normal file
241
examples/checkpoint_engine/update.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""
|
||||
Usage:
|
||||
1) Launch the server with wait-for-initial-weights option in one terminal:
|
||||
python -m sglang.launch_server --model-path /workspace/Qwen/Qwen3-4B/ --tensor-parallel-size 2 --port 19730 --load-format dummy --checkpoint-engine-wait-weights-before-ready --mem-fraction-static 0.7
|
||||
|
||||
2) Torchrun this script in another terminal:
|
||||
torchrun --nproc-per-node 2 update.py --update-method broadcast --checkpoint-path /workspace/Qwen/Qwen3-4B/ --inference-parallel-size 2
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from contextlib import contextmanager
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from checkpoint_engine.ps import ParameterServer
|
||||
from loguru import logger
|
||||
from safetensors import safe_open
|
||||
|
||||
|
||||
@contextmanager
|
||||
def timer(msg: str):
|
||||
start = time.perf_counter()
|
||||
yield
|
||||
end = time.perf_counter()
|
||||
logger.info(f"{msg} duration: {end - start:.2f} seconds")
|
||||
|
||||
|
||||
def check_sglang_ready(
|
||||
endpoint: str, inference_parallel_size: int, uds: str | None = None
|
||||
):
|
||||
if rank != rank // inference_parallel_size * inference_parallel_size:
|
||||
return
|
||||
retry_num = 0
|
||||
transport = None
|
||||
if uds is not None:
|
||||
transport = httpx.HTTPTransport(uds=uds)
|
||||
with httpx.Client(transport=transport) as client:
|
||||
while True:
|
||||
try:
|
||||
response = client.get(f"{endpoint}/ping", timeout=10)
|
||||
response.raise_for_status()
|
||||
break
|
||||
except (httpx.ConnectError, httpx.HTTPStatusError) as e:
|
||||
if retry_num % 10 == 0:
|
||||
logger.warning(
|
||||
f"fail to check sglang ready, retry {retry_num} times, error: {e}"
|
||||
)
|
||||
retry_num += 1
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
def split_checkpoint_files(
|
||||
checkpoint_path: str, rank: int, world_size: int
|
||||
) -> list[str]:
|
||||
checkpoint_files = [
|
||||
os.path.join(checkpoint_path, f)
|
||||
for f in filter(
|
||||
lambda x: x.endswith(".safetensors"), os.listdir(checkpoint_path)
|
||||
)
|
||||
]
|
||||
files_per_rank = (len(checkpoint_files) + world_size - 1) // world_size
|
||||
return checkpoint_files[rank * files_per_rank : (rank + 1) * files_per_rank]
|
||||
|
||||
|
||||
def split_tensors(
|
||||
checkpoint_path: str, rank: int, world_size: int
|
||||
) -> dict[str, torch.Tensor]:
|
||||
index_fn = os.path.join(checkpoint_path, "model.safetensors.index.json")
|
||||
with open(index_fn) as f:
|
||||
weight_map: dict[str, str] = json.load(f)["weight_map"]
|
||||
weights_per_rank = (len(weight_map) + world_size - 1) // world_size
|
||||
fn_tensors: dict[str, list[str]] = defaultdict(list)
|
||||
weight_keys = list(weight_map.items())
|
||||
for name, file in weight_keys[
|
||||
rank * weights_per_rank : (rank + 1) * weights_per_rank
|
||||
]:
|
||||
fn_tensors[file].append(name)
|
||||
named_tensors = {}
|
||||
for file, names in fn_tensors.items():
|
||||
with safe_open(os.path.join(checkpoint_path, file), framework="pt") as f:
|
||||
for name in names:
|
||||
named_tensors[name] = f.get_tensor(name)
|
||||
return named_tensors
|
||||
|
||||
|
||||
def req_inference(
|
||||
endpoint: str,
|
||||
inference_parallel_size: int,
|
||||
timeout: float = 300.0,
|
||||
uds: str | None = None,
|
||||
weight_version: str | None = None,
|
||||
) -> Callable[[list[tuple[str, str]]], None]:
|
||||
rank = int(os.getenv("RANK", 0))
|
||||
src = rank // inference_parallel_size * inference_parallel_size
|
||||
|
||||
def req_func(socket_paths: list[tuple[str, str]]):
|
||||
if rank == src:
|
||||
with httpx.Client(transport=httpx.HTTPTransport(uds=uds)) as client:
|
||||
resp = client.post(
|
||||
f"{endpoint}/update_weights_from_ipc",
|
||||
json={
|
||||
"zmq_handles": dict(
|
||||
socket_paths[src : src + inference_parallel_size]
|
||||
),
|
||||
"flush_cache": True,
|
||||
"weight_version": weight_version,
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
||||
return req_func
|
||||
|
||||
|
||||
def update_weights(
|
||||
ps: ParameterServer,
|
||||
checkpoint_name: str,
|
||||
checkpoint_files: list[str],
|
||||
named_tensors: dict[str, torch.Tensor],
|
||||
req_func: Callable[[list[tuple[str, str]]], None],
|
||||
inference_parallel_size: int,
|
||||
endpoint: str,
|
||||
save_metas_file: str | None = None,
|
||||
update_method: Literal["broadcast", "p2p", "all"] = "broadcast",
|
||||
uds: str | None = None,
|
||||
):
|
||||
ps.register_checkpoint(
|
||||
checkpoint_name, files=checkpoint_files, named_tensors=named_tensors
|
||||
)
|
||||
ps.init_process_group()
|
||||
check_sglang_ready(endpoint, inference_parallel_size, uds)
|
||||
dist.barrier()
|
||||
with timer("Gather metas"):
|
||||
ps.gather_metas(checkpoint_name)
|
||||
if save_metas_file and int(os.getenv("RANK")) == 0:
|
||||
with open(save_metas_file, "wb") as f:
|
||||
pickle.dump(ps.get_metas(), f)
|
||||
|
||||
if update_method == "broadcast" or update_method == "all":
|
||||
with timer("Update weights without setting ranks"):
|
||||
ps.update(checkpoint_name, req_func)
|
||||
|
||||
if update_method == "p2p" or update_method == "all":
|
||||
if update_method:
|
||||
# sleep 2s to wait destroy process group
|
||||
time.sleep(2)
|
||||
with timer("Update weights with setting ranks"):
|
||||
ps.update(
|
||||
checkpoint_name, req_func, ranks=list(range(inference_parallel_size))
|
||||
)
|
||||
|
||||
|
||||
def join(
|
||||
ps: ParameterServer,
|
||||
checkpoint_name: str,
|
||||
load_metas_file: str,
|
||||
req_func: Callable[[list[tuple[str, str]]], None],
|
||||
inference_parallel_size: int,
|
||||
endpoint: str,
|
||||
uds: str | None = None,
|
||||
):
|
||||
assert load_metas_file, "load_metas_file is required"
|
||||
with open(load_metas_file, "rb") as f:
|
||||
metas = pickle.load(f)
|
||||
ps.init_process_group()
|
||||
check_sglang_ready(endpoint, inference_parallel_size, uds)
|
||||
dist.barrier()
|
||||
with timer("Gather metas before join"):
|
||||
ps.gather_metas(checkpoint_name)
|
||||
ps.load_metas(metas)
|
||||
with timer(
|
||||
f"Update weights with setting ranks as range(0, {inference_parallel_size}) by using p2p"
|
||||
):
|
||||
ps.update(checkpoint_name, req_func, ranks=list(range(inference_parallel_size)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Update weights example")
|
||||
parser.add_argument("--checkpoint-path", type=str, default=None)
|
||||
parser.add_argument("--save-metas-file", type=str, default=None)
|
||||
parser.add_argument("--load-metas-file", type=str, default=None)
|
||||
parser.add_argument("--sleep-time", type=int, default=0)
|
||||
parser.add_argument("--endpoint", type=str, default="http://localhost:19730")
|
||||
parser.add_argument("--inference-parallel-size", type=int, default=8)
|
||||
parser.add_argument("--checkpoint-name", type=str, default="my-checkpoint-iter-0")
|
||||
parser.add_argument("--update-method", type=str, default="broadcast")
|
||||
parser.add_argument("--uds", type=str, default=None)
|
||||
parser.add_argument("--weight-version", type=str, default=None)
|
||||
args = parser.parse_args()
|
||||
rank = int(os.getenv("RANK"))
|
||||
world_size = int(os.getenv("WORLD_SIZE"))
|
||||
req_func = req_inference(
|
||||
args.endpoint,
|
||||
args.inference_parallel_size,
|
||||
uds=args.uds,
|
||||
weight_version=args.weight_version,
|
||||
)
|
||||
ps = ParameterServer(auto_pg=True)
|
||||
ps._p2p_store = None
|
||||
if args.load_metas_file:
|
||||
join(
|
||||
ps,
|
||||
args.checkpoint_name,
|
||||
args.load_metas_file,
|
||||
req_func,
|
||||
args.inference_parallel_size,
|
||||
args.endpoint,
|
||||
args.uds,
|
||||
)
|
||||
else:
|
||||
if os.path.exists(
|
||||
os.path.join(args.checkpoint_path, "model.safetensors.index.json")
|
||||
):
|
||||
named_tensors = split_tensors(args.checkpoint_path, rank, world_size)
|
||||
checkpoint_files = []
|
||||
else:
|
||||
checkpoint_files = split_checkpoint_files(
|
||||
args.checkpoint_path, rank, world_size
|
||||
)
|
||||
named_tensors = {}
|
||||
update_weights(
|
||||
ps,
|
||||
args.checkpoint_name,
|
||||
checkpoint_files,
|
||||
named_tensors,
|
||||
req_func,
|
||||
args.inference_parallel_size,
|
||||
args.endpoint,
|
||||
args.save_metas_file,
|
||||
args.update_method,
|
||||
args.uds,
|
||||
)
|
||||
time.sleep(args.sleep_time)
|
||||
@@ -89,6 +89,7 @@ test = [
|
||||
"sentence_transformers",
|
||||
"tabulate",
|
||||
]
|
||||
checkpoint-engine = ["checkpoint-engine==0.1.2"]
|
||||
all = []
|
||||
dev = ["sglang[test]"]
|
||||
|
||||
|
||||
142
python/sglang/srt/checkpoint_engine/checkpoint_engine_worker.py
Normal file
142
python/sglang/srt/checkpoint_engine/checkpoint_engine_worker.py
Normal file
@@ -0,0 +1,142 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""
|
||||
Checkpoint-engine integration for SGLang.
|
||||
This module provides weight update functionality via IPC for checkpoint-engine compatibility.
|
||||
"""
|
||||
import logging
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
import torch
|
||||
import zmq
|
||||
|
||||
try:
|
||||
from checkpoint_engine.worker import update_weights_from_ipc
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"checkpoint-engine is not installed. "
|
||||
"Please install it with: pip install sglang[checkpoint-engine]"
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SGLangCheckpointEngineWorkerExtension:
|
||||
"""
|
||||
Worker extension for SGLang to support checkpoint-engine IPC weight updates.
|
||||
This class provides the interface needed for checkpoint-engine integration.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._zmq_ctx: Optional[zmq.Context] = None
|
||||
|
||||
def get_device_uuid(self) -> str:
|
||||
"""Get the UUID of current device."""
|
||||
# We need to implement this to get the device UUID
|
||||
# This will be overridden when integrated into SGLang's worker
|
||||
raise NotImplementedError(
|
||||
"This method should be overridden by SGLang integration"
|
||||
)
|
||||
|
||||
def get_device_id(self) -> int:
|
||||
"""Get the device ID."""
|
||||
raise NotImplementedError(
|
||||
"This method should be overridden by SGLang integration"
|
||||
)
|
||||
|
||||
def get_model_loader(self) -> Callable:
|
||||
"""Get the model weight loader function."""
|
||||
raise NotImplementedError(
|
||||
"This method should be overridden by SGLang integration"
|
||||
)
|
||||
|
||||
def get_post_hook(self) -> Optional[Callable]:
|
||||
"""Get the post-processing hook after weight loading."""
|
||||
return None
|
||||
|
||||
def update_weights_from_ipc(self, zmq_handles: Dict[str, str]):
|
||||
"""
|
||||
Update weights from IPC communication.
|
||||
Args:
|
||||
zmq_handles: Dict mapping device UUID to ZMQ socket path
|
||||
"""
|
||||
if self._zmq_ctx is None:
|
||||
self._zmq_ctx = zmq.Context()
|
||||
device_uuid = self.get_device_uuid()
|
||||
device_id = self.get_device_id()
|
||||
if device_uuid not in zmq_handles:
|
||||
raise ValueError(
|
||||
f"Device UUID {device_uuid} not found in zmq_handles: {list(zmq_handles.keys())}"
|
||||
)
|
||||
update_weights_from_ipc(
|
||||
self._zmq_ctx,
|
||||
zmq_handles[device_uuid],
|
||||
device_id=device_id,
|
||||
run=self.get_model_loader(),
|
||||
post_hook=self.get_post_hook(),
|
||||
)
|
||||
|
||||
|
||||
class SGLangCheckpointEngineWorkerExtensionImpl(SGLangCheckpointEngineWorkerExtension):
|
||||
"""
|
||||
Implementation of SGLangCheckpointEngineWorkerExtension that integrates with SGLang's model runner.
|
||||
This class provides the concrete implementation for checkpoint-engine IPC weight updates.
|
||||
"""
|
||||
|
||||
def __init__(self, model_runner):
|
||||
super().__init__()
|
||||
self.model_runner = model_runner
|
||||
|
||||
def get_device_uuid(self) -> str:
|
||||
"""Get the UUID of current device."""
|
||||
# Get device UUID for current device
|
||||
device_id = torch.cuda.current_device()
|
||||
try:
|
||||
return f"GPU-{torch.cuda.get_device_properties(device_id).uuid!s}"
|
||||
except AssertionError as e:
|
||||
raise ValueError(f"Failed to get GPU UUID for device {device_id}") from e
|
||||
|
||||
def get_device_id(self) -> int:
|
||||
"""Get the device ID."""
|
||||
return torch.cuda.current_device()
|
||||
|
||||
def get_model_loader(self) -> Callable:
|
||||
"""Get the model weight loader function."""
|
||||
return self.model_runner.model.load_weights
|
||||
|
||||
def get_post_hook(self) -> Optional[Callable]:
|
||||
"""Get the post-processing hook after weight loading."""
|
||||
|
||||
def post_hook():
|
||||
# Perform post-processing after weight loading similar to DefaultModelLoader
|
||||
try:
|
||||
from sglang.srt.model_loader.loader import device_loading_context
|
||||
|
||||
# Process quantization methods after loading weights
|
||||
for _, module in self.model_runner.model.named_modules():
|
||||
quant_method = getattr(module, "quant_method", None)
|
||||
if quant_method is not None:
|
||||
# Move parameters to device if needed for quantization processing
|
||||
target_device = torch.device(
|
||||
"cuda", torch.cuda.current_device()
|
||||
)
|
||||
with device_loading_context(module, target_device):
|
||||
quant_method.process_weights_after_loading(module)
|
||||
# Call model-specific post-loading hook if available
|
||||
if hasattr(self.model_runner.model, "post_load_weights"):
|
||||
self.model_runner.model.post_load_weights()
|
||||
except Exception as e:
|
||||
logger.warning(f"Post-hook processing failed: {e}")
|
||||
|
||||
return post_hook
|
||||
@@ -59,6 +59,7 @@ from sglang.srt.managers.io_struct import (
|
||||
UnloadLoRAAdapterReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromIPCReqInput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
)
|
||||
from sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerRouter
|
||||
@@ -649,6 +650,21 @@ class Engine(EngineBase):
|
||||
request=None,
|
||||
)
|
||||
|
||||
def update_weights_from_ipc(
|
||||
self,
|
||||
zmq_handles: Dict[str, str],
|
||||
flush_cache: bool = True,
|
||||
):
|
||||
"""Update weights from IPC for checkpoint-engine integration."""
|
||||
obj = UpdateWeightsFromIPCReqInput(
|
||||
zmq_handles=zmq_handles,
|
||||
flush_cache=flush_cache,
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
return loop.run_until_complete(
|
||||
self.tokenizer_manager.update_weights_from_ipc(obj, None)
|
||||
)
|
||||
|
||||
|
||||
def _set_envs_and_config(server_args: ServerArgs):
|
||||
# Set global environments
|
||||
|
||||
@@ -96,6 +96,7 @@ from sglang.srt.managers.io_struct import (
|
||||
UnloadLoRAAdapterReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromIPCReqInput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
UpdateWeightVersionReqInput,
|
||||
VertexGenerateReqInput,
|
||||
@@ -129,6 +130,7 @@ logger = logging.getLogger(__name__)
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
|
||||
WAIT_WEIGHTS_READY_TIMEOUT = int(os.getenv("SGLANG_WAIT_WEIGHTS_READY_TIMEOUT", 120))
|
||||
|
||||
|
||||
# Store global states
|
||||
@@ -838,6 +840,27 @@ async def update_weights_from_distributed(
|
||||
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
|
||||
@app.post("/update_weights_from_ipc")
|
||||
async def update_weights_from_ipc(obj: UpdateWeightsFromIPCReqInput, request: Request):
|
||||
"""Update the weights from IPC (Inter-Process Communication) for checkpoint-engine integration."""
|
||||
success, message = await _global_state.tokenizer_manager.update_weights_from_ipc(
|
||||
obj, request
|
||||
)
|
||||
|
||||
# Update weight version if provided and weights update was successful
|
||||
if success and obj.weight_version is not None:
|
||||
_update_weight_version_if_provided(obj.weight_version)
|
||||
message += f" Weight version updated to {obj.weight_version}."
|
||||
|
||||
content = {"success": success, "message": message}
|
||||
if success:
|
||||
if _global_state.tokenizer_manager.initial_weights_loaded is False:
|
||||
_global_state.tokenizer_manager.initial_weights_loaded = True
|
||||
return ORJSONResponse(content)
|
||||
else:
|
||||
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
|
||||
@app.post("/update_weight_version")
|
||||
async def update_weight_version(obj: UpdateWeightVersionReqInput, request: Request):
|
||||
"""Update the weight version. This operation requires no active requests."""
|
||||
@@ -1530,6 +1553,8 @@ def _wait_and_warmup(
|
||||
pipe_finish_writer: Optional[multiprocessing.connection.Connection],
|
||||
launch_callback: Optional[Callable[[], None]] = None,
|
||||
):
|
||||
if server_args.checkpoint_engine_wait_weights_before_ready:
|
||||
_wait_weights_ready()
|
||||
if not server_args.skip_server_warmup:
|
||||
if not _execute_server_warmup(
|
||||
server_args,
|
||||
@@ -1552,3 +1577,24 @@ def _wait_and_warmup(
|
||||
|
||||
if launch_callback is not None:
|
||||
launch_callback()
|
||||
|
||||
|
||||
def _wait_weights_ready():
|
||||
"""Wait for weights to be ready within the specified timeout."""
|
||||
timeout = WAIT_WEIGHTS_READY_TIMEOUT
|
||||
start_time = time.time()
|
||||
|
||||
for _ in range(timeout):
|
||||
if _global_state.tokenizer_manager.initial_weights_loaded:
|
||||
logger.info(
|
||||
f"Weights are ready after {time.time() - start_time:.2f} seconds"
|
||||
)
|
||||
return
|
||||
time.sleep(1)
|
||||
|
||||
# Timeout reached without weights being ready
|
||||
logger.error(
|
||||
f"Weights are not ready after waiting {timeout} seconds. "
|
||||
f"Consider increasing SGLANG_WAIT_WEIGHTS_READY_TIMEOUT environment variable. "
|
||||
f"Current status: initial_weights_loaded={_global_state.tokenizer_manager.initial_weights_loaded}"
|
||||
)
|
||||
|
||||
@@ -1080,6 +1080,24 @@ class InitWeightsSendGroupForRemoteInstanceReqInput(BaseReq):
|
||||
backend: str = "nccl"
|
||||
|
||||
|
||||
# Now UpdateWeightsFromIPCReqInput and UpdateWeightsFromIPCReqOutput
|
||||
# are only used by Checkpoint Engine (https://github.com/MoonshotAI/checkpoint-engine)
|
||||
@dataclass
|
||||
class UpdateWeightsFromIPCReqInput(BaseReq):
|
||||
# ZMQ socket paths for each device UUID
|
||||
zmq_handles: Dict[str, str]
|
||||
# Whether to flush cache after weight update
|
||||
flush_cache: bool = True
|
||||
# Optional: Update weight version along with weights
|
||||
weight_version: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateWeightsFromIPCReqOutput(BaseReq):
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class InitWeightsSendGroupForRemoteInstanceReqOutput(BaseReq):
|
||||
success: bool
|
||||
|
||||
@@ -109,6 +109,7 @@ from sglang.srt.managers.io_struct import (
|
||||
UnloadLoRAAdapterReqOutput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromIPCReqInput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
)
|
||||
from sglang.srt.managers.mm_utils import init_embedding_cache
|
||||
@@ -530,6 +531,7 @@ class Scheduler(
|
||||
self.update_weights_from_distributed,
|
||||
),
|
||||
(UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
|
||||
(UpdateWeightsFromIPCReqInput, self.update_weights_from_ipc),
|
||||
(GetWeightsByNameReqInput, self.get_weights_by_name),
|
||||
(ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
|
||||
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
|
||||
|
||||
@@ -21,6 +21,8 @@ from sglang.srt.managers.io_struct import (
|
||||
UpdateWeightFromDiskReqOutput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromDistributedReqOutput,
|
||||
UpdateWeightsFromIPCReqInput,
|
||||
UpdateWeightsFromIPCReqOutput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
UpdateWeightsFromTensorReqOutput,
|
||||
)
|
||||
@@ -80,6 +82,18 @@ class SchedulerUpdateWeightsMixin:
|
||||
torch.distributed.barrier(group=self.tp_cpu_group)
|
||||
return UpdateWeightsFromTensorReqOutput(success, message)
|
||||
|
||||
def update_weights_from_ipc(self, recv_req: UpdateWeightsFromIPCReqInput):
|
||||
"""Update the online model parameter from IPC for checkpoint-engine integration."""
|
||||
success, message = self.tp_worker.update_weights_from_ipc(recv_req)
|
||||
if success:
|
||||
if recv_req.flush_cache:
|
||||
flush_cache_success = self.flush_cache()
|
||||
assert flush_cache_success, "Cache flush failed after updating weights"
|
||||
else:
|
||||
logger.error(message)
|
||||
torch.distributed.barrier(group=self.tp_cpu_group)
|
||||
return UpdateWeightsFromIPCReqOutput(success, message)
|
||||
|
||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
||||
return GetWeightsByNameReqOutput(parameter)
|
||||
|
||||
@@ -63,6 +63,8 @@ from sglang.srt.managers.io_struct import (
|
||||
UnloadLoRAAdapterReqOutput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromDistributedReqOutput,
|
||||
UpdateWeightsFromIPCReqInput,
|
||||
UpdateWeightsFromIPCReqOutput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
UpdateWeightsFromTensorReqOutput,
|
||||
)
|
||||
@@ -169,6 +171,9 @@ class TokenizerCommunicatorMixin:
|
||||
self.update_weights_from_tensor_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.update_weights_from_ipc_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.get_weights_by_name_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
@@ -235,6 +240,10 @@ class TokenizerCommunicatorMixin:
|
||||
UpdateWeightsFromTensorReqOutput,
|
||||
self.update_weights_from_tensor_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
UpdateWeightsFromIPCReqOutput,
|
||||
self.update_weights_from_ipc_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
GetWeightsByNameReqOutput,
|
||||
self.get_weights_by_name_communicator.handle_recv,
|
||||
@@ -442,6 +451,28 @@ class TokenizerCommunicatorMixin:
|
||||
result = (await self.update_weights_from_tensor_communicator(obj))[0]
|
||||
return result.success, result.message
|
||||
|
||||
async def update_weights_from_ipc(
|
||||
self,
|
||||
obj: UpdateWeightsFromIPCReqInput,
|
||||
request: Optional[fastapi.Request] = None,
|
||||
) -> Tuple[bool, str]:
|
||||
"""Update weights via IPC for checkpoint-engine integration."""
|
||||
self.auto_create_handle_loop()
|
||||
try:
|
||||
# For now, we only support single data parallel instance
|
||||
assert (
|
||||
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
||||
), "dp_size must be 1 or dp attention must be enabled for update weights from IPC"
|
||||
logger.info("Starting IPC weight update")
|
||||
# This means that weight sync cannot run while requests are in progress.
|
||||
async with self.model_update_lock.writer_lock:
|
||||
result = (await self.update_weights_from_ipc_communicator(obj))[0]
|
||||
return result.success, result.message
|
||||
except Exception as e:
|
||||
error_msg = f"IPC weight update failed: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
return False, error_msg
|
||||
|
||||
async def load_lora_adapter(
|
||||
self: TokenizerManager,
|
||||
obj: LoadLoRAAdapterReqInput,
|
||||
|
||||
@@ -284,6 +284,11 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
self.gracefully_exit = False
|
||||
self.last_receive_tstamp = 0
|
||||
|
||||
# Initial weights status
|
||||
self.initial_weights_loaded = True
|
||||
if server_args.checkpoint_engine_wait_weights_before_ready:
|
||||
self.initial_weights_loaded = False
|
||||
|
||||
# Dumping
|
||||
self.dump_requests_folder = "" # By default do not dump
|
||||
self.dump_requests_threshold = 1000
|
||||
|
||||
@@ -32,6 +32,7 @@ from sglang.srt.managers.io_struct import (
|
||||
UnloadLoRAAdapterReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromIPCReqInput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||
@@ -164,6 +165,11 @@ class BaseTpWorker(ABC):
|
||||
)
|
||||
return success, message
|
||||
|
||||
def update_weights_from_ipc(self, recv_req: UpdateWeightsFromIPCReqInput):
|
||||
"""Update weights from IPC for checkpoint-engine integration."""
|
||||
success, message = self.model_runner.update_weights_from_ipc(recv_req)
|
||||
return success, message
|
||||
|
||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||
parameter = self.model_runner.get_weights_by_name(
|
||||
recv_req.name, recv_req.truncate_size
|
||||
|
||||
@@ -2387,6 +2387,23 @@ class ModelRunner:
|
||||
)
|
||||
ShardedStateLoader.save_model(self.model, path, pattern, max_size)
|
||||
|
||||
def update_weights_from_ipc(self, recv_req):
|
||||
"""Update weights from IPC for checkpoint-engine integration."""
|
||||
try:
|
||||
from sglang.srt.checkpoint_engine.checkpoint_engine_worker import (
|
||||
SGLangCheckpointEngineWorkerExtensionImpl,
|
||||
)
|
||||
|
||||
# Create a worker extension that integrates with SGLang's model
|
||||
worker = SGLangCheckpointEngineWorkerExtensionImpl(self)
|
||||
worker.update_weights_from_ipc(recv_req.zmq_handles)
|
||||
return True, "IPC weight update completed successfully"
|
||||
except ImportError as e:
|
||||
return False, f"IPC weight update failed: ImportError {e}"
|
||||
except Exception as e:
|
||||
logger.error(f"IPC weight update failed: {e}")
|
||||
return False, str(e)
|
||||
|
||||
|
||||
def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
|
||||
params_dict = dict(model.named_parameters())
|
||||
|
||||
@@ -208,6 +208,7 @@ class ServerArgs:
|
||||
skip_server_warmup: bool = False
|
||||
warmups: Optional[str] = None
|
||||
nccl_port: Optional[int] = None
|
||||
checkpoint_engine_wait_weights_before_ready: bool = False
|
||||
|
||||
# Quantization and data type
|
||||
dtype: str = "auto"
|
||||
@@ -1704,6 +1705,12 @@ class ServerArgs:
|
||||
default=ServerArgs.nccl_port,
|
||||
help="The port for NCCL distributed environment setup. Defaults to a random port.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint-engine-wait-weights-before-ready",
|
||||
action="store_true",
|
||||
help="If set, the server will wait for initial weights to be loaded via checkpoint-engine or other update methods "
|
||||
"before serving inference requests.",
|
||||
)
|
||||
|
||||
# Quantization and data type
|
||||
parser.add_argument(
|
||||
|
||||
@@ -2275,6 +2275,11 @@ def launch_dummy_health_check_server(host, port, enable_metrics):
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/ping")
|
||||
async def ping():
|
||||
"""Could be used by the checkpoint-engine update script to confirm the server is up."""
|
||||
return Response(status_code=200)
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
"""Check the health of the http server."""
|
||||
|
||||
Reference in New Issue
Block a user