[main][refractor] Refractor forward metadata retrieval across DP nodes to reduce redundant padding. (#2062)
Before refactoring cross-DP decoding metadata aggregation, clean up the
token‐padding logic .
### What this PR does:
1. First checks whether any DP instance is in the prefill phase.
2. If in the `decode` phase and `torchair_graph_enabled `is true, pads
each DP instance’s token count up to the global maximum.
3. If in the `prefill` phase, or in decode phase with graph mode
**disabled**, returns each DP instance’s original token count without
padding.
This reordering removes the previous two‐step padding/unpadding flow and
ensures padding only occurs when strictly necessary.
- vLLM version: v0.10.0
- vLLM main:
bd3db7f469
Signed-off-by: yx0716 <jinyx1007@foxmail.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
@@ -348,13 +348,38 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
torch._logging.set_logs(
|
||||
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
|
||||
|
||||
self.check_batch_sizes_consistency()
|
||||
# NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True
|
||||
self.in_profile_run = False
|
||||
|
||||
# kv role
|
||||
self.is_kv_producer = False
|
||||
self.is_kv_consumer = False
|
||||
if vllm_config.kv_transfer_config is not None:
|
||||
self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer
|
||||
self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer
|
||||
|
||||
def check_batch_sizes_consistency(self) -> None:
|
||||
if not dist.is_initialized():
|
||||
return
|
||||
|
||||
local = torch.tensor(self.torchair_graph_batch_sizes,
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
gathered_graph_batch_size = local.clone()
|
||||
dist.all_reduce(gathered_graph_batch_size,
|
||||
group=get_dp_group().cpu_group)
|
||||
expected = local * self.dp_size
|
||||
|
||||
if not torch.equal(gathered_graph_batch_size, expected):
|
||||
diff_idxs = (gathered_graph_batch_size != expected).nonzero(
|
||||
as_tuple=False).flatten().tolist()
|
||||
raise AssertionError(
|
||||
f"[Graph BatchSize Mismatch] Found mismatches at indices {diff_idxs}.\n"
|
||||
f"Local (rank {self.dp_rank}): {local.tolist()}\n"
|
||||
f"Sum over ranks: {gathered_graph_batch_size.tolist()}\n"
|
||||
f"Expected if all equal: {[v * self.dp_size for v in local.tolist()]}"
|
||||
)
|
||||
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
"""Update the cached states and the persistent batch with the scheduler
|
||||
@@ -570,44 +595,58 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.input_batch.refresh_sampling_metadata()
|
||||
|
||||
def _get_forward_metadata_across_dp(
|
||||
self,
|
||||
maybe_padded_num_tokens: int,
|
||||
num_tokens: int,
|
||||
with_prefill: bool,
|
||||
enable_dbo: bool = False,
|
||||
self, num_tokens: int, with_prefill: bool,
|
||||
enable_dbo: bool) -> tuple[torch.Tensor, bool, bool]:
|
||||
|
||||
# Compose: all_reduce metadata (num_tokens of each rank, with_prefill, enable_dbo)
|
||||
num_tokens_across_dp = torch.zeros(self.dp_size + 2,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
num_tokens_across_dp[self.dp_rank] = num_tokens
|
||||
num_tokens_across_dp[-2] = int(with_prefill)
|
||||
num_tokens_across_dp[-1] = int(not enable_dbo)
|
||||
dist.all_reduce(num_tokens_across_dp, group=get_dp_group().cpu_group)
|
||||
with_prefill = bool(num_tokens_across_dp[-2])
|
||||
enable_dbo = not bool(num_tokens_across_dp[-1])
|
||||
num_tokens_across_dp = num_tokens_across_dp[:-2]
|
||||
return num_tokens_across_dp, with_prefill, enable_dbo
|
||||
|
||||
def _get_forward_metadata_across_dp_and_pad(
|
||||
self, num_tokens: int, with_prefill: bool, enable_dbo: bool
|
||||
) -> tuple[int, Optional[torch.Tensor], bool, bool]:
|
||||
if self.dp_size == 1:
|
||||
return maybe_padded_num_tokens, None, with_prefill, enable_dbo
|
||||
return num_tokens, None, with_prefill, enable_dbo
|
||||
|
||||
num_tokens_across_dp = [0] * self.dp_size * 2
|
||||
num_tokens_across_dp[self.dp_rank] = maybe_padded_num_tokens
|
||||
num_tokens_across_dp[self.dp_size + self.dp_rank] = num_tokens
|
||||
forward_metadata = torch.tensor(num_tokens_across_dp +
|
||||
[with_prefill, not enable_dbo],
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
dist.all_reduce(forward_metadata, group=get_dp_group().cpu_group)
|
||||
with_prefill = bool(forward_metadata[-2])
|
||||
if self.is_kv_producer and not envs_ascend.VLLM_ASCEND_ENABLE_CHUNK_MC2:
|
||||
num_tokens_across_dp = torch.tensor([num_tokens] * self.dp_size,
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
return num_tokens, num_tokens_across_dp, True, enable_dbo
|
||||
|
||||
# NOTE: when with_prefill is false before all_reduce and true after all_reduce, we need to revert pad.
|
||||
if with_prefill:
|
||||
num_tokens_across_dp = forward_metadata[self.dp_size:self.dp_size *
|
||||
2]
|
||||
maybe_padded_num_tokens = num_tokens
|
||||
else:
|
||||
num_tokens_across_dp = forward_metadata[:self.dp_size]
|
||||
|
||||
# NOTE: when in torchair_graph_mode, we need to pad local_num_tokens to
|
||||
# `max_tokens_across_dp`, in other situation it is not necessary.
|
||||
if self.torchair_graph_enabled and not with_prefill:
|
||||
maybe_padded_num_tokens = torch.max(num_tokens_across_dp).item()
|
||||
num_tokens_across_dp = torch.tensor([maybe_padded_num_tokens] *
|
||||
if self.is_kv_consumer and self.torchair_graph_enabled and len(
|
||||
self.torchair_graph_batch_sizes
|
||||
) == 1 and not self.in_profile_run:
|
||||
max_num_decode_tokens = self.torchair_graph_batch_sizes[0]
|
||||
num_tokens_across_dp = torch.tensor([max_num_decode_tokens] *
|
||||
self.dp_size,
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
return max_num_decode_tokens, num_tokens_across_dp, False, enable_dbo
|
||||
|
||||
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, not bool(
|
||||
forward_metadata[-1])
|
||||
maybe_padded_num_tokens = num_tokens
|
||||
num_tokens_across_dp, with_prefill, enable_dbo = self._get_forward_metadata_across_dp(
|
||||
num_tokens, with_prefill, enable_dbo)
|
||||
|
||||
if self.torchair_graph_enabled and not with_prefill:
|
||||
max_num_token = num_tokens_across_dp.max().item()
|
||||
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
|
||||
max_num_token)
|
||||
num_tokens_across_dp = torch.full((self.dp_size, ),
|
||||
maybe_padded_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device="cpu")
|
||||
|
||||
return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo
|
||||
|
||||
def _check_dbo_is_valid(self, query_lens: torch.Tensor,
|
||||
attn_state: AscendAttentionState,
|
||||
@@ -1108,16 +1147,13 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
attn_state,
|
||||
total_num_scheduled_tokens)
|
||||
|
||||
maybe_padded_num_tokens = total_num_scheduled_tokens
|
||||
if self.torchair_graph_enabled and not with_prefill:
|
||||
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
|
||||
total_num_scheduled_tokens)
|
||||
enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(),
|
||||
attn_state,
|
||||
total_num_scheduled_tokens)
|
||||
(padded_num_tokens_across_dp, num_tokens_across_dp, with_prefill,
|
||||
enable_dbo) = self._get_forward_metadata_across_dp(
|
||||
maybe_padded_num_tokens, total_num_scheduled_tokens, with_prefill,
|
||||
enable_dbo)
|
||||
enable_dbo) = self._get_forward_metadata_across_dp_and_pad(
|
||||
total_num_scheduled_tokens, with_prefill, enable_dbo)
|
||||
extra_builder_kwargs['enable_dbo_across_dp'] = enable_dbo
|
||||
|
||||
if self.torchair_graph_enabled and not with_prefill:
|
||||
graph_pad_size = padded_num_tokens_across_dp - total_num_scheduled_tokens
|
||||
|
||||
@@ -1791,16 +1827,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
with_prefill: bool = False,
|
||||
is_torchair_compile: bool = False,
|
||||
) -> torch.Tensor:
|
||||
maybe_padded_num_tokens = num_tokens
|
||||
if self.torchair_graph_enabled and not with_prefill:
|
||||
maybe_padded_num_tokens = self.select_torchair_padded_batch_size(
|
||||
num_tokens)
|
||||
|
||||
# Padding for DP
|
||||
(num_tokens, num_tokens_across_dp, with_prefill,
|
||||
_) = self._get_forward_metadata_across_dp(maybe_padded_num_tokens,
|
||||
num_tokens, with_prefill,
|
||||
False)
|
||||
_) = self._get_forward_metadata_across_dp_and_pad(
|
||||
num_tokens, with_prefill, False)
|
||||
|
||||
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
||||
# for dummy run with LoRA so that the num_reqs collectively
|
||||
|
||||
Reference in New Issue
Block a user