Improve DP attention (#4390)

Co-authored-by: dhou-xai <dhou@x.ai>
Co-authored-by: SangBin Cho <rkooo567@gmail.com>
This commit is contained in:
Lianmin Zheng
2025-03-13 08:23:56 -07:00
committed by GitHub
parent f141298a3c
commit 8e66fbecee
9 changed files with 345 additions and 226 deletions

View File

@@ -997,7 +997,7 @@ class Scheduler(SchedulerOutputProcessorMixin):
# Handle DP attention
if self.server_args.enable_dp_attention:
ret = self.prepare_dp_attn_batch(ret)
ret, _ = self.prepare_dp_attn_batch(ret)
return ret
@@ -1269,39 +1269,72 @@ class Scheduler(SchedulerOutputProcessorMixin):
# Check if other DP workers have running batches
if local_batch is None:
num_tokens = 0
global_num_tokens_for_logprob = 0
elif local_batch.forward_mode.is_decode():
num_tokens = local_batch.batch_size()
if not self.spec_algorithm.is_none() and self.spec_algorithm.is_eagle():
num_tokens = num_tokens * self.server_args.speculative_num_draft_tokens
global_num_tokens_for_logprob = num_tokens
else:
num_tokens = local_batch.extend_num_tokens
global_num_tokens_for_logprob = sum(
[
# We should have at least 1 token for sample in every case.
max(extend_len - logprob_start_len, 1)
for logprob_start_len, extend_len in zip(
local_batch.extend_logprob_start_lens, local_batch.extend_lens
)
]
)
local_num_tokens = torch.tensor([num_tokens], dtype=torch.int64)
global_num_tokens = torch.empty(self.tp_size, dtype=torch.int64)
if local_batch is None or local_batch.forward_mode.is_decode_or_idle():
can_cuda_graph = 1
else:
can_cuda_graph = 0
if not self.spec_algorithm.is_none():
# TODO(sang): Support cuda graph when idle batch is there.
if local_batch is None or local_batch.forward_mode.is_idle():
can_cuda_graph = 0
is_extend_in_batch = (
local_batch.forward_mode.is_extend() if local_batch else False
)
local_info = torch.tensor(
[
num_tokens,
can_cuda_graph,
global_num_tokens_for_logprob,
is_extend_in_batch,
],
dtype=torch.int64,
)
global_info = torch.empty(
(self.server_args.dp_size, self.attn_tp_size, 4),
dtype=torch.int64,
)
torch.distributed.all_gather_into_tensor(
global_num_tokens,
local_num_tokens,
global_info.flatten(),
local_info,
group=self.tp_cpu_group,
)
global_num_tokens = global_info[:, 0, 0].tolist()
can_cuda_graph = min(global_info[:, 0, 1].tolist())
global_num_tokens_for_logprob = global_info[:, 0, 2].tolist()
is_extend_in_batch = global_info[:, 0, 3].tolist()
if local_batch is None and global_num_tokens.max().item() > 0:
if local_batch is None and max(global_num_tokens) > 0:
local_batch = self.get_idle_batch()
if local_batch is not None:
local_batch.global_num_tokens = global_num_tokens.tolist()
local_batch.global_num_tokens = global_num_tokens
local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
# Check forward mode for cuda graph
if not self.server_args.disable_cuda_graph:
forward_mode_state = torch.tensor(
(1 if local_batch.forward_mode.is_decode_or_idle() else 0),
dtype=torch.int32,
)
torch.distributed.all_reduce(
forward_mode_state,
op=torch.distributed.ReduceOp.MIN,
group=self.tp_cpu_group,
)
local_batch.can_run_dp_cuda_graph = forward_mode_state.item() == 1
local_batch.can_run_dp_cuda_graph = can_cuda_graph
return local_batch
return local_batch, any(is_extend_in_batch)
def get_idle_batch(self):
idle_batch = ScheduleBatch.init_new(