from dataclasses import dataclass from typing import TYPE_CHECKING, NamedTuple, TypeVar import numpy as np import torch import torch_npu import vllm.envs as envs_vllm from vllm.config import VllmConfig, get_current_vllm_config from vllm.forward_context import ForwardContext, get_forward_context from vllm.logger import logger from vllm.model_executor.layers.attention.mla_attention import MLACommonMetadataBuilder from vllm.model_executor.layers.linear import UnquantizedLinearMethod from vllm.utils.math_utils import cdiv, round_down from vllm.v1.attention.backend import AttentionBackend, AttentionCGSupport, MLAAttentionImpl # type: ignore from vllm.v1.attention.backends.utils import PAD_SLOT_ID # type: ignore from vllm.v1.kv_cache_interface import AttentionSpec, MLAAttentionSpec from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.context_parallel.common_cp import AscendPCPMetadata, CPChunkedContextMetadata from vllm_ascend.attention.utils import ( AscendCommonAttentionMetadata, ascend_chunked_prefill_workspace_size, enable_cp, enabling_mlapo, maybe_save_kv_layer_to_connector, split_decodes_and_prefills, trans_rope_weight, transdata, wait_for_kv_layer_from_connector, ) from vllm_ascend.compilation.acl_graph import ( get_draft_graph_params, get_graph_params, update_draft_graph_params_workspaces, update_graph_params_workspaces, ) from vllm_ascend.ops.layer_shard_linear import ( is_hidden_layer, post_process_after_loading_for_shard_weight_series, reach_layer_for_shard_weight_series, register_all_layers_to_shard_weight_series, ) from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla from vllm_ascend.quantization.methods import AscendW8A8LinearMethod from vllm_ascend.utils import ( ACL_FORMAT_FRACTAL_ND, get_weight_prefetch_method, maybe_trans_nz, weak_ref_tensors, ) from vllm_ascend.worker.npu_input_batch import NPUInputBatch if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 BUILD_METADATA_STEP_PREFILL = 0 BUILD_METADATA_STEP_DECODE = 1 # token count limits within the mlapo operator MLAPO_MAX_SUPPORTED_TOKENS = 1024 class AscendMLABackend(AttentionBackend): accept_output_buffer: bool = True @staticmethod def get_name() -> str: # HACK(Ronald1995): vllm `initialize_kv_cache` method in model runner v2 make # attention name assertion, we just set name to FLASH_ATTN to avoid assertion error. # rectify this when vllm disable the assertion. return "ASCEND_MLA" if not envs_vllm.VLLM_USE_V2_MODEL_RUNNER else "FLASH_ATTN" @staticmethod def get_builder_cls(): if enable_cp(): from vllm_ascend.attention.context_parallel.mla_cp import AscendMlaCPMetadataBuilder return AscendMlaCPMetadataBuilder return AscendMLAMetadataBuilder @staticmethod def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, head_size: int) -> tuple[int, ...]: return num_blocks, block_size, num_kv_heads, head_size @staticmethod def get_impl_cls() -> type["MLAAttentionImpl"]: if enable_cp(): from vllm_ascend.attention.context_parallel.mla_cp import AscendMlaCPImpl return AscendMlaCPImpl return AscendMLAImpl @staticmethod def get_supported_kernel_block_sizes() -> list[int]: return [128] @dataclass class ChunkedContextMetadata: """ Metadata for chunked context handling in MLA attention. Manages sequence boundaries and workspace for chunked prefill processing. """ cu_seq_lens: torch.Tensor starts: torch.Tensor seq_tot: list[int] max_seq_lens: list[int] workspace: torch.Tensor chunk_seq_lens: torch.Tensor chunk_seq_lens_npu: torch.Tensor @dataclass class AscendMLAPrefillMetadata: """Prefill Specific Metadata for Ascend""" attn_mask: torch.Tensor query_lens: torch.Tensor seq_lens: list[int] context_lens: torch.Tensor input_positions: torch.Tensor query_start_loc: torch.Tensor block_table: torch.Tensor max_query_len: int max_seq_lens: int chunked_context: ChunkedContextMetadata | CPChunkedContextMetadata | None = None sin: torch.Tensor = None cos: torch.Tensor = None pcp_metadata: AscendPCPMetadata | None = None @dataclass class AscendMLADecodeMetadata: """Decode-specific metadata for Ascend MLA attention.""" # Input positions for rotary embeddings since for MLA the rotary # position embeddings are applied inside the attention backend input_positions: torch.Tensor block_table: torch.Tensor seq_lens: torch.Tensor max_seq_lens: int seq_lens_list: list[int] actual_seq_lengths_q: list[int] | None = None attn_mask: torch.Tensor | None = None sin: torch.Tensor = None cos: torch.Tensor = None cp_seq_len: torch.Tensor = None @dataclass class AscendMLAMetadata: """Metadata for MLACommon. NOTE: Please read the comment at the top of the file before trying to understand this class """ # NOTE(sang): Definition of context_len, query_len, and seq_len. # |---------- N-1 iteration --------| # |---------------- N iteration ---------------------| # |- tokenA -|......................|-- newTokens ---| # |---------- context_len ----------| # |-------------------- seq_len ---------------------| # |-- query_len ---| num_actual_tokens_pcp_padded: int num_actual_tokens: int # Number of tokens excluding padding. slot_mapping: torch.Tensor query_start_loc: torch.Tensor seq_lens: torch.Tensor block_tables: torch.Tensor # New for MLA (compared to FlashAttention) # For handling prefill decode split num_decodes: int num_decode_tokens: int num_prefills: int # For logging. num_input_tokens: int = 0 # Number of tokens including padding. query_lens: list[int] | None = None # The dimension of the attention heads head_dim: int | None = None attn_mask: torch.Tensor = None # chunked prefill by default if no attn_states passed attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill decode: AscendMLADecodeMetadata | None = None prefill: AscendMLAPrefillMetadata | None = None reshape_cache_event: torch.npu.Event = None def __post_init__(self): pass # supported_head_sizes = AscendMLABackend.get_supported_head_sizes() # if self.head_dim is not None and self.head_dim \ # not in supported_head_sizes: # raise ValueError( # f"Only {supported_head_sizes} are supported for head_dim,", # f"received {self.head_dim}.") M = TypeVar("M", bound=AscendMLAMetadata) class AscendMLAMetadataBuilder(MLACommonMetadataBuilder[AscendMLAMetadata]): """ NOTE: Please read the comment at the top of the file before trying to understand this class """ def __init__( self, kv_cache_spec: MLAAttentionSpec, layer_names: list[str], vllm_config: VllmConfig, device: torch.device, metadata_cls: type[AscendMLAMetadata] | None = None, supports_dcp_with_varlen: bool = False, ): super().__init__( kv_cache_spec, layer_names, vllm_config, device, metadata_cls if metadata_cls is not None else AscendMLAMetadata, supports_dcp_with_varlen, ) scheduler_config = vllm_config.scheduler_config self.block_size = vllm_config.cache_config.block_size self.max_blocks = (vllm_config.model_config.max_model_len + self.block_size - 1) // self.block_size self.chunked_prefill_enabled = scheduler_config.enable_chunked_prefill self.speculative_config = vllm_config.speculative_config self.decode_threshold = 1 if self.speculative_config: spec_token_num = self.speculative_config.num_speculative_tokens self.decode_threshold += spec_token_num assert self.decode_threshold <= 16, ( f"decode_threshold exceeded \ npu_fused_infer_attention_score TND layout's limit of 16, \ got {self.decode_threshold}" ) self.reorder_batch_threshold = self.decode_threshold self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim self.cos_cache = None self.sin_cache = None self.chunk_seq_lens: torch.Tensor = None self.cu_seq_lens_cpu: torch.Tensor = None self.num_chunks: torch.Tensor = None self.max_context_chunk = 0 self.num_decodes = 0 self.num_prefills = 0 self.num_decode_tokens = 0 self.num_prefill_tokens = 0 self.context_lens_cpu: torch.Tensor = None self.num_actual_tokens: int | None = None self.block_table: torch.Tensor = None self.slot_mapping: torch.Tensor = None self.graph_pad_size = 0 self.query_lens: torch.Tensor = None self.seq_lens: torch.Tensor = None self.attn_mask_builder = AttentionMaskBuilder(self.device) @staticmethod def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int: return ascend_chunked_prefill_workspace_size(vllm_config) @classmethod def get_cudagraph_support( cls: type["AscendMLAMetadataBuilder"], vllm_config: VllmConfig, kv_cache_spec: AttentionSpec, ) -> AttentionCGSupport: # Explicit override in case the underlying builder specialized this getter. # @override omitted only because of mypy limitation due to type variable. return AttentionCGSupport.UNIFORM_BATCH def reorder_batch(self, input_batch: "NPUInputBatch", scheduler_output: "SchedulerOutput") -> bool: # We now want to reorder the batch so that the "decode" requests are at # the front and the "prefill" requests are at the using the least amount # swaps possible. (NOTE for now we loosely use "decode" to mean requests # where attention is likely memory-bound and "prefill" to mean requests # where attention is likely compute-bound, TODO(lucas): figure out a # better naming here) decodes = [] prefills = [] for i, req_id in enumerate(input_batch.req_ids): num_tokens = scheduler_output.num_scheduled_tokens[req_id] if num_tokens <= self.decode_threshold: decodes.append(i) else: prefills.append(i) # We hope that this is fairly minimal since decodes # should be around for a number of iterations so hopefully they are # relatively stationary (and new request are generally appended to the # persistent batch so already should be at the back) # To achieve this we loop over the decodes in descending order and # the prefills in ascending order. We swap decodes from the "back" # i.e. past where the last decode should be in the reodorered with # prefills from the front of the batch. # `decodes` and `prefills` are already in ascending order just based on # the above loop num_decodes = len(decodes) num_prefills = len(prefills) first_prefill = 0 modified_batch = False for i in range(1, min(num_decodes, num_prefills) + 1): # If the decode is at the "back" of the batch, i, we can swap it # with the prefill closest to the front of the batch if decodes[num_decodes - i] >= num_decodes: input_batch.swap_states(prefills[first_prefill], decodes[num_decodes - i]) first_prefill += 1 modified_batch = True else: break # Save for next `build` call # TODO(lucas): this is a bit of a hack, we should probably have a # better way of doing this return modified_batch def pad_actual_seq_len_q_mtp_enable_pad( self, num_reqs_pad_size, num_reqs, actual_seq_lengths_q, common_attn_metadata ): """ Pads actual_seq_lengths_q evenly to not exceed 16 tokens per request in order to meet the requirement of npu_fused_infer_attention_score. In Torchair scenario, the lengths of the queries must be padded to the same length. And npu_fused_infer_attention_score constraint requires the last element must equal to batch_size(num_tokens). For example: batch_size=36, num_reqs_pad_size=2, num_reqs=16 By default, each request should have inference 2 token, which means actual_seq_lengths_q should be [2,4,6,8,10,12,14,16,18,20,22,24,26,28,30,32,34,36]. However, mtp torchair + PD scenario, the actual_seq_lengths_q may be [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16] before padding, since the first decode request only has 1 token. In order to meet the requirement of npu_fused_infer_attention_score, we need to pad actual_seq_lengths_q evenly to not exceed 16 tokens per request. after padding actual_seq_lengths_q should be similar to [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,32,36] """ FIA_SEQ_LEN_LIMIT = 16 need_padding = ( num_reqs_pad_size != 0 and len(common_attn_metadata.actual_seq_lengths_q) > num_reqs and common_attn_metadata.actual_seq_lengths_q[num_reqs] - actual_seq_lengths_q[-1] > FIA_SEQ_LEN_LIMIT ) if need_padding: padding_seq_len_q = common_attn_metadata.actual_seq_lengths_q[num_reqs : num_reqs + num_reqs_pad_size] start_val = actual_seq_lengths_q[-1] end_val = padding_seq_len_q[-1] num_step = len(padding_seq_len_q) interpolated = np.round(np.linspace(start_val, end_val, num_step + 1)[1:]).astype(int).tolist() assert interpolated[-1] == end_val assert len(interpolated) == len(padding_seq_len_q) actual_seq_lengths_q = actual_seq_lengths_q + interpolated else: actual_seq_lengths_q = ( actual_seq_lengths_q + common_attn_metadata.actual_seq_lengths_q[num_reqs : num_reqs + num_reqs_pad_size] ) return actual_seq_lengths_q def pad_actual_seq_len_q_mtp_disable_pad(self, num_reqs_pad_size, num_reqs, actual_seq_lengths_q): """ Only use for acl full graph mode. Pad the last element of the actual_seq_lengths_q equal to the TND(T) and the num of dimensions equal to the batch_size of main model. For example: batch_size = 8, num_reqs = 4, num_speculative_tokens = 1 input actual_seq_lengths_q = [1, 2, 4, 5] (the 3rd req was accept a token) After padding the actual_seq_lengths_q will be similar to [1, 2, 4, 5, 6, 6, 7, 8] """ need_padding = num_reqs_pad_size > 0 if need_padding: start_val = actual_seq_lengths_q[-1] end_val = num_reqs + num_reqs_pad_size num_step = num_reqs_pad_size interpolated = np.round(np.linspace(start_val, end_val, num_step + 1)[1:]).astype(int).tolist() assert interpolated[-1] == end_val assert len(interpolated) == num_reqs_pad_size actual_seq_lengths_q = actual_seq_lengths_q + interpolated return actual_seq_lengths_q def set_num_actual_tokens( self, common_attn_metadata: AscendCommonAttentionMetadata, ): self.num_actual_tokens = common_attn_metadata.num_actual_tokens def build( self, common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, fast_build: bool = False, ) -> AscendMLAMetadata: num_reqs = common_attn_metadata.num_reqs query_start_loc = common_attn_metadata.query_start_loc query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu self.num_decodes, self.num_prefills, self.num_decode_tokens, self.num_prefill_tokens = ( split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold) ) self.set_num_actual_tokens(common_attn_metadata) assert self.num_decodes + self.num_prefills == num_reqs assert self.num_decode_tokens + self.num_prefill_tokens == common_attn_metadata.num_actual_tokens # NOTE: Currently, MTP-fullgraph is incompatibility pcp self.slot_mapping = common_attn_metadata.slot_mapping[: self.num_actual_tokens] query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] self.query_lens = query_seq_lens_cpu[:num_reqs] self.seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] self.graph_pad_size = common_attn_metadata.graph_pad_size block_table_size = self.get_block_table_size(common_attn_metadata, BUILD_METADATA_STEP_PREFILL) self.block_table = common_attn_metadata.block_table_tensor[:block_table_size] prefill_metadata = None if self.num_prefills > 0: prefill_metadata = self.build_prefill_metadata(common_prefix_len, common_attn_metadata) decode_metadata = None if self.num_decodes > 0: decode_metadata = self.build_decode_metadata(common_prefix_len, common_attn_metadata) return self.metadata_cls( # type: ignore num_actual_tokens_pcp_padded=self.num_actual_tokens, num_input_tokens=common_attn_metadata.num_input_tokens, num_actual_tokens=self.num_actual_tokens, query_lens=self.query_lens.tolist(), slot_mapping=self.slot_mapping, head_dim=self.model_config.get_head_size(), num_decodes=self.num_decodes, num_decode_tokens=self.num_decode_tokens, num_prefills=self.num_prefills, attn_mask=self.attn_mask_builder.get_final_mla_mask(self.model_config), attn_state=common_attn_metadata.attn_state, prefill=prefill_metadata, decode=decode_metadata, query_start_loc=query_start_loc, block_tables=self.block_table, seq_lens=self.seq_lens, ) def build_chunked_metadata( self, common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, ): if not self.chunked_prefill_enabled: return None num_reqs = common_attn_metadata.num_reqs num_computed_tokens_cpu = self.seq_lens - self.query_lens reqs_start = self.num_decodes # prefill_start self.context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] max_context_len_cpu = self.context_lens_cpu.max().item() if not max_context_len_cpu > 0: return None num_prefills_with_context_cpu = (self.context_lens_cpu > 0).sum().item() self.max_context_chunk = self.chunked_prefill_workspace_size // num_prefills_with_context_cpu self.max_context_chunk = round_down(self.max_context_chunk, self.block_size) assert self.max_context_chunk > 0 self.num_chunks = cdiv(max_context_len_cpu, self.max_context_chunk) chunk_starts = ( torch.arange(self.num_chunks, dtype=torch.int32).unsqueeze(1).expand(-1, self.num_prefills) * self.max_context_chunk ) chunk_ends = torch.min(self.context_lens_cpu.unsqueeze(0), chunk_starts + self.max_context_chunk) self.chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) self.cu_seq_lens_cpu = torch.zeros(self.num_chunks, self.num_prefills + 1, dtype=torch.int32, pin_memory=True) torch.cumsum(self.chunk_seq_lens, dim=1, out=self.cu_seq_lens_cpu[:, 1:], dtype=torch.int32) return ChunkedContextMetadata( cu_seq_lens=self.cu_seq_lens_cpu.pin_memory().to(self.device, non_blocking=True), starts=chunk_starts.pin_memory().to(self.device, non_blocking=True), seq_tot=self.chunk_seq_lens.sum(dim=1).tolist(), max_seq_lens=self.chunk_seq_lens.max(dim=1).values.tolist(), chunk_seq_lens=self.chunk_seq_lens, chunk_seq_lens_npu=self.chunk_seq_lens.npu(), workspace=self.chunked_prefill_workspace, ) def get_block_table_size(self, common_attn_metadata: AscendCommonAttentionMetadata, build_metadata_step: int): if build_metadata_step == BUILD_METADATA_STEP_PREFILL: # If graph_pad_size > -1, mean is running in fullgraph mode. # NOTE: Maybe this block_table change can be removed when graph_pad_size > 1. if ( self.graph_pad_size > common_attn_metadata.num_reqs and self.speculative_config.disable_padded_drafter_batch ): return self.graph_pad_size return common_attn_metadata.num_reqs return self.num_decodes def build_prefill_metadata( self, common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, ) -> AscendMLAPrefillMetadata: query_start_loc = common_attn_metadata.query_start_loc # NOTE: Currently, MTP-fullgraph is incompatibility pcp input_positions = common_attn_metadata.positions[: self.num_actual_tokens].long() chunked_context_metadata = self.build_chunked_metadata(common_prefix_len, common_attn_metadata) reqs_start = self.num_decodes # prefill_start tokens_start = self.num_decode_tokens max_query_len = self.query_lens[reqs_start:].max().item() max_seq_lens = self.seq_lens[reqs_start:].max().item() prefill_query_start_loc = query_start_loc[reqs_start:] - query_start_loc[reqs_start] prefill_input_positions = input_positions[tokens_start:] cos, sin = get_cos_and_sin_mla(prefill_input_positions) return AscendMLAPrefillMetadata( attn_mask=self.attn_mask_builder.get_final_mla_mask(self.model_config), query_lens=self.query_lens[reqs_start:].to(torch.int32), seq_lens=self.seq_lens, context_lens=self.seq_lens[reqs_start:], input_positions=prefill_input_positions, block_table=self.block_table[reqs_start:, ...], max_query_len=max_query_len, max_seq_lens=max_seq_lens, query_start_loc=prefill_query_start_loc, chunked_context=chunked_context_metadata, sin=sin, cos=cos, ) def build_decode_metadata( self, common_prefix_len: int, common_attn_metadata: AscendCommonAttentionMetadata, ) -> AscendMLADecodeMetadata: num_reqs = common_attn_metadata.num_reqs query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu input_positions = common_attn_metadata.positions[: self.num_actual_tokens].long() # Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario actual_seq_lengths_q = query_start_loc_cpu[1 : self.num_decodes + 1].tolist() max_seq_lens = self.seq_lens[: self.num_decodes].max().item() self.seq_lens = self.seq_lens[: self.num_decodes] input_positions = input_positions[: self.num_decode_tokens] block_table_size = self.get_block_table_size(common_attn_metadata, BUILD_METADATA_STEP_DECODE) self.block_table = self.block_table[:block_table_size] # NOTE: Currently, MTP-fullgraph is incompatibility pcp # NOTE: Maybe this block_table change can be removed when graph_pad_size > 1. if self.graph_pad_size > self.num_decodes and self.speculative_config.disable_padded_drafter_batch: self.block_table = self.block_table[: self.graph_pad_size, ...] seq_lens_list = self.seq_lens.tolist() cp_seq_len = None if self.graph_pad_size > num_reqs: if self.speculative_config.disable_padded_drafter_batch: num_reqs_pad_size = self.graph_pad_size - num_reqs actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_disable_pad( num_reqs_pad_size, num_reqs, actual_seq_lengths_q ) seq_lens_list = seq_lens_list + [0] * (self.graph_pad_size - self.num_decodes) num_block_pad_size = self.graph_pad_size - self.block_table.shape[0] if num_block_pad_size > 0: block_table_padding = torch.zeros( (num_block_pad_size,) + self.block_table.shape[1:], dtype=self.block_table.dtype, device=self.block_table.device, ) self.block_table = torch.cat([self.block_table, block_table_padding], dim=0) else: num_token_pad_size = self.graph_pad_size - self.num_decode_tokens num_reqs_pad_size = self.graph_pad_size // common_attn_metadata.decode_token_per_req - num_reqs num_block_table_pad_size = ( self.graph_pad_size // common_attn_metadata.decode_token_per_req - self.num_decodes ) seq_lens_list = self.seq_lens.tolist() + [0] * num_reqs_pad_size slot_padding = torch.full( (num_token_pad_size,), PAD_SLOT_ID, dtype=self.slot_mapping.dtype, device=self.slot_mapping.device ) self.slot_mapping = torch.cat([self.slot_mapping, slot_padding]) block_table_padding = torch.zeros( (num_block_table_pad_size,) + self.block_table.shape[1:], dtype=self.block_table.dtype, device=self.block_table.device, ) self.block_table = torch.cat([self.block_table, block_table_padding], dim=0) position_padding = torch.zeros( num_token_pad_size, dtype=input_positions.dtype, device=input_positions.device ) input_positions = torch.cat([input_positions, position_padding]) actual_seq_lengths_q = self.pad_actual_seq_len_q_mtp_enable_pad( num_reqs_pad_size, num_reqs, actual_seq_lengths_q, common_attn_metadata ) cos, sin = get_cos_and_sin_mla(input_positions, use_cache=True) decode_metadata = AscendMLADecodeMetadata( input_positions=input_positions, block_table=self.block_table, seq_lens=self.seq_lens, seq_lens_list=seq_lens_list, max_seq_lens=max_seq_lens, attn_mask=self.attn_mask_builder.get_splitfuse_attn_mask(), actual_seq_lengths_q=actual_seq_lengths_q, sin=sin[: self.num_decode_tokens, ...], cos=cos[: self.num_decode_tokens, ...], cp_seq_len=cp_seq_len, ) return decode_metadata def build_for_graph_capture( self, common_attn_metadata: AscendCommonAttentionMetadata, attn_state: AscendAttentionState = AscendAttentionState.DecodeOnly, ): if attn_state in {AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding}: attn_metadata = self.build( common_prefix_len=0, common_attn_metadata=common_attn_metadata, ) else: raise NotImplementedError( "Currently we only support building dummy metadata for DecodeOnly and SpecDecoding state" ) attn_metadata.attn_state = attn_state return attn_metadata class DecodeMLAPreprocessResult(NamedTuple): ql_nope: torch.Tensor | None = None q_pe: torch.Tensor | None = None k_nope: torch.Tensor | None = None k_pe: torch.Tensor | None = None decode_q_wo_k_up: torch.Tensor | None = None class PrefillMLAPreprocessResult(NamedTuple): q_nope: torch.Tensor | None = None q_pe: torch.Tensor | None = None k_nope: torch.Tensor | None = None k_pe: torch.Tensor | None = None value: torch.Tensor | None = None class AscendMLAImpl(MLAAttentionImpl): """ NOTE: Please read the comment at the top of the file before trying to understand this class """ def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, alibi_slopes: list[float] | None, sliding_window: int | None, kv_cache_dtype: str, logits_soft_cap: float | None, attn_type: str, kv_sharing_target_layer_name: str | None, **kwargs, ): self.vllm_config = get_current_vllm_config() self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_kv_heads self.kv_cache_dtype = kv_cache_dtype # MLA Args self.q_lora_rank = kwargs["q_lora_rank"] self.kv_lora_rank = kwargs["kv_lora_rank"] self.qk_nope_head_dim = kwargs["qk_nope_head_dim"] self.qk_rope_head_dim = kwargs["qk_rope_head_dim"] self.qk_head_dim = kwargs["qk_head_dim"] self.v_head_dim = kwargs["v_head_dim"] self.rotary_emb = kwargs["rotary_emb"] self.fused_qkv_a_proj = kwargs.get("fused_qkv_a_proj") self.q_proj = kwargs["q_proj"] if self.q_lora_rank is None else kwargs["q_b_proj"] self.kv_b_proj = kwargs["kv_b_proj"] self.o_proj = kwargs["o_proj"] self.vllm_config = get_current_vllm_config() self.kv_a_proj_with_mqa = kwargs.get("kv_a_proj_with_mqa") self.kv_a_layernorm = kwargs.get("kv_a_layernorm") self.q_a_layernorm = kwargs.get("q_a_layernorm") self.num_queries_per_kv = self.num_heads // self.num_kv_heads ascend_config = get_ascend_config() self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp self.enable_kv_nz = ascend_config.enable_kv_nz self.ring_mla_mask_size = 512 self.speculative_config = self.vllm_config.speculative_config self.enable_mlapo = enabling_mlapo(self.vllm_config) self.is_kv_producer = ( self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer ) self.layer_sharding_kwargs = [] for layer_name in get_ascend_config().layer_sharding or []: if layer_name in kwargs: self.layer_sharding_kwargs.append(kwargs[layer_name]) else: logger.warning_once( f"Layer '{layer_name}' not found in kwargs for layer sharding, skipping sharding configuration" ) register_all_layers_to_shard_weight_series(self.layer_sharding_kwargs) @staticmethod def update_graph_params( update_stream, forward_context, num_tokens, vllm_config=None, speculative_config=None, num_dcp_pcp_tokens=None, draft_attn_metadatas=None, ): if forward_context.is_draft_model: graph_params = get_draft_graph_params() else: graph_params = get_graph_params() # FIXME: Behold! We are using a temporary hack here to update the args # for each layer's attention op in the graph. with torch.npu.stream(update_stream): for key, param, handle, event in zip( forward_context.attn_metadata, graph_params.attn_params[num_tokens], graph_params.handles[num_tokens], graph_params.events[num_tokens], ): ( q_nope, k_nope, q_pe, k_pe, num_heads, num_kv_heads, input_layout, attn_mask, sparse_mode, scale, block_table, block_size, seq_lens_list, actual_seq_lengths, attn_output, softmax_lse, ) = param seq_lens_list = forward_context.attn_metadata[key].decode.seq_lens_list if speculative_config and speculative_config.method == "mtp" and not forward_context.is_draft_model: actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q spec_multiple = speculative_config.num_speculative_tokens + 1 seq_lens_list = seq_lens_list + [0] * (num_tokens // spec_multiple - len(seq_lens_list)) actual_seq_lengths = [spec_multiple * (i + 1) for i in range(num_tokens // spec_multiple)] elif forward_context.is_draft_model: actual_seq_lengths = forward_context.attn_metadata[key].decode.actual_seq_lengths_q block_table = forward_context.attn_metadata[key].decode.block_table # TODO: This is a hack and should be fixed in the future. if speculative_config.disable_padded_drafter_batch: block_table = block_table[: len(actual_seq_lengths)] seq_lens_list = seq_lens_list + [0] * (len(actual_seq_lengths) - len(seq_lens_list)) else: seq_lens_list = seq_lens_list + [0] * (num_tokens - len(seq_lens_list)) torch.npu.graph_task_update_begin(update_stream, handle) torch_npu.npu_fused_infer_attention_score.out( q_nope, k_nope, k_nope, query_rope=q_pe, key_rope=k_pe, num_heads=num_heads, num_key_value_heads=num_kv_heads, input_layout=input_layout, atten_mask=attn_mask, sparse_mode=sparse_mode, scale=scale, antiquant_mode=0, antiquant_scale=None, block_table=block_table, block_size=block_size, actual_seq_lengths_kv=seq_lens_list, actual_seq_lengths=actual_seq_lengths, workspace=graph_params.workspaces.get(num_tokens), out=[attn_output, softmax_lse], ) torch.npu.graph_task_update_end(update_stream) event.record(update_stream) def _v_up_proj(self, x): # Convert from (N, B, L)/(N, B, 1, L) to (N, B, L) x = x.view(self.num_heads, -1, self.kv_lora_rank) # Multiply (N, B, L) x (N, L, V) -> (B, N, V) x = torch_npu.npu_transpose_batchmatmul(x, self.W_UV, perm_y=(1, 0, 2)) # Convert from (B, N, V) to (B, N * V) x = x.reshape(-1, self.num_heads * self.v_head_dim) return x # Return `ql_nope`, `q_pe` def _q_proj_and_k_up_proj(self, x): q_nope, q_pe = ( self.q_proj(x)[0] .view(-1, self.num_heads, self.qk_head_dim) .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) ) # Convert from (B, N, P) to (N, B, P) q_nope = q_nope.transpose(0, 1) # Multiply (N, B, P) x (N, P, L) -> (N, B, L) ql_nope = torch.bmm(q_nope, self.W_UK_T) # Convert from (N, B, L) to (B, N, L) return ql_nope.transpose(0, 1), q_pe def process_weights_after_loading(self, act_dtype: torch.dtype): # NOTE: We currently do not support quant kv_b_proj. assert isinstance(self.kv_b_proj.quant_method, UnquantizedLinearMethod) # NOTE: Weight will be reshaped next, we need to revert and transpose it. kv_b_proj_weight = torch_npu.npu_format_cast(self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_ND).T assert kv_b_proj_weight.shape == ( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), ), ( f"{kv_b_proj_weight.shape=}, " f"{self.kv_lora_rank=}, " f"{self.num_heads=}, " f"{self.qk_nope_head_dim=}, " f"{self.v_head_dim=}" ) kv_b_proj_weight = kv_b_proj_weight.view( self.kv_lora_rank, self.num_heads, self.qk_nope_head_dim + self.v_head_dim, ) W_UK, W_UV = kv_b_proj_weight.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) # Convert from (L, N, V) to (N, L, V) self.W_UV = W_UV.transpose(0, 1).contiguous() # Convert from (L, N, P) to (N, P, L) self.W_UK_T = W_UK.permute(1, 2, 0).contiguous() # TODO(zzzzwwjj): Currently, torch.ops._C_ascend.batch_matmul_transpose cannot support weight nz # self.W_UV = maybe_trans_nz(self.W_UV) if self.enable_mlapo: # Currently mlapo only supports W8A8 quantization in MLA scenario # TODO(whx): modify this limitation when mlapo supports floating point if self.fused_qkv_a_proj is None or not isinstance( getattr(self.fused_qkv_a_proj.quant_method, "quant_method", None), AscendW8A8LinearMethod ): self.enable_mlapo = False logger.warning_once( "Currently mlapo only supports W8A8 quantization in MLA scenario." "Some layers in your model are not quantized with W8A8," "thus mlapo is disabled for these layers." ) if self.enable_mlapo: self._process_weights_for_fused_mlapo(act_dtype) else: # if mlapo, W_UK_T can't trans nz self.W_UK_T = maybe_trans_nz(self.W_UK_T) for layer in self.layer_sharding_kwargs or []: if is_hidden_layer(layer): post_process_after_loading_for_shard_weight_series(layer) def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype): kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[..., self.q_lora_rank :].contiguous() # type: ignore[union-attr] q_a_proj_wt = self.fused_qkv_a_proj.weight.data[..., : self.q_lora_rank].contiguous() # type: ignore[union-attr] kv_a_proj_wt = kv_a_proj_wt.t().contiguous() kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim) kv_a_proj_wt = kv_a_proj_wt.t().contiguous() wd_qkv = torch.cat((kv_a_proj_wt, q_a_proj_wt), dim=-1) wd_qkv = wd_qkv.t().contiguous() wd_qkv = transdata(wd_qkv, block_size=(16, 32)).unsqueeze(0).contiguous() self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29) kv_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[self.q_lora_rank :].contiguous() # type: ignore[union-attr] q_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[: self.q_lora_rank].contiguous() # type: ignore[union-attr] kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape(self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous() kv_a_proj_deq_scl = trans_rope_weight(kv_a_proj_deq_scl, self.qk_rope_head_dim) kv_a_proj_deq_scl = kv_a_proj_deq_scl.view(self.kv_lora_rank + self.qk_rope_head_dim).contiguous() self.deq_scale_qkv = torch.cat((kv_a_proj_deq_scl, q_a_proj_deq_scl), dim=-1).contiguous() kv_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[self.q_lora_rank :].contiguous() # type: ignore[union-attr] q_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[: self.q_lora_rank].contiguous() # type: ignore[union-attr] kv_a_proj_qt_bias = kv_a_proj_qt_bias.reshape(self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous() kv_a_proj_qt_bias = trans_rope_weight(kv_a_proj_qt_bias, self.qk_rope_head_dim) kv_a_proj_qt_bias = kv_a_proj_qt_bias.view(self.kv_lora_rank + self.qk_rope_head_dim).contiguous() self.quant_bias_qkv = torch.cat((kv_a_proj_qt_bias, q_a_proj_qt_bias), dim=-1).contiguous() wu_q = self.q_proj.weight.data wu_q = wu_q.t().reshape(self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1) wu_q = trans_rope_weight(wu_q, self.qk_rope_head_dim) wu_q = wu_q.reshape(self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim), -1) wu_q = transdata(wu_q, block_size=(16, 32)).unsqueeze(0).contiguous() self.wu_q = torch_npu.npu_format_cast(wu_q, 29) qb_deq_scl = self.q_proj.deq_scale.data qb_deq_scl = qb_deq_scl.reshape(self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1) qb_deq_scl = trans_rope_weight(qb_deq_scl, self.qk_rope_head_dim) self.qb_deq_scl = qb_deq_scl.reshape(self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)) qb_qt_bias = self.q_proj.quant_bias.data qb_qt_bias = qb_qt_bias.reshape(self.num_heads, self.qk_nope_head_dim + self.qk_rope_head_dim, -1) qb_qt_bias = trans_rope_weight(qb_qt_bias, self.qk_rope_head_dim) self.qb_qt_bias = qb_qt_bias.reshape(self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)) device = self.q_proj.weight.device self.gamma1 = self.q_a_layernorm.weight.data # type: ignore[union-attr] self.beta1 = torch.zeros_like(self.gamma1) if (_bias := self.q_a_layernorm.bias) is None else _bias.data # type: ignore[union-attr] self.gamma2 = self.kv_a_layernorm.weight.data # type: ignore[union-attr] self.quant_scale0 = self.fused_qkv_a_proj.input_scale.data # type: ignore[union-attr] self.quant_offset0 = self.fused_qkv_a_proj.input_offset.data # type: ignore[union-attr] self.quant_scale1 = self.q_proj.input_scale.data self.quant_offset1 = self.q_proj.input_offset.data self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device) self.q_nope_scale = torch.tensor([1], dtype=act_dtype, device=device) # On KV consumers (decode-only) MLAPO uses the transformed weights built above; # the original fused_qkv_a_proj/q_proj weights and quant params are no longer # referenced, so drop them to save memory. if ( self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_consumer and self.vllm_config.scheduler_config.max_num_batched_tokens <= MLAPO_MAX_SUPPORTED_TOKENS ): self.fused_qkv_a_proj.weight = None # type: ignore[union-attr] self.fused_qkv_a_proj.deq_scale = None # type: ignore[union-attr] self.fused_qkv_a_proj.quant_bias = None # type: ignore[union-attr] self.q_proj.weight = None self.q_proj.deq_scale = None self.q_proj.quant_bias = None torch.npu.empty_cache() def get_context_seq_len_npu(self, index: int, attn_metadata: AscendMLAMetadata): prefill_metadata = attn_metadata.prefill assert prefill_metadata is not None assert prefill_metadata.chunked_context is not None assert prefill_metadata.chunked_context.chunk_seq_lens_npu is not None iters = len(prefill_metadata.chunked_context.seq_tot) assert 0 <= index < iters return prefill_metadata.chunked_context.chunk_seq_lens_npu[index] def _reorg_kvcache( self, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, chunked_context: CPChunkedContextMetadata, chunk_idx: int, toks: int, ) -> tuple[torch.Tensor, torch.Tensor]: return kv_c_normed, k_pe def _compute_prefill_context( self, q_nope: torch.Tensor, q_pe: torch.Tensor, kv_c_and_k_pe_cache: tuple[torch.Tensor], rope_dim: int, attn_metadata: AscendMLAMetadata, prefix_output: torch.Tensor, prefix_lse: torch.Tensor, ): assert len(kv_c_and_k_pe_cache) > 1 prefill_metadata = attn_metadata.prefill if prefill_metadata is None or prefill_metadata.chunked_context is None: return prefix_output, prefix_lse iters = len(prefill_metadata.chunked_context.seq_tot) current_seq_len = torch.tensor(prefill_metadata.query_lens, dtype=torch.int32) cache_kv_c = kv_c_and_k_pe_cache[0] cache_k_pe = kv_c_and_k_pe_cache[1] num_heads = cache_k_pe.size(2) latent_kv_dim = kv_c_and_k_pe_cache[0].size(-1) for i in range(iters): toks = prefill_metadata.chunked_context.seq_tot[i] # chunk_seq_lens will be padded when pcp&dcp context_seq_len = prefill_metadata.chunked_context.chunk_seq_lens[i] seq_len = torch.stack([current_seq_len, context_seq_len]) context_seq_len_npu = self.get_context_seq_len_npu(i, attn_metadata) kv_c_normed = torch.empty(toks, num_heads, latent_kv_dim, dtype=q_nope.dtype, device=q_nope.device) k_pe = torch.empty(toks, num_heads, rope_dim, dtype=q_nope.dtype, device=q_nope.device) torch_npu.atb.npu_paged_cache_load( cache_kv_c, cache_k_pe, prefill_metadata.block_table, context_seq_len_npu, seq_starts=prefill_metadata.chunked_context.starts[i], key=kv_c_normed, value=k_pe, ) kv_c_normed, k_pe = self._reorg_kvcache( kv_c_normed, k_pe, chunked_context=prefill_metadata.chunked_context, chunk_idx=i, toks=toks, ) kv_c_normed = kv_c_normed.squeeze() kv_nope = self.kv_b_proj(kv_c_normed)[0].view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) k_pe = k_pe.expand((*k_nope.shape[:-1], -1)) mask = attn_metadata.attn_mask torch_npu.atb.npu_ring_mla( q_nope=q_nope, q_rope=q_pe, k_nope=k_nope, k_rope=k_pe, value=v, mask=mask, seqlen=seq_len, head_num=self.num_heads, kv_head_num=self.num_heads, pre_out=prefix_output, prev_lse=prefix_lse, qk_scale=self.scale, kernel_type="kernel_type_high_precision", mask_type="no_mask", input_layout="type_bsnd", calc_type="calc_type_default", output=prefix_output, softmax_lse=prefix_lse, ) return prefix_output, prefix_lse def _forward_prefill( self, q_nope: torch.Tensor, q_pe: torch.Tensor, k_nope: torch.Tensor, k_pe: torch.Tensor, value: torch.Tensor, kv_c_and_k_pe_cache: tuple[torch.Tensor], attn_metadata: AscendMLAMetadata, ) -> torch.Tensor: assert attn_metadata.prefill is not None assert len(kv_c_and_k_pe_cache) > 1 num_tokens = q_nope.size(0) attn_output = torch.empty(num_tokens, self.num_heads, self.v_head_dim, dtype=q_nope.dtype, device=q_nope.device) attn_lse = torch.empty(self.num_heads, num_tokens, dtype=torch.float32, device=q_nope.device) torch_npu.atb.npu_ring_mla( q_nope=q_nope, q_rope=q_pe, k_nope=k_nope, k_rope=k_pe, value=value, mask=attn_metadata.attn_mask, seqlen=attn_metadata.prefill.query_lens, head_num=self.num_heads, kv_head_num=self.num_heads, pre_out=None, prev_lse=None, qk_scale=self.scale, kernel_type="kernel_type_high_precision", mask_type="mask_type_triu", input_layout="type_bsnd", calc_type="calc_type_first_ring", output=attn_output, softmax_lse=attn_lse, ) attn_output, attn_lse = self._compute_prefill_context( q_nope, q_pe, kv_c_and_k_pe_cache, self.qk_rope_head_dim, attn_metadata, attn_output, attn_lse ) attn_output = attn_output.reshape([num_tokens, self.num_heads * self.v_head_dim]) return attn_output def exec_kv_decode( self, kv_no_split: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, kv_cache: tuple, slots: torch.Tensor, ): B = kv_no_split.shape[0] N = self.num_kv_heads S = 1 # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] kv_no_split = kv_no_split.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) cache_mode = "PA_NZ" if self.enable_kv_nz else "PA" k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( kv_no_split, self.kv_a_layernorm.weight, # type: ignore[union-attr] cos, sin, slots.to(torch.int64), kv_cache[1], kv_cache[0], epsilon=self.kv_a_layernorm.variance_epsilon, # type: ignore[union-attr] cache_mode=cache_mode, ) return k_pe, k_nope def exec_kv_prefill( self, kv_no_split: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, kv_cache: tuple, slots: torch.Tensor, ): B = kv_no_split.shape[0] N = self.num_kv_heads S = 1 # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] kv_no_split = kv_no_split.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) cache_mode = "PA" _, _, k_pe, k_nope = torch_npu.npu_kv_rmsnorm_rope_cache( kv_no_split, self.kv_a_layernorm.weight, # type: ignore[union-attr] cos, sin, slots.to(torch.int64), kv_cache[1], kv_cache[0], epsilon=self.kv_a_layernorm.variance_epsilon, # type: ignore[union-attr] cache_mode=cache_mode, is_output_kv=True, ) return k_pe, k_nope def rope_single( self, x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, ) -> torch.Tensor: B, N, D = x.shape S = 1 x = x.view(B, N, S, D) x = torch_npu.npu_interleave_rope(x, cos, sin) return x.view(B, N, D) def _forward_decode( self, q_nope: torch.Tensor, q_pe: torch.Tensor, k_nope: torch.Tensor, k_pe: torch.Tensor, block_size: int, attn_metadata: AscendMLAMetadata, ) -> torch.Tensor: decode_meta = attn_metadata.decode assert decode_meta is not None num_tokens = q_nope.size(0) # shape of knope/k_pe for npu graph mode should be: # [num_blocks, num_kv_heads, block_size, self.kv_lora_rank/self.qk_rope_head_dim] actual_seq_lengths = None if self.enable_kv_nz: nz_fmt_last_dim = 16 k_nope = k_nope.view( -1, self.num_kv_heads, self.kv_lora_rank // nz_fmt_last_dim, block_size, nz_fmt_last_dim ) k_pe = k_pe.view( -1, self.num_kv_heads, self.qk_rope_head_dim // nz_fmt_last_dim, block_size, nz_fmt_last_dim ) else: k_nope = k_nope.view(-1, self.num_kv_heads, block_size, self.kv_lora_rank) k_pe = k_pe.view(-1, self.num_kv_heads, block_size, self.qk_rope_head_dim) attn_output_shape: tuple | None = None if ( attn_metadata.attn_state in [ AscendAttentionState.SpecDecoding, AscendAttentionState.ChunkedPrefill, AscendAttentionState.DecodeOnly, ] and self.speculative_config is not None ): # The right part layout indicates the layout of the attention # output. It is set to NTD to avoid the need for a transpose # operation after attention. input_layout = "TND_NTD" # TODO: If the driver is upgraded later, the contiguous function can be deleted. # Input shape: [num_tokens, num_heads, dim] q_nope = q_nope.view(num_tokens, self.num_heads, -1).contiguous() q_pe = q_pe.view(num_tokens, self.num_heads, -1) # Output shape: [num_heads, num_tokens, dim] attn_output_shape = (self.num_heads, num_tokens, self.kv_lora_rank) sparse_mode = 3 attn_mask = attn_metadata.decode.attn_mask # type:ignore actual_seq_lengths = decode_meta.actual_seq_lengths_q else: # The output layout is set to NBSD to eliminate the need for a # transpose operation after attention. if self.enable_kv_nz: # Input shape: [num_tokens, seq_len, num_heads, dim] input_layout = "BSND_NBSD" q_nope = q_nope.view(num_tokens, 1, self.num_heads, -1).contiguous() q_pe = q_pe.view(num_tokens, 1, self.num_heads, -1) else: # Input shape: [num_tokens, num_heads, seq_len, dim] input_layout = "BNSD_NBSD" q_nope = q_nope.view(num_tokens, self.num_heads, 1, -1).contiguous() q_pe = q_pe.view(num_tokens, self.num_heads, 1, -1) # Output shape: [num_heads, num_tokens, seq_len, dim] attn_output_shape = (self.num_heads, num_tokens, 1, self.kv_lora_rank) sparse_mode = 0 attn_mask = None common_kwargs = { "query_rope": q_pe, "key_rope": k_pe, "num_heads": self.num_heads, "num_key_value_heads": self.num_kv_heads, "input_layout": input_layout, "atten_mask": attn_mask, "sparse_mode": sparse_mode, "scale": self.scale, "antiquant_mode": 0, "antiquant_scale": None, "block_table": decode_meta.block_table, "block_size": block_size, "actual_seq_lengths": actual_seq_lengths, "actual_seq_lengths_kv": decode_meta.seq_lens_list, } forward_context: ForwardContext = get_forward_context() if forward_context.is_draft_model: graph_params = get_draft_graph_params() else: graph_params = get_graph_params() if forward_context.capturing: stream = torch_npu.npu.current_stream() event = torch.npu.ExternalEvent() event.wait(stream) event.reset(stream) graph_params.events[num_tokens].append(event) workspace = graph_params.workspaces.get(num_tokens) if workspace is None: workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace( q_nope, k_nope, k_nope, **common_kwargs ) if forward_context.is_draft_model: update_draft_graph_params_workspaces(num_tokens, workspace) else: update_graph_params_workspaces(num_tokens, workspace) attn_output = torch.empty(attn_output_shape, dtype=q_nope.dtype, device=q_nope.device) softmax_lse = torch.empty(num_tokens, dtype=q_nope.dtype, device=q_nope.device) graph_params.attn_params[num_tokens].append( ( weak_ref_tensors(q_nope), weak_ref_tensors(k_nope), weak_ref_tensors(q_pe), weak_ref_tensors(k_pe), self.num_heads, self.num_kv_heads, input_layout, weak_ref_tensors(attn_mask) if attn_mask is not None else None, sparse_mode, self.scale, decode_meta.block_table, block_size, decode_meta.seq_lens_list, actual_seq_lengths, weak_ref_tensors(attn_output), weak_ref_tensors(softmax_lse), ) ) torch.npu.graph_task_group_begin(stream) torch_npu.npu_fused_infer_attention_score.out( q_nope, k_nope, k_nope, **common_kwargs, workspace=workspace, out=[attn_output, softmax_lse] ) handle = torch.npu.graph_task_group_end(stream) graph_params.handles[num_tokens].append(handle) else: attn_output, _ = torch_npu.npu_fused_infer_attention_score(q_nope, k_nope, k_nope, **common_kwargs) return self._v_up_proj(attn_output) def reorg_decode_q(self, decode_q_nope, decode_q_pe): return decode_q_nope, decode_q_pe def _mla_preprocess_only_decode(self, hidden_states, kv_cache, attn_metadata): bsz = attn_metadata.num_decode_tokens hidden_states = hidden_states[:bsz] cos_shape = attn_metadata.decode.cos.shape cos = attn_metadata.decode.cos.view(cos_shape[0], cos_shape[-1]) sin = attn_metadata.decode.sin.view(cos_shape[0], cos_shape[-1]) decode_k_nope, decode_k_pe = kv_cache[0], kv_cache[1] decode_q_nope = torch.empty( (hidden_states.shape[0], self.W_UK_T.shape[0], decode_k_nope.shape[-1]), dtype=hidden_states.dtype, device=hidden_states.device, ) decode_q_pe = torch.empty( (hidden_states.shape[0], self.W_UK_T.shape[0], decode_k_pe.shape[-1]), dtype=hidden_states.dtype, device=hidden_states.device, ) torch.ops._C_ascend.mla_preprocess( hidden_states, self.wd_qkv, self.deq_scale_qkv, self.gamma1, self.beta1, self.wu_q, self.qb_deq_scl, self.gamma2, cos, sin, self.W_UK_T, decode_k_nope, decode_k_pe, attn_metadata.slot_mapping[:bsz], quant_scale0=self.quant_scale0, quant_offset0=self.quant_offset0, bias0=self.quant_bias_qkv, quant_scale1=self.quant_scale1, quant_offset1=self.quant_offset1, bias1=self.qb_qt_bias, ctkv_scale=self.ctkv_scale, q_nope_scale=self.q_nope_scale, cache_mode="nzcache" if self.enable_kv_nz else "krope_ctkv", quant_mode="per_tensor_quant_asymm", q_out0=decode_q_nope, kv_cache_out0=decode_k_nope, q_out1=decode_q_pe, kv_cache_out1=decode_k_pe, enable_inner_out=False, inner_out=torch.tensor([], device=hidden_states.device), ) decode_q_nope = decode_q_nope.view(bsz, self.num_heads, self.kv_lora_rank) decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1) decode_q_nope, decode_q_pe = self.reorg_decode_q(decode_q_nope, decode_q_pe) decode_preprocess_res = DecodeMLAPreprocessResult(decode_q_nope, decode_q_pe, decode_k_nope, decode_k_pe) return decode_preprocess_res, None def mla_preprocess_prefill(self, q_c, kv_no_split, kv_cache, attn_metadata): num_decode_tokens = attn_metadata.num_decode_tokens num_actual_tokens = attn_metadata.num_actual_tokens prefill_kv_no_split = kv_no_split[num_decode_tokens:num_actual_tokens] prefill_q_c = q_c[num_decode_tokens:num_actual_tokens] prefill_q = self.q_proj(prefill_q_c)[0].view(-1, self.num_heads, self.qk_head_dim) prefill_q_pe = prefill_q[..., self.qk_nope_head_dim :] prefill_q_nope = prefill_q[..., : self.qk_nope_head_dim] cos = attn_metadata.prefill.cos sin = attn_metadata.prefill.sin prefill_slots = attn_metadata.slot_mapping[num_decode_tokens:num_actual_tokens] prefill_q_pe = self.rope_single(prefill_q_pe, cos, sin) if self.is_kv_producer: attn_metadata.reshape_cache_event = torch.npu.Event() prefill_k_pe, prefill_k_c_normed = self.exec_kv_prefill(prefill_kv_no_split, cos, sin, kv_cache, prefill_slots) if self.is_kv_producer: attn_metadata.reshape_cache_event.record() prefill_k_nope, prefill_value = ( self.kv_b_proj(prefill_k_c_normed)[0] .view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) ) prefill_k_pe = prefill_k_pe.view(prefill_q_c.shape[0], self.num_kv_heads, -1) prefill_k_pe = prefill_k_pe.expand((*prefill_k_nope.shape[:-1], -1)) return PrefillMLAPreprocessResult(prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value) def mla_preprocess_decode(self, q_c, kv_no_split, kv_cache, attn_metadata): num_decode_tokens = attn_metadata.num_decode_tokens decode_q_c = q_c[:num_decode_tokens] cos = attn_metadata.decode.cos sin = attn_metadata.decode.sin decode_ql_nope, decode_q_pe = self._q_proj_and_k_up_proj(decode_q_c) decode_q_pe = self.rope_single(decode_q_pe, cos, sin) decode_slots = attn_metadata.slot_mapping[:num_decode_tokens:1] decode_kv_no_split = kv_no_split[:num_decode_tokens] decode_k_pe, decode_k_nope = self.exec_kv_decode(decode_kv_no_split, cos, sin, kv_cache, decode_slots) return DecodeMLAPreprocessResult(decode_ql_nope, decode_q_pe, decode_k_nope, decode_k_pe) def _mla_preprocess(self, layer_name, hidden_states, kv_cache, attn_metadata, need_gather_q_kv): # MLA Preprocess: # 1. Perform fused_qkv_a_proj and q_a_layernorm to obtain q_c and kv_no_split # or # Perform kv_a_proj_with_mqa to obtain kv_no_split # 2. If need_gather_q_kv, perform all_gather. # 3. Preprocess decode tokens, write kv cache and get: # decode_ql_nope, decode_q_pe, decode_k_pe, decode_k_nope # 4. Preprocess prefill tokens, write kv cache and get: # prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value has_decode = attn_metadata.num_decodes > 0 has_prefill = attn_metadata.num_prefills > 0 if self.fused_qkv_a_proj is not None: weight_prefetch_method = get_weight_prefetch_method() weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream( inputs=self.fused_qkv_a_proj.weight, dependency=hidden_states ) qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] q_c, kv_no_split = qkv_lora.split( [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1, ) q_c = self.q_a_layernorm(q_c) # type: ignore[misc] # allgather need contiguous data kv_no_split = kv_no_split.contiguous() else: q_c = hidden_states kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0] # type: ignore[misc] # Process for Flash Comm V1 q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(q_c.contiguous(), need_gather_q_kv) kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(kv_no_split.contiguous(), need_gather_q_kv) for layer in self.layer_sharding_kwargs or []: if is_hidden_layer(layer): reach_layer_for_shard_weight_series(layer) decode_preprocess_res = None prefill_preprocess_res = None if has_prefill: wait_for_kv_layer_from_connector(layer_name) # Preprocess for decode tokens if has_decode: decode_preprocess_res = self.mla_preprocess_decode(q_c, kv_no_split, kv_cache, attn_metadata) # Preprocess for prefill tokens if has_prefill: prefill_preprocess_res = self.mla_preprocess_prefill(q_c, kv_no_split, kv_cache, attn_metadata) return decode_preprocess_res, prefill_preprocess_res def get_num_actual_tokens(self, attn_metadata: M): return attn_metadata.num_actual_tokens def forward_mha( self, layer_name: str, hidden_states: torch.Tensor, kv_cache: tuple[torch.Tensor], attn_metadata: M, need_gather_q_kv: bool = False, output: torch.Tensor | None = None, ) -> torch.Tensor: raise NotImplementedError("forward_mha is not supported for MLA attention. Use forward() instead.") def forward_mqa( self, layer_name: str, hidden_states: torch.Tensor, kv_cache: tuple[torch.Tensor], attn_metadata: M, need_gather_q_kv: bool = False, output: torch.Tensor | None = None, ) -> torch.Tensor: raise NotImplementedError("forward_mqa is not supported for MLA attention. Use forward() instead.") def forward( self, layer_name, hidden_states: torch.Tensor, # query in unified attn kv_cache: tuple[torch.Tensor], attn_metadata: M, need_gather_q_kv: bool = False, output: torch.Tensor | None = None, ) -> torch.Tensor: assert output is not None, "Output tensor must be provided." if attn_metadata is None: # Profiling run. for layer in self.layer_sharding_kwargs or []: if is_hidden_layer(layer): reach_layer_for_shard_weight_series(layer) return output.fill_(0) forward_context = get_forward_context() num_actual_tokens = self.get_num_actual_tokens(attn_metadata) assert ( attn_metadata.num_decodes is not None and attn_metadata.num_prefills is not None and attn_metadata.num_decode_tokens is not None ) has_prefill = attn_metadata.num_prefills > 0 num_decode_tokens = attn_metadata.num_decode_tokens # Inputs and outputs may be padded for CUDA graphs output_padded = output o_proj_input_shape = (forward_context.num_tokens, self.num_heads * self.v_head_dim) o_proj_input = torch.empty(o_proj_input_shape, dtype=hidden_states.dtype, device=hidden_states.device) # MLA Preprocess if self.enable_mlapo and attn_metadata.num_decode_tokens <= MLAPO_MAX_SUPPORTED_TOKENS: hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( hidden_states.contiguous(), need_gather_q_kv ) decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess_only_decode( hidden_states, kv_cache, attn_metadata ) else: decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess( layer_name, hidden_states, kv_cache, attn_metadata, need_gather_q_kv ) if decode_preprocess_res is not None: # MLA Preprocess for decoding output_decode = self._forward_decode( decode_preprocess_res.ql_nope, decode_preprocess_res.q_pe, decode_preprocess_res.k_nope, decode_preprocess_res.k_pe, kv_cache[0].shape[1], attn_metadata, ) o_proj_input[:num_decode_tokens] = output_decode if prefill_preprocess_res is not None: # FIX: aicore move should be also placed on the comm stream in dbo, # otherwise it may affect the accuracy # TODO: use an elegant way to overlap output_prefill = self._forward_prefill( prefill_preprocess_res.q_nope, prefill_preprocess_res.q_pe, prefill_preprocess_res.k_nope, prefill_preprocess_res.k_pe, prefill_preprocess_res.value, kv_cache, attn_metadata, ) o_proj_input[num_decode_tokens:num_actual_tokens] = output_prefill # O proj weight_prefetch_method = get_weight_prefetch_method() weight_prefetch_method.maybe_prefetch_mla_or_sla_weight_in_current_stream( inputs=self.o_proj.weight, dependency=o_proj_input, max_size=MAX_O_PROJ_PREFETCH_SIZE, linear_layer=self.o_proj, ) output[...] = self.o_proj(o_proj_input, is_prefill=prefill_preprocess_res is not None)[0] del o_proj_input if has_prefill: maybe_save_kv_layer_to_connector(layer_name, list(kv_cache)) return output_padded