[Perf] Simplify FIA prefill context merge path (#7293)
### What this PR does / why we need it? This PR simplifies and hardens MLA prefill context merging in `vllm_ascend/attention/mla_v1.py` after FIA migration by directly building `out_list/lse_list` (without temporary chunk buffers or `cat/stack/split`) and using `reshape` for safe flattening of non-contiguous tensors. ### Does this PR introduce _any_ user-facing change? No. This is an internal refactor/stability improvement only; no API/interface behavior changes. ### How was this patch tested? - Verified tensor shape/data flow for `npu_attention_update` inputs (`out_list/lse_list`) after refactor. - Confirmed no lint errors in the modified file. - CI UT coverage on attention/MLA paths is used for validation. vLLM version: `v0.17.0` vLLM main: `vllm-project/vllm@4034c3d` --------- Signed-off-by: lico67373 <918688502@qq.com>
This commit is contained in:
@@ -1055,8 +1055,19 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
|
||||
actual_seq_lengths_q = prefill_metadata.actual_seq_lengths_q
|
||||
|
||||
chunk_outputs = []
|
||||
chunk_lses = []
|
||||
if iters == 0:
|
||||
return prefix_output, prefix_lse
|
||||
|
||||
num_tokens = q_nope.size(0)
|
||||
D = self.v_head_dim
|
||||
H = self.num_heads
|
||||
|
||||
if prefix_lse.dim() == 2:
|
||||
prefix_lse = prefix_lse.transpose(0, 1).unsqueeze(-1)
|
||||
prefix_output = prefix_output.to(torch.float32)
|
||||
prefix_lse = prefix_lse.to(torch.float32)
|
||||
out_list = [prefix_output.reshape(num_tokens * H, D)]
|
||||
lse_list = [prefix_lse.reshape(num_tokens * H)]
|
||||
|
||||
for i in range(iters):
|
||||
toks = prefill_metadata.chunked_context.seq_tot[i]
|
||||
@@ -1105,42 +1116,15 @@ class AscendMLAImpl(MLAAttentionImpl):
|
||||
actual_seq_lengths=actual_seq_lengths_q,
|
||||
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
||||
)
|
||||
chunk_outputs.append(chunk_out)
|
||||
chunk_lses.append(chunk_lse)
|
||||
if chunk_lse.dim() == 2:
|
||||
chunk_lse = chunk_lse.transpose(0, 1).unsqueeze(-1)
|
||||
chunk_out = chunk_out.to(torch.float32)
|
||||
chunk_lse = chunk_lse.to(torch.float32)
|
||||
out_list.append(chunk_out.reshape(num_tokens * H, D))
|
||||
lse_list.append(chunk_lse.reshape(num_tokens * H))
|
||||
|
||||
if len(chunk_outputs) > 0:
|
||||
num_tokens = q_nope.size(0)
|
||||
D = self.v_head_dim
|
||||
H = self.num_heads
|
||||
|
||||
# Normalize prefix output/lse to [num_tokens, H, D] and [num_tokens, H, 1]
|
||||
prefix_output = prefix_output.to(torch.float32)
|
||||
prefix_lse = prefix_lse.to(torch.float32)
|
||||
if prefix_lse.dim() == 2:
|
||||
prefix_lse = prefix_lse.transpose(0, 1).unsqueeze(-1)
|
||||
|
||||
# Concat output and lse: [num_tokens, H, D+1]
|
||||
all_out_lse = [torch.cat([prefix_output, prefix_lse], dim=-1)]
|
||||
for chunk_out, chunk_lse in zip(chunk_outputs, chunk_lses):
|
||||
chunk_out = chunk_out.to(torch.float32)
|
||||
chunk_lse = chunk_lse.to(torch.float32)
|
||||
if chunk_lse.dim() == 2:
|
||||
chunk_lse = chunk_lse.transpose(0, 1).unsqueeze(-1)
|
||||
all_out_lse.append(torch.cat([chunk_out, chunk_lse], dim=-1))
|
||||
|
||||
# Stack and split: [N, num_tokens, H, D+1]
|
||||
all_out_lse = torch.stack(all_out_lse, dim=0)
|
||||
N = all_out_lse.size(0)
|
||||
out_flat, lse_flat = torch.split(all_out_lse, [D, 1], dim=-1)
|
||||
|
||||
# Flatten and unbind for npu_attention_update
|
||||
out_list = out_flat.view(N, num_tokens * H, D).unbind(0)
|
||||
lse_list = lse_flat.view(N, num_tokens * H).unbind(0)
|
||||
|
||||
output_final, _ = torch_npu.npu_attention_update(lse_list, out_list, 0)
|
||||
return output_final.view(num_tokens, H, D), None
|
||||
|
||||
return prefix_output, prefix_lse
|
||||
output_final, _ = torch_npu.npu_attention_update(tuple(lse_list), tuple(out_list), 0)
|
||||
return output_final.view(num_tokens, H, D), None
|
||||
|
||||
def _forward_prefill(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user