From 9547d6f0d972c56752dba97a1b13a089c04c0b3f Mon Sep 17 00:00:00 2001 From: Angazenn <92204292+Angazenn@users.noreply.github.com> Date: Fri, 17 Oct 2025 21:56:01 +0800 Subject: [PATCH] [Core]Append padding logic for Attention (#3256) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? This PR aims to add padding logic to seq_lens、block_tables when running in full decode scenario. Before this PR, the number of input tokens with padding might exceeds corresponding seq_lens. For example, when running in full decode scenario: ``` input_ids : [1, 3, 0, 0] seq_lens: [2, 1] query_start_loc: [0, 1, 2] ``` Here, `input_ids` is padded by 2 tokens while `seq_lens`/`query_start_loc` are not. The mismatch between `input_ids` and `seq_lens`/`query_start_loc` might cause some potential bugs. This PR would change it into : ``` input_ids : [1, 3, 0, 0] seq_lens: [2, 1, 1, 1] query_start_loc: [0, 1, 2, 3, 4] ``` ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: Angazenn --- vllm_ascend/attention/attention_v1.py | 23 +++++++++++++++++++++++ vllm_ascend/attention/mla_v1.py | 1 + vllm_ascend/attention/sfa_v1.py | 1 + vllm_ascend/attention/utils.py | 4 ++++ vllm_ascend/worker/model_runner_v1.py | 3 +-- 5 files changed, 30 insertions(+), 2 deletions(-) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 7dc309b..fce9e8c 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -216,6 +216,29 @@ class AscendAttentionMetadataBuilder: query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: num_reqs + 1] + + if attn_state == AscendAttentionState.DecodeOnly and \ + common_attn_metadata.num_input_tokens > num_actual_tokens: + padded_num_tokens = common_attn_metadata.num_input_tokens - num_actual_tokens + seq_lens = torch.cat([ + seq_lens, + torch.ones(padded_num_tokens, + dtype=seq_lens.dtype, + device=seq_lens.device) + ]) + block_table_padding = torch.zeros( + (padded_num_tokens, ) + block_table.shape[1:], + dtype=block_table.dtype, + device=block_table.device) + block_table = torch.cat([block_table, block_table_padding], dim=0) + query_start_loc_cpu = torch.cat([ + query_start_loc_cpu, + torch.arange(query_start_loc_cpu[-1] + 1, + query_start_loc_cpu[-1] + padded_num_tokens, + dtype=query_start_loc_cpu.dtype, + device=query_start_loc_cpu.device) + ]) + query_start_loc = query_start_loc_cpu.to(self.device, non_blocking=True) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index e40b5af..36a8f02 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -445,6 +445,7 @@ class AscendMLAMetadataBuilder: cos=cos[:num_decode_tokens, ...]) return self.metadata_cls( # type: ignore + num_input_tokens=common_attn_metadata.num_input_tokens, num_actual_tokens=num_actual_tokens, query_lens=query_lens.tolist(), slot_mapping=slot_mapping, diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 725a2be..a709946 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -419,6 +419,7 @@ class AscendSFAMetadataBuilder: cos=cos) return self.metadata_cls( # type: ignore + num_input_tokens=common_attn_metadata.num_input_tokens, num_actual_tokens=num_actual_tokens, query_lens=query_lens.tolist(), slot_mapping=slot_mapping, diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 61befaa..1ad81c0 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -64,6 +64,10 @@ class AscendCommonAttentionMetadata: graph_pad_size: int = -1 + # num_input_tokens refers to total number of tokens including + # padding tokens. It is used to handle some padding operations. + num_input_tokens: int = 0 + # NOTE: This is a temporary solution for rotary embedding in MLA cos: torch.Tensor = None sin: torch.Tensor = None diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 8c61d04..2694374 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1477,6 +1477,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): seq_lens=self.seq_lens_cpu[:num_reqs], num_reqs=num_reqs, num_actual_tokens=total_num_scheduled_tokens, + num_input_tokens=num_input_tokens, actual_seq_lengths_q=self.actual_seq_lengths_q, # TODO: change this to the right block table for linear attn block_table_tensor=blk_table_tensor[:num_reqs], @@ -1523,8 +1524,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): model=self.get_model(), **extra_attn_metadata_args) - if self.vllm_config.model_config.use_mla or self.use_sparse: - attn_metadata_i.num_input_tokens = num_input_tokens for layer_name in attn_group.layer_names: attn_metadata[layer_name] = attn_metadata_i