[Feature] Support loading weights from ckpt engine worker (#11755)
Signed-off-by: Yang Kaiyong <yangkaiyong.yky@antgroup.com> Signed-off-by: Cruz Zhao <CruzZhao@linux.alibaba.com> Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com> Co-authored-by: Yang Kaiyong <yangkaiyong.yky@antgroup.com> Co-authored-by: Cruz Zhao <CruzZhao@linux.alibaba.com> Co-authored-by: Xuchun Shang <xuchun.shang@gmail.com> Co-authored-by: Shangming Cai <csmthu@gmail.com>
This commit is contained in:
@@ -100,4 +100,5 @@ SGLang supports various environment variables that can be used to configure its
|
|||||||
|
|
||||||
| Environment Variable | Description | Default Value |
|
| Environment Variable | Description | Default Value |
|
||||||
| --- | --- | --- |
|
| --- | --- | --- |
|
||||||
|
| `SGLANG_WAIT_WEIGHTS_READY_TIMEOUT` | Timeout period for waiting on weights | `120` |
|
||||||
| `SGLANG_DISABLE_OUTLINES_DISK_CACHE` | Disable Outlines disk cache | `true` |
|
| `SGLANG_DISABLE_OUTLINES_DISK_CACHE` | Disable Outlines disk cache | `true` |
|
||||||
|
|||||||
241
examples/checkpoint_engine/update.py
Normal file
241
examples/checkpoint_engine/update.py
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
1) Launch the server with wait-for-initial-weights option in one terminal:
|
||||||
|
python -m sglang.launch_server --model-path /workspace/Qwen/Qwen3-4B/ --tensor-parallel-size 2 --port 19730 --load-format dummy --checkpoint-engine-wait-weights-before-ready --mem-fraction-static 0.7
|
||||||
|
|
||||||
|
2) Torchrun this script in another terminal:
|
||||||
|
torchrun --nproc-per-node 2 update.py --update-method broadcast --checkpoint-path /workspace/Qwen/Qwen3-4B/ --inference-parallel-size 2
|
||||||
|
"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
from collections.abc import Callable
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from checkpoint_engine.ps import ParameterServer
|
||||||
|
from loguru import logger
|
||||||
|
from safetensors import safe_open
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def timer(msg: str):
|
||||||
|
start = time.perf_counter()
|
||||||
|
yield
|
||||||
|
end = time.perf_counter()
|
||||||
|
logger.info(f"{msg} duration: {end - start:.2f} seconds")
|
||||||
|
|
||||||
|
|
||||||
|
def check_sglang_ready(
|
||||||
|
endpoint: str, inference_parallel_size: int, uds: str | None = None
|
||||||
|
):
|
||||||
|
if rank != rank // inference_parallel_size * inference_parallel_size:
|
||||||
|
return
|
||||||
|
retry_num = 0
|
||||||
|
transport = None
|
||||||
|
if uds is not None:
|
||||||
|
transport = httpx.HTTPTransport(uds=uds)
|
||||||
|
with httpx.Client(transport=transport) as client:
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
response = client.get(f"{endpoint}/ping", timeout=10)
|
||||||
|
response.raise_for_status()
|
||||||
|
break
|
||||||
|
except (httpx.ConnectError, httpx.HTTPStatusError) as e:
|
||||||
|
if retry_num % 10 == 0:
|
||||||
|
logger.warning(
|
||||||
|
f"fail to check sglang ready, retry {retry_num} times, error: {e}"
|
||||||
|
)
|
||||||
|
retry_num += 1
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
|
||||||
|
def split_checkpoint_files(
|
||||||
|
checkpoint_path: str, rank: int, world_size: int
|
||||||
|
) -> list[str]:
|
||||||
|
checkpoint_files = [
|
||||||
|
os.path.join(checkpoint_path, f)
|
||||||
|
for f in filter(
|
||||||
|
lambda x: x.endswith(".safetensors"), os.listdir(checkpoint_path)
|
||||||
|
)
|
||||||
|
]
|
||||||
|
files_per_rank = (len(checkpoint_files) + world_size - 1) // world_size
|
||||||
|
return checkpoint_files[rank * files_per_rank : (rank + 1) * files_per_rank]
|
||||||
|
|
||||||
|
|
||||||
|
def split_tensors(
|
||||||
|
checkpoint_path: str, rank: int, world_size: int
|
||||||
|
) -> dict[str, torch.Tensor]:
|
||||||
|
index_fn = os.path.join(checkpoint_path, "model.safetensors.index.json")
|
||||||
|
with open(index_fn) as f:
|
||||||
|
weight_map: dict[str, str] = json.load(f)["weight_map"]
|
||||||
|
weights_per_rank = (len(weight_map) + world_size - 1) // world_size
|
||||||
|
fn_tensors: dict[str, list[str]] = defaultdict(list)
|
||||||
|
weight_keys = list(weight_map.items())
|
||||||
|
for name, file in weight_keys[
|
||||||
|
rank * weights_per_rank : (rank + 1) * weights_per_rank
|
||||||
|
]:
|
||||||
|
fn_tensors[file].append(name)
|
||||||
|
named_tensors = {}
|
||||||
|
for file, names in fn_tensors.items():
|
||||||
|
with safe_open(os.path.join(checkpoint_path, file), framework="pt") as f:
|
||||||
|
for name in names:
|
||||||
|
named_tensors[name] = f.get_tensor(name)
|
||||||
|
return named_tensors
|
||||||
|
|
||||||
|
|
||||||
|
def req_inference(
|
||||||
|
endpoint: str,
|
||||||
|
inference_parallel_size: int,
|
||||||
|
timeout: float = 300.0,
|
||||||
|
uds: str | None = None,
|
||||||
|
weight_version: str | None = None,
|
||||||
|
) -> Callable[[list[tuple[str, str]]], None]:
|
||||||
|
rank = int(os.getenv("RANK", 0))
|
||||||
|
src = rank // inference_parallel_size * inference_parallel_size
|
||||||
|
|
||||||
|
def req_func(socket_paths: list[tuple[str, str]]):
|
||||||
|
if rank == src:
|
||||||
|
with httpx.Client(transport=httpx.HTTPTransport(uds=uds)) as client:
|
||||||
|
resp = client.post(
|
||||||
|
f"{endpoint}/update_weights_from_ipc",
|
||||||
|
json={
|
||||||
|
"zmq_handles": dict(
|
||||||
|
socket_paths[src : src + inference_parallel_size]
|
||||||
|
),
|
||||||
|
"flush_cache": True,
|
||||||
|
"weight_version": weight_version,
|
||||||
|
},
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
|
||||||
|
return req_func
|
||||||
|
|
||||||
|
|
||||||
|
def update_weights(
|
||||||
|
ps: ParameterServer,
|
||||||
|
checkpoint_name: str,
|
||||||
|
checkpoint_files: list[str],
|
||||||
|
named_tensors: dict[str, torch.Tensor],
|
||||||
|
req_func: Callable[[list[tuple[str, str]]], None],
|
||||||
|
inference_parallel_size: int,
|
||||||
|
endpoint: str,
|
||||||
|
save_metas_file: str | None = None,
|
||||||
|
update_method: Literal["broadcast", "p2p", "all"] = "broadcast",
|
||||||
|
uds: str | None = None,
|
||||||
|
):
|
||||||
|
ps.register_checkpoint(
|
||||||
|
checkpoint_name, files=checkpoint_files, named_tensors=named_tensors
|
||||||
|
)
|
||||||
|
ps.init_process_group()
|
||||||
|
check_sglang_ready(endpoint, inference_parallel_size, uds)
|
||||||
|
dist.barrier()
|
||||||
|
with timer("Gather metas"):
|
||||||
|
ps.gather_metas(checkpoint_name)
|
||||||
|
if save_metas_file and int(os.getenv("RANK")) == 0:
|
||||||
|
with open(save_metas_file, "wb") as f:
|
||||||
|
pickle.dump(ps.get_metas(), f)
|
||||||
|
|
||||||
|
if update_method == "broadcast" or update_method == "all":
|
||||||
|
with timer("Update weights without setting ranks"):
|
||||||
|
ps.update(checkpoint_name, req_func)
|
||||||
|
|
||||||
|
if update_method == "p2p" or update_method == "all":
|
||||||
|
if update_method:
|
||||||
|
# sleep 2s to wait destroy process group
|
||||||
|
time.sleep(2)
|
||||||
|
with timer("Update weights with setting ranks"):
|
||||||
|
ps.update(
|
||||||
|
checkpoint_name, req_func, ranks=list(range(inference_parallel_size))
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def join(
|
||||||
|
ps: ParameterServer,
|
||||||
|
checkpoint_name: str,
|
||||||
|
load_metas_file: str,
|
||||||
|
req_func: Callable[[list[tuple[str, str]]], None],
|
||||||
|
inference_parallel_size: int,
|
||||||
|
endpoint: str,
|
||||||
|
uds: str | None = None,
|
||||||
|
):
|
||||||
|
assert load_metas_file, "load_metas_file is required"
|
||||||
|
with open(load_metas_file, "rb") as f:
|
||||||
|
metas = pickle.load(f)
|
||||||
|
ps.init_process_group()
|
||||||
|
check_sglang_ready(endpoint, inference_parallel_size, uds)
|
||||||
|
dist.barrier()
|
||||||
|
with timer("Gather metas before join"):
|
||||||
|
ps.gather_metas(checkpoint_name)
|
||||||
|
ps.load_metas(metas)
|
||||||
|
with timer(
|
||||||
|
f"Update weights with setting ranks as range(0, {inference_parallel_size}) by using p2p"
|
||||||
|
):
|
||||||
|
ps.update(checkpoint_name, req_func, ranks=list(range(inference_parallel_size)))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Update weights example")
|
||||||
|
parser.add_argument("--checkpoint-path", type=str, default=None)
|
||||||
|
parser.add_argument("--save-metas-file", type=str, default=None)
|
||||||
|
parser.add_argument("--load-metas-file", type=str, default=None)
|
||||||
|
parser.add_argument("--sleep-time", type=int, default=0)
|
||||||
|
parser.add_argument("--endpoint", type=str, default="http://localhost:19730")
|
||||||
|
parser.add_argument("--inference-parallel-size", type=int, default=8)
|
||||||
|
parser.add_argument("--checkpoint-name", type=str, default="my-checkpoint-iter-0")
|
||||||
|
parser.add_argument("--update-method", type=str, default="broadcast")
|
||||||
|
parser.add_argument("--uds", type=str, default=None)
|
||||||
|
parser.add_argument("--weight-version", type=str, default=None)
|
||||||
|
args = parser.parse_args()
|
||||||
|
rank = int(os.getenv("RANK"))
|
||||||
|
world_size = int(os.getenv("WORLD_SIZE"))
|
||||||
|
req_func = req_inference(
|
||||||
|
args.endpoint,
|
||||||
|
args.inference_parallel_size,
|
||||||
|
uds=args.uds,
|
||||||
|
weight_version=args.weight_version,
|
||||||
|
)
|
||||||
|
ps = ParameterServer(auto_pg=True)
|
||||||
|
ps._p2p_store = None
|
||||||
|
if args.load_metas_file:
|
||||||
|
join(
|
||||||
|
ps,
|
||||||
|
args.checkpoint_name,
|
||||||
|
args.load_metas_file,
|
||||||
|
req_func,
|
||||||
|
args.inference_parallel_size,
|
||||||
|
args.endpoint,
|
||||||
|
args.uds,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if os.path.exists(
|
||||||
|
os.path.join(args.checkpoint_path, "model.safetensors.index.json")
|
||||||
|
):
|
||||||
|
named_tensors = split_tensors(args.checkpoint_path, rank, world_size)
|
||||||
|
checkpoint_files = []
|
||||||
|
else:
|
||||||
|
checkpoint_files = split_checkpoint_files(
|
||||||
|
args.checkpoint_path, rank, world_size
|
||||||
|
)
|
||||||
|
named_tensors = {}
|
||||||
|
update_weights(
|
||||||
|
ps,
|
||||||
|
args.checkpoint_name,
|
||||||
|
checkpoint_files,
|
||||||
|
named_tensors,
|
||||||
|
req_func,
|
||||||
|
args.inference_parallel_size,
|
||||||
|
args.endpoint,
|
||||||
|
args.save_metas_file,
|
||||||
|
args.update_method,
|
||||||
|
args.uds,
|
||||||
|
)
|
||||||
|
time.sleep(args.sleep_time)
|
||||||
@@ -89,6 +89,7 @@ test = [
|
|||||||
"sentence_transformers",
|
"sentence_transformers",
|
||||||
"tabulate",
|
"tabulate",
|
||||||
]
|
]
|
||||||
|
checkpoint-engine = ["checkpoint-engine==0.1.2"]
|
||||||
all = []
|
all = []
|
||||||
dev = ["sglang[test]"]
|
dev = ["sglang[test]"]
|
||||||
|
|
||||||
|
|||||||
142
python/sglang/srt/checkpoint_engine/checkpoint_engine_worker.py
Normal file
142
python/sglang/srt/checkpoint_engine/checkpoint_engine_worker.py
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
# Copyright 2023-2024 SGLang Team
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""
|
||||||
|
Checkpoint-engine integration for SGLang.
|
||||||
|
This module provides weight update functionality via IPC for checkpoint-engine compatibility.
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
from typing import Callable, Dict, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import zmq
|
||||||
|
|
||||||
|
try:
|
||||||
|
from checkpoint_engine.worker import update_weights_from_ipc
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"checkpoint-engine is not installed. "
|
||||||
|
"Please install it with: pip install sglang[checkpoint-engine]"
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SGLangCheckpointEngineWorkerExtension:
|
||||||
|
"""
|
||||||
|
Worker extension for SGLang to support checkpoint-engine IPC weight updates.
|
||||||
|
This class provides the interface needed for checkpoint-engine integration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._zmq_ctx: Optional[zmq.Context] = None
|
||||||
|
|
||||||
|
def get_device_uuid(self) -> str:
|
||||||
|
"""Get the UUID of current device."""
|
||||||
|
# We need to implement this to get the device UUID
|
||||||
|
# This will be overridden when integrated into SGLang's worker
|
||||||
|
raise NotImplementedError(
|
||||||
|
"This method should be overridden by SGLang integration"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_device_id(self) -> int:
|
||||||
|
"""Get the device ID."""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"This method should be overridden by SGLang integration"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_model_loader(self) -> Callable:
|
||||||
|
"""Get the model weight loader function."""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"This method should be overridden by SGLang integration"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_post_hook(self) -> Optional[Callable]:
|
||||||
|
"""Get the post-processing hook after weight loading."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
def update_weights_from_ipc(self, zmq_handles: Dict[str, str]):
|
||||||
|
"""
|
||||||
|
Update weights from IPC communication.
|
||||||
|
Args:
|
||||||
|
zmq_handles: Dict mapping device UUID to ZMQ socket path
|
||||||
|
"""
|
||||||
|
if self._zmq_ctx is None:
|
||||||
|
self._zmq_ctx = zmq.Context()
|
||||||
|
device_uuid = self.get_device_uuid()
|
||||||
|
device_id = self.get_device_id()
|
||||||
|
if device_uuid not in zmq_handles:
|
||||||
|
raise ValueError(
|
||||||
|
f"Device UUID {device_uuid} not found in zmq_handles: {list(zmq_handles.keys())}"
|
||||||
|
)
|
||||||
|
update_weights_from_ipc(
|
||||||
|
self._zmq_ctx,
|
||||||
|
zmq_handles[device_uuid],
|
||||||
|
device_id=device_id,
|
||||||
|
run=self.get_model_loader(),
|
||||||
|
post_hook=self.get_post_hook(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SGLangCheckpointEngineWorkerExtensionImpl(SGLangCheckpointEngineWorkerExtension):
|
||||||
|
"""
|
||||||
|
Implementation of SGLangCheckpointEngineWorkerExtension that integrates with SGLang's model runner.
|
||||||
|
This class provides the concrete implementation for checkpoint-engine IPC weight updates.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_runner):
|
||||||
|
super().__init__()
|
||||||
|
self.model_runner = model_runner
|
||||||
|
|
||||||
|
def get_device_uuid(self) -> str:
|
||||||
|
"""Get the UUID of current device."""
|
||||||
|
# Get device UUID for current device
|
||||||
|
device_id = torch.cuda.current_device()
|
||||||
|
try:
|
||||||
|
return f"GPU-{torch.cuda.get_device_properties(device_id).uuid!s}"
|
||||||
|
except AssertionError as e:
|
||||||
|
raise ValueError(f"Failed to get GPU UUID for device {device_id}") from e
|
||||||
|
|
||||||
|
def get_device_id(self) -> int:
|
||||||
|
"""Get the device ID."""
|
||||||
|
return torch.cuda.current_device()
|
||||||
|
|
||||||
|
def get_model_loader(self) -> Callable:
|
||||||
|
"""Get the model weight loader function."""
|
||||||
|
return self.model_runner.model.load_weights
|
||||||
|
|
||||||
|
def get_post_hook(self) -> Optional[Callable]:
|
||||||
|
"""Get the post-processing hook after weight loading."""
|
||||||
|
|
||||||
|
def post_hook():
|
||||||
|
# Perform post-processing after weight loading similar to DefaultModelLoader
|
||||||
|
try:
|
||||||
|
from sglang.srt.model_loader.loader import device_loading_context
|
||||||
|
|
||||||
|
# Process quantization methods after loading weights
|
||||||
|
for _, module in self.model_runner.model.named_modules():
|
||||||
|
quant_method = getattr(module, "quant_method", None)
|
||||||
|
if quant_method is not None:
|
||||||
|
# Move parameters to device if needed for quantization processing
|
||||||
|
target_device = torch.device(
|
||||||
|
"cuda", torch.cuda.current_device()
|
||||||
|
)
|
||||||
|
with device_loading_context(module, target_device):
|
||||||
|
quant_method.process_weights_after_loading(module)
|
||||||
|
# Call model-specific post-loading hook if available
|
||||||
|
if hasattr(self.model_runner.model, "post_load_weights"):
|
||||||
|
self.model_runner.model.post_load_weights()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Post-hook processing failed: {e}")
|
||||||
|
|
||||||
|
return post_hook
|
||||||
@@ -59,6 +59,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
UnloadLoRAAdapterReqInput,
|
UnloadLoRAAdapterReqInput,
|
||||||
UpdateWeightFromDiskReqInput,
|
UpdateWeightFromDiskReqInput,
|
||||||
UpdateWeightsFromDistributedReqInput,
|
UpdateWeightsFromDistributedReqInput,
|
||||||
|
UpdateWeightsFromIPCReqInput,
|
||||||
UpdateWeightsFromTensorReqInput,
|
UpdateWeightsFromTensorReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerRouter
|
from sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerRouter
|
||||||
@@ -649,6 +650,21 @@ class Engine(EngineBase):
|
|||||||
request=None,
|
request=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def update_weights_from_ipc(
|
||||||
|
self,
|
||||||
|
zmq_handles: Dict[str, str],
|
||||||
|
flush_cache: bool = True,
|
||||||
|
):
|
||||||
|
"""Update weights from IPC for checkpoint-engine integration."""
|
||||||
|
obj = UpdateWeightsFromIPCReqInput(
|
||||||
|
zmq_handles=zmq_handles,
|
||||||
|
flush_cache=flush_cache,
|
||||||
|
)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
return loop.run_until_complete(
|
||||||
|
self.tokenizer_manager.update_weights_from_ipc(obj, None)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _set_envs_and_config(server_args: ServerArgs):
|
def _set_envs_and_config(server_args: ServerArgs):
|
||||||
# Set global environments
|
# Set global environments
|
||||||
|
|||||||
@@ -96,6 +96,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
UnloadLoRAAdapterReqInput,
|
UnloadLoRAAdapterReqInput,
|
||||||
UpdateWeightFromDiskReqInput,
|
UpdateWeightFromDiskReqInput,
|
||||||
UpdateWeightsFromDistributedReqInput,
|
UpdateWeightsFromDistributedReqInput,
|
||||||
|
UpdateWeightsFromIPCReqInput,
|
||||||
UpdateWeightsFromTensorReqInput,
|
UpdateWeightsFromTensorReqInput,
|
||||||
UpdateWeightVersionReqInput,
|
UpdateWeightVersionReqInput,
|
||||||
VertexGenerateReqInput,
|
VertexGenerateReqInput,
|
||||||
@@ -129,6 +130,7 @@ logger = logging.getLogger(__name__)
|
|||||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||||
|
|
||||||
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
|
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
|
||||||
|
WAIT_WEIGHTS_READY_TIMEOUT = int(os.getenv("SGLANG_WAIT_WEIGHTS_READY_TIMEOUT", 120))
|
||||||
|
|
||||||
|
|
||||||
# Store global states
|
# Store global states
|
||||||
@@ -838,6 +840,27 @@ async def update_weights_from_distributed(
|
|||||||
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/update_weights_from_ipc")
|
||||||
|
async def update_weights_from_ipc(obj: UpdateWeightsFromIPCReqInput, request: Request):
|
||||||
|
"""Update the weights from IPC (Inter-Process Communication) for checkpoint-engine integration."""
|
||||||
|
success, message = await _global_state.tokenizer_manager.update_weights_from_ipc(
|
||||||
|
obj, request
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update weight version if provided and weights update was successful
|
||||||
|
if success and obj.weight_version is not None:
|
||||||
|
_update_weight_version_if_provided(obj.weight_version)
|
||||||
|
message += f" Weight version updated to {obj.weight_version}."
|
||||||
|
|
||||||
|
content = {"success": success, "message": message}
|
||||||
|
if success:
|
||||||
|
if _global_state.tokenizer_manager.initial_weights_loaded is False:
|
||||||
|
_global_state.tokenizer_manager.initial_weights_loaded = True
|
||||||
|
return ORJSONResponse(content)
|
||||||
|
else:
|
||||||
|
return ORJSONResponse(content, status_code=HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
|
|
||||||
@app.post("/update_weight_version")
|
@app.post("/update_weight_version")
|
||||||
async def update_weight_version(obj: UpdateWeightVersionReqInput, request: Request):
|
async def update_weight_version(obj: UpdateWeightVersionReqInput, request: Request):
|
||||||
"""Update the weight version. This operation requires no active requests."""
|
"""Update the weight version. This operation requires no active requests."""
|
||||||
@@ -1530,6 +1553,8 @@ def _wait_and_warmup(
|
|||||||
pipe_finish_writer: Optional[multiprocessing.connection.Connection],
|
pipe_finish_writer: Optional[multiprocessing.connection.Connection],
|
||||||
launch_callback: Optional[Callable[[], None]] = None,
|
launch_callback: Optional[Callable[[], None]] = None,
|
||||||
):
|
):
|
||||||
|
if server_args.checkpoint_engine_wait_weights_before_ready:
|
||||||
|
_wait_weights_ready()
|
||||||
if not server_args.skip_server_warmup:
|
if not server_args.skip_server_warmup:
|
||||||
if not _execute_server_warmup(
|
if not _execute_server_warmup(
|
||||||
server_args,
|
server_args,
|
||||||
@@ -1552,3 +1577,24 @@ def _wait_and_warmup(
|
|||||||
|
|
||||||
if launch_callback is not None:
|
if launch_callback is not None:
|
||||||
launch_callback()
|
launch_callback()
|
||||||
|
|
||||||
|
|
||||||
|
def _wait_weights_ready():
|
||||||
|
"""Wait for weights to be ready within the specified timeout."""
|
||||||
|
timeout = WAIT_WEIGHTS_READY_TIMEOUT
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
for _ in range(timeout):
|
||||||
|
if _global_state.tokenizer_manager.initial_weights_loaded:
|
||||||
|
logger.info(
|
||||||
|
f"Weights are ready after {time.time() - start_time:.2f} seconds"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
# Timeout reached without weights being ready
|
||||||
|
logger.error(
|
||||||
|
f"Weights are not ready after waiting {timeout} seconds. "
|
||||||
|
f"Consider increasing SGLANG_WAIT_WEIGHTS_READY_TIMEOUT environment variable. "
|
||||||
|
f"Current status: initial_weights_loaded={_global_state.tokenizer_manager.initial_weights_loaded}"
|
||||||
|
)
|
||||||
|
|||||||
@@ -1080,6 +1080,24 @@ class InitWeightsSendGroupForRemoteInstanceReqInput(BaseReq):
|
|||||||
backend: str = "nccl"
|
backend: str = "nccl"
|
||||||
|
|
||||||
|
|
||||||
|
# Now UpdateWeightsFromIPCReqInput and UpdateWeightsFromIPCReqOutput
|
||||||
|
# are only used by Checkpoint Engine (https://github.com/MoonshotAI/checkpoint-engine)
|
||||||
|
@dataclass
|
||||||
|
class UpdateWeightsFromIPCReqInput(BaseReq):
|
||||||
|
# ZMQ socket paths for each device UUID
|
||||||
|
zmq_handles: Dict[str, str]
|
||||||
|
# Whether to flush cache after weight update
|
||||||
|
flush_cache: bool = True
|
||||||
|
# Optional: Update weight version along with weights
|
||||||
|
weight_version: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UpdateWeightsFromIPCReqOutput(BaseReq):
|
||||||
|
success: bool
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InitWeightsSendGroupForRemoteInstanceReqOutput(BaseReq):
|
class InitWeightsSendGroupForRemoteInstanceReqOutput(BaseReq):
|
||||||
success: bool
|
success: bool
|
||||||
|
|||||||
@@ -109,6 +109,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
UnloadLoRAAdapterReqOutput,
|
UnloadLoRAAdapterReqOutput,
|
||||||
UpdateWeightFromDiskReqInput,
|
UpdateWeightFromDiskReqInput,
|
||||||
UpdateWeightsFromDistributedReqInput,
|
UpdateWeightsFromDistributedReqInput,
|
||||||
|
UpdateWeightsFromIPCReqInput,
|
||||||
UpdateWeightsFromTensorReqInput,
|
UpdateWeightsFromTensorReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.mm_utils import init_embedding_cache
|
from sglang.srt.managers.mm_utils import init_embedding_cache
|
||||||
@@ -530,6 +531,7 @@ class Scheduler(
|
|||||||
self.update_weights_from_distributed,
|
self.update_weights_from_distributed,
|
||||||
),
|
),
|
||||||
(UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
|
(UpdateWeightsFromTensorReqInput, self.update_weights_from_tensor),
|
||||||
|
(UpdateWeightsFromIPCReqInput, self.update_weights_from_ipc),
|
||||||
(GetWeightsByNameReqInput, self.get_weights_by_name),
|
(GetWeightsByNameReqInput, self.get_weights_by_name),
|
||||||
(ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
|
(ReleaseMemoryOccupationReqInput, self.release_memory_occupation),
|
||||||
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
|
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
|
||||||
|
|||||||
@@ -21,6 +21,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
UpdateWeightFromDiskReqOutput,
|
UpdateWeightFromDiskReqOutput,
|
||||||
UpdateWeightsFromDistributedReqInput,
|
UpdateWeightsFromDistributedReqInput,
|
||||||
UpdateWeightsFromDistributedReqOutput,
|
UpdateWeightsFromDistributedReqOutput,
|
||||||
|
UpdateWeightsFromIPCReqInput,
|
||||||
|
UpdateWeightsFromIPCReqOutput,
|
||||||
UpdateWeightsFromTensorReqInput,
|
UpdateWeightsFromTensorReqInput,
|
||||||
UpdateWeightsFromTensorReqOutput,
|
UpdateWeightsFromTensorReqOutput,
|
||||||
)
|
)
|
||||||
@@ -80,6 +82,18 @@ class SchedulerUpdateWeightsMixin:
|
|||||||
torch.distributed.barrier(group=self.tp_cpu_group)
|
torch.distributed.barrier(group=self.tp_cpu_group)
|
||||||
return UpdateWeightsFromTensorReqOutput(success, message)
|
return UpdateWeightsFromTensorReqOutput(success, message)
|
||||||
|
|
||||||
|
def update_weights_from_ipc(self, recv_req: UpdateWeightsFromIPCReqInput):
|
||||||
|
"""Update the online model parameter from IPC for checkpoint-engine integration."""
|
||||||
|
success, message = self.tp_worker.update_weights_from_ipc(recv_req)
|
||||||
|
if success:
|
||||||
|
if recv_req.flush_cache:
|
||||||
|
flush_cache_success = self.flush_cache()
|
||||||
|
assert flush_cache_success, "Cache flush failed after updating weights"
|
||||||
|
else:
|
||||||
|
logger.error(message)
|
||||||
|
torch.distributed.barrier(group=self.tp_cpu_group)
|
||||||
|
return UpdateWeightsFromIPCReqOutput(success, message)
|
||||||
|
|
||||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||||
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
parameter = self.tp_worker.get_weights_by_name(recv_req)
|
||||||
return GetWeightsByNameReqOutput(parameter)
|
return GetWeightsByNameReqOutput(parameter)
|
||||||
|
|||||||
@@ -63,6 +63,8 @@ from sglang.srt.managers.io_struct import (
|
|||||||
UnloadLoRAAdapterReqOutput,
|
UnloadLoRAAdapterReqOutput,
|
||||||
UpdateWeightsFromDistributedReqInput,
|
UpdateWeightsFromDistributedReqInput,
|
||||||
UpdateWeightsFromDistributedReqOutput,
|
UpdateWeightsFromDistributedReqOutput,
|
||||||
|
UpdateWeightsFromIPCReqInput,
|
||||||
|
UpdateWeightsFromIPCReqOutput,
|
||||||
UpdateWeightsFromTensorReqInput,
|
UpdateWeightsFromTensorReqInput,
|
||||||
UpdateWeightsFromTensorReqOutput,
|
UpdateWeightsFromTensorReqOutput,
|
||||||
)
|
)
|
||||||
@@ -169,6 +171,9 @@ class TokenizerCommunicatorMixin:
|
|||||||
self.update_weights_from_tensor_communicator = _Communicator(
|
self.update_weights_from_tensor_communicator = _Communicator(
|
||||||
self.send_to_scheduler, server_args.dp_size
|
self.send_to_scheduler, server_args.dp_size
|
||||||
)
|
)
|
||||||
|
self.update_weights_from_ipc_communicator = _Communicator(
|
||||||
|
self.send_to_scheduler, server_args.dp_size
|
||||||
|
)
|
||||||
self.get_weights_by_name_communicator = _Communicator(
|
self.get_weights_by_name_communicator = _Communicator(
|
||||||
self.send_to_scheduler, server_args.dp_size
|
self.send_to_scheduler, server_args.dp_size
|
||||||
)
|
)
|
||||||
@@ -235,6 +240,10 @@ class TokenizerCommunicatorMixin:
|
|||||||
UpdateWeightsFromTensorReqOutput,
|
UpdateWeightsFromTensorReqOutput,
|
||||||
self.update_weights_from_tensor_communicator.handle_recv,
|
self.update_weights_from_tensor_communicator.handle_recv,
|
||||||
),
|
),
|
||||||
|
(
|
||||||
|
UpdateWeightsFromIPCReqOutput,
|
||||||
|
self.update_weights_from_ipc_communicator.handle_recv,
|
||||||
|
),
|
||||||
(
|
(
|
||||||
GetWeightsByNameReqOutput,
|
GetWeightsByNameReqOutput,
|
||||||
self.get_weights_by_name_communicator.handle_recv,
|
self.get_weights_by_name_communicator.handle_recv,
|
||||||
@@ -442,6 +451,28 @@ class TokenizerCommunicatorMixin:
|
|||||||
result = (await self.update_weights_from_tensor_communicator(obj))[0]
|
result = (await self.update_weights_from_tensor_communicator(obj))[0]
|
||||||
return result.success, result.message
|
return result.success, result.message
|
||||||
|
|
||||||
|
async def update_weights_from_ipc(
|
||||||
|
self,
|
||||||
|
obj: UpdateWeightsFromIPCReqInput,
|
||||||
|
request: Optional[fastapi.Request] = None,
|
||||||
|
) -> Tuple[bool, str]:
|
||||||
|
"""Update weights via IPC for checkpoint-engine integration."""
|
||||||
|
self.auto_create_handle_loop()
|
||||||
|
try:
|
||||||
|
# For now, we only support single data parallel instance
|
||||||
|
assert (
|
||||||
|
self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
|
||||||
|
), "dp_size must be 1 or dp attention must be enabled for update weights from IPC"
|
||||||
|
logger.info("Starting IPC weight update")
|
||||||
|
# This means that weight sync cannot run while requests are in progress.
|
||||||
|
async with self.model_update_lock.writer_lock:
|
||||||
|
result = (await self.update_weights_from_ipc_communicator(obj))[0]
|
||||||
|
return result.success, result.message
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"IPC weight update failed: {str(e)}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return False, error_msg
|
||||||
|
|
||||||
async def load_lora_adapter(
|
async def load_lora_adapter(
|
||||||
self: TokenizerManager,
|
self: TokenizerManager,
|
||||||
obj: LoadLoRAAdapterReqInput,
|
obj: LoadLoRAAdapterReqInput,
|
||||||
|
|||||||
@@ -284,6 +284,11 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
|||||||
self.gracefully_exit = False
|
self.gracefully_exit = False
|
||||||
self.last_receive_tstamp = 0
|
self.last_receive_tstamp = 0
|
||||||
|
|
||||||
|
# Initial weights status
|
||||||
|
self.initial_weights_loaded = True
|
||||||
|
if server_args.checkpoint_engine_wait_weights_before_ready:
|
||||||
|
self.initial_weights_loaded = False
|
||||||
|
|
||||||
# Dumping
|
# Dumping
|
||||||
self.dump_requests_folder = "" # By default do not dump
|
self.dump_requests_folder = "" # By default do not dump
|
||||||
self.dump_requests_threshold = 1000
|
self.dump_requests_threshold = 1000
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
UnloadLoRAAdapterReqInput,
|
UnloadLoRAAdapterReqInput,
|
||||||
UpdateWeightFromDiskReqInput,
|
UpdateWeightFromDiskReqInput,
|
||||||
UpdateWeightsFromDistributedReqInput,
|
UpdateWeightsFromDistributedReqInput,
|
||||||
|
UpdateWeightsFromIPCReqInput,
|
||||||
UpdateWeightsFromTensorReqInput,
|
UpdateWeightsFromTensorReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||||
@@ -164,6 +165,11 @@ class BaseTpWorker(ABC):
|
|||||||
)
|
)
|
||||||
return success, message
|
return success, message
|
||||||
|
|
||||||
|
def update_weights_from_ipc(self, recv_req: UpdateWeightsFromIPCReqInput):
|
||||||
|
"""Update weights from IPC for checkpoint-engine integration."""
|
||||||
|
success, message = self.model_runner.update_weights_from_ipc(recv_req)
|
||||||
|
return success, message
|
||||||
|
|
||||||
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
def get_weights_by_name(self, recv_req: GetWeightsByNameReqInput):
|
||||||
parameter = self.model_runner.get_weights_by_name(
|
parameter = self.model_runner.get_weights_by_name(
|
||||||
recv_req.name, recv_req.truncate_size
|
recv_req.name, recv_req.truncate_size
|
||||||
|
|||||||
@@ -2387,6 +2387,23 @@ class ModelRunner:
|
|||||||
)
|
)
|
||||||
ShardedStateLoader.save_model(self.model, path, pattern, max_size)
|
ShardedStateLoader.save_model(self.model, path, pattern, max_size)
|
||||||
|
|
||||||
|
def update_weights_from_ipc(self, recv_req):
|
||||||
|
"""Update weights from IPC for checkpoint-engine integration."""
|
||||||
|
try:
|
||||||
|
from sglang.srt.checkpoint_engine.checkpoint_engine_worker import (
|
||||||
|
SGLangCheckpointEngineWorkerExtensionImpl,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a worker extension that integrates with SGLang's model
|
||||||
|
worker = SGLangCheckpointEngineWorkerExtensionImpl(self)
|
||||||
|
worker.update_weights_from_ipc(recv_req.zmq_handles)
|
||||||
|
return True, "IPC weight update completed successfully"
|
||||||
|
except ImportError as e:
|
||||||
|
return False, f"IPC weight update failed: ImportError {e}"
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"IPC weight update failed: {e}")
|
||||||
|
return False, str(e)
|
||||||
|
|
||||||
|
|
||||||
def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
|
def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
|
||||||
params_dict = dict(model.named_parameters())
|
params_dict = dict(model.named_parameters())
|
||||||
|
|||||||
@@ -208,6 +208,7 @@ class ServerArgs:
|
|||||||
skip_server_warmup: bool = False
|
skip_server_warmup: bool = False
|
||||||
warmups: Optional[str] = None
|
warmups: Optional[str] = None
|
||||||
nccl_port: Optional[int] = None
|
nccl_port: Optional[int] = None
|
||||||
|
checkpoint_engine_wait_weights_before_ready: bool = False
|
||||||
|
|
||||||
# Quantization and data type
|
# Quantization and data type
|
||||||
dtype: str = "auto"
|
dtype: str = "auto"
|
||||||
@@ -1704,6 +1705,12 @@ class ServerArgs:
|
|||||||
default=ServerArgs.nccl_port,
|
default=ServerArgs.nccl_port,
|
||||||
help="The port for NCCL distributed environment setup. Defaults to a random port.",
|
help="The port for NCCL distributed environment setup. Defaults to a random port.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint-engine-wait-weights-before-ready",
|
||||||
|
action="store_true",
|
||||||
|
help="If set, the server will wait for initial weights to be loaded via checkpoint-engine or other update methods "
|
||||||
|
"before serving inference requests.",
|
||||||
|
)
|
||||||
|
|
||||||
# Quantization and data type
|
# Quantization and data type
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
|||||||
@@ -2275,6 +2275,11 @@ def launch_dummy_health_check_server(host, port, enable_metrics):
|
|||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.get("/ping")
|
||||||
|
async def ping():
|
||||||
|
"""Could be used by the checkpoint-engine update script to confirm the server is up."""
|
||||||
|
return Response(status_code=200)
|
||||||
|
|
||||||
@app.get("/health")
|
@app.get("/health")
|
||||||
async def health():
|
async def health():
|
||||||
"""Check the health of the http server."""
|
"""Check the health of the http server."""
|
||||||
|
|||||||
Reference in New Issue
Block a user