2024-10-11 07:22:48 -07:00
|
|
|
"""
|
|
|
|
|
Copyright 2023-2024 SGLang Team
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License.
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
"""A controller that dispatches requests to multiple data parallel workers."""
|
|
|
|
|
|
|
|
|
|
import logging
|
|
|
|
|
import multiprocessing as mp
|
|
|
|
|
from enum import Enum, auto
|
|
|
|
|
|
|
|
|
|
import zmq
|
|
|
|
|
|
|
|
|
|
from sglang.srt.managers.io_struct import (
|
|
|
|
|
TokenizedEmbeddingReqInput,
|
|
|
|
|
TokenizedGenerateReqInput,
|
|
|
|
|
)
|
|
|
|
|
from sglang.srt.managers.scheduler import run_scheduler_process
|
|
|
|
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
|
|
|
|
from sglang.srt.utils import (
|
|
|
|
|
configure_logger,
|
2024-10-25 23:07:07 -07:00
|
|
|
get_zmq_socket,
|
2024-10-11 07:22:48 -07:00
|
|
|
kill_parent_process,
|
|
|
|
|
suppress_other_loggers,
|
|
|
|
|
)
|
|
|
|
|
from sglang.utils import get_exception_traceback
|
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoadBalanceMethod(Enum):
|
|
|
|
|
"""Load balance method."""
|
|
|
|
|
|
|
|
|
|
ROUND_ROBIN = auto()
|
|
|
|
|
SHORTEST_QUEUE = auto()
|
|
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
|
def from_str(cls, method: str):
|
|
|
|
|
method = method.upper()
|
|
|
|
|
try:
|
|
|
|
|
return cls[method]
|
|
|
|
|
except KeyError as exc:
|
|
|
|
|
raise ValueError(f"Invalid load balance method: {method}") from exc
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DataParallelController:
|
|
|
|
|
"""A controller that dispatches requests to multiple data parallel workers."""
|
|
|
|
|
|
|
|
|
|
def __init__(self, server_args, port_args) -> None:
|
|
|
|
|
# Parse args
|
|
|
|
|
self.server_args = server_args
|
|
|
|
|
self.port_args = port_args
|
|
|
|
|
self.load_balance_method = LoadBalanceMethod.from_str(
|
|
|
|
|
server_args.load_balance_method
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# Init inter-process communication
|
|
|
|
|
self.context = zmq.Context(1 + server_args.dp_size)
|
2024-10-25 23:07:07 -07:00
|
|
|
self.recv_from_tokenizer = get_zmq_socket(
|
|
|
|
|
self.context, zmq.PULL, port_args.scheduler_input_ipc_name
|
|
|
|
|
)
|
2024-10-11 07:22:48 -07:00
|
|
|
|
|
|
|
|
# Dispatch method
|
|
|
|
|
self.round_robin_counter = 0
|
|
|
|
|
dispatch_lookup = {
|
|
|
|
|
LoadBalanceMethod.ROUND_ROBIN: self.round_robin_scheduler,
|
|
|
|
|
LoadBalanceMethod.SHORTEST_QUEUE: self.shortest_queue_scheduler,
|
|
|
|
|
}
|
|
|
|
|
self.dispatching = dispatch_lookup[self.load_balance_method]
|
|
|
|
|
|
|
|
|
|
# Start data parallel workers
|
|
|
|
|
base_gpu_id = 0
|
|
|
|
|
self.workers = []
|
2024-11-16 17:01:43 +08:00
|
|
|
scheduler_pipe_readers = []
|
2024-10-11 07:22:48 -07:00
|
|
|
for dp_rank in range(server_args.dp_size):
|
|
|
|
|
tmp_port_args = PortArgs.init_new(server_args)
|
2024-11-16 00:30:39 -08:00
|
|
|
tmp_port_args.tokenizer_ipc_name = port_args.tokenizer_ipc_name
|
2024-10-11 07:22:48 -07:00
|
|
|
tmp_port_args.detokenizer_ipc_name = port_args.detokenizer_ipc_name
|
|
|
|
|
|
2024-11-16 17:01:43 +08:00
|
|
|
if server_args.enable_dp_attention:
|
|
|
|
|
# Share workers for DP and TP
|
|
|
|
|
send_to, reader = self.launch_tensor_parallel_process(
|
|
|
|
|
server_args,
|
|
|
|
|
tmp_port_args,
|
|
|
|
|
base_gpu_id,
|
|
|
|
|
dp_rank,
|
|
|
|
|
)
|
|
|
|
|
base_gpu_id += 1
|
|
|
|
|
scheduler_pipe_readers.append(reader)
|
|
|
|
|
else:
|
|
|
|
|
send_to = self.launch_tensor_parallel_group(
|
|
|
|
|
server_args,
|
|
|
|
|
tmp_port_args,
|
|
|
|
|
base_gpu_id,
|
|
|
|
|
dp_rank,
|
|
|
|
|
)
|
|
|
|
|
base_gpu_id += server_args.tp_size
|
2024-10-11 07:22:48 -07:00
|
|
|
self.workers.append(send_to)
|
2024-11-16 17:01:43 +08:00
|
|
|
|
|
|
|
|
for reader in scheduler_pipe_readers:
|
|
|
|
|
reader.recv()
|
2024-10-11 07:22:48 -07:00
|
|
|
|
|
|
|
|
def launch_tensor_parallel_group(
|
|
|
|
|
self,
|
|
|
|
|
server_args: ServerArgs,
|
|
|
|
|
port_args: PortArgs,
|
|
|
|
|
base_gpu_id: int,
|
|
|
|
|
dp_rank: int,
|
|
|
|
|
):
|
|
|
|
|
# Launch tensor parallel scheduler processes
|
|
|
|
|
scheduler_procs = []
|
|
|
|
|
scheduler_pipe_readers = []
|
|
|
|
|
tp_size_per_node = server_args.tp_size // server_args.nnodes
|
|
|
|
|
tp_rank_range = range(
|
|
|
|
|
tp_size_per_node * server_args.node_rank,
|
|
|
|
|
tp_size_per_node * (server_args.node_rank + 1),
|
|
|
|
|
)
|
|
|
|
|
for tp_rank in tp_rank_range:
|
|
|
|
|
reader, writer = mp.Pipe(duplex=False)
|
|
|
|
|
gpu_id = base_gpu_id + tp_rank % tp_size_per_node
|
|
|
|
|
proc = mp.Process(
|
|
|
|
|
target=run_scheduler_process,
|
|
|
|
|
args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
|
|
|
|
|
)
|
|
|
|
|
proc.start()
|
|
|
|
|
scheduler_procs.append(proc)
|
|
|
|
|
scheduler_pipe_readers.append(reader)
|
|
|
|
|
|
2024-10-25 23:07:07 -07:00
|
|
|
send_to = get_zmq_socket(
|
|
|
|
|
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
|
|
|
|
|
)
|
2024-10-11 07:22:48 -07:00
|
|
|
|
|
|
|
|
# Wait for model to finish loading
|
|
|
|
|
for i in range(len(scheduler_pipe_readers)):
|
|
|
|
|
scheduler_pipe_readers[i].recv()
|
|
|
|
|
|
|
|
|
|
return send_to
|
|
|
|
|
|
2024-11-16 17:01:43 +08:00
|
|
|
def launch_tensor_parallel_process(
|
|
|
|
|
self,
|
|
|
|
|
server_args: ServerArgs,
|
|
|
|
|
port_args: PortArgs,
|
|
|
|
|
base_gpu_id: int,
|
|
|
|
|
dp_rank: int,
|
|
|
|
|
):
|
|
|
|
|
reader, writer = mp.Pipe(duplex=False)
|
|
|
|
|
gpu_id = base_gpu_id
|
|
|
|
|
tp_rank = dp_rank
|
|
|
|
|
proc = mp.Process(
|
|
|
|
|
target=run_scheduler_process,
|
|
|
|
|
args=(server_args, port_args, gpu_id, tp_rank, dp_rank, writer),
|
|
|
|
|
)
|
|
|
|
|
proc.start()
|
|
|
|
|
send_to = get_zmq_socket(
|
|
|
|
|
self.context, zmq.PUSH, port_args.scheduler_input_ipc_name
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
return send_to, reader
|
|
|
|
|
|
2024-10-11 07:22:48 -07:00
|
|
|
def round_robin_scheduler(self, req):
|
|
|
|
|
self.workers[self.round_robin_counter].send_pyobj(req)
|
|
|
|
|
self.round_robin_counter = (self.round_robin_counter + 1) % len(self.workers)
|
|
|
|
|
|
|
|
|
|
def shortest_queue_scheduler(self, input_requests):
|
|
|
|
|
raise NotImplementedError()
|
|
|
|
|
|
|
|
|
|
def event_loop(self):
|
|
|
|
|
while True:
|
|
|
|
|
while True:
|
|
|
|
|
try:
|
|
|
|
|
recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
|
|
|
|
|
except zmq.ZMQError:
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
if isinstance(
|
|
|
|
|
recv_req,
|
|
|
|
|
(
|
|
|
|
|
TokenizedGenerateReqInput,
|
|
|
|
|
TokenizedEmbeddingReqInput,
|
|
|
|
|
),
|
|
|
|
|
):
|
|
|
|
|
self.dispatching(recv_req)
|
|
|
|
|
else:
|
|
|
|
|
# Send other control messages to all workers
|
|
|
|
|
for worker in self.workers:
|
2024-10-23 10:46:29 -07:00
|
|
|
worker.send_pyobj(recv_req)
|
2024-10-11 07:22:48 -07:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_data_parallel_controller_process(
|
|
|
|
|
server_args: ServerArgs,
|
|
|
|
|
port_args: PortArgs,
|
|
|
|
|
pipe_writer,
|
|
|
|
|
):
|
|
|
|
|
configure_logger(server_args)
|
|
|
|
|
suppress_other_loggers()
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
controller = DataParallelController(server_args, port_args)
|
|
|
|
|
pipe_writer.send("ready")
|
|
|
|
|
controller.event_loop()
|
|
|
|
|
except Exception:
|
|
|
|
|
msg = get_exception_traceback()
|
|
|
|
|
logger.error(msg)
|
|
|
|
|
kill_parent_process()
|