[Feature] Support loading weights from ckpt engine worker (#11755)
Signed-off-by: Yang Kaiyong <yangkaiyong.yky@antgroup.com> Signed-off-by: Cruz Zhao <CruzZhao@linux.alibaba.com> Signed-off-by: Xuchun Shang <xuchun.shang@gmail.com> Co-authored-by: Yang Kaiyong <yangkaiyong.yky@antgroup.com> Co-authored-by: Cruz Zhao <CruzZhao@linux.alibaba.com> Co-authored-by: Xuchun Shang <xuchun.shang@gmail.com> Co-authored-by: Shangming Cai <csmthu@gmail.com>
This commit is contained in:
241
examples/checkpoint_engine/update.py
Normal file
241
examples/checkpoint_engine/update.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""
|
||||
Usage:
|
||||
1) Launch the server with wait-for-initial-weights option in one terminal:
|
||||
python -m sglang.launch_server --model-path /workspace/Qwen/Qwen3-4B/ --tensor-parallel-size 2 --port 19730 --load-format dummy --checkpoint-engine-wait-weights-before-ready --mem-fraction-static 0.7
|
||||
|
||||
2) Torchrun this script in another terminal:
|
||||
torchrun --nproc-per-node 2 update.py --update-method broadcast --checkpoint-path /workspace/Qwen/Qwen3-4B/ --inference-parallel-size 2
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from contextlib import contextmanager
|
||||
from typing import Literal
|
||||
|
||||
import httpx
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from checkpoint_engine.ps import ParameterServer
|
||||
from loguru import logger
|
||||
from safetensors import safe_open
|
||||
|
||||
|
||||
@contextmanager
|
||||
def timer(msg: str):
|
||||
start = time.perf_counter()
|
||||
yield
|
||||
end = time.perf_counter()
|
||||
logger.info(f"{msg} duration: {end - start:.2f} seconds")
|
||||
|
||||
|
||||
def check_sglang_ready(
|
||||
endpoint: str, inference_parallel_size: int, uds: str | None = None
|
||||
):
|
||||
if rank != rank // inference_parallel_size * inference_parallel_size:
|
||||
return
|
||||
retry_num = 0
|
||||
transport = None
|
||||
if uds is not None:
|
||||
transport = httpx.HTTPTransport(uds=uds)
|
||||
with httpx.Client(transport=transport) as client:
|
||||
while True:
|
||||
try:
|
||||
response = client.get(f"{endpoint}/ping", timeout=10)
|
||||
response.raise_for_status()
|
||||
break
|
||||
except (httpx.ConnectError, httpx.HTTPStatusError) as e:
|
||||
if retry_num % 10 == 0:
|
||||
logger.warning(
|
||||
f"fail to check sglang ready, retry {retry_num} times, error: {e}"
|
||||
)
|
||||
retry_num += 1
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
def split_checkpoint_files(
|
||||
checkpoint_path: str, rank: int, world_size: int
|
||||
) -> list[str]:
|
||||
checkpoint_files = [
|
||||
os.path.join(checkpoint_path, f)
|
||||
for f in filter(
|
||||
lambda x: x.endswith(".safetensors"), os.listdir(checkpoint_path)
|
||||
)
|
||||
]
|
||||
files_per_rank = (len(checkpoint_files) + world_size - 1) // world_size
|
||||
return checkpoint_files[rank * files_per_rank : (rank + 1) * files_per_rank]
|
||||
|
||||
|
||||
def split_tensors(
|
||||
checkpoint_path: str, rank: int, world_size: int
|
||||
) -> dict[str, torch.Tensor]:
|
||||
index_fn = os.path.join(checkpoint_path, "model.safetensors.index.json")
|
||||
with open(index_fn) as f:
|
||||
weight_map: dict[str, str] = json.load(f)["weight_map"]
|
||||
weights_per_rank = (len(weight_map) + world_size - 1) // world_size
|
||||
fn_tensors: dict[str, list[str]] = defaultdict(list)
|
||||
weight_keys = list(weight_map.items())
|
||||
for name, file in weight_keys[
|
||||
rank * weights_per_rank : (rank + 1) * weights_per_rank
|
||||
]:
|
||||
fn_tensors[file].append(name)
|
||||
named_tensors = {}
|
||||
for file, names in fn_tensors.items():
|
||||
with safe_open(os.path.join(checkpoint_path, file), framework="pt") as f:
|
||||
for name in names:
|
||||
named_tensors[name] = f.get_tensor(name)
|
||||
return named_tensors
|
||||
|
||||
|
||||
def req_inference(
|
||||
endpoint: str,
|
||||
inference_parallel_size: int,
|
||||
timeout: float = 300.0,
|
||||
uds: str | None = None,
|
||||
weight_version: str | None = None,
|
||||
) -> Callable[[list[tuple[str, str]]], None]:
|
||||
rank = int(os.getenv("RANK", 0))
|
||||
src = rank // inference_parallel_size * inference_parallel_size
|
||||
|
||||
def req_func(socket_paths: list[tuple[str, str]]):
|
||||
if rank == src:
|
||||
with httpx.Client(transport=httpx.HTTPTransport(uds=uds)) as client:
|
||||
resp = client.post(
|
||||
f"{endpoint}/update_weights_from_ipc",
|
||||
json={
|
||||
"zmq_handles": dict(
|
||||
socket_paths[src : src + inference_parallel_size]
|
||||
),
|
||||
"flush_cache": True,
|
||||
"weight_version": weight_version,
|
||||
},
|
||||
timeout=timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
|
||||
return req_func
|
||||
|
||||
|
||||
def update_weights(
|
||||
ps: ParameterServer,
|
||||
checkpoint_name: str,
|
||||
checkpoint_files: list[str],
|
||||
named_tensors: dict[str, torch.Tensor],
|
||||
req_func: Callable[[list[tuple[str, str]]], None],
|
||||
inference_parallel_size: int,
|
||||
endpoint: str,
|
||||
save_metas_file: str | None = None,
|
||||
update_method: Literal["broadcast", "p2p", "all"] = "broadcast",
|
||||
uds: str | None = None,
|
||||
):
|
||||
ps.register_checkpoint(
|
||||
checkpoint_name, files=checkpoint_files, named_tensors=named_tensors
|
||||
)
|
||||
ps.init_process_group()
|
||||
check_sglang_ready(endpoint, inference_parallel_size, uds)
|
||||
dist.barrier()
|
||||
with timer("Gather metas"):
|
||||
ps.gather_metas(checkpoint_name)
|
||||
if save_metas_file and int(os.getenv("RANK")) == 0:
|
||||
with open(save_metas_file, "wb") as f:
|
||||
pickle.dump(ps.get_metas(), f)
|
||||
|
||||
if update_method == "broadcast" or update_method == "all":
|
||||
with timer("Update weights without setting ranks"):
|
||||
ps.update(checkpoint_name, req_func)
|
||||
|
||||
if update_method == "p2p" or update_method == "all":
|
||||
if update_method:
|
||||
# sleep 2s to wait destroy process group
|
||||
time.sleep(2)
|
||||
with timer("Update weights with setting ranks"):
|
||||
ps.update(
|
||||
checkpoint_name, req_func, ranks=list(range(inference_parallel_size))
|
||||
)
|
||||
|
||||
|
||||
def join(
|
||||
ps: ParameterServer,
|
||||
checkpoint_name: str,
|
||||
load_metas_file: str,
|
||||
req_func: Callable[[list[tuple[str, str]]], None],
|
||||
inference_parallel_size: int,
|
||||
endpoint: str,
|
||||
uds: str | None = None,
|
||||
):
|
||||
assert load_metas_file, "load_metas_file is required"
|
||||
with open(load_metas_file, "rb") as f:
|
||||
metas = pickle.load(f)
|
||||
ps.init_process_group()
|
||||
check_sglang_ready(endpoint, inference_parallel_size, uds)
|
||||
dist.barrier()
|
||||
with timer("Gather metas before join"):
|
||||
ps.gather_metas(checkpoint_name)
|
||||
ps.load_metas(metas)
|
||||
with timer(
|
||||
f"Update weights with setting ranks as range(0, {inference_parallel_size}) by using p2p"
|
||||
):
|
||||
ps.update(checkpoint_name, req_func, ranks=list(range(inference_parallel_size)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Update weights example")
|
||||
parser.add_argument("--checkpoint-path", type=str, default=None)
|
||||
parser.add_argument("--save-metas-file", type=str, default=None)
|
||||
parser.add_argument("--load-metas-file", type=str, default=None)
|
||||
parser.add_argument("--sleep-time", type=int, default=0)
|
||||
parser.add_argument("--endpoint", type=str, default="http://localhost:19730")
|
||||
parser.add_argument("--inference-parallel-size", type=int, default=8)
|
||||
parser.add_argument("--checkpoint-name", type=str, default="my-checkpoint-iter-0")
|
||||
parser.add_argument("--update-method", type=str, default="broadcast")
|
||||
parser.add_argument("--uds", type=str, default=None)
|
||||
parser.add_argument("--weight-version", type=str, default=None)
|
||||
args = parser.parse_args()
|
||||
rank = int(os.getenv("RANK"))
|
||||
world_size = int(os.getenv("WORLD_SIZE"))
|
||||
req_func = req_inference(
|
||||
args.endpoint,
|
||||
args.inference_parallel_size,
|
||||
uds=args.uds,
|
||||
weight_version=args.weight_version,
|
||||
)
|
||||
ps = ParameterServer(auto_pg=True)
|
||||
ps._p2p_store = None
|
||||
if args.load_metas_file:
|
||||
join(
|
||||
ps,
|
||||
args.checkpoint_name,
|
||||
args.load_metas_file,
|
||||
req_func,
|
||||
args.inference_parallel_size,
|
||||
args.endpoint,
|
||||
args.uds,
|
||||
)
|
||||
else:
|
||||
if os.path.exists(
|
||||
os.path.join(args.checkpoint_path, "model.safetensors.index.json")
|
||||
):
|
||||
named_tensors = split_tensors(args.checkpoint_path, rank, world_size)
|
||||
checkpoint_files = []
|
||||
else:
|
||||
checkpoint_files = split_checkpoint_files(
|
||||
args.checkpoint_path, rank, world_size
|
||||
)
|
||||
named_tensors = {}
|
||||
update_weights(
|
||||
ps,
|
||||
args.checkpoint_name,
|
||||
checkpoint_files,
|
||||
named_tensors,
|
||||
req_func,
|
||||
args.inference_parallel_size,
|
||||
args.endpoint,
|
||||
args.save_metas_file,
|
||||
args.update_method,
|
||||
args.uds,
|
||||
)
|
||||
time.sleep(args.sleep_time)
|
||||
Reference in New Issue
Block a user