feat: mtp support dp-attention (#6081)

Co-authored-by: austindeng <austindeng@tencent.com>
Co-authored-by: tianqilin.99 <tianqilin.99@bytedance.com>
Co-authored-by: Qiaolin Yu <liin1211@outlook.com>
Co-authored-by: ch-wan <cwan39@gatech.edu>
This commit is contained in:
u4lr451
2025-06-17 15:33:28 +08:00
committed by GitHub
parent 8a10c4c3d9
commit 10d60cd41b
22 changed files with 641 additions and 151 deletions

View File

@@ -1350,6 +1350,29 @@ class Scheduler(
self.metrics_collector.log_stats(self.stats)
self._publish_kv_events()
def coordinate_spec_dp_attn_batch(self, new_batch: Optional[ScheduleBatch]):
"""Coordinate the DP attention batch."""
local_info = torch.tensor(
[
(new_batch is not None),
],
dtype=torch.int64,
)
global_info = torch.empty(
(self.server_args.dp_size, self.attn_tp_size, 1),
dtype=torch.int64,
)
torch.distributed.all_gather_into_tensor(
global_info.flatten(),
local_info,
group=self.tp_cpu_group,
)
any_new_batch = any(
global_info[:, 0, 0].tolist()
) # Any DP worker has forward batch
return any_new_batch
def get_next_batch_to_run(self) -> Optional[ScheduleBatch]:
# Merge the prefill batch into the running batch
chunked_req_to_exclude = set()
@@ -1383,7 +1406,14 @@ class Scheduler(
self.running_batch.merge_batch(self.last_batch)
new_batch = self.get_new_batch_prefill()
if new_batch is not None:
# TODO(ch-wan): minor refactor is needed here to improve readability
any_new_batch = (
self.server_args.enable_dp_attention
and not self.spec_algorithm.is_none()
and self.coordinate_spec_dp_attn_batch(new_batch)
)
if new_batch is not None or any_new_batch:
# Run prefill first if possible
ret = new_batch
else:
@@ -1732,8 +1762,6 @@ class Scheduler(
num_tokens_for_logprob = 0
elif local_batch.forward_mode.is_decode():
num_tokens = local_batch.batch_size()
if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
num_tokens = num_tokens * speculative_num_draft_tokens
num_tokens_for_logprob = num_tokens
else:
num_tokens = local_batch.extend_num_tokens
@@ -1809,6 +1837,7 @@ class Scheduler(
local_batch.global_num_tokens_for_logprob = (
global_num_tokens_for_logprob
)
local_batch.is_extend_in_batch = any(is_extend_in_batch)
local_batch.tbo_split_seq_index = tbo_split_seq_index
local_batch.global_forward_mode = global_forward_mode
@@ -1816,6 +1845,7 @@ class Scheduler(
if not disable_cuda_graph:
local_batch.can_run_dp_cuda_graph = can_cuda_graph
# TODO(ch-wan): refactor: any(is_extend_in_batch) now is a part of local_batch. Remove it from here.
return local_batch, any(is_extend_in_batch)
def get_idle_batch(self):