Support colocating requests (#7973)
This commit is contained in:
@@ -26,6 +26,7 @@ import zmq
|
||||
|
||||
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
||||
from sglang.srt.managers.io_struct import (
|
||||
BlockReqInput,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
)
|
||||
@@ -282,6 +283,9 @@ class DataParallelController:
|
||||
),
|
||||
):
|
||||
self.dispatching(recv_req)
|
||||
elif isinstance(recv_req, BlockReqInput):
|
||||
for worker in self.workers:
|
||||
worker.send_pyobj(recv_req)
|
||||
else:
|
||||
# Send other control messages to first worker of tp group
|
||||
for worker in self.workers[:: self.control_message_step]:
|
||||
|
||||
@@ -1103,3 +1103,13 @@ class LoRAUpdateResult:
|
||||
|
||||
|
||||
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
|
||||
|
||||
|
||||
class BlockReqType(Enum):
|
||||
BLOCK = 1
|
||||
UNBLOCK = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockReqInput:
|
||||
type: BlockReqType
|
||||
|
||||
@@ -123,6 +123,7 @@ from sglang.srt.managers.schedule_policy import (
|
||||
PrefillAdder,
|
||||
SchedulePolicy,
|
||||
)
|
||||
from sglang.srt.managers.scheduler_input_blocker import SchedulerInputBlocker
|
||||
from sglang.srt.managers.scheduler_output_processor_mixin import (
|
||||
SchedulerOutputProcessorMixin,
|
||||
)
|
||||
@@ -504,6 +505,12 @@ class Scheduler(
|
||||
)
|
||||
self.init_profier()
|
||||
|
||||
self.input_blocker = (
|
||||
SchedulerInputBlocker(noop=self.attn_tp_rank != 0)
|
||||
if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
|
||||
else None
|
||||
)
|
||||
|
||||
# Init metrics stats
|
||||
self.init_metrics(tp_rank, pp_rank, dp_rank)
|
||||
self.init_kv_events(server_args.kv_events_config)
|
||||
@@ -1035,6 +1042,9 @@ class Scheduler(
|
||||
else:
|
||||
recv_reqs = None
|
||||
|
||||
if self.input_blocker is not None:
|
||||
recv_reqs = self.input_blocker.handle(recv_reqs)
|
||||
|
||||
if self.server_args.enable_dp_attention:
|
||||
if self.attn_tp_rank == 0:
|
||||
work_reqs = [
|
||||
|
||||
106
python/sglang/srt/managers/scheduler_input_blocker.py
Normal file
106
python/sglang/srt/managers/scheduler_input_blocker.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum, auto
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from sglang.srt.managers.io_struct import BlockReqInput, BlockReqType
|
||||
from sglang.srt.poll_based_barrier import PollBasedBarrier
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SchedulerInputBlocker:
|
||||
def __init__(self, noop: bool):
|
||||
self._state = _State.UNBLOCKED
|
||||
self._pending_reqs = []
|
||||
self._noop = noop
|
||||
self._global_unblock_barrier = PollBasedBarrier(noop=noop)
|
||||
|
||||
def handle(self, recv_reqs: Optional[List[Any]]):
|
||||
assert (recv_reqs is None) == self._noop
|
||||
|
||||
if not self._noop:
|
||||
output_reqs = []
|
||||
for recv_req in recv_reqs:
|
||||
output_reqs += self._handle_recv_req(recv_req)
|
||||
|
||||
global_arrived_unblock_barrier = (
|
||||
self._global_unblock_barrier.poll_global_arrived()
|
||||
)
|
||||
if (
|
||||
self._state == _State.GLOBAL_UNBLOCK_BARRIER
|
||||
and global_arrived_unblock_barrier
|
||||
):
|
||||
output_reqs += self._handle_arrive_unblock_barrier()
|
||||
|
||||
if not self._noop:
|
||||
return output_reqs
|
||||
|
||||
def _handle_recv_req(self, recv_req):
|
||||
if isinstance(recv_req, BlockReqInput):
|
||||
if recv_req.type == BlockReqType.BLOCK:
|
||||
self._execute_block_req()
|
||||
return []
|
||||
elif recv_req.type == BlockReqType.UNBLOCK:
|
||||
self._execute_unblock_req()
|
||||
return []
|
||||
else:
|
||||
raise NotImplementedError(f"{recv_req=}")
|
||||
else:
|
||||
if self._state == _State.UNBLOCKED:
|
||||
return [recv_req]
|
||||
else:
|
||||
self._pending_reqs.append(recv_req)
|
||||
return []
|
||||
|
||||
def _execute_block_req(self):
|
||||
logger.info("Handle block req")
|
||||
self._change_state(original=_State.UNBLOCKED, target=_State.BLOCKED)
|
||||
|
||||
def _execute_unblock_req(self):
|
||||
logger.info("Handle unblock req")
|
||||
self._change_state(
|
||||
original=_State.BLOCKED, target=_State.GLOBAL_UNBLOCK_BARRIER
|
||||
)
|
||||
self._global_unblock_barrier.local_arrive()
|
||||
|
||||
def _handle_arrive_unblock_barrier(self):
|
||||
logger.info(f"Arrived at unblock barrier ({len(self._pending_reqs)=})")
|
||||
self._change_state(
|
||||
original=_State.GLOBAL_UNBLOCK_BARRIER, target=_State.UNBLOCKED
|
||||
)
|
||||
output_reqs = [*self._pending_reqs]
|
||||
self._pending_reqs.clear()
|
||||
return output_reqs
|
||||
|
||||
def _change_state(self, original: "_State", target: "_State"):
|
||||
assert self._state == original, f"{self._state=} {original=} {target=}"
|
||||
self._state = target
|
||||
|
||||
|
||||
class _State(Enum):
|
||||
UNBLOCKED = auto()
|
||||
BLOCKED = auto()
|
||||
GLOBAL_UNBLOCK_BARRIER = auto()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def input_blocker_guard_region(send_to_scheduler):
|
||||
send_to_scheduler.send_pyobj(BlockReqInput(BlockReqType.BLOCK))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
send_to_scheduler.send_pyobj(BlockReqInput(BlockReqType.UNBLOCK))
|
||||
@@ -27,6 +27,7 @@ import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections import deque
|
||||
from contextlib import nullcontext
|
||||
from datetime import datetime
|
||||
from http import HTTPStatus
|
||||
from typing import (
|
||||
@@ -69,6 +70,7 @@ from sglang.srt.managers.io_struct import (
|
||||
BatchMultimodalOut,
|
||||
BatchStrOut,
|
||||
BatchTokenIDOut,
|
||||
BlockReqType,
|
||||
CloseSessionReqInput,
|
||||
ConfigureLoggingReq,
|
||||
EmbeddingReqInput,
|
||||
@@ -114,6 +116,7 @@ from sglang.srt.managers.io_struct import (
|
||||
)
|
||||
from sglang.srt.managers.mm_utils import TensorTransportMode
|
||||
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
|
||||
from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
|
||||
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
@@ -819,12 +822,21 @@ class TokenizerManager:
|
||||
rids.append(tmp_obj.rid)
|
||||
else:
|
||||
# Sequential tokenization and processing
|
||||
for i in range(batch_size):
|
||||
tmp_obj = obj[i]
|
||||
tokenized_obj = await self._tokenize_one_request(tmp_obj)
|
||||
state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
||||
generators.append(self._wait_one_response(tmp_obj, state, request))
|
||||
rids.append(tmp_obj.rid)
|
||||
with (
|
||||
input_blocker_guard_region(send_to_scheduler=self.send_to_scheduler)
|
||||
if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
|
||||
else nullcontext()
|
||||
):
|
||||
for i in range(batch_size):
|
||||
tmp_obj = obj[i]
|
||||
tokenized_obj = await self._tokenize_one_request(tmp_obj)
|
||||
state = self._send_one_request(
|
||||
tmp_obj, tokenized_obj, created_time
|
||||
)
|
||||
generators.append(
|
||||
self._wait_one_response(tmp_obj, state, request)
|
||||
)
|
||||
rids.append(tmp_obj.rid)
|
||||
else:
|
||||
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
|
||||
if batch_size > 128:
|
||||
|
||||
31
python/sglang/srt/poll_based_barrier.py
Normal file
31
python/sglang/srt/poll_based_barrier.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import torch
|
||||
|
||||
from sglang.srt.distributed import get_world_group
|
||||
|
||||
|
||||
class PollBasedBarrier:
|
||||
def __init__(self, noop: bool = False):
|
||||
self._noop = noop
|
||||
self._local_arrived = False
|
||||
|
||||
def local_arrive(self):
|
||||
assert not self._local_arrived
|
||||
self._local_arrived = True
|
||||
|
||||
def poll_global_arrived(self) -> bool:
|
||||
global_arrived = self._compute_global_arrived()
|
||||
output = self._local_arrived and global_arrived
|
||||
if output:
|
||||
self._local_arrived = False
|
||||
return output
|
||||
|
||||
def _compute_global_arrived(self) -> bool:
|
||||
local_arrived = self._noop or self._local_arrived
|
||||
global_arrived = torch.tensor(local_arrived)
|
||||
# Can optimize if bottleneck
|
||||
torch.distributed.all_reduce(
|
||||
global_arrived,
|
||||
torch.distributed.ReduceOp.MIN,
|
||||
group=get_world_group().cpu_group,
|
||||
)
|
||||
return global_arrived.item()
|
||||
Reference in New Issue
Block a user