[feat] support minimum token load balance in dp attention (#7379)

This commit is contained in:
Guanhua Wang
2025-08-03 15:46:47 +08:00
committed by GitHub
parent b0add2da00
commit f7b2853ff8
8 changed files with 271 additions and 6 deletions

View File

@@ -126,7 +126,7 @@ from sglang.srt.managers.scheduler_update_weights_mixin import (
from sglang.srt.managers.session_controller import Session
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
from sglang.srt.managers.utils import validate_input_length
from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
from sglang.srt.mem_cache.radix_cache import RadixCache
@@ -203,6 +203,7 @@ class Scheduler(
moe_ep_rank: int,
pp_rank: int,
dp_rank: Optional[int],
dp_balance_meta: Optional[DPBalanceMeta] = None,
):
# Parse args
self.server_args = server_args
@@ -522,6 +523,15 @@ class Scheduler(
]
)
self.balance_meta = dp_balance_meta
if (
server_args.enable_dp_attention
and server_args.load_balance_method == "minimum_tokens"
):
assert dp_balance_meta is not None
self.recv_dp_balance_id_this_term = []
def init_tokenizer(self):
server_args = self.server_args
@@ -1049,6 +1059,12 @@ class Scheduler(
self,
recv_req: TokenizedGenerateReqInput,
):
if (
self.server_args.enable_dp_attention
and self.server_args.load_balance_method == "minimum_tokens"
):
self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
# Create a new request
if (
recv_req.session_params is None
@@ -1459,6 +1475,11 @@ class Scheduler(
# Handle DP attention
if need_dp_attn_preparation:
if (
self.server_args.load_balance_method == "minimum_tokens"
and self.forward_ct % 40 == 0
):
self.handle_dp_balance_data(ret)
ret = self.prepare_mlp_sync_batch(ret)
return ret
@@ -1786,6 +1807,86 @@ class Scheduler(
disable_overlap_schedule=self.server_args.disable_overlap_schedule,
)
def handle_dp_balance_data(self, local_batch: ScheduleBatch):
def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]:
"""gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
recv_list = self.recv_dp_balance_id_this_term
assert len(recv_list) <= 511, (
"The number of requests received this round is too large. "
"Please increase gather_tensor_size and onfly_info_size."
)
# The maximum size of the tensor used for gathering data from all workers.
gather_tensor_size = 512
# recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
recv_tensor[0] = holding_tokens_list
recv_tensor[1] = len(
recv_list
) # The first element is the length of the list.
recv_tensor[2 : len(recv_list) + 2] = torch.tensor(
recv_list, dtype=torch.int32
)
if self.tp_rank == 0:
gathered_list = [
torch.zeros(gather_tensor_size, dtype=torch.int32)
for _ in range(self.balance_meta.num_workers)
]
else:
gathered_list = None
torch.distributed.gather(
recv_tensor, gathered_list, group=self.tp_cpu_group
)
gathered_id_list_per_worker = None
if self.tp_rank == 0:
gathered_id_list_per_worker = []
holding_tokens_list = []
for tensor in gathered_list:
holding_tokens_list.append(tensor[0].item())
list_length = tensor[1].item()
gathered_id_list_per_worker.append(
tensor[2 : list_length + 2].tolist()
)
return gathered_id_list_per_worker, holding_tokens_list
def write_shared_dp_balance_info(new_recv_rid_lists, local_tokens):
meta = self.balance_meta
with meta.mutex:
onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
assert len(new_recv_rid_lists) == len(
onfly_list
), "num_worker not equal"
# 1.Check if the rid received by each worker this round is present in onfly.
# If it is, remove the corresponding onfly item.
worker_id = 0
for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
for new_recv_rid in new_recv_rids:
assert (
new_recv_rid in on_fly_reqs
), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
del on_fly_reqs[new_recv_rid]
worker_id += 1
# 2. Atomically write local_tokens and onfly into shm under the mutex
meta.set_shared_onfly_info(onfly_list)
meta.set_shared_local_tokens(local_tokens)
holding_tokens = self.get_load()
new_recv_dp_balance_id_list, holding_token_list = gather_dp_balance_info(
holding_tokens
)
self.recv_dp_balance_id_this_term.clear()
if self.tp_rank == 0: # only first worker write info
write_shared_dp_balance_info(
new_recv_dp_balance_id_list, holding_token_list
)
@staticmethod
def prepare_mlp_sync_batch_raw(
local_batch: ScheduleBatch,
@@ -2394,6 +2495,7 @@ def run_scheduler_process(
pp_rank: int,
dp_rank: Optional[int],
pipe_writer,
balance_meta: Optional[DPBalanceMeta] = None,
):
# Generate the prefix
prefix = ""
@@ -2427,7 +2529,14 @@ def run_scheduler_process(
# Create a scheduler and run the event loop
try:
scheduler = Scheduler(
server_args, port_args, gpu_id, tp_rank, moe_ep_rank, pp_rank, dp_rank
server_args,
port_args,
gpu_id,
tp_rank,
moe_ep_rank,
pp_rank,
dp_rank,
dp_balance_meta=balance_meta,
)
pipe_writer.send(
{