From 9c5cac24506a230f487659a97b7cf09c920bb480 Mon Sep 17 00:00:00 2001 From: zhyncs Date: Thu, 18 Jul 2024 20:33:21 +1000 Subject: [PATCH] fix: resolve lint error (#650) --- .github/pull_request_template.md | 2 +- .../srt/managers/controller/manager_multi.py | 21 ++++++++++++------- .../srt/managers/controller/manager_single.py | 20 +++++++++++++----- 3 files changed, 30 insertions(+), 13 deletions(-) diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 147086400..20f4a10bc 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -10,6 +10,6 @@ Briefly describe the changes made in this PR. ## Checklist -1. Ensure pre-commit or other linting tools are used to fix potential lint issues. +1. Ensure pre-commit `pre-commit run --all-files` or other linting tools are used to fix potential lint issues. 2. Confirm that modifications are covered by complete unit tests. If not, please add more unit tests for correctness. 3. Modify documentation as needed, such as docstrings or example tutorials. diff --git a/python/sglang/srt/managers/controller/manager_multi.py b/python/sglang/srt/managers/controller/manager_multi.py index 188ee0e20..f24cbc116 100644 --- a/python/sglang/srt/managers/controller/manager_multi.py +++ b/python/sglang/srt/managers/controller/manager_multi.py @@ -29,6 +29,7 @@ logger = logging.getLogger("srt.controller") class LoadBalanceMethod(Enum): """Load balance method.""" + ROUND_ROBIN = auto() SHORTEST_QUEUE = auto() @@ -44,6 +45,7 @@ class LoadBalanceMethod(Enum): @dataclasses.dataclass class WorkerHandle: """Store the handle of a data parallel worker.""" + proc: multiprocessing.Process queue: multiprocessing.Queue @@ -62,7 +64,8 @@ class ControllerMulti: self.port_args = port_args self.model_overide_args = model_overide_args self.load_balance_method = LoadBalanceMethod.from_str( - server_args.load_balance_method) + server_args.load_balance_method + ) # Init communication context = zmq.Context() @@ -85,7 +88,9 @@ class ControllerMulti: def start_dp_worker(self, dp_worker_id: int): tp_size = self.server_args.tp_size - pipe_controller_reader, pipe_controller_writer = multiprocessing.Pipe(duplex=False) + pipe_controller_reader, pipe_controller_writer = multiprocessing.Pipe( + duplex=False + ) gpu_ids = list(range(dp_worker_id * tp_size, (dp_worker_id + 1) * tp_size)) queue = multiprocessing.Queue() @@ -100,7 +105,7 @@ class ControllerMulti: gpu_ids, dp_worker_id, queue, - ) + ), ) proc.start() @@ -109,10 +114,12 @@ class ControllerMulti: raise RuntimeError( f"Initialization failed. controller_init_state: {controller_init_state}" ) - self.workers.append(WorkerHandle( - proc=proc, - queue=queue, - )) + self.workers.append( + WorkerHandle( + proc=proc, + queue=queue, + ) + ) def round_robin_scheduler(self, input_requests): for r in input_requests: diff --git a/python/sglang/srt/managers/controller/manager_single.py b/python/sglang/srt/managers/controller/manager_single.py index 9326945f9..e9eff6876 100644 --- a/python/sglang/srt/managers/controller/manager_single.py +++ b/python/sglang/srt/managers/controller/manager_single.py @@ -8,7 +8,9 @@ from typing import List import zmq from sglang.srt.managers.controller.tp_worker import ( - broadcast_recv_input, launch_tp_servers, ModelTpServer + ModelTpServer, + broadcast_recv_input, + launch_tp_servers, ) from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.utils import kill_parent_process @@ -41,7 +43,9 @@ class ControllerSingle: if not self.is_dp_worker: self.recv_from_tokenizer = context.socket(zmq.PULL) - self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.controller_port}") + self.recv_from_tokenizer.bind( + f"tcp://127.0.0.1:{port_args.controller_port}" + ) self.send_to_detokenizer = context.socket(zmq.PUSH) self.send_to_detokenizer.connect( @@ -128,9 +132,15 @@ def start_controller_process( queue = None try: - controller = ControllerSingle(server_args, port_args, model_overide_args, - gpu_ids, is_data_parallel_worker, - dp_worker_id, queue) + controller = ControllerSingle( + server_args, + port_args, + model_overide_args, + gpu_ids, + is_data_parallel_worker, + dp_worker_id, + queue, + ) except Exception: pipe_writer.send(get_exception_traceback()) raise