From db12c1e2c840a98570bf8a793f2a9786999e8c3f Mon Sep 17 00:00:00 2001 From: zzhxxx Date: Sun, 11 Jan 2026 09:47:27 +0800 Subject: [PATCH] [Perf] Supports compute-communication overlap in the forward of sfa_v1 in the Sharded-CP feature. (#5701) ### What this PR does / why we need it? > Extracted from PR #5513 Based on the Sharded-CP feature PR:#4702; RFC:https://github.com/vllm-project/vllm/issues/30055 ### All-gather KV Cache for Communication Overlap: - This PR adjusts the calculation order in the SFA. - split `index_select` into `indexer_select_pre_process` and `indexer_select_post_process`. - Combine `nope`, `rope` and `index-k` into a tensor to perform asynchronous all-gather. ### benchmark: input=40k && num_batch_token=20k - before: ``` Mean TTFT (ms): 2614.52 Median TTFT (ms): 3148.03 P50 TTFT (ms): 3148.03 P90 TTFT (ms): 3163.48 P99 TTFT (ms): 3170.20 ``` - after: ``` Mean TTFT (ms): 2529.92 Median TTFT (ms): 3051.69 P50 TTFT (ms): 3051.69 P90 TTFT (ms): 3067.31 P99 TTFT (ms): 3072.15 ``` ### Does this PR introduce _any_ user-facing change? None ### How was this patch tested? - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2f4e6548efec402b913ffddc8726230d9311948d --------- Signed-off-by: zzhx1 --- vllm_ascend/attention/sfa_v1.py | 145 +++++++++++++++++++++---------- vllm_ascend/distributed/utils.py | 21 ++++- vllm_ascend/utils.py | 19 ++-- 3 files changed, 124 insertions(+), 61 deletions(-) diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 7673322c..3810c0be 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -24,6 +24,7 @@ from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, trans_rope_weight, transdata, wait_for_kv_layer_from_connector) +from vllm_ascend.distributed.utils import all_gather_async from vllm_ascend.ops.layer_shard_linear import ( is_hidden_layer, post_process_after_loading_for_shard_weight_series, reach_layer_for_shard_weight_series, @@ -227,7 +228,10 @@ class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): cos = cos[local_start:local_end_with_pad] sin = sin[local_start:local_end_with_pad] - slot_mapping_cp = slot_mapping[local_start:local_end_with_pad] + slot_mapping_cp = torch.full(size=(num_tokens_per_device, ), + fill_value=-1, + dtype=slot_mapping.dtype, + device=slot_mapping.device) assert cos.shape[0] == num_tokens_per_device, \ f"cos.shape[0] must be equal to num_tokens_per_device, \ got {cos.shape[0]} and {num_tokens_per_device}" @@ -505,7 +509,6 @@ class AscendSFAImpl(MLAAttentionImpl): sin: torch.Tensor, kv_cache: Tuple, slots: torch.Tensor, - slots_cp: Optional[torch.Tensor], ): B = kv_no_split.shape[0] N = self.num_kv_heads @@ -516,30 +519,19 @@ class AscendSFAImpl(MLAAttentionImpl): cache_mode = "PA" if self.enable_sfa_cp: - assert slots_cp is not None _, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache( kv_no_split, self.kv_a_layernorm.weight, cos, sin, - slots_cp.to(torch.int64), + slots.to(torch.int64), kv_cache[1], kv_cache[0], epsilon=self.kv_a_layernorm.variance_epsilon, cache_mode=cache_mode, is_output_kv=True, ) - # TODO: Temporarily adapt SFA-CP and replace it later with PCP. --clrs97 - k_pe = get_tp_group().all_gather(k_pe, 0) - k_nope = get_tp_group().all_gather(k_nope, 0) - - if kv_cache is not None: - torch_npu.npu_scatter_nd_update_( - kv_cache[0].view(-1, k_nope.shape[-1]), slots.view(-1, 1), - k_nope.view(-1, k_nope.shape[-1])) - torch_npu.npu_scatter_nd_update_( - kv_cache[1].view(-1, k_pe.shape[-1]), slots.view(-1, 1), - k_pe.view(-1, k_pe.shape[-1])) + return k_pe, k_nope else: torch_npu.npu_kv_rmsnorm_rope_cache( kv_no_split, @@ -552,6 +544,7 @@ class AscendSFAImpl(MLAAttentionImpl): epsilon=self.kv_a_layernorm.variance_epsilon, cache_mode=cache_mode, ) + return None, None def rope_single( self, @@ -744,6 +737,7 @@ class AscendSFAImpl(MLAAttentionImpl): if is_hidden_layer(layer): reach_layer_for_shard_weight_series(layer) return output.fill_(0) + has_prefill = attn_metadata.has_prefill cos = attn_metadata.cos sin = attn_metadata.sin @@ -764,6 +758,12 @@ class AscendSFAImpl(MLAAttentionImpl): need_gather_q_kv=need_gather_q_kv, num_input_tokens=attn_metadata.num_input_tokens, ) + q, k = self.indexer_select_pre_process( + x=hidden_states, + qr=q_c, + cos=cos, + sin=sin, + need_gather_q_kv=need_gather_q_kv) else: assert self.fused_qkv_a_proj is not None, "q lora is required for DSA." maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight, @@ -782,31 +782,67 @@ class AscendSFAImpl(MLAAttentionImpl): kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( kv_no_split.contiguous(), need_gather_q_kv) + q, k = self.indexer_select_pre_process( + x=hidden_states, + qr=q_c, + cos=cos, + sin=sin, + need_gather_q_kv=need_gather_q_kv) + if has_prefill: wait_for_kv_layer_from_connector(layer_name) slot_mapping = attn_metadata.slot_mapping - slot_mapping_cp = None if self.enable_sfa_cp: assert attn_metadata.sfa_cp_context is not None - slot_mapping_cp = attn_metadata.sfa_cp_context.slot_mapping_cp + slot_mapping = attn_metadata.sfa_cp_context.slot_mapping_cp actual_seq_lengths_query = attn_metadata.sfa_cp_context.actual_seq_lengths_query actual_seq_lengths_key = attn_metadata.sfa_cp_context.actual_seq_lengths_key - self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping, - slot_mapping_cp) + k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache, + slot_mapping) - if self.enable_sfa_cp and attn_metadata.sfa_cp_context is not None: - for layer in (self.layer_sharding_kwargs or []): - if is_hidden_layer(layer): - reach_layer_for_shard_weight_series(layer) + if self.enable_sfa_cp: + assert k_pe is not None + assert k_nope is not None + # support all_gather kv async for communication calculation overlap + fused_kv_no_split, kv_ag_handle = all_gather_async( + torch.cat([ + k_pe.view(-1, k_pe.shape[-1]), + k_nope.view(-1, k_nope.shape[-1]), + k.view(-1, k.shape[-1]) + ], + dim=1), get_tp_group()) ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c) q_pe = self.rope_single(q_pe, cos, sin) - topk_indices = self.indexer_select( + if self.enable_sfa_cp: + if kv_ag_handle is not None: + kv_ag_handle.wait() + for layer in (self.layer_sharding_kwargs or []): + if is_hidden_layer(layer): + reach_layer_for_shard_weight_series(layer) + + if kv_cache is not None: + assert fused_kv_no_split is not None + k_pe, k_nope, k = fused_kv_no_split.split([ + self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim + ], + dim=-1) + slot_mapping = attn_metadata.slot_mapping.view(-1, 1) + torch_npu.npu_scatter_nd_update_( + kv_cache[0].view(-1, k_nope.shape[-1]), slot_mapping, + k_nope) + torch_npu.npu_scatter_nd_update_( + kv_cache[1].view(-1, k_pe.shape[-1]), slot_mapping, + k_pe) + + topk_indices = self.indexer_select_post_process( x=hidden_states, qr=q_c, + q=q, + k=k, kv_cache=kv_cache, attn_metadata=attn_metadata, cos=cos, @@ -814,6 +850,7 @@ class AscendSFAImpl(MLAAttentionImpl): actual_seq_lengths_query=actual_seq_lengths_query, actual_seq_lengths_key=actual_seq_lengths_key, need_gather_q_kv=need_gather_q_kv) + attn_output = torch.ops._C_ascend.npu_sparse_flash_attention( query=ql_nope, key=kv_cache[0], @@ -830,6 +867,7 @@ class AscendSFAImpl(MLAAttentionImpl): layout_kv="PA_BSND", sparse_mode=3, ) + attn_output = self._v_up_proj(attn_output, has_prefill) maybe_npu_prefetch(inputs=self.o_proj.weight, dependency=attn_output, @@ -838,22 +876,14 @@ class AscendSFAImpl(MLAAttentionImpl): output[...] = self.o_proj(attn_output)[0] return output_padded - def indexer_select( + def indexer_select_pre_process( self, x: torch.Tensor, qr: torch.Tensor, - kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], - attn_metadata: M, cos: torch.Tensor, sin: torch.Tensor, - actual_seq_lengths_query: torch.Tensor, - actual_seq_lengths_key: torch.Tensor, need_gather_q_kv: bool = False, ): - # q process in new stream - q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] - q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128] - k_proj, _ = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128] k_proj = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( k_proj, need_gather_q_kv) @@ -861,6 +891,9 @@ class AscendSFAImpl(MLAAttentionImpl): k = k.view(-1, 1, self.head_dim) if HAS_TRITON: + q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] + q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128] + cos = cos.view(-1, self.qk_rope_head_dim) sin = sin.view(-1, self.qk_rope_head_dim) q, k = rope_forward_triton(q, @@ -870,6 +903,38 @@ class AscendSFAImpl(MLAAttentionImpl): rope_dim=self.qk_rope_head_dim, is_neox_style=True) else: + k_pe, k_nope = torch.split( + k, + [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], + dim=-1) # [b,s,64+64] + + k_pe = k_pe.unsqueeze(2) + k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin) + k_pe = k_pe.squeeze(2) + + k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128] + q = None + + return q, k + + def indexer_select_post_process( + self, + x: torch.Tensor, + qr: torch.Tensor, + q: Optional[torch.Tensor], + k: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + attn_metadata: M, + cos: torch.Tensor, + sin: torch.Tensor, + actual_seq_lengths_query: torch.Tensor, + actual_seq_lengths_key: torch.Tensor, + need_gather_q_kv: bool = False, + ): + if q is None: + q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] + q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128] + cos_q, sin_q = cos, sin cos = cos.view(-1, 1, 1, self.qk_rope_head_dim) sin = sin.view(-1, 1, 1, self.qk_rope_head_dim) @@ -884,20 +949,6 @@ class AscendSFAImpl(MLAAttentionImpl): q_pe = q_pe.squeeze(2) q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128] - k_pe, k_nope = torch.split( - k, - [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], - dim=-1) # [b,s,64+64] - - k_pe = k_pe.unsqueeze(2) - k_pe = torch_npu.npu_rotary_mul(k_pe, cos, sin) - k_pe = k_pe.squeeze(2) - - k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128] - - if self.enable_sfa_cp: - k = get_tp_group().all_gather(k, 0) - if kv_cache is not None: if self.is_kv_producer: attn_metadata.reshape_cache_event = torch.npu.Event() diff --git a/vllm_ascend/distributed/utils.py b/vllm_ascend/distributed/utils.py index 70c57d28..3a624de2 100644 --- a/vllm_ascend/distributed/utils.py +++ b/vllm_ascend/distributed/utils.py @@ -1,8 +1,9 @@ import os +from typing import Optional import torch import torch.distributed as dist -from vllm.distributed.parallel_state import get_dp_group +from vllm.distributed.parallel_state import GroupCoordinator, get_dp_group from vllm.forward_context import get_forward_context from vllm_ascend.distributed.parallel_state import (get_fc3_quant_x_group, @@ -90,3 +91,21 @@ def fc3_all_gather_and_maybe_unpad_impl(x: torch.Tensor, ) -> torch.Tensor: offset += num_tokens_dp x = result return x + + +def all_gather_async(input: torch.Tensor, + group: GroupCoordinator, + output: Optional[torch.Tensor] = None, + async_op: bool = True): + if group.world_size == 1: + return input, None + if output is None: + input_size = input.size() + output_size = (input_size[0] * group.world_size, ) + input_size[1:] + output = torch.empty(output_size, + dtype=input.dtype, + device=input.device) + return output, dist.all_gather_into_tensor(output, + input, + group=group.device_group, + async_op=async_op) \ No newline at end of file diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 0000d696..8f8fb0c1 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -1172,20 +1172,13 @@ def singleton(cls): @lru_cache(maxsize=1) def enable_dsa_cp() -> bool: from vllm.config import get_current_vllm_config - vllm_config = get_current_vllm_config() - if vllm_config is None: - return False - - model_config = getattr(vllm_config, "model_config", None) - if model_config is None: - return False - - hf_text_config = getattr(model_config, "hf_text_config", None) - if hf_text_config is None: - return False - - return hasattr(hf_text_config, "index_topk") + is_ds_v32 = hasattr( + vllm_config.model_config, "hf_text_config") and hasattr( + vllm_config.model_config.hf_text_config, "index_topk") + if is_ds_v32 and enable_sp(): + return True + return False @lru_cache(maxsize=1)