From 2bd9c357886f04eca1ce5f9a00fe09aa77570479 Mon Sep 17 00:00:00 2001 From: rjg-lyh <83491835+rjg-lyh@users.noreply.github.com> Date: Thu, 5 Mar 2026 14:27:11 +0800 Subject: [PATCH] [perf][refactor] Refactor and optimize sfa_v1.py for dsv3.2/glm5 (#6874) ### What this PR does / why we need it? This PR refactors sfa_v1.py to improve code readability and usability, fixes a code bug, and enhances performance through the replacement of certain operators. ### changes - **improve code readability**: Optimizes parts of the code structure in sfa_v1.py, supplementary comments for key code blocks, removes some unused variables, and improves the naming of certain functions and variables. - **resolved a duplicated double write to k_cache**: Fixed redundant double writes of k_cache in the indexer_select module (in both the `forward` function and `indexer_select_post_process`), improving performance to some extent. - **replace `scatter` ops with `reshape_and_cache`**: This optimization replaces two separate cache storage operations on `k_nope` and `k_pe` with a single call to the `reshape_and_cache` operator, improving performance. The original `scatter` operator involves reordering slot_mapping for generality, introducing significant scalar computations. In contrast, the `reshape_and_cache` operator eliminates this redundant reordering step, thus reducing unnecessary computation time and enhancing the operator's performance. ### performance comparison 4*A3, 1P1D, P dp2tp16, D dp8tp4, input/output: 64K/3K origin: TTFT: **28s**, TPOT: 26ms, TPS: **820 token/s** fixed redundant double writes of k_cache: TTFT: **24s**, TPOT: 26ms, TPS: **840 token/s** replace scatter ops with reshape_and_cache: TTFT: **24s**, TPOT: 26ms, TPS: **850 token/s** ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed with new added/existing test. - vLLM version: v0.16.0 - vLLM main: https://github.com/vllm-project/vllm/commit/15d76f74e2fdb12a95ea00f0ca283acf6219a2b7 --------- Signed-off-by: rjg-lyh <1318825571@qq.com> --- .../attention/context_parallel/sfa_cp.py | 43 +- vllm_ascend/attention/sfa_v1.py | 981 +++++++++--------- vllm_ascend/ops/triton/rope.py | 153 +++ vllm_ascend/utils.py | 14 + 4 files changed, 676 insertions(+), 515 deletions(-) diff --git a/vllm_ascend/attention/context_parallel/sfa_cp.py b/vllm_ascend/attention/context_parallel/sfa_cp.py index 73366682..2b6f361d 100644 --- a/vllm_ascend/attention/context_parallel/sfa_cp.py +++ b/vllm_ascend/attention/context_parallel/sfa_cp.py @@ -5,10 +5,12 @@ import torch import torch_npu from vllm.config import VllmConfig from vllm.distributed import get_dcp_group, get_pcp_group +from vllm.triton_utils import HAS_TRITON from vllm_ascend.attention.context_parallel.common_cp import AscendPCPMetadata from vllm_ascend.attention.sfa_v1 import AscendSFAImpl, AscendSFAMetadata, AscendSFAMetadataBuilder from vllm_ascend.attention.utils import AscendCommonAttentionMetadata, enabling_mlapo, split_decodes_and_prefills +from vllm_ascend.ops.triton.rope import rope_forward_triton_siso M = TypeVar("M", bound=AscendSFAMetadata) @@ -299,42 +301,33 @@ class AscendSFACPImpl(AscendSFAImpl): def indexer_select_post_process( self, x: torch.Tensor, - qr: torch.Tensor, - q: torch.Tensor | None, - k: torch.Tensor, + q_c: torch.Tensor, kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor], attn_metadata: M, cos: torch.Tensor, sin: torch.Tensor, actual_seq_lengths_query: torch.Tensor, actual_seq_lengths_key: torch.Tensor, - need_gather_q_kv: bool = False, ): - if q is None: - q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] - q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128] - cos_q, sin_q = cos, sin + weights, _ = self.weights_proj(x) - q_pe, q_nope = torch.split( - q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1 + q_li, _ = self.wq_b(q_c) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] + q_li = q_li.view(-1, self.n_head, self.head_dim) # [n_toks,64,128] + if HAS_TRITON: + q_li = rope_forward_triton_siso( + q_li, cos, sin, rope_dim=self.qk_rope_head_dim, is_neox_style=self.is_rope_neox_style + ) + else: + q_li_pe, q_li_nope = torch.split( + q_li, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1 ) # [b,s,64,64+64] - q_pe = q_pe.unsqueeze(2) - q_pe = torch_npu.npu_rotary_mul(q_pe, cos_q, sin_q) - q_pe = q_pe.squeeze(2) - q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128] + q_li_pe = q_li_pe.unsqueeze(2) + q_li_pe = torch_npu.npu_rotary_mul(q_li_pe, cos, sin) + q_li_pe = q_li_pe.squeeze(2) + q_li = torch.cat([q_li_pe, q_li_nope], dim=-1) # [b*s,64,128] - if kv_cache is not None: - if self.is_kv_producer: - attn_metadata.reshape_cache_event = torch.npu.Event() - torch_npu.npu_scatter_nd_update_( - kv_cache[2].view(-1, k.shape[-1]), attn_metadata.slot_mapping.view(-1, 1), k.view(-1, k.shape[-1]) - ) # b, s, n, d - if self.is_kv_producer: - attn_metadata.reshape_cache_event.record() - - weights, _ = self.weights_proj(x) - weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(weights, need_gather_q_kv) + q = q_li key = kv_cache[2] assert attn_metadata.sfa_cp_metadata is not None diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index 64d21ef2..22def96c 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -30,6 +30,7 @@ from vllm_ascend.attention.utils import ( transdata, wait_for_kv_layer_from_connector, ) +from vllm_ascend.device.device_op import DeviceOperator from vllm_ascend.distributed.utils import all_gather_async from vllm_ascend.ops.layer_shard_linear import ( is_hidden_layer, @@ -38,7 +39,7 @@ from vllm_ascend.ops.layer_shard_linear import ( register_all_layers_to_shard_weight_series, ) from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla -from vllm_ascend.ops.triton.rope import rope_forward_triton +from vllm_ascend.ops.triton.rope import rope_forward_triton_siso from vllm_ascend.quantization.methods import AscendW8A8LinearMethod from vllm_ascend.utils import ( ACL_FORMAT_FRACTAL_ND, @@ -46,6 +47,7 @@ from vllm_ascend.utils import ( dispose_layer, enable_dsa_cp, enable_dsa_cp_with_layer_shard, + enable_dsa_cp_with_o_proj_tp, get_weight_prefetch_method, maybe_trans_nz, ) @@ -393,8 +395,8 @@ class AscendSFAImpl(MLAAttentionImpl): ascend_config = get_ascend_config() self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp - # In sfa, prefill and decode have the same calculation formula, - # so do not distinguish between prefill and decode here. + # The MLAPO operator fuses the pre-processing steps on Q/K/V in MLA into a single operator + # NOTE: it imposes a limit on the number of input tokens and conflicts with FlashComm self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO assert self.indexer is not None, "Indexer is required for DSA." @@ -419,21 +421,29 @@ class AscendSFAImpl(MLAAttentionImpl): self.is_rope_neox_style = False self.use_torch_npu_lightning_indexer = True + # Effective in SFA when FlashComm is enabled. self.enable_dsa_cp = enable_dsa_cp() - self.enable_dsa_cp_prefill_only = enable_dsa_cp_with_layer_shard() + + # Enable layer sharding via DSA-CP on the P node in the PD-disaggregated setup. + self.enable_dsa_cp_with_layer_shard = enable_dsa_cp_with_layer_shard() + + # use original TP o_proj weight in PD mix stage, and full gather + # for o_proj weight for prefill stage. + self.enable_dsa_cp_with_o_proj_tp = enable_dsa_cp_with_o_proj_tp() + if self.enable_dsa_cp: self.local_num_heads = self.num_heads * self.tp_size - if self.enable_dsa_cp_prefill_only: - self.layer_sharding_kwargs = [] - for layer_name in get_ascend_config().layer_sharding or []: - if layer_name in kwargs: - self.layer_sharding_kwargs.append(kwargs[layer_name]) - else: - logger.warning_once( - f"[SFAImpl init] Layer '{layer_name}' not found in kwargs for layer sharding, " - "skipping sharding configuration" - ) - register_all_layers_to_shard_weight_series(self.layer_sharding_kwargs) + if self.enable_dsa_cp_with_layer_shard: + self.layer_sharding_kwargs = [] + for layer_name in get_ascend_config().layer_sharding or []: + if layer_name in kwargs: + self.layer_sharding_kwargs.append(kwargs[layer_name]) + else: + logger.warning_once( + f"[SFAImpl init] Layer '{layer_name}' not found in kwargs for layer sharding, " + "skipping sharding configuration" + ) + register_all_layers_to_shard_weight_series(self.layer_sharding_kwargs) def process_weights_after_loading(self, act_dtype: torch.dtype): # NOTE: We currently do not support quant kv_b_proj. @@ -469,7 +479,7 @@ class AscendSFAImpl(MLAAttentionImpl): # Dispose kv_b_proj since it is replaced by W_UV and W_UK_T to save memory dispose_layer(self.kv_b_proj) if self.enable_dsa_cp: - if self.enable_dsa_cp_prefill_only: + if self.enable_dsa_cp_with_layer_shard: for layer in self.layer_sharding_kwargs or []: if is_hidden_layer(layer): post_process_after_loading_for_shard_weight_series(layer) @@ -501,100 +511,6 @@ class AscendSFAImpl(MLAAttentionImpl): # if mlapo, W_UK_T can't trans nz self.W_UK_T = maybe_trans_nz(self.W_UK_T) - def _v_up_proj(self, x): - num_input_tokens, _, _ = x.shape - if ( - x.dtype in [torch.float16, torch.bfloat16] - and hasattr(torch.ops._C_ascend, "batch_matmul_transpose") - and num_input_tokens <= BMM_TRANS_MAX_SUPPORTED_TOKENS - ): - x = x.view(-1, self.local_num_heads, self.kv_lora_rank) - res = torch.empty((num_input_tokens, self.local_num_heads, self.v_head_dim), dtype=x.dtype, device=x.device) - torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res) - x = res.reshape(-1, self.local_num_heads * self.v_head_dim) - else: - # Convert from (B, N, L) to (N, B, L) - x = x.view(-1, self.local_num_heads, self.kv_lora_rank).transpose(0, 1) - # # Multiply (N, B, L) x (N, L, V) -> (N, B, V) - x = torch.bmm(x, self.W_UV) - # # Convert from (N, B, V) to (B, N * V) - x = x.transpose(0, 1).reshape(-1, self.local_num_heads * self.v_head_dim) - return x - - # Return `ql_nope`, `q_pe` - def _q_proj_and_k_up_proj(self, x): - q_nope, q_pe = ( - self.q_proj(x)[0] - .view(-1, self.local_num_heads, self.qk_head_dim) - .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - ) - - # Convert from (B, N, P) to (N, B, P) - q_nope = q_nope.transpose(0, 1) - # Multiply (N, B, P) x (N, P, L) -> (N, B, L) - ql_nope = torch.bmm(q_nope, self.W_UK_T) - # Convert from (N, B, L) to (B, N, L) - return ql_nope.transpose(0, 1), q_pe - - def _get_full_kv(self, k, attn_metadata): - return k - - def exec_kv( - self, - kv_no_split: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - kv_cache: tuple, - slots: torch.Tensor, - attn_metadata: M, - ): - B = kv_no_split.shape[0] - N = self.num_kv_heads - S = 1 - # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] - kv_no_split = kv_no_split.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) - cache_mode = "PA" - - if self.enable_dsa_cp: - _, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache( - kv_no_split, - self.kv_a_layernorm.weight, # type: ignore[union-attr] - cos, - sin, - slots.to(torch.int64), - kv_cache[1], - kv_cache[0], - epsilon=self.kv_a_layernorm.variance_epsilon, # type: ignore[union-attr] - cache_mode=cache_mode, - is_output_kv=True, - ) - return k_pe, k_nope - else: - torch_npu.npu_kv_rmsnorm_rope_cache( - kv_no_split, - self.kv_a_layernorm.weight, # type: ignore[union-attr] - cos, - sin, - slots.to(torch.int64), - kv_cache[1], - kv_cache[0], - epsilon=self.kv_a_layernorm.variance_epsilon, # type: ignore[union-attr] - cache_mode=cache_mode, - ) - return None, None - - def rope_single( - self, - x: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - ) -> torch.Tensor: - B, N, D = x.shape - S = 1 - x = x.view(B, N, S, D) - x = torch_npu.npu_interleave_rope(x, cos, sin) - return x.view(B, N, D) - # Processing the input parameters for MLAPO by reordering and transposing # QKV(and part of Q) weight, applying RoPE-related dimension transformations, # and handling quantization parameters. @@ -672,375 +588,38 @@ class AscendSFAImpl(MLAAttentionImpl): self.q_proj.quant_bias = None torch.npu.empty_cache() - def _sfa_preprocess_decode( + def forward_mha( self, - hidden_states: torch.Tensor, - kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor], + q: torch.Tensor, + kv_c_normed: torch.Tensor, + k_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: M, - need_gather_q_kv: bool, - num_input_tokens: int, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(hidden_states.contiguous(), need_gather_q_kv) - k_nope, k_pe = kv_cache[0], kv_cache[1] - ql_nope = torch.empty( - (num_input_tokens, self.W_UK_T.shape[0], k_nope.shape[-1]), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - q_pe = torch.empty( - (num_input_tokens, self.W_UK_T.shape[0], k_pe.shape[-1]), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - q_c = torch.empty( - (num_input_tokens, self.q_lora_rank), - dtype=hidden_states.dtype, - device=hidden_states.device, - ) - torch.ops._C_ascend.mla_preprocess( - hidden_states, - self.wd_qkv, - self.deq_scale_qkv, - self.gamma1, - self.beta1, - self.wu_q, - self.qb_deq_scl, - self.gamma2, - attn_metadata.cos, - attn_metadata.sin, - self.W_UK_T, - k_nope, - k_pe, - attn_metadata.slot_mapping, - quant_scale0=self.quant_scale0, - quant_offset0=self.quant_offset0, - bias0=self.quant_bias_qkv, - quant_scale1=self.quant_scale1, - quant_offset1=self.quant_offset1, - bias1=self.qb_qt_bias, - ctkv_scale=self.ctkv_scale, - q_nope_scale=self.q_nope_scale, - cache_mode="krope_ctkv", - quant_mode="per_tensor_quant_asymm", - enable_inner_out=True, - q_out0=ql_nope, - kv_cache_out0=k_nope, - q_out1=q_pe, - kv_cache_out1=k_pe, - inner_out=q_c, - ) - return hidden_states, ql_nope, q_pe, q_c + k_scale: torch.Tensor, + output: torch.Tensor, + ) -> None: + raise NotImplementedError("forward_mha is not supported for SFA attention. Use forward() instead.") - def forward( + def forward_mqa( self, - layer_name, - hidden_states: torch.Tensor, # query in unified attn - kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor], + q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], + kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: M, - need_gather_q_kv: bool = False, - output: torch.Tensor | None = None, + layer, + ) -> tuple[torch.Tensor, torch.Tensor | None]: + raise NotImplementedError("forward_mqa is not supported for SFA attention. Use forward() instead.") + + def rope_single( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, ) -> torch.Tensor: - assert output is not None, "Output tensor must be provided." - forward_context = get_forward_context() - if attn_metadata is None: - # Profiling run. - if self.enable_dsa_cp_prefill_only and not forward_context.in_profile_run: - for layer in self.layer_sharding_kwargs or []: - if is_hidden_layer(layer): - reach_layer_for_shard_weight_series(layer) - return output.fill_(0) - - cos = attn_metadata.cos - sin = attn_metadata.sin - actual_seq_lengths_query = attn_metadata.cum_query_lens - actual_seq_lengths_key = attn_metadata.seq_lens - if self.enable_dsa_cp: - need_gather_q_kv = False - # Inputs and outputs may be padded for CUDA graphs - num_input_tokens = attn_metadata.num_input_tokens - output_padded = output - - # all-gather o_proj weight for prefill stage of PD mix node - o_proj_full_handle = None - # if is PD mix stage, using original TP o_proj weight, and also need to full gather for o_proj - # weight for prefill stage. - should_shard_weight = self.enable_dsa_cp_prefill_only or attn_metadata.attn_state not in { - AscendAttentionState.DecodeOnly, - AscendAttentionState.SpecDecoding, - } - - if self.enable_mlapo and num_input_tokens <= MLAPO_MAX_SUPPORTED_TOKENS: - hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_decode( - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - need_gather_q_kv=need_gather_q_kv, - num_input_tokens=num_input_tokens, - ) - q, k = self.indexer_select_pre_process( - x=hidden_states, qr=q_c, cos=cos, sin=sin, need_gather_q_kv=need_gather_q_kv - ) - else: - assert self.fused_qkv_a_proj is not None, "q lora is required for DSA." - weight_prefetch_method = get_weight_prefetch_method() - weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream( - inputs=self.fused_qkv_a_proj.weight, dependency=hidden_states - ) - qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] - q_c, kv_no_split = qkv_lora.split( - [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], - dim=-1, - ) - assert self.q_a_layernorm is not None, "q_a_layernorm must be initialized" - q_c = self.q_a_layernorm(q_c) - # Process for Flash Comm V1 - if need_gather_q_kv: - q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(q_c.contiguous(), need_gather_q_kv) - kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - kv_no_split.contiguous(), need_gather_q_kv - ) - - q, k = self.indexer_select_pre_process( - x=hidden_states, qr=q_c, cos=cos, sin=sin, need_gather_q_kv=need_gather_q_kv - ) - - wait_for_kv_layer_from_connector(layer_name) - - slot_mapping = attn_metadata.slot_mapping - if self.enable_dsa_cp: - assert attn_metadata.dsa_cp_context is not None - slot_mapping = attn_metadata.dsa_cp_context.slot_mapping_cp - actual_seq_lengths_query = attn_metadata.dsa_cp_context.actual_seq_lengths_query - actual_seq_lengths_key = attn_metadata.dsa_cp_context.actual_seq_lengths_key - - k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping, attn_metadata) - - if self.enable_dsa_cp: - assert k_pe is not None - assert k_nope is not None - # support all_gather kv async for communication calculation overlap - fused_kv_no_split, kv_ag_handle = all_gather_async( - torch.cat( - [k_pe.view(-1, k_pe.shape[-1]), k_nope.view(-1, k_nope.shape[-1]), k.view(-1, k.shape[-1])], - dim=1, - ), - get_tp_group(), - async_op=should_shard_weight, - ) - - ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c) - q_pe = self.rope_single(q_pe, cos, sin) - - if self.enable_dsa_cp: - if kv_ag_handle is not None: - kv_ag_handle.wait() - - if self.enable_dsa_cp_prefill_only: - for layer in self.layer_sharding_kwargs or []: - if is_hidden_layer(layer): - reach_layer_for_shard_weight_series(layer) - elif should_shard_weight: - _, o_proj_full_handle = all_gather_async( - self.o_proj_tp_weight, get_tp_group(), output=AscendSFAImpl.o_proj_full_pool - ) - - if kv_cache is not None: - assert fused_kv_no_split is not None - k_pe, k_nope, k = fused_kv_no_split.split( - [self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim], dim=-1 - ) - slot_mapping = attn_metadata.slot_mapping.view(-1, 1) - torch_npu.npu_scatter_nd_update_(kv_cache[0].view(-1, k_nope.shape[-1]), slot_mapping, k_nope) - torch_npu.npu_scatter_nd_update_(kv_cache[1].view(-1, k_pe.shape[-1]), slot_mapping, k_pe) - - k = self._get_full_kv(k, attn_metadata) - if kv_cache is not None: - torch_npu.npu_scatter_nd_update_( - kv_cache[2].view(-1, k.shape[-1]), attn_metadata.slot_mapping.view(-1, 1), k.view(-1, k.shape[-1]) - ) # b, s, n, d - - topk_indices = self.indexer_select_post_process( - x=hidden_states, - qr=q_c, - q=q, - k=k, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - cos=cos, - sin=sin, - actual_seq_lengths_query=actual_seq_lengths_query, - actual_seq_lengths_key=actual_seq_lengths_key, - need_gather_q_kv=need_gather_q_kv, - ) - - attn_output = self._execute_sparse_flash_attention_process( - ql_nope, q_pe, kv_cache, topk_indices, attn_metadata, actual_seq_lengths_query, actual_seq_lengths_key - ) - - attn_output = self._v_up_proj(attn_output) - weight_prefetch_method = get_weight_prefetch_method() - weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream( - inputs=self.o_proj.weight, - dependency=attn_output, - max_size=MAX_O_PROJ_PREFETCH_SIZE, - linear_layer=self.o_proj, - ) - - if self.enable_dsa_cp and not self.enable_dsa_cp_prefill_only: - # When using SFA-CP with pd mixed, o_proj has two cases: - # 1. prefill: o_proj is a TP weight, we need to all-gather o_proj weight to switch TP=1. - # 2. decode: all-to-all the hidden_state before the o_proj forward. - result, require_o_proj_forward = self._handle_o_proj_weight_switch_and_forward( - attn_output=attn_output, - output=output, - o_proj_full_handle=o_proj_full_handle, - should_shard_weight=should_shard_weight, - ) - if not require_o_proj_forward: - return result - attn_output = result - - output[...] = self.o_proj(attn_output)[0] - - maybe_save_kv_layer_to_connector(layer_name, list(kv_cache)) - - return output_padded - - def _execute_sparse_flash_attention_process( - self, ql_nope, q_pe, kv_cache, topk_indices, attn_metadata, actual_seq_lengths_query, actual_seq_lengths_key - ): - block_table = attn_metadata.block_table - kv = kv_cache[0] - key_rope = kv_cache[1] - - attn_output = torch.ops._C_ascend.npu_sparse_flash_attention( - query=ql_nope, - key=kv, - value=kv, - sparse_indices=topk_indices, - scale_value=self.scale, - sparse_block_size=1, - block_table=block_table, - actual_seq_lengths_query=actual_seq_lengths_query, - actual_seq_lengths_kv=actual_seq_lengths_key, - query_rope=q_pe, - key_rope=key_rope, - layout_query="TND", - layout_kv="PA_BSND", - sparse_mode=3, - ) - return attn_output - - def indexer_select_pre_process( - self, - x: torch.Tensor, - qr: torch.Tensor, - cos: torch.Tensor, - sin: torch.Tensor, - need_gather_q_kv: bool = False, - ): - k_proj, _ = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128] - k_proj = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(k_proj, need_gather_q_kv) - k = self.k_norm(k_proj).unsqueeze(1) - k = k.view(-1, 1, self.head_dim) - - if HAS_TRITON: - q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] - q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128] - - cos = cos.view(-1, self.qk_rope_head_dim) - sin = sin.view(-1, self.qk_rope_head_dim) - q, k = rope_forward_triton( - q, k, cos, sin, rope_dim=self.qk_rope_head_dim, is_neox_style=self.is_rope_neox_style - ) - else: - k_pe, k_nope = torch.split(k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1) - - cos = cos.view(-1, 1, 1, self.qk_rope_head_dim) - sin = sin.view(-1, 1, 1, self.qk_rope_head_dim) - - k_pe = k_pe.unsqueeze(2) - k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin) - k_pe = k_pe.squeeze(2) - - k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128] - q = None - - return q, k - - def indexer_select_post_process( - self, - x: torch.Tensor, - qr: torch.Tensor, - q: torch.Tensor | None, - k: torch.Tensor, - kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor], - attn_metadata: M, - cos: torch.Tensor, - sin: torch.Tensor, - actual_seq_lengths_query: torch.Tensor, - actual_seq_lengths_key: torch.Tensor, - need_gather_q_kv: bool = False, - ): - if q is None: - q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] - q = q.view(-1, self.n_head, self.head_dim) # [n_toks,64,128] - cos_q, sin_q = cos, sin - - q_pe, q_nope = torch.split( - q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1 - ) # [b,s,64,64+64] - - q_pe = q_pe.unsqueeze(2) - q_pe = torch_npu.npu_rotary_mul(q_pe, cos_q, sin_q) - q_pe = q_pe.squeeze(2) - q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128] - - if kv_cache is not None: - if self.is_kv_producer: - attn_metadata.reshape_cache_event = torch.npu.Event() - torch_npu.npu_scatter_nd_update_( - kv_cache[2].view(-1, k.shape[-1]), attn_metadata.slot_mapping.view(-1, 1), k.view(-1, k.shape[-1]) - ) # b, s, n, d - if self.is_kv_producer: - attn_metadata.reshape_cache_event.record() - - weights, _ = self.weights_proj(x) - weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(weights, need_gather_q_kv) - - key = kv_cache[2] - block_table = attn_metadata.block_table - - # DSV3.2 currently has graph compilation issues when using torch_npu.npu.lightning_indexer. - # So two branches are maintained temporarily. - # TODO: torch.ops._C_ascend.npu_lightning_indexer needs to be removed. - if self.use_torch_npu_lightning_indexer: - topk_indices, _ = torch_npu.npu_lightning_indexer( - query=q, - key=key, - weights=weights, - actual_seq_lengths_query=actual_seq_lengths_query, - actual_seq_lengths_key=actual_seq_lengths_key, - block_table=block_table, - layout_query="TND", - layout_key="PA_BSND", - sparse_count=2048, - sparse_mode=3, - ) - else: - topk_indices = torch.ops._C_ascend.npu_lightning_indexer( - query=q, - key=key, - weights=weights, - actual_seq_lengths_query=actual_seq_lengths_query, - actual_seq_lengths_key=actual_seq_lengths_key, - block_table=block_table, - layout_query="TND", - layout_key="PA_BSND", - sparse_count=2048, - sparse_mode=3, - ) - return topk_indices + B, N, D = x.shape + S = 1 + x = x.view(B, N, S, D) + x = torch_npu.npu_interleave_rope(x, cos, sin) + return x.view(B, N, D) def _init_o_proj_tp_full_params(self): """ @@ -1120,23 +699,445 @@ class AscendSFAImpl(MLAAttentionImpl): return attn_output, True - def forward_mha( - self, - q: torch.Tensor, - kv_c_normed: torch.Tensor, - k_pe: torch.Tensor, - kv_c_and_k_pe_cache: torch.Tensor, - attn_metadata: M, - k_scale: torch.Tensor, - output: torch.Tensor, - ) -> None: - raise NotImplementedError("forward_mha is not supported for SFA attention. Use forward() instead.") + def _get_full_kv(self, k, attn_metadata): + return k - def forward_mqa( + def exec_kv( self, - q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], - kv_c_and_k_pe_cache: torch.Tensor, + kv_no_split: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + kv_cache: tuple, + slots: torch.Tensor, attn_metadata: M, - layer, - ) -> tuple[torch.Tensor, torch.Tensor | None]: - raise NotImplementedError("forward_mqa is not supported for SFA attention. Use forward() instead.") + ): + B = kv_no_split.shape[0] + N = self.num_kv_heads + S = 1 + # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] + kv_no_split = kv_no_split.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) + cache_mode = "PA" + + if self.enable_dsa_cp: + _, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache( + kv_no_split, + self.kv_a_layernorm.weight, # type: ignore[union-attr] + cos, + sin, + slots.to(torch.int64), + kv_cache[1], + kv_cache[0], + epsilon=self.kv_a_layernorm.variance_epsilon, # type: ignore[union-attr] + cache_mode=cache_mode, + is_output_kv=True, + ) + return k_pe, k_nope + else: + torch_npu.npu_kv_rmsnorm_rope_cache( + kv_no_split, + self.kv_a_layernorm.weight, # type: ignore[union-attr] + cos, + sin, + slots.to(torch.int64), + kv_cache[1], + kv_cache[0], + epsilon=self.kv_a_layernorm.variance_epsilon, # type: ignore[union-attr] + cache_mode=cache_mode, + ) + return None, None + + # Return `ql_nope`, `q_pe` + def _q_proj_and_k_up_proj(self, x): + q_nope, q_pe = ( + self.q_proj(x)[0] + .view(-1, self.local_num_heads, self.qk_head_dim) + .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + ) + + # Convert from (B, N, P) to (N, B, P) + q_nope = q_nope.transpose(0, 1) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + ql_nope = torch.bmm(q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + return ql_nope.transpose(0, 1), q_pe + + def _v_up_proj(self, x): + num_input_tokens, _, _ = x.shape + if ( + x.dtype in [torch.float16, torch.bfloat16] + and hasattr(torch.ops._C_ascend, "batch_matmul_transpose") + and num_input_tokens <= BMM_TRANS_MAX_SUPPORTED_TOKENS + ): + x = x.view(-1, self.local_num_heads, self.kv_lora_rank) + res = torch.empty((num_input_tokens, self.local_num_heads, self.v_head_dim), dtype=x.dtype, device=x.device) + torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res) + x = res.reshape(-1, self.local_num_heads * self.v_head_dim) + else: + # Convert from (B, N, L) to (N, B, L) + x = x.view(-1, self.local_num_heads, self.kv_lora_rank).transpose(0, 1) + # # Multiply (N, B, L) x (N, L, V) -> (N, B, V) + x = torch.bmm(x, self.W_UV) + # # Convert from (N, B, V) to (B, N * V) + x = x.transpose(0, 1).reshape(-1, self.local_num_heads * self.v_head_dim) + return x + + def _sfa_preprocess_with_mlapo( + self, + hidden_states: torch.Tensor, + kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor], + cos: torch.Tensor, + sin: torch.Tensor, + slot_mapping: torch.Tensor, + num_input_tokens: int, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + k_nope, k_pe = kv_cache[0], kv_cache[1] + ql_nope = torch.empty( + (num_input_tokens, self.W_UK_T.shape[0], k_nope.shape[-1]), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + q_pe = torch.empty( + (num_input_tokens, self.W_UK_T.shape[0], k_pe.shape[-1]), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + q_c = torch.empty( + (num_input_tokens, self.q_lora_rank), + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + torch.ops._C_ascend.mla_preprocess( + hidden_states, + self.wd_qkv, + self.deq_scale_qkv, + self.gamma1, + self.beta1, + self.wu_q, + self.qb_deq_scl, + self.gamma2, + cos, + sin, + self.W_UK_T, + k_nope, + k_pe, + slot_mapping, + quant_scale0=self.quant_scale0, + quant_offset0=self.quant_offset0, + bias0=self.quant_bias_qkv, + quant_scale1=self.quant_scale1, + quant_offset1=self.quant_offset1, + bias1=self.qb_qt_bias, + ctkv_scale=self.ctkv_scale, + q_nope_scale=self.q_nope_scale, + cache_mode="krope_ctkv", + quant_mode="per_tensor_quant_asymm", + enable_inner_out=True, + q_out0=ql_nope, + kv_cache_out0=k_nope, + q_out1=q_pe, + kv_cache_out1=k_pe, + inner_out=q_c, + ) + return hidden_states, ql_nope, q_pe, q_c + + def indexer_select_pre_process( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ): + k_li, _ = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128] + k_li = self.k_norm(k_li).unsqueeze(1) + k_li = k_li.view(-1, 1, self.head_dim) + + if HAS_TRITON: + cos = cos.view(-1, self.qk_rope_head_dim) + sin = sin.view(-1, self.qk_rope_head_dim) + k_li = rope_forward_triton_siso( + k_li, cos, sin, rope_dim=self.qk_rope_head_dim, is_neox_style=self.is_rope_neox_style + ) + else: + k_li_pe, k_li_nope = torch.split( + k_li, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1 + ) + + cos = cos.view(-1, 1, 1, self.qk_rope_head_dim) + sin = sin.view(-1, 1, 1, self.qk_rope_head_dim) + + k_li_pe = k_li_pe.unsqueeze(2) + k_li_pe = torch_npu.npu_interleave_rope(k_li_pe, cos, sin) + k_li_pe = k_li_pe.squeeze(2) + + k_li = torch.cat([k_li_pe, k_li_nope], dim=-1) # [b*s,128] + + return k_li + + def indexer_select_post_process( + self, + x: torch.Tensor, + q_c: torch.Tensor, + kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor], + attn_metadata: M, + cos: torch.Tensor, + sin: torch.Tensor, + actual_seq_lengths_query: torch.Tensor, + actual_seq_lengths_key: torch.Tensor, + ): + weights, _ = self.weights_proj(x) + + q_li, _ = self.wq_b(q_c) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] + q_li = q_li.view(-1, self.n_head, self.head_dim) # [n_toks,64,128] + if HAS_TRITON: + q_li = rope_forward_triton_siso( + q_li, cos, sin, rope_dim=self.qk_rope_head_dim, is_neox_style=self.is_rope_neox_style + ) + else: + q_li_pe, q_li_nope = torch.split( + q_li, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1 + ) # [b,s,64,64+64] + + q_li_pe = q_li_pe.unsqueeze(2) + q_li_pe = torch_npu.npu_rotary_mul(q_li_pe, cos, sin) + q_li_pe = q_li_pe.squeeze(2) + q_li = torch.cat([q_li_pe, q_li_nope], dim=-1) # [b*s,64,128] + + # DSV3.2 currently has graph compilation issues when using torch_npu.npu.lightning_indexer. + # So two branches are maintained temporarily. + # TODO: torch.ops._C_ascend.npu_lightning_indexer needs to be removed. + if self.use_torch_npu_lightning_indexer: + topk_indices, _ = torch_npu.npu_lightning_indexer( + query=q_li, + key=kv_cache[2], + weights=weights, + actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_key=actual_seq_lengths_key, + block_table=attn_metadata.block_table, + layout_query="TND", + layout_key="PA_BSND", + sparse_count=2048, + sparse_mode=3, + ) + else: + topk_indices = torch.ops._C_ascend.npu_lightning_indexer( + query=q_li, + key=kv_cache[2], + weights=weights, + actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_key=actual_seq_lengths_key, + block_table=attn_metadata.block_table, + layout_query="TND", + layout_key="PA_BSND", + sparse_count=2048, + sparse_mode=3, + ) + return topk_indices + + def _execute_sparse_flash_attention_process( + self, ql_nope, q_pe, kv_cache, topk_indices, attn_metadata, actual_seq_lengths_query, actual_seq_lengths_key + ): + block_table = attn_metadata.block_table + kv = kv_cache[0] + key_rope = kv_cache[1] + + attn_output = torch.ops._C_ascend.npu_sparse_flash_attention( + query=ql_nope, + key=kv, + value=kv, + sparse_indices=topk_indices, + scale_value=self.scale, + sparse_block_size=1, + block_table=block_table, + actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_kv=actual_seq_lengths_key, + query_rope=q_pe, + key_rope=key_rope, + layout_query="TND", + layout_kv="PA_BSND", + sparse_mode=3, + ) + return attn_output + + def forward( + self, + layer_name, + hidden_states: torch.Tensor, # query in unified attn + kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor], + attn_metadata: M, + need_gather_q_kv: bool = False, + output: torch.Tensor | None = None, + ) -> torch.Tensor: + assert output is not None, "Output tensor must be provided." + forward_context = get_forward_context() + if attn_metadata is None: + # Profiling run. + if self.enable_dsa_cp_with_layer_shard and not forward_context.in_profile_run: + for layer in self.layer_sharding_kwargs or []: + if is_hidden_layer(layer): + reach_layer_for_shard_weight_series(layer) + return output.fill_(0) + + cos = attn_metadata.cos + sin = attn_metadata.sin + slot_mapping = attn_metadata.slot_mapping + slot_mapping_cp = None + if self.enable_dsa_cp: + assert attn_metadata.dsa_cp_context is not None + slot_mapping_cp = attn_metadata.dsa_cp_context.slot_mapping_cp + actual_seq_lengths_query = attn_metadata.dsa_cp_context.actual_seq_lengths_query + actual_seq_lengths_key = attn_metadata.dsa_cp_context.actual_seq_lengths_key + else: + actual_seq_lengths_query = attn_metadata.cum_query_lens + actual_seq_lengths_key = attn_metadata.seq_lens + + # Inputs and outputs may be padded for CUDA graphs + num_input_tokens = attn_metadata.num_input_tokens + output_padded = output + + # all-gather o_proj weight for prefill stage of PD mix node + o_proj_full_handle = None + # if is PD mix stage, using original TP o_proj weight, and also need to full gather for o_proj + # weight for prefill stage. + full_gather_o_proj_enabled = self.enable_dsa_cp_with_o_proj_tp and attn_metadata.attn_state not in { + AscendAttentionState.DecodeOnly, + AscendAttentionState.SpecDecoding, + } + + # run mlapo ops when dsa-cp is disabled, and ensure that num_tokens satisfies the count limitation + if self.enable_mlapo and num_input_tokens <= MLAPO_MAX_SUPPORTED_TOKENS: + hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_with_mlapo( + hidden_states=hidden_states, + kv_cache=kv_cache, + cos=cos, + sin=sin, + slot_mapping=slot_mapping, + num_input_tokens=num_input_tokens, + ) + k_li = self.indexer_select_pre_process(x=hidden_states, cos=cos, sin=sin) + # native + else: + assert self.fused_qkv_a_proj is not None, "q lora is required for DSA." + weight_prefetch_method = get_weight_prefetch_method() + weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream( + inputs=self.fused_qkv_a_proj.weight, dependency=hidden_states + ) + qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] + q_c, kv_no_split = qkv_lora.split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], + dim=-1, + ) + assert self.q_a_layernorm is not None, "q_a_layernorm must be initialized" + q_c = self.q_a_layernorm(q_c) + + k_li = self.indexer_select_pre_process(x=hidden_states, cos=cos, sin=sin) + + wait_for_kv_layer_from_connector(layer_name) + + if self.enable_dsa_cp: + assert slot_mapping_cp is not None + k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping_cp, attn_metadata) + else: + k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping, attn_metadata) + + if self.enable_dsa_cp: + assert k_pe is not None + assert k_nope is not None + async_op = self.enable_dsa_cp_with_layer_shard or full_gather_o_proj_enabled + # support all_gather kv async for communication calculation overlap + fused_kv_no_split, kv_ag_handle = all_gather_async( + torch.cat( + [ + k_pe.view(-1, k_pe.shape[-1]), + k_nope.view(-1, k_nope.shape[-1]), + k_li.view(-1, k_li.shape[-1]), + ], + dim=1, + ), + get_tp_group(), + async_op=async_op, + ) + + ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c) + q_pe = self.rope_single(q_pe, cos, sin) + + if self.enable_dsa_cp: + if kv_ag_handle is not None: + kv_ag_handle.wait() + + if self.enable_dsa_cp_with_layer_shard: + for layer in self.layer_sharding_kwargs or []: + if is_hidden_layer(layer): + reach_layer_for_shard_weight_series(layer) + elif full_gather_o_proj_enabled: + _, o_proj_full_handle = all_gather_async( + self.o_proj_tp_weight, get_tp_group(), output=AscendSFAImpl.o_proj_full_pool + ) + + if kv_cache is not None: + assert fused_kv_no_split is not None + k_pe, k_nope, k_li = fused_kv_no_split.split( + [self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim], dim=-1 + ) + k_nope = k_nope.view(k_nope.shape[0], 1, -1) + k_pe = k_pe.view(k_pe.shape[0], 1, -1) + DeviceOperator.reshape_and_cache( + key=k_nope[: attn_metadata.num_actual_tokens], + value=k_pe[: attn_metadata.num_actual_tokens], + key_cache=kv_cache[0], + value_cache=kv_cache[1], + slot_mapping=slot_mapping[: attn_metadata.num_actual_tokens], + ) + + k_li = self._get_full_kv(k_li, attn_metadata) + + if kv_cache is not None: + if self.is_kv_producer: + attn_metadata.reshape_cache_event = torch.npu.Event() + torch_npu.npu_scatter_nd_update_( + kv_cache[2].view(-1, k_li.shape[-1]), slot_mapping.view(-1, 1), k_li.view(-1, k_li.shape[-1]) + ) # b, s, n, d + if self.is_kv_producer: + attn_metadata.reshape_cache_event.record() + + topk_indices = self.indexer_select_post_process( + x=hidden_states, + q_c=q_c, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + cos=cos, + sin=sin, + actual_seq_lengths_query=actual_seq_lengths_query, + actual_seq_lengths_key=actual_seq_lengths_key, + ) + + attn_output = self._execute_sparse_flash_attention_process( + ql_nope, q_pe, kv_cache, topk_indices, attn_metadata, actual_seq_lengths_query, actual_seq_lengths_key + ) + + attn_output = self._v_up_proj(attn_output) + weight_prefetch_method = get_weight_prefetch_method() + weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream( + inputs=self.o_proj.weight, + dependency=attn_output, + max_size=MAX_O_PROJ_PREFETCH_SIZE, + linear_layer=self.o_proj, + ) + + if self.enable_dsa_cp_with_o_proj_tp: + # When using SFA-CP with pd mixed, o_proj has two cases: + # 1. prefill: o_proj is a TP weight, we need to all-gather o_proj weight to switch TP=1. + # 2. decode: all-to-all the hidden_state before the o_proj forward. + result, require_o_proj_forward = self._handle_o_proj_weight_switch_and_forward( + attn_output=attn_output, + output=output, + o_proj_full_handle=o_proj_full_handle, + should_shard_weight=full_gather_o_proj_enabled, + ) + if not require_o_proj_forward: + return result + attn_output = result + + output[...] = self.o_proj(attn_output)[0] + + maybe_save_kv_layer_to_connector(layer_name, list(kv_cache)) + + return output_padded diff --git a/vllm_ascend/ops/triton/rope.py b/vllm_ascend/ops/triton/rope.py index ad863e40..90906517 100644 --- a/vllm_ascend/ops/triton/rope.py +++ b/vllm_ascend/ops/triton/rope.py @@ -146,6 +146,79 @@ def _triton_rope( tl.store(k_start_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) +@triton.jit +def _triton_rope_siso( + qk_ptr, + qk_row_stride, + cos_ptr, + cos_row_stride, + sin_ptr, + sin_row_stride, + cos_sin_ptr, + cos_sin_row_stride, + pos_ptr, + num_tokens, + n_h: tl.constexpr, + hd: tl.constexpr, + rope_dim: tl.constexpr, + pad_n_h: tl.constexpr, + pad_rope_dim: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + IS_NEOX_STYLE: tl.constexpr, + USE_COS_SIN: tl.constexpr, +): + pid = tl.program_id(0).to(tl.int64) + row_block_size = tl.num_programs(0) + + for row_idx in tl.range(pid, num_tokens, row_block_size): + qk_start_ptr = qk_ptr + row_idx * qk_row_stride + + # #################################################################### + # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position + # m of this program instance + # #################################################################### + cos_offsets = tl.arange(0, pad_rope_dim // 2) + sin_offsets = tl.arange(pad_rope_dim // 2, pad_rope_dim) + cos_mask = cos_offsets < (rope_dim // 2) + if USE_COS_SIN: + pos_idx = tl.load(pos_ptr + row_idx).to(tl.int64) + cos_start_ptr = cos_sin_ptr + pos_idx * cos_sin_row_stride + cos_row = tl.load(cos_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32) + sin_row = tl.load(cos_start_ptr + sin_offsets, mask=cos_mask, other=0).to(tl.float32) + else: + cos_start_ptr = cos_ptr + row_idx * cos_row_stride + sin_start_ptr = sin_ptr + row_idx * sin_row_stride + cos_row = tl.load(cos_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32) + sin_row = tl.load(sin_start_ptr + cos_offsets, mask=cos_mask, other=0).to(tl.float32) + + # #################################################################### + # Load the left and right half of q and k for the current + # program instance (i.e. for the current token) separately + # #################################################################### + # left half of the head + if IS_NEOX_STYLE: + first_half_offsets = tl.arange(0, pad_n_h)[:, None] * hd + tl.arange(0, pad_rope_dim // 2)[None, :] + else: + first_half_offsets = tl.arange(0, pad_n_h)[:, None] * hd + (2 * tl.arange(0, pad_rope_dim // 2)[None, :]) + + first_mask = (tl.arange(0, pad_n_h)[:, None] < n_h) & ( + tl.arange(0, pad_rope_dim // 2)[None, :] < (rope_dim // 2) + ) + qk_tile_1 = tl.load(qk_start_ptr + first_half_offsets, mask=first_mask, other=0).to(sin_row.dtype) + + # right half of the head + if IS_NEOX_STYLE: + second_half_offsets = first_half_offsets + (rope_dim // 2) + else: + second_half_offsets = first_half_offsets + 1 + second_mask = first_mask + qk_tile_2 = tl.load(qk_start_ptr + second_half_offsets, mask=second_mask, other=0).to(sin_row.dtype) + + # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin] + new_qk_tile_1 = qk_tile_1 * cos_row - qk_tile_2 * sin_row + tl.store(qk_start_ptr + first_half_offsets, new_qk_tile_1, mask=first_mask) + + def rope_forward_triton( q: torch.Tensor, k: torch.Tensor, @@ -237,3 +310,83 @@ def rope_forward_triton( "Please check whether you call rope_forward_triton correctly." ) return q, k + + +def rope_forward_triton_siso( + qk: torch.Tensor, + cos: torch.Tensor = None, + sin: torch.Tensor = None, + cos_sin_cache: torch.Tensor = None, + positions: torch.Tensor = None, + rope_dim: int = -1, + is_neox_style: bool = True, +) -> tuple[torch.Tensor, torch.Tensor]: + if not qk.is_contiguous(): + qk = qk.contiguous() + + num_tokens, n_head, head_dim = qk.shape + assert rope_dim <= head_dim + pad_rope_dim = triton.next_power_of_2(rope_dim) + pad_n_head = triton.next_power_of_2(n_head) + BLOCK_SIZE = pad_n_head + num_vectorcore = get_vectorcore_num() + n_row = min(num_tokens, num_vectorcore) + + if cos_sin_cache is not None and positions is not None: + assert positions.shape[0] == num_tokens + _triton_rope_siso[(n_row,)]( + qk, + qk.stride(0), + None, + None, + None, + None, + cos_sin_cache, + cos_sin_cache.stride(0), + positions, + num_tokens, + n_head, + head_dim, + rope_dim, + pad_n_head, + pad_rope_dim, + BLOCK_SIZE=BLOCK_SIZE, + IS_NEOX_STYLE=is_neox_style, + USE_COS_SIN=True, + ) + elif cos is not None and sin is not None: + assert cos.shape[0] == num_tokens and sin.shape[0] == num_tokens + cos = cos.view(num_tokens, -1) + sin = sin.view(num_tokens, -1) + if rope_dim == -1: + # If rope_dim is not specified, we assume that input cos/sin is not + # duplicated to rope_dim, which means rope_dim == cos.shape[-1] * 2 + rope_dim = cos.shape[-1] * 2 + _triton_rope_siso[(n_row,)]( + qk, + qk.stride(0), + cos, + cos.stride(0), + sin, + sin.stride(0), + None, + None, + None, + num_tokens, + n_head, + head_dim, + rope_dim, + pad_n_head, + pad_rope_dim, + BLOCK_SIZE=BLOCK_SIZE, + IS_NEOX_STYLE=is_neox_style, + USE_COS_SIN=False, + ) + else: + raise ValueError( + "Currently, rope_forward_triton supports passing:\n" + "1. positions and original cos_sin_cache.\n" + "2. cos and sin which are already selected by positions\n" + "Please check whether you call rope_forward_triton correctly." + ) + return qk diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 424e4fc0..da11e35b 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -1138,10 +1138,24 @@ def enable_dsa_cp_with_layer_shard() -> bool: from vllm.config import get_current_vllm_config vllm_config = get_current_vllm_config() + # because the broadcast in layer sharding needs to be overlapped with a heavy compute stream to be + # effectively hidden, it is enabled only during the prefill stage. is_prefill_instance = vllm_config.kv_transfer_config is not None and vllm_config.kv_transfer_config.is_kv_producer return is_prefill_instance +@lru_cache(maxsize=1) +def enable_dsa_cp_with_o_proj_tp() -> bool: + if not enable_dsa_cp(): + return False + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() + # if is PD mix stage, using original TP o_proj weight, and also need to + # full gather for o_proj weight for prefill stage. + return vllm_config.kv_transfer_config is None + + def check_gdn_layer(vllm_config) -> bool: """ gdn layer is marked with `linear_attention`.