Let bench_one_batch support enable_dp_attention (#4058)

This commit is contained in:
fzyzcjy
2025-04-09 14:44:25 +08:00
committed by GitHub
parent 76c48a0913
commit 61970b08d8
2 changed files with 49 additions and 7 deletions

View File

@@ -1466,14 +1466,36 @@ class Scheduler(
self.send_to_tokenizer.send_pyobj(HealthCheckOutput())
def prepare_dp_attn_batch(self, local_batch: ScheduleBatch):
return self.prepare_dp_attn_batch_raw(
local_batch,
dp_size=self.server_args.dp_size,
attn_tp_size=self.attn_tp_size,
tp_cpu_group=self.tp_cpu_group,
get_idle_batch=self.get_idle_batch,
disable_cuda_graph=self.server_args.disable_cuda_graph,
spec_algorithm=self.spec_algorithm,
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
)
@staticmethod
def prepare_dp_attn_batch_raw(
local_batch: ScheduleBatch,
dp_size,
attn_tp_size: int,
tp_cpu_group,
get_idle_batch,
disable_cuda_graph: bool,
spec_algorithm,
speculative_num_draft_tokens,
):
# Check if other DP workers have running batches
if local_batch is None:
num_tokens = 0
global_num_tokens_for_logprob = 0
elif local_batch.forward_mode.is_decode():
num_tokens = local_batch.batch_size()
if not self.spec_algorithm.is_none() and self.spec_algorithm.is_eagle():
num_tokens = num_tokens * self.server_args.speculative_num_draft_tokens
if not spec_algorithm.is_none() and spec_algorithm.is_eagle():
num_tokens = num_tokens * speculative_num_draft_tokens
global_num_tokens_for_logprob = num_tokens
else:
num_tokens = local_batch.extend_num_tokens
@@ -1492,7 +1514,7 @@ class Scheduler(
else:
can_cuda_graph = 0
if not self.spec_algorithm.is_none():
if not spec_algorithm.is_none():
# TODO(sang): Support cuda graph when idle batch is there.
if local_batch is None or local_batch.forward_mode.is_idle():
can_cuda_graph = 0
@@ -1510,13 +1532,13 @@ class Scheduler(
dtype=torch.int64,
)
global_info = torch.empty(
(self.server_args.dp_size, self.attn_tp_size, 4),
(dp_size, attn_tp_size, 4),
dtype=torch.int64,
)
torch.distributed.all_gather_into_tensor(
global_info.flatten(),
local_info,
group=self.tp_cpu_group,
group=tp_cpu_group,
)
global_num_tokens = global_info[:, 0, 0].tolist()
can_cuda_graph = min(global_info[:, 0, 1].tolist())
@@ -1524,14 +1546,14 @@ class Scheduler(
is_extend_in_batch = global_info[:, 0, 3].tolist()
if local_batch is None and max(global_num_tokens) > 0:
local_batch = self.get_idle_batch()
local_batch = get_idle_batch()
if local_batch is not None:
local_batch.global_num_tokens = global_num_tokens
local_batch.global_num_tokens_for_logprob = global_num_tokens_for_logprob
# Check forward mode for cuda graph
if not self.server_args.disable_cuda_graph:
if not disable_cuda_graph:
local_batch.can_run_dp_cuda_graph = can_cuda_graph
return local_batch, any(is_extend_in_batch)