diff --git a/python/sglang/srt/managers/data_parallel_controller.py b/python/sglang/srt/managers/data_parallel_controller.py index 62c3800c2..1e2bfbf10 100644 --- a/python/sglang/srt/managers/data_parallel_controller.py +++ b/python/sglang/srt/managers/data_parallel_controller.py @@ -26,6 +26,7 @@ import zmq from sglang.srt.layers.dp_attention import compute_dp_attention_world_info from sglang.srt.managers.io_struct import ( + BlockReqInput, TokenizedEmbeddingReqInput, TokenizedGenerateReqInput, ) @@ -282,6 +283,9 @@ class DataParallelController: ), ): self.dispatching(recv_req) + elif isinstance(recv_req, BlockReqInput): + for worker in self.workers: + worker.send_pyobj(recv_req) else: # Send other control messages to first worker of tp group for worker in self.workers[:: self.control_message_step]: diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 377205e67..773e0c57d 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -1103,3 +1103,13 @@ class LoRAUpdateResult: LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult + + +class BlockReqType(Enum): + BLOCK = 1 + UNBLOCK = 2 + + +@dataclass +class BlockReqInput: + type: BlockReqType diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 5d3d115e2..656bf7684 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -123,6 +123,7 @@ from sglang.srt.managers.schedule_policy import ( PrefillAdder, SchedulePolicy, ) +from sglang.srt.managers.scheduler_input_blocker import SchedulerInputBlocker from sglang.srt.managers.scheduler_output_processor_mixin import ( SchedulerOutputProcessorMixin, ) @@ -504,6 +505,12 @@ class Scheduler( ) self.init_profier() + self.input_blocker = ( + SchedulerInputBlocker(noop=self.attn_tp_rank != 0) + if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN") + else None + ) + # Init metrics stats self.init_metrics(tp_rank, pp_rank, dp_rank) self.init_kv_events(server_args.kv_events_config) @@ -1035,6 +1042,9 @@ class Scheduler( else: recv_reqs = None + if self.input_blocker is not None: + recv_reqs = self.input_blocker.handle(recv_reqs) + if self.server_args.enable_dp_attention: if self.attn_tp_rank == 0: work_reqs = [ diff --git a/python/sglang/srt/managers/scheduler_input_blocker.py b/python/sglang/srt/managers/scheduler_input_blocker.py new file mode 100644 index 000000000..60ae8d5d6 --- /dev/null +++ b/python/sglang/srt/managers/scheduler_input_blocker.py @@ -0,0 +1,106 @@ +# Copyright 2023-2024 SGLang Team +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +import logging +from contextlib import contextmanager +from enum import Enum, auto +from typing import Any, List, Optional + +from sglang.srt.managers.io_struct import BlockReqInput, BlockReqType +from sglang.srt.poll_based_barrier import PollBasedBarrier + +logger = logging.getLogger(__name__) + + +class SchedulerInputBlocker: + def __init__(self, noop: bool): + self._state = _State.UNBLOCKED + self._pending_reqs = [] + self._noop = noop + self._global_unblock_barrier = PollBasedBarrier(noop=noop) + + def handle(self, recv_reqs: Optional[List[Any]]): + assert (recv_reqs is None) == self._noop + + if not self._noop: + output_reqs = [] + for recv_req in recv_reqs: + output_reqs += self._handle_recv_req(recv_req) + + global_arrived_unblock_barrier = ( + self._global_unblock_barrier.poll_global_arrived() + ) + if ( + self._state == _State.GLOBAL_UNBLOCK_BARRIER + and global_arrived_unblock_barrier + ): + output_reqs += self._handle_arrive_unblock_barrier() + + if not self._noop: + return output_reqs + + def _handle_recv_req(self, recv_req): + if isinstance(recv_req, BlockReqInput): + if recv_req.type == BlockReqType.BLOCK: + self._execute_block_req() + return [] + elif recv_req.type == BlockReqType.UNBLOCK: + self._execute_unblock_req() + return [] + else: + raise NotImplementedError(f"{recv_req=}") + else: + if self._state == _State.UNBLOCKED: + return [recv_req] + else: + self._pending_reqs.append(recv_req) + return [] + + def _execute_block_req(self): + logger.info("Handle block req") + self._change_state(original=_State.UNBLOCKED, target=_State.BLOCKED) + + def _execute_unblock_req(self): + logger.info("Handle unblock req") + self._change_state( + original=_State.BLOCKED, target=_State.GLOBAL_UNBLOCK_BARRIER + ) + self._global_unblock_barrier.local_arrive() + + def _handle_arrive_unblock_barrier(self): + logger.info(f"Arrived at unblock barrier ({len(self._pending_reqs)=})") + self._change_state( + original=_State.GLOBAL_UNBLOCK_BARRIER, target=_State.UNBLOCKED + ) + output_reqs = [*self._pending_reqs] + self._pending_reqs.clear() + return output_reqs + + def _change_state(self, original: "_State", target: "_State"): + assert self._state == original, f"{self._state=} {original=} {target=}" + self._state = target + + +class _State(Enum): + UNBLOCKED = auto() + BLOCKED = auto() + GLOBAL_UNBLOCK_BARRIER = auto() + + +@contextmanager +def input_blocker_guard_region(send_to_scheduler): + send_to_scheduler.send_pyobj(BlockReqInput(BlockReqType.BLOCK)) + try: + yield + finally: + send_to_scheduler.send_pyobj(BlockReqInput(BlockReqType.UNBLOCK)) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index c998b51c9..700e62ed4 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -27,6 +27,7 @@ import threading import time import uuid from collections import deque +from contextlib import nullcontext from datetime import datetime from http import HTTPStatus from typing import ( @@ -69,6 +70,7 @@ from sglang.srt.managers.io_struct import ( BatchMultimodalOut, BatchStrOut, BatchTokenIDOut, + BlockReqType, CloseSessionReqInput, ConfigureLoggingReq, EmbeddingReqInput, @@ -114,6 +116,7 @@ from sglang.srt.managers.io_struct import ( ) from sglang.srt.managers.mm_utils import TensorTransportMode from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors +from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region from sglang.srt.metrics.collector import TokenizerMetricsCollector from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import PortArgs, ServerArgs @@ -819,12 +822,21 @@ class TokenizerManager: rids.append(tmp_obj.rid) else: # Sequential tokenization and processing - for i in range(batch_size): - tmp_obj = obj[i] - tokenized_obj = await self._tokenize_one_request(tmp_obj) - state = self._send_one_request(tmp_obj, tokenized_obj, created_time) - generators.append(self._wait_one_response(tmp_obj, state, request)) - rids.append(tmp_obj.rid) + with ( + input_blocker_guard_region(send_to_scheduler=self.send_to_scheduler) + if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN") + else nullcontext() + ): + for i in range(batch_size): + tmp_obj = obj[i] + tokenized_obj = await self._tokenize_one_request(tmp_obj) + state = self._send_one_request( + tmp_obj, tokenized_obj, created_time + ) + generators.append( + self._wait_one_response(tmp_obj, state, request) + ) + rids.append(tmp_obj.rid) else: # FIXME: When using batch and parallel_sample_num together, the perf is not optimal. if batch_size > 128: diff --git a/python/sglang/srt/poll_based_barrier.py b/python/sglang/srt/poll_based_barrier.py new file mode 100644 index 000000000..db1d22763 --- /dev/null +++ b/python/sglang/srt/poll_based_barrier.py @@ -0,0 +1,31 @@ +import torch + +from sglang.srt.distributed import get_world_group + + +class PollBasedBarrier: + def __init__(self, noop: bool = False): + self._noop = noop + self._local_arrived = False + + def local_arrive(self): + assert not self._local_arrived + self._local_arrived = True + + def poll_global_arrived(self) -> bool: + global_arrived = self._compute_global_arrived() + output = self._local_arrived and global_arrived + if output: + self._local_arrived = False + return output + + def _compute_global_arrived(self) -> bool: + local_arrived = self._noop or self._local_arrived + global_arrived = torch.tensor(local_arrived) + # Can optimize if bottleneck + torch.distributed.all_reduce( + global_arrived, + torch.distributed.ReduceOp.MIN, + group=get_world_group().cpu_group, + ) + return global_arrived.item()