fix torchair execute issue on padding data, and mtp padding logic (#1160)

### What this PR does / why we need it?
The former PR https://github.com/vllm-project/vllm-ascend/pull/736
select the valid token inside the `input_ids` and `position_ids` breaks
the necessary padding required by torchair. In this PR, we pending the
pad logic after the multimodal part.


Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
This commit is contained in:
Pleaplusone
2025-06-10 22:20:40 +08:00
committed by GitHub
parent 95414bae70
commit 291c216898
2 changed files with 9 additions and 6 deletions

View File

@@ -376,7 +376,10 @@ class AscendMLAMetadataBuilder:
seq_lens = seq_lens[:self._num_decode_tokens]
input_positions = input_positions[:self._num_decode_tokens]
block_table = block_table[:self._num_decode_tokens, ...]
if use_torchair_graph and self.runner.attn_state == AscendAttentionState.DecodeOnly:
if use_torchair_graph and self.runner.attn_state in [
AscendAttentionState.DecodeOnly,
AscendAttentionState.SpecDecoding
]:
num_seqs = len(seq_lens)
if graph_pad_size != 0:
pad_value = 1

View File

@@ -943,11 +943,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
input_ids = self.input_ids[:num_input_tokens]
if (envs_ascend.VLLM_ENABLE_MC2
or self.torchair_graph_enabled) and not with_prefill:
input_ids = self.input_ids[:padded_batch_size]
positions = self.positions[:padded_batch_size]
# prepare the MRoPE for mllm if using multimodal
num_input_tokens = total_num_scheduled_tokens
# _prepare_inputs may reorder the batch, so we must gather multi
@@ -985,6 +980,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
else:
positions = self.positions[:num_input_tokens]
if (envs_ascend.VLLM_ENABLE_MC2
or self.torchair_graph_enabled) and not with_prefill:
input_ids = self.input_ids[:padded_batch_size]
positions = self.positions[:padded_batch_size]
# Run forward pass
with set_forward_context(attn_metadata,
self.vllm_config,