Remove dp balance metadata and minimul token balance. (#11170)
This commit is contained in:
@@ -812,7 +812,6 @@ def _launch_subprocesses(
|
||||
pp_rank,
|
||||
None,
|
||||
writer,
|
||||
None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -120,11 +120,8 @@ message GenerateRequest {
|
||||
// Data parallel routing
|
||||
int32 data_parallel_rank = 16;
|
||||
|
||||
// For load balancing
|
||||
int32 dp_balance_id = 17;
|
||||
|
||||
// Whether client wants streaming response
|
||||
bool stream = 18;
|
||||
bool stream = 17;
|
||||
}
|
||||
|
||||
message TokenizedInput {
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -82,7 +82,7 @@ class DisaggregatedParams(_message.Message):
|
||||
def __init__(self, bootstrap_host: _Optional[str] = ..., bootstrap_port: _Optional[int] = ..., bootstrap_room: _Optional[int] = ...) -> None: ...
|
||||
|
||||
class GenerateRequest(_message.Message):
|
||||
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "dp_balance_id", "stream")
|
||||
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "stream")
|
||||
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
TOKENIZED_FIELD_NUMBER: _ClassVar[int]
|
||||
MM_INPUTS_FIELD_NUMBER: _ClassVar[int]
|
||||
@@ -99,7 +99,6 @@ class GenerateRequest(_message.Message):
|
||||
INPUT_EMBEDS_FIELD_NUMBER: _ClassVar[int]
|
||||
LORA_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int]
|
||||
DP_BALANCE_ID_FIELD_NUMBER: _ClassVar[int]
|
||||
STREAM_FIELD_NUMBER: _ClassVar[int]
|
||||
request_id: str
|
||||
tokenized: TokenizedInput
|
||||
@@ -117,9 +116,8 @@ class GenerateRequest(_message.Message):
|
||||
input_embeds: _containers.RepeatedScalarFieldContainer[float]
|
||||
lora_id: str
|
||||
data_parallel_rank: int
|
||||
dp_balance_id: int
|
||||
stream: bool
|
||||
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., dp_balance_id: _Optional[int] = ..., stream: bool = ...) -> None: ...
|
||||
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., stream: bool = ...) -> None: ...
|
||||
|
||||
class TokenizedInput(_message.Message):
|
||||
__slots__ = ("original_text", "input_ids")
|
||||
|
||||
@@ -17,14 +17,11 @@ import faulthandler
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import signal
|
||||
import struct
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from enum import Enum, auto
|
||||
from multiprocessing import shared_memory
|
||||
from typing import Dict, List
|
||||
from typing import List
|
||||
|
||||
import psutil
|
||||
import setproctitle
|
||||
@@ -39,7 +36,6 @@ from sglang.srt.managers.io_struct import (
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||
from sglang.srt.managers.utils import DPBalanceMeta
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
from sglang.srt.utils import (
|
||||
@@ -108,15 +104,9 @@ class DPBudget:
|
||||
class DataParallelController:
|
||||
"""A controller that dispatches requests to multiple data parallel workers."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
port_args: PortArgs,
|
||||
dp_balance_meta: DPBalanceMeta,
|
||||
) -> None:
|
||||
def __init__(self, server_args: ServerArgs, port_args: PortArgs) -> None:
|
||||
# for dp balance
|
||||
self.global_balance_id = 0
|
||||
self.balance_meta = dp_balance_meta
|
||||
|
||||
# Parse args
|
||||
self.max_total_num_tokens = None
|
||||
@@ -322,7 +312,6 @@ class DataParallelController:
|
||||
pp_rank,
|
||||
dp_rank,
|
||||
writer,
|
||||
self.balance_meta,
|
||||
),
|
||||
)
|
||||
with memory_saver_adapter.configure_subprocess():
|
||||
@@ -370,31 +359,11 @@ class DataParallelController:
|
||||
if self.maybe_external_dp_rank_routing(req):
|
||||
return
|
||||
|
||||
# This variable corresponds to the balance_id in TokenizedGenerateReqInput.
|
||||
# We use it to to control the number of onfly tokens (requests dispatched to workers but not yet received).
|
||||
def get_next_global_balance_id() -> int:
|
||||
INT32_MAX = 2147483647
|
||||
current_id = self.global_balance_id
|
||||
self.global_balance_id = (self.global_balance_id + 1) % INT32_MAX
|
||||
return current_id
|
||||
|
||||
req.dp_balance_id = get_next_global_balance_id()
|
||||
with self.balance_meta.mutex:
|
||||
# 1. local_tokens represents the tokens currently inferring on the worker,
|
||||
# while onfly refers to the requests dispatched by the dispatcher but not yet received by the scheduler.
|
||||
onfly_info = self.balance_meta.get_shared_onfly()
|
||||
local_tokens = self.balance_meta.get_shared_local_tokens()
|
||||
total_tokens = [
|
||||
local_token + sum(onfly_dict.values())
|
||||
for local_token, onfly_dict in zip(local_tokens, onfly_info)
|
||||
]
|
||||
target_worker = total_tokens.index(min(total_tokens))
|
||||
onfly_info[target_worker][req.dp_balance_id] = len(req.input_ids)
|
||||
# 2. write the new onfly info to the shm
|
||||
self.balance_meta.set_shared_onfly_info(onfly_info)
|
||||
|
||||
# logger.info(f"dp workers {local_tokens=}, {onfly_info=}, {target_worker=}")
|
||||
self.workers[target_worker].send_pyobj(req)
|
||||
logger.warning(
|
||||
"The 'minimum_tokens' load balancing method is deprecated for now and will introduced later."
|
||||
"Fall back to 'round_robin_scheduler'"
|
||||
)
|
||||
self.round_robin_scheduler(req)
|
||||
|
||||
def event_loop(self):
|
||||
while True:
|
||||
@@ -416,12 +385,9 @@ def run_data_parallel_controller_process(
|
||||
faulthandler.enable()
|
||||
configure_logger(server_args)
|
||||
parent_process = psutil.Process().parent()
|
||||
balance_meta = DPBalanceMeta(server_args.dp_size)
|
||||
|
||||
try:
|
||||
controller = DataParallelController(
|
||||
server_args, port_args, dp_balance_meta=balance_meta
|
||||
)
|
||||
controller = DataParallelController(server_args, port_args)
|
||||
pipe_writer.send(
|
||||
{
|
||||
"status": "ready",
|
||||
@@ -440,6 +406,3 @@ def run_data_parallel_controller_process(
|
||||
traceback = get_exception_traceback()
|
||||
logger.error(f"DataParallelController hit an exception: {traceback}")
|
||||
parent_process.send_signal(signal.SIGQUIT)
|
||||
finally:
|
||||
# we need to destruct mp.Manager() in balance_meta
|
||||
balance_meta.destructor()
|
||||
|
||||
@@ -606,9 +606,6 @@ class TokenizedGenerateReqInput:
|
||||
# For data parallel rank routing
|
||||
data_parallel_rank: Optional[int] = None
|
||||
|
||||
# For dp balance
|
||||
dp_balance_id: int = -1
|
||||
|
||||
# Priority for the request
|
||||
priority: Optional[int] = None
|
||||
|
||||
@@ -778,8 +775,6 @@ class TokenizedEmbeddingReqInput:
|
||||
sampling_params: SamplingParams
|
||||
# For data parallel rank routing
|
||||
data_parallel_rank: Optional[int] = None
|
||||
# For dp balance
|
||||
dp_balance_id: int = -1
|
||||
# Priority for the request
|
||||
priority: Optional[int] = None
|
||||
|
||||
|
||||
@@ -145,7 +145,7 @@ from sglang.srt.managers.scheduler_update_weights_mixin import (
|
||||
from sglang.srt.managers.session_controller import Session
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
||||
from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length
|
||||
from sglang.srt.managers.utils import validate_input_length
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
@@ -271,7 +271,6 @@ class Scheduler(
|
||||
moe_ep_rank: int,
|
||||
pp_rank: int,
|
||||
dp_rank: Optional[int],
|
||||
dp_balance_meta: Optional[DPBalanceMeta] = None,
|
||||
):
|
||||
# Parse args
|
||||
self.server_args = server_args
|
||||
@@ -600,7 +599,6 @@ class Scheduler(
|
||||
|
||||
# Init metrics stats
|
||||
self.init_metrics(tp_rank, pp_rank, dp_rank)
|
||||
self.init_dp_balance(dp_balance_meta)
|
||||
|
||||
if self.enable_kv_cache_events:
|
||||
self.init_kv_events(server_args.kv_events_config)
|
||||
@@ -1270,8 +1268,6 @@ class Scheduler(
|
||||
self,
|
||||
recv_req: TokenizedGenerateReqInput,
|
||||
):
|
||||
self.maybe_update_dp_balance_data(recv_req)
|
||||
|
||||
# Create a new request
|
||||
if (
|
||||
recv_req.session_params is None
|
||||
@@ -1797,7 +1793,6 @@ class Scheduler(
|
||||
|
||||
# Handle DP attention
|
||||
if need_dp_attn_preparation:
|
||||
self.maybe_handle_dp_balance_data()
|
||||
ret = self.prepare_mlp_sync_batch(ret)
|
||||
|
||||
return ret
|
||||
@@ -2803,7 +2798,6 @@ def run_scheduler_process(
|
||||
pp_rank: int,
|
||||
dp_rank: Optional[int],
|
||||
pipe_writer,
|
||||
balance_meta: Optional[DPBalanceMeta] = None,
|
||||
):
|
||||
# Generate the logger prefix
|
||||
prefix = ""
|
||||
@@ -2852,7 +2846,6 @@ def run_scheduler_process(
|
||||
moe_ep_rank,
|
||||
pp_rank,
|
||||
dp_rank,
|
||||
dp_balance_meta=balance_meta,
|
||||
)
|
||||
pipe_writer.send(
|
||||
{
|
||||
|
||||
@@ -12,7 +12,6 @@ from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
|
||||
from sglang.srt.managers.schedule_policy import PrefillAdder
|
||||
from sglang.srt.managers.scheduler import Req, ScheduleBatch
|
||||
from sglang.srt.managers.utils import DPBalanceMeta
|
||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
|
||||
@@ -64,16 +63,6 @@ class SchedulerMetricsMixin:
|
||||
labels["dp_rank"] = dp_rank
|
||||
self.metrics_collector = SchedulerMetricsCollector(labels=labels)
|
||||
|
||||
def init_dp_balance(self: Scheduler, dp_balance_meta: Optional[DPBalanceMeta]):
|
||||
self.balance_meta = dp_balance_meta
|
||||
if (
|
||||
self.server_args.enable_dp_attention
|
||||
and self.server_args.load_balance_method == "minimum_tokens"
|
||||
):
|
||||
assert dp_balance_meta is not None
|
||||
|
||||
self.recv_dp_balance_id_this_term = []
|
||||
|
||||
def init_kv_events(self: Scheduler, kv_events_config: Optional[str]):
|
||||
if self.enable_kv_cache_events:
|
||||
self.kv_event_publisher = EventPublisherFactory.create(
|
||||
@@ -319,91 +308,6 @@ class SchedulerMetricsMixin:
|
||||
batch = KVEventBatch(ts=time.time(), events=events)
|
||||
self.kv_event_publisher.publish(batch)
|
||||
|
||||
def maybe_update_dp_balance_data(
|
||||
self: Scheduler, recv_req: TokenizedGenerateReqInput
|
||||
):
|
||||
if (
|
||||
self.server_args.enable_dp_attention
|
||||
and self.server_args.load_balance_method == "minimum_tokens"
|
||||
):
|
||||
self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
|
||||
|
||||
def maybe_handle_dp_balance_data(self: Scheduler):
|
||||
if (
|
||||
self.server_args.load_balance_method == "minimum_tokens"
|
||||
and self.forward_ct % 40 == 0
|
||||
):
|
||||
holding_tokens = self.get_load().num_tokens
|
||||
|
||||
new_recv_dp_balance_id_list, holding_token_list = (
|
||||
self.gather_dp_balance_info(holding_tokens)
|
||||
)
|
||||
|
||||
self.recv_dp_balance_id_this_term.clear()
|
||||
if self.tp_rank == 0: # only first worker write info
|
||||
self.write_shared_dp_balance_info(
|
||||
new_recv_dp_balance_id_list, holding_token_list
|
||||
)
|
||||
|
||||
def gather_dp_balance_info(
|
||||
self: Scheduler, holding_tokens_list
|
||||
) -> Union[None, List[List[int]]]:
|
||||
"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
|
||||
recv_list = self.recv_dp_balance_id_this_term
|
||||
assert len(recv_list) <= 511, (
|
||||
"The number of requests received this round is too large. "
|
||||
"Please increase gather_tensor_size and onfly_info_size."
|
||||
)
|
||||
# The maximum size of the tensor used for gathering data from all workers.
|
||||
gather_tensor_size = 512
|
||||
|
||||
# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
|
||||
recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
|
||||
recv_tensor[0] = holding_tokens_list
|
||||
recv_tensor[1] = len(recv_list) # The first element is the length of the list.
|
||||
recv_tensor[2 : len(recv_list) + 2] = torch.tensor(recv_list, dtype=torch.int32)
|
||||
|
||||
if self.tp_rank == 0:
|
||||
gathered_list = [
|
||||
torch.zeros(gather_tensor_size, dtype=torch.int32)
|
||||
for _ in range(self.balance_meta.num_workers)
|
||||
]
|
||||
else:
|
||||
gathered_list = None
|
||||
|
||||
torch.distributed.gather(recv_tensor, gathered_list, group=self.tp_cpu_group)
|
||||
|
||||
gathered_id_list_per_worker = None
|
||||
if self.tp_rank == 0:
|
||||
gathered_id_list_per_worker = []
|
||||
holding_tokens_list = []
|
||||
for tensor in gathered_list:
|
||||
holding_tokens_list.append(tensor[0].item())
|
||||
list_length = tensor[1].item()
|
||||
gathered_id_list_per_worker.append(tensor[2 : list_length + 2].tolist())
|
||||
|
||||
return gathered_id_list_per_worker, holding_tokens_list
|
||||
|
||||
def write_shared_dp_balance_info(self: Scheduler, new_recv_rid_lists, local_tokens):
|
||||
meta = self.balance_meta
|
||||
|
||||
with meta.mutex:
|
||||
onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
|
||||
assert len(new_recv_rid_lists) == len(onfly_list), "num_worker not equal"
|
||||
# 1.Check if the rid received by each worker this round is present in onfly.
|
||||
# If it is, remove the corresponding onfly item.
|
||||
worker_id = 0
|
||||
for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
|
||||
for new_recv_rid in new_recv_rids:
|
||||
assert (
|
||||
new_recv_rid in on_fly_reqs
|
||||
), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
|
||||
del on_fly_reqs[new_recv_rid]
|
||||
worker_id += 1
|
||||
# 2. Atomically write local_tokens and onfly into shm under the mutex
|
||||
meta.set_shared_onfly_info(onfly_list)
|
||||
meta.set_shared_local_tokens(local_tokens)
|
||||
|
||||
def calculate_utilization(self):
|
||||
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
self.stats.utilization = -1
|
||||
|
||||
@@ -96,46 +96,3 @@ def get_logprob_from_pp_outputs(
|
||||
]
|
||||
|
||||
return logits_output, extend_input_len_per_req, extend_logprob_start_len_per_req
|
||||
|
||||
|
||||
class DPBalanceMeta:
|
||||
"""
|
||||
This class will be use in scheduler and dp controller
|
||||
"""
|
||||
|
||||
def __init__(self, num_workers: int):
|
||||
self.num_workers = num_workers
|
||||
self._manager = mp.Manager()
|
||||
self.mutex = self._manager.Lock()
|
||||
|
||||
init_local_tokens = [0] * self.num_workers
|
||||
init_onfly_info = [self._manager.dict() for _ in range(self.num_workers)]
|
||||
|
||||
self.shared_state = self._manager.Namespace()
|
||||
self.shared_state.local_tokens = self._manager.list(init_local_tokens)
|
||||
self.shared_state.onfly_info = self._manager.list(init_onfly_info)
|
||||
|
||||
def destructor(self):
|
||||
# we must destructor this class manually
|
||||
self._manager.shutdown()
|
||||
|
||||
def get_shared_onfly(self) -> List[Dict[int, int]]:
|
||||
return [dict(d) for d in self.shared_state.onfly_info]
|
||||
|
||||
def set_shared_onfly_info(self, data: List[Dict[int, int]]):
|
||||
self.shared_state.onfly_info = data
|
||||
|
||||
def get_shared_local_tokens(self) -> List[int]:
|
||||
return list(self.shared_state.local_tokens)
|
||||
|
||||
def set_shared_local_tokens(self, data: List[int]):
|
||||
self.shared_state.local_tokens = data
|
||||
|
||||
def __getstate__(self):
|
||||
state = self.__dict__.copy()
|
||||
del state["_manager"]
|
||||
return state
|
||||
|
||||
def __setstate__(self, state):
|
||||
self.__dict__.update(state)
|
||||
self._manager = None
|
||||
|
||||
Reference in New Issue
Block a user