[v0.11.0][Perf] Delete redundant operations in model_runner and forward_context (#3775)

<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->

cherry pick https://github.com/vllm-project/vllm-ascend/pull/3677

Remove redundant operations from `model_runner` and `forward_context`.
This optimization can significantly reduce the idle time (bubble) before
decoding when running models with small parameter counts (e.g.,
Qwen/Qwen2.5-0.5B).

Testing on 800I A2, bubble is reduced from 3.8ms to 2.8ms :
Before
<img width="1655" height="696" alt="image"
src="https://github.com/user-attachments/assets/d7608e52-2438-46dd-8fc9-391fd6274495"
/>

After
<img width="1607" height="774" alt="image"
src="https://github.com/user-attachments/assets/56daf081-2dba-4d2e-99d4-e055187d9806"
/>
### What this PR does / why we need it?
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.

- Please clarify why the changes are needed. For instance, the use case
and bug description.

- Fixes #
-->

### Does this PR introduce _any_ user-facing change?
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->
No
### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->

---------

Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
realliujiaxu
2025-10-29 15:58:53 +08:00
committed by GitHub
parent 75de3fa172
commit 29bd9235ed
5 changed files with 34 additions and 25 deletions

View File

@@ -68,6 +68,8 @@ def test_select_moe_comm_method(soc_version, enable_expert_parallel,
with patch('vllm_ascend.worker.model_runner_v1.get_ascend_soc_version',
return_value=soc_version), \
patch('vllm_ascend.worker.model_runner_v1.is_global_first_rank',
return_value=True), \
patch('vllm_ascend.worker.model_runner_v1.is_moe_model',
return_value=True):
# Bind the real method to the mock object
@@ -102,6 +104,8 @@ def test_select_moe_comm_method_unsupported_soc():
return_value=unsupported_soc), \
patch('vllm_ascend.worker.model_runner_v1.is_global_first_rank',
return_value=True), \
patch('vllm_ascend.worker.model_runner_v1.is_moe_model',
return_value=True), \
pytest.raises(ValueError, match=f"Unsupported soc_version: {unsupported_soc}"):
NPUModelRunner._select_moe_comm_method(mock_runner, 100, False)

View File

@@ -11,7 +11,8 @@ from vllm.forward_context import (BatchDescriptor, get_forward_context,
set_forward_context)
import vllm_ascend.envs as envs_ascend
from vllm_ascend.utils import enable_sp, is_moe_model, version_check
from vllm_ascend.utils import (enable_sp, has_layer_idx, is_moe_model,
version_check)
if TYPE_CHECKING:
from vllm_ascend.ops.weight_prefetch import WeightPrefetchMethod
@@ -136,9 +137,7 @@ def set_ascend_forward_context(
# set layer_idx to enable optimization features that depend on this information.
# This is only applicable to models that contain these necessary attributes.
forward_context.layer_idx = None
if model_instance is not None and \
hasattr(model_instance, "model") and \
hasattr(model_instance.model, "start_layer"):
if has_layer_idx(model_instance):
forward_context.layer_idx = model_instance.model.start_layer
# TODO(rjg-lyh): refactor mlp weight prefetch method

View File

@@ -39,7 +39,7 @@ _MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}
def get_moe_comm_method(
moe_comm_type: Optional[MoECommType]) -> Optional[MoECommMethod]:
return _MoECommMethods.get(moe_comm_type)
return _MoECommMethods.get(moe_comm_type, None)
def setup_moe_comm_method(moe_config):

View File

@@ -58,6 +58,7 @@ _DEFAULT_BUFFER_SIZE = 200
_MIN_DP_BUFFER_SIZE = 50
_IS_MOE_MODEL = None
_ENABLE_SP = None
_HAS_LAYER_IDX = None
def is_310p():
@@ -785,3 +786,14 @@ def version_check():
if full_date >= "20250919":
return True
return False
def has_layer_idx(model_instance: torch.nn.Module) -> bool:
if model_instance is None:
return False
global _HAS_LAYER_IDX
if _HAS_LAYER_IDX is None:
_HAS_LAYER_IDX = hasattr(model_instance, "model") and \
hasattr(model_instance.model, "start_layer")
return _HAS_LAYER_IDX

View File

@@ -131,7 +131,7 @@ from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
AscendSocVersion, ProfileExecuteDuration,
enable_sp, get_ascend_soc_version, is_310p,
is_enable_nz, lmhead_tp_enable)
is_enable_nz, is_moe_model, lmhead_tp_enable)
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
if TYPE_CHECKING:
@@ -470,11 +470,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.in_profile_run = False
self._init_mc2_tokens_capacity()
self.reserved_mc2_mask = torch.zeros(
self.mc2_tokens_capacity,
dtype=torch.bool,
device=self.device,
)
if is_moe_model(vllm_config):
self.reserved_mc2_mask = torch.zeros(
self.mc2_tokens_capacity,
dtype=torch.bool,
device=self.device,
)
else:
self.reserved_mc2_mask = None
self.dynamic_eplb = self.ascend_config.dynamic_eplb or self.ascend_config.expert_map_record_path
if self.dynamic_eplb:
EPLBParamUtils.check_dynamic_eplb(self.ascend_config.dynamic_eplb)
@@ -1341,9 +1344,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.query_lens = torch.from_numpy(num_scheduled_tokens)
# Copy the tensors to the NPU.
self.input_ids[:total_num_scheduled_tokens].copy_(
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens)
self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_()
self.positions[:num_input_tokens].copy_(
self.positions_cpu[:num_input_tokens], non_blocking=True)
@@ -1364,16 +1365,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self._update_graph_pad_size(with_prefill, maybe_padded_num_tokens)
attn_metadata: dict[str, Any] = {}
# Prepare input_ids
token_indices = (positions_np +
req_indices * self.input_batch.token_ids_cpu.shape[1])
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
0,
torch.from_numpy(token_indices),
out=self.input_ids_cpu[:total_num_scheduled_tokens])
# Copy the tensors to the NPU.
self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens)
# _prepare_inputs may reorder the batch, so we must gather
# multi-modal outputs after that to ensure the correct order
if self.is_multimodal_model:
@@ -1835,7 +1826,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
)
def _select_moe_comm_method(self, num_tokens: int,
with_prefill: bool) -> MoECommType:
with_prefill: bool) -> Optional[MoECommType]:
"""1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all
are designed for expert parallelism.
2. If expert parallel is enabled, we need to consider the soc version and the
@@ -1858,6 +1849,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
Returns:
MoECommType: The selected MoE communication method.
"""
if not is_moe_model(self.vllm_config):
return None
soc_version = get_ascend_soc_version()
quant_type = getattr(self.vllm_config.model_config.hf_config,
'moe_quantize', None)