feat(remote_model): support variable remote backend for model loader (#3964)
Signed-off-by: wangyu <wangyu.steph@bytedance.com>
This commit is contained in:
@@ -32,6 +32,7 @@ import psutil
|
||||
import setproctitle
|
||||
import torch
|
||||
import zmq
|
||||
from torch.distributed import barrier
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
@@ -59,6 +60,8 @@ from sglang.srt.managers.io_struct import (
|
||||
ReleaseMemoryOccupationReqOutput,
|
||||
ResumeMemoryOccupationReqInput,
|
||||
ResumeMemoryOccupationReqOutput,
|
||||
RpcReqInput,
|
||||
RpcReqOutput,
|
||||
SetInternalStateReq,
|
||||
SetInternalStateReqOutput,
|
||||
TokenizedEmbeddingReqInput,
|
||||
@@ -193,8 +196,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
self.send_to_detokenizer = get_zmq_socket(
|
||||
context, zmq.PUSH, port_args.detokenizer_ipc_name, False
|
||||
)
|
||||
|
||||
self.recv_from_rpc = get_zmq_socket(
|
||||
context, zmq.DEALER, port_args.rpc_ipc_name, False
|
||||
)
|
||||
else:
|
||||
self.recv_from_tokenizer = None
|
||||
self.recv_from_rpc = None
|
||||
self.send_to_tokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
||||
self.send_to_detokenizer = SimpleNamespace(send_pyobj=lambda x: None)
|
||||
|
||||
@@ -376,6 +384,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
(ProfileReq, self.profile),
|
||||
(GetInternalStateReq, self.get_internal_state),
|
||||
(SetInternalStateReq, self.set_internal_state),
|
||||
(RpcReqInput, self.handle_rpc_request),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -549,6 +558,13 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
except zmq.ZMQError:
|
||||
break
|
||||
recv_reqs.append(recv_req)
|
||||
|
||||
while True:
|
||||
try:
|
||||
recv_rpc = self.recv_from_rpc.recv_pyobj(zmq.NOBLOCK)
|
||||
except zmq.ZMQError:
|
||||
break
|
||||
recv_reqs.append(recv_rpc)
|
||||
else:
|
||||
recv_reqs = None
|
||||
|
||||
@@ -600,7 +616,11 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
|
||||
output = self._request_dispatcher(recv_req)
|
||||
if output is not None:
|
||||
self.send_to_tokenizer.send_pyobj(output)
|
||||
if isinstance(output, RpcReqOutput):
|
||||
if self.recv_from_rpc is not None:
|
||||
self.recv_from_rpc.send_pyobj(output)
|
||||
else:
|
||||
self.send_to_tokenizer.send_pyobj(output)
|
||||
|
||||
def handle_generate_request(
|
||||
self,
|
||||
@@ -1492,6 +1512,47 @@ class Scheduler(SchedulerOutputProcessorMixin):
|
||||
server_args=global_server_args_dict,
|
||||
)
|
||||
|
||||
def handle_rpc_request(self, recv_req: RpcReqInput):
|
||||
# Handle RPC requests
|
||||
logger.info(
|
||||
f"handle_rpc_request: {recv_req.method}, param: {recv_req.parameters}"
|
||||
)
|
||||
|
||||
success = True
|
||||
exec = None
|
||||
try:
|
||||
func = getattr(self, recv_req.method)
|
||||
func(recv_req.parameters)
|
||||
except Exception as e:
|
||||
success = False
|
||||
exec = e
|
||||
logger.error(f"Failed to call rpc {recv_req.method}: {str(e)}")
|
||||
|
||||
barrier()
|
||||
return RpcReqOutput(success, "" if not exec else str(exec))
|
||||
|
||||
def save_remote_model(self, params):
|
||||
url = params["url"]
|
||||
|
||||
if isinstance(self.tp_worker, TpModelWorkerClient):
|
||||
worker = self.tp_worker.worker
|
||||
else:
|
||||
worker = self.tp_worker
|
||||
|
||||
worker.model_runner.save_remote_model(url)
|
||||
|
||||
def save_sharded_model(self, params):
|
||||
if isinstance(self.tp_worker, TpModelWorkerClient):
|
||||
worker = self.tp_worker.worker
|
||||
else:
|
||||
worker = self.tp_worker
|
||||
|
||||
worker.model_runner.save_sharded_model(
|
||||
path=params["path"],
|
||||
pattern=params["pattern"],
|
||||
max_size=params["max_size"],
|
||||
)
|
||||
|
||||
def abort_request(self, recv_req: AbortReq):
|
||||
# Delete requests in the waiting queue
|
||||
to_del = []
|
||||
|
||||
Reference in New Issue
Block a user