feat(remote_model): support variable remote backend for model loader (#3964)

Signed-off-by: wangyu <wangyu.steph@bytedance.com>
This commit is contained in:
wangyu
2025-03-14 15:40:44 +08:00
committed by GitHub
parent 977d7cd26a
commit 1ce4878d31
22 changed files with 1055 additions and 9 deletions

View File

@@ -27,6 +27,9 @@ import signal
import threading
from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union
import zmq
import zmq.asyncio
# Fix a bug of Python threading
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
@@ -44,6 +47,8 @@ from sglang.srt.managers.io_struct import (
InitWeightsUpdateGroupReqInput,
ReleaseMemoryOccupationReqInput,
ResumeMemoryOccupationReqInput,
RpcReqInput,
RpcReqOutput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
UpdateWeightsFromTensorReqInput,
@@ -57,6 +62,7 @@ from sglang.srt.utils import (
MultiprocessingSerializer,
assert_pkg_version,
configure_logger,
get_zmq_socket,
kill_process_tree,
launch_dummy_health_check_server,
maybe_set_triton_cache_manager,
@@ -102,15 +108,25 @@ class Engine:
# Shutdown the subprocesses automatically when the program exits
atexit.register(self.shutdown)
# Allocate ports for inter-process communications
port_args = PortArgs.init_new(server_args)
logger.info(f"{server_args=}")
# Launch subprocesses
tokenizer_manager, scheduler_info = _launch_subprocesses(
server_args=server_args
server_args=server_args,
port_args=port_args,
)
self.server_args = server_args
self.tokenizer_manager = tokenizer_manager
self.scheduler_info = scheduler_info
context = zmq.Context(2)
self.send_to_rpc = get_zmq_socket(
context, zmq.DEALER, port_args.rpc_ipc_name, True
)
def generate(
self,
# The input prompt. It can be a single prompt or a batch of prompts.
@@ -350,6 +366,23 @@ class Engine:
self.tokenizer_manager.resume_memory_occupation(obj, None)
)
"""
Execute an RPC call on all scheduler processes.
"""
def collective_rpc(self, method: str, **kwargs):
obj = RpcReqInput(method=method, parameters=kwargs)
self.send_to_rpc.send_pyobj(obj)
recv_req = self.send_to_rpc.recv_pyobj(zmq.BLOCKY)
assert isinstance(recv_req, RpcReqOutput)
assert recv_req.success, recv_req.message
def save_remote_model(self, **kwargs):
self.collective_rpc("save_remote_model", **kwargs)
def save_sharded_model(self, **kwargs):
self.collective_rpc("save_sharded_model", **kwargs)
def _set_envs_and_config(server_args: ServerArgs):
# Set global environments
@@ -408,7 +441,9 @@ def _set_envs_and_config(server_args: ServerArgs):
mp.set_start_method("spawn", force=True)
def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dict]:
def _launch_subprocesses(
server_args: ServerArgs, port_args: Optional[PortArgs] = None
) -> Tuple[TokenizerManager, Dict]:
"""
Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
"""
@@ -418,8 +453,9 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic
_set_envs_and_config(server_args)
# Allocate ports for inter-process communications
port_args = PortArgs.init_new(server_args)
logger.info(f"{server_args=}")
if port_args is None:
port_args = PortArgs.init_new(server_args)
logger.info(f"{server_args=}")
# If using model from www.modelscope.cn, first download the model.
server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(