from dataclasses import dataclass from typing import TYPE_CHECKING, TypeVar import scipy # type: ignore import torch import torch_npu import vllm.envs as envs_vllm from torch import nn from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group from vllm.logger import logger from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadataBuilder from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.triton_utils import HAS_TRITON from vllm.v1.attention.backend import ( AttentionBackend, # type: ignore AttentionCGSupport, MLAAttentionImpl, ) from vllm.v1.kv_cache_interface import AttentionSpec from vllm_ascend import envs from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import _EXTRA_CTX from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.context_parallel.common_cp import AscendPCPMetadata from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE, MLAPO_MAX_SUPPORTED_TOKENS from vllm_ascend.attention.utils import ( AscendCommonAttentionMetadata, ascend_chunked_prefill_workspace_size, enable_cp, maybe_save_kv_layer_to_connector, trans_rope_weight, transdata, wait_for_kv_layer_from_connector, ) from vllm_ascend.device.device_op import DeviceOperator from vllm_ascend.distributed.utils import all_gather_async from vllm_ascend.ops.layer_shard_linear import ( is_hidden_layer, post_process_after_loading_for_shard_weight_series, reach_layer_for_shard_weight_series, register_all_layers_to_shard_weight_series, ) from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla from vllm_ascend.ops.triton.rope import rope_forward_triton_siso from vllm_ascend.quantization.methods import AscendW8A8LinearMethod from vllm_ascend.utils import ( ACL_FORMAT_FRACTAL_ND, _round_up, dispose_layer, enable_dsa_cp, enable_dsa_cp_with_layer_shard, enable_dsa_cp_with_o_proj_tp, get_weight_prefetch_method, maybe_trans_nz, ) from vllm_ascend.worker.npu_input_batch import NPUInputBatch if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput # token count limits within bmm_transpose operator BMM_TRANS_MAX_SUPPORTED_TOKENS = 1024 class AscendSFABackend(AttentionBackend): accept_output_buffer: bool = True @staticmethod def get_name() -> str: # HACK(Ronald1995): vllm `initialize_kv_cache` method in model runner v2 make # attention name assertion, we just set name to FLASH_ATTN to avoid assertion error. # rectify this when vllm disable the assertion. return "ASCEND_SFA" if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER else "FLASH_ATTN" @staticmethod def get_builder_cls(): if enable_cp(): from vllm_ascend.attention.context_parallel.sfa_cp import AscendSFACPMetadataBuilder return AscendSFACPMetadataBuilder return AscendSFAMetadataBuilder @staticmethod def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, head_size: int) -> tuple[int, ...]: return (num_blocks, block_size, num_kv_heads, head_size) @staticmethod def get_impl_cls() -> type["AscendSFAImpl"]: if enable_cp(): from vllm_ascend.attention.context_parallel.sfa_cp import AscendSFACPImpl return AscendSFACPImpl return AscendSFAImpl @staticmethod def get_supported_kernel_block_sizes() -> list[int]: return [128] @dataclass class DSACPContext: num_tokens: int num_tokens_pad: int local_start: int local_end: int local_end_with_pad: int slot_mapping_cp: torch.Tensor actual_seq_lengths_query: torch.Tensor actual_seq_lengths_key: torch.Tensor @dataclass class AscendSFAMetadata: """Metadata for MLACommon. NOTE: Please read the comment at the top of the file before trying to understand this class """ # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| # |-------------------- seq_len ---------------------| # |-- query_len ---| num_actual_tokens: int # Number of tokens excluding padding. slot_mapping: torch.Tensor seq_lens: torch.Tensor cum_query_lens: torch.Tensor block_table: torch.Tensor sin: torch.Tensor cos: torch.Tensor # For logging. num_input_tokens: int = 0 # Number of tokens including padding. # The dimension of the attention heads head_dim: int | None = None attn_mask: torch.Tensor = None # chunked prefill by default if no attn_states passed attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill dsa_cp_context: DSACPContext | None = None reshape_cache_event: torch.npu.Event = None sfa_cp_metadata: AscendPCPMetadata | None = None num_decodes: int = 0 num_decode_tokens: int = 0 num_prefills: int = 0 M = TypeVar("M", bound=AscendSFAMetadata) class AscendSFAMetadataBuilder(MLACommonMetadataBuilder[AscendSFAMetadata]): """ NOTE: Please read the comment at the top of the file before trying to understand this class """ def __init__( self, kv_cache_spec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device, metadata_cls: type[AscendSFAMetadata] | None = None, supports_dcp_with_varlen: bool = False, ): super().__init__( kv_cache_spec, layer_names, vllm_config, device, metadata_cls if metadata_cls is not None else AscendSFAMetadata, supports_dcp_with_varlen, ) self.block_size = vllm_config.cache_config.block_size self.max_blocks = (vllm_config.model_config.max_model_len + self.block_size - 1) // self.block_size self.speculative_config = vllm_config.speculative_config self.decode_threshold = 1 if self.speculative_config: spec_token_num = self.speculative_config.num_speculative_tokens self.decode_threshold += spec_token_num assert self.decode_threshold <= 16, ( f"decode_threshold exceeded \ npu_fused_infer_attention_score TND layout's limit of 16, \ got {self.decode_threshold}" ) self.reorder_batch_threshold = self.decode_threshold self.attn_mask_builder = AttentionMaskBuilder(self.device) self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim self.enable_dsa_cp = enable_dsa_cp() max_num_reqs = vllm_config.scheduler_config.max_num_seqs self.actual_seq_lengths_query = torch.zeros(max_num_reqs + 1, dtype=torch.int32, device=device) self.actual_seq_lengths_key = torch.empty_like(self.actual_seq_lengths_query) @staticmethod def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int: return ascend_chunked_prefill_workspace_size(vllm_config) @classmethod def get_cudagraph_support( cls: type["AscendSFAMetadataBuilder"], vllm_config: VllmConfig, kv_cache_spec: AttentionSpec, ) -> AttentionCGSupport: # Explicit override in case the underlying builder specialized this getter. # @override omitted only because of mypy limitation due to type variable. return AttentionCGSupport.UNIFORM_BATCH def reorder_batch(self, input_batch: "NPUInputBatch", scheduler_output: "SchedulerOutput") -> bool: # No need to reorder for Ascend SFA return False def build( self, common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, fast_build: bool = False, ) -> AscendSFAMetadata: num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens num_input_tokens = common_attn_metadata.num_input_tokens block_table = common_attn_metadata.block_table_tensor[:num_reqs] slot_mapping = common_attn_metadata.slot_mapping[:num_input_tokens] input_positions = common_attn_metadata.positions[:num_input_tokens].long() cum_query_lens = common_attn_metadata.query_start_loc[1 : num_reqs + 1] seq_lens = common_attn_metadata.seq_lens[:num_reqs] cos, sin = get_cos_and_sin_mla(input_positions, True) dsa_cp_context = None if self.enable_dsa_cp: global_tp_size = get_tp_group().world_size num_tokens = num_input_tokens num_tokens_pad = _round_up(num_tokens, global_tp_size) num_tokens_per_device = num_tokens_pad // global_tp_size local_start = get_tp_group().rank_in_group * num_tokens_per_device local_end_with_pad = local_start + num_tokens_per_device local_end = min(local_end_with_pad, num_actual_tokens) pad_size = num_tokens_pad - cos.shape[0] assert cos.shape == sin.shape, f"cos.shape must be equal to sin.shape, got {cos.shape} and {sin.shape}" if pad_size > 0: cos = nn.functional.pad(cos, (0, 0, 0, 0, 0, 0, 0, pad_size)) sin = nn.functional.pad(sin, (0, 0, 0, 0, 0, 0, 0, pad_size)) pad_size_slot = num_tokens_pad - slot_mapping.shape[0] if pad_size_slot > 0: slot_mapping = nn.functional.pad(slot_mapping, (0, pad_size_slot), value=-1) else: slot_mapping = slot_mapping[:num_tokens_pad] slot_mapping_cp = slot_mapping[local_start:local_end_with_pad] cos = cos[local_start:local_end_with_pad] sin = sin[local_start:local_end_with_pad] assert cos.shape[0] == num_tokens_per_device, ( f"cos.shape[0] must be equal to num_tokens_per_device, \ got {cos.shape[0]} and {num_tokens_per_device}" ) assert slot_mapping_cp.shape[0] == num_tokens_per_device, ( f"slot_mapping_cp.shape[0] must be equal to num_tokens_per_device, \ got {slot_mapping_cp.shape[0]} and {num_tokens_per_device}" ) assert slot_mapping.shape[0] == num_tokens_pad, ( f"slot_mapping.shape[0] must be equal to num_tokens_pad, \ got {slot_mapping.shape[0]} and {num_tokens_pad}" ) actual_seq_lengths_query = self.actual_seq_lengths_query actual_seq_lengths_key = self.actual_seq_lengths_key num_segs = cum_query_lens.shape[0] last_token = 0 cum = 0 for i in range(0, num_segs): global_start = last_token global_end = cum_query_lens[i].item() last_token = global_end req_local_start = max(global_start, local_start) req_local_end = min(global_end, local_end_with_pad) num_local_tokens = req_local_end - req_local_start if num_local_tokens > 0: cum += num_local_tokens actual_seq_lengths_query[i] = cum offset = global_end - req_local_end actual_seq_lengths_key[i] = seq_lens[i].item() - offset else: actual_seq_lengths_query[i] = cum actual_seq_lengths_key[i] = 0 actual_seq_lengths_query = actual_seq_lengths_query[:num_reqs] actual_seq_lengths_key = actual_seq_lengths_key[:num_reqs] dsa_cp_context = DSACPContext( num_tokens=num_tokens, num_tokens_pad=num_tokens_pad, local_start=local_start, local_end=local_end, local_end_with_pad=local_end_with_pad, slot_mapping_cp=slot_mapping_cp, actual_seq_lengths_query=actual_seq_lengths_query, actual_seq_lengths_key=actual_seq_lengths_key, ) return self.metadata_cls( # type: ignore num_input_tokens=common_attn_metadata.num_input_tokens, num_actual_tokens=num_actual_tokens, cum_query_lens=cum_query_lens, seq_lens=seq_lens, slot_mapping=slot_mapping, head_dim=self.model_config.get_head_size(), attn_mask=self.attn_mask_builder.get_attention_mask(self.model_config), attn_state=common_attn_metadata.attn_state, block_table=block_table, sin=sin[:num_input_tokens], cos=cos[:num_input_tokens], dsa_cp_context=dsa_cp_context, ) def build_for_graph_capture( self, common_attn_metadata: AscendCommonAttentionMetadata, attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly, ): if attn_state in {AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding}: attn_metadata = self.build( common_prefix_len=0, common_attn_metadata=common_attn_metadata, ) else: raise NotImplementedError("Currently we only support building dummy metadata for DecodeOnly state") attn_metadata.attn_state = attn_state return attn_metadata class AscendSFAImpl(MLAAttentionImpl): """ NOTE: Please read the comment at the top of the file before trying to understand this class """ # Supports forward using the all-gather o_proj weight for decode requests when Sharded CP is enabled. o_proj_full_pool: torch.Tensor | None = None # qk_hadamard tensor shared when dsa c8 enabled qk_hadamard: torch.Tensor | None = None def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, alibi_slopes: list[float] | None, sliding_window: int | None, kv_cache_dtype: str, logits_soft_cap: float | None, attn_type: str, kv_sharing_target_layer_name: str | None, **kwargs, ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_kv_heads self.kv_cache_dtype = kv_cache_dtype # MLA Args self.q_lora_rank = kwargs["q_lora_rank"] self.kv_lora_rank = kwargs["kv_lora_rank"] self.qk_nope_head_dim = kwargs["qk_nope_head_dim"] self.qk_rope_head_dim = kwargs["qk_rope_head_dim"] self.qk_head_dim = kwargs["qk_head_dim"] self.v_head_dim = kwargs["v_head_dim"] self.rotary_emb = kwargs["rotary_emb"] self.q_proj = kwargs["q_proj"] if self.q_lora_rank is None else kwargs["q_b_proj"] self.fused_qkv_a_proj = kwargs.get("fused_qkv_a_proj") self.kv_b_proj = kwargs["kv_b_proj"] self.o_proj = kwargs["o_proj"] self.indexer = kwargs["indexer"] self.kv_a_proj_with_mqa = kwargs.get("kv_a_proj_with_mqa") self.kv_a_layernorm = kwargs.get("kv_a_layernorm") self.q_a_layernorm = kwargs.get("q_a_layernorm") self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.tp_size = get_tensor_model_parallel_world_size() self.tp_rank = get_tp_group().rank_in_group self.q_b_proj = kwargs["q_b_proj"] ascend_config = get_ascend_config() self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp # The MLAPO operator fuses the pre-processing steps on Q/K/V in MLA into a single operator # NOTE: it imposes a limit on the number of input tokens and conflicts with FlashComm self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO assert self.indexer is not None, "Indexer is required for DSA." self.local_num_heads = self.num_heads self.vllm_config = get_current_vllm_config() self.is_kv_producer = ( self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer ) # indexer param self.n_head: int = self.indexer.n_head # 64 self.head_dim: int = self.indexer.head_dim # 128 self.wq_b = self.indexer.wq_b self.wk = self.indexer.wk self.weights_proj = self.indexer.weights_proj self.k_norm = self.indexer.k_norm self.cp_size = 1 self.is_rope_neox_style = True self.use_torch_npu_lightning_indexer = False if self.vllm_config.model_config.hf_config.model_type in ["glm_moe_dsa"]: self.is_rope_neox_style = False self.use_torch_npu_lightning_indexer = True # dsa c8 self.use_sparse_c8_indexer = ascend_config.enable_sparse_c8 if self.use_sparse_c8_indexer: self.c8_k_cache_dtype = torch.int8 self.c8_k_scale_cache_dtype = torch.float16 # Effective in SFA when FlashComm is enabled. self.enable_dsa_cp = enable_dsa_cp() # Enable layer sharding via DSA-CP on the P node in the PD-disaggregated setup. self.enable_dsa_cp_with_layer_shard = enable_dsa_cp_with_layer_shard() # use original TP o_proj weight in PD mix stage, and full gather # for o_proj weight for prefill stage. self.enable_dsa_cp_with_o_proj_tp = enable_dsa_cp_with_o_proj_tp() if self.enable_dsa_cp: self.local_num_heads = self.num_heads * self.tp_size if self.enable_dsa_cp_with_layer_shard: self.layer_sharding_kwargs = [] for layer_name in get_ascend_config().layer_sharding or []: if layer_name in kwargs: self.layer_sharding_kwargs.append(kwargs[layer_name]) else: logger.warning_once( f"[SFAImpl init] Layer '{layer_name}' not found in kwargs for layer sharding, " "skipping sharding configuration" ) register_all_layers_to_shard_weight_series(self.layer_sharding_kwargs) def process_weights_after_loading(self, act_dtype: torch.dtype): # NOTE: We currently do not support quant kv_b_proj. assert isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod) # NOTE: Weight will be reshaped next, we need to revert and transpose it. kv_b_proj_weight = torch_npu.npu_format_cast(self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_ND).T assert kv_b_proj_weight.shape == ( self.kv_lora_rank, self.local_num_heads * (self.qk_nope_head_dim + self.v_head_dim), ), ( f"{kv_b_proj_weight.shape=}, " f"{self.kv_lora_rank=}, " f"{self.local_num_heads=}, " f"{self.qk_nope_head_dim=}, " f"{self.v_head_dim=}" ) kv_b_proj_weight = kv_b_proj_weight.view( self.kv_lora_rank, self.local_num_heads, self.qk_nope_head_dim + self.v_head_dim, ) W_UK, W_UV = kv_b_proj_weight.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) # Convert from (L, N, V) to (N, L, V) self.W_UV = W_UV.transpose(0, 1).contiguous() # Convert from (L, N, P) to (N, P, L) self.W_UK_T = W_UK.permute(1, 2, 0).contiguous() # TODO(zzzzwwjj): Currently, torch.ops._C_ascend.batch_matmul_transpose cannot support weight nz # self.W_UV = maybe_trans_nz(self.W_UV) # Dispose kv_b_proj since it is replaced by W_UV and W_UK_T to save memory dispose_layer(self.kv_b_proj) if self.enable_dsa_cp: if self.enable_dsa_cp_with_layer_shard: for layer in self.layer_sharding_kwargs or []: if is_hidden_layer(layer): post_process_after_loading_for_shard_weight_series(layer) else: self._init_o_proj_tp_full_params() if self.enable_mlapo: quant_method = getattr( getattr(self.fused_qkv_a_proj, "quant_method", None), "quant_method", None, ) reasons = [] if self.fused_qkv_a_proj is None or not isinstance(quant_method, AscendW8A8LinearMethod): reasons.append( "Currently mlapo only supports W8A8 quantization in SFA scenario." "Some layers in your model are not quantized with W8A8," "thus mlapo is disabled for these layers." ) if self.enable_dsa_cp: reasons.append("Currently mlapo does not support SFA with CP,thus mlapo is disabled for these layers.") if reasons: self.enable_mlapo = False for msg in reasons: logger.warning_once(msg) else: self._process_weights_for_fused_mlapo(act_dtype) if not self.enable_mlapo: # if mlapo, W_UK_T can't trans nz self.W_UK_T = maybe_trans_nz(self.W_UK_T) if self.use_sparse_c8_indexer and AscendSFAImpl.qk_hadamard is None: AscendSFAImpl.qk_hadamard = torch.tensor(scipy.linalg.hadamard(128), dtype=torch.bfloat16, device="npu") / ( 128**0.5 ) # Processing the input parameters for MLAPO by reordering and transposing # QKV(and part of Q) weight, applying RoPE-related dimension transformations, # and handling quantization parameters. def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype): assert self.kv_a_proj_with_mqa is None assert self.fused_qkv_a_proj is not None kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[..., self.q_lora_rank :].contiguous() q_a_proj_wt = self.fused_qkv_a_proj.weight.data[..., : self.q_lora_rank].contiguous() kv_a_proj_wt = kv_a_proj_wt.t().contiguous() kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim) kv_a_proj_wt = kv_a_proj_wt.t().contiguous() wd_qkv = torch.cat((kv_a_proj_wt, q_a_proj_wt), dim=-1) wd_qkv = wd_qkv.t().contiguous() wd_qkv = transdata(wd_qkv, block_size=(16, 32)).unsqueeze(0).contiguous() self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29) kv_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[self.q_lora_rank :].contiguous() q_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[: self.q_lora_rank].contiguous() kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape(self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous() kv_a_proj_deq_scl = trans_rope_weight(kv_a_proj_deq_scl, self.qk_rope_head_dim) kv_a_proj_deq_scl = kv_a_proj_deq_scl.view(self.kv_lora_rank + self.qk_rope_head_dim).contiguous() self.deq_scale_qkv = torch.cat((kv_a_proj_deq_scl, q_a_proj_deq_scl), dim=-1).contiguous() kv_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[self.q_lora_rank :].contiguous() q_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[: self.q_lora_rank].contiguous() kv_a_proj_qt_bias = kv_a_proj_qt_bias.reshape(self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous() kv_a_proj_qt_bias = trans_rope_weight(kv_a_proj_qt_bias, self.qk_rope_head_dim) kv_a_proj_qt_bias = kv_a_proj_qt_bias.view(self.kv_lora_rank + self.qk_rope_head_dim).contiguous() self.quant_bias_qkv = torch.cat((kv_a_proj_qt_bias, q_a_proj_qt_bias), dim=-1).contiguous() wu_q = self.q_proj.weight.data wu_q = wu_q.t().reshape(self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1) wu_q = trans_rope_weight(wu_q, self.qk_rope_head_dim) wu_q = wu_q.reshape(self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim), -1) wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous() self.wu_q = torch_npu.npu_format_cast(wu_q, 29) qb_deq_scl = self.q_proj.deq_scale.data qb_deq_scl = qb_deq_scl.reshape(self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1) qb_deq_scl = trans_rope_weight(qb_deq_scl, self.qk_rope_head_dim) self.qb_deq_scl = qb_deq_scl.reshape(self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)) qb_qt_bias = self.q_proj.quant_bias.data qb_qt_bias = qb_qt_bias.reshape(self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1) qb_qt_bias = trans_rope_weight(qb_qt_bias, self.qk_rope_head_dim) self.qb_qt_bias = qb_qt_bias.reshape(self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)) device = self.q_proj.weight.device self.gamma1 = self.q_a_layernorm.weight.data # type: ignore[union-attr] self.beta1 = self.q_a_layernorm.bias.data # type: ignore[union-attr] self.gamma2 = self.kv_a_layernorm.weight.data # type: ignore[union-attr] self.quant_scale0 = self.fused_qkv_a_proj.input_scale.data self.quant_offset0 = self.fused_qkv_a_proj.input_offset.data self.quant_scale1 = self.q_proj.input_scale.data self.quant_offset1 = self.q_proj.input_offset.data self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device) self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device) # On KV consumers (decode-only) MLAPO uses the transformed weights built above; # the original fused_qkv_a_proj/q_proj weights and quant params are no longer # referenced, so drop them to save memory. if ( self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_consumer and self.vllm_config.scheduler_config.max_num_batched_tokens <= MLAPO_MAX_SUPPORTED_TOKENS ): self.fused_qkv_a_proj.weight = None self.fused_qkv_a_proj.deq_scale = None self.fused_qkv_a_proj.quant_bias = None self.q_proj.weight = None self.q_proj.deq_scale = None self.q_proj.quant_bias = None torch.npu.empty_cache() def forward_mha( self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: M, k_scale: torch.Tensor, output: torch.Tensor, ) -> None: raise NotImplementedError("forward_mha is not supported for SFA attention. Use forward() instead.") def forward_mqa( self, q: torch.Tensor | tuple[torch.Tensor, torch.Tensor], kv_c_and_k_pe_cache: torch.Tensor, attn_metadata: M, layer, ) -> tuple[torch.Tensor, torch.Tensor | None]: raise NotImplementedError("forward_mqa is not supported for SFA attention. Use forward() instead.") def rope_single( self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, ) -> torch.Tensor: B, N, D = x.shape S = 1 x = x.view(B, N, S, D) x = torch_npu.npu_interleave_rope(x, cos, sin) return x.view(B, N, D) def _init_o_proj_tp_full_params(self): """ Initialize TP-mode and Full-mode parameters for o_proj weight, preparing for weight switching in PD mix stage. For PD mix stage: - Use original TP o_proj weight for decode phase - Need full-gather o_proj weight from all TP ranks for prefill phase """ if AscendSFAImpl.o_proj_full_pool is None: sample = self.o_proj.weight AscendSFAImpl.o_proj_full_pool = torch.empty( (sample.shape[0] * self.tp_size, sample.shape[1]), dtype=sample.dtype, device=sample.device ) # Save TP-mode parameters (original sharded weights) self.o_proj_tp_weight = self.o_proj.weight.clone().detach() self.o_proj_tp_aclnn_input_scale = self.o_proj.aclnn_input_scale.clone().detach() self.o_proj_tp_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.clone().detach() self.o_proj_tp_aclnn_input_offset = self.o_proj.aclnn_input_offset.clone().detach() # Initially switch to TP mode for graph capture self.o_proj.weight.set_(self.o_proj_tp_weight) self.o_proj.aclnn_input_scale.set_(self.o_proj_tp_aclnn_input_scale) self.o_proj.aclnn_input_scale_reciprocal.set_(self.o_proj_tp_aclnn_input_scale_reciprocal) self.o_proj.aclnn_input_offset.set_(self.o_proj_tp_aclnn_input_offset) # Precompute Full-mode quantization parameters by repeating TP parameters across all TP ranks self.o_proj_full_aclnn_input_scale = self.o_proj.aclnn_input_scale.repeat(self.tp_size) self.o_proj_full_aclnn_input_scale_reciprocal = self.o_proj.aclnn_input_scale_reciprocal.repeat(self.tp_size) self.o_proj_full_aclnn_input_offset = self.o_proj.aclnn_input_offset.repeat(self.tp_size) def _handle_o_proj_weight_switch_and_forward( self, attn_output: torch.Tensor, output: torch.Tensor, o_proj_full_handle: torch.distributed.Work | None, should_shard_weight: bool, ) -> tuple[torch.Tensor, bool]: """ Handle o_proj weight switching between TP-mode and Full-mode, and execute forward computation. """ # Gather o_proj weight from all TP ranks for Full-mode computation if should_shard_weight: # Wait for the completion of o_proj weight all-gather operation if o_proj_full_handle is not None: o_proj_full_handle.wait() # Switch o_proj to Full-mode (gathered weight from all TP ranks) self.o_proj.weight.set_(AscendSFAImpl.o_proj_full_pool) self.o_proj.aclnn_input_scale.set_(self.o_proj_full_aclnn_input_scale) self.o_proj.aclnn_input_scale_reciprocal.set_(self.o_proj_full_aclnn_input_scale_reciprocal) self.o_proj.aclnn_input_offset.set_(self.o_proj_full_aclnn_input_offset) # Apply quantization method and execute forward computation output[...] = self.o_proj.quant_method.quant_method.apply(self.o_proj, attn_output) # Switch o_proj back to TP-mode for subsequent decode operations self.o_proj.weight.set_(self.o_proj_tp_weight) self.o_proj.aclnn_input_scale.set_(self.o_proj_tp_aclnn_input_scale) self.o_proj.aclnn_input_scale_reciprocal.set_(self.o_proj_tp_aclnn_input_scale_reciprocal) self.o_proj.aclnn_input_offset.set_(self.o_proj_tp_aclnn_input_offset) return output, False else: # For decode scenario: perform all-to-all communication on o_proj input activations # Reshape for all-to-all: [batch * seq, tp_size, head_dim] -> [tp_size, batch * seq, head_dim] send = ( attn_output.view(-1, self.tp_size, self.num_heads * self.v_head_dim) .permute(1, 0, 2) .reshape(-1, self.num_heads * self.v_head_dim) ) attn_output = torch.empty_like(send) torch.distributed.all_to_all_single(attn_output, send, group=get_tp_group().device_group) return attn_output, True def _get_full_kv(self, k, attn_metadata): return k def exec_kv( self, kv_no_split: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, kv_cache: tuple, slots: torch.Tensor, attn_metadata: M, ): B = kv_no_split.shape[0] N = self.num_kv_heads S = 1 # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] kv_no_split = kv_no_split.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) cache_mode = "PA" if self.enable_dsa_cp: _, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache( kv_no_split, self.kv_a_layernorm.weight, # type: ignore[union-attr] cos, sin, slots.to(torch.int64), kv_cache[1], kv_cache[0], epsilon=self.kv_a_layernorm.variance_epsilon, # type: ignore[union-attr] cache_mode=cache_mode, is_output_kv=True, ) return k_pe, k_nope else: torch_npu.npu_kv_rmsnorm_rope_cache( kv_no_split, self.kv_a_layernorm.weight, # type: ignore[union-attr] cos, sin, slots.to(torch.int64), kv_cache[1], kv_cache[0], epsilon=self.kv_a_layernorm.variance_epsilon, # type: ignore[union-attr] cache_mode=cache_mode, ) return None, None # Return `ql_nope`, `q_pe` def _q_proj_and_k_up_proj(self, x): q_nope, q_pe = ( self.q_proj(x)[0] .view(-1, self.local_num_heads, self.qk_head_dim) .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) ) # Convert from (B, N, P) to (N, B, P) q_nope = q_nope.transpose(0, 1) # Multiply (N, B, P) x (N, P, L) -> (N, B, L) ql_nope = torch.bmm(q_nope, self.W_UK_T) # Convert from (N, B, L) to (B, N, L) return ql_nope.transpose(0, 1), q_pe def _v_up_proj(self, x): num_input_tokens, _, _ = x.shape if ( x.dtype in [torch.float16, torch.bfloat16] and hasattr(torch.ops._C_ascend, "batch_matmul_transpose") and num_input_tokens <= BMM_TRANS_MAX_SUPPORTED_TOKENS ): x = x.view(-1, self.local_num_heads, self.kv_lora_rank) res = torch.empty((num_input_tokens, self.local_num_heads, self.v_head_dim), dtype=x.dtype, device=x.device) torch.ops._C_ascend.batch_matmul_transpose(x, self.W_UV, res) x = res.reshape(-1, self.local_num_heads * self.v_head_dim) else: # Convert from (B, N, L) to (N, B, L) x = x.view(-1, self.local_num_heads, self.kv_lora_rank).transpose(0, 1) # # Multiply (N, B, L) x (N, L, V) -> (N, B, V) x = torch.bmm(x, self.W_UV) # # Convert from (N, B, V) to (B, N * V) x = x.transpose(0, 1).reshape(-1, self.local_num_heads * self.v_head_dim) return x def _sfa_preprocess_with_mlapo( self, hidden_states: torch.Tensor, kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor], cos: torch.Tensor, sin: torch.Tensor, slot_mapping: torch.Tensor, num_input_tokens: int, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: k_nope, k_pe = kv_cache[0], kv_cache[1] ql_nope = torch.empty( (num_input_tokens, self.W_UK_T.shape[0], k_nope.shape[-1]), dtype=hidden_states.dtype, device=hidden_states.device, ) q_pe = torch.empty( (num_input_tokens, self.W_UK_T.shape[0], k_pe.shape[-1]), dtype=hidden_states.dtype, device=hidden_states.device, ) q_c = torch.empty( (num_input_tokens, self.q_lora_rank), dtype=hidden_states.dtype, device=hidden_states.device, ) torch.ops._C_ascend.mla_preprocess( hidden_states, self.wd_qkv, self.deq_scale_qkv, self.gamma1, self.beta1, self.wu_q, self.qb_deq_scl, self.gamma2, cos, sin, self.W_UK_T, k_nope, k_pe, slot_mapping, quant_scale0=self.quant_scale0, quant_offset0=self.quant_offset0, bias0=self.quant_bias_qkv, quant_scale1=self.quant_scale1, quant_offset1=self.quant_offset1, bias1=self.qb_qt_bias, ctkv_scale=self.ctkv_scale, q_nope_scale=self.q_nope_scale, cache_mode="krope_ctkv", quant_mode="per_tensor_quant_asymm", enable_inner_out=True, q_out0=ql_nope, kv_cache_out0=k_nope, q_out1=q_pe, kv_cache_out1=k_pe, inner_out=q_c, ) return hidden_states, ql_nope, q_pe, q_c def indexer_select_pre_process( self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, ): k_li, _ = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128] k_li = self.k_norm(k_li).unsqueeze(1) k_li = k_li.view(-1, 1, self.head_dim) if HAS_TRITON: cos = cos.view(-1, self.qk_rope_head_dim) sin = sin.view(-1, self.qk_rope_head_dim) k_li = rope_forward_triton_siso( k_li, cos, sin, rope_dim=self.qk_rope_head_dim, is_neox_style=self.is_rope_neox_style ) else: k_li_pe, k_li_nope = torch.split( k_li, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1 ) cos = cos.view(-1, 1, 1, self.qk_rope_head_dim) sin = sin.view(-1, 1, 1, self.qk_rope_head_dim) k_li_pe = k_li_pe.unsqueeze(2) k_li_pe = torch_npu.npu_rotary_mul(k_li_pe, cos, sin) k_li_pe = k_li_pe.squeeze(2) k_li = torch.cat([k_li_pe, k_li_nope], dim=-1) # [b*s,128] if self.use_sparse_c8_indexer: k_li = k_li @ AscendSFAImpl.qk_hadamard k_li, k_li_scale = torch_npu.npu_dynamic_quant(k_li.view(-1, self.head_dim), dst_type=self.c8_k_cache_dtype) k_li_scale = k_li_scale.to(self.c8_k_scale_cache_dtype) # [b*s,] k_li_scale = k_li_scale.unsqueeze(-1) # [b*s,1] else: k_li_scale = None return k_li, k_li_scale def indexer_select_post_process( self, x: torch.Tensor, q_c: torch.Tensor, kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor], attn_metadata: M, cos: torch.Tensor, sin: torch.Tensor, actual_seq_lengths_query: torch.Tensor, actual_seq_lengths_key: torch.Tensor, ): weights, _ = self.weights_proj(x) q_li, _ = self.wq_b(q_c) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] q_li = q_li.view(-1, self.n_head, self.head_dim) # [n_toks,64,128] if HAS_TRITON: q_li = rope_forward_triton_siso( q_li, cos, sin, rope_dim=self.qk_rope_head_dim, is_neox_style=self.is_rope_neox_style ) else: q_li_pe, q_li_nope = torch.split( q_li, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], dim=-1 ) # [b,s,64,64+64] q_li_pe = q_li_pe.unsqueeze(2) q_li_pe = torch_npu.npu_rotary_mul(q_li_pe, cos, sin) q_li_pe = q_li_pe.squeeze(2) q_li = torch.cat([q_li_pe, q_li_nope], dim=-1) # [b*s,64,128] if self.use_sparse_c8_indexer: q_li_shape_ori = q_li.shape q_li = q_li @ AscendSFAImpl.qk_hadamard q_li, q_li_scale = torch_npu.npu_dynamic_quant(q_li.view(-1, self.head_dim), dst_type=self.c8_k_cache_dtype) q_li_scale = q_li_scale.to(self.c8_k_scale_cache_dtype) # DSV3.2 currently has graph compilation issues when using torch_npu.npu.lightning_indexer. # So two branches are maintained temporarily. # TODO: torch.ops._C_ascend.npu_lightning_indexer needs to be removed. if self.use_sparse_c8_indexer: assert len(kv_cache) == 4 weights = weights.to(torch.float16) topk_indices = torch.ops._C_ascend.npu_lightning_indexer_quant( query=q_li.view(q_li_shape_ori), key=kv_cache[2], weights=weights, query_dequant_scale=q_li_scale.view(q_li_shape_ori[:-1]), key_dequant_scale=kv_cache[3].squeeze(2), # B S N D -> B S D actual_seq_lengths_query=actual_seq_lengths_query, actual_seq_lengths_key=actual_seq_lengths_key, block_table=attn_metadata.block_table, query_quant_mode=0, key_quant_mode=0, layout_query="TND", layout_key="PA_BSND", sparse_count=2048, sparse_mode=3, ) elif self.use_torch_npu_lightning_indexer: topk_indices, _ = torch_npu.npu_lightning_indexer( query=q_li, key=kv_cache[2], weights=weights, actual_seq_lengths_query=actual_seq_lengths_query, actual_seq_lengths_key=actual_seq_lengths_key, block_table=attn_metadata.block_table, layout_query="TND", layout_key="PA_BSND", sparse_count=2048, sparse_mode=3, ) else: topk_indices = torch.ops._C_ascend.npu_lightning_indexer( query=q_li, key=kv_cache[2], weights=weights, actual_seq_lengths_query=actual_seq_lengths_query, actual_seq_lengths_key=actual_seq_lengths_key, block_table=attn_metadata.block_table, layout_query="TND", layout_key="PA_BSND", sparse_count=2048, sparse_mode=3, ) return topk_indices def _execute_sparse_flash_attention_process( self, ql_nope, q_pe, kv_cache, topk_indices, attn_metadata, actual_seq_lengths_query, actual_seq_lengths_key ): block_table = attn_metadata.block_table kv = kv_cache[0] key_rope = kv_cache[1] attn_output = torch.ops._C_ascend.npu_sparse_flash_attention( query=ql_nope, key=kv, value=kv, sparse_indices=topk_indices, scale_value=self.scale, sparse_block_size=1, block_table=block_table, actual_seq_lengths_query=actual_seq_lengths_query, actual_seq_lengths_kv=actual_seq_lengths_key, query_rope=q_pe, key_rope=key_rope, layout_query="TND", layout_kv="PA_BSND", sparse_mode=3, ) return attn_output def forward( self, layer_name, hidden_states: torch.Tensor, # query in unified attn kv_cache: tuple[torch.Tensor, torch.Tensor, torch.Tensor], attn_metadata: M, need_gather_q_kv: bool = False, output: torch.Tensor | None = None, ) -> torch.Tensor: assert output is not None, "Output tensor must be provided." if attn_metadata is None: # Profiling run. if self.enable_dsa_cp_with_layer_shard and not _EXTRA_CTX.in_profile_run: for layer in self.layer_sharding_kwargs or []: if is_hidden_layer(layer): reach_layer_for_shard_weight_series(layer) return output.fill_(0) cos = attn_metadata.cos sin = attn_metadata.sin slot_mapping = attn_metadata.slot_mapping slot_mapping_cp = None if self.enable_dsa_cp: assert attn_metadata.dsa_cp_context is not None slot_mapping_cp = attn_metadata.dsa_cp_context.slot_mapping_cp actual_seq_lengths_query = attn_metadata.dsa_cp_context.actual_seq_lengths_query actual_seq_lengths_key = attn_metadata.dsa_cp_context.actual_seq_lengths_key else: actual_seq_lengths_query = attn_metadata.cum_query_lens actual_seq_lengths_key = attn_metadata.seq_lens # Inputs and outputs may be padded for CUDA graphs num_input_tokens = attn_metadata.num_input_tokens output_padded = output # all-gather o_proj weight for prefill stage of PD mix node o_proj_full_handle = None # if is PD mix stage, using original TP o_proj weight, and also need to full gather for o_proj # weight for prefill stage. full_gather_o_proj_enabled = self.enable_dsa_cp_with_o_proj_tp and attn_metadata.attn_state not in { AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding, } # run mlapo ops when dsa-cp is disabled, and ensure that num_tokens satisfies the count limitation if self.enable_mlapo and num_input_tokens <= MLAPO_MAX_SUPPORTED_TOKENS: hidden_states, ql_nope, q_pe, q_c = self._sfa_preprocess_with_mlapo( hidden_states=hidden_states, kv_cache=kv_cache, cos=cos, sin=sin, slot_mapping=slot_mapping, num_input_tokens=num_input_tokens, ) k_li, k_li_scale = self.indexer_select_pre_process(x=hidden_states, cos=cos, sin=sin) # native else: assert self.fused_qkv_a_proj is not None, "q lora is required for DSA." weight_prefetch_method = get_weight_prefetch_method() weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream( inputs=self.fused_qkv_a_proj.weight, dependency=hidden_states ) qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] q_c, kv_no_split = qkv_lora.split( [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1, ) assert self.q_a_layernorm is not None, "q_a_layernorm must be initialized" q_c = self.q_a_layernorm(q_c) k_li, k_li_scale = self.indexer_select_pre_process(x=hidden_states, cos=cos, sin=sin) wait_for_kv_layer_from_connector(layer_name) if self.enable_dsa_cp: assert slot_mapping_cp is not None k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping_cp, attn_metadata) else: k_pe, k_nope = self.exec_kv(kv_no_split, cos, sin, kv_cache, slot_mapping, attn_metadata) if self.enable_dsa_cp: assert k_pe is not None assert k_nope is not None assert k_li is not None async_op = self.enable_dsa_cp_with_layer_shard or full_gather_o_proj_enabled # support all_gather kv async for communication calculation overlap if not self.use_sparse_c8_indexer: fused_kv_no_split, kv_ag_handle = all_gather_async( torch.cat( [ k_pe.view(-1, k_pe.shape[-1]), k_nope.view(-1, k_nope.shape[-1]), k_li.view(-1, k_li.shape[-1]), ], dim=1, ), get_tp_group(), async_op=async_op, ) else: # due to different dtypes, we have to split commu pass assert k_li_scale is not None fused_kv_no_split, _ = all_gather_async( torch.cat( [ k_pe.view(-1, k_pe.shape[-1]), k_nope.view(-1, k_nope.shape[-1]), ], dim=1, ), get_tp_group(), async_op=async_op, ) k_li, _ = all_gather_async( k_li, get_tp_group(), async_op=async_op, ) k_li_scale, kv_ag_handle = all_gather_async( k_li_scale, get_tp_group(), async_op=async_op, ) ql_nope, q_pe = self._q_proj_and_k_up_proj(q_c) q_pe = self.rope_single(q_pe, cos, sin) if self.enable_dsa_cp: if kv_ag_handle is not None: kv_ag_handle.wait() if self.enable_dsa_cp_with_layer_shard: for layer in self.layer_sharding_kwargs or []: if is_hidden_layer(layer): reach_layer_for_shard_weight_series(layer) elif full_gather_o_proj_enabled: _, o_proj_full_handle = all_gather_async( self.o_proj_tp_weight, get_tp_group(), output=AscendSFAImpl.o_proj_full_pool ) if kv_cache is not None: assert fused_kv_no_split is not None if not self.use_sparse_c8_indexer: k_pe, k_nope, k_li = fused_kv_no_split.split( [self.qk_rope_head_dim, self.kv_lora_rank, self.head_dim], dim=-1 ) else: k_pe, k_nope = fused_kv_no_split.split([self.qk_rope_head_dim, self.kv_lora_rank], dim=-1) k_nope = k_nope.view(k_nope.shape[0], 1, -1) k_pe = k_pe.view(k_pe.shape[0], 1, -1) DeviceOperator.reshape_and_cache( key=k_nope[: attn_metadata.num_actual_tokens], value=k_pe[: attn_metadata.num_actual_tokens], key_cache=kv_cache[0], value_cache=kv_cache[1], slot_mapping=slot_mapping[: attn_metadata.num_actual_tokens], ) k_li = self._get_full_kv(k_li, attn_metadata) if kv_cache is not None: if self.is_kv_producer: attn_metadata.reshape_cache_event = torch.npu.Event() torch_npu.npu_scatter_nd_update_( kv_cache[2].view(-1, k_li.shape[-1]), slot_mapping.view(-1, 1), k_li.view(-1, k_li.shape[-1]) ) # b, s, n, d if self.use_sparse_c8_indexer: assert len(kv_cache) == 4 assert k_li_scale is not None torch_npu.npu_scatter_nd_update_( kv_cache[3].view(-1, k_li_scale.shape[-1]), slot_mapping.view(-1, 1), k_li_scale.view(-1, k_li_scale.shape[-1]), ) if self.is_kv_producer: attn_metadata.reshape_cache_event.record() topk_indices = self.indexer_select_post_process( x=hidden_states, q_c=q_c, kv_cache=kv_cache, attn_metadata=attn_metadata, cos=cos, sin=sin, actual_seq_lengths_query=actual_seq_lengths_query, actual_seq_lengths_key=actual_seq_lengths_key, ) attn_output = self._execute_sparse_flash_attention_process( ql_nope, q_pe, kv_cache, topk_indices, attn_metadata, actual_seq_lengths_query, actual_seq_lengths_key ) attn_output = self._v_up_proj(attn_output) weight_prefetch_method = get_weight_prefetch_method() weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream( inputs=self.o_proj.weight, dependency=attn_output, max_size=MAX_O_PROJ_PREFETCH_SIZE, linear_layer=self.o_proj, ) if self.enable_dsa_cp_with_o_proj_tp: # When using SFA-CP with pd mixed, o_proj has two cases: # 1. prefill: o_proj is a TP weight, we need to all-gather o_proj weight to switch TP=1. # 2. decode: all-to-all the hidden_state before the o_proj forward. result, require_o_proj_forward = self._handle_o_proj_weight_switch_and_forward( attn_output=attn_output, output=output, o_proj_full_handle=o_proj_full_handle, should_shard_weight=full_gather_o_proj_enabled, ) if not require_o_proj_forward: return result attn_output = result output[...] = self.o_proj(attn_output)[0] maybe_save_kv_layer_to_connector(layer_name, list(kv_cache)) return output_padded