From caa71e50cafedd47231e9df11ba933de3bfcf0a0 Mon Sep 17 00:00:00 2001 From: LICO67373 <110013619+LICO1314@users.noreply.github.com> Date: Mon, 23 Mar 2026 15:47:42 +0800 Subject: [PATCH] [Perf] Simplify FIA prefill context merge path (#7293) ### What this PR does / why we need it? This PR simplifies and hardens MLA prefill context merging in `vllm_ascend/attention/mla_v1.py` after FIA migration by directly building `out_list/lse_list` (without temporary chunk buffers or `cat/stack/split`) and using `reshape` for safe flattening of non-contiguous tensors. ### Does this PR introduce _any_ user-facing change? No. This is an internal refactor/stability improvement only; no API/interface behavior changes. ### How was this patch tested? - Verified tensor shape/data flow for `npu_attention_update` inputs (`out_list/lse_list`) after refactor. - Confirmed no lint errors in the modified file. - CI UT coverage on attention/MLA paths is used for validation. vLLM version: `v0.17.0` vLLM main: `vllm-project/vllm@4034c3d` --------- Signed-off-by: lico67373 <918688502@qq.com> --- vllm_ascend/attention/mla_v1.py | 58 ++++++++++++--------------------- 1 file changed, 21 insertions(+), 37 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index dd6b018b..aa93d460 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -1055,8 +1055,19 @@ class AscendMLAImpl(MLAAttentionImpl): actual_seq_lengths_q = prefill_metadata.actual_seq_lengths_q - chunk_outputs = [] - chunk_lses = [] + if iters == 0: + return prefix_output, prefix_lse + + num_tokens = q_nope.size(0) + D = self.v_head_dim + H = self.num_heads + + if prefix_lse.dim() == 2: + prefix_lse = prefix_lse.transpose(0, 1).unsqueeze(-1) + prefix_output = prefix_output.to(torch.float32) + prefix_lse = prefix_lse.to(torch.float32) + out_list = [prefix_output.reshape(num_tokens * H, D)] + lse_list = [prefix_lse.reshape(num_tokens * H)] for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] @@ -1105,42 +1116,15 @@ class AscendMLAImpl(MLAAttentionImpl): actual_seq_lengths=actual_seq_lengths_q, actual_seq_lengths_kv=actual_seq_lengths_kv, ) - chunk_outputs.append(chunk_out) - chunk_lses.append(chunk_lse) + if chunk_lse.dim() == 2: + chunk_lse = chunk_lse.transpose(0, 1).unsqueeze(-1) + chunk_out = chunk_out.to(torch.float32) + chunk_lse = chunk_lse.to(torch.float32) + out_list.append(chunk_out.reshape(num_tokens * H, D)) + lse_list.append(chunk_lse.reshape(num_tokens * H)) - if len(chunk_outputs) > 0: - num_tokens = q_nope.size(0) - D = self.v_head_dim - H = self.num_heads - - # Normalize prefix output/lse to [num_tokens, H, D] and [num_tokens, H, 1] - prefix_output = prefix_output.to(torch.float32) - prefix_lse = prefix_lse.to(torch.float32) - if prefix_lse.dim() == 2: - prefix_lse = prefix_lse.transpose(0, 1).unsqueeze(-1) - - # Concat output and lse: [num_tokens, H, D+1] - all_out_lse = [torch.cat([prefix_output, prefix_lse], dim=-1)] - for chunk_out, chunk_lse in zip(chunk_outputs, chunk_lses): - chunk_out = chunk_out.to(torch.float32) - chunk_lse = chunk_lse.to(torch.float32) - if chunk_lse.dim() == 2: - chunk_lse = chunk_lse.transpose(0, 1).unsqueeze(-1) - all_out_lse.append(torch.cat([chunk_out, chunk_lse], dim=-1)) - - # Stack and split: [N, num_tokens, H, D+1] - all_out_lse = torch.stack(all_out_lse, dim=0) - N = all_out_lse.size(0) - out_flat, lse_flat = torch.split(all_out_lse, [D, 1], dim=-1) - - # Flatten and unbind for npu_attention_update - out_list = out_flat.view(N, num_tokens * H, D).unbind(0) - lse_list = lse_flat.view(N, num_tokens * H).unbind(0) - - output_final, _ = torch_npu.npu_attention_update(lse_list, out_list, 0) - return output_final.view(num_tokens, H, D), None - - return prefix_output, prefix_lse + output_final, _ = torch_npu.npu_attention_update(tuple(lse_list), tuple(out_list), 0) + return output_final.view(num_tokens, H, D), None def _forward_prefill( self,