feat(remote_model): support variable remote backend for model loader (#3964)
Signed-off-by: wangyu <wangyu.steph@bytedance.com>
This commit is contained in:
@@ -27,6 +27,9 @@ import signal
|
||||
import threading
|
||||
from typing import AsyncIterator, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
# Fix a bug of Python threading
|
||||
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
||||
|
||||
@@ -44,6 +47,8 @@ from sglang.srt.managers.io_struct import (
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
ReleaseMemoryOccupationReqInput,
|
||||
ResumeMemoryOccupationReqInput,
|
||||
RpcReqInput,
|
||||
RpcReqOutput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
@@ -57,6 +62,7 @@ from sglang.srt.utils import (
|
||||
MultiprocessingSerializer,
|
||||
assert_pkg_version,
|
||||
configure_logger,
|
||||
get_zmq_socket,
|
||||
kill_process_tree,
|
||||
launch_dummy_health_check_server,
|
||||
maybe_set_triton_cache_manager,
|
||||
@@ -102,15 +108,25 @@ class Engine:
|
||||
# Shutdown the subprocesses automatically when the program exits
|
||||
atexit.register(self.shutdown)
|
||||
|
||||
# Allocate ports for inter-process communications
|
||||
port_args = PortArgs.init_new(server_args)
|
||||
logger.info(f"{server_args=}")
|
||||
|
||||
# Launch subprocesses
|
||||
tokenizer_manager, scheduler_info = _launch_subprocesses(
|
||||
server_args=server_args
|
||||
server_args=server_args,
|
||||
port_args=port_args,
|
||||
)
|
||||
|
||||
self.server_args = server_args
|
||||
self.tokenizer_manager = tokenizer_manager
|
||||
self.scheduler_info = scheduler_info
|
||||
|
||||
context = zmq.Context(2)
|
||||
self.send_to_rpc = get_zmq_socket(
|
||||
context, zmq.DEALER, port_args.rpc_ipc_name, True
|
||||
)
|
||||
|
||||
def generate(
|
||||
self,
|
||||
# The input prompt. It can be a single prompt or a batch of prompts.
|
||||
@@ -350,6 +366,23 @@ class Engine:
|
||||
self.tokenizer_manager.resume_memory_occupation(obj, None)
|
||||
)
|
||||
|
||||
"""
|
||||
Execute an RPC call on all scheduler processes.
|
||||
"""
|
||||
|
||||
def collective_rpc(self, method: str, **kwargs):
|
||||
obj = RpcReqInput(method=method, parameters=kwargs)
|
||||
self.send_to_rpc.send_pyobj(obj)
|
||||
recv_req = self.send_to_rpc.recv_pyobj(zmq.BLOCKY)
|
||||
assert isinstance(recv_req, RpcReqOutput)
|
||||
assert recv_req.success, recv_req.message
|
||||
|
||||
def save_remote_model(self, **kwargs):
|
||||
self.collective_rpc("save_remote_model", **kwargs)
|
||||
|
||||
def save_sharded_model(self, **kwargs):
|
||||
self.collective_rpc("save_sharded_model", **kwargs)
|
||||
|
||||
|
||||
def _set_envs_and_config(server_args: ServerArgs):
|
||||
# Set global environments
|
||||
@@ -408,7 +441,9 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
|
||||
def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dict]:
|
||||
def _launch_subprocesses(
|
||||
server_args: ServerArgs, port_args: Optional[PortArgs] = None
|
||||
) -> Tuple[TokenizerManager, Dict]:
|
||||
"""
|
||||
Launch the TokenizerManager in the main process, the Scheduler in a subprocess, and the DetokenizerManager in another subprocess.
|
||||
"""
|
||||
@@ -418,8 +453,9 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic
|
||||
_set_envs_and_config(server_args)
|
||||
|
||||
# Allocate ports for inter-process communications
|
||||
port_args = PortArgs.init_new(server_args)
|
||||
logger.info(f"{server_args=}")
|
||||
if port_args is None:
|
||||
port_args = PortArgs.init_new(server_args)
|
||||
logger.info(f"{server_args=}")
|
||||
|
||||
# If using model from www.modelscope.cn, first download the model.
|
||||
server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(
|
||||
|
||||
Reference in New Issue
Block a user