diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 50f49e229..3027f704d 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -500,6 +500,7 @@ class Scheduler( # Init metrics stats self.init_metrics(tp_rank, pp_rank, dp_rank) self.init_kv_events(server_args.kv_events_config) + self.init_dp_balance(dp_balance_meta) # Init disaggregation self.disaggregation_mode = DisaggregationMode( @@ -545,15 +546,6 @@ class Scheduler( ] ) - self.balance_meta = dp_balance_meta - if ( - server_args.enable_dp_attention - and server_args.load_balance_method == "minimum_tokens" - ): - assert dp_balance_meta is not None - - self.recv_dp_balance_id_this_term = [] - def init_tokenizer(self): server_args = self.server_args self.is_generation = self.model_config.is_generation @@ -1126,11 +1118,7 @@ class Scheduler( self, recv_req: TokenizedGenerateReqInput, ): - if ( - self.server_args.enable_dp_attention - and self.server_args.load_balance_method == "minimum_tokens" - ): - self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id) + self.maybe_update_dp_balance_data(recv_req) # Create a new request if ( @@ -1568,11 +1556,7 @@ class Scheduler( # Handle DP attention if need_dp_attn_preparation: - if ( - self.server_args.load_balance_method == "minimum_tokens" - and self.forward_ct % 40 == 0 - ): - self.handle_dp_balance_data(ret) + self.maybe_handle_dp_balance_data() ret = self.prepare_mlp_sync_batch(ret) return ret @@ -1897,86 +1881,6 @@ class Scheduler( disable_overlap_schedule=self.server_args.disable_overlap_schedule, ) - def handle_dp_balance_data(self, local_batch: ScheduleBatch): - def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]: - """gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance""" - recv_list = self.recv_dp_balance_id_this_term - assert len(recv_list) <= 511, ( - "The number of requests received this round is too large. " - "Please increase gather_tensor_size and onfly_info_size." - ) - # The maximum size of the tensor used for gathering data from all workers. - gather_tensor_size = 512 - - # recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids - recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32) - recv_tensor[0] = holding_tokens_list - recv_tensor[1] = len( - recv_list - ) # The first element is the length of the list. - recv_tensor[2 : len(recv_list) + 2] = torch.tensor( - recv_list, dtype=torch.int32 - ) - - if self.tp_rank == 0: - gathered_list = [ - torch.zeros(gather_tensor_size, dtype=torch.int32) - for _ in range(self.balance_meta.num_workers) - ] - else: - gathered_list = None - - torch.distributed.gather( - recv_tensor, gathered_list, group=self.tp_cpu_group - ) - - gathered_id_list_per_worker = None - if self.tp_rank == 0: - gathered_id_list_per_worker = [] - holding_tokens_list = [] - for tensor in gathered_list: - holding_tokens_list.append(tensor[0].item()) - list_length = tensor[1].item() - gathered_id_list_per_worker.append( - tensor[2 : list_length + 2].tolist() - ) - - return gathered_id_list_per_worker, holding_tokens_list - - def write_shared_dp_balance_info(new_recv_rid_lists, local_tokens): - meta = self.balance_meta - - with meta.mutex: - onfly_list: List[Dict[int, int]] = meta.get_shared_onfly() - assert len(new_recv_rid_lists) == len( - onfly_list - ), "num_worker not equal" - # 1.Check if the rid received by each worker this round is present in onfly. - # If it is, remove the corresponding onfly item. - worker_id = 0 - for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list): - for new_recv_rid in new_recv_rids: - assert ( - new_recv_rid in on_fly_reqs - ), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong" - del on_fly_reqs[new_recv_rid] - worker_id += 1 - # 2. Atomically write local_tokens and onfly into shm under the mutex - meta.set_shared_onfly_info(onfly_list) - meta.set_shared_local_tokens(local_tokens) - - holding_tokens = self.get_load() - - new_recv_dp_balance_id_list, holding_token_list = gather_dp_balance_info( - holding_tokens - ) - - self.recv_dp_balance_id_this_term.clear() - if self.tp_rank == 0: # only first worker write info - write_shared_dp_balance_info( - new_recv_dp_balance_id_list, holding_token_list - ) - @staticmethod def prepare_mlp_sync_batch_raw( local_batch: ScheduleBatch, diff --git a/python/sglang/srt/managers/scheduler_metrics_mixin.py b/python/sglang/srt/managers/scheduler_metrics_mixin.py index ccc61bd98..342cc83da 100644 --- a/python/sglang/srt/managers/scheduler_metrics_mixin.py +++ b/python/sglang/srt/managers/scheduler_metrics_mixin.py @@ -1,15 +1,24 @@ +from __future__ import annotations + import logging import time from collections import defaultdict -from typing import List, Optional +from typing import TYPE_CHECKING, Dict, List, Optional, Union + +import torch from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch from sglang.srt.disaggregation.utils import DisaggregationMode +from sglang.srt.managers.io_struct import TokenizedGenerateReqInput from sglang.srt.managers.schedule_policy import PrefillAdder from sglang.srt.managers.scheduler import Req, ScheduleBatch +from sglang.srt.managers.utils import DPBalanceMeta from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats from sglang.srt.utils import get_bool_env_var +if TYPE_CHECKING: + from sglang.srt.managers.scheduler import Scheduler + logger = logging.getLogger(__name__) RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME") @@ -28,7 +37,9 @@ class KvMetrics: class SchedulerMetricsMixin: - def init_metrics(self, tp_rank: int, pp_rank: int, dp_rank: Optional[int]): + def init_metrics( + self: Scheduler, tp_rank: int, pp_rank: int, dp_rank: Optional[int] + ): self.last_gen_throughput: float = 0.0 self.last_input_throughput: float = 0.0 self.step_time_dict = defaultdict(list) # Dict[batch size -> step time] @@ -50,14 +61,24 @@ class SchedulerMetricsMixin: labels["dp_rank"] = dp_rank self.metrics_collector = SchedulerMetricsCollector(labels=labels) - def init_kv_events(self, kv_events_config: Optional[str]): + def init_dp_balance(self: Scheduler, dp_balance_meta: Optional[DPBalanceMeta]): + self.balance_meta = dp_balance_meta + if ( + self.server_args.enable_dp_attention + and self.server_args.load_balance_method == "minimum_tokens" + ): + assert dp_balance_meta is not None + + self.recv_dp_balance_id_this_term = [] + + def init_kv_events(self: Scheduler, kv_events_config: Optional[str]): if self.enable_kv_cache_events: self.kv_event_publisher = EventPublisherFactory.create( kv_events_config, self.attn_dp_rank ) def log_prefill_stats( - self, + self: Scheduler, adder: PrefillAdder, can_run_list: List[Req], running_bs: int, @@ -138,7 +159,7 @@ class SchedulerMetricsMixin: self._publish_kv_events() def log_decode_stats( - self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None + self: Scheduler, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None ): batch = running_batch or self.running_batch @@ -220,7 +241,7 @@ class SchedulerMetricsMixin: self._emit_kv_metrics() self._publish_kv_events() - def _emit_kv_metrics(self): + def _emit_kv_metrics(self: Scheduler): kv_metrics = KvMetrics() kv_metrics.request_active_slots = self.stats.num_running_reqs kv_metrics.request_total_slots = self.max_running_requests @@ -236,9 +257,94 @@ class SchedulerMetricsMixin: if not self.send_metrics_from_scheduler.closed: self.send_metrics_from_scheduler.send_pyobj(kv_metrics) - def _publish_kv_events(self): + def _publish_kv_events(self: Scheduler): if self.enable_kv_cache_events: events = self.tree_cache.take_events() if events: batch = KVEventBatch(ts=time.time(), events=events) self.kv_event_publisher.publish(batch) + + def maybe_update_dp_balance_data( + self: Scheduler, recv_req: TokenizedGenerateReqInput + ): + if ( + self.server_args.enable_dp_attention + and self.server_args.load_balance_method == "minimum_tokens" + ): + self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id) + + def maybe_handle_dp_balance_data(self: Scheduler): + if ( + self.server_args.load_balance_method == "minimum_tokens" + and self.forward_ct % 40 == 0 + ): + holding_tokens = self.get_load() + + new_recv_dp_balance_id_list, holding_token_list = ( + self.gather_dp_balance_info(holding_tokens) + ) + + self.recv_dp_balance_id_this_term.clear() + if self.tp_rank == 0: # only first worker write info + self.write_shared_dp_balance_info( + new_recv_dp_balance_id_list, holding_token_list + ) + + def gather_dp_balance_info( + self: Scheduler, holding_tokens_list + ) -> Union[None, List[List[int]]]: + """gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance""" + recv_list = self.recv_dp_balance_id_this_term + assert len(recv_list) <= 511, ( + "The number of requests received this round is too large. " + "Please increase gather_tensor_size and onfly_info_size." + ) + # The maximum size of the tensor used for gathering data from all workers. + gather_tensor_size = 512 + + # recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids + recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32) + recv_tensor[0] = holding_tokens_list + recv_tensor[1] = len(recv_list) # The first element is the length of the list. + recv_tensor[2 : len(recv_list) + 2] = torch.tensor(recv_list, dtype=torch.int32) + + if self.tp_rank == 0: + gathered_list = [ + torch.zeros(gather_tensor_size, dtype=torch.int32) + for _ in range(self.balance_meta.num_workers) + ] + else: + gathered_list = None + + torch.distributed.gather(recv_tensor, gathered_list, group=self.tp_cpu_group) + + gathered_id_list_per_worker = None + if self.tp_rank == 0: + gathered_id_list_per_worker = [] + holding_tokens_list = [] + for tensor in gathered_list: + holding_tokens_list.append(tensor[0].item()) + list_length = tensor[1].item() + gathered_id_list_per_worker.append(tensor[2 : list_length + 2].tolist()) + + return gathered_id_list_per_worker, holding_tokens_list + + def write_shared_dp_balance_info(self: Scheduler, new_recv_rid_lists, local_tokens): + meta = self.balance_meta + + with meta.mutex: + onfly_list: List[Dict[int, int]] = meta.get_shared_onfly() + assert len(new_recv_rid_lists) == len(onfly_list), "num_worker not equal" + # 1.Check if the rid received by each worker this round is present in onfly. + # If it is, remove the corresponding onfly item. + worker_id = 0 + for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list): + for new_recv_rid in new_recv_rids: + assert ( + new_recv_rid in on_fly_reqs + ), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong" + del on_fly_reqs[new_recv_rid] + worker_id += 1 + # 2. Atomically write local_tokens and onfly into shm under the mutex + meta.set_shared_onfly_info(onfly_list) + meta.set_shared_local_tokens(local_tokens)