diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index 2001674..845e88e 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -17,6 +17,8 @@ # Adapted from vllm-project/vllm/vllm/worker/gpu_model_runner.py # +from typing import Optional + import torch from vllm.config import VllmConfig @@ -27,3 +29,29 @@ class NPUTorchairModelRunner(NPUModelRunner): def __init__(self, vllm_config: VllmConfig, device: torch.device): super().__init__(vllm_config, device) + + def _get_forward_metadata_across_dp_and_pad( + self, num_tokens: int, with_prefill: bool, enable_dbo: bool + ) -> tuple[int, Optional[torch.Tensor], bool, bool]: + if self.dp_size == 1: + if not with_prefill: + maybe_padded_num_tokens = self.select_torchair_padded_batch_size( + num_tokens) + return maybe_padded_num_tokens, None, with_prefill, enable_dbo + return num_tokens, None, with_prefill, enable_dbo + + num_tokens_across_dp, with_prefill, enable_dbo = self._get_forward_metadata_across_dp( + num_tokens, with_prefill, enable_dbo) + + if not with_prefill: + max_num_token = num_tokens_across_dp.max().item() + maybe_padded_num_tokens = self.select_torchair_padded_batch_size( + max_num_token) + num_tokens_across_dp = torch.full((self.dp_size, ), + maybe_padded_num_tokens, + dtype=torch.int32, + device="cpu") + else: + maybe_padded_num_tokens = num_tokens + + return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index ba1657c..d3a2985 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -640,26 +640,11 @@ class NPUModelRunner(LoRAModelRunnerMixin): self, num_tokens: int, with_prefill: bool, enable_dbo: bool ) -> tuple[int, Optional[torch.Tensor], bool, bool]: if self.dp_size == 1: - if self.torchair_graph_enabled and not with_prefill: - maybe_padded_num_tokens = self.select_torchair_padded_batch_size( - num_tokens) - return maybe_padded_num_tokens, None, with_prefill, enable_dbo return num_tokens, None, with_prefill, enable_dbo - maybe_padded_num_tokens = num_tokens num_tokens_across_dp, with_prefill, enable_dbo = self._get_forward_metadata_across_dp( num_tokens, with_prefill, enable_dbo) - - if self.torchair_graph_enabled and not with_prefill: - max_num_token = num_tokens_across_dp.max().item() - maybe_padded_num_tokens = self.select_torchair_padded_batch_size( - max_num_token) - num_tokens_across_dp = torch.full((self.dp_size, ), - maybe_padded_num_tokens, - dtype=torch.int32, - device="cpu") - - return maybe_padded_num_tokens, num_tokens_across_dp, with_prefill, enable_dbo + return num_tokens, num_tokens_across_dp, with_prefill, enable_dbo def _check_dbo_is_valid(self, query_lens: torch.Tensor, attn_state: AscendAttentionState,