Improve process creation (#1534)
This commit is contained in:
@@ -1,207 +0,0 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""
|
||||
A controller that manages multiple data parallel workers.
|
||||
Each data parallel worker can manage multiple tensor parallel workers.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import multiprocessing
|
||||
from enum import Enum, auto
|
||||
|
||||
import numpy as np
|
||||
import zmq
|
||||
|
||||
from sglang.srt.managers.controller_single import (
|
||||
start_controller_process as start_controller_process_single,
|
||||
)
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
FlushCacheReq,
|
||||
TokenizedGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import configure_logger, kill_parent_process
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoadBalanceMethod(Enum):
|
||||
"""Load balance method."""
|
||||
|
||||
ROUND_ROBIN = auto()
|
||||
SHORTEST_QUEUE = auto()
|
||||
|
||||
@classmethod
|
||||
def from_str(cls, method: str):
|
||||
method = method.upper()
|
||||
try:
|
||||
return cls[method]
|
||||
except KeyError as exc:
|
||||
raise ValueError(f"Invalid load balance method: {method}") from exc
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class WorkerHandle:
|
||||
"""Store the handle of a data parallel worker."""
|
||||
|
||||
proc: multiprocessing.Process
|
||||
queue: multiprocessing.Queue
|
||||
|
||||
|
||||
class ControllerMulti:
|
||||
"""A controller that manages multiple data parallel workers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
):
|
||||
# Parse args
|
||||
self.server_args = server_args
|
||||
self.port_args = port_args
|
||||
self.load_balance_method = LoadBalanceMethod.from_str(
|
||||
server_args.load_balance_method
|
||||
)
|
||||
|
||||
# Init communication
|
||||
context = zmq.Context()
|
||||
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
||||
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.controller_port}")
|
||||
|
||||
# Dispatch method
|
||||
self.round_robin_counter = 0
|
||||
dispatch_lookup = {
|
||||
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
|
||||
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
|
||||
}
|
||||
self.dispatching = dispatch_lookup[self.load_balance_method]
|
||||
|
||||
# Start data parallel workers
|
||||
self.workers = []
|
||||
for i in range(server_args.dp_size):
|
||||
self.start_dp_worker(i)
|
||||
|
||||
def start_dp_worker(self, dp_worker_id: int):
|
||||
tp_size = self.server_args.tp_size
|
||||
|
||||
pipe_controller_reader, pipe_controller_writer = multiprocessing.Pipe(
|
||||
duplex=False
|
||||
)
|
||||
|
||||
gpu_ids = list(range(dp_worker_id * tp_size, (dp_worker_id + 1) * tp_size))
|
||||
queue = multiprocessing.Queue()
|
||||
proc = multiprocessing.Process(
|
||||
target=start_controller_process_single,
|
||||
args=(
|
||||
self.server_args,
|
||||
self.port_args,
|
||||
pipe_controller_writer,
|
||||
True,
|
||||
gpu_ids,
|
||||
dp_worker_id,
|
||||
queue,
|
||||
),
|
||||
)
|
||||
proc.start()
|
||||
|
||||
controller_init_state = pipe_controller_reader.recv()
|
||||
if controller_init_state != "init ok":
|
||||
raise RuntimeError(
|
||||
f"Initialization failed. controller_init_state: {controller_init_state}"
|
||||
)
|
||||
self.workers.append(
|
||||
WorkerHandle(
|
||||
proc=proc,
|
||||
queue=queue,
|
||||
)
|
||||
)
|
||||
|
||||
def round_robin_scheduler(self, input_requests):
|
||||
for r in input_requests:
|
||||
self.workers[self.round_robin_counter].queue.put(r)
|
||||
self.round_robin_counter = (self.round_robin_counter + 1) % len(
|
||||
self.workers
|
||||
)
|
||||
|
||||
def shortest_queue_scheduler(self, input_requests):
|
||||
for r in input_requests:
|
||||
queue_sizes = [worker.queue.qsize() for worker in self.workers]
|
||||
wid = np.argmin(queue_sizes)
|
||||
self.workers[wid].queue.put(r)
|
||||
|
||||
def loop_for_forward(self):
|
||||
while True:
|
||||
recv_reqs = self.recv_requests()
|
||||
self.dispatching(recv_reqs)
|
||||
|
||||
def recv_requests(self):
|
||||
recv_reqs = []
|
||||
|
||||
while True:
|
||||
try:
|
||||
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
||||
except zmq.ZMQError:
|
||||
break
|
||||
|
||||
if isinstance(recv_req, FlushCacheReq):
|
||||
# TODO(lsyin): apply more specific flushCacheReq
|
||||
for worker in self.workers:
|
||||
worker.queue.put(recv_req)
|
||||
elif isinstance(recv_req, AbortReq):
|
||||
in_queue = False
|
||||
for i, req in enumerate(recv_reqs):
|
||||
if req.rid == recv_req.rid:
|
||||
recv_reqs[i] = recv_req
|
||||
in_queue = True
|
||||
break
|
||||
if not in_queue:
|
||||
# Send abort req to all TP groups
|
||||
for worker in self.workers:
|
||||
worker.queue.put(recv_req)
|
||||
elif isinstance(recv_req, TokenizedGenerateReqInput):
|
||||
recv_reqs.append(recv_req)
|
||||
else:
|
||||
logger.error(f"Invalid object: {recv_req}")
|
||||
|
||||
return recv_reqs
|
||||
|
||||
|
||||
def start_controller_process(
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
pipe_writer,
|
||||
):
|
||||
"""Start a controller process."""
|
||||
|
||||
configure_logger(server_args)
|
||||
|
||||
try:
|
||||
controller = ControllerMulti(server_args, port_args)
|
||||
except Exception:
|
||||
pipe_writer.send(get_exception_traceback())
|
||||
raise
|
||||
|
||||
pipe_writer.send("init ok")
|
||||
|
||||
try:
|
||||
controller.loop_for_forward()
|
||||
except Exception:
|
||||
logger.error("Exception in ControllerMulti:\n" + get_exception_traceback())
|
||||
finally:
|
||||
kill_parent_process()
|
||||
@@ -1,164 +0,0 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""A controller that manages a group of tensor parallel workers."""
|
||||
|
||||
import logging
|
||||
import multiprocessing
|
||||
from typing import List
|
||||
|
||||
import zmq
|
||||
|
||||
from sglang.srt.managers.tp_worker import (
|
||||
ModelTpServer,
|
||||
broadcast_recv_input,
|
||||
launch_tp_servers,
|
||||
)
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import configure_logger, kill_parent_process
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ControllerSingle:
|
||||
"""A controller that manages a group of tensor parallel workers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
gpu_ids: List[int],
|
||||
is_data_parallel_worker: bool,
|
||||
dp_worker_id: int,
|
||||
mp_queue: multiprocessing.Queue,
|
||||
):
|
||||
# Parse args
|
||||
self.tp_size = server_args.tp_size
|
||||
self.is_dp_worker = is_data_parallel_worker
|
||||
self.dp_worker_id = dp_worker_id
|
||||
self.mp_queue = mp_queue
|
||||
|
||||
# Init inter-process communication
|
||||
context = zmq.Context(2)
|
||||
|
||||
if not self.is_dp_worker:
|
||||
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
||||
self.recv_from_tokenizer.bind(
|
||||
f"tcp://127.0.0.1:{port_args.controller_port}"
|
||||
)
|
||||
|
||||
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
||||
self.send_to_detokenizer.connect(
|
||||
f"tcp://127.0.0.1:{port_args.detokenizer_port}"
|
||||
)
|
||||
|
||||
# Launch other tp ranks
|
||||
tp_size_local = server_args.tp_size // server_args.nnodes
|
||||
self.tp_procs = []
|
||||
if tp_size_local > 1:
|
||||
tp_rank_range = range(1, tp_size_local)
|
||||
self.tp_procs = launch_tp_servers(
|
||||
gpu_ids,
|
||||
tp_rank_range,
|
||||
server_args,
|
||||
port_args.nccl_ports[dp_worker_id],
|
||||
)
|
||||
|
||||
# Launch tp rank 0
|
||||
self.tp_server = ModelTpServer(
|
||||
gpu_ids[0],
|
||||
0,
|
||||
server_args,
|
||||
port_args.nccl_ports[dp_worker_id],
|
||||
)
|
||||
self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
|
||||
|
||||
def loop_for_forward(self):
|
||||
while True:
|
||||
if not self.is_dp_worker:
|
||||
recv_reqs = self.recv_requests_from_zmq()
|
||||
else:
|
||||
recv_reqs = self.recv_requests_from_mp_queue()
|
||||
|
||||
if self.tp_size > 1:
|
||||
broadcast_recv_input(recv_reqs, 0, self.tp_cpu_group)
|
||||
|
||||
out_pyobjs = self.tp_server.exposed_step(recv_reqs)
|
||||
|
||||
for obj in out_pyobjs:
|
||||
self.send_to_detokenizer.send_pyobj(obj)
|
||||
|
||||
def recv_requests_from_zmq(self):
|
||||
recv_reqs = []
|
||||
while True:
|
||||
try:
|
||||
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
||||
except zmq.ZMQError:
|
||||
break
|
||||
recv_reqs.append(recv_req)
|
||||
|
||||
return recv_reqs
|
||||
|
||||
def recv_requests_from_mp_queue(self):
|
||||
recv_reqs = []
|
||||
while not self.mp_queue.empty():
|
||||
recv_reqs.append(self.mp_queue.get())
|
||||
return recv_reqs
|
||||
|
||||
|
||||
def start_controller_process(
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
pipe_writer: multiprocessing.connection.Connection,
|
||||
is_data_parallel_worker: bool = False,
|
||||
gpu_ids: List[int] = None,
|
||||
dp_worker_id: int = None,
|
||||
queue: multiprocessing.connection.Connection = None,
|
||||
):
|
||||
"""Start a controller process."""
|
||||
if is_data_parallel_worker:
|
||||
logger_prefix = f" DP{dp_worker_id} TP0"
|
||||
else:
|
||||
logger_prefix = " TP0"
|
||||
configure_logger(server_args, prefix=logger_prefix)
|
||||
|
||||
if not is_data_parallel_worker:
|
||||
tp_size_local = server_args.tp_size // server_args.nnodes
|
||||
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
|
||||
dp_worker_id = 0
|
||||
queue = None
|
||||
|
||||
try:
|
||||
controller = ControllerSingle(
|
||||
server_args,
|
||||
port_args,
|
||||
gpu_ids,
|
||||
is_data_parallel_worker,
|
||||
dp_worker_id,
|
||||
queue,
|
||||
)
|
||||
except Exception:
|
||||
pipe_writer.send(get_exception_traceback())
|
||||
raise
|
||||
|
||||
pipe_writer.send("init ok")
|
||||
|
||||
try:
|
||||
controller.loop_for_forward()
|
||||
except Exception:
|
||||
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
|
||||
finally:
|
||||
kill_parent_process()
|
||||
@@ -16,6 +16,8 @@ limitations under the License.
|
||||
"""DetokenizerManager is a process that detokenizes the token ids."""
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from typing import List
|
||||
|
||||
import zmq
|
||||
@@ -29,8 +31,11 @@ from sglang.srt.managers.io_struct import (
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import configure_logger, kill_parent_process
|
||||
from sglang.utils import find_printable_text, get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class DecodeStatus:
|
||||
@@ -53,8 +58,8 @@ class DetokenizerManager:
|
||||
):
|
||||
# Init inter-process communication
|
||||
context = zmq.Context(2)
|
||||
self.recv_from_router = context.socket(zmq.PULL)
|
||||
self.recv_from_router.bind(f"tcp://127.0.0.1:{port_args.detokenizer_port}")
|
||||
self.recv_from_scheduler = context.socket(zmq.PULL)
|
||||
self.recv_from_scheduler.bind(f"tcp://127.0.0.1:{port_args.detokenizer_port}")
|
||||
|
||||
self.send_to_tokenizer = context.socket(zmq.PUSH)
|
||||
self.send_to_tokenizer.connect(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
|
||||
@@ -68,13 +73,13 @@ class DetokenizerManager:
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
)
|
||||
|
||||
self.decode_status = {}
|
||||
self.decode_status = LimitedCapacityDict()
|
||||
|
||||
def handle_loop(self):
|
||||
def event_loop(self):
|
||||
"""The event loop that handles requests"""
|
||||
|
||||
while True:
|
||||
recv_obj = self.recv_from_router.recv_pyobj()
|
||||
recv_obj = self.recv_from_scheduler.recv_pyobj()
|
||||
|
||||
if isinstance(recv_obj, BatchEmbeddingOut):
|
||||
# If it is embedding model, no detokenization is needed.
|
||||
@@ -165,15 +170,29 @@ class DetokenizerManager:
|
||||
)
|
||||
|
||||
|
||||
def start_detokenizer_process(
|
||||
class LimitedCapacityDict(OrderedDict):
|
||||
def __init__(self, capacity=1 << 15, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.capacity = capacity
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
if len(self) >= self.capacity:
|
||||
# Remove the oldest element (first item in the dict)
|
||||
self.popitem(last=False)
|
||||
# Set the new item
|
||||
super().__setitem__(key, value)
|
||||
|
||||
|
||||
def run_detokenizer_process(
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
pipe_writer,
|
||||
):
|
||||
configure_logger(server_args)
|
||||
|
||||
try:
|
||||
manager = DetokenizerManager(server_args, port_args)
|
||||
manager.event_loop()
|
||||
except Exception:
|
||||
pipe_writer.send(get_exception_traceback())
|
||||
raise
|
||||
pipe_writer.send("init ok")
|
||||
manager.handle_loop()
|
||||
msg = get_exception_traceback()
|
||||
logger.error(msg)
|
||||
kill_parent_process()
|
||||
|
||||
111
python/sglang/srt/managers/scheduler.py
Normal file
111
python/sglang/srt/managers/scheduler.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""
|
||||
Copyright 2023-2024 SGLang Team
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
"""
|
||||
|
||||
"""A scheduler that manages a tensor parallel GPU worker."""
|
||||
|
||||
import logging
|
||||
import multiprocessing
|
||||
|
||||
import zmq
|
||||
|
||||
from sglang.srt.managers.tp_worker import ModelTpServer
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import broadcast_pyobj, configure_logger, kill_parent_process
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Scheduler:
|
||||
"""A scheduler that manages a tensor parallel GPU worker."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
gpu_id: int,
|
||||
tp_rank: int,
|
||||
):
|
||||
# Parse args
|
||||
self.tp_rank = tp_rank
|
||||
self.tp_size = server_args.tp_size
|
||||
|
||||
# Init inter-process communication
|
||||
context = zmq.Context(2)
|
||||
|
||||
if self.tp_rank == 0:
|
||||
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
||||
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.scheduler_port}")
|
||||
|
||||
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
||||
self.send_to_detokenizer.connect(
|
||||
f"tcp://127.0.0.1:{port_args.detokenizer_port}"
|
||||
)
|
||||
else:
|
||||
self.send_to_detokenizer = None
|
||||
|
||||
# Launch a tp server
|
||||
self.tp_server = ModelTpServer(
|
||||
gpu_id=gpu_id,
|
||||
tp_rank=tp_rank,
|
||||
server_args=server_args,
|
||||
nccl_port=port_args.nccl_ports[0],
|
||||
)
|
||||
self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
|
||||
|
||||
def event_loop(self):
|
||||
while True:
|
||||
if self.tp_rank == 0:
|
||||
recv_reqs = self.recv_requests_from_zmq()
|
||||
else:
|
||||
recv_reqs = None
|
||||
|
||||
recv_reqs = broadcast_pyobj(recv_reqs, self.tp_rank, self.tp_cpu_group)
|
||||
out_pyobjs = self.tp_server.exposed_step(recv_reqs)
|
||||
|
||||
if self.tp_rank == 0:
|
||||
for obj in out_pyobjs:
|
||||
self.send_to_detokenizer.send_pyobj(obj)
|
||||
|
||||
def recv_requests_from_zmq(self):
|
||||
recv_reqs = []
|
||||
|
||||
while True:
|
||||
try:
|
||||
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
||||
except zmq.ZMQError:
|
||||
break
|
||||
recv_reqs.append(recv_req)
|
||||
|
||||
return recv_reqs
|
||||
|
||||
|
||||
def run_scheduler_process(
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
gpu_id: int,
|
||||
tp_rank: int,
|
||||
pipe_writer: multiprocessing.connection.Connection,
|
||||
):
|
||||
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
||||
|
||||
try:
|
||||
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank)
|
||||
pipe_writer.send("ready")
|
||||
scheduler.event_loop()
|
||||
except Exception:
|
||||
msg = get_exception_traceback()
|
||||
logger.error(msg)
|
||||
kill_parent_process()
|
||||
@@ -88,8 +88,8 @@ class TokenizerManager:
|
||||
self.recv_from_detokenizer = context.socket(zmq.PULL)
|
||||
self.recv_from_detokenizer.bind(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
|
||||
|
||||
self.send_to_controller = context.socket(zmq.PUSH)
|
||||
self.send_to_controller.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
|
||||
self.send_to_scheduler = context.socket(zmq.PUSH)
|
||||
self.send_to_scheduler.connect(f"tcp://127.0.0.1:{port_args.scheduler_port}")
|
||||
|
||||
# Read model args
|
||||
self.model_path = server_args.model_path
|
||||
@@ -285,7 +285,7 @@ class TokenizerManager:
|
||||
input_ids,
|
||||
sampling_params,
|
||||
)
|
||||
self.send_to_controller.send_pyobj(tokenized_obj)
|
||||
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
||||
|
||||
# Recv results
|
||||
event = asyncio.Event()
|
||||
@@ -397,7 +397,7 @@ class TokenizerManager:
|
||||
input_ids,
|
||||
sampling_params,
|
||||
)
|
||||
self.send_to_controller.send_pyobj(tokenized_obj)
|
||||
self.send_to_scheduler.send_pyobj(tokenized_obj)
|
||||
|
||||
event = asyncio.Event()
|
||||
state = ReqState([], False, event)
|
||||
@@ -530,14 +530,14 @@ class TokenizerManager:
|
||||
|
||||
def flush_cache(self):
|
||||
req = FlushCacheReq()
|
||||
self.send_to_controller.send_pyobj(req)
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
def abort_request(self, rid: str):
|
||||
if rid not in self.rid_to_state:
|
||||
return
|
||||
del self.rid_to_state[rid]
|
||||
req = AbortReq(rid)
|
||||
self.send_to_controller.send_pyobj(req)
|
||||
self.send_to_scheduler.send_pyobj(req)
|
||||
|
||||
async def update_weights(
|
||||
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
|
||||
@@ -554,7 +554,7 @@ class TokenizerManager:
|
||||
# wait for the previous generation requests to finish
|
||||
while len(self.rid_to_state) > 0:
|
||||
await asyncio.sleep(0)
|
||||
self.send_to_controller.send_pyobj(obj)
|
||||
self.send_to_scheduler.send_pyobj(obj)
|
||||
self.model_update_result = asyncio.Future()
|
||||
result = await self.model_update_result
|
||||
if result.success:
|
||||
@@ -665,6 +665,7 @@ class TokenizerManager:
|
||||
def detokenize_logprob_tokens(
|
||||
self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool
|
||||
):
|
||||
# TODO(lianmin): This should run on DetokenizerManager
|
||||
if not decode_to_text:
|
||||
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
|
||||
|
||||
|
||||
@@ -17,16 +17,12 @@ limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import pickle
|
||||
import time
|
||||
import warnings
|
||||
from typing import Any, List, Optional, Union
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
@@ -58,7 +54,7 @@ from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import (
|
||||
configure_logger,
|
||||
broadcast_pyobj,
|
||||
is_multimodal_model,
|
||||
set_random_seed,
|
||||
suppress_other_loggers,
|
||||
@@ -140,7 +136,7 @@ class ModelTpServer:
|
||||
)
|
||||
|
||||
# Sync random seed across TP workers
|
||||
server_args.random_seed = broadcast_recv_input(
|
||||
server_args.random_seed = broadcast_pyobj(
|
||||
[server_args.random_seed],
|
||||
self.tp_rank,
|
||||
self.model_runner.tp_group.cpu_group,
|
||||
@@ -935,82 +931,3 @@ class ModelTpServer:
|
||||
else:
|
||||
logger.error(message)
|
||||
return success, message
|
||||
|
||||
|
||||
def run_tp_server(
|
||||
gpu_id: int,
|
||||
tp_rank: int,
|
||||
server_args: ServerArgs,
|
||||
nccl_port: int,
|
||||
):
|
||||
"""Run a tensor parallel model server."""
|
||||
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
||||
|
||||
try:
|
||||
model_server = ModelTpServer(
|
||||
gpu_id,
|
||||
tp_rank,
|
||||
server_args,
|
||||
nccl_port,
|
||||
)
|
||||
tp_cpu_group = model_server.model_runner.tp_group.cpu_group
|
||||
|
||||
while True:
|
||||
recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group)
|
||||
model_server.exposed_step(recv_reqs)
|
||||
except Exception:
|
||||
logger.error("Exception in run_tp_server:\n" + get_exception_traceback())
|
||||
raise
|
||||
|
||||
|
||||
def launch_tp_servers(
|
||||
gpu_ids: List[int],
|
||||
tp_rank_range: List[int],
|
||||
server_args: ServerArgs,
|
||||
nccl_port: int,
|
||||
):
|
||||
"""Launch multiple tensor parallel servers."""
|
||||
procs = []
|
||||
for i in tp_rank_range:
|
||||
proc = multiprocessing.Process(
|
||||
target=run_tp_server,
|
||||
args=(gpu_ids[i], i, server_args, nccl_port),
|
||||
)
|
||||
proc.start()
|
||||
procs.append(proc)
|
||||
|
||||
return procs
|
||||
|
||||
|
||||
def broadcast_recv_input(
|
||||
data: Any, rank: int, dist_group: torch.distributed.ProcessGroup
|
||||
):
|
||||
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
|
||||
|
||||
if rank == 0:
|
||||
if len(data) == 0:
|
||||
tensor_size = torch.tensor([0], dtype=torch.long)
|
||||
dist.broadcast(tensor_size, src=0, group=dist_group)
|
||||
else:
|
||||
serialized_data = pickle.dumps(data)
|
||||
size = len(serialized_data)
|
||||
tensor_data = torch.ByteTensor(list(serialized_data))
|
||||
tensor_size = torch.tensor([size], dtype=torch.long)
|
||||
|
||||
dist.broadcast(tensor_size, src=0, group=dist_group)
|
||||
dist.broadcast(tensor_data, src=0, group=dist_group)
|
||||
return data
|
||||
else:
|
||||
tensor_size = torch.tensor([0], dtype=torch.long)
|
||||
dist.broadcast(tensor_size, src=0, group=dist_group)
|
||||
size = tensor_size.item()
|
||||
|
||||
if size == 0:
|
||||
return []
|
||||
|
||||
tensor_data = torch.empty(size, dtype=torch.uint8)
|
||||
dist.broadcast(tensor_data, src=0, group=dist_group)
|
||||
|
||||
serialized_data = bytes(tensor_data.tolist())
|
||||
data = pickle.loads(serialized_data)
|
||||
return data
|
||||
|
||||
@@ -135,8 +135,8 @@ class ModelRunner:
|
||||
if not self.server_args.enable_p2p_check:
|
||||
monkey_patch_vllm_p2p_access_check(self.gpu_id)
|
||||
|
||||
if self.server_args.nccl_init_addr:
|
||||
nccl_init_method = f"tcp://{self.server_args.nccl_init_addr}"
|
||||
if self.server_args.dist_init_addr:
|
||||
nccl_init_method = f"tcp://{self.server_args.dist_init_addr}"
|
||||
else:
|
||||
nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
|
||||
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
|
||||
|
||||
@@ -43,20 +43,14 @@ from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
||||
from sglang.srt.constrained import disable_cache
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
from sglang.srt.managers.controller_multi import (
|
||||
start_controller_process as start_controller_process_multi,
|
||||
)
|
||||
from sglang.srt.managers.controller_single import launch_tp_servers
|
||||
from sglang.srt.managers.controller_single import (
|
||||
start_controller_process as start_controller_process_single,
|
||||
)
|
||||
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
||||
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
|
||||
from sglang.srt.managers.io_struct import (
|
||||
EmbeddingReqInput,
|
||||
GenerateReqInput,
|
||||
RewardReqInput,
|
||||
UpdateWeightReqInput,
|
||||
)
|
||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
from sglang.srt.openai_api.adapter import (
|
||||
load_chat_template_for_openai_api,
|
||||
@@ -82,8 +76,7 @@ from sglang.srt.utils import (
|
||||
is_hip,
|
||||
kill_child_process,
|
||||
maybe_set_triton_cache_manager,
|
||||
prepare_model,
|
||||
prepare_tokenizer,
|
||||
prepare_model_and_tokenizer,
|
||||
set_ulimit,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
@@ -303,8 +296,8 @@ def launch_server(
|
||||
"""Launch an HTTP server."""
|
||||
global tokenizer_manager
|
||||
|
||||
# Configure global environment
|
||||
configure_logger(server_args)
|
||||
|
||||
server_args.check_server_args()
|
||||
_set_envs_and_config(server_args)
|
||||
|
||||
@@ -317,81 +310,60 @@ def launch_server(
|
||||
ports = server_args.additional_ports
|
||||
port_args = PortArgs(
|
||||
tokenizer_port=ports[0],
|
||||
controller_port=ports[1],
|
||||
scheduler_port=ports[1],
|
||||
detokenizer_port=ports[2],
|
||||
nccl_ports=ports[3:],
|
||||
)
|
||||
logger.info(f"{server_args=}")
|
||||
|
||||
# Use model from www.modelscope.cn, first download the model.
|
||||
server_args.model_path = prepare_model(server_args.model_path)
|
||||
server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path)
|
||||
|
||||
# Launch processes for multi-node tensor parallelism
|
||||
if server_args.nnodes > 1 and server_args.node_rank != 0:
|
||||
tp_size_local = server_args.tp_size // server_args.nnodes
|
||||
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
|
||||
tp_rank_range = list(
|
||||
range(
|
||||
server_args.node_rank * tp_size_local,
|
||||
(server_args.node_rank + 1) * tp_size_local,
|
||||
)
|
||||
)
|
||||
procs = launch_tp_servers(
|
||||
gpu_ids,
|
||||
tp_rank_range,
|
||||
server_args,
|
||||
ports[3],
|
||||
)
|
||||
|
||||
try:
|
||||
for p in procs:
|
||||
p.join()
|
||||
finally:
|
||||
kill_child_process(os.getpid(), including_parent=False)
|
||||
return
|
||||
|
||||
# Launch processes
|
||||
pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
|
||||
|
||||
if server_args.dp_size == 1:
|
||||
start_controller_process = start_controller_process_single
|
||||
else:
|
||||
start_controller_process = start_controller_process_multi
|
||||
proc_controller = mp.Process(
|
||||
target=start_controller_process,
|
||||
args=(server_args, port_args, pipe_controller_writer),
|
||||
# If using model from www.modelscope.cn, first download the model.
|
||||
server_args.model_path, server_args.tokenizer_path = prepare_model_and_tokenizer(
|
||||
server_args.model_path, server_args.tokenizer_path
|
||||
)
|
||||
proc_controller.start()
|
||||
|
||||
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
|
||||
proc_detoken = mp.Process(
|
||||
target=start_detokenizer_process,
|
||||
# Launch tensor parallel scheduler processes
|
||||
scheduler_procs = []
|
||||
scheduler_pipe_readers = []
|
||||
tp_size_per_node = server_args.tp_size // server_args.nnodes
|
||||
tp_rank_range = range(
|
||||
tp_size_per_node * server_args.node_rank,
|
||||
tp_size_per_node * (server_args.node_rank + 1),
|
||||
)
|
||||
for tp_rank in tp_rank_range:
|
||||
reader, writer = mp.Pipe(duplex=False)
|
||||
gpu_id = tp_rank % tp_size_per_node
|
||||
proc = mp.Process(
|
||||
target=run_scheduler_process,
|
||||
args=(server_args, port_args, gpu_id, tp_rank, writer),
|
||||
)
|
||||
proc.start()
|
||||
scheduler_procs.append(proc)
|
||||
scheduler_pipe_readers.append(reader)
|
||||
|
||||
if server_args.node_rank >= 1:
|
||||
# For other nodes, they do not need to run tokenizer or detokenizer,
|
||||
# so they can just wait here.
|
||||
while True:
|
||||
pass
|
||||
|
||||
# Launch detokenizer process
|
||||
detoken_proc = mp.Process(
|
||||
target=run_detokenizer_process,
|
||||
args=(
|
||||
server_args,
|
||||
port_args,
|
||||
pipe_detoken_writer,
|
||||
),
|
||||
)
|
||||
proc_detoken.start()
|
||||
detoken_proc.start()
|
||||
|
||||
# Launch tokenizer process
|
||||
tokenizer_manager = TokenizerManager(server_args, port_args)
|
||||
if server_args.chat_template:
|
||||
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
|
||||
|
||||
# Wait for the model to finish loading
|
||||
controller_init_state = pipe_controller_reader.recv()
|
||||
detoken_init_state = pipe_detoken_reader.recv()
|
||||
|
||||
if controller_init_state != "init ok" or detoken_init_state != "init ok":
|
||||
proc_controller.kill()
|
||||
proc_detoken.kill()
|
||||
raise RuntimeError(
|
||||
"Initialization failed. "
|
||||
f"controller_init_state: {controller_init_state}, "
|
||||
f"detoken_init_state: {detoken_init_state}"
|
||||
)
|
||||
assert proc_controller.is_alive() and proc_detoken.is_alive()
|
||||
# Wait for model to finish loading
|
||||
for i in range(len(scheduler_pipe_readers)):
|
||||
scheduler_pipe_readers[i].recv()
|
||||
|
||||
# Add api key authorization
|
||||
if server_args.api_key:
|
||||
@@ -404,7 +376,7 @@ def launch_server(
|
||||
t.start()
|
||||
|
||||
try:
|
||||
# Listen for requests
|
||||
# Listen for HTTP requests
|
||||
uvicorn.run(
|
||||
app,
|
||||
host=server_args.host,
|
||||
@@ -451,9 +423,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
||||
"at https://docs.flashinfer.ai/installation.html.",
|
||||
)
|
||||
|
||||
if is_hip():
|
||||
# to figure out a better method of not using fork later
|
||||
mp.set_start_method("spawn", force=True)
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
||||
|
||||
def _wait_and_warmup(server_args, pipe_finish_writer, pid):
|
||||
@@ -517,7 +487,7 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
|
||||
|
||||
logger.info("The server is fired up and ready to roll!")
|
||||
if pipe_finish_writer is not None:
|
||||
pipe_finish_writer.send("init ok")
|
||||
pipe_finish_writer.send("ready")
|
||||
|
||||
|
||||
class Runtime:
|
||||
@@ -564,7 +534,7 @@ class Runtime:
|
||||
except EOFError:
|
||||
init_state = ""
|
||||
|
||||
if init_state != "init ok":
|
||||
if init_state != "ready":
|
||||
self.shutdown()
|
||||
raise RuntimeError(
|
||||
"Initialization failed. Please see the error messages above."
|
||||
|
||||
@@ -78,9 +78,9 @@ class ServerArgs:
|
||||
load_balance_method: str = "round_robin"
|
||||
|
||||
# Distributed args
|
||||
nccl_init_addr: Optional[str] = None
|
||||
dist_init_addr: Optional[str] = None
|
||||
nnodes: int = 1
|
||||
node_rank: Optional[int] = None
|
||||
node_rank: int = 0
|
||||
|
||||
# Model override args in JSON
|
||||
json_model_override_args: str = "{}"
|
||||
@@ -426,14 +426,17 @@ class ServerArgs:
|
||||
|
||||
# Multi-node distributed serving args
|
||||
parser.add_argument(
|
||||
"--nccl-init-addr",
|
||||
"--dist-init-addr",
|
||||
"--nccl-init-addr", # For backward compatbility. This will be removed in the future.
|
||||
type=str,
|
||||
help="The nccl init address of multi-node server.",
|
||||
help="The host address for initializing distributed backend (e.g., `192.168.0.2:25000`).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--nnodes", type=int, default=ServerArgs.nnodes, help="The number of nodes."
|
||||
)
|
||||
parser.add_argument("--node-rank", type=int, help="The node rank.")
|
||||
parser.add_argument(
|
||||
"--node-rank", type=int, default=ServerArgs.node_rank, help="The node rank."
|
||||
)
|
||||
|
||||
# Model override args
|
||||
parser.add_argument(
|
||||
@@ -583,6 +586,11 @@ class ServerArgs:
|
||||
and (self.lora_paths is None or self.disable_radix_cache)
|
||||
), "compatibility of lora and cuda graph and radix attention is in progress"
|
||||
|
||||
assert self.dp_size == 1, (
|
||||
"The support for data parallelism is temporarily disabled during refactor. "
|
||||
"Please use sglang<=0.3.2 or wait for later updates."
|
||||
)
|
||||
|
||||
|
||||
def prepare_server_args(argv: List[str]) -> ServerArgs:
|
||||
"""
|
||||
@@ -604,9 +612,13 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PortArgs:
|
||||
# The port for tokenizer to receive inputs from detokenizer (zmq)
|
||||
tokenizer_port: int
|
||||
controller_port: int
|
||||
# The port for scheduler to receive inputs from tokenizer (zmq)
|
||||
scheduler_port: int
|
||||
# The port for detokenizer to receive inputs from scheduler (zmq)
|
||||
detokenizer_port: int
|
||||
# The port for nccl initialization for multiple TP groups (torch.dist)
|
||||
nccl_ports: List[int]
|
||||
|
||||
|
||||
|
||||
@@ -16,13 +16,12 @@ limitations under the License.
|
||||
"""Common utilities."""
|
||||
|
||||
import base64
|
||||
import fcntl
|
||||
import logging
|
||||
import os
|
||||
import pickle
|
||||
import random
|
||||
import resource
|
||||
import socket
|
||||
import struct
|
||||
import time
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from io import BytesIO
|
||||
@@ -36,7 +35,6 @@ import torch.distributed as dist
|
||||
from fastapi.responses import JSONResponse
|
||||
from packaging import version as pkg_version
|
||||
from torch import nn
|
||||
from torch.nn.parameter import Parameter
|
||||
from triton.runtime.cache import (
|
||||
FileCacheManager,
|
||||
default_cache_dir,
|
||||
@@ -539,89 +537,6 @@ class CustomCacheManager(FileCacheManager):
|
||||
raise RuntimeError("Could not create or locate cache dir")
|
||||
|
||||
|
||||
def get_ip_address(ifname):
|
||||
"""
|
||||
Get the IP address of a network interface.
|
||||
|
||||
:param ifname: Name of the network interface (e.g., 'eth0')
|
||||
:return: IP address of the network interface
|
||||
"""
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
ip_address = fcntl.ioctl(
|
||||
s.fileno(),
|
||||
0x8915, # SIOCGIFADDR
|
||||
struct.pack("256s", bytes(ifname[:15], "utf-8")),
|
||||
)[20:24]
|
||||
return socket.inet_ntoa(ip_address)
|
||||
|
||||
|
||||
def send_addrs_to_rank_0(model_port_args, server_args):
|
||||
assert server_args.node_rank != 0 and server_args.dp_size == 1
|
||||
|
||||
ifname = os.environ.get(
|
||||
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
|
||||
)
|
||||
ip_addr = get_ip_address(ifname)
|
||||
|
||||
num_tp_ports = server_args.tp_size // server_args.nnodes
|
||||
model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
|
||||
ip_addr = [int(x) for x in ip_addr.split(".")]
|
||||
addrs_tensor = torch.tensor(
|
||||
ip_addr + model_port_args.model_tp_ports, dtype=torch.int
|
||||
)
|
||||
|
||||
init_method = f"tcp://{server_args.nccl_init_addr}"
|
||||
dist.init_process_group(
|
||||
backend="gloo",
|
||||
init_method=init_method,
|
||||
rank=server_args.node_rank,
|
||||
world_size=server_args.nnodes,
|
||||
)
|
||||
dist.send(addrs_tensor, dst=0)
|
||||
print(
|
||||
f"Node {server_args.node_rank} sent: ip_address {ip_addr} and ports {model_port_args.model_tp_ports}"
|
||||
)
|
||||
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def receive_addrs(model_port_args, server_args):
|
||||
assert server_args.node_rank == 0 and server_args.dp_size == 1
|
||||
|
||||
ifname = os.environ.get(
|
||||
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
|
||||
)
|
||||
ip_addr = get_ip_address(ifname)
|
||||
|
||||
num_tp_ports = server_args.tp_size // server_args.nnodes
|
||||
model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
|
||||
|
||||
init_method = f"tcp://{server_args.nccl_init_addr}"
|
||||
dist.init_process_group(
|
||||
backend="gloo",
|
||||
init_method=init_method,
|
||||
rank=server_args.node_rank,
|
||||
world_size=server_args.nnodes,
|
||||
)
|
||||
|
||||
for src_rank in range(1, server_args.nnodes):
|
||||
tensor = torch.zeros(4 + num_tp_ports, dtype=torch.int)
|
||||
dist.recv(tensor, src=src_rank)
|
||||
ip = ".".join([str(x) for x in tensor[:4].tolist()])
|
||||
ports = tensor[4:].tolist()
|
||||
model_port_args.model_tp_ips[
|
||||
num_tp_ports * src_rank : num_tp_ports * (src_rank + 1)
|
||||
] = [ip] * num_tp_ports
|
||||
model_port_args.model_tp_ports[
|
||||
num_tp_ports * src_rank : num_tp_ports * (src_rank + 1)
|
||||
] = ports
|
||||
print(f"Node 0 received from rank {src_rank}: {tensor.tolist()}")
|
||||
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
def set_ulimit(target_soft_limit=65535):
|
||||
resource_type = resource.RLIMIT_NOFILE
|
||||
current_soft, current_hard = resource.getrlimit(resource_type)
|
||||
@@ -645,24 +560,16 @@ def add_api_key_middleware(app, api_key: str):
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
def prepare_model(model_path: str):
|
||||
def prepare_model_and_tokenizer(model_path: str, tokenizer_path: str):
|
||||
if "SGLANG_USE_MODELSCOPE" in os.environ:
|
||||
if not os.path.exists(model_path):
|
||||
from modelscope import snapshot_download
|
||||
|
||||
return snapshot_download(model_path)
|
||||
return model_path
|
||||
|
||||
|
||||
def prepare_tokenizer(tokenizer_path: str):
|
||||
if "SGLANG_USE_MODELSCOPE" in os.environ:
|
||||
if not os.path.exists(tokenizer_path):
|
||||
from modelscope import snapshot_download
|
||||
|
||||
return snapshot_download(
|
||||
model_path = snapshot_download(model_path)
|
||||
tokenizer_path = snapshot_download(
|
||||
tokenizer_path, ignore_patterns=["*.bin", "*.safetensors"]
|
||||
)
|
||||
return tokenizer_path
|
||||
return model_path, tokenizer_path
|
||||
|
||||
|
||||
def configure_logger(server_args, prefix: str = ""):
|
||||
@@ -704,3 +611,37 @@ def set_weight_attrs(
|
||||
for key, value in weight_attrs.items():
|
||||
assert not hasattr(weight, key), f"Overwriting existing tensor attribute: {key}"
|
||||
setattr(weight, key, value)
|
||||
|
||||
|
||||
def broadcast_pyobj(
|
||||
data: List[Any], rank: int, dist_group: torch.distributed.ProcessGroup
|
||||
):
|
||||
"""Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
|
||||
|
||||
if rank == 0:
|
||||
if len(data) == 0:
|
||||
tensor_size = torch.tensor([0], dtype=torch.long)
|
||||
dist.broadcast(tensor_size, src=0, group=dist_group)
|
||||
else:
|
||||
serialized_data = pickle.dumps(data)
|
||||
size = len(serialized_data)
|
||||
tensor_data = torch.ByteTensor(list(serialized_data))
|
||||
tensor_size = torch.tensor([size], dtype=torch.long)
|
||||
|
||||
dist.broadcast(tensor_size, src=0, group=dist_group)
|
||||
dist.broadcast(tensor_data, src=0, group=dist_group)
|
||||
return data
|
||||
else:
|
||||
tensor_size = torch.tensor([0], dtype=torch.long)
|
||||
dist.broadcast(tensor_size, src=0, group=dist_group)
|
||||
size = tensor_size.item()
|
||||
|
||||
if size == 0:
|
||||
return []
|
||||
|
||||
tensor_data = torch.empty(size, dtype=torch.uint8)
|
||||
dist.broadcast(tensor_data, src=0, group=dist_group)
|
||||
|
||||
serialized_data = bytes(tensor_data.tolist())
|
||||
data = pickle.loads(serialized_data)
|
||||
return data
|
||||
|
||||
Reference in New Issue
Block a user