fix: resolve lint error (#650)

This commit is contained in:
zhyncs
2024-07-18 20:33:21 +10:00
committed by GitHub
parent 5960a6e505
commit 9c5cac2450
3 changed files with 30 additions and 13 deletions

View File

@@ -29,6 +29,7 @@ logger = logging.getLogger("srt.controller")
class LoadBalanceMethod(Enum):
"""Load balance method."""
ROUND_ROBIN = auto()
SHORTEST_QUEUE = auto()
@@ -44,6 +45,7 @@ class LoadBalanceMethod(Enum):
@dataclasses.dataclass
class WorkerHandle:
"""Store the handle of a data parallel worker."""
proc: multiprocessing.Process
queue: multiprocessing.Queue
@@ -62,7 +64,8 @@ class ControllerMulti:
self.port_args = port_args
self.model_overide_args = model_overide_args
self.load_balance_method = LoadBalanceMethod.from_str(
server_args.load_balance_method)
server_args.load_balance_method
)
# Init communication
context = zmq.Context()
@@ -85,7 +88,9 @@ class ControllerMulti:
def start_dp_worker(self, dp_worker_id: int):
tp_size = self.server_args.tp_size
pipe_controller_reader, pipe_controller_writer = multiprocessing.Pipe(duplex=False)
pipe_controller_reader, pipe_controller_writer = multiprocessing.Pipe(
duplex=False
)
gpu_ids = list(range(dp_worker_id * tp_size, (dp_worker_id + 1) * tp_size))
queue = multiprocessing.Queue()
@@ -100,7 +105,7 @@ class ControllerMulti:
gpu_ids,
dp_worker_id,
queue,
)
),
)
proc.start()
@@ -109,10 +114,12 @@ class ControllerMulti:
raise RuntimeError(
f"Initialization failed. controller_init_state: {controller_init_state}"
)
self.workers.append(WorkerHandle(
proc=proc,
queue=queue,
))
self.workers.append(
WorkerHandle(
proc=proc,
queue=queue,
)
)
def round_robin_scheduler(self, input_requests):
for r in input_requests:

View File

@@ -8,7 +8,9 @@ from typing import List
import zmq
from sglang.srt.managers.controller.tp_worker import (
broadcast_recv_input, launch_tp_servers, ModelTpServer
ModelTpServer,
broadcast_recv_input,
launch_tp_servers,
)
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import kill_parent_process
@@ -41,7 +43,9 @@ class ControllerSingle:
if not self.is_dp_worker:
self.recv_from_tokenizer = context.socket(zmq.PULL)
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.controller_port}")
self.recv_from_tokenizer.bind(
f"tcp://127.0.0.1:{port_args.controller_port}"
)
self.send_to_detokenizer = context.socket(zmq.PUSH)
self.send_to_detokenizer.connect(
@@ -128,9 +132,15 @@ def start_controller_process(
queue = None
try:
controller = ControllerSingle(server_args, port_args, model_overide_args,
gpu_ids, is_data_parallel_worker,
dp_worker_id, queue)
controller = ControllerSingle(
server_args,
port_args,
model_overide_args,
gpu_ids,
is_data_parallel_worker,
dp_worker_id,
queue,
)
except Exception:
pipe_writer.send(get_exception_traceback())
raise