fix: resolve lint error (#650)
This commit is contained in:
@@ -29,6 +29,7 @@ logger = logging.getLogger("srt.controller")
|
||||
|
||||
class LoadBalanceMethod(Enum):
|
||||
"""Load balance method."""
|
||||
|
||||
ROUND_ROBIN = auto()
|
||||
SHORTEST_QUEUE = auto()
|
||||
|
||||
@@ -44,6 +45,7 @@ class LoadBalanceMethod(Enum):
|
||||
@dataclasses.dataclass
|
||||
class WorkerHandle:
|
||||
"""Store the handle of a data parallel worker."""
|
||||
|
||||
proc: multiprocessing.Process
|
||||
queue: multiprocessing.Queue
|
||||
|
||||
@@ -62,7 +64,8 @@ class ControllerMulti:
|
||||
self.port_args = port_args
|
||||
self.model_overide_args = model_overide_args
|
||||
self.load_balance_method = LoadBalanceMethod.from_str(
|
||||
server_args.load_balance_method)
|
||||
server_args.load_balance_method
|
||||
)
|
||||
|
||||
# Init communication
|
||||
context = zmq.Context()
|
||||
@@ -85,7 +88,9 @@ class ControllerMulti:
|
||||
def start_dp_worker(self, dp_worker_id: int):
|
||||
tp_size = self.server_args.tp_size
|
||||
|
||||
pipe_controller_reader, pipe_controller_writer = multiprocessing.Pipe(duplex=False)
|
||||
pipe_controller_reader, pipe_controller_writer = multiprocessing.Pipe(
|
||||
duplex=False
|
||||
)
|
||||
|
||||
gpu_ids = list(range(dp_worker_id * tp_size, (dp_worker_id + 1) * tp_size))
|
||||
queue = multiprocessing.Queue()
|
||||
@@ -100,7 +105,7 @@ class ControllerMulti:
|
||||
gpu_ids,
|
||||
dp_worker_id,
|
||||
queue,
|
||||
)
|
||||
),
|
||||
)
|
||||
proc.start()
|
||||
|
||||
@@ -109,10 +114,12 @@ class ControllerMulti:
|
||||
raise RuntimeError(
|
||||
f"Initialization failed. controller_init_state: {controller_init_state}"
|
||||
)
|
||||
self.workers.append(WorkerHandle(
|
||||
proc=proc,
|
||||
queue=queue,
|
||||
))
|
||||
self.workers.append(
|
||||
WorkerHandle(
|
||||
proc=proc,
|
||||
queue=queue,
|
||||
)
|
||||
)
|
||||
|
||||
def round_robin_scheduler(self, input_requests):
|
||||
for r in input_requests:
|
||||
|
||||
@@ -8,7 +8,9 @@ from typing import List
|
||||
import zmq
|
||||
|
||||
from sglang.srt.managers.controller.tp_worker import (
|
||||
broadcast_recv_input, launch_tp_servers, ModelTpServer
|
||||
ModelTpServer,
|
||||
broadcast_recv_input,
|
||||
launch_tp_servers,
|
||||
)
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.utils import kill_parent_process
|
||||
@@ -41,7 +43,9 @@ class ControllerSingle:
|
||||
|
||||
if not self.is_dp_worker:
|
||||
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
||||
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.controller_port}")
|
||||
self.recv_from_tokenizer.bind(
|
||||
f"tcp://127.0.0.1:{port_args.controller_port}"
|
||||
)
|
||||
|
||||
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
||||
self.send_to_detokenizer.connect(
|
||||
@@ -128,9 +132,15 @@ def start_controller_process(
|
||||
queue = None
|
||||
|
||||
try:
|
||||
controller = ControllerSingle(server_args, port_args, model_overide_args,
|
||||
gpu_ids, is_data_parallel_worker,
|
||||
dp_worker_id, queue)
|
||||
controller = ControllerSingle(
|
||||
server_args,
|
||||
port_args,
|
||||
model_overide_args,
|
||||
gpu_ids,
|
||||
is_data_parallel_worker,
|
||||
dp_worker_id,
|
||||
queue,
|
||||
)
|
||||
except Exception:
|
||||
pipe_writer.send(get_exception_traceback())
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user