fix: resolve lint error (#650)

This commit is contained in:
zhyncs
2024-07-18 20:33:21 +10:00
committed by GitHub
parent 5960a6e505
commit 9c5cac2450
3 changed files with 30 additions and 13 deletions

View File

@@ -10,6 +10,6 @@ Briefly describe the changes made in this PR.
## Checklist
1. Ensure pre-commit or other linting tools are used to fix potential lint issues.
1. Ensure pre-commit `pre-commit run --all-files` or other linting tools are used to fix potential lint issues.
2. Confirm that modifications are covered by complete unit tests. If not, please add more unit tests for correctness.
3. Modify documentation as needed, such as docstrings or example tutorials.

View File

@@ -29,6 +29,7 @@ logger = logging.getLogger("srt.controller")
class LoadBalanceMethod(Enum):
"""Load balance method."""
ROUND_ROBIN = auto()
SHORTEST_QUEUE = auto()
@@ -44,6 +45,7 @@ class LoadBalanceMethod(Enum):
@dataclasses.dataclass
class WorkerHandle:
"""Store the handle of a data parallel worker."""
proc: multiprocessing.Process
queue: multiprocessing.Queue
@@ -62,7 +64,8 @@ class ControllerMulti:
self.port_args = port_args
self.model_overide_args = model_overide_args
self.load_balance_method = LoadBalanceMethod.from_str(
server_args.load_balance_method)
server_args.load_balance_method
)
# Init communication
context = zmq.Context()
@@ -85,7 +88,9 @@ class ControllerMulti:
def start_dp_worker(self, dp_worker_id: int):
tp_size = self.server_args.tp_size
pipe_controller_reader, pipe_controller_writer = multiprocessing.Pipe(duplex=False)
pipe_controller_reader, pipe_controller_writer = multiprocessing.Pipe(
duplex=False
)
gpu_ids = list(range(dp_worker_id * tp_size, (dp_worker_id + 1) * tp_size))
queue = multiprocessing.Queue()
@@ -100,7 +105,7 @@ class ControllerMulti:
gpu_ids,
dp_worker_id,
queue,
)
),
)
proc.start()
@@ -109,10 +114,12 @@ class ControllerMulti:
raise RuntimeError(
f"Initialization failed. controller_init_state: {controller_init_state}"
)
self.workers.append(WorkerHandle(
proc=proc,
queue=queue,
))
self.workers.append(
WorkerHandle(
proc=proc,
queue=queue,
)
)
def round_robin_scheduler(self, input_requests):
for r in input_requests:

View File

@@ -8,7 +8,9 @@ from typing import List
import zmq
from sglang.srt.managers.controller.tp_worker import (
broadcast_recv_input, launch_tp_servers, ModelTpServer
ModelTpServer,
broadcast_recv_input,
launch_tp_servers,
)
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.utils import kill_parent_process
@@ -41,7 +43,9 @@ class ControllerSingle:
if not self.is_dp_worker:
self.recv_from_tokenizer = context.socket(zmq.PULL)
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.controller_port}")
self.recv_from_tokenizer.bind(
f"tcp://127.0.0.1:{port_args.controller_port}"
)
self.send_to_detokenizer = context.socket(zmq.PUSH)
self.send_to_detokenizer.connect(
@@ -128,9 +132,15 @@ def start_controller_process(
queue = None
try:
controller = ControllerSingle(server_args, port_args, model_overide_args,
gpu_ids, is_data_parallel_worker,
dp_worker_id, queue)
controller = ControllerSingle(
server_args,
port_args,
model_overide_args,
gpu_ids,
is_data_parallel_worker,
dp_worker_id,
queue,
)
except Exception:
pipe_writer.send(get_exception_traceback())
raise