diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index c5c29cbc..faf03253 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -52,6 +52,8 @@ if prefill_context_parallel_enable(): if TYPE_CHECKING: from vllm.v1.core.sched.output import SchedulerOutput +MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024 + class AscendMLABackend(AttentionBackend): @@ -808,16 +810,17 @@ class AscendMLAImpl(MLAAttentionImpl): # self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29) # self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29) - # Currently mlapo only supports W8A8 quantization in MLA scenario - # TODO(whx): modify this limitation when mlapo supports floating point - if self.fused_qkv_a_proj is None or not isinstance( - getattr(self.fused_qkv_a_proj.quant_method, 'quant_method', - None), AscendW8A8LinearMethod): - self.enable_mlapo = False - logger.warning_once( - "Currently mlapo only supports W8A8 quantization in MLA scenario." - "Some layers in your model are not quantized with W8A8," - "thus mlapo is disabled for these layers.") + if self.enable_mlapo: + # Currently mlapo only supports W8A8 quantization in MLA scenario + # TODO(whx): modify this limitation when mlapo supports floating point + if self.fused_qkv_a_proj is None or not isinstance( + getattr(self.fused_qkv_a_proj.quant_method, 'quant_method', + None), AscendW8A8LinearMethod): + self.enable_mlapo = False + logger.warning_once( + "Currently mlapo only supports W8A8 quantization in MLA scenario." + "Some layers in your model are not quantized with W8A8," + "thus mlapo is disabled for these layers.") if self.enable_mlapo: self._process_weights_for_fused_mlapo(act_dtype) @@ -1282,12 +1285,13 @@ class AscendMLAImpl(MLAAttentionImpl): def _mla_preprocess(self, layer_name, hidden_states, kv_cache, attn_metadata, need_gather_q_kv): # MLA Preprocess: - # 1. Perform q_a_proj and q_a_layernorm to obtain q_c - # 2. Perform kv_a_proj_with_mqa to obtain kv_no_split - # 3. If need_gather_q_kv, perform all_gather. - # 4. Preprocess decode tokens, write kv cache and get: + # 1. Perform fused_qkv_a_proj and q_a_layernorm to obtain q_c and kv_no_split + # or + # Perform kv_a_proj_with_mqa to obtain kv_no_split + # 2. If need_gather_q_kv, perform all_gather. + # 3. Preprocess decode tokens, write kv cache and get: # decode_ql_nope, decode_q_pe, decode_k_pe, decode_k_nope - # 5. Preprocess prefill tokens, write kv cache and get: + # 4. Preprocess prefill tokens, write kv cache and get: # prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value has_decode = attn_metadata.num_decodes > 0 has_prefill = attn_metadata.num_prefills > 0 diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index b6514028..0558b384 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -1,976 +1,535 @@ -from dataclasses import dataclass -from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type, - TypeVar) - -import torch -import torch_npu -from torch import nn -from vllm.attention.backends.abstract import (AttentionBackend, - AttentionMetadata, - MLAAttentionImpl) -from vllm.config import VllmConfig, get_current_vllm_config -from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group -from vllm.model_executor.layers.linear import (LinearBase, - UnquantizedLinearMethod) -from vllm.utils import cdiv, round_down -from vllm.v1.attention.backends.utils import AttentionCGSupport - -from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.attention.attention_v1 import AscendAttentionState -from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, - split_decodes_and_prefills) -from vllm_ascend.worker.npu_input_batch import InputBatch - -if TYPE_CHECKING: - from vllm.v1.core.sched.output import SchedulerOutput - - -class AscendSFABackend(AttentionBackend): - - accept_output_buffer: bool = True - - @staticmethod - def get_name() -> str: - return "ASCEND_SFA" - - @staticmethod - def get_metadata_cls() -> type["AttentionMetadata"]: - return AscendSFAMetadata - - @staticmethod - def get_builder_cls(): - return AscendSFAMetadataBuilder - - @staticmethod - def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, - head_size: int) -> tuple[int, ...]: - return (num_blocks, block_size, num_kv_heads, head_size) - - @staticmethod - def get_impl_cls() -> Type["AscendSFAImpl"]: - return AscendSFAImpl - - -@dataclass -class AscendSFAPrefillMetadata: - """ Prefill Specific Metadata for Ascend""" - - @dataclass - class ChunkedContextMetadata: - # New for MLA (compared to FlashAttention) - # For handling chunked prefill - cu_seq_lens: torch.Tensor - starts: torch.Tensor - seq_tot: list[int] - max_seq_lens: list[int] - workspace: torch.Tensor - chunk_seq_lens: torch.Tensor - - attn_mask: torch.Tensor - query_lens: list[int] - seq_lens: list[int] - - context_lens: torch.Tensor - input_positions: torch.Tensor - query_start_loc: torch.Tensor - block_table: torch.Tensor - max_query_len: int - max_seq_lens: int - sin: torch.Tensor - cos: torch.Tensor - chunked_context: Optional[ChunkedContextMetadata] = None - - -@dataclass -class AscendSFADecodeMetadata: - # Input positions for rotrary embeddings since for MLA the rotary - # position embeddings are applied inside the attention backend - input_positions: torch.Tensor - block_table: torch.Tensor - seq_lens: torch.Tensor - max_seq_lens: int - seq_lens_list: list[int] - actual_seq_lengths_q: torch.Tensor - sin: torch.Tensor - cos: torch.Tensor - attn_mask: Optional[torch.Tensor] = None - - -@dataclass -class AscendSFAMetadata: - """Metadata for MLACommon. - - NOTE: Please read the comment at the top of the file before trying to - understand this class - """ - # NOTE(sang): Definition of context_len, query_len, and seq_len. - # |---------- N-1 iteration --------| - # |---------------- N iteration ---------------------| - # |- tokenA -|......................|-- newTokens ---| - # |---------- context_len ----------| - # |-------------------- seq_len ---------------------| - # |-- query_len ---| - - num_actual_tokens: int # Number of tokens excluding padding. - slot_mapping: torch.Tensor - query_start_loc: torch.Tensor - seq_lens: torch.Tensor - block_tables: torch.Tensor - - # New for MLA (compared to FlashAttention) - # For handling prefill decode split - num_decodes: int - num_decode_tokens: int - num_prefills: int - - # For logging. - num_input_tokens: int = 0 # Number of tokens including padding. - - query_lens: Optional[list[int]] = None - # The dimension of the attention heads - head_dim: Optional[int] = None - attn_mask: torch.Tensor = None - # chunked prefill by default if no attn_states passed - attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill - - decode: Optional[AscendSFADecodeMetadata] = None - prefill: Optional[AscendSFAPrefillMetadata] = None - - def __post_init__(self): - pass - # supported_head_sizes = AscendMLABackend.get_supported_head_sizes() - # if self.head_dim is not None and self.head_dim \ - # not in supported_head_sizes: - # raise ValueError( - # f"Only {supported_head_sizes} are supported for head_dim,", - # f"received {self.head_dim}.") - - -M = TypeVar("M", bound=AscendSFAMetadata) - - -class AscendSFAMetadataBuilder: - # Does this backend/builder support ACL Graphs for attention (default: no). - aclgraph_support: ClassVar[AttentionCGSupport] = \ - AttentionCGSupport.NEVER - """ - NOTE: Please read the comment at the top of the file before trying to - understand this class - """ - - # _attn_mask_builder = None - def __init__(self, - kv_cache_spec, - layer_names, - vllm_config: VllmConfig, - device: torch.device, - metadata_cls: Optional[AscendSFAMetadata] = None): - self.metadata_cls: Optional[AscendSFAMetadata] = metadata_cls \ - if metadata_cls is not None else AscendSFAMetadata # type: ignore - self.vllm_config = vllm_config - self.model_config = vllm_config.model_config - self.device = device - scheduler_config = vllm_config.scheduler_config - self.block_size = vllm_config.cache_config.block_size - self.max_blocks = (vllm_config.model_config.max_model_len + - self.block_size - 1) // self.block_size - self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled - - self.speculative_config = vllm_config.speculative_config - self.decode_threshold = 1 - if self.speculative_config: - spec_token_num = self.speculative_config.num_speculative_tokens - self.decode_threshold += spec_token_num - assert self.decode_threshold <= 16, f"decode_threshold exceeded \ - npu_fused_infer_attention_score TND layout's limit of 16, \ - got {self.decode_threshold}" - - if self.chunked_prefill_enabled: - self.chunked_prefill_workspace_size = min( - # Max sure there is enough for 8 full length request or at least - # 4 pages of cache per request - max(8 * self.model_config.max_model_len, - 4 * scheduler_config.max_num_seqs * self.block_size), - # For long-context models try not to over-allocate limiting - # kv-cache space, limiting it to 64k tokens, - # which would result in the workspace being: - # 2*(576)*(64*1024) = 144mb - # (assuming 576 MLA head dim, and fp16) - # which would result in up-projected context being - # 2*(192*128)*(64*1024) = 3gb - # (assuming 192 QK head dim, 128 heads, and fp16) - 128 * 1024) - assert self.chunked_prefill_workspace_size >= \ - scheduler_config.max_num_seqs * self.block_size - self.chunked_prefill_workspace = torch.empty( - (self.chunked_prefill_workspace_size, - self.model_config.get_head_size()), - dtype=self.model_config.dtype, - device=device, - ) - self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim - self.cos_cache = None - self.sin_cache = None - - def reorder_batch(self, input_batch: "InputBatch", - scheduler_output: "SchedulerOutput") -> bool: - # We now want to reorder the batch so that the "decode" requests are at - # the front and the "prefill" requests are at the using the least amount - # swaps possible. (NOTE for now we loosely use "decode" to mean requests - # where attention is likely memory-bound and "prefill" to mean requests - # where attention is likely compute-bound, TODO(lucas): figure out a - # better naming here) - decodes = [] - prefills = [] - - for i, req_id in enumerate(input_batch.req_ids): - num_tokens = scheduler_output.num_scheduled_tokens[req_id] - if num_tokens <= self.decode_threshold: - decodes.append(i) - else: - prefills.append(i) - - # We hope that this is fairly minimal since decodes - # should be around for a number of iterations so hopefully they are - # relatively stationary (and new request are generally appended to the - # persistent batch so already should be at the back) - # To achieve this we loop over the decodes in descending order and - # the prefills in ascending order. We swap decodes from the "back" - # i.e. past where the last decode should be in the reodorered with - # prefills from the front of the batch. - # `decodes` and `prefills` are already in ascending order just based on - # the above loop - num_decodes = len(decodes) - num_prefills = len(prefills) - first_prefill = 0 - modified_batch = False - - for i in range(1, min(num_decodes, num_prefills) + 1): - # If the decode is at the "back" of the batch, i, we can swap it - # with the prefill closest to the front of the batch - if decodes[num_decodes - i] >= num_decodes: - input_batch.swap_states(prefills[first_prefill], - decodes[num_decodes - i]) - first_prefill += 1 - modified_batch = True - else: - break - - # Save for next `build` call - # TODO(lucas): this is a bit of a hack, we should probably have a - # better way of doing this - return modified_batch - - def build( - self, - common_prefix_len: int, - common_attn_metadata: AscendCommonAttentionMetadata, - model: nn.Module, - ) -> AscendSFAMetadata: - num_reqs = common_attn_metadata.num_reqs - num_actual_tokens = common_attn_metadata.num_actual_tokens - query_start_loc = common_attn_metadata.query_start_loc - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ - split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold) - assert num_decodes + num_prefills == num_reqs - assert num_decode_tokens + num_prefill_tokens == num_actual_tokens - - # Note(simon): be careful about the CPU <> GPU memory movement in this - # function. We should avoid GPU -> CPU sync as much as possible because - # it blocks on all previous kernels. - device = self.device - - block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) - slot_mapping = common_attn_metadata.slot_mapping[: - num_actual_tokens].to( - device, - non_blocking=True) - input_positions = common_attn_metadata.positions[: - num_actual_tokens].long( - ) - - if self.cos_cache is None: - self.cos_cache = model.model.layers[ - model.model.start_layer].self_attn.rotary_emb.cos_cached - self.sin_cache = model.model.layers[ - model.model.start_layer].self_attn.rotary_emb.sin_cached - if self.cos_cache.dtype != self.model_config.dtype: # type: ignore - self.cos_cache = self.cos_cache.to( # type: ignore - self.model_config.dtype) # type: ignore - self.sin_cache = self.sin_cache.to( # type: ignore - self.model_config.dtype) # type: ignore - - query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] - query_lens = query_seq_lens_cpu[:num_reqs] - seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] - num_computed_tokens_cpu = (seq_lens - query_lens) - - prefill_metadata = None - chunked_context_metadata = None - if num_prefills > 0: - reqs_start = num_decodes # prefill_start - tokens_start = num_decode_tokens - max_query_len = query_lens[reqs_start:].max().item() - max_seq_lens = seq_lens[reqs_start:].max().item() - prefill_query_start_loc = query_start_loc[ - reqs_start:] - query_start_loc[reqs_start] - - context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs] - max_context_len_cpu = context_lens_cpu.max().item() - num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item() - if self.chunked_prefill_enabled and max_context_len_cpu > 0: - max_context_chunk = (self.chunked_prefill_workspace_size // - num_prefills_with_context_cpu) - max_context_chunk = round_down(max_context_chunk, - self.block_size) - - assert max_context_chunk > 0 - num_chunks = cdiv(max_context_len_cpu, max_context_chunk) - chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \ - .unsqueeze(1).expand(-1, num_prefills) * max_context_chunk - chunk_ends = torch.min(context_lens_cpu.unsqueeze(0), - chunk_starts + max_context_chunk) - chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0) - cu_seq_lens_cpu = torch.zeros(num_chunks, - num_prefills + 1, - dtype=torch.int32, - pin_memory=True) - torch.cumsum(chunk_seq_lens, - dim=1, - out=cu_seq_lens_cpu[:, 1:], - dtype=torch.int32) - chunked_context_metadata = \ - AscendSFAPrefillMetadata.ChunkedContextMetadata( - cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True), - starts=chunk_starts.to(device, non_blocking=True), - seq_tot=chunk_seq_lens.sum(dim=1).tolist(), - max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(), - chunk_seq_lens=chunk_seq_lens, - workspace=self.chunked_prefill_workspace, - ) - prefill_input_positions = input_positions[tokens_start:] - cos = self.cos_cache[ - prefill_input_positions].unsqueeze( # type: ignore - 1).unsqueeze(2) - sin = self.sin_cache[ - prefill_input_positions].unsqueeze( # type: ignore - 1).unsqueeze(2) - actual_query_lens = torch.tensor(query_lens[reqs_start:], - dtype=torch.int32).npu() - query_lens_prefill_sfa = torch.cumsum(actual_query_lens, - dim=0).to(torch.int32) - seq_lens_prefill_sfa = seq_lens[reqs_start:].to(torch.int32).npu() - prefill_metadata = AscendSFAPrefillMetadata( - attn_mask=common_attn_metadata.attn_mask, - query_lens=query_lens_prefill_sfa, - seq_lens=seq_lens_prefill_sfa, - context_lens=seq_lens[reqs_start:], - input_positions=prefill_input_positions, - block_table=block_table[reqs_start:, ...], - max_query_len=max_query_len, - max_seq_lens=max_seq_lens, - query_start_loc=prefill_query_start_loc, - chunked_context=chunked_context_metadata, - sin=sin, - cos=cos, - ) - - decode_metadata = None - if num_decodes > 0: - # Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario - actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].to( - torch.int32).npu() - max_seq_lens = seq_lens[:num_decodes].max().item() - seq_lens = seq_lens[:num_decodes].to(torch.int32).npu() - input_positions = input_positions[:num_decode_tokens] - block_table = block_table[:num_decodes, ...] - seq_lens_list = seq_lens.tolist() - - cos = self.cos_cache[input_positions].unsqueeze( # type: ignore - 1).unsqueeze(2) - sin = self.sin_cache[input_positions].unsqueeze( # type: ignore - 1).unsqueeze(2) - - decode_metadata = AscendSFADecodeMetadata( - input_positions=input_positions, - block_table=block_table, - seq_lens=seq_lens, - seq_lens_list=seq_lens_list, - max_seq_lens=max_seq_lens, - attn_mask=common_attn_metadata.spec_attn_mask, - actual_seq_lengths_q=actual_seq_lengths_q, - sin=sin, - cos=cos) - - return self.metadata_cls( # type: ignore - num_input_tokens=common_attn_metadata.num_input_tokens, - num_actual_tokens=num_actual_tokens, - query_lens=query_lens.tolist(), - slot_mapping=slot_mapping, - head_dim=self.model_config.get_head_size(), - num_decodes=num_decodes, - num_decode_tokens=num_decode_tokens, - num_prefills=num_prefills, - attn_mask=common_attn_metadata.attn_mask, - attn_state=common_attn_metadata.attn_state, - prefill=prefill_metadata, - decode=decode_metadata, - query_start_loc=query_start_loc, - block_tables=block_table, - seq_lens=seq_lens, - ) - - -class PrefillSFAPreprocessResult(NamedTuple): - q_nope: Optional[torch.Tensor] = None - q_pe: Optional[torch.Tensor] = None - k_nope: Optional[torch.Tensor] = None - k_pe: Optional[torch.Tensor] = None - topk_indices: Optional[torch.Tensor] = None - query_states: Optional[torch.Tensor] = None - key_states: Optional[torch.Tensor] = None - - -class DecodeSFAPreprocessResult(NamedTuple): - q_nope: Optional[torch.Tensor] = None - q_pe: Optional[torch.Tensor] = None - # nope_cache: Optional[torch.Tensor] = None - # rope_cache: Optional[torch.Tensor] = None - topk_indices: Optional[torch.Tensor] = None - query_states: Optional[torch.Tensor] = None - key_states: Optional[torch.Tensor] = None - bsz: Optional[int] = None - - -class AscendSFAImpl(MLAAttentionImpl): - """ - NOTE: Please read the comment at the top of the file before trying to - understand this class - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: int, - alibi_slopes: Optional[list[float]], - sliding_window: Optional[int], - kv_cache_dtype: str, - logits_soft_cap: Optional[float], - attn_type: str, - kv_sharing_target_layer_name: Optional[str], - **kwargs, - ) -> None: - self.num_heads = num_heads - self.head_size = head_size - self.scale = float(scale) - self.num_kv_heads = num_kv_heads - self.kv_cache_dtype = kv_cache_dtype - - # MLA Args - self.q_lora_rank = kwargs['q_lora_rank'] - self.kv_lora_rank = kwargs['kv_lora_rank'] - self.qk_nope_head_dim = kwargs['qk_nope_head_dim'] - self.qk_rope_head_dim = kwargs['qk_rope_head_dim'] - self.qk_head_dim = kwargs['qk_head_dim'] - self.v_head_dim = kwargs['v_head_dim'] - self.rotary_emb = kwargs['rotary_emb'] - self.q_proj = kwargs['q_proj'] if self.q_lora_rank is None else kwargs[ - 'q_b_proj'] - self.fused_qkv_a_proj = kwargs.get('fused_qkv_a_proj', None) - self.kv_b_proj = kwargs['kv_b_proj'] - self.o_proj = kwargs['o_proj'] - self.indexer = kwargs['indexer'] - self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None) - self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None) - self.q_a_layernorm = kwargs.get('q_a_layernorm', None) - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - self.tp_size = get_tensor_model_parallel_world_size() - self.num_heads_per_rank = self.num_heads // self.tp_size - self.q_b_proj = kwargs['q_b_proj'] - - ascend_config = get_ascend_config() - self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp - self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz - - vllm_config = get_current_vllm_config() - self.ring_mla_mask_size = 512 - self.prefill_mask = None - - # indexer param - self.dim = self.indexer.dim - self.n_heads: int = self.indexer.n_heads # 64 - self.head_dim: int = self.indexer.head_dim # 128 - self.index_topk: int = self.indexer.index_topk # 2048 - self.wq_b = self.indexer.wq_b - self.wk = self.indexer.wk - self.weights_proj = self.indexer.weights_proj - self.k_norm = self.indexer.k_norm - self.softmax_scale = self.indexer.softmax_scale - - # Adapt torch air graph mode with spec decoding. - speculative_config = vllm_config.speculative_config - if speculative_config is not None: - self.spec_token_num = speculative_config.num_speculative_tokens - assert self.spec_token_num > 0 - - self.cp_size = 1 - - def process_weights_after_loading(self, act_dtype: torch.dtype): - - def get_layer_weight(layer): - WEIGHT_NAMES = ("weight", "qweight", "weight_packed") - for attr in WEIGHT_NAMES: - if hasattr(layer, attr): - return getattr(layer, attr) - raise AttributeError( - f"Layer '{layer}' has no recognized weight attribute:" - f" {WEIGHT_NAMES}.") - - def get_and_maybe_dequant_weights(layer: LinearBase): - if not isinstance(layer.quant_method, UnquantizedLinearMethod): - # NOTE: This should only be used offline, since it's O(N^3) - eye = torch.eye(layer.input_size_per_partition, - dtype=act_dtype, - device=get_layer_weight(layer).device) - dequant_weights = layer.quant_method.apply(layer, - eye, - bias=None) - del eye - # standardize to (output, input) - return dequant_weights.T - return layer.weight - - # we currently do not have quantized bmm's which are needed for - # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform - # the bmm's in 16-bit, the extra memory overhead of this is fairly low - kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T - assert kv_b_proj_weight.shape == ( - self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( - f"{kv_b_proj_weight.shape=}, " - f"{self.kv_lora_rank=}, " - f"{self.num_heads=}, " - f"{self.qk_nope_head_dim=}, " - f"{self.v_head_dim=}") - kv_b_proj_weight = kv_b_proj_weight.view( - self.kv_lora_rank, - self.num_heads, - self.qk_nope_head_dim + self.v_head_dim, - ) - - self.kv_b_proj_w_k, self.kv_b_proj_w_v = kv_b_proj_weight.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - # Convert from (L, N, V) to (N, L, V) - self.kv_b_proj_w_v = self.kv_b_proj_w_v.transpose(0, 1).contiguous() - # Convert from (L, N, P) to (N, P, L) - self.kv_b_proj_w_k = self.kv_b_proj_w_k.permute(1, 2, 0).contiguous() - - # Waiting for BMM NZ support - # self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29) - # self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29) - - def _sfa_preprocess(self, hidden_states, kv_cache, attn_metadata, - need_gather_q_kv): - # SFA Preprocess: - # 1. Perform q_a_proj and q_a_layernorm to obtain q_c - # 2. Perform kv_a_proj_with_mqa to obtain kv_no_split - # 3. If need_gather_q_kv, perform all_gather. - # 4. Preprocess decode tokens, write kv cache and get: - # decode_ql_nope, decode_q_pe, decode_k_pe, decode_k_nope - # 5. Preprocess prefill tokens, write kv cache and get: - # prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value - has_decode = attn_metadata.num_decodes > 0 - has_prefill = attn_metadata.num_prefills > 0 - - num_decode_tokens = attn_metadata.num_decode_tokens - num_actual_tokens = attn_metadata.num_actual_tokens - if need_gather_q_kv: - # q_c = get_tp_group().all_gather(q_c, 0) - # kv_no_split = get_tp_group().all_gather(kv_no_split, 0) - hidden_states = get_tp_group().all_gather(hidden_states, 0) - # hidden_states_decode = hidden_states[:num_decode_tokens] - # if self.q_a_proj is not None: - # npu_prefetch(self.q_a_proj.weight, - # hidden_states, - # enabled=self.enable_prefetch) - # ckq = self.q_a_proj(hidden_states) # q down - # q_c = self.q_a_layernorm(ckq) # q down layernorm - # else: - # q_c = hidden_states - - # kv_no_split = self.kv_a_proj_with_mqa(hidden_states) # c_kv - # Process for shared_expert_dp - - decode_preprocess_res = None - prefill_preprocess_res = None - # Preprocess for decode tokens - if has_decode: - q_len = 1 - hidden_states_decode = hidden_states[:num_decode_tokens] - decode_qkv_lora = self.fused_qkv_a_proj(hidden_states_decode)[0] - decode_q_c, decode_kv_no_split = decode_qkv_lora.split( - [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], - dim=-1, - ) - decode_q_c = self.q_a_layernorm(decode_q_c) # q down layernorm - decode_kv_no_split = decode_kv_no_split.contiguous() - - # decode_q_c = q_c[:num_decode_tokens] - decode_slot_mapping = attn_metadata.slot_mapping[: - num_decode_tokens] - # decode_kv_no_split = decode_kv_no_split[:num_decode_tokens] - - decode_q = self.q_b_proj(decode_q_c) - bsz, _ = decode_q.shape - decode_q = decode_q.view(bsz, self.num_heads, 1, self.qk_head_dim) - decode_q_nope, decode_q_pe = torch.split( - decode_q, [self.qk_nope_head_dim, self.qk_rope_head_dim], - dim=-1) - decode_q_nope = decode_q_nope.view( - -1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1) - decode_q_nope = (torch.matmul(decode_q_nope, - self.kv_b_proj_w_k).transpose( - 1, - 0).view(bsz, q_len, - self.num_heads, - self.kv_lora_rank)) - - # stream2 kv - key_cache = kv_cache[0] - value_cache = kv_cache[1] - cos = attn_metadata.decode.cos - sin = attn_metadata.decode.sin - cos_q, sin_q = cos, sin - cos = cos.view(-1, 1, 1, self.qk_rope_head_dim) - sin = sin.view(-1, 1, 1, self.qk_rope_head_dim) - - decode_kv_no_split = decode_kv_no_split.unsqueeze(1).unsqueeze(1) - decode_k_rope, decode_k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( - decode_kv_no_split, - self.kv_a_layernorm.weight, - cos, - sin, - decode_slot_mapping.to(torch.int64), - value_cache, - key_cache, - c_kv_scale=None, - epsilon=self.kv_a_layernorm.variance_epsilon, - cache_mode='PA') # adapter NZ - # nz_block_size = 16 - # KVCACHE_NZ_DIM = 16 - # decode_k_nope = decode_k_nope.view(block_num, 1, self.kv_lora_rank // nz_block_size, block_size, nz_block_size) - # decode_k_rope = decode_k_rope.view(block_num, 1, self.qk_rope_head_dim // KVCACHE_NZ_DIM, block_size, KVCACHE_NZ_DIM) - - decode_q_pe = torch_npu.npu_interleave_rope(decode_q_pe, cos, - sin) # BNSD - - decode_q_nope = decode_q_nope.view(bsz, self.num_heads, - self.kv_lora_rank) - decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1) - - topk_indices = self.indexer_select(hidden_states_decode, - decode_q_c, - attn_metadata=attn_metadata, - cos=cos, - sin=sin, - kv_cache=kv_cache) - - query_states = (decode_q_nope, decode_q_pe) - key_states = (decode_k_nope, decode_k_rope) - decode_preprocess_res = DecodeSFAPreprocessResult( - q_nope=decode_q_nope, - q_pe=decode_q_pe, - # nope_cache = nope_cache, - # rope_cache = rope_cache, - topk_indices=topk_indices, - query_states=query_states, - key_states=key_states, - bsz=bsz, - ) - - # Preprocess for prefill tokens - if has_prefill: - bsz = 1 - - hidden_states_prefill = hidden_states[ - num_decode_tokens:num_actual_tokens] - prefill_qkv_lora = self.fused_qkv_a_proj(hidden_states_prefill)[0] - prefill_q_c, prefill_kv_no_split = prefill_qkv_lora.split( - [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], - dim=-1, - ) - prefill_q_c = self.q_a_layernorm(prefill_q_c) # q down layernorm - prefill_kv_no_split = prefill_kv_no_split.contiguous() - - # prefill_q_c = q_c[ - # num_decode_tokens:num_actual_tokens] - prefill_slot_mapping = attn_metadata.slot_mapping[ - num_decode_tokens:num_actual_tokens] - # decode_kv_no_split = decode_kv_no_split[:num_decode_tokens] - - prefill_slot_mapping = attn_metadata.slot_mapping[ - num_decode_tokens:num_actual_tokens] - # prefill_kv_no_split = kv_no_split[ - # num_decode_tokens:num_actual_tokens] - # prefill_qr = prefill_q_c[num_decode_tokens:num_actual_tokens] - prefill_qr = prefill_q_c - prefill_q = self.q_b_proj(prefill_qr) - prefill_q = prefill_q.view(-1, self.num_heads, self.qk_head_dim) - prefill_q_nope, prefill_q_pe = torch.split( - prefill_q, [self.qk_nope_head_dim, self.qk_rope_head_dim], - dim=-1) - prefill_q_nope = prefill_q_nope.view( - -1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1) - prefill_q_nope = (torch.matmul(prefill_q_nope, - self.kv_b_proj_w_k).transpose( - 1, - 0).view(-1, self.num_heads, - self.kv_lora_rank)) - prefill_q_pe = prefill_q_pe.unsqueeze(2) - - # stream2 kv - - nope_cache = kv_cache[0] - rope_cache = kv_cache[1] - cos = attn_metadata.prefill.cos - sin = attn_metadata.prefill.sin - cos_q, sin_q = cos, sin - - # cos = cos.view(-1, 1, 1, self.qk_rope_head_dim) - # sin = sin.view(-1, 1, 1, self.qk_rope_head_dim) - - prefill_q_pe = torch_npu.npu_interleave_rope( - prefill_q_pe, cos_q, sin_q) # BNSD - prefill_q_pe = prefill_q_pe.squeeze(2) #BSH - # q[..., self.qk_nope_head_dim:] = prefill_q_pe # TODO:???? - - prefill_latent_cache = prefill_kv_no_split # (B,S,N,D) - prefill_k_pe, prefill_k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( - prefill_latent_cache.view( - -1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim), - self.kv_a_layernorm.weight, - cos.view(-1, 1, 1, self.qk_rope_head_dim), - sin.view(-1, 1, 1, self.qk_rope_head_dim), - prefill_slot_mapping.to(torch.int64), - rope_cache, - nope_cache, - k_rope_scale=None, - c_kv_scale=None, - k_rope_offset=None, - c_kv_offset=None, - epsilon=self.kv_a_layernorm.variance_epsilon, - cache_mode="PA") - - topk_indices = self.indexer_select(x=hidden_states_prefill, - qr=prefill_qr, - kv_cache=kv_cache, - cos=cos, - sin=sin, - attn_metadata=attn_metadata) - query_states = (prefill_q_nope, prefill_q_pe) - key_states = (prefill_k_nope, prefill_k_pe) - prefill_preprocess_res = PrefillSFAPreprocessResult( - q_nope=prefill_q_nope, - q_pe=prefill_q_pe, - topk_indices=topk_indices, - k_nope=prefill_k_nope, - k_pe=prefill_k_pe, - query_states=query_states, - key_states=key_states, - ) - - return decode_preprocess_res, prefill_preprocess_res - - def forward( - self, - hidden_states: torch.Tensor, # query in unified attn - kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], - attn_metadata: M, - need_gather_q_kv: bool = False, - output: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - assert output is not None, "Output tensor must be provided." - if attn_metadata is None: - # Profiling run. - return output.fill_(0) - num_actual_tokens = attn_metadata.num_actual_tokens - assert attn_metadata.num_decodes is not None and \ - attn_metadata.num_prefills is not None and \ - attn_metadata.num_decode_tokens is not None - num_decode_tokens = attn_metadata.num_decode_tokens - # Inputs and outputs may be padded for CUDA graphs - output = output[:num_actual_tokens, ...] - o_proj_input_shape = (num_actual_tokens, - self.num_heads * self.v_head_dim) - o_proj_input = torch.empty(o_proj_input_shape, - dtype=hidden_states.dtype, - device=hidden_states.device) - - # SFA Preprocess - decode_preprocess_res, prefill_preprocess_res = self._sfa_preprocess( - hidden_states, kv_cache, attn_metadata, need_gather_q_kv) - - if decode_preprocess_res is not None: - # bsz, q_len, _, _ = query_states[0].shape - decode_attn_output = self.apply_attention_fusion( - query_states=decode_preprocess_res.query_states, - key_states=decode_preprocess_res.key_states, - attn_metadata=attn_metadata, - topk_indices=decode_preprocess_res.topk_indices) - o_proj_input[:num_decode_tokens] = decode_attn_output - - if prefill_preprocess_res is not None: - prefill_attn_output = self.apply_attention_fusion( - query_states=prefill_preprocess_res.query_states, - key_states=prefill_preprocess_res.key_states, - attn_metadata=attn_metadata, - topk_indices=prefill_preprocess_res.topk_indices) - o_proj_input[num_decode_tokens:] = prefill_attn_output - - output[...] = self.mla_epilog(o_proj_input, absorb=True) - return output - - def apply_attention_fusion(self, query_states, key_states, topk_indices, - attn_metadata: M): - # repeat k/v heads if n_kv_heads < n_heads - q_nope, q_pe = query_states - k_nope, k_rope = key_states - - if attn_metadata.prefill is not None: - - prefill_metadata = attn_metadata.prefill - - slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention( - query=q_nope, - key=k_nope, - value=k_nope, - sparse_indices=topk_indices, - scale_value=self.scale, - sparse_block_size=1, - block_table=prefill_metadata.block_table, - actual_seq_lengths_query=prefill_metadata.query_lens, - actual_seq_lengths_kv=prefill_metadata.seq_lens, - query_rope=q_pe, - key_rope=k_rope, - layout_query="TND", - layout_kv="PA_BSND", - sparse_mode=3, - ) - - elif attn_metadata.decode is not None: - decode_metadata = attn_metadata.decode - - slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention( - query=q_nope, - key=k_nope, - value=k_nope, - sparse_indices=topk_indices, - scale_value=self.scale, - sparse_block_size=1, - block_table=attn_metadata.decode.block_table, - actual_seq_lengths_query=decode_metadata.actual_seq_lengths_q, - actual_seq_lengths_kv=decode_metadata.seq_lens, - query_rope=q_pe, - key_rope=k_rope, - layout_query="TND", - layout_kv="PA_BSND", - sparse_mode=3, - ) - slc_fa_fusion = slc_fa_fusion.squeeze(1) - - slc_fa_fusion = slc_fa_fusion.transpose(0, 1) - - # input shape [N//attn_tp_size, T(bs*q_len), D] - # output shape [T(bs*q_len), N//attn_tp_size, D] - attn_output = torch.matmul(slc_fa_fusion, - self.kv_b_proj_w_v).transpose(1, 0).reshape( - -1, self.num_heads * self.v_head_dim) - # Note: Considering the fusion rules of TBMM, attn_output shape requires a 3-dim shape, and - # with appropriate tensor stride for the later 'view' operation if oproj_tp_size > 1. - # after reshape: [T(bs*q_len), 1, N//attn_tp_size*D] - # attn_output = attn_output.reshape(-1, self.num_heads * self.v_head_dim) - - return attn_output - - def mla_epilog(self, - attn_output: torch.Tensor = None, - absorb: bool = False): - # TODO: need to check - attn_output = self.o_proj(attn_output.reshape(attn_output.shape[0], - -1), - is_prefill=True, - is_force_scatter=False) - - return attn_output - - def indexer_select( - self, - x: torch.Tensor, - qr: torch.Tensor, - kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], - cos, - sin, - attn_metadata: M, - ): - if attn_metadata.prefill is not None: - actual_seq_lengths_query = attn_metadata.prefill.query_lens - actual_seq_lengths_key = attn_metadata.prefill.seq_lens - block_table = attn_metadata.prefill.block_table - elif attn_metadata.decode is not None: - actual_seq_lengths_query = attn_metadata.decode.actual_seq_lengths_q - actual_seq_lengths_key = attn_metadata.decode.seq_lens - block_table = attn_metadata.decode.block_table - - cos_q, sin_q = cos, sin - cos = cos.view(-1, 1, 1, self.qk_rope_head_dim) - sin = sin.view(-1, 1, 1, self.qk_rope_head_dim) - - # q process in new stream - q = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] - q = q.view(-1, self.n_heads, self.head_dim) # [b,s,64,128] - q_pe, q_nope = torch.split( - q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], - dim=-1) # [b,s,64,64+64] - - q_pe = q_pe.unsqueeze(2) - q_pe = torch_npu.npu_interleave_rope(q_pe, cos_q, sin_q) - q_pe = q_pe.squeeze(2) - q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128] - - k_proj = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128] - k = self.k_norm(k_proj).unsqueeze(1) - k_pe, k_nope = torch.split( - k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], - dim=-1) # [b,s,64+64] - - k_pe = k_pe.unsqueeze(2) - k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin) - k_pe = k_pe.squeeze(2) - - k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128] - - if kv_cache is not None: - torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]), - attn_metadata.slot_mapping.view( - -1, 1), - k.view(-1, - k.shape[-1])) # b, s, n, d - - weights = self.weights_proj(x) - - topk_indices = torch.ops.custom.npu_lightning_indexer( - query=q, - key=kv_cache[2], - weights=weights, - actual_seq_lengths_query=actual_seq_lengths_query, - actual_seq_lengths_key=actual_seq_lengths_key, - block_table=block_table, - layout_query="TND", - layout_key="PA_BSND", - sparse_count=2048, - sparse_mode=3) - return topk_indices +from dataclasses import dataclass +from typing import TYPE_CHECKING, ClassVar, Optional, Tuple, Type, TypeVar + +import torch +import torch_npu +from torch import nn +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata, + MLAAttentionImpl) +from vllm.config import VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.linear import (LinearBase, + UnquantizedLinearMethod) +from vllm.v1.attention.backends.utils import AttentionCGSupport + +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.attention.attention_v1 import AscendAttentionState +from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE +from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, + wait_for_kv_layer_from_connector) +from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch +from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, + is_enable_nz) +from vllm_ascend.worker.npu_input_batch import InputBatch + +if TYPE_CHECKING: + from vllm.v1.core.sched.output import SchedulerOutput + + +class AscendSFABackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_name() -> str: + return "ASCEND_SFA" + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + return AscendSFAMetadata + + @staticmethod + def get_builder_cls(): + return AscendSFAMetadataBuilder + + @staticmethod + def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int, + head_size: int) -> tuple[int, ...]: + return (num_blocks, block_size, num_kv_heads, head_size) + + @staticmethod + def get_impl_cls() -> Type["AscendSFAImpl"]: + return AscendSFAImpl + + +@dataclass +class AscendSFAMetadata: + """Metadata for MLACommon. + + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + # NOTE(sang): Definition of context_len, query_len, and seq_len. + # |---------- N-1 iteration --------| + # |---------------- N iteration ---------------------| + # |- tokenA -|......................|-- newTokens ---| + # |---------- context_len ----------| + # |-------------------- seq_len ---------------------| + # |-- query_len ---| + has_prefill: bool + num_actual_tokens: int # Number of tokens excluding padding. + slot_mapping: torch.Tensor + seq_lens: torch.Tensor + cum_query_lens: torch.Tensor + block_tables: torch.Tensor + sin: torch.Tensor + cos: torch.Tensor + + # For logging. + num_input_tokens: int = 0 # Number of tokens including padding. + # The dimension of the attention heads + head_dim: Optional[int] = None + attn_mask: torch.Tensor = None + # chunked prefill by default if no attn_states passed + attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill + + +M = TypeVar("M", bound=AscendSFAMetadata) + + +class AscendSFAMetadataBuilder: + # Does this backend/builder support ACL Graphs for attention (default: no). + aclgraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.NEVER + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + # _attn_mask_builder = None + def __init__(self, + kv_cache_spec, + layer_names, + vllm_config: VllmConfig, + device: torch.device, + metadata_cls: Optional[AscendSFAMetadata] = None): + self.metadata_cls: Optional[AscendSFAMetadata] = metadata_cls \ + if metadata_cls is not None else AscendSFAMetadata # type: ignore + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.device = device + self.block_size = vllm_config.cache_config.block_size + self.max_blocks = (vllm_config.model_config.max_model_len + + self.block_size - 1) // self.block_size + + self.speculative_config = vllm_config.speculative_config + self.decode_threshold = 1 + if self.speculative_config: + spec_token_num = self.speculative_config.num_speculative_tokens + self.decode_threshold += spec_token_num + assert self.decode_threshold <= 16, f"decode_threshold exceeded \ + npu_fused_infer_attention_score TND layout's limit of 16, \ + got {self.decode_threshold}" + + self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim + self.cos_cache = None + self.sin_cache = None + + def reorder_batch(self, input_batch: "InputBatch", + scheduler_output: "SchedulerOutput") -> bool: + # No need to reorder for Ascend SFA + return False + + def build( + self, + common_prefix_len: int, + common_attn_metadata: AscendCommonAttentionMetadata, + model: nn.Module, + ) -> AscendSFAMetadata: + num_reqs = common_attn_metadata.num_reqs + num_actual_tokens = common_attn_metadata.num_actual_tokens + query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu + device = self.device + + block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) + slot_mapping = common_attn_metadata.slot_mapping[: + num_actual_tokens].to( + device, + non_blocking=True) + input_positions = common_attn_metadata.positions[: + num_actual_tokens].long( + ) + query_start_loc = common_attn_metadata.query_start_loc + query_lens = query_start_loc[1:] - query_start_loc[:-1] + has_prefill = any(query_lens > self.decode_threshold) + + if self.cos_cache is None: + self.cos_cache = model.model.layers[ + model.model.start_layer].self_attn.rotary_emb.cos_cached + self.sin_cache = model.model.layers[ + model.model.start_layer].self_attn.rotary_emb.sin_cached + if self.cos_cache.dtype != self.model_config.dtype: # type: ignore + self.cos_cache = self.cos_cache.to( # type: ignore + self.model_config.dtype) # type: ignore + self.sin_cache = self.sin_cache.to( # type: ignore + self.model_config.dtype) # type: ignore + + cum_query_lens = query_start_loc_cpu[1:num_reqs + 1].to( + torch.int32).to(device, non_blocking=True) + seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs].to( + torch.int32).to(device, non_blocking=True) + + cos = self.cos_cache[input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + sin = self.sin_cache[input_positions].unsqueeze( # type: ignore + 1).unsqueeze(2) + + return self.metadata_cls( # type: ignore + has_prefill=has_prefill, + num_input_tokens=common_attn_metadata.num_input_tokens, + num_actual_tokens=num_actual_tokens, + cum_query_lens=cum_query_lens, + seq_lens=seq_lens, + slot_mapping=slot_mapping, + head_dim=self.model_config.get_head_size(), + attn_mask=common_attn_metadata.attn_mask, + attn_state=common_attn_metadata.attn_state, + block_tables=block_table, + sin=sin, + cos=cos) + + +class AscendSFAImpl(MLAAttentionImpl): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + **kwargs, + ) -> None: + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + self.kv_cache_dtype = kv_cache_dtype + + # MLA Args + self.q_lora_rank = kwargs['q_lora_rank'] + self.kv_lora_rank = kwargs['kv_lora_rank'] + self.qk_nope_head_dim = kwargs['qk_nope_head_dim'] + self.qk_rope_head_dim = kwargs['qk_rope_head_dim'] + self.qk_head_dim = kwargs['qk_head_dim'] + self.v_head_dim = kwargs['v_head_dim'] + self.rotary_emb = kwargs['rotary_emb'] + self.q_proj = kwargs['q_proj'] if self.q_lora_rank is None else kwargs[ + 'q_b_proj'] + self.fused_qkv_a_proj = kwargs.get('fused_qkv_a_proj', None) + self.kv_b_proj = kwargs['kv_b_proj'] + self.o_proj = kwargs['o_proj'] + self.indexer = kwargs['indexer'] + self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None) + self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None) + self.q_a_layernorm = kwargs.get('q_a_layernorm', None) + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.tp_size = get_tensor_model_parallel_world_size() + self.num_heads_per_rank = self.num_heads // self.tp_size + self.q_b_proj = kwargs['q_b_proj'] + + ascend_config = get_ascend_config() + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp + self.enable_prefetch = ascend_config.weight_prefetch_config.enabled + self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz + + assert self.indexer is not None, "Indexer is required for DSA." + # indexer param + self.n_head: int = self.indexer.n_head # 64 + self.head_dim: int = self.indexer.head_dim # 128 + self.wq_b = self.indexer.wq_b + self.wk = self.indexer.wk + self.weights_proj = self.indexer.weights_proj + self.k_norm = self.indexer.k_norm + + self.cp_size = 1 + + def process_weights_after_loading(self, act_dtype: torch.dtype): + + def get_layer_weight(layer): + WEIGHT_NAMES = ("weight", "qweight", "weight_packed") + for attr in WEIGHT_NAMES: + if hasattr(layer, attr): + return getattr(layer, attr) + raise AttributeError( + f"Layer '{layer}' has no recognized weight attribute:" + f" {WEIGHT_NAMES}.") + + def get_and_maybe_dequant_weights(layer: LinearBase): + if not isinstance(layer.quant_method, UnquantizedLinearMethod): + # NOTE: This should only be used offline, since it's O(N^3) + eye = torch.eye(layer.input_size_per_partition, + dtype=act_dtype, + device=get_layer_weight(layer).device) + dequant_weights = layer.quant_method.apply(layer, + eye, + bias=None) + del eye + # standardize to (output, input) + return dequant_weights.T + # Weight will be reshaped next. To be on the safe side, the format + # of the weight should be reverted to FRACTAL_AND. + layer.weight.data = torch_npu.npu_format_cast( + layer.weight.data, ACL_FORMAT_FRACTAL_ND) + return layer.weight + + # we currently do not have quantized bmm's which are needed for + # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform + # the bmm's in 16-bit, the extra memory overhead of this is fairly low + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + assert kv_b_proj_weight.shape == ( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}") + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + + W_UK, W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1).contiguous() + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0).contiguous() + + # Function `get_and_maybe_dequant_weights` will cast the weights to + # FRACTAL_AND. So we need to cast to FRACTAL_NZ again. + if is_enable_nz(): + self.kv_b_proj.weight.data = torch_npu.npu_format_cast( + self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_NZ) + + # Waiting for BMM NZ support + # self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29) + # self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29) + + def _v_up_proj(self, x): + if self.W_UV.shape[0] * self.W_UV.shape[1] < 65536: + x = x.view(-1, self.num_heads, self.kv_lora_rank) + x = torch_npu.npu_transpose_batchmatmul(x, + self.W_UV, + perm_x1=[1, 0, 2], + perm_x2=[0, 1, 2], + perm_y=[1, 0, 2]) + x = x.reshape(-1, self.num_heads * self.v_head_dim) + else: + # Convert from (B, N, L) to (N, B, L) + x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + # # Multiply (N, B, L) x (N, L, V) -> (N, B, V) + x = torch.bmm(x, self.W_UV) + # # Convert from (N, B, V) to (B, N * V) + x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + return x + + # Return `ql_nope`, `q_pe` + def _q_proj_and_k_up_proj(self, x): + q_nope, q_pe = self.q_proj(x)[0]\ + .view(-1, self.num_heads, self.qk_head_dim)\ + .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + + # Convert from (B, N, P) to (N, B, P) + q_nope = q_nope.transpose(0, 1) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + ql_nope = torch.bmm(q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + return ql_nope.transpose(0, 1), q_pe + + def exec_kv( + self, + kv_no_split: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + kv_cache: Tuple, + slots: torch.Tensor, + ): + B = kv_no_split.shape[0] + N = self.num_kv_heads + S = 1 + # npu_kv_rmsnorm_rope_cache needs [B, N, S, D] + kv_no_split = kv_no_split.view( + B, N, S, self.kv_lora_rank + self.qk_rope_head_dim) + cache_mode = "PA_NZ" if self.enable_kv_nz else "PA" + k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache( + kv_no_split, + self.kv_a_layernorm.weight, + cos, + sin, + slots.to(torch.int64), + kv_cache[1], + kv_cache[0], + epsilon=self.kv_a_layernorm.variance_epsilon, + cache_mode=cache_mode, + ) + return k_pe, k_nope + + def rope_single( + self, + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + ) -> torch.Tensor: + B, N, D = x.shape + S = 1 + x = x.view(B, N, S, D) + x = torch_npu.npu_interleave_rope(x, cos, sin) + return x.view(B, N, D) + + def forward( + self, + layer_name, + hidden_states: torch.Tensor, # query in unified attn + kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + attn_metadata: M, + need_gather_q_kv: bool = False, + output: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + assert output is not None, "Output tensor must be provided." + if attn_metadata is None: + # Profiling run. + return output.fill_(0) + has_prefill = attn_metadata.has_prefill + num_actual_tokens = attn_metadata.num_actual_tokens + hidden_states = hidden_states[:num_actual_tokens] + # Inputs and outputs may be padded for CUDA graphs + output_padded = output + output = output[:num_actual_tokens] + assert self.fused_qkv_a_proj is not None, "q lora is required for DSA." + maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight, + dependency=hidden_states, + enabled=self.enable_prefetch) + qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] + q_c, kv_no_split = qkv_lora.split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], + dim=-1, + ) + q_c = self.q_a_layernorm(q_c) + + # Process for Flash Comm V1 + q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + q_c.contiguous(), need_gather_q_kv) + kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + kv_no_split.contiguous(), need_gather_q_kv) + + if has_prefill: + wait_for_kv_layer_from_connector(layer_name) + + slot_mapping = attn_metadata.slot_mapping[:num_actual_tokens] + ql_nope, q_pe = \ + self._q_proj_and_k_up_proj(q_c) + q_pe = self.rope_single(q_pe, attn_metadata.cos, attn_metadata.sin) + k_pe, k_nope = self.exec_kv(kv_no_split, attn_metadata.cos, + attn_metadata.sin, kv_cache, slot_mapping) + + topk_indices = self.indexer_select(x=hidden_states, + qr=q_c, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + need_gather_q_kv=need_gather_q_kv) + attn_output = torch.ops.custom.npu_sparse_flash_attention( + query=ql_nope, + key=k_nope, + value=k_nope, + sparse_indices=topk_indices, + scale_value=self.scale, + sparse_block_size=1, + block_table=attn_metadata.block_tables, + actual_seq_lengths_query=attn_metadata.cum_query_lens, + actual_seq_lengths_kv=attn_metadata.seq_lens, + query_rope=q_pe, + key_rope=k_pe, + layout_query="TND", + layout_kv="PA_BSND", + sparse_mode=3, + ) + attn_output = self._v_up_proj(attn_output) + maybe_npu_prefetch(inputs=self.o_proj.weight, + dependency=attn_output, + max_size=MAX_O_PROJ_PREFETCH_SIZE, + enabled=self.enable_prefetch) + output[...] = self.o_proj(attn_output)[0] + return output_padded + + def indexer_select( + self, + x: torch.Tensor, + qr: torch.Tensor, + kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + attn_metadata: M, + need_gather_q_kv: bool = False, + ): + cos = attn_metadata.cos + sin = attn_metadata.sin + + cos_q, sin_q = cos, sin + cos = cos.view(-1, 1, 1, self.qk_rope_head_dim) + sin = sin.view(-1, 1, 1, self.qk_rope_head_dim) + + # q process in new stream + q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] + q = q.view(-1, self.n_head, self.head_dim) # [b,s,64,128] + q_pe, q_nope = torch.split( + q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], + dim=-1) # [b,s,64,64+64] + + q_pe = q_pe.unsqueeze(2) + q_pe = torch_npu.npu_interleave_rope(q_pe, cos_q, sin_q) + q_pe = q_pe.squeeze(2) + q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128] + + k_proj, _ = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128] + k_proj = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + k_proj, need_gather_q_kv) + k = self.k_norm(k_proj).unsqueeze(1) + k_pe, k_nope = torch.split( + k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], + dim=-1) # [b,s,64+64] + + k_pe = k_pe.unsqueeze(2) + k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin) + k_pe = k_pe.squeeze(2) + + k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128] + + if kv_cache is not None: + torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]), + attn_metadata.slot_mapping.view( + -1, 1), + k.view(-1, + k.shape[-1])) # b, s, n, d + + weights, _ = self.weights_proj(x) + weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( + weights, need_gather_q_kv) + + block_table = attn_metadata.block_tables + seq_lens = attn_metadata.seq_lens + cum_query_lens = attn_metadata.cum_query_lens + + topk_indices = torch.ops.custom.npu_lightning_indexer( + query=q, + key=kv_cache[2], + weights=weights, + actual_seq_lengths_query=cum_query_lens, + actual_seq_lengths_key=seq_lens, + block_table=block_table, + layout_query="TND", + layout_key="PA_BSND", + sparse_count=2048, + sparse_mode=3) + return topk_indices diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index 2ebbdeb6..21ea48e3 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -29,10 +29,6 @@ def register_model(): "vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen2_5_VLForConditionalGeneration_Without_Padding" ) - ModelRegistry.register_model( - "DeepseekV32ForCausalLM", - "vllm_ascend.models.deepseek_v3_2:CustomDeepseekV3ForCausalLM") - # There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization # to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM. ModelRegistry.register_model( diff --git a/vllm_ascend/models/deepseek_v3_2.py b/vllm_ascend/models/deepseek_v3_2.py deleted file mode 100644 index bf17c977..00000000 --- a/vllm_ascend/models/deepseek_v3_2.py +++ /dev/null @@ -1,658 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# # Adapted from -# # vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_v2.py -# # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py -# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py -# """Inference-only DeepseekV2/DeepseekV3 model.""" - -from typing import Any, Dict, Iterable, Optional, Union - -import torch -from torch import nn -from transformers import PretrainedConfig -from vllm.attention import AttentionMetadata -from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (divide, get_pp_group, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - get_tp_group, split_tensor_along_last_dim, - tensor_model_parallel_all_reduce) -from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED, - ColumnParallelLinear, - MergedColumnParallelLinear, - ReplicatedLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.models.deepseek_v2 import \ - yarn_get_mscale # noqa: E501 -from vllm.model_executor.models.deepseek_v2 import ( - DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM, - DeepseekV2MLAAttention, DeepseekV2MLP, DeepseekV2Model, DeepseekV2MoE, - get_spec_layer_idx_from_weight_name) -from vllm.model_executor.models.utils import ( - PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) -from vllm.model_executor.utils import set_weight_attrs -from vllm.platforms import current_platform - -from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.models.layers.sfa import AscendSFAModules, Indexer -from vllm_ascend.ops.fused_moe.fused_moe import AscendFusedMoE -from vllm_ascend.ops.linear import AscendLinearBase -from vllm_ascend.utils import vllm_version_is - -if vllm_version_is("0.11.0"): - from vllm.model_executor.layers.mla import MultiHeadLatentAttention -else: - from vllm.model_executor.layers.mla import MultiHeadLatentAttentionWrapper - - -@support_torch_compile -class AscendDeepseekV2Model(DeepseekV2Model, nn.Module): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - # Rewrite this init func mainly for removing cuda-hard code - nn.Module.__init__(self) - - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.config = config - - self.vocab_size = config.vocab_size - assert hasattr(config, "index_topk") - topk_tokens = config.index_topk - topk_indices_buffer = torch.empty( - vllm_config.scheduler_config.max_num_batched_tokens, - topk_tokens, - dtype=torch.int32, - device=current_platform.device_type) - - if get_pp_group().is_first_rank: - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=f"{prefix}.embed_tokens") - else: - self.embed_tokens = PPMissingLayer() - - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix, - topk_indices_buffer), - prefix=f"{prefix}.layers") - - if get_pp_group().is_last_rank: - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - else: - self.norm = PPMissingLayer() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) - - -class CustomDeepseekV2RowParallelLinear(RowParallelLinear): - - def __init__( - self, - input_size: int, - output_size: int, - bias: bool = True, - input_is_parallel: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - reduce_results: bool = True, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - *, - return_bias: bool = True, - disable_tp: bool = False, - ): - # Divide the weight matrix along the first dimension. - self.tp_rank = (get_tensor_model_parallel_rank() - if not disable_tp else 0) - self.tp_size = (get_tensor_model_parallel_world_size() - if not disable_tp else 1) - self.input_size_per_partition = divide(input_size, self.tp_size) - self.output_size_per_partition = output_size - self.output_partition_sizes = [output_size] - - AscendLinearBase.__init__(self, - input_size, - output_size, - skip_bias_add, - params_dtype, - quant_config, - prefix, - return_bias=return_bias, - disable_tp=disable_tp) - - self.input_is_parallel = input_is_parallel - self.reduce_results = reduce_results - - assert self.quant_method is not None - self.quant_method.create_weights( - layer=self, - input_size_per_partition=self.input_size_per_partition, - output_partition_sizes=self.output_partition_sizes, - input_size=self.input_size, - output_size=self.output_size, - params_dtype=self.params_dtype, - weight_loader=( - self.weight_loader_v2 if self.quant_method.__class__.__name__ - in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) - if not reduce_results and (bias and not skip_bias_add): - raise ValueError("When not reduce the results, adding bias to the " - "results can lead to incorrect results") - - if bias: - self.bias = nn.Parameter( - torch.empty(self.output_size, dtype=params_dtype)) - set_weight_attrs(self.bias, { - "output_dim": 0, - "weight_loader": self.weight_loader, - }) - else: - self.register_parameter("bias", None) - self.update_param_tp_status() - - def forward( - self, - input_, - is_prefill=True, - is_force_scatter=False - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]: - if self.input_is_parallel: - input_parallel = input_ - else: - tp_rank = get_tensor_model_parallel_rank() - splitted_input = split_tensor_along_last_dim( - input_, num_partitions=self.tp_size) - input_parallel = splitted_input[tp_rank].contiguous() - - # Matrix multiply. - assert self.quant_method is not None - # Only fuse bias add into GEMM for rank 0 (this ensures that - # bias will not get added more than once in TP>1 case) - bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias - output_parallel = self.quant_method.apply(self, - input_parallel, - bias=bias_) - if self.reduce_results and self.tp_size > 1: - output = tensor_model_parallel_all_reduce(output_parallel) - else: - output = output_parallel - - output_bias = self.bias if self.skip_bias_add else None - - if not self.return_bias: - return output - return output, output_bias - - -class CustomDeepseekV2SFAAttention(DeepseekV2MLAAttention): - - def __init__( - self, - config: PretrainedConfig, - hidden_size: int, - num_heads: int, - qk_nope_head_dim: int, - qk_rope_head_dim: int, - v_head_dim: int, - q_lora_rank: Optional[int], - kv_lora_rank: int, - rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, - max_position_embeddings: int = 8192, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - nn.Module.__init__(self) - self.hidden_size = hidden_size - self.qk_nope_head_dim = qk_nope_head_dim - self.qk_rope_head_dim = qk_rope_head_dim - self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim - self.v_head_dim = v_head_dim - - self.q_lora_rank = q_lora_rank - self.kv_lora_rank = kv_lora_rank - - self.num_heads = num_heads - self.tp_size = get_tensor_model_parallel_world_size() - assert num_heads % self.tp_size == 0 - self.num_local_heads = num_heads // self.tp_size - self.layers = config.num_hidden_layers - self.first_k_dense_replace = config.first_k_dense_replace - - self.scaling = self.qk_head_dim**-0.5 - self.rope_theta = rope_theta - self.max_position_embeddings = max_position_embeddings - - self.prefix = prefix - self.debug_layer_idx = int(self.prefix.split(".")[-2]) - - ascend_config = get_ascend_config() - self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp - - if self.q_lora_rank is not None: - self.q_a_layernorm = RMSNorm(self.q_lora_rank, - eps=config.rms_norm_eps) - self.q_b_proj = ColumnParallelLinear( - q_lora_rank, - self.num_heads * self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_b_proj", - return_bias=False, - ) - else: - self.q_proj = ColumnParallelLinear( - self.hidden_size, - self.num_heads * self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_proj", - return_bias=False, - ) - - self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, - eps=config.rms_norm_eps) - self.kv_b_proj = ColumnParallelLinear( - self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.kv_b_proj", - return_bias=False, - ) - self.o_proj = CustomDeepseekV2RowParallelLinear( - self.num_heads * self.v_head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj", - return_bias=False, - ) - - if self.q_lora_rank is not None: - self.fused_qkv_a_proj = MergedColumnParallelLinear( - self.hidden_size, - [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.fused_qkv_a_proj", - disable_tp=True) - self.kv_a_proj_with_mqa = None - else: - self.kv_a_proj_with_mqa = ReplicatedLinear( - self.hidden_size, - self.kv_lora_rank + self.qk_rope_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.kv_a_proj_with_mqa") - - if rope_scaling: - rope_scaling["rope_type"] = 'deepseek_yarn' - - self.rotary_emb = get_rope(qk_rope_head_dim, - rotary_dim=qk_rope_head_dim, - max_position=max_position_embeddings, - base=rope_theta, - rope_scaling=rope_scaling, - is_neox_style=False) - - if rope_scaling: - mscale_all_dim = rope_scaling.get("mscale_all_dim", False) - scaling_factor = rope_scaling["factor"] - mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) - self.scaling = self.scaling * mscale * mscale - - self.dim: int = config.hidden_size # 7168 - # TODO(zzzzwwjj): wait transformers add these params - self.n_heads: int = 64 # 64 - self.head_dim: int = 128 # 128 - self.index_topk: int = 2048 # 2048 - self.indexer = Indexer( - config, - quant_config=quant_config, - dim=self.dim, - n_heads=self.n_heads, - head_dim=self.head_dim, - index_topk=self.index_topk, - prefix=f"{prefix}.indexer", - ) - - sfa_modules = AscendSFAModules( - q_a_layernorm=self.q_a_layernorm - if self.q_lora_rank is not None else None, - q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, - q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None, - kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, - fused_qkv_a_proj=self.fused_qkv_a_proj - if self.q_lora_rank is not None else None, - kv_a_layernorm=self.kv_a_layernorm, - kv_b_proj=self.kv_b_proj, - o_proj=self.o_proj, - rotary_emb=self.rotary_emb, - indexer=self.indexer, - is_sparse=hasattr(config, "index_topk"), - topk_indices_buffer=None) - - if vllm_version_is("0.11.0"): - self.sfa_attn = MultiHeadLatentAttention( - hidden_size=self.hidden_size, - num_heads=self.num_local_heads, - scale=self.scaling, - qk_nope_head_dim=self.qk_nope_head_dim, - qk_rope_head_dim=self.qk_rope_head_dim, - v_head_dim=self.v_head_dim, - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.kv_lora_rank, - mla_modules=sfa_modules, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix, - ) - else: - self.sfa_attn = MultiHeadLatentAttentionWrapper( - hidden_size=self.hidden_size, - num_heads=self.num_local_heads, - scale=self.scaling, - qk_nope_head_dim=self.qk_nope_head_dim, - qk_rope_head_dim=self.qk_rope_head_dim, - v_head_dim=self.v_head_dim, - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.kv_lora_rank, - mla_modules=sfa_modules, - cache_config=cache_config, - quant_config=quant_config, - prefix=prefix, - ) - self.prefix = prefix - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: Optional[torch.Tensor] = None, - attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: - return self.sfa_attn(positions, hidden_states, kv_cache, attn_metadata) - - -class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): - - def __init__(self, - vllm_config: VllmConfig, - prefix: str, - topk_indices_buffer=None) -> None: - nn.Module.__init__(self) - config = vllm_config.model_config.hf_config - model_config = vllm_config.model_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - parallel_config = vllm_config.parallel_config - - self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) - # DecoderLayers are created with `make_layers` which passes the prefix - # with the layer's index. - layer_idx = int(prefix.split(sep='.')[-1]) - self.layer_idx = layer_idx - self.layers = config.num_hidden_layers - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tp_group().rank_in_group - # TODO: enable mla in vllm-ascend - if model_config.use_mla: - attn_cls = CustomDeepseekV2SFAAttention - else: - attn_cls = DeepseekV2Attention - self.self_attn = attn_cls( - config=config, - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - qk_nope_head_dim=config.qk_nope_head_dim, - qk_rope_head_dim=config.qk_rope_head_dim, - v_head_dim=config.v_head_dim, - q_lora_rank=config.q_lora_rank - if hasattr(config, "q_lora_rank") else None, - kv_lora_rank=config.kv_lora_rank, - rope_theta=rope_theta, - rope_scaling=rope_scaling, - max_position_embeddings=max_position_embeddings, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - ) - - if (config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace - and layer_idx % config.moe_layer_freq == 0): - self.mlp = DeepseekV2MoE( - config=config, - parallel_config=parallel_config, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - ) - if self.mlp.gate.e_score_correction_bias is not None: - self.mlp.gate.e_score_correction_bias.data = ( - self.mlp.gate.e_score_correction_bias.data.to( - dtype=torch.get_default_dtype())) - else: - self.mlp = DeepseekV2MLP( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.routed_scaling_factor = config.routed_scaling_factor - self.first_k_dense_replace = config.first_k_dense_replace - self.tp_group = get_tp_group().device_group - - -class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - nn.Module.__init__(self) - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.config = config - self.quant_config = quant_config - - # `packed_modules_mapping` needs to be modified before - # initializing DeepseekV2Model, as it is passed inplace to - # quantization config init and may be used to select the - # quant_method for relevant layers during initialization. - self.fuse_qkv_a_proj = hasattr( - config, "q_lora_rank") and config.q_lora_rank is not None - if self.fuse_qkv_a_proj: - self.packed_modules_mapping["fused_qkv_a_proj"] = [ - "q_a_proj", - "kv_a_proj_with_mqa", - ] - - self.model = AscendDeepseekV2Model(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "model")) - if get_pp_group().is_last_rank: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "lm_head")) - else: - self.lm_head = PPMissingLayer() - self.logits_processor = LogitsProcessor(config.vocab_size) - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - self.expert_weights: list[Any] = [] - - # Set MoE hyperparameters - self.num_moe_layers = (config.num_hidden_layers - - config.first_k_dense_replace) - self.num_expert_groups = config.n_group - - self.moe_layers: list[FusedMoE] = [] - example_moe = None - for layer in self.model.layers: - if isinstance(layer, PPMissingLayer): - continue - - assert isinstance(layer, DeepseekV2DecoderLayer) - if isinstance(layer.mlp, DeepseekV2MoE): - # Pick last one layer since the first ones may be dense layers. - example_moe = layer.mlp - self.moe_layers.append(layer.mlp.experts) - - if example_moe is None: - raise RuntimeError("No DeepseekV2MoE layer found in model.layers.") - - self.num_logical_experts = example_moe.n_logical_experts - self.num_physical_experts = example_moe.n_physical_experts - self.num_local_physical_experts = example_moe.n_local_physical_experts - self.num_routed_experts = example_moe.n_routed_experts - self.num_shared_experts = example_moe.n_shared_experts - self.num_redundant_experts = example_moe.n_redundant_experts - - # NOTE: This `load_weights` is mainly copied from - # https://github.com/vllm-project/vllm/commit/07b8fae219b1fff51ef115c38c44b51395be5bb5 - # to fix CI, and it is different from the implementation in main - # TODO: support eplb style load_weights - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - """""" - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ("fused_qkv_a_proj", "q_a_proj", 0), - ("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1), - ] - - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = AscendFusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts) - - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if "module" in name: - continue - - spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) - if spec_layer is not None: - continue # skip spec decode layers for main model - - for (param_name, weight_name, shard_id) in stacked_params_mapping: - # Skip non-stacked layers and experts (experts handled below). - if weight_name not in name: - continue - # We have mlp.experts[0].gate_proj in the checkpoint. - # Since we handle the experts below in expert_params_mapping, - # we need to skip here BEFORE we update the name, otherwise - # name will be updated to mlp.experts[0].gate_up_proj, which - # will then be updated below in expert_params_mapping - # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if (("mlp.experts." in name) and name not in params_dict): - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id, - return_success=False) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - - -class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM): - pass - - -DeepseekV2DecoderLayer.__init__ = CustomDeepseekV2DecoderLayer.__init__ diff --git a/vllm_ascend/models/layers/mla.py b/vllm_ascend/models/layers/mla.py index f77f4677..4ea4a27b 100644 --- a/vllm_ascend/models/layers/mla.py +++ b/vllm_ascend/models/layers/mla.py @@ -52,6 +52,35 @@ else: from vllm.model_executor.layers.mla import MultiHeadLatentAttentionWrapper +class IndexerWrapper(nn.Module): + ''' + A wrapper of Indexer for Deepseek v3.2. + This wrapper is currently used to solve the fp8 hard code issue of vllm's deepseek_v2.py. + It wraps the original Indexer, inherits its module weights + (including wq_b, wk, weights_proj, k_norm) + while deletes the unused topk_indices_buffer and k_cache to save memory. + TODO: Will be removed once original Indexer supports different quantization methods. + ''' + + def __init__(self, vllm_indexer: nn.Module) -> None: + super().__init__() + + self.n_head: int = vllm_indexer.n_head # 64 + self.head_dim: int = vllm_indexer.head_dim # 128 + self.topk_tokens: int = vllm_indexer.topk_tokens # 2048 + self.q_lora_rank: int = vllm_indexer.q_lora_rank # 1536 + self.wq_b = vllm_indexer.wq_b + self.wk = vllm_indexer.wk + self.weights_proj = vllm_indexer.weights_proj + self.k_norm = vllm_indexer.k_norm + self.softmax_scale = vllm_indexer.softmax_scale + vllm_indexer.topk_indices_buffer = None # delete topk_indices_buffer + vllm_indexer.k_cache = None # delete k_cache + + def forward(self): + return + + # TODO(whx): adapt v0.11.0 and DSA class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper): @@ -86,6 +115,10 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper): self.first_k_dense_replace = hf_config.first_k_dense_replace self.tp_size = get_tensor_model_parallel_world_size() self.layers = hf_config.num_hidden_layers + if mla_modules.indexer is not None: + ascend_indexer = IndexerWrapper(mla_modules.indexer) + else: + ascend_indexer = None if vllm_version_is("0.11.0"): self.mla_attn = Attention( @@ -97,6 +130,8 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper): quant_config=quant_config, prefix=f"{prefix}.attn", use_mla=True, + indexer=ascend_indexer, + use_sparse=mla_modules.is_sparse, # MLA Args q_lora_rank=self.q_lora_rank, kv_lora_rank=self.kv_lora_rank, @@ -128,7 +163,7 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper): quant_config=quant_config, prefix=f"{prefix}.attn", use_sparse=mla_modules.is_sparse, - indexer=mla_modules.indexer, + indexer=ascend_indexer, # extra args rotary_emb=mla_modules.rotary_emb, fused_qkv_a_proj=mla_modules.fused_qkv_a_proj, diff --git a/vllm_ascend/models/layers/sfa.py b/vllm_ascend/models/layers/sfa.py deleted file mode 100644 index 53343716..00000000 --- a/vllm_ascend/models/layers/sfa.py +++ /dev/null @@ -1,275 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from dataclasses import dataclass -from typing import Optional - -import torch -from torch import nn -from vllm.attention import AttentionMetadata -from vllm.config import CacheConfig, get_current_vllm_config -from vllm.distributed import get_tensor_model_parallel_world_size -from vllm.forward_context import ForwardContext, get_forward_context -from vllm.model_executor.layers.linear import ReplicatedLinear -from vllm.model_executor.layers.mla import MLAModules -from vllm.model_executor.layers.quantization import QuantizationConfig - -from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.utils import vllm_version_is - -if vllm_version_is("0.11.0"): - from vllm.attention import Attention - from vllm.model_executor.layers.mla import \ - MultiHeadLatentAttention as MultiHeadLatentAttentionWrapper - from vllm.utils import direct_register_custom_op -else: - from vllm.attention.layer import MLAAttention - from vllm.model_executor.layers.mla import MultiHeadLatentAttentionWrapper - from vllm.utils.torch_utils import direct_register_custom_op - - -@dataclass -class AscendSFAModules: - q_a_layernorm: Optional[torch.nn.Module] - q_proj: Optional[torch.nn.Module] - kv_a_proj_with_mqa: torch.nn.Module - kv_a_layernorm: torch.nn.Module - kv_b_proj: torch.nn.Module - o_proj: torch.nn.Module - rotary_emb: torch.nn.Module - indexer: torch.nn.Module - is_sparse: bool - fused_qkv_a_proj: Optional[torch.nn.Module] - q_b_proj: Optional[torch.nn.Module] - topk_indices_buffer: Optional[torch.Tensor] - - -class AscendSparseFlashAttention(MultiHeadLatentAttentionWrapper): - - def __init__( - self, - hidden_size: int, - num_heads: int, - scale: float, - qk_nope_head_dim: int, - qk_rope_head_dim: int, - v_head_dim: int, - q_lora_rank: Optional[int], - kv_lora_rank: int, - mla_modules: MLAModules, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - nn.Module.__init__(self) - self.hidden_size = hidden_size - self.kv_lora_rank = kv_lora_rank - self.qk_rope_head_dim = qk_rope_head_dim - self.q_lora_rank = q_lora_rank - self.qk_nope_head_dim = qk_nope_head_dim - self.qk_head_dim = qk_rope_head_dim + qk_nope_head_dim - self.v_head_dim = v_head_dim - self.prefix = prefix - self.scaling = scale - self.indexer = mla_modules.indexer - self.is_sparse = mla_modules.is_sparse - hf_config = get_current_vllm_config().model_config.hf_config - self.enable_shared_expert_dp = get_ascend_config( - ).enable_shared_expert_dp - self.debug_layer_idx = int(self.prefix.split(".")[-2]) - self.first_k_dense_replace = hf_config.first_k_dense_replace - self.tp_size = get_tensor_model_parallel_world_size() - self.layers = hf_config.num_hidden_layers - - if vllm_version_is("0.11.0"): - self.sfa_attn = Attention( - num_heads=num_heads, - head_size=self.kv_lora_rank + self.qk_rope_head_dim, - scale=scale, - num_kv_heads=1, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - use_mla=True, - use_sparse=True, - indexer=self.indexer, - # SFA Args - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.kv_lora_rank, - qk_nope_head_dim=self.qk_nope_head_dim, - qk_rope_head_dim=self.qk_rope_head_dim, - v_head_dim=self.v_head_dim, - qk_head_dim=self.qk_head_dim, - rotary_emb=mla_modules.rotary_emb, - fused_qkv_a_proj=mla_modules.fused_qkv_a_proj, - q_b_proj=mla_modules.q_b_proj, - q_a_layernorm=mla_modules.q_a_layernorm, - q_proj=mla_modules.q_proj, - kv_a_proj_with_mqa=mla_modules.kv_a_proj_with_mqa, - kv_a_layernorm=mla_modules.kv_a_layernorm, - kv_b_proj=mla_modules.kv_b_proj, - o_proj=mla_modules.o_proj, - ) - else: - self.sfa_attn = MLAAttention( - num_heads=num_heads, - scale=scale, - qk_nope_head_dim=self.qk_nope_head_dim, - qk_rope_head_dim=self.qk_rope_head_dim, - v_head_dim=self.v_head_dim, - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.kv_lora_rank, - kv_b_proj=mla_modules.kv_b_proj, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - use_sparse=mla_modules.is_sparse, - indexer=mla_modules.indexer, - # extra args - rotary_emb=mla_modules.rotary_emb, - fused_qkv_a_proj=mla_modules.fused_qkv_a_proj, - q_b_proj=mla_modules.q_b_proj, - q_a_layernorm=mla_modules.q_a_layernorm, - q_proj=mla_modules.q_proj, - kv_a_proj_with_mqa=mla_modules.kv_a_proj_with_mqa, - kv_a_layernorm=mla_modules.kv_a_layernorm, - o_proj=mla_modules.o_proj, - ) - - compilation_config = get_current_vllm_config().compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - compilation_config.static_forward_context[prefix] = self - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: Optional[torch.Tensor] = None, - attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: - num_tokens = hidden_states.shape[0] - need_gather_q_kv = False - if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: - # Simulate all gather to calculate output shape - num_tokens = num_tokens * self.tp_size - need_gather_q_kv = True - if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace: - output_shape = hidden_states.shape - else: - rows = num_tokens // self.tp_size - if num_tokens % self.tp_size: - rows += 1 - output_shape = (rows, hidden_states.shape[1]) - # FIXME: This does not seem right, should make sure the buffer is fixed - output = torch.empty(output_shape, - dtype=hidden_states.dtype, - device=hidden_states.device) - torch.ops.vllm.sfa_forward(hidden_states, need_gather_q_kv, output, - self.prefix) - output = output.view(-1, output_shape[-1]) - return output - - -def sfa_forward( - hidden_states: torch.Tensor, - need_gather_q_kv: bool, - output: torch.Tensor, - layer_name: str, -) -> None: - forward_context: ForwardContext = get_forward_context() - self = forward_context.no_compile_layers[layer_name] - if forward_context.attn_metadata: - attn_metadata = forward_context.attn_metadata[self.sfa_attn.layer_name] - else: - attn_metadata = forward_context.attn_metadata - kv_cache = self.sfa_attn.kv_cache[forward_context.virtual_engine] - self.sfa_attn.impl.forward(hidden_states, kv_cache, attn_metadata, - need_gather_q_kv, output) - return - - -class Indexer(nn.Module): - - def __init__(self, - config, - dim: int = 7168, - n_heads: int = 64, - head_dim: int = 128, - index_topk: int = 2048, - q_lora_rank: int = 1536, - rope_head_dim: int = 64, - quant_config: Optional[QuantizationConfig] = None, - prefix: Optional[str] = ""): - super().__init__() - - self.dim: int = dim # 7168 - self.n_heads: int = n_heads # 64 - self.head_dim: int = head_dim # 128 - self.rope_head_dim: int = rope_head_dim # 64 - self.index_topk: int = index_topk # 2048 - self.q_lora_rank: int = q_lora_rank # 1536 - self.wq_b = ReplicatedLinear( - self.q_lora_rank, - self.n_heads * self.head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.wq_b", - return_bias=False, - ) - self.wk = ReplicatedLinear( - self.dim, - self.head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.wk", - return_bias=False, - ) - self.weights_proj = ReplicatedLinear( - self.dim, - self.n_heads, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.weights_proj", - return_bias=False, - ) - self.k_norm = nn.LayerNorm(self.head_dim) - self.softmax_scale = self.head_dim**-0.5 - - def forward(self): - return - - -def sfa_forward_fake( - hidden_states: torch.Tensor, - need_gather_q_kv: bool, - output: torch.Tensor, - layer_name: str, -) -> None: - return - - -direct_register_custom_op( - op_name="sfa_forward", - op_func=sfa_forward, - mutates_args=["output"], - fake_impl=sfa_forward_fake, - dispatch_key="PrivateUse1", -) diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index ae6d3997..846c4832 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -33,3 +33,4 @@ from vllm_ascend.utils import vllm_version_is if vllm_version_is("0.11.0"): import vllm_ascend.patch.worker.patch_deepseek_mtp # noqa + import vllm_ascend.patch.worker.patch_deepseek_v3_2 # noqa diff --git a/vllm_ascend/patch/worker/patch_deepseek_v3_2.py b/vllm_ascend/patch/worker/patch_deepseek_v3_2.py new file mode 100644 index 00000000..cdafcb67 --- /dev/null +++ b/vllm_ascend/patch/worker/patch_deepseek_v3_2.py @@ -0,0 +1,108 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from itertools import islice +from typing import Optional, Union + +import torch +import vllm.model_executor.models.deepseek_v2 +from torch import nn +from vllm.compilation.decorators import support_torch_compile +from vllm.config import VllmConfig +from vllm.distributed import get_pp_group +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.vocab_parallel_embedding import \ + VocabParallelEmbedding +from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer +from vllm.model_executor.models.utils import ( + PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers) +from vllm.sequence import IntermediateTensors + + +@support_torch_compile +class DeepseekV2Model(nn.Module): + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + + self.vocab_size = config.vocab_size + self.is_v32 = hasattr(config, "index_topk") + topk_indices_buffer = None + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens") + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix, + topk_indices_buffer), + prefix=f"{prefix}.layers") + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + intermediate_tensors: Optional[IntermediateTensors], + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for layer in islice(self.layers, self.start_layer, self.end_layer): + hidden_states, residual = layer(positions, hidden_states, residual) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +vllm.model_executor.models.deepseek_v2.DeepseekV2Model = DeepseekV2Model diff --git a/vllm_ascend/torchair/models/torchair_deepseek_v2.py b/vllm_ascend/torchair/models/torchair_deepseek_v2.py index 3709be33..3faf28f8 100644 --- a/vllm_ascend/torchair/models/torchair_deepseek_v2.py +++ b/vllm_ascend/torchair/models/torchair_deepseek_v2.py @@ -69,7 +69,6 @@ from vllm.sequence import IntermediateTensors from vllm_ascend import envs from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.models.layers.sfa import Indexer from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.quantization.quant_config import AscendLinearMethod from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE @@ -83,6 +82,57 @@ else: from vllm.attention.layer import MLAAttention +class Indexer(nn.Module): + + def __init__(self, + config, + dim: int = 7168, + n_heads: int = 64, + head_dim: int = 128, + index_topk: int = 2048, + q_lora_rank: int = 1536, + rope_head_dim: int = 64, + quant_config: Optional[QuantizationConfig] = None, + prefix: Optional[str] = ""): + super().__init__() + + self.dim: int = dim # 7168 + self.n_heads: int = n_heads # 64 + self.head_dim: int = head_dim # 128 + self.rope_head_dim: int = rope_head_dim # 64 + self.index_topk: int = index_topk # 2048 + self.q_lora_rank: int = q_lora_rank # 1536 + self.wq_b = ReplicatedLinear( + self.q_lora_rank, + self.n_heads * self.head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wq_b", + return_bias=False, + ) + self.wk = ReplicatedLinear( + self.dim, + self.head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wk", + return_bias=False, + ) + self.weights_proj = ReplicatedLinear( + self.dim, + self.n_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.weights_proj", + return_bias=False, + ) + self.k_norm = nn.LayerNorm(self.head_dim) + self.softmax_scale = self.head_dim**-0.5 + + def forward(self): + return + + class TorchairDeepseekV2SiluAndMul(SiluAndMul): def __init__(self, diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index d591f2b9..ddcbfb59 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -577,7 +577,6 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): from vllm.model_executor.custom_op import CustomOp from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention - from vllm_ascend.models.layers.sfa import AscendSparseFlashAttention from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul from vllm_ascend.ops.fused_moe.fused_moe import (AscendFusedMoE, AscendSharedFusedMoE) @@ -625,10 +624,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None): mla_to_register = "MultiHeadLatentAttention" if vllm_version_is( "0.11.0") else "MultiHeadLatentAttentionWrapper" if vllm_config and vllm_config.model_config and vllm_config.model_config.use_mla: - AscendMLAAttentionWarrper = AscendSparseFlashAttention if hasattr( - vllm_config.model_config.hf_config, - "index_topk") else AscendMultiHeadLatentAttention - REGISTERED_ASCEND_OPS[mla_to_register] = AscendMLAAttentionWarrper + REGISTERED_ASCEND_OPS[mla_to_register] = AscendMultiHeadLatentAttention for name, op_cls in REGISTERED_ASCEND_OPS.items(): CustomOp.register_oot(_decorated_op_cls=op_cls, name=name)