Support colocating requests (#7973)
This commit is contained in:
@@ -26,6 +26,7 @@ import zmq
|
|||||||
|
|
||||||
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
||||||
from sglang.srt.managers.io_struct import (
|
from sglang.srt.managers.io_struct import (
|
||||||
|
BlockReqInput,
|
||||||
TokenizedEmbeddingReqInput,
|
TokenizedEmbeddingReqInput,
|
||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
)
|
)
|
||||||
@@ -282,6 +283,9 @@ class DataParallelController:
|
|||||||
),
|
),
|
||||||
):
|
):
|
||||||
self.dispatching(recv_req)
|
self.dispatching(recv_req)
|
||||||
|
elif isinstance(recv_req, BlockReqInput):
|
||||||
|
for worker in self.workers:
|
||||||
|
worker.send_pyobj(recv_req)
|
||||||
else:
|
else:
|
||||||
# Send other control messages to first worker of tp group
|
# Send other control messages to first worker of tp group
|
||||||
for worker in self.workers[:: self.control_message_step]:
|
for worker in self.workers[:: self.control_message_step]:
|
||||||
|
|||||||
@@ -1103,3 +1103,13 @@ class LoRAUpdateResult:
|
|||||||
|
|
||||||
|
|
||||||
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
|
LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
|
||||||
|
|
||||||
|
|
||||||
|
class BlockReqType(Enum):
|
||||||
|
BLOCK = 1
|
||||||
|
UNBLOCK = 2
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BlockReqInput:
|
||||||
|
type: BlockReqType
|
||||||
|
|||||||
@@ -123,6 +123,7 @@ from sglang.srt.managers.schedule_policy import (
|
|||||||
PrefillAdder,
|
PrefillAdder,
|
||||||
SchedulePolicy,
|
SchedulePolicy,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.managers.scheduler_input_blocker import SchedulerInputBlocker
|
||||||
from sglang.srt.managers.scheduler_output_processor_mixin import (
|
from sglang.srt.managers.scheduler_output_processor_mixin import (
|
||||||
SchedulerOutputProcessorMixin,
|
SchedulerOutputProcessorMixin,
|
||||||
)
|
)
|
||||||
@@ -504,6 +505,12 @@ class Scheduler(
|
|||||||
)
|
)
|
||||||
self.init_profier()
|
self.init_profier()
|
||||||
|
|
||||||
|
self.input_blocker = (
|
||||||
|
SchedulerInputBlocker(noop=self.attn_tp_rank != 0)
|
||||||
|
if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
# Init metrics stats
|
# Init metrics stats
|
||||||
self.init_metrics(tp_rank, pp_rank, dp_rank)
|
self.init_metrics(tp_rank, pp_rank, dp_rank)
|
||||||
self.init_kv_events(server_args.kv_events_config)
|
self.init_kv_events(server_args.kv_events_config)
|
||||||
@@ -1035,6 +1042,9 @@ class Scheduler(
|
|||||||
else:
|
else:
|
||||||
recv_reqs = None
|
recv_reqs = None
|
||||||
|
|
||||||
|
if self.input_blocker is not None:
|
||||||
|
recv_reqs = self.input_blocker.handle(recv_reqs)
|
||||||
|
|
||||||
if self.server_args.enable_dp_attention:
|
if self.server_args.enable_dp_attention:
|
||||||
if self.attn_tp_rank == 0:
|
if self.attn_tp_rank == 0:
|
||||||
work_reqs = [
|
work_reqs = [
|
||||||
|
|||||||
106
python/sglang/srt/managers/scheduler_input_blocker.py
Normal file
106
python/sglang/srt/managers/scheduler_input_blocker.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
# Copyright 2023-2024 SGLang Team
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
import logging
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from enum import Enum, auto
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
from sglang.srt.managers.io_struct import BlockReqInput, BlockReqType
|
||||||
|
from sglang.srt.poll_based_barrier import PollBasedBarrier
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SchedulerInputBlocker:
|
||||||
|
def __init__(self, noop: bool):
|
||||||
|
self._state = _State.UNBLOCKED
|
||||||
|
self._pending_reqs = []
|
||||||
|
self._noop = noop
|
||||||
|
self._global_unblock_barrier = PollBasedBarrier(noop=noop)
|
||||||
|
|
||||||
|
def handle(self, recv_reqs: Optional[List[Any]]):
|
||||||
|
assert (recv_reqs is None) == self._noop
|
||||||
|
|
||||||
|
if not self._noop:
|
||||||
|
output_reqs = []
|
||||||
|
for recv_req in recv_reqs:
|
||||||
|
output_reqs += self._handle_recv_req(recv_req)
|
||||||
|
|
||||||
|
global_arrived_unblock_barrier = (
|
||||||
|
self._global_unblock_barrier.poll_global_arrived()
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
self._state == _State.GLOBAL_UNBLOCK_BARRIER
|
||||||
|
and global_arrived_unblock_barrier
|
||||||
|
):
|
||||||
|
output_reqs += self._handle_arrive_unblock_barrier()
|
||||||
|
|
||||||
|
if not self._noop:
|
||||||
|
return output_reqs
|
||||||
|
|
||||||
|
def _handle_recv_req(self, recv_req):
|
||||||
|
if isinstance(recv_req, BlockReqInput):
|
||||||
|
if recv_req.type == BlockReqType.BLOCK:
|
||||||
|
self._execute_block_req()
|
||||||
|
return []
|
||||||
|
elif recv_req.type == BlockReqType.UNBLOCK:
|
||||||
|
self._execute_unblock_req()
|
||||||
|
return []
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"{recv_req=}")
|
||||||
|
else:
|
||||||
|
if self._state == _State.UNBLOCKED:
|
||||||
|
return [recv_req]
|
||||||
|
else:
|
||||||
|
self._pending_reqs.append(recv_req)
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _execute_block_req(self):
|
||||||
|
logger.info("Handle block req")
|
||||||
|
self._change_state(original=_State.UNBLOCKED, target=_State.BLOCKED)
|
||||||
|
|
||||||
|
def _execute_unblock_req(self):
|
||||||
|
logger.info("Handle unblock req")
|
||||||
|
self._change_state(
|
||||||
|
original=_State.BLOCKED, target=_State.GLOBAL_UNBLOCK_BARRIER
|
||||||
|
)
|
||||||
|
self._global_unblock_barrier.local_arrive()
|
||||||
|
|
||||||
|
def _handle_arrive_unblock_barrier(self):
|
||||||
|
logger.info(f"Arrived at unblock barrier ({len(self._pending_reqs)=})")
|
||||||
|
self._change_state(
|
||||||
|
original=_State.GLOBAL_UNBLOCK_BARRIER, target=_State.UNBLOCKED
|
||||||
|
)
|
||||||
|
output_reqs = [*self._pending_reqs]
|
||||||
|
self._pending_reqs.clear()
|
||||||
|
return output_reqs
|
||||||
|
|
||||||
|
def _change_state(self, original: "_State", target: "_State"):
|
||||||
|
assert self._state == original, f"{self._state=} {original=} {target=}"
|
||||||
|
self._state = target
|
||||||
|
|
||||||
|
|
||||||
|
class _State(Enum):
|
||||||
|
UNBLOCKED = auto()
|
||||||
|
BLOCKED = auto()
|
||||||
|
GLOBAL_UNBLOCK_BARRIER = auto()
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def input_blocker_guard_region(send_to_scheduler):
|
||||||
|
send_to_scheduler.send_pyobj(BlockReqInput(BlockReqType.BLOCK))
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
send_to_scheduler.send_pyobj(BlockReqInput(BlockReqType.UNBLOCK))
|
||||||
@@ -27,6 +27,7 @@ import threading
|
|||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from contextlib import nullcontext
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from typing import (
|
from typing import (
|
||||||
@@ -69,6 +70,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
BatchMultimodalOut,
|
BatchMultimodalOut,
|
||||||
BatchStrOut,
|
BatchStrOut,
|
||||||
BatchTokenIDOut,
|
BatchTokenIDOut,
|
||||||
|
BlockReqType,
|
||||||
CloseSessionReqInput,
|
CloseSessionReqInput,
|
||||||
ConfigureLoggingReq,
|
ConfigureLoggingReq,
|
||||||
EmbeddingReqInput,
|
EmbeddingReqInput,
|
||||||
@@ -114,6 +116,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
)
|
)
|
||||||
from sglang.srt.managers.mm_utils import TensorTransportMode
|
from sglang.srt.managers.mm_utils import TensorTransportMode
|
||||||
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
|
from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
|
||||||
|
from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
|
||||||
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||||
@@ -819,12 +822,21 @@ class TokenizerManager:
|
|||||||
rids.append(tmp_obj.rid)
|
rids.append(tmp_obj.rid)
|
||||||
else:
|
else:
|
||||||
# Sequential tokenization and processing
|
# Sequential tokenization and processing
|
||||||
for i in range(batch_size):
|
with (
|
||||||
tmp_obj = obj[i]
|
input_blocker_guard_region(send_to_scheduler=self.send_to_scheduler)
|
||||||
tokenized_obj = await self._tokenize_one_request(tmp_obj)
|
if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
|
||||||
state = self._send_one_request(tmp_obj, tokenized_obj, created_time)
|
else nullcontext()
|
||||||
generators.append(self._wait_one_response(tmp_obj, state, request))
|
):
|
||||||
rids.append(tmp_obj.rid)
|
for i in range(batch_size):
|
||||||
|
tmp_obj = obj[i]
|
||||||
|
tokenized_obj = await self._tokenize_one_request(tmp_obj)
|
||||||
|
state = self._send_one_request(
|
||||||
|
tmp_obj, tokenized_obj, created_time
|
||||||
|
)
|
||||||
|
generators.append(
|
||||||
|
self._wait_one_response(tmp_obj, state, request)
|
||||||
|
)
|
||||||
|
rids.append(tmp_obj.rid)
|
||||||
else:
|
else:
|
||||||
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
|
# FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
|
||||||
if batch_size > 128:
|
if batch_size > 128:
|
||||||
|
|||||||
31
python/sglang/srt/poll_based_barrier.py
Normal file
31
python/sglang/srt/poll_based_barrier.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.distributed import get_world_group
|
||||||
|
|
||||||
|
|
||||||
|
class PollBasedBarrier:
|
||||||
|
def __init__(self, noop: bool = False):
|
||||||
|
self._noop = noop
|
||||||
|
self._local_arrived = False
|
||||||
|
|
||||||
|
def local_arrive(self):
|
||||||
|
assert not self._local_arrived
|
||||||
|
self._local_arrived = True
|
||||||
|
|
||||||
|
def poll_global_arrived(self) -> bool:
|
||||||
|
global_arrived = self._compute_global_arrived()
|
||||||
|
output = self._local_arrived and global_arrived
|
||||||
|
if output:
|
||||||
|
self._local_arrived = False
|
||||||
|
return output
|
||||||
|
|
||||||
|
def _compute_global_arrived(self) -> bool:
|
||||||
|
local_arrived = self._noop or self._local_arrived
|
||||||
|
global_arrived = torch.tensor(local_arrived)
|
||||||
|
# Can optimize if bottleneck
|
||||||
|
torch.distributed.all_reduce(
|
||||||
|
global_arrived,
|
||||||
|
torch.distributed.ReduceOp.MIN,
|
||||||
|
group=get_world_group().cpu_group,
|
||||||
|
)
|
||||||
|
return global_arrived.item()
|
||||||
Reference in New Issue
Block a user