[Perf] Simplify FIA prefill context merge path (#7293)
### What this PR does / why we need it? This PR simplifies and hardens MLA prefill context merging in `vllm_ascend/attention/mla_v1.py` after FIA migration by directly building `out_list/lse_list` (without temporary chunk buffers or `cat/stack/split`) and using `reshape` for safe flattening of non-contiguous tensors. ### Does this PR introduce _any_ user-facing change? No. This is an internal refactor/stability improvement only; no API/interface behavior changes. ### How was this patch tested? - Verified tensor shape/data flow for `npu_attention_update` inputs (`out_list/lse_list`) after refactor. - Confirmed no lint errors in the modified file. - CI UT coverage on attention/MLA paths is used for validation. vLLM version: `v0.17.0` vLLM main: `vllm-project/vllm@4034c3d` --------- Signed-off-by: lico67373 <918688502@qq.com>
This commit is contained in:
@@ -1055,8 +1055,19 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
|
|
||||||
actual_seq_lengths_q = prefill_metadata.actual_seq_lengths_q
|
actual_seq_lengths_q = prefill_metadata.actual_seq_lengths_q
|
||||||
|
|
||||||
chunk_outputs = []
|
if iters == 0:
|
||||||
chunk_lses = []
|
return prefix_output, prefix_lse
|
||||||
|
|
||||||
|
num_tokens = q_nope.size(0)
|
||||||
|
D = self.v_head_dim
|
||||||
|
H = self.num_heads
|
||||||
|
|
||||||
|
if prefix_lse.dim() == 2:
|
||||||
|
prefix_lse = prefix_lse.transpose(0, 1).unsqueeze(-1)
|
||||||
|
prefix_output = prefix_output.to(torch.float32)
|
||||||
|
prefix_lse = prefix_lse.to(torch.float32)
|
||||||
|
out_list = [prefix_output.reshape(num_tokens * H, D)]
|
||||||
|
lse_list = [prefix_lse.reshape(num_tokens * H)]
|
||||||
|
|
||||||
for i in range(iters):
|
for i in range(iters):
|
||||||
toks = prefill_metadata.chunked_context.seq_tot[i]
|
toks = prefill_metadata.chunked_context.seq_tot[i]
|
||||||
@@ -1105,43 +1116,16 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
actual_seq_lengths=actual_seq_lengths_q,
|
actual_seq_lengths=actual_seq_lengths_q,
|
||||||
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
actual_seq_lengths_kv=actual_seq_lengths_kv,
|
||||||
)
|
)
|
||||||
chunk_outputs.append(chunk_out)
|
|
||||||
chunk_lses.append(chunk_lse)
|
|
||||||
|
|
||||||
if len(chunk_outputs) > 0:
|
|
||||||
num_tokens = q_nope.size(0)
|
|
||||||
D = self.v_head_dim
|
|
||||||
H = self.num_heads
|
|
||||||
|
|
||||||
# Normalize prefix output/lse to [num_tokens, H, D] and [num_tokens, H, 1]
|
|
||||||
prefix_output = prefix_output.to(torch.float32)
|
|
||||||
prefix_lse = prefix_lse.to(torch.float32)
|
|
||||||
if prefix_lse.dim() == 2:
|
|
||||||
prefix_lse = prefix_lse.transpose(0, 1).unsqueeze(-1)
|
|
||||||
|
|
||||||
# Concat output and lse: [num_tokens, H, D+1]
|
|
||||||
all_out_lse = [torch.cat([prefix_output, prefix_lse], dim=-1)]
|
|
||||||
for chunk_out, chunk_lse in zip(chunk_outputs, chunk_lses):
|
|
||||||
chunk_out = chunk_out.to(torch.float32)
|
|
||||||
chunk_lse = chunk_lse.to(torch.float32)
|
|
||||||
if chunk_lse.dim() == 2:
|
if chunk_lse.dim() == 2:
|
||||||
chunk_lse = chunk_lse.transpose(0, 1).unsqueeze(-1)
|
chunk_lse = chunk_lse.transpose(0, 1).unsqueeze(-1)
|
||||||
all_out_lse.append(torch.cat([chunk_out, chunk_lse], dim=-1))
|
chunk_out = chunk_out.to(torch.float32)
|
||||||
|
chunk_lse = chunk_lse.to(torch.float32)
|
||||||
|
out_list.append(chunk_out.reshape(num_tokens * H, D))
|
||||||
|
lse_list.append(chunk_lse.reshape(num_tokens * H))
|
||||||
|
|
||||||
# Stack and split: [N, num_tokens, H, D+1]
|
output_final, _ = torch_npu.npu_attention_update(tuple(lse_list), tuple(out_list), 0)
|
||||||
all_out_lse = torch.stack(all_out_lse, dim=0)
|
|
||||||
N = all_out_lse.size(0)
|
|
||||||
out_flat, lse_flat = torch.split(all_out_lse, [D, 1], dim=-1)
|
|
||||||
|
|
||||||
# Flatten and unbind for npu_attention_update
|
|
||||||
out_list = out_flat.view(N, num_tokens * H, D).unbind(0)
|
|
||||||
lse_list = lse_flat.view(N, num_tokens * H).unbind(0)
|
|
||||||
|
|
||||||
output_final, _ = torch_npu.npu_attention_update(lse_list, out_list, 0)
|
|
||||||
return output_final.view(num_tokens, H, D), None
|
return output_final.view(num_tokens, H, D), None
|
||||||
|
|
||||||
return prefix_output, prefix_lse
|
|
||||||
|
|
||||||
def _forward_prefill(
|
def _forward_prefill(
|
||||||
self,
|
self,
|
||||||
q_nope: torch.Tensor,
|
q_nope: torch.Tensor,
|
||||||
|
|||||||
Reference in New Issue
Block a user