feat(remote_model): support variable remote backend for model loader (#3964)

Signed-off-by: wangyu <wangyu.steph@bytedance.com>
This commit is contained in:
wangyu
2025-03-14 15:40:44 +08:00
committed by GitHub
parent 977d7cd26a
commit 1ce4878d31
22 changed files with 1055 additions and 9 deletions

View File

@@ -32,6 +32,7 @@ import psutil
import setproctitle
import torch
import zmq
from torch.distributed import barrier
from sglang.global_config import global_config
from sglang.srt.configs.model_config import ModelConfig
@@ -59,6 +60,8 @@ from sglang.srt.managers.io_struct import (
ReleaseMemoryOccupationReqOutput,
ResumeMemoryOccupationReqInput,
ResumeMemoryOccupationReqOutput,
RpcReqInput,
RpcReqOutput,
SetInternalStateReq,
SetInternalStateReqOutput,
TokenizedEmbeddingReqInput,
@@ -193,8 +196,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
self.send_to_detokenizer = get_zmq_socket(
context, zmq.PUSH, port_args.detokenizer_ipc_name, False
)
self.recv_from_rpc = get_zmq_socket(
context, zmq.DEALER, port_args.rpc_ipc_name, False
)
else:
self.recv_from_tokenizer = None
self.recv_from_rpc = None
self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
@@ -376,6 +384,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
(ProfileReq, self.profile),
(GetInternalStateReq, self.get_internal_state),
(SetInternalStateReq, self.set_internal_state),
(RpcReqInput, self.handle_rpc_request),
]
)
@@ -549,6 +558,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
except zmq.ZMQError:
break
recv_reqs.append(recv_req)
while True:
try:
recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
except zmq.ZMQError:
break
recv_reqs.append(recv_rpc)
else:
recv_reqs = None
@@ -600,7 +616,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
output = self._request_dispatcher(recv_req)
if output is not None:
self.send_to_tokenizer.send_pyobj(output)
if isinstance(output, RpcReqOutput):
if self.recv_from_rpc is not None:
self.recv_from_rpc.send_pyobj(output)
else:
self.send_to_tokenizer.send_pyobj(output)
def handle_generate_request(
self,
@@ -1492,6 +1512,47 @@ class Scheduler(SchedulerOutputProcessorMixin):
server_args=global_server_args_dict,
)
def handle_rpc_request(self, recv_req: RpcReqInput):
# Handle RPC requests
logger.info(
f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
)
success = True
exec = None
try:
func = getattr(self, recv_req.method)
func(recv_req.parameters)
except Exception as e:
success = False
exec = e
logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")
barrier()
return RpcReqOutput(success, "" if not exec else str(exec))
def save_remote_model(self, params):
url = params["url"]
if isinstance(self.tp_worker, TpModelWorkerClient):
worker = self.tp_worker.worker
else:
worker = self.tp_worker
worker.model_runner.save_remote_model(url)
def save_sharded_model(self, params):
if isinstance(self.tp_worker, TpModelWorkerClient):
worker = self.tp_worker.worker
else:
worker = self.tp_worker
worker.model_runner.save_sharded_model(
path=params["path"],
pattern=params["pattern"],
max_size=params["max_size"],
)
def abort_request(self, recv_req: AbortReq):
# Delete requests in the waiting queue
to_del = []