From f57bdb09fc3126a1857edf95a3bbc208312678a0 Mon Sep 17 00:00:00 2001 From: pichangping <1337510399@qq.com> Date: Wed, 29 Oct 2025 09:33:35 +0800 Subject: [PATCH] [long_seq_optim] BSND to TND and FA_UPDATE replacement (#3778) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What this PR does / why we need it? We have optimized the performance of long sequences:First,Modify the input data format for attention calculation. Instead of using the original BSND format, remove the logic for converting between TND and BSND, and directly adopt the TND format. The TND input format can be directly reused, which shortens the data flow path. Converting to BSND is an unnecessary processing step.Second, we switched the output update of the concatenated small operators to the npu_attention_update fusion operator to improve performance. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/c9461e05a4ed3557cfbf4b15ded1e26761cc39ca --------- Signed-off-by: pichangping <1337510399@qq.com> --- vllm_ascend/attention/attention_v1.py | 212 +++++++++++++------------- vllm_ascend/worker/model_runner_v1.py | 8 +- 2 files changed, 108 insertions(+), 112 deletions(-) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 07ef6d91..594e825e 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -23,7 +23,6 @@ import numpy as np import torch import torch.distributed as dist import torch.nn as nn -import torch.nn.functional as F import torch_npu from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) @@ -318,6 +317,18 @@ class AscendAttentionMetadataBuilder: pcp_metadata = None common_long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata if common_long_seq_metadata is not None: + attn_mask_seqlens = common_long_seq_metadata.attn_mask_seqlens + head_attn_nomask_seqlens = common_long_seq_metadata.head_attn_nomask_seqlens + tail_attn_nomask_seqlens = common_long_seq_metadata.tail_attn_nomask_seqlens + pcp_size = get_prefill_context_model_parallel_world_size( + ) if prefill_context_parallel_enable() else 1 + if pcp_size > 1: + attn_mask_seqlens = torch.cumsum(attn_mask_seqlens[0], + dim=0).tolist() + head_attn_nomask_seqlens = torch.cumsum( + head_attn_nomask_seqlens[1], dim=0).tolist() + tail_attn_nomask_seqlens = torch.cumsum( + tail_attn_nomask_seqlens[1], dim=0).tolist() pcp_metadata = AscendPCPMetadata( q_head_idx=common_long_seq_metadata.q_head_idx_tensor, q_tail_idx=common_long_seq_metadata.q_tail_idx_tensor, @@ -329,12 +340,9 @@ class AscendAttentionMetadataBuilder: kv_with_q_tail_nomask_idx_tensor, kv_with_q_tail_mask_idx=common_long_seq_metadata. kv_with_q_tail_mask_idx_tensor, - attn_mask_seqlens=common_long_seq_metadata. - attn_mask_seqlens, - head_attn_nomask_seqlens=common_long_seq_metadata. - head_attn_nomask_seqlens, - tail_attn_nomask_seqlens=common_long_seq_metadata. - tail_attn_nomask_seqlens, + attn_mask_seqlens=attn_mask_seqlens, + head_attn_nomask_seqlens=head_attn_nomask_seqlens, + tail_attn_nomask_seqlens=tail_attn_nomask_seqlens, q_full_idx=common_long_seq_metadata.q_full_idx, pcp_prefill_mask=common_long_seq_metadata.pcp_prefill_mask) prefill_metadata = AscendMetadataForPrefill( @@ -726,28 +734,6 @@ class AscendAttentionBackendImpl(AttentionImpl): out=output) return output - def _pack_tnd_2_bsnd(self, tensor_tnd: torch.Tensor, - lengths: List[int]) -> torch.Tensor: - max_len = max(lengths) - splits = torch.split(tensor_tnd, lengths, dim=0) - - padded = [] - for s in splits: - pad_len = max_len - s.shape[0] - s_pad = F.pad(s, (0, 0, 0, 0, 0, pad_len)) - padded.append(s_pad) - - tensor_bsnd = torch.stack(padded, dim=0) - return tensor_bsnd - - def _unpack_bsnd_2_tnd(self, tensor_bsnd: torch.Tensor, - lengths: List[int]) -> torch.Tensor: - slices = [] - for i, length in enumerate(lengths): - slices.append(tensor_bsnd[i, :length]) - tensor_tnd = torch.cat(slices, dim=0) - return tensor_tnd - def _attention_with_nomask_and_mask(self, q: torch.Tensor, q_seqlens: List[int], k_nomask: torch.Tensor, @@ -757,17 +743,15 @@ class AscendAttentionBackendImpl(AttentionImpl): v_mask: torch.Tensor, kv_seqlens_mask: List[int], mask: torch.Tensor) -> torch.Tensor: - q = self._pack_tnd_2_bsnd(q, q_seqlens) - # nomask Attention if k_nomask is not None: attn_out_nomask, attn_lse_nomask = torch.ops.npu.npu_fused_infer_attention_score( q, - self._pack_tnd_2_bsnd(k_nomask, kv_seqlens_nomask), - self._pack_tnd_2_bsnd(v_nomask, kv_seqlens_nomask), + k_nomask, + v_nomask, num_heads=self.num_heads, num_key_value_heads=self.num_kv_heads, - input_layout="BSND", + input_layout="TND", atten_mask=None, scale=self.scale, sparse_mode=0, @@ -776,38 +760,46 @@ class AscendAttentionBackendImpl(AttentionImpl): softmax_lse_flag=True, actual_seq_lengths_kv=kv_seqlens_nomask, actual_seq_lengths=q_seqlens) - attn_out_nomask = self._unpack_bsnd_2_tnd(attn_out_nomask, - q_seqlens) - # (B, N, Q_S, 1) -> (B, Q_S, N, 1) -> (T, N, 1) - attn_lse_nomask = self._unpack_bsnd_2_tnd( - attn_lse_nomask.permute([0, 2, 1, 3]), q_seqlens) # mask Attention attn_out_mask, attn_lse_mask = torch.ops.npu.npu_fused_infer_attention_score( q, - self._pack_tnd_2_bsnd(k_mask, kv_seqlens_mask), - self._pack_tnd_2_bsnd(v_mask, kv_seqlens_mask), + k_mask, + v_mask, num_heads=self.num_heads, num_key_value_heads=self.num_kv_heads, - input_layout="BSND", + input_layout="TND", atten_mask=mask, scale=self.scale, - sparse_mode=0, + sparse_mode=3, antiquant_mode=0, antiquant_scale=None, softmax_lse_flag=True, actual_seq_lengths_kv=kv_seqlens_mask, actual_seq_lengths=q_seqlens) - attn_out_mask = self._unpack_bsnd_2_tnd(attn_out_mask, q_seqlens) - attn_lse_mask = self._unpack_bsnd_2_tnd( - attn_lse_mask.permute([0, 2, 1, 3]), q_seqlens) # update output = attn_out_mask if k_nomask is not None: - output, _ = self._update_out_and_lse( - torch.stack([attn_out_nomask, attn_out_mask], dim=0), - torch.stack([attn_lse_nomask, attn_lse_mask], dim=0)) + T = attn_out_mask.shape[0] + N = attn_out_mask.shape[1] + D = attn_out_mask.shape[2] + + attn_out_mask, attn_lse_mask = self._out_lse_reshape( + attn_out_mask, attn_lse_mask) + attn_out_nomask, attn_lse_nomask = self._out_lse_reshape( + attn_out_nomask, attn_lse_nomask) + attn_out_mask = attn_out_mask.to(torch.float32) + attn_out_nomask = attn_out_nomask.to(torch.float32) + attn_lse_mask = attn_lse_mask.to(torch.float32) + attn_lse_nomask = attn_lse_nomask.to(torch.float32) + + attn_output = [attn_out_nomask, attn_out_mask] + attn_lse = [attn_lse_nomask, attn_lse_mask] + update_type = 0 + output, _ = torch_npu.npu_attention_update(attn_lse, attn_output, + update_type) + output = output.view(T, N, D) return output @@ -832,15 +824,15 @@ class AscendAttentionBackendImpl(AttentionImpl): # 1. Attention calculation in the first half of Q in load balancing output_head = self._attention_with_nomask_and_mask( q=torch.index_select(query, 0, q_head_idx), - q_seqlens=attn_mask_seqlens[0].tolist(), + q_seqlens=attn_mask_seqlens, k_nomask=torch.index_select(key, 0, kv_with_q_head_nomask_idx) if self.pcp_rank > 0 else None, v_nomask=torch.index_select(value, 0, kv_with_q_head_nomask_idx) if self.pcp_rank > 0 else None, - kv_seqlens_nomask=head_attn_nomask_seqlens[1].tolist(), + kv_seqlens_nomask=head_attn_nomask_seqlens, k_mask=torch.index_select(key, 0, kv_with_q_head_mask_idx), v_mask=torch.index_select(value, 0, kv_with_q_head_mask_idx), - kv_seqlens_mask=attn_mask_seqlens[0].tolist(), + kv_seqlens_mask=attn_mask_seqlens, mask=mask) # 2. the Attention calculation in the latter half of Q in load balancing @@ -848,13 +840,13 @@ class AscendAttentionBackendImpl(AttentionImpl): # pcp_rank1: Q2*KV0~KV1 + Q2*KV2 output_tail = self._attention_with_nomask_and_mask( q=torch.index_select(query, 0, q_tail_idx), - q_seqlens=attn_mask_seqlens[0].tolist(), + q_seqlens=attn_mask_seqlens, k_nomask=torch.index_select(key, 0, kv_with_q_tail_nomask_idx), v_nomask=torch.index_select(value, 0, kv_with_q_tail_nomask_idx), - kv_seqlens_nomask=tail_attn_nomask_seqlens[1].tolist(), + kv_seqlens_nomask=tail_attn_nomask_seqlens, k_mask=torch.index_select(key, 0, kv_with_q_tail_mask_idx), v_mask=torch.index_select(value, 0, kv_with_q_tail_mask_idx), - kv_seqlens_mask=attn_mask_seqlens[0].tolist(), + kv_seqlens_mask=attn_mask_seqlens, mask=mask) # 3. Combine the output of the first half and second half. @@ -863,20 +855,36 @@ class AscendAttentionBackendImpl(AttentionImpl): torch.cat([output_head, output_tail], dim=0), 0, q_full_idx) return output - def _update_out_and_lse(self, out_list: torch.Tensor, - lse_list: torch.Tensor) -> torch.Tensor: - """LSE_final = log(sum(exp(LSE_i))), O_final = sum(exp(LSE_i - LSE_final) * O_i) - Args: - out_list: shape = [N, batch_size, num_heads, head_size] - lse_list: shape = [N, batch_size, num_heads, 1] - Returns: - out_final: shape = [batch_size, num_heads, head_size] - lse_final: shape = [batch_size, num_heads, 1] - """ - lse_final = torch.logsumexp(lse_list, dim=0, keepdim=False) - out_final = torch.sum(torch.exp(lse_list - lse_final) * out_list, - dim=0) - return out_final, lse_final + def _out_lse_reshape(self, attn_out: torch.Tensor, + attn_lse: torch.Tensor) -> torch.Tensor: + attn_out = attn_out.contiguous().view( + attn_out.shape[0] * attn_out.shape[1], attn_out.shape[2]) + attn_lse = attn_lse.contiguous().view( + attn_lse.shape[0] * attn_lse.shape[1] * attn_lse.shape[2]) + return attn_out, attn_lse + + def _npu_attention_update( + self, attn_out_lse_list: List[torch.Tensor]) -> torch.Tensor: + update_type = 0 + + batch = attn_out_lse_list[0].shape[0] + num_heads = attn_out_lse_list[0].shape[1] + head_dim = attn_out_lse_list[0].shape[2] - 1 + + attn_out_split_cp = [] + attn_lse_split_cp = [] + + for i in attn_out_lse_list: + attn_out_allgather, attn_lse_allgather = self._out_lse_reshape( + *torch.split(i, [self.head_size, 1], dim=-1)) + attn_out_split_cp.append(attn_out_allgather) + attn_lse_split_cp.append(attn_lse_allgather) + + attn_out, attn_lse = torch_npu.npu_attention_update( + attn_lse_split_cp, attn_out_split_cp, update_type) + attn_out = attn_out.view(batch, num_heads, head_dim) + + return attn_out def _forward_decode_pcp_dcp(self, query: torch.Tensor, attn_metadata: AscendMetadata) -> torch.Tensor: @@ -889,9 +897,6 @@ class AscendAttentionBackendImpl(AttentionImpl): else: num_heads = self.num_heads - # 1. Compute out&lse by "npu_fused_infer_attention_score" - q_nope = query.view(query.shape[0], 1, query.shape[1], query.shape[2]) - # [b,num_heads,head_size] -> [b,1,num_heads,head_size] k_nope = self.key_cache.view(self.key_cache.shape[0], self.key_cache.shape[1], -1) value = self.value_cache.view(self.key_cache.shape[0], @@ -902,7 +907,7 @@ class AscendAttentionBackendImpl(AttentionImpl): 'num_key_value_heads': self.num_kv_heads, 'input_layout': - "BSND", + "TND", 'atten_mask': None, 'scale': @@ -917,9 +922,11 @@ class AscendAttentionBackendImpl(AttentionImpl): attn_metadata.block_tables, 'block_size': self.key_cache.shape[1], - "actual_seq_lengths_kv": - attn_metadata.decode_meta. - num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank], + 'actual_seq_lengths_kv': + attn_metadata.seq_lens_list[:attn_metadata.num_decode_tokens], + 'actual_seq_lengths': + attn_metadata.actual_seq_lengths_q[:attn_metadata. + num_decode_tokens] } graph_params = get_graph_params() forward_context: ForwardContext = get_forward_context() @@ -935,16 +942,16 @@ class AscendAttentionBackendImpl(AttentionImpl): workspace = graph_params.workspaces.get(num_tokens) if workspace is None: workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( - q_nope, k_nope, value, **common_kwargs) + query, k_nope, value, **common_kwargs) update_graph_params_workspaces(num_tokens, weak_ref_tensors(workspace)) - attn_out = torch.empty_like(q_nope) + attn_out = torch.empty_like(query) attn_lse = torch.empty((num_tokens, num_heads, 1, 1), dtype=torch.float, - device=q_nope.device) + device=query.device) graph_params.attn_params[num_tokens].append( - (weak_ref_tensors(q_nope), weak_ref_tensors(k_nope), + (weak_ref_tensors(query), weak_ref_tensors(k_nope), weak_ref_tensors(value), self.num_heads, self.num_kv_heads, self.scale, attn_metadata.block_tables, self.key_cache.shape[1], attn_metadata.decode_meta. @@ -954,7 +961,7 @@ class AscendAttentionBackendImpl(AttentionImpl): self.pcp_rank, self.dcp_rank, self.dcp_size)) torch.npu.graph_task_group_begin(stream) torch_npu.npu_fused_infer_attention_score.out( - q_nope, + query, k_nope, value, **common_kwargs, @@ -964,14 +971,12 @@ class AscendAttentionBackendImpl(AttentionImpl): graph_params.handles[num_tokens].append(handle) else: attn_out, attn_lse = torch_npu.npu_fused_infer_attention_score( - q_nope, k_nope, value, **common_kwargs) + query, k_nope, value, **common_kwargs) - attn_out = attn_out.view(attn_out.shape[0], attn_out.shape[2], - attn_out.shape[3]) - attn_lse = attn_lse.view(attn_lse.shape[0], attn_lse.shape[1], 1) + attn_out_lse_list = [] + # Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1] + attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1) if self.dcp_size > 1: - # Concat out&lse: [bs,num_heads,v_head_dim] + [bs,num_heads,1] -> [bs,num_heads,v_head_dim+1] - attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1) # permute: [bs, num_heads, v_head_dim+1] -> [num_heads, v_head_dim+1, bs] attn_out_lse = attn_out_lse.permute([1, 2, 0]).contiguous() attn_out_lse_all2all = torch.empty_like(attn_out_lse) @@ -980,35 +985,28 @@ class AscendAttentionBackendImpl(AttentionImpl): group=self.dcp_group) # permute: [num_heads, v_head_dim+1, bs] -> [bs, num_heads, v_head_dim+1] attn_out_lse_all2all = attn_out_lse_all2all.permute([2, 0, 1]) - attn_out_lse_split_on_seq = list( + if self.pcp_size > 1: + attn_out_lse = attn_out_lse_all2all.contiguous() + attn_out_lse_list = list( torch.chunk(attn_out_lse_all2all, self.dcp_size, dim=1)) - attn_out_lse_split_dcp = torch.stack( - attn_out_lse_split_on_seq, - dim=0) # [dcp, batch_size, num_heads, head_size+1] - # Update out&lse - attn_out_split_dcp, attn_lse_split_dcp = torch.split( - attn_out_lse_split_dcp, [self.head_size, 1], dim=-1) - attn_out, attn_lse = self._update_out_and_lse( - attn_out_split_dcp, attn_lse_split_dcp) if self.pcp_size > 1: - # 2. Concat out&lse: [bs,num_heads,head_size] + [bs,num_heads,1] -> [bs,num_heads,head_size+1] - attn_out_lse = torch.cat([attn_out, attn_lse], dim=-1) - # 3. AllGather out&lse within CP group + # AllGather out&lse within CP group attn_out_lse_list = [ torch.empty_like(attn_out_lse) for _ in range(self.pcp_size) ] dist.all_gather(attn_out_lse_list, attn_out_lse, group=self.pcp_group) - # 4. Update out&lse - attn_out_lse_allgather = torch.stack( - attn_out_lse_list, - dim=0) # [pcp, batch_size, num_heads, head_size+1] - attn_out_allgather, attn_lse_allgather = torch.split( - attn_out_lse_allgather, [self.head_size, 1], dim=-1) - attn_out, _ = self._update_out_and_lse(attn_out_allgather, - attn_lse_allgather) + if self.dcp_size > 1 and self.pcp_size > 1: + attn_out_lse_list_pcp_dcp = [] + for s in attn_out_lse_list: + attn_out_lse_list_split = list( + torch.chunk(s, self.dcp_size, dim=1)) + attn_out_lse_list_pcp_dcp += attn_out_lse_list_split + attn_out_lse_list = attn_out_lse_list_pcp_dcp + # Update out&lse + attn_out = self._npu_attention_update(attn_out_lse_list) return attn_out def _forward_pcp_dcp(self, query: torch.Tensor, key: torch.Tensor, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 01bd128b..fbf1d3e2 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1374,13 +1374,13 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.input_batch.block_table.compute_slot_mapping( req_indices, positions_np) + self.input_batch.block_table.commit_slot_mapping( + total_num_scheduled_tokens) tokens, position_pcp, pcp_unpad_mask = self._update_tokens_for_pcp( tokens) num_scheduled_tokens = np.array(tokens, dtype=np.int32) # update total_num_scheduled_tokens total_num_scheduled_tokens = sum(num_scheduled_tokens[:num_reqs]) - self.input_batch.block_table.commit_slot_mapping( - total_num_scheduled_tokens) total_num_pcp_pads = sum(self.num_pcp_pads) max_num_scheduled_tokens = max(tokens) @@ -4140,7 +4140,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): num_decodes = sum(self.input_batch.num_computed_tokens_cpu[:num_reqs] >= self.input_batch.num_prompt_tokens[:num_reqs]) num_actual_tokens_pcp_padded = total_num_scheduled_tokens * self.pcp_size - num_prefills = num_reqs - num_decodes long_seq_metadata = None if self.pcp_size * self.dcp_size > 1: long_seq_metadata = AscendPrefillContextParallelMetadata( @@ -4248,9 +4247,8 @@ class NPUModelRunner(LoRAModelRunnerMixin): device=self.device, dtype=self.dtype), 1) else: - max_seq_len = max(seq_lens, default=0) pcp_prefill_mask = torch.triu( - torch.full((num_prefills, max_seq_len, max_seq_len), + torch.full((2048, 2048), True, device=self.device, dtype=torch.bool), 1)