fix: resolve lint error (#650)
This commit is contained in:
2
.github/pull_request_template.md
vendored
2
.github/pull_request_template.md
vendored
@@ -10,6 +10,6 @@ Briefly describe the changes made in this PR.
|
|||||||
|
|
||||||
## Checklist
|
## Checklist
|
||||||
|
|
||||||
1. Ensure pre-commit or other linting tools are used to fix potential lint issues.
|
1. Ensure pre-commit `pre-commit run --all-files` or other linting tools are used to fix potential lint issues.
|
||||||
2. Confirm that modifications are covered by complete unit tests. If not, please add more unit tests for correctness.
|
2. Confirm that modifications are covered by complete unit tests. If not, please add more unit tests for correctness.
|
||||||
3. Modify documentation as needed, such as docstrings or example tutorials.
|
3. Modify documentation as needed, such as docstrings or example tutorials.
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ logger = logging.getLogger("srt.controller")
|
|||||||
|
|
||||||
class LoadBalanceMethod(Enum):
|
class LoadBalanceMethod(Enum):
|
||||||
"""Load balance method."""
|
"""Load balance method."""
|
||||||
|
|
||||||
ROUND_ROBIN = auto()
|
ROUND_ROBIN = auto()
|
||||||
SHORTEST_QUEUE = auto()
|
SHORTEST_QUEUE = auto()
|
||||||
|
|
||||||
@@ -44,6 +45,7 @@ class LoadBalanceMethod(Enum):
|
|||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class WorkerHandle:
|
class WorkerHandle:
|
||||||
"""Store the handle of a data parallel worker."""
|
"""Store the handle of a data parallel worker."""
|
||||||
|
|
||||||
proc: multiprocessing.Process
|
proc: multiprocessing.Process
|
||||||
queue: multiprocessing.Queue
|
queue: multiprocessing.Queue
|
||||||
|
|
||||||
@@ -62,7 +64,8 @@ class ControllerMulti:
|
|||||||
self.port_args = port_args
|
self.port_args = port_args
|
||||||
self.model_overide_args = model_overide_args
|
self.model_overide_args = model_overide_args
|
||||||
self.load_balance_method = LoadBalanceMethod.from_str(
|
self.load_balance_method = LoadBalanceMethod.from_str(
|
||||||
server_args.load_balance_method)
|
server_args.load_balance_method
|
||||||
|
)
|
||||||
|
|
||||||
# Init communication
|
# Init communication
|
||||||
context = zmq.Context()
|
context = zmq.Context()
|
||||||
@@ -85,7 +88,9 @@ class ControllerMulti:
|
|||||||
def start_dp_worker(self, dp_worker_id: int):
|
def start_dp_worker(self, dp_worker_id: int):
|
||||||
tp_size = self.server_args.tp_size
|
tp_size = self.server_args.tp_size
|
||||||
|
|
||||||
pipe_controller_reader, pipe_controller_writer = multiprocessing.Pipe(duplex=False)
|
pipe_controller_reader, pipe_controller_writer = multiprocessing.Pipe(
|
||||||
|
duplex=False
|
||||||
|
)
|
||||||
|
|
||||||
gpu_ids = list(range(dp_worker_id * tp_size, (dp_worker_id + 1) * tp_size))
|
gpu_ids = list(range(dp_worker_id * tp_size, (dp_worker_id + 1) * tp_size))
|
||||||
queue = multiprocessing.Queue()
|
queue = multiprocessing.Queue()
|
||||||
@@ -100,7 +105,7 @@ class ControllerMulti:
|
|||||||
gpu_ids,
|
gpu_ids,
|
||||||
dp_worker_id,
|
dp_worker_id,
|
||||||
queue,
|
queue,
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
proc.start()
|
proc.start()
|
||||||
|
|
||||||
@@ -109,10 +114,12 @@ class ControllerMulti:
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Initialization failed. controller_init_state: {controller_init_state}"
|
f"Initialization failed. controller_init_state: {controller_init_state}"
|
||||||
)
|
)
|
||||||
self.workers.append(WorkerHandle(
|
self.workers.append(
|
||||||
proc=proc,
|
WorkerHandle(
|
||||||
queue=queue,
|
proc=proc,
|
||||||
))
|
queue=queue,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
def round_robin_scheduler(self, input_requests):
|
def round_robin_scheduler(self, input_requests):
|
||||||
for r in input_requests:
|
for r in input_requests:
|
||||||
|
|||||||
@@ -8,7 +8,9 @@ from typing import List
|
|||||||
import zmq
|
import zmq
|
||||||
|
|
||||||
from sglang.srt.managers.controller.tp_worker import (
|
from sglang.srt.managers.controller.tp_worker import (
|
||||||
broadcast_recv_input, launch_tp_servers, ModelTpServer
|
ModelTpServer,
|
||||||
|
broadcast_recv_input,
|
||||||
|
launch_tp_servers,
|
||||||
)
|
)
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
from sglang.srt.utils import kill_parent_process
|
from sglang.srt.utils import kill_parent_process
|
||||||
@@ -41,7 +43,9 @@ class ControllerSingle:
|
|||||||
|
|
||||||
if not self.is_dp_worker:
|
if not self.is_dp_worker:
|
||||||
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
self.recv_from_tokenizer = context.socket(zmq.PULL)
|
||||||
self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.controller_port}")
|
self.recv_from_tokenizer.bind(
|
||||||
|
f"tcp://127.0.0.1:{port_args.controller_port}"
|
||||||
|
)
|
||||||
|
|
||||||
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
self.send_to_detokenizer = context.socket(zmq.PUSH)
|
||||||
self.send_to_detokenizer.connect(
|
self.send_to_detokenizer.connect(
|
||||||
@@ -128,9 +132,15 @@ def start_controller_process(
|
|||||||
queue = None
|
queue = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
controller = ControllerSingle(server_args, port_args, model_overide_args,
|
controller = ControllerSingle(
|
||||||
gpu_ids, is_data_parallel_worker,
|
server_args,
|
||||||
dp_worker_id, queue)
|
port_args,
|
||||||
|
model_overide_args,
|
||||||
|
gpu_ids,
|
||||||
|
is_data_parallel_worker,
|
||||||
|
dp_worker_id,
|
||||||
|
queue,
|
||||||
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pipe_writer.send(get_exception_traceback())
|
pipe_writer.send(get_exception_traceback())
|
||||||
raise
|
raise
|
||||||
|
|||||||
Reference in New Issue
Block a user