diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index c933fa7..b2f730a 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -348,13 +348,38 @@ class NPUModelRunner(LoRAModelRunnerMixin): torch._logging.set_logs( recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES) + self.check_batch_sizes_consistency() # NOTE: we need to use `in_profile_run` to determine whether `enable_force_load_balance` is True self.in_profile_run = False # kv role self.is_kv_producer = False + self.is_kv_consumer = False if vllm_config.kv_transfer_config is not None: self.is_kv_producer = vllm_config.kv_transfer_config.is_kv_producer + self.is_kv_consumer = vllm_config.kv_transfer_config.is_kv_consumer + + def check_batch_sizes_consistency(self) -> None: + if not dist.is_initialized(): + return + + local = torch.tensor(self.torchair_graph_batch_sizes, + device="cpu", + dtype=torch.int32) + gathered_graph_batch_size = local.clone() + dist.all_reduce(gathered_graph_batch_size, + group=get_dp_group().cpu_group) + expected = local * self.dp_size + + if not torch.equal(gathered_graph_batch_size, expected): + diff_idxs = (gathered_graph_batch_size != expected).nonzero( + as_tuple=False).flatten().tolist() + raise AssertionError( + f"[Graph BatchSize Mismatch] Found mismatches at indices {diff_idxs}.\n" + f"Local (rank {self.dp_rank}): {local.tolist()}\n" + f"Sum over ranks: {gathered_graph_batch_size.tolist()}\n" + f"Expected if all equal: {[v * self.dp_size for v in local.tolist()]}" + ) def _update_states(self, scheduler_output: "SchedulerOutput") -> None: """Update the cached states and the persistent batch with the scheduler @@ -570,44 +595,58 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.input_batch.refresh_sampling_metadata() def _get_forward_metadata_across_dp( - self, - maybe_padded_num_tokens: int, - num_tokens: int, - with_prefill: bool, - enable_dbo: bool = False, + self, num_tokens: int, with_prefill: bool, + enable_dbo: bool) -> tuple[torch.Tensor, bool, bool]: + + # Compose: all_reduce metadata (num_tokens of each rank, with_prefill, enable_dbo) + num_tokens_across_dp = torch.zeros(self.dp_size + 2, + dtype=torch.int32, + device="cpu") + num_tokens_across_dp[self.dp_rank] = num_tokens + num_tokens_across_dp[-2] = int(with_prefill) + num_tokens_across_dp[-1] = int(not enable_dbo) + dist.all_reduce(num_tokens_across_dp, group=get_dp_group().cpu_group) + with_prefill = bool(num_tokens_across_dp[-2]) + enable_dbo = not bool(num_tokens_across_dp[-1]) + num_tokens_across_dp = num_tokens_across_dp[:-2] + return num_tokens_across_dp, with_prefill, enable_dbo + + def _get_forward_metadata_across_dp_and_pad( + self, num_tokens: int, with_prefill: bool, enable_dbo: bool ) -> tuple[int, Optional[torch.Tensor], bool, bool]: if self.dp_size == 1: - return maybe_padded_num_tokens, None, with_prefill, enable_dbo + return num_tokens, None, with_prefill, enable_dbo - num_tokens_across_dp = [0] * self.dp_size * 2 - num_tokens_across_dp[self.dp_rank] = maybe_padded_num_tokens - num_tokens_across_dp[self.dp_size + self.dp_rank] = num_tokens - forward_metadata = torch.tensor(num_tokens_across_dp + - [with_prefill, not enable_dbo], - device="cpu", - dtype=torch.int32) - dist.all_reduce(forward_metadata, group=get_dp_group().cpu_group) - with_prefill = bool(forward_metadata[-2]) + if self.is_kv_producer and not envs_ascend.VLLM_ASCEND_ENABLE_CHUNK_MC2: + num_tokens_across_dp = torch.tensor([num_tokens] * self.dp_size, + device="cpu", + dtype=torch.int32) + return num_tokens, num_tokens_across_dp, True, enable_dbo - # NOTE: when with_prefill is false before all_reduce and true after all_reduce, we need to revert pad. - if with_prefill: - num_tokens_across_dp = forward_metadata[self.dp_size:self.dp_size * - 2] - maybe_padded_num_tokens = num_tokens - else: - num_tokens_across_dp = forward_metadata[:self.dp_size] - - # NOTE: when in torchair_graph_mode, we need to pad local_num_tokens to - # `max_tokens_across_dp`, in other situation it is not necessary. - if self.torchair_graph_enabled and not with_prefill: - maybe_padded_num_tokens = torch.max(num_tokens_across_dp).item() - num_tokens_across_dp = torch.tensor([maybe_padded_num_tokens] * + if self.is_kv_consumer and self.torchair_graph_enabled and len( + self.torchair_graph_batch_sizes + ) == 1 and not self.in_profile_run: + max_num_decode_tokens = self.torchair_graph_batch_sizes[0] + num_tokens_across_dp = torch.tensor([max_num_decode_tokens] * self.dp_size, device="cpu", dtype=torch.int32) + return max_num_decode_tokens, num_tokens_across_dp, False, enable_dbo - return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, not bool( - forward_metadata[-1]) + maybe_padded_num_tokens = num_tokens + num_tokens_across_dp, with_prefill, enable_dbo = self._get_forward_metadata_across_dp( + num_tokens, with_prefill, enable_dbo) + + if self.torchair_graph_enabled and not with_prefill: + max_num_token = num_tokens_across_dp.max().item() + maybe_padded_num_tokens = self.select_torchair_padded_batch_size( + max_num_token) + num_tokens_across_dp = torch.full((self.dp_size, ), + maybe_padded_num_tokens, + dtype=torch.int32, + device="cpu") + + return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo def _check_dbo_is_valid(self, query_lens: torch.Tensor, attn_state: AscendAttentionState, @@ -1108,16 +1147,13 @@ class NPUModelRunner(LoRAModelRunnerMixin): attn_state, total_num_scheduled_tokens) - maybe_padded_num_tokens = total_num_scheduled_tokens - if self.torchair_graph_enabled and not with_prefill: - maybe_padded_num_tokens = self.select_torchair_padded_batch_size( - total_num_scheduled_tokens) + enable_dbo = self._check_dbo_is_valid(self.query_lens.tolist(), + attn_state, + total_num_scheduled_tokens) (padded_num_tokens_across_dp, num_tokens_across_dp, with_prefill, - enable_dbo) = self._get_forward_metadata_across_dp( - maybe_padded_num_tokens, total_num_scheduled_tokens, with_prefill, - enable_dbo) + enable_dbo) = self._get_forward_metadata_across_dp_and_pad( + total_num_scheduled_tokens, with_prefill, enable_dbo) extra_builder_kwargs['enable_dbo_across_dp'] = enable_dbo - if self.torchair_graph_enabled and not with_prefill: graph_pad_size = padded_num_tokens_across_dp - total_num_scheduled_tokens @@ -1791,16 +1827,10 @@ class NPUModelRunner(LoRAModelRunnerMixin): with_prefill: bool = False, is_torchair_compile: bool = False, ) -> torch.Tensor: - maybe_padded_num_tokens = num_tokens - if self.torchair_graph_enabled and not with_prefill: - maybe_padded_num_tokens = self.select_torchair_padded_batch_size( - num_tokens) - # Padding for DP (num_tokens, num_tokens_across_dp, with_prefill, - _) = self._get_forward_metadata_across_dp(maybe_padded_num_tokens, - num_tokens, with_prefill, - False) + _) = self._get_forward_metadata_across_dp_and_pad( + num_tokens, with_prefill, False) # Set num_scheduled_tokens based on num_tokens and max_num_seqs # for dummy run with LoRA so that the num_reqs collectively