[1/N] DP-refactor: move dp balance code into scheduler's mixin class (#10004)
This commit is contained in:
@@ -500,6 +500,7 @@ class Scheduler(
|
|||||||
# Init metrics stats
|
# Init metrics stats
|
||||||
self.init_metrics(tp_rank, pp_rank, dp_rank)
|
self.init_metrics(tp_rank, pp_rank, dp_rank)
|
||||||
self.init_kv_events(server_args.kv_events_config)
|
self.init_kv_events(server_args.kv_events_config)
|
||||||
|
self.init_dp_balance(dp_balance_meta)
|
||||||
|
|
||||||
# Init disaggregation
|
# Init disaggregation
|
||||||
self.disaggregation_mode = DisaggregationMode(
|
self.disaggregation_mode = DisaggregationMode(
|
||||||
@@ -545,15 +546,6 @@ class Scheduler(
|
|||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.balance_meta = dp_balance_meta
|
|
||||||
if (
|
|
||||||
server_args.enable_dp_attention
|
|
||||||
and server_args.load_balance_method == "minimum_tokens"
|
|
||||||
):
|
|
||||||
assert dp_balance_meta is not None
|
|
||||||
|
|
||||||
self.recv_dp_balance_id_this_term = []
|
|
||||||
|
|
||||||
def init_tokenizer(self):
|
def init_tokenizer(self):
|
||||||
server_args = self.server_args
|
server_args = self.server_args
|
||||||
self.is_generation = self.model_config.is_generation
|
self.is_generation = self.model_config.is_generation
|
||||||
@@ -1126,11 +1118,7 @@ class Scheduler(
|
|||||||
self,
|
self,
|
||||||
recv_req: TokenizedGenerateReqInput,
|
recv_req: TokenizedGenerateReqInput,
|
||||||
):
|
):
|
||||||
if (
|
self.maybe_update_dp_balance_data(recv_req)
|
||||||
self.server_args.enable_dp_attention
|
|
||||||
and self.server_args.load_balance_method == "minimum_tokens"
|
|
||||||
):
|
|
||||||
self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
|
|
||||||
|
|
||||||
# Create a new request
|
# Create a new request
|
||||||
if (
|
if (
|
||||||
@@ -1568,11 +1556,7 @@ class Scheduler(
|
|||||||
|
|
||||||
# Handle DP attention
|
# Handle DP attention
|
||||||
if need_dp_attn_preparation:
|
if need_dp_attn_preparation:
|
||||||
if (
|
self.maybe_handle_dp_balance_data()
|
||||||
self.server_args.load_balance_method == "minimum_tokens"
|
|
||||||
and self.forward_ct % 40 == 0
|
|
||||||
):
|
|
||||||
self.handle_dp_balance_data(ret)
|
|
||||||
ret = self.prepare_mlp_sync_batch(ret)
|
ret = self.prepare_mlp_sync_batch(ret)
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
@@ -1897,86 +1881,6 @@ class Scheduler(
|
|||||||
disable_overlap_schedule=self.server_args.disable_overlap_schedule,
|
disable_overlap_schedule=self.server_args.disable_overlap_schedule,
|
||||||
)
|
)
|
||||||
|
|
||||||
def handle_dp_balance_data(self, local_batch: ScheduleBatch):
|
|
||||||
def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]:
|
|
||||||
"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
|
|
||||||
recv_list = self.recv_dp_balance_id_this_term
|
|
||||||
assert len(recv_list) <= 511, (
|
|
||||||
"The number of requests received this round is too large. "
|
|
||||||
"Please increase gather_tensor_size and onfly_info_size."
|
|
||||||
)
|
|
||||||
# The maximum size of the tensor used for gathering data from all workers.
|
|
||||||
gather_tensor_size = 512
|
|
||||||
|
|
||||||
# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
|
|
||||||
recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
|
|
||||||
recv_tensor[0] = holding_tokens_list
|
|
||||||
recv_tensor[1] = len(
|
|
||||||
recv_list
|
|
||||||
) # The first element is the length of the list.
|
|
||||||
recv_tensor[2 : len(recv_list) + 2] = torch.tensor(
|
|
||||||
recv_list, dtype=torch.int32
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.tp_rank == 0:
|
|
||||||
gathered_list = [
|
|
||||||
torch.zeros(gather_tensor_size, dtype=torch.int32)
|
|
||||||
for _ in range(self.balance_meta.num_workers)
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
gathered_list = None
|
|
||||||
|
|
||||||
torch.distributed.gather(
|
|
||||||
recv_tensor, gathered_list, group=self.tp_cpu_group
|
|
||||||
)
|
|
||||||
|
|
||||||
gathered_id_list_per_worker = None
|
|
||||||
if self.tp_rank == 0:
|
|
||||||
gathered_id_list_per_worker = []
|
|
||||||
holding_tokens_list = []
|
|
||||||
for tensor in gathered_list:
|
|
||||||
holding_tokens_list.append(tensor[0].item())
|
|
||||||
list_length = tensor[1].item()
|
|
||||||
gathered_id_list_per_worker.append(
|
|
||||||
tensor[2 : list_length + 2].tolist()
|
|
||||||
)
|
|
||||||
|
|
||||||
return gathered_id_list_per_worker, holding_tokens_list
|
|
||||||
|
|
||||||
def write_shared_dp_balance_info(new_recv_rid_lists, local_tokens):
|
|
||||||
meta = self.balance_meta
|
|
||||||
|
|
||||||
with meta.mutex:
|
|
||||||
onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
|
|
||||||
assert len(new_recv_rid_lists) == len(
|
|
||||||
onfly_list
|
|
||||||
), "num_worker not equal"
|
|
||||||
# 1.Check if the rid received by each worker this round is present in onfly.
|
|
||||||
# If it is, remove the corresponding onfly item.
|
|
||||||
worker_id = 0
|
|
||||||
for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
|
|
||||||
for new_recv_rid in new_recv_rids:
|
|
||||||
assert (
|
|
||||||
new_recv_rid in on_fly_reqs
|
|
||||||
), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
|
|
||||||
del on_fly_reqs[new_recv_rid]
|
|
||||||
worker_id += 1
|
|
||||||
# 2. Atomically write local_tokens and onfly into shm under the mutex
|
|
||||||
meta.set_shared_onfly_info(onfly_list)
|
|
||||||
meta.set_shared_local_tokens(local_tokens)
|
|
||||||
|
|
||||||
holding_tokens = self.get_load()
|
|
||||||
|
|
||||||
new_recv_dp_balance_id_list, holding_token_list = gather_dp_balance_info(
|
|
||||||
holding_tokens
|
|
||||||
)
|
|
||||||
|
|
||||||
self.recv_dp_balance_id_this_term.clear()
|
|
||||||
if self.tp_rank == 0: # only first worker write info
|
|
||||||
write_shared_dp_balance_info(
|
|
||||||
new_recv_dp_balance_id_list, holding_token_list
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def prepare_mlp_sync_batch_raw(
|
def prepare_mlp_sync_batch_raw(
|
||||||
local_batch: ScheduleBatch,
|
local_batch: ScheduleBatch,
|
||||||
|
|||||||
@@ -1,15 +1,24 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import List, Optional
|
from typing import TYPE_CHECKING, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
|
from sglang.srt.disaggregation.kv_events import EventPublisherFactory, KVEventBatch
|
||||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||||
|
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
|
||||||
from sglang.srt.managers.schedule_policy import PrefillAdder
|
from sglang.srt.managers.schedule_policy import PrefillAdder
|
||||||
from sglang.srt.managers.scheduler import Req, ScheduleBatch
|
from sglang.srt.managers.scheduler import Req, ScheduleBatch
|
||||||
|
from sglang.srt.managers.utils import DPBalanceMeta
|
||||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
from sglang.srt.metrics.collector import SchedulerMetricsCollector, SchedulerStats
|
||||||
from sglang.srt.utils import get_bool_env_var
|
from sglang.srt.utils import get_bool_env_var
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from sglang.srt.managers.scheduler import Scheduler
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
|
RECORD_STEP_TIME = get_bool_env_var("SGLANG_RECORD_STEP_TIME")
|
||||||
@@ -28,7 +37,9 @@ class KvMetrics:
|
|||||||
|
|
||||||
|
|
||||||
class SchedulerMetricsMixin:
|
class SchedulerMetricsMixin:
|
||||||
def init_metrics(self, tp_rank: int, pp_rank: int, dp_rank: Optional[int]):
|
def init_metrics(
|
||||||
|
self: Scheduler, tp_rank: int, pp_rank: int, dp_rank: Optional[int]
|
||||||
|
):
|
||||||
self.last_gen_throughput: float = 0.0
|
self.last_gen_throughput: float = 0.0
|
||||||
self.last_input_throughput: float = 0.0
|
self.last_input_throughput: float = 0.0
|
||||||
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
|
self.step_time_dict = defaultdict(list) # Dict[batch size -> step time]
|
||||||
@@ -50,14 +61,24 @@ class SchedulerMetricsMixin:
|
|||||||
labels["dp_rank"] = dp_rank
|
labels["dp_rank"] = dp_rank
|
||||||
self.metrics_collector = SchedulerMetricsCollector(labels=labels)
|
self.metrics_collector = SchedulerMetricsCollector(labels=labels)
|
||||||
|
|
||||||
def init_kv_events(self, kv_events_config: Optional[str]):
|
def init_dp_balance(self: Scheduler, dp_balance_meta: Optional[DPBalanceMeta]):
|
||||||
|
self.balance_meta = dp_balance_meta
|
||||||
|
if (
|
||||||
|
self.server_args.enable_dp_attention
|
||||||
|
and self.server_args.load_balance_method == "minimum_tokens"
|
||||||
|
):
|
||||||
|
assert dp_balance_meta is not None
|
||||||
|
|
||||||
|
self.recv_dp_balance_id_this_term = []
|
||||||
|
|
||||||
|
def init_kv_events(self: Scheduler, kv_events_config: Optional[str]):
|
||||||
if self.enable_kv_cache_events:
|
if self.enable_kv_cache_events:
|
||||||
self.kv_event_publisher = EventPublisherFactory.create(
|
self.kv_event_publisher = EventPublisherFactory.create(
|
||||||
kv_events_config, self.attn_dp_rank
|
kv_events_config, self.attn_dp_rank
|
||||||
)
|
)
|
||||||
|
|
||||||
def log_prefill_stats(
|
def log_prefill_stats(
|
||||||
self,
|
self: Scheduler,
|
||||||
adder: PrefillAdder,
|
adder: PrefillAdder,
|
||||||
can_run_list: List[Req],
|
can_run_list: List[Req],
|
||||||
running_bs: int,
|
running_bs: int,
|
||||||
@@ -138,7 +159,7 @@ class SchedulerMetricsMixin:
|
|||||||
self._publish_kv_events()
|
self._publish_kv_events()
|
||||||
|
|
||||||
def log_decode_stats(
|
def log_decode_stats(
|
||||||
self, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
|
self: Scheduler, can_run_cuda_graph: bool, running_batch: ScheduleBatch = None
|
||||||
):
|
):
|
||||||
batch = running_batch or self.running_batch
|
batch = running_batch or self.running_batch
|
||||||
|
|
||||||
@@ -220,7 +241,7 @@ class SchedulerMetricsMixin:
|
|||||||
self._emit_kv_metrics()
|
self._emit_kv_metrics()
|
||||||
self._publish_kv_events()
|
self._publish_kv_events()
|
||||||
|
|
||||||
def _emit_kv_metrics(self):
|
def _emit_kv_metrics(self: Scheduler):
|
||||||
kv_metrics = KvMetrics()
|
kv_metrics = KvMetrics()
|
||||||
kv_metrics.request_active_slots = self.stats.num_running_reqs
|
kv_metrics.request_active_slots = self.stats.num_running_reqs
|
||||||
kv_metrics.request_total_slots = self.max_running_requests
|
kv_metrics.request_total_slots = self.max_running_requests
|
||||||
@@ -236,9 +257,94 @@ class SchedulerMetricsMixin:
|
|||||||
if not self.send_metrics_from_scheduler.closed:
|
if not self.send_metrics_from_scheduler.closed:
|
||||||
self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
|
self.send_metrics_from_scheduler.send_pyobj(kv_metrics)
|
||||||
|
|
||||||
def _publish_kv_events(self):
|
def _publish_kv_events(self: Scheduler):
|
||||||
if self.enable_kv_cache_events:
|
if self.enable_kv_cache_events:
|
||||||
events = self.tree_cache.take_events()
|
events = self.tree_cache.take_events()
|
||||||
if events:
|
if events:
|
||||||
batch = KVEventBatch(ts=time.time(), events=events)
|
batch = KVEventBatch(ts=time.time(), events=events)
|
||||||
self.kv_event_publisher.publish(batch)
|
self.kv_event_publisher.publish(batch)
|
||||||
|
|
||||||
|
def maybe_update_dp_balance_data(
|
||||||
|
self: Scheduler, recv_req: TokenizedGenerateReqInput
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
self.server_args.enable_dp_attention
|
||||||
|
and self.server_args.load_balance_method == "minimum_tokens"
|
||||||
|
):
|
||||||
|
self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
|
||||||
|
|
||||||
|
def maybe_handle_dp_balance_data(self: Scheduler):
|
||||||
|
if (
|
||||||
|
self.server_args.load_balance_method == "minimum_tokens"
|
||||||
|
and self.forward_ct % 40 == 0
|
||||||
|
):
|
||||||
|
holding_tokens = self.get_load()
|
||||||
|
|
||||||
|
new_recv_dp_balance_id_list, holding_token_list = (
|
||||||
|
self.gather_dp_balance_info(holding_tokens)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.recv_dp_balance_id_this_term.clear()
|
||||||
|
if self.tp_rank == 0: # only first worker write info
|
||||||
|
self.write_shared_dp_balance_info(
|
||||||
|
new_recv_dp_balance_id_list, holding_token_list
|
||||||
|
)
|
||||||
|
|
||||||
|
def gather_dp_balance_info(
|
||||||
|
self: Scheduler, holding_tokens_list
|
||||||
|
) -> Union[None, List[List[int]]]:
|
||||||
|
"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
|
||||||
|
recv_list = self.recv_dp_balance_id_this_term
|
||||||
|
assert len(recv_list) <= 511, (
|
||||||
|
"The number of requests received this round is too large. "
|
||||||
|
"Please increase gather_tensor_size and onfly_info_size."
|
||||||
|
)
|
||||||
|
# The maximum size of the tensor used for gathering data from all workers.
|
||||||
|
gather_tensor_size = 512
|
||||||
|
|
||||||
|
# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
|
||||||
|
recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
|
||||||
|
recv_tensor[0] = holding_tokens_list
|
||||||
|
recv_tensor[1] = len(recv_list) # The first element is the length of the list.
|
||||||
|
recv_tensor[2 : len(recv_list) + 2] = torch.tensor(recv_list, dtype=torch.int32)
|
||||||
|
|
||||||
|
if self.tp_rank == 0:
|
||||||
|
gathered_list = [
|
||||||
|
torch.zeros(gather_tensor_size, dtype=torch.int32)
|
||||||
|
for _ in range(self.balance_meta.num_workers)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
gathered_list = None
|
||||||
|
|
||||||
|
torch.distributed.gather(recv_tensor, gathered_list, group=self.tp_cpu_group)
|
||||||
|
|
||||||
|
gathered_id_list_per_worker = None
|
||||||
|
if self.tp_rank == 0:
|
||||||
|
gathered_id_list_per_worker = []
|
||||||
|
holding_tokens_list = []
|
||||||
|
for tensor in gathered_list:
|
||||||
|
holding_tokens_list.append(tensor[0].item())
|
||||||
|
list_length = tensor[1].item()
|
||||||
|
gathered_id_list_per_worker.append(tensor[2 : list_length + 2].tolist())
|
||||||
|
|
||||||
|
return gathered_id_list_per_worker, holding_tokens_list
|
||||||
|
|
||||||
|
def write_shared_dp_balance_info(self: Scheduler, new_recv_rid_lists, local_tokens):
|
||||||
|
meta = self.balance_meta
|
||||||
|
|
||||||
|
with meta.mutex:
|
||||||
|
onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
|
||||||
|
assert len(new_recv_rid_lists) == len(onfly_list), "num_worker not equal"
|
||||||
|
# 1.Check if the rid received by each worker this round is present in onfly.
|
||||||
|
# If it is, remove the corresponding onfly item.
|
||||||
|
worker_id = 0
|
||||||
|
for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
|
||||||
|
for new_recv_rid in new_recv_rids:
|
||||||
|
assert (
|
||||||
|
new_recv_rid in on_fly_reqs
|
||||||
|
), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
|
||||||
|
del on_fly_reqs[new_recv_rid]
|
||||||
|
worker_id += 1
|
||||||
|
# 2. Atomically write local_tokens and onfly into shm under the mutex
|
||||||
|
meta.set_shared_onfly_info(onfly_list)
|
||||||
|
meta.set_shared_local_tokens(local_tokens)
|
||||||
|
|||||||
Reference in New Issue
Block a user