[Perf] Simplify FIA prefill context merge path (#7293)

### What this PR does / why we need it?
This PR simplifies and hardens MLA prefill context merging in
`vllm_ascend/attention/mla_v1.py` after FIA migration by directly
building `out_list/lse_list` (without temporary chunk buffers or
`cat/stack/split`) and using `reshape` for safe flattening of
non-contiguous tensors.

### Does this PR introduce _any_ user-facing change?
No. This is an internal refactor/stability improvement only; no
API/interface behavior changes.

### How was this patch tested?
- Verified tensor shape/data flow for `npu_attention_update` inputs
(`out_list/lse_list`) after refactor.
- Confirmed no lint errors in the modified file.
- CI UT coverage on attention/MLA paths is used for validation.

vLLM version: `v0.17.0`  
vLLM main: `vllm-project/vllm@4034c3d`

---------

Signed-off-by: lico67373 <918688502@qq.com>
This commit is contained in:
LICO67373
2026-03-23 15:47:42 +08:00
committed by GitHub
parent da866cc168
commit caa71e50ca

View File

@@ -1055,8 +1055,19 @@ class AscendMLAImpl(MLAAttentionImpl):
actual_seq_lengths_q = prefill_metadata.actual_seq_lengths_q
chunk_outputs = []
chunk_lses = []
if iters == 0:
return prefix_output, prefix_lse
num_tokens = q_nope.size(0)
D = self.v_head_dim
H = self.num_heads
if prefix_lse.dim() == 2:
prefix_lse = prefix_lse.transpose(0, 1).unsqueeze(-1)
prefix_output = prefix_output.to(torch.float32)
prefix_lse = prefix_lse.to(torch.float32)
out_list = [prefix_output.reshape(num_tokens * H, D)]
lse_list = [prefix_lse.reshape(num_tokens * H)]
for i in range(iters):
toks = prefill_metadata.chunked_context.seq_tot[i]
@@ -1105,42 +1116,15 @@ class AscendMLAImpl(MLAAttentionImpl):
actual_seq_lengths=actual_seq_lengths_q,
actual_seq_lengths_kv=actual_seq_lengths_kv,
)
chunk_outputs.append(chunk_out)
chunk_lses.append(chunk_lse)
if chunk_lse.dim() == 2:
chunk_lse = chunk_lse.transpose(0, 1).unsqueeze(-1)
chunk_out = chunk_out.to(torch.float32)
chunk_lse = chunk_lse.to(torch.float32)
out_list.append(chunk_out.reshape(num_tokens * H, D))
lse_list.append(chunk_lse.reshape(num_tokens * H))
if len(chunk_outputs) > 0:
num_tokens = q_nope.size(0)
D = self.v_head_dim
H = self.num_heads
# Normalize prefix output/lse to [num_tokens, H, D] and [num_tokens, H, 1]
prefix_output = prefix_output.to(torch.float32)
prefix_lse = prefix_lse.to(torch.float32)
if prefix_lse.dim() == 2:
prefix_lse = prefix_lse.transpose(0, 1).unsqueeze(-1)
# Concat output and lse: [num_tokens, H, D+1]
all_out_lse = [torch.cat([prefix_output, prefix_lse], dim=-1)]
for chunk_out, chunk_lse in zip(chunk_outputs, chunk_lses):
chunk_out = chunk_out.to(torch.float32)
chunk_lse = chunk_lse.to(torch.float32)
if chunk_lse.dim() == 2:
chunk_lse = chunk_lse.transpose(0, 1).unsqueeze(-1)
all_out_lse.append(torch.cat([chunk_out, chunk_lse], dim=-1))
# Stack and split: [N, num_tokens, H, D+1]
all_out_lse = torch.stack(all_out_lse, dim=0)
N = all_out_lse.size(0)
out_flat, lse_flat = torch.split(all_out_lse, [D, 1], dim=-1)
# Flatten and unbind for npu_attention_update
out_list = out_flat.view(N, num_tokens * H, D).unbind(0)
lse_list = lse_flat.view(N, num_tokens * H).unbind(0)
output_final, _ = torch_npu.npu_attention_update(lse_list, out_list, 0)
return output_final.view(num_tokens, H, D), None
return prefix_output, prefix_lse
output_final, _ = torch_npu.npu_attention_update(tuple(lse_list), tuple(out_list), 0)
return output_final.view(num_tokens, H, D), None
def _forward_prefill(
self,