[Perf] Delete redundant operations in model_runner and forward_context (#3677)

### What this PR does / why we need it?

Remove redundant operations from `model_runner` and `forward_context`.
This optimization can significantly reduce the idle time (bubble) before
decoding when running models with small parameter counts (e.g.,
Qwen/Qwen2.5-0.5B).

Testing on 800I A2, bubble is reduced from 3.8ms to 2.8ms :
Before
<img width="1655" height="696" alt="image"
src="https://github.com/user-attachments/assets/d7608e52-2438-46dd-8fc9-391fd6274495"
/>

After
<img width="1607" height="774" alt="image"
src="https://github.com/user-attachments/assets/56daf081-2dba-4d2e-99d4-e055187d9806"
/>

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?


- vLLM version: v0.11.0rc3
- vLLM main:
https://github.com/vllm-project/vllm/commit/releases/v0.11.1

---------

Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
realliujiaxu
2025-10-29 15:59:55 +08:00
committed by GitHub
parent 0d1859af08
commit 74191864b7
5 changed files with 34 additions and 25 deletions

View File

@@ -136,7 +136,7 @@ from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
AscendSocVersion, ProfileExecuteDuration,
enable_sp, get_ascend_soc_version, is_310p,
is_enable_nz, lmhead_tp_enable,
is_enable_nz, is_moe_model, lmhead_tp_enable,
prefill_context_parallel_enable,
vllm_version_is)
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
@@ -515,11 +515,14 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.in_profile_run = False
self._init_mc2_tokens_capacity()
self.reserved_mc2_mask = torch.zeros(
self.mc2_tokens_capacity,
dtype=torch.bool,
device=self.device,
)
if is_moe_model(vllm_config):
self.reserved_mc2_mask = torch.zeros(
self.mc2_tokens_capacity,
dtype=torch.bool,
device=self.device,
)
else:
self.reserved_mc2_mask = None
self.dynamic_eplb = self.ascend_config.dynamic_eplb or self.ascend_config.expert_map_record_path
if self.dynamic_eplb:
EPLBParamUtils.check_dynamic_eplb(self.ascend_config.dynamic_eplb)
@@ -1497,9 +1500,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.query_lens = torch.from_numpy(num_scheduled_tokens)
# Copy the tensors to the NPU.
self.input_ids[:total_num_scheduled_tokens].copy_(
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens)
self.positions_cpu[total_num_scheduled_tokens:num_input_tokens].zero_()
self.positions[:num_input_tokens].copy_(
self.positions_cpu[:num_input_tokens], non_blocking=True)
@@ -1521,16 +1522,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self._update_graph_pad_size(with_prefill, maybe_padded_num_tokens)
attn_metadata: dict[str, Any] = {}
# Prepare input_ids
token_indices = (positions_np +
req_indices * self.input_batch.token_ids_cpu.shape[1])
torch.index_select(self.input_batch.token_ids_cpu_tensor.flatten(),
0,
torch.from_numpy(token_indices),
out=self.input_ids_cpu[:total_num_scheduled_tokens])
# Copy the tensors to the NPU.
self._prepare_input_ids(total_num_scheduled_tokens, cu_num_tokens)
# _prepare_inputs may reorder the batch, so we must gather
# multi-modal outputs after that to ensure the correct order
if self.is_multimodal_model:
@@ -2075,7 +2066,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
)
def _select_moe_comm_method(self, num_tokens: int,
with_prefill: bool) -> MoECommType:
with_prefill: bool) -> Optional[MoECommType]:
"""1. If expert parallel is not enabled, we use all-gather since MC2 and all-to-all
are designed for expert parallelism.
2. If expert parallel is enabled, we need to consider the soc version and the
@@ -2098,6 +2089,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
Returns:
MoECommType: The selected MoE communication method.
"""
if not is_moe_model(self.vllm_config):
return None
soc_version = get_ascend_soc_version()
quant_type = getattr(self.vllm_config.model_config.hf_config,
'moe_quantize', None)