[Model][3/N] Refactor sfa into mla and remove deepseek_v3_2.py (#3769)

This is the follow-up PR to PR #3189, which continues to refactor sfa
into mla and finally remove deepseek_v3_2.py. This is the last PR of
deepseek modeling refactoring. After this, all deepseek-related model
codes are removed from vllm_ascend.

FurtherMore, after this PR deepseek v3.2 can run chunk-prefill with
correct accuracy.

- vLLM version: v0.11.0rc3
- vLLM main:
83f478bb19

---------

Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
whx
2025-10-30 17:06:38 +08:00
committed by GitHub
parent eff3e5fc6f
commit f6149f3894
10 changed files with 751 additions and 1935 deletions

View File

@@ -52,6 +52,8 @@ if prefill_context_parallel_enable():
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.core.sched.output import SchedulerOutput
MAX_O_PROJ_PREFETCH_SIZE = 16 * 1024 * 1024
class AscendMLABackend(AttentionBackend): class AscendMLABackend(AttentionBackend):
@@ -808,16 +810,17 @@ class AscendMLAImpl(MLAAttentionImpl):
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29) # self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29) # self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
# Currently mlapo only supports W8A8 quantization in MLA scenario if self.enable_mlapo:
# TODO(whx): modify this limitation when mlapo supports floating point # Currently mlapo only supports W8A8 quantization in MLA scenario
if self.fused_qkv_a_proj is None or not isinstance( # TODO(whx): modify this limitation when mlapo supports floating point
getattr(self.fused_qkv_a_proj.quant_method, 'quant_method', if self.fused_qkv_a_proj is None or not isinstance(
None), AscendW8A8LinearMethod): getattr(self.fused_qkv_a_proj.quant_method, 'quant_method',
self.enable_mlapo = False None), AscendW8A8LinearMethod):
logger.warning_once( self.enable_mlapo = False
"Currently mlapo only supports W8A8 quantization in MLA scenario." logger.warning_once(
"Some layers in your model are not quantized with W8A8," "Currently mlapo only supports W8A8 quantization in MLA scenario."
"thus mlapo is disabled for these layers.") "Some layers in your model are not quantized with W8A8,"
"thus mlapo is disabled for these layers.")
if self.enable_mlapo: if self.enable_mlapo:
self._process_weights_for_fused_mlapo(act_dtype) self._process_weights_for_fused_mlapo(act_dtype)
@@ -1282,12 +1285,13 @@ class AscendMLAImpl(MLAAttentionImpl):
def _mla_preprocess(self, layer_name, hidden_states, kv_cache, def _mla_preprocess(self, layer_name, hidden_states, kv_cache,
attn_metadata, need_gather_q_kv): attn_metadata, need_gather_q_kv):
# MLA Preprocess: # MLA Preprocess:
# 1. Perform q_a_proj and q_a_layernorm to obtain q_c # 1. Perform fused_qkv_a_proj and q_a_layernorm to obtain q_c and kv_no_split
# 2. Perform kv_a_proj_with_mqa to obtain kv_no_split # or
# 3. If need_gather_q_kv, perform all_gather. # Perform kv_a_proj_with_mqa to obtain kv_no_split
# 4. Preprocess decode tokens, write kv cache and get: # 2. If need_gather_q_kv, perform all_gather.
# 3. Preprocess decode tokens, write kv cache and get:
# decode_ql_nope, decode_q_pe, decode_k_pe, decode_k_nope # decode_ql_nope, decode_q_pe, decode_k_pe, decode_k_nope
# 5. Preprocess prefill tokens, write kv cache and get: # 4. Preprocess prefill tokens, write kv cache and get:
# prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value # prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value
has_decode = attn_metadata.num_decodes > 0 has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0 has_prefill = attn_metadata.num_prefills > 0

View File

@@ -1,6 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type, from typing import TYPE_CHECKING, ClassVar, Optional, Tuple, Type, TypeVar
TypeVar)
import torch import torch
import torch_npu import torch_npu
@@ -8,17 +7,20 @@ from torch import nn
from vllm.attention.backends.abstract import (AttentionBackend, from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata, AttentionMetadata,
MLAAttentionImpl) MLAAttentionImpl)
from vllm.config import VllmConfig, get_current_vllm_config from vllm.config import VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.linear import (LinearBase, from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.utils import cdiv, round_down
from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.mla_v1 import MAX_O_PROJ_PREFETCH_SIZE
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
split_decodes_and_prefills) wait_for_kv_layer_from_connector)
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
is_enable_nz)
from vllm_ascend.worker.npu_input_batch import InputBatch from vllm_ascend.worker.npu_input_batch import InputBatch
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -51,51 +53,6 @@ class AscendSFABackend(AttentionBackend):
return AscendSFAImpl return AscendSFAImpl
@dataclass
class AscendSFAPrefillMetadata:
""" Prefill Specific Metadata for Ascend"""
@dataclass
class ChunkedContextMetadata:
# New for MLA (compared to FlashAttention)
# For handling chunked prefill
cu_seq_lens: torch.Tensor
starts: torch.Tensor
seq_tot: list[int]
max_seq_lens: list[int]
workspace: torch.Tensor
chunk_seq_lens: torch.Tensor
attn_mask: torch.Tensor
query_lens: list[int]
seq_lens: list[int]
context_lens: torch.Tensor
input_positions: torch.Tensor
query_start_loc: torch.Tensor
block_table: torch.Tensor
max_query_len: int
max_seq_lens: int
sin: torch.Tensor
cos: torch.Tensor
chunked_context: Optional[ChunkedContextMetadata] = None
@dataclass
class AscendSFADecodeMetadata:
# Input positions for rotrary embeddings since for MLA the rotary
# position embeddings are applied inside the attention backend
input_positions: torch.Tensor
block_table: torch.Tensor
seq_lens: torch.Tensor
max_seq_lens: int
seq_lens_list: list[int]
actual_seq_lengths_q: torch.Tensor
sin: torch.Tensor
cos: torch.Tensor
attn_mask: Optional[torch.Tensor] = None
@dataclass @dataclass
class AscendSFAMetadata: class AscendSFAMetadata:
"""Metadata for MLACommon. """Metadata for MLACommon.
@@ -110,41 +67,23 @@ class AscendSFAMetadata:
# |---------- context_len ----------| # |---------- context_len ----------|
# |-------------------- seq_len ---------------------| # |-------------------- seq_len ---------------------|
# |-- query_len ---| # |-- query_len ---|
has_prefill: bool
num_actual_tokens: int # Number of tokens excluding padding. num_actual_tokens: int # Number of tokens excluding padding.
slot_mapping: torch.Tensor slot_mapping: torch.Tensor
query_start_loc: torch.Tensor
seq_lens: torch.Tensor seq_lens: torch.Tensor
cum_query_lens: torch.Tensor
block_tables: torch.Tensor block_tables: torch.Tensor
sin: torch.Tensor
# New for MLA (compared to FlashAttention) cos: torch.Tensor
# For handling prefill decode split
num_decodes: int
num_decode_tokens: int
num_prefills: int
# For logging. # For logging.
num_input_tokens: int = 0 # Number of tokens including padding. num_input_tokens: int = 0 # Number of tokens including padding.
query_lens: Optional[list[int]] = None
# The dimension of the attention heads # The dimension of the attention heads
head_dim: Optional[int] = None head_dim: Optional[int] = None
attn_mask: torch.Tensor = None attn_mask: torch.Tensor = None
# chunked prefill by default if no attn_states passed # chunked prefill by default if no attn_states passed
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
decode: Optional[AscendSFADecodeMetadata] = None
prefill: Optional[AscendSFAPrefillMetadata] = None
def __post_init__(self):
pass
# supported_head_sizes = AscendMLABackend.get_supported_head_sizes()
# if self.head_dim is not None and self.head_dim \
# not in supported_head_sizes:
# raise ValueError(
# f"Only {supported_head_sizes} are supported for head_dim,",
# f"received {self.head_dim}.")
M = TypeVar("M", bound=AscendSFAMetadata) M = TypeVar("M", bound=AscendSFAMetadata)
@@ -170,11 +109,9 @@ class AscendSFAMetadataBuilder:
self.vllm_config = vllm_config self.vllm_config = vllm_config
self.model_config = vllm_config.model_config self.model_config = vllm_config.model_config
self.device = device self.device = device
scheduler_config = vllm_config.scheduler_config
self.block_size = vllm_config.cache_config.block_size self.block_size = vllm_config.cache_config.block_size
self.max_blocks = (vllm_config.model_config.max_model_len + self.max_blocks = (vllm_config.model_config.max_model_len +
self.block_size - 1) // self.block_size self.block_size - 1) // self.block_size
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
self.speculative_config = vllm_config.speculative_config self.speculative_config = vllm_config.speculative_config
self.decode_threshold = 1 self.decode_threshold = 1
@@ -185,81 +122,14 @@ class AscendSFAMetadataBuilder:
npu_fused_infer_attention_score TND layout's limit of 16, \ npu_fused_infer_attention_score TND layout's limit of 16, \
got {self.decode_threshold}" got {self.decode_threshold}"
if self.chunked_prefill_enabled:
self.chunked_prefill_workspace_size = min(
# Max sure there is enough for 8 full length request or at least
# 4 pages of cache per request
max(8 * self.model_config.max_model_len,
4 * scheduler_config.max_num_seqs * self.block_size),
# For long-context models try not to over-allocate limiting
# kv-cache space, limiting it to 64k tokens,
# which would result in the workspace being:
# 2*(576)*(64*1024) = 144mb
# (assuming 576 MLA head dim, and fp16)
# which would result in up-projected context being
# 2*(192*128)*(64*1024) = 3gb
# (assuming 192 QK head dim, 128 heads, and fp16)
128 * 1024)
assert self.chunked_prefill_workspace_size >= \
scheduler_config.max_num_seqs * self.block_size
self.chunked_prefill_workspace = torch.empty(
(self.chunked_prefill_workspace_size,
self.model_config.get_head_size()),
dtype=self.model_config.dtype,
device=device,
)
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
self.cos_cache = None self.cos_cache = None
self.sin_cache = None self.sin_cache = None
def reorder_batch(self, input_batch: "InputBatch", def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool: scheduler_output: "SchedulerOutput") -> bool:
# We now want to reorder the batch so that the "decode" requests are at # No need to reorder for Ascend SFA
# the front and the "prefill" requests are at the using the least amount return False
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
# where attention is likely memory-bound and "prefill" to mean requests
# where attention is likely compute-bound, TODO(lucas): figure out a
# better naming here)
decodes = []
prefills = []
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
if num_tokens <= self.decode_threshold:
decodes.append(i)
else:
prefills.append(i)
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
# relatively stationary (and new request are generally appended to the
# persistent batch so already should be at the back)
# To achieve this we loop over the decodes in descending order and
# the prefills in ascending order. We swap decodes from the "back"
# i.e. past where the last decode should be in the reodorered with
# prefills from the front of the batch.
# `decodes` and `prefills` are already in ascending order just based on
# the above loop
num_decodes = len(decodes)
num_prefills = len(prefills)
first_prefill = 0
modified_batch = False
for i in range(1, min(num_decodes, num_prefills) + 1):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
if decodes[num_decodes - i] >= num_decodes:
input_batch.swap_states(prefills[first_prefill],
decodes[num_decodes - i])
first_prefill += 1
modified_batch = True
else:
break
# Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this
return modified_batch
def build( def build(
self, self,
@@ -269,16 +139,7 @@ class AscendSFAMetadataBuilder:
) -> AscendSFAMetadata: ) -> AscendSFAMetadata:
num_reqs = common_attn_metadata.num_reqs num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens num_actual_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
# Note(simon): be careful about the CPU <> GPU memory movement in this
# function. We should avoid GPU -> CPU sync as much as possible because
# it blocks on all previous kernels.
device = self.device device = self.device
block_table = (common_attn_metadata.block_table_tensor[:num_reqs]) block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
@@ -289,6 +150,9 @@ class AscendSFAMetadataBuilder:
input_positions = common_attn_metadata.positions[: input_positions = common_attn_metadata.positions[:
num_actual_tokens].long( num_actual_tokens].long(
) )
query_start_loc = common_attn_metadata.query_start_loc
query_lens = query_start_loc[1:] - query_start_loc[:-1]
has_prefill = any(query_lens > self.decode_threshold)
if self.cos_cache is None: if self.cos_cache is None:
self.cos_cache = model.model.layers[ self.cos_cache = model.model.layers[
@@ -301,146 +165,29 @@ class AscendSFAMetadataBuilder:
self.sin_cache = self.sin_cache.to( # type: ignore self.sin_cache = self.sin_cache.to( # type: ignore
self.model_config.dtype) # type: ignore self.model_config.dtype) # type: ignore
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] cum_query_lens = query_start_loc_cpu[1:num_reqs + 1].to(
query_lens = query_seq_lens_cpu[:num_reqs] torch.int32).to(device, non_blocking=True)
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs].to(
num_computed_tokens_cpu = (seq_lens - query_lens) torch.int32).to(device, non_blocking=True)
prefill_metadata = None cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
chunked_context_metadata = None 1).unsqueeze(2)
if num_prefills > 0: sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
reqs_start = num_decodes # prefill_start 1).unsqueeze(2)
tokens_start = num_decode_tokens
max_query_len = query_lens[reqs_start:].max().item()
max_seq_lens = seq_lens[reqs_start:].max().item()
prefill_query_start_loc = query_start_loc[
reqs_start:] - query_start_loc[reqs_start]
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
max_context_len_cpu = context_lens_cpu.max().item()
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
max_context_chunk = (self.chunked_prefill_workspace_size //
num_prefills_with_context_cpu)
max_context_chunk = round_down(max_context_chunk,
self.block_size)
assert max_context_chunk > 0
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \
.unsqueeze(1).expand(-1, num_prefills) * max_context_chunk
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
chunk_starts + max_context_chunk)
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
cu_seq_lens_cpu = torch.zeros(num_chunks,
num_prefills + 1,
dtype=torch.int32,
pin_memory=True)
torch.cumsum(chunk_seq_lens,
dim=1,
out=cu_seq_lens_cpu[:, 1:],
dtype=torch.int32)
chunked_context_metadata = \
AscendSFAPrefillMetadata.ChunkedContextMetadata(
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
starts=chunk_starts.to(device, non_blocking=True),
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens,
workspace=self.chunked_prefill_workspace,
)
prefill_input_positions = input_positions[tokens_start:]
cos = self.cos_cache[
prefill_input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
sin = self.sin_cache[
prefill_input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
actual_query_lens = torch.tensor(query_lens[reqs_start:],
dtype=torch.int32).npu()
query_lens_prefill_sfa = torch.cumsum(actual_query_lens,
dim=0).to(torch.int32)
seq_lens_prefill_sfa = seq_lens[reqs_start:].to(torch.int32).npu()
prefill_metadata = AscendSFAPrefillMetadata(
attn_mask=common_attn_metadata.attn_mask,
query_lens=query_lens_prefill_sfa,
seq_lens=seq_lens_prefill_sfa,
context_lens=seq_lens[reqs_start:],
input_positions=prefill_input_positions,
block_table=block_table[reqs_start:, ...],
max_query_len=max_query_len,
max_seq_lens=max_seq_lens,
query_start_loc=prefill_query_start_loc,
chunked_context=chunked_context_metadata,
sin=sin,
cos=cos,
)
decode_metadata = None
if num_decodes > 0:
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].to(
torch.int32).npu()
max_seq_lens = seq_lens[:num_decodes].max().item()
seq_lens = seq_lens[:num_decodes].to(torch.int32).npu()
input_positions = input_positions[:num_decode_tokens]
block_table = block_table[:num_decodes, ...]
seq_lens_list = seq_lens.tolist()
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
decode_metadata = AscendSFADecodeMetadata(
input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
seq_lens_list=seq_lens_list,
max_seq_lens=max_seq_lens,
attn_mask=common_attn_metadata.spec_attn_mask,
actual_seq_lengths_q=actual_seq_lengths_q,
sin=sin,
cos=cos)
return self.metadata_cls( # type: ignore return self.metadata_cls( # type: ignore
has_prefill=has_prefill,
num_input_tokens=common_attn_metadata.num_input_tokens, num_input_tokens=common_attn_metadata.num_input_tokens,
num_actual_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens,
query_lens=query_lens.tolist(), cum_query_lens=cum_query_lens,
seq_lens=seq_lens,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
head_dim=self.model_config.get_head_size(), head_dim=self.model_config.get_head_size(),
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
attn_mask=common_attn_metadata.attn_mask, attn_mask=common_attn_metadata.attn_mask,
attn_state=common_attn_metadata.attn_state, attn_state=common_attn_metadata.attn_state,
prefill=prefill_metadata,
decode=decode_metadata,
query_start_loc=query_start_loc,
block_tables=block_table, block_tables=block_table,
seq_lens=seq_lens, sin=sin,
) cos=cos)
class PrefillSFAPreprocessResult(NamedTuple):
q_nope: Optional[torch.Tensor] = None
q_pe: Optional[torch.Tensor] = None
k_nope: Optional[torch.Tensor] = None
k_pe: Optional[torch.Tensor] = None
topk_indices: Optional[torch.Tensor] = None
query_states: Optional[torch.Tensor] = None
key_states: Optional[torch.Tensor] = None
class DecodeSFAPreprocessResult(NamedTuple):
q_nope: Optional[torch.Tensor] = None
q_pe: Optional[torch.Tensor] = None
# nope_cache: Optional[torch.Tensor] = None
# rope_cache: Optional[torch.Tensor] = None
topk_indices: Optional[torch.Tensor] = None
query_states: Optional[torch.Tensor] = None
key_states: Optional[torch.Tensor] = None
bsz: Optional[int] = None
class AscendSFAImpl(MLAAttentionImpl): class AscendSFAImpl(MLAAttentionImpl):
@@ -493,28 +240,17 @@ class AscendSFAImpl(MLAAttentionImpl):
ascend_config = get_ascend_config() ascend_config = get_ascend_config()
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
self.enable_prefetch = ascend_config.weight_prefetch_config.enabled
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
vllm_config = get_current_vllm_config() assert self.indexer is not None, "Indexer is required for DSA."
self.ring_mla_mask_size = 512
self.prefill_mask = None
# indexer param # indexer param
self.dim = self.indexer.dim self.n_head: int = self.indexer.n_head # 64
self.n_heads: int = self.indexer.n_heads # 64
self.head_dim: int = self.indexer.head_dim # 128 self.head_dim: int = self.indexer.head_dim # 128
self.index_topk: int = self.indexer.index_topk # 2048
self.wq_b = self.indexer.wq_b self.wq_b = self.indexer.wq_b
self.wk = self.indexer.wk self.wk = self.indexer.wk
self.weights_proj = self.indexer.weights_proj self.weights_proj = self.indexer.weights_proj
self.k_norm = self.indexer.k_norm self.k_norm = self.indexer.k_norm
self.softmax_scale = self.indexer.softmax_scale
# Adapt torch air graph mode with spec decoding.
speculative_config = vllm_config.speculative_config
if speculative_config is not None:
self.spec_token_num = speculative_config.num_speculative_tokens
assert self.spec_token_num > 0
self.cp_size = 1 self.cp_size = 1
@@ -541,6 +277,10 @@ class AscendSFAImpl(MLAAttentionImpl):
del eye del eye
# standardize to (output, input) # standardize to (output, input)
return dequant_weights.T return dequant_weights.T
# Weight will be reshaped next. To be on the safe side, the format
# of the weight should be reverted to FRACTAL_AND.
layer.weight.data = torch_npu.npu_format_cast(
layer.weight.data, ACL_FORMAT_FRACTAL_ND)
return layer.weight return layer.weight
# we currently do not have quantized bmm's which are needed for # we currently do not have quantized bmm's which are needed for
@@ -561,232 +301,98 @@ class AscendSFAImpl(MLAAttentionImpl):
self.qk_nope_head_dim + self.v_head_dim, self.qk_nope_head_dim + self.v_head_dim,
) )
self.kv_b_proj_w_k, self.kv_b_proj_w_v = kv_b_proj_weight.split( W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1) [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
# Convert from (L, N, V) to (N, L, V) # Convert from (L, N, V) to (N, L, V)
self.kv_b_proj_w_v = self.kv_b_proj_w_v.transpose(0, 1).contiguous() self.W_UV = W_UV.transpose(0, 1).contiguous()
# Convert from (L, N, P) to (N, P, L) # Convert from (L, N, P) to (N, P, L)
self.kv_b_proj_w_k = self.kv_b_proj_w_k.permute(1, 2, 0).contiguous() self.W_UK_T = W_UK.permute(1, 2, 0).contiguous()
# Function `get_and_maybe_dequant_weights` will cast the weights to
# FRACTAL_AND. So we need to cast to FRACTAL_NZ again.
if is_enable_nz():
self.kv_b_proj.weight.data = torch_npu.npu_format_cast(
self.kv_b_proj.weight.data, ACL_FORMAT_FRACTAL_NZ)
# Waiting for BMM NZ support # Waiting for BMM NZ support
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29) # self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29) # self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
def _sfa_preprocess(self, hidden_states, kv_cache, attn_metadata, def _v_up_proj(self, x):
need_gather_q_kv): if self.W_UV.shape[0] * self.W_UV.shape[1] < 65536:
# SFA Preprocess: x = x.view(-1, self.num_heads, self.kv_lora_rank)
# 1. Perform q_a_proj and q_a_layernorm to obtain q_c x = torch_npu.npu_transpose_batchmatmul(x,
# 2. Perform kv_a_proj_with_mqa to obtain kv_no_split self.W_UV,
# 3. If need_gather_q_kv, perform all_gather. perm_x1=[1, 0, 2],
# 4. Preprocess decode tokens, write kv cache and get: perm_x2=[0, 1, 2],
# decode_ql_nope, decode_q_pe, decode_k_pe, decode_k_nope perm_y=[1, 0, 2])
# 5. Preprocess prefill tokens, write kv cache and get: x = x.reshape(-1, self.num_heads * self.v_head_dim)
# prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value else:
has_decode = attn_metadata.num_decodes > 0 # Convert from (B, N, L) to (N, B, L)
has_prefill = attn_metadata.num_prefills > 0 x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
# # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
# # Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
return x
num_decode_tokens = attn_metadata.num_decode_tokens # Return `ql_nope`, `q_pe`
num_actual_tokens = attn_metadata.num_actual_tokens def _q_proj_and_k_up_proj(self, x):
if need_gather_q_kv: q_nope, q_pe = self.q_proj(x)[0]\
# q_c = get_tp_group().all_gather(q_c, 0) .view(-1, self.num_heads, self.qk_head_dim)\
# kv_no_split = get_tp_group().all_gather(kv_no_split, 0) .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
hidden_states = get_tp_group().all_gather(hidden_states, 0)
# hidden_states_decode = hidden_states[:num_decode_tokens]
# if self.q_a_proj is not None:
# npu_prefetch(self.q_a_proj.weight,
# hidden_states,
# enabled=self.enable_prefetch)
# ckq = self.q_a_proj(hidden_states) # q down
# q_c = self.q_a_layernorm(ckq) # q down layernorm
# else:
# q_c = hidden_states
# kv_no_split = self.kv_a_proj_with_mqa(hidden_states) # c_kv # Convert from (B, N, P) to (N, B, P)
# Process for shared_expert_dp q_nope = q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope = torch.bmm(q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
return ql_nope.transpose(0, 1), q_pe
decode_preprocess_res = None def exec_kv(
prefill_preprocess_res = None self,
# Preprocess for decode tokens kv_no_split: torch.Tensor,
if has_decode: cos: torch.Tensor,
q_len = 1 sin: torch.Tensor,
hidden_states_decode = hidden_states[:num_decode_tokens] kv_cache: Tuple,
decode_qkv_lora = self.fused_qkv_a_proj(hidden_states_decode)[0] slots: torch.Tensor,
decode_q_c, decode_kv_no_split = decode_qkv_lora.split( ):
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], B = kv_no_split.shape[0]
dim=-1, N = self.num_kv_heads
) S = 1
decode_q_c = self.q_a_layernorm(decode_q_c) # q down layernorm # npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
decode_kv_no_split = decode_kv_no_split.contiguous() kv_no_split = kv_no_split.view(
B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
cache_mode = "PA_NZ" if self.enable_kv_nz else "PA"
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
kv_no_split,
self.kv_a_layernorm.weight,
cos,
sin,
slots.to(torch.int64),
kv_cache[1],
kv_cache[0],
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode=cache_mode,
)
return k_pe, k_nope
# decode_q_c = q_c[:num_decode_tokens] def rope_single(
decode_slot_mapping = attn_metadata.slot_mapping[: self,
num_decode_tokens] x: torch.Tensor,
# decode_kv_no_split = decode_kv_no_split[:num_decode_tokens] cos: torch.Tensor,
sin: torch.Tensor,
decode_q = self.q_b_proj(decode_q_c) ) -> torch.Tensor:
bsz, _ = decode_q.shape B, N, D = x.shape
decode_q = decode_q.view(bsz, self.num_heads, 1, self.qk_head_dim) S = 1
decode_q_nope, decode_q_pe = torch.split( x = x.view(B, N, S, D)
decode_q, [self.qk_nope_head_dim, self.qk_rope_head_dim], x = torch_npu.npu_interleave_rope(x, cos, sin)
dim=-1) return x.view(B, N, D)
decode_q_nope = decode_q_nope.view(
-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1)
decode_q_nope = (torch.matmul(decode_q_nope,
self.kv_b_proj_w_k).transpose(
1,
0).view(bsz, q_len,
self.num_heads,
self.kv_lora_rank))
# stream2 kv
key_cache = kv_cache[0]
value_cache = kv_cache[1]
cos = attn_metadata.decode.cos
sin = attn_metadata.decode.sin
cos_q, sin_q = cos, sin
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
decode_kv_no_split = decode_kv_no_split.unsqueeze(1).unsqueeze(1)
decode_k_rope, decode_k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
decode_kv_no_split,
self.kv_a_layernorm.weight,
cos,
sin,
decode_slot_mapping.to(torch.int64),
value_cache,
key_cache,
c_kv_scale=None,
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode='PA') # adapter NZ
# nz_block_size = 16
# KVCACHE_NZ_DIM = 16
# decode_k_nope = decode_k_nope.view(block_num, 1, self.kv_lora_rank // nz_block_size, block_size, nz_block_size)
# decode_k_rope = decode_k_rope.view(block_num, 1, self.qk_rope_head_dim // KVCACHE_NZ_DIM, block_size, KVCACHE_NZ_DIM)
decode_q_pe = torch_npu.npu_interleave_rope(decode_q_pe, cos,
sin) # BNSD
decode_q_nope = decode_q_nope.view(bsz, self.num_heads,
self.kv_lora_rank)
decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1)
topk_indices = self.indexer_select(hidden_states_decode,
decode_q_c,
attn_metadata=attn_metadata,
cos=cos,
sin=sin,
kv_cache=kv_cache)
query_states = (decode_q_nope, decode_q_pe)
key_states = (decode_k_nope, decode_k_rope)
decode_preprocess_res = DecodeSFAPreprocessResult(
q_nope=decode_q_nope,
q_pe=decode_q_pe,
# nope_cache = nope_cache,
# rope_cache = rope_cache,
topk_indices=topk_indices,
query_states=query_states,
key_states=key_states,
bsz=bsz,
)
# Preprocess for prefill tokens
if has_prefill:
bsz = 1
hidden_states_prefill = hidden_states[
num_decode_tokens:num_actual_tokens]
prefill_qkv_lora = self.fused_qkv_a_proj(hidden_states_prefill)[0]
prefill_q_c, prefill_kv_no_split = prefill_qkv_lora.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
dim=-1,
)
prefill_q_c = self.q_a_layernorm(prefill_q_c) # q down layernorm
prefill_kv_no_split = prefill_kv_no_split.contiguous()
# prefill_q_c = q_c[
# num_decode_tokens:num_actual_tokens]
prefill_slot_mapping = attn_metadata.slot_mapping[
num_decode_tokens:num_actual_tokens]
# decode_kv_no_split = decode_kv_no_split[:num_decode_tokens]
prefill_slot_mapping = attn_metadata.slot_mapping[
num_decode_tokens:num_actual_tokens]
# prefill_kv_no_split = kv_no_split[
# num_decode_tokens:num_actual_tokens]
# prefill_qr = prefill_q_c[num_decode_tokens:num_actual_tokens]
prefill_qr = prefill_q_c
prefill_q = self.q_b_proj(prefill_qr)
prefill_q = prefill_q.view(-1, self.num_heads, self.qk_head_dim)
prefill_q_nope, prefill_q_pe = torch.split(
prefill_q, [self.qk_nope_head_dim, self.qk_rope_head_dim],
dim=-1)
prefill_q_nope = prefill_q_nope.view(
-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1)
prefill_q_nope = (torch.matmul(prefill_q_nope,
self.kv_b_proj_w_k).transpose(
1,
0).view(-1, self.num_heads,
self.kv_lora_rank))
prefill_q_pe = prefill_q_pe.unsqueeze(2)
# stream2 kv
nope_cache = kv_cache[0]
rope_cache = kv_cache[1]
cos = attn_metadata.prefill.cos
sin = attn_metadata.prefill.sin
cos_q, sin_q = cos, sin
# cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
# sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
prefill_q_pe = torch_npu.npu_interleave_rope(
prefill_q_pe, cos_q, sin_q) # BNSD
prefill_q_pe = prefill_q_pe.squeeze(2) #BSH
# q[..., self.qk_nope_head_dim:] = prefill_q_pe # TODO:????
prefill_latent_cache = prefill_kv_no_split # (B,S,N,D)
prefill_k_pe, prefill_k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
prefill_latent_cache.view(
-1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim),
self.kv_a_layernorm.weight,
cos.view(-1, 1, 1, self.qk_rope_head_dim),
sin.view(-1, 1, 1, self.qk_rope_head_dim),
prefill_slot_mapping.to(torch.int64),
rope_cache,
nope_cache,
k_rope_scale=None,
c_kv_scale=None,
k_rope_offset=None,
c_kv_offset=None,
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode="PA")
topk_indices = self.indexer_select(x=hidden_states_prefill,
qr=prefill_qr,
kv_cache=kv_cache,
cos=cos,
sin=sin,
attn_metadata=attn_metadata)
query_states = (prefill_q_nope, prefill_q_pe)
key_states = (prefill_k_nope, prefill_k_pe)
prefill_preprocess_res = PrefillSFAPreprocessResult(
q_nope=prefill_q_nope,
q_pe=prefill_q_pe,
topk_indices=topk_indices,
k_nope=prefill_k_nope,
k_pe=prefill_k_pe,
query_states=query_states,
key_states=key_states,
)
return decode_preprocess_res, prefill_preprocess_res
def forward( def forward(
self, self,
layer_name,
hidden_states: torch.Tensor, # query in unified attn hidden_states: torch.Tensor, # query in unified attn
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
attn_metadata: M, attn_metadata: M,
@@ -797,141 +403,86 @@ class AscendSFAImpl(MLAAttentionImpl):
if attn_metadata is None: if attn_metadata is None:
# Profiling run. # Profiling run.
return output.fill_(0) return output.fill_(0)
has_prefill = attn_metadata.has_prefill
num_actual_tokens = attn_metadata.num_actual_tokens num_actual_tokens = attn_metadata.num_actual_tokens
assert attn_metadata.num_decodes is not None and \ hidden_states = hidden_states[:num_actual_tokens]
attn_metadata.num_prefills is not None and \
attn_metadata.num_decode_tokens is not None
num_decode_tokens = attn_metadata.num_decode_tokens
# Inputs and outputs may be padded for CUDA graphs # Inputs and outputs may be padded for CUDA graphs
output = output[:num_actual_tokens, ...] output_padded = output
o_proj_input_shape = (num_actual_tokens, output = output[:num_actual_tokens]
self.num_heads * self.v_head_dim) assert self.fused_qkv_a_proj is not None, "q lora is required for DSA."
o_proj_input = torch.empty(o_proj_input_shape, maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight,
dtype=hidden_states.dtype, dependency=hidden_states,
device=hidden_states.device) enabled=self.enable_prefetch)
qkv_lora = self.fused_qkv_a_proj(hidden_states)[0]
q_c, kv_no_split = qkv_lora.split(
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
dim=-1,
)
q_c = self.q_a_layernorm(q_c)
# SFA Preprocess # Process for Flash Comm V1
decode_preprocess_res, prefill_preprocess_res = self._sfa_preprocess( q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states, kv_cache, attn_metadata, need_gather_q_kv) q_c.contiguous(), need_gather_q_kv)
kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
kv_no_split.contiguous(), need_gather_q_kv)
if decode_preprocess_res is not None: if has_prefill:
# bsz, q_len, _, _ = query_states[0].shape wait_for_kv_layer_from_connector(layer_name)
decode_attn_output = self.apply_attention_fusion(
query_states=decode_preprocess_res.query_states,
key_states=decode_preprocess_res.key_states,
attn_metadata=attn_metadata,
topk_indices=decode_preprocess_res.topk_indices)
o_proj_input[:num_decode_tokens] = decode_attn_output
if prefill_preprocess_res is not None: slot_mapping = attn_metadata.slot_mapping[:num_actual_tokens]
prefill_attn_output = self.apply_attention_fusion( ql_nope, q_pe = \
query_states=prefill_preprocess_res.query_states, self._q_proj_and_k_up_proj(q_c)
key_states=prefill_preprocess_res.key_states, q_pe = self.rope_single(q_pe, attn_metadata.cos, attn_metadata.sin)
attn_metadata=attn_metadata, k_pe, k_nope = self.exec_kv(kv_no_split, attn_metadata.cos,
topk_indices=prefill_preprocess_res.topk_indices) attn_metadata.sin, kv_cache, slot_mapping)
o_proj_input[num_decode_tokens:] = prefill_attn_output
output[...] = self.mla_epilog(o_proj_input, absorb=True) topk_indices = self.indexer_select(x=hidden_states,
return output qr=q_c,
kv_cache=kv_cache,
def apply_attention_fusion(self, query_states, key_states, topk_indices, attn_metadata=attn_metadata,
attn_metadata: M): need_gather_q_kv=need_gather_q_kv)
# repeat k/v heads if n_kv_heads < n_heads attn_output = torch.ops.custom.npu_sparse_flash_attention(
q_nope, q_pe = query_states query=ql_nope,
k_nope, k_rope = key_states key=k_nope,
value=k_nope,
if attn_metadata.prefill is not None: sparse_indices=topk_indices,
scale_value=self.scale,
prefill_metadata = attn_metadata.prefill sparse_block_size=1,
block_table=attn_metadata.block_tables,
slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention( actual_seq_lengths_query=attn_metadata.cum_query_lens,
query=q_nope, actual_seq_lengths_kv=attn_metadata.seq_lens,
key=k_nope, query_rope=q_pe,
value=k_nope, key_rope=k_pe,
sparse_indices=topk_indices, layout_query="TND",
scale_value=self.scale, layout_kv="PA_BSND",
sparse_block_size=1, sparse_mode=3,
block_table=prefill_metadata.block_table, )
actual_seq_lengths_query=prefill_metadata.query_lens, attn_output = self._v_up_proj(attn_output)
actual_seq_lengths_kv=prefill_metadata.seq_lens, maybe_npu_prefetch(inputs=self.o_proj.weight,
query_rope=q_pe, dependency=attn_output,
key_rope=k_rope, max_size=MAX_O_PROJ_PREFETCH_SIZE,
layout_query="TND", enabled=self.enable_prefetch)
layout_kv="PA_BSND", output[...] = self.o_proj(attn_output)[0]
sparse_mode=3, return output_padded
)
elif attn_metadata.decode is not None:
decode_metadata = attn_metadata.decode
slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention(
query=q_nope,
key=k_nope,
value=k_nope,
sparse_indices=topk_indices,
scale_value=self.scale,
sparse_block_size=1,
block_table=attn_metadata.decode.block_table,
actual_seq_lengths_query=decode_metadata.actual_seq_lengths_q,
actual_seq_lengths_kv=decode_metadata.seq_lens,
query_rope=q_pe,
key_rope=k_rope,
layout_query="TND",
layout_kv="PA_BSND",
sparse_mode=3,
)
slc_fa_fusion = slc_fa_fusion.squeeze(1)
slc_fa_fusion = slc_fa_fusion.transpose(0, 1)
# input shape [N//attn_tp_size, T(bs*q_len), D]
# output shape [T(bs*q_len), N//attn_tp_size, D]
attn_output = torch.matmul(slc_fa_fusion,
self.kv_b_proj_w_v).transpose(1, 0).reshape(
-1, self.num_heads * self.v_head_dim)
# Note: Considering the fusion rules of TBMM, attn_output shape requires a 3-dim shape, and
# with appropriate tensor stride for the later 'view' operation if oproj_tp_size > 1.
# after reshape: [T(bs*q_len), 1, N//attn_tp_size*D]
# attn_output = attn_output.reshape(-1, self.num_heads * self.v_head_dim)
return attn_output
def mla_epilog(self,
attn_output: torch.Tensor = None,
absorb: bool = False):
# TODO: need to check
attn_output = self.o_proj(attn_output.reshape(attn_output.shape[0],
-1),
is_prefill=True,
is_force_scatter=False)
return attn_output
def indexer_select( def indexer_select(
self, self,
x: torch.Tensor, x: torch.Tensor,
qr: torch.Tensor, qr: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
cos,
sin,
attn_metadata: M, attn_metadata: M,
need_gather_q_kv: bool = False,
): ):
if attn_metadata.prefill is not None: cos = attn_metadata.cos
actual_seq_lengths_query = attn_metadata.prefill.query_lens sin = attn_metadata.sin
actual_seq_lengths_key = attn_metadata.prefill.seq_lens
block_table = attn_metadata.prefill.block_table
elif attn_metadata.decode is not None:
actual_seq_lengths_query = attn_metadata.decode.actual_seq_lengths_q
actual_seq_lengths_key = attn_metadata.decode.seq_lens
block_table = attn_metadata.decode.block_table
cos_q, sin_q = cos, sin cos_q, sin_q = cos, sin
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim) cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim) sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
# q process in new stream # q process in new stream
q = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128] q, _ = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
q = q.view(-1, self.n_heads, self.head_dim) # [b,s,64,128] q = q.view(-1, self.n_head, self.head_dim) # [b,s,64,128]
q_pe, q_nope = torch.split( q_pe, q_nope = torch.split(
q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
dim=-1) # [b,s,64,64+64] dim=-1) # [b,s,64,64+64]
@@ -941,7 +492,9 @@ class AscendSFAImpl(MLAAttentionImpl):
q_pe = q_pe.squeeze(2) q_pe = q_pe.squeeze(2)
q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128] q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128]
k_proj = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128] k_proj, _ = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128]
k_proj = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
k_proj, need_gather_q_kv)
k = self.k_norm(k_proj).unsqueeze(1) k = self.k_norm(k_proj).unsqueeze(1)
k_pe, k_nope = torch.split( k_pe, k_nope = torch.split(
k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim], k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
@@ -960,14 +513,20 @@ class AscendSFAImpl(MLAAttentionImpl):
k.view(-1, k.view(-1,
k.shape[-1])) # b, s, n, d k.shape[-1])) # b, s, n, d
weights = self.weights_proj(x) weights, _ = self.weights_proj(x)
weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
weights, need_gather_q_kv)
block_table = attn_metadata.block_tables
seq_lens = attn_metadata.seq_lens
cum_query_lens = attn_metadata.cum_query_lens
topk_indices = torch.ops.custom.npu_lightning_indexer( topk_indices = torch.ops.custom.npu_lightning_indexer(
query=q, query=q,
key=kv_cache[2], key=kv_cache[2],
weights=weights, weights=weights,
actual_seq_lengths_query=actual_seq_lengths_query, actual_seq_lengths_query=cum_query_lens,
actual_seq_lengths_key=actual_seq_lengths_key, actual_seq_lengths_key=seq_lens,
block_table=block_table, block_table=block_table,
layout_query="TND", layout_query="TND",
layout_key="PA_BSND", layout_key="PA_BSND",

View File

@@ -29,10 +29,6 @@ def register_model():
"vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen2_5_VLForConditionalGeneration_Without_Padding" "vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen2_5_VLForConditionalGeneration_Without_Padding"
) )
ModelRegistry.register_model(
"DeepseekV32ForCausalLM",
"vllm_ascend.models.deepseek_v3_2:CustomDeepseekV3ForCausalLM")
# There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization # There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization
# to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM. # to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM.
ModelRegistry.register_model( ModelRegistry.register_model(

View File

@@ -1,658 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# # Adapted from
# # vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_v2.py
# # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py
# """Inference-only DeepseekV2/DeepseekV3 model."""
from typing import Any, Dict, Iterable, Optional, Union
import torch
from torch import nn
from transformers import PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (divide, get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
get_tp_group, split_tensor_along_last_dim,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED,
ColumnParallelLinear,
MergedColumnParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, maybe_remap_kv_scale_name)
from vllm.model_executor.models.deepseek_v2 import \
yarn_get_mscale # noqa: E501
from vllm.model_executor.models.deepseek_v2 import (
DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM,
DeepseekV2MLAAttention, DeepseekV2MLP, DeepseekV2Model, DeepseekV2MoE,
get_spec_layer_idx_from_weight_name)
from vllm.model_executor.models.utils import (
PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.models.layers.sfa import AscendSFAModules, Indexer
from vllm_ascend.ops.fused_moe.fused_moe import AscendFusedMoE
from vllm_ascend.ops.linear import AscendLinearBase
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"):
from vllm.model_executor.layers.mla import MultiHeadLatentAttention
else:
from vllm.model_executor.layers.mla import MultiHeadLatentAttentionWrapper
@support_torch_compile
class AscendDeepseekV2Model(DeepseekV2Model, nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
# Rewrite this init func mainly for removing cuda-hard code
nn.Module.__init__(self)
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.vocab_size = config.vocab_size
assert hasattr(config, "index_topk")
topk_tokens = config.index_topk
topk_indices_buffer = torch.empty(
vllm_config.scheduler_config.max_num_batched_tokens,
topk_tokens,
dtype=torch.int32,
device=current_platform.device_type)
if get_pp_group().is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens")
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix,
topk_indices_buffer),
prefix=f"{prefix}.layers")
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
class CustomDeepseekV2RowParallelLinear(RowParallelLinear):
def __init__(
self,
input_size: int,
output_size: int,
bias: bool = True,
input_is_parallel: bool = True,
skip_bias_add: bool = False,
params_dtype: Optional[torch.dtype] = None,
reduce_results: bool = True,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
*,
return_bias: bool = True,
disable_tp: bool = False,
):
# Divide the weight matrix along the first dimension.
self.tp_rank = (get_tensor_model_parallel_rank()
if not disable_tp else 0)
self.tp_size = (get_tensor_model_parallel_world_size()
if not disable_tp else 1)
self.input_size_per_partition = divide(input_size, self.tp_size)
self.output_size_per_partition = output_size
self.output_partition_sizes = [output_size]
AscendLinearBase.__init__(self,
input_size,
output_size,
skip_bias_add,
params_dtype,
quant_config,
prefix,
return_bias=return_bias,
disable_tp=disable_tp)
self.input_is_parallel = input_is_parallel
self.reduce_results = reduce_results
assert self.quant_method is not None
self.quant_method.create_weights(
layer=self,
input_size_per_partition=self.input_size_per_partition,
output_partition_sizes=self.output_partition_sizes,
input_size=self.input_size,
output_size=self.output_size,
params_dtype=self.params_dtype,
weight_loader=(
self.weight_loader_v2 if self.quant_method.__class__.__name__
in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader))
if not reduce_results and (bias and not skip_bias_add):
raise ValueError("When not reduce the results, adding bias to the "
"results can lead to incorrect results")
if bias:
self.bias = nn.Parameter(
torch.empty(self.output_size, dtype=params_dtype))
set_weight_attrs(self.bias, {
"output_dim": 0,
"weight_loader": self.weight_loader,
})
else:
self.register_parameter("bias", None)
self.update_param_tp_status()
def forward(
self,
input_,
is_prefill=True,
is_force_scatter=False
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
if self.input_is_parallel:
input_parallel = input_
else:
tp_rank = get_tensor_model_parallel_rank()
splitted_input = split_tensor_along_last_dim(
input_, num_partitions=self.tp_size)
input_parallel = splitted_input[tp_rank].contiguous()
# Matrix multiply.
assert self.quant_method is not None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
output_parallel = self.quant_method.apply(self,
input_parallel,
bias=bias_)
if self.reduce_results and self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
if not self.return_bias:
return output
return output, output_bias
class CustomDeepseekV2SFAAttention(DeepseekV2MLAAttention):
def __init__(
self,
config: PretrainedConfig,
hidden_size: int,
num_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: Optional[int],
kv_lora_rank: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.num_heads = num_heads
self.tp_size = get_tensor_model_parallel_world_size()
assert num_heads % self.tp_size == 0
self.num_local_heads = num_heads // self.tp_size
self.layers = config.num_hidden_layers
self.first_k_dense_replace = config.first_k_dense_replace
self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.prefix = prefix
self.debug_layer_idx = int(self.prefix.split(".")[-2])
ascend_config = get_ascend_config()
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
if self.q_lora_rank is not None:
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(
q_lora_rank,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_b_proj",
return_bias=False,
)
else:
self.q_proj = ColumnParallelLinear(
self.hidden_size,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj",
return_bias=False,
)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj",
return_bias=False,
)
self.o_proj = CustomDeepseekV2RowParallelLinear(
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
return_bias=False,
)
if self.q_lora_rank is not None:
self.fused_qkv_a_proj = MergedColumnParallelLinear(
self.hidden_size,
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.fused_qkv_a_proj",
disable_tp=True)
self.kv_a_proj_with_mqa = None
else:
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_a_proj_with_mqa")
if rope_scaling:
rope_scaling["rope_type"] = 'deepseek_yarn'
self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=False)
if rope_scaling:
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
scaling_factor = rope_scaling["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
self.dim: int = config.hidden_size # 7168
# TODO(zzzzwwjj): wait transformers add these params
self.n_heads: int = 64 # 64
self.head_dim: int = 128 # 128
self.index_topk: int = 2048 # 2048
self.indexer = Indexer(
config,
quant_config=quant_config,
dim=self.dim,
n_heads=self.n_heads,
head_dim=self.head_dim,
index_topk=self.index_topk,
prefix=f"{prefix}.indexer",
)
sfa_modules = AscendSFAModules(
q_a_layernorm=self.q_a_layernorm
if self.q_lora_rank is not None else None,
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None,
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
fused_qkv_a_proj=self.fused_qkv_a_proj
if self.q_lora_rank is not None else None,
kv_a_layernorm=self.kv_a_layernorm,
kv_b_proj=self.kv_b_proj,
o_proj=self.o_proj,
rotary_emb=self.rotary_emb,
indexer=self.indexer,
is_sparse=hasattr(config, "index_topk"),
topk_indices_buffer=None)
if vllm_version_is("0.11.0"):
self.sfa_attn = MultiHeadLatentAttention(
hidden_size=self.hidden_size,
num_heads=self.num_local_heads,
scale=self.scaling,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
mla_modules=sfa_modules,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
)
else:
self.sfa_attn = MultiHeadLatentAttentionWrapper(
hidden_size=self.hidden_size,
num_heads=self.num_local_heads,
scale=self.scaling,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
mla_modules=sfa_modules,
cache_config=cache_config,
quant_config=quant_config,
prefix=prefix,
)
self.prefix = prefix
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
return self.sfa_attn(positions, hidden_states, kv_cache, attn_metadata)
class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
def __init__(self,
vllm_config: VllmConfig,
prefix: str,
topk_indices_buffer=None) -> None:
nn.Module.__init__(self)
config = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
# DecoderLayers are created with `make_layers` which passes the prefix
# with the layer's index.
layer_idx = int(prefix.split(sep='.')[-1])
self.layer_idx = layer_idx
self.layers = config.num_hidden_layers
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tp_group().rank_in_group
# TODO: enable mla in vllm-ascend
if model_config.use_mla:
attn_cls = CustomDeepseekV2SFAAttention
else:
attn_cls = DeepseekV2Attention
self.self_attn = attn_cls(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
qk_nope_head_dim=config.qk_nope_head_dim,
qk_rope_head_dim=config.qk_rope_head_dim,
v_head_dim=config.v_head_dim,
q_lora_rank=config.q_lora_rank
if hasattr(config, "q_lora_rank") else None,
kv_lora_rank=config.kv_lora_rank,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
)
if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
and layer_idx % config.moe_layer_freq == 0):
self.mlp = DeepseekV2MoE(
config=config,
parallel_config=parallel_config,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
if self.mlp.gate.e_score_correction_bias is not None:
self.mlp.gate.e_score_correction_bias.data = (
self.mlp.gate.e_score_correction_bias.data.to(
dtype=torch.get_default_dtype()))
else:
self.mlp = DeepseekV2MLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.routed_scaling_factor = config.routed_scaling_factor
self.first_k_dense_replace = config.first_k_dense_replace
self.tp_group = get_tp_group().device_group
class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
nn.Module.__init__(self)
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.quant_config = quant_config
# `packed_modules_mapping` needs to be modified before
# initializing DeepseekV2Model, as it is passed inplace to
# quantization config init and may be used to select the
# quant_method for relevant layers during initialization.
self.fuse_qkv_a_proj = hasattr(
config, "q_lora_rank") and config.q_lora_rank is not None
if self.fuse_qkv_a_proj:
self.packed_modules_mapping["fused_qkv_a_proj"] = [
"q_a_proj",
"kv_a_proj_with_mqa",
]
self.model = AscendDeepseekV2Model(vllm_config=vllm_config,
prefix=maybe_prefix(
prefix, "model"))
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=maybe_prefix(
prefix, "lm_head"))
else:
self.lm_head = PPMissingLayer()
self.logits_processor = LogitsProcessor(config.vocab_size)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
self.expert_weights: list[Any] = []
# Set MoE hyperparameters
self.num_moe_layers = (config.num_hidden_layers -
config.first_k_dense_replace)
self.num_expert_groups = config.n_group
self.moe_layers: list[FusedMoE] = []
example_moe = None
for layer in self.model.layers:
if isinstance(layer, PPMissingLayer):
continue
assert isinstance(layer, DeepseekV2DecoderLayer)
if isinstance(layer.mlp, DeepseekV2MoE):
# Pick last one layer since the first ones may be dense layers.
example_moe = layer.mlp
self.moe_layers.append(layer.mlp.experts)
if example_moe is None:
raise RuntimeError("No DeepseekV2MoE layer found in model.layers.")
self.num_logical_experts = example_moe.n_logical_experts
self.num_physical_experts = example_moe.n_physical_experts
self.num_local_physical_experts = example_moe.n_local_physical_experts
self.num_routed_experts = example_moe.n_routed_experts
self.num_shared_experts = example_moe.n_shared_experts
self.num_redundant_experts = example_moe.n_redundant_experts
# NOTE: This `load_weights` is mainly copied from
# https://github.com/vllm-project/vllm/commit/07b8fae219b1fff51ef115c38c44b51395be5bb5
# to fix CI, and it is different from the implementation in main
# TODO: support eplb style load_weights
def load_weights(self, weights: Iterable[tuple[str,
torch.Tensor]]) -> set[str]:
""""""
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
("fused_qkv_a_proj", "q_a_proj", 0),
("fused_qkv_a_proj", "kv_a_proj_with_mqa", 1),
]
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
expert_params_mapping = AscendFusedMoE.make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.n_routed_experts)
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
if "module" in name:
continue
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is not None:
continue # skip spec decode layers for main model
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if (("mlp.experts." in name) and name not in params_dict):
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
return_success=False)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
name = maybe_remap_kv_scale_name(name, params_dict)
if name is None:
continue
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM):
pass
DeepseekV2DecoderLayer.__init__ = CustomDeepseekV2DecoderLayer.__init__

View File

@@ -52,6 +52,35 @@ else:
from vllm.model_executor.layers.mla import MultiHeadLatentAttentionWrapper from vllm.model_executor.layers.mla import MultiHeadLatentAttentionWrapper
class IndexerWrapper(nn.Module):
'''
A wrapper of Indexer for Deepseek v3.2.
This wrapper is currently used to solve the fp8 hard code issue of vllm's deepseek_v2.py.
It wraps the original Indexer, inherits its module weights
(including wq_b, wk, weights_proj, k_norm)
while deletes the unused topk_indices_buffer and k_cache to save memory.
TODO: Will be removed once original Indexer supports different quantization methods.
'''
def __init__(self, vllm_indexer: nn.Module) -> None:
super().__init__()
self.n_head: int = vllm_indexer.n_head # 64
self.head_dim: int = vllm_indexer.head_dim # 128
self.topk_tokens: int = vllm_indexer.topk_tokens # 2048
self.q_lora_rank: int = vllm_indexer.q_lora_rank # 1536
self.wq_b = vllm_indexer.wq_b
self.wk = vllm_indexer.wk
self.weights_proj = vllm_indexer.weights_proj
self.k_norm = vllm_indexer.k_norm
self.softmax_scale = vllm_indexer.softmax_scale
vllm_indexer.topk_indices_buffer = None # delete topk_indices_buffer
vllm_indexer.k_cache = None # delete k_cache
def forward(self):
return
# TODO(whx): adapt v0.11.0 and DSA # TODO(whx): adapt v0.11.0 and DSA
class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper): class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
@@ -86,6 +115,10 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
self.first_k_dense_replace = hf_config.first_k_dense_replace self.first_k_dense_replace = hf_config.first_k_dense_replace
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.layers = hf_config.num_hidden_layers self.layers = hf_config.num_hidden_layers
if mla_modules.indexer is not None:
ascend_indexer = IndexerWrapper(mla_modules.indexer)
else:
ascend_indexer = None
if vllm_version_is("0.11.0"): if vllm_version_is("0.11.0"):
self.mla_attn = Attention( self.mla_attn = Attention(
@@ -97,6 +130,8 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
use_mla=True, use_mla=True,
indexer=ascend_indexer,
use_sparse=mla_modules.is_sparse,
# MLA Args # MLA Args
q_lora_rank=self.q_lora_rank, q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank, kv_lora_rank=self.kv_lora_rank,
@@ -128,7 +163,7 @@ class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper):
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.attn", prefix=f"{prefix}.attn",
use_sparse=mla_modules.is_sparse, use_sparse=mla_modules.is_sparse,
indexer=mla_modules.indexer, indexer=ascend_indexer,
# extra args # extra args
rotary_emb=mla_modules.rotary_emb, rotary_emb=mla_modules.rotary_emb,
fused_qkv_a_proj=mla_modules.fused_qkv_a_proj, fused_qkv_a_proj=mla_modules.fused_qkv_a_proj,

View File

@@ -1,275 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional
import torch
from torch import nn
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.mla import MLAModules
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"):
from vllm.attention import Attention
from vllm.model_executor.layers.mla import \
MultiHeadLatentAttention as MultiHeadLatentAttentionWrapper
from vllm.utils import direct_register_custom_op
else:
from vllm.attention.layer import MLAAttention
from vllm.model_executor.layers.mla import MultiHeadLatentAttentionWrapper
from vllm.utils.torch_utils import direct_register_custom_op
@dataclass
class AscendSFAModules:
q_a_layernorm: Optional[torch.nn.Module]
q_proj: Optional[torch.nn.Module]
kv_a_proj_with_mqa: torch.nn.Module
kv_a_layernorm: torch.nn.Module
kv_b_proj: torch.nn.Module
o_proj: torch.nn.Module
rotary_emb: torch.nn.Module
indexer: torch.nn.Module
is_sparse: bool
fused_qkv_a_proj: Optional[torch.nn.Module]
q_b_proj: Optional[torch.nn.Module]
topk_indices_buffer: Optional[torch.Tensor]
class AscendSparseFlashAttention(MultiHeadLatentAttentionWrapper):
def __init__(
self,
hidden_size: int,
num_heads: int,
scale: float,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: Optional[int],
kv_lora_rank: int,
mla_modules: MLAModules,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.q_lora_rank = q_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_head_dim = qk_rope_head_dim + qk_nope_head_dim
self.v_head_dim = v_head_dim
self.prefix = prefix
self.scaling = scale
self.indexer = mla_modules.indexer
self.is_sparse = mla_modules.is_sparse
hf_config = get_current_vllm_config().model_config.hf_config
self.enable_shared_expert_dp = get_ascend_config(
).enable_shared_expert_dp
self.debug_layer_idx = int(self.prefix.split(".")[-2])
self.first_k_dense_replace = hf_config.first_k_dense_replace
self.tp_size = get_tensor_model_parallel_world_size()
self.layers = hf_config.num_hidden_layers
if vllm_version_is("0.11.0"):
self.sfa_attn = Attention(
num_heads=num_heads,
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
scale=scale,
num_kv_heads=1,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_mla=True,
use_sparse=True,
indexer=self.indexer,
# SFA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
qk_head_dim=self.qk_head_dim,
rotary_emb=mla_modules.rotary_emb,
fused_qkv_a_proj=mla_modules.fused_qkv_a_proj,
q_b_proj=mla_modules.q_b_proj,
q_a_layernorm=mla_modules.q_a_layernorm,
q_proj=mla_modules.q_proj,
kv_a_proj_with_mqa=mla_modules.kv_a_proj_with_mqa,
kv_a_layernorm=mla_modules.kv_a_layernorm,
kv_b_proj=mla_modules.kv_b_proj,
o_proj=mla_modules.o_proj,
)
else:
self.sfa_attn = MLAAttention(
num_heads=num_heads,
scale=scale,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
v_head_dim=self.v_head_dim,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
kv_b_proj=mla_modules.kv_b_proj,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_sparse=mla_modules.is_sparse,
indexer=mla_modules.indexer,
# extra args
rotary_emb=mla_modules.rotary_emb,
fused_qkv_a_proj=mla_modules.fused_qkv_a_proj,
q_b_proj=mla_modules.q_b_proj,
q_a_layernorm=mla_modules.q_a_layernorm,
q_proj=mla_modules.q_proj,
kv_a_proj_with_mqa=mla_modules.kv_a_proj_with_mqa,
kv_a_layernorm=mla_modules.kv_a_layernorm,
o_proj=mla_modules.o_proj,
)
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
num_tokens = hidden_states.shape[0]
need_gather_q_kv = False
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
# Simulate all gather to calculate output shape
num_tokens = num_tokens * self.tp_size
need_gather_q_kv = True
if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace:
output_shape = hidden_states.shape
else:
rows = num_tokens // self.tp_size
if num_tokens % self.tp_size:
rows += 1
output_shape = (rows, hidden_states.shape[1])
# FIXME: This does not seem right, should make sure the buffer is fixed
output = torch.empty(output_shape,
dtype=hidden_states.dtype,
device=hidden_states.device)
torch.ops.vllm.sfa_forward(hidden_states, need_gather_q_kv, output,
self.prefix)
output = output.view(-1, output_shape[-1])
return output
def sfa_forward(
hidden_states: torch.Tensor,
need_gather_q_kv: bool,
output: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
if forward_context.attn_metadata:
attn_metadata = forward_context.attn_metadata[self.sfa_attn.layer_name]
else:
attn_metadata = forward_context.attn_metadata
kv_cache = self.sfa_attn.kv_cache[forward_context.virtual_engine]
self.sfa_attn.impl.forward(hidden_states, kv_cache, attn_metadata,
need_gather_q_kv, output)
return
class Indexer(nn.Module):
def __init__(self,
config,
dim: int = 7168,
n_heads: int = 64,
head_dim: int = 128,
index_topk: int = 2048,
q_lora_rank: int = 1536,
rope_head_dim: int = 64,
quant_config: Optional[QuantizationConfig] = None,
prefix: Optional[str] = ""):
super().__init__()
self.dim: int = dim # 7168
self.n_heads: int = n_heads # 64
self.head_dim: int = head_dim # 128
self.rope_head_dim: int = rope_head_dim # 64
self.index_topk: int = index_topk # 2048
self.q_lora_rank: int = q_lora_rank # 1536
self.wq_b = ReplicatedLinear(
self.q_lora_rank,
self.n_heads * self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wq_b",
return_bias=False,
)
self.wk = ReplicatedLinear(
self.dim,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wk",
return_bias=False,
)
self.weights_proj = ReplicatedLinear(
self.dim,
self.n_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.weights_proj",
return_bias=False,
)
self.k_norm = nn.LayerNorm(self.head_dim)
self.softmax_scale = self.head_dim**-0.5
def forward(self):
return
def sfa_forward_fake(
hidden_states: torch.Tensor,
need_gather_q_kv: bool,
output: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="sfa_forward",
op_func=sfa_forward,
mutates_args=["output"],
fake_impl=sfa_forward_fake,
dispatch_key="PrivateUse1",
)

View File

@@ -33,3 +33,4 @@ from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.11.0"): if vllm_version_is("0.11.0"):
import vllm_ascend.patch.worker.patch_deepseek_mtp # noqa import vllm_ascend.patch.worker.patch_deepseek_mtp # noqa
import vllm_ascend.patch.worker.patch_deepseek_v3_2 # noqa

View File

@@ -0,0 +1,108 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from itertools import islice
from typing import Optional, Union
import torch
import vllm.model_executor.models.deepseek_v2
from torch import nn
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.vocab_parallel_embedding import \
VocabParallelEmbedding
from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer
from vllm.model_executor.models.utils import (
PPMissingLayer, make_empty_intermediate_tensors_factory, make_layers)
from vllm.sequence import IntermediateTensors
@support_torch_compile
class DeepseekV2Model(nn.Module):
fall_back_to_pt_during_load = False
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.vocab_size = config.vocab_size
self.is_v32 = hasattr(config, "index_topk")
topk_indices_buffer = None
if get_pp_group().is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens")
else:
self.embed_tokens = PPMissingLayer()
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix,
topk_indices_buffer),
prefix=f"{prefix}.layers")
if get_pp_group().is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer()
self.make_empty_intermediate_tensors = (
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
if get_pp_group().is_first_rank:
if inputs_embeds is not None:
hidden_states = inputs_embeds
else:
hidden_states = self.get_input_embeddings(input_ids)
residual = None
else:
assert intermediate_tensors is not None
hidden_states = intermediate_tensors["hidden_states"]
residual = intermediate_tensors["residual"]
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states, residual = layer(positions, hidden_states, residual)
if not get_pp_group().is_last_rank:
return IntermediateTensors({
"hidden_states": hidden_states,
"residual": residual
})
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
vllm.model_executor.models.deepseek_v2.DeepseekV2Model = DeepseekV2Model

View File

@@ -69,7 +69,6 @@ from vllm.sequence import IntermediateTensors
from vllm_ascend import envs from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.models.layers.sfa import Indexer
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
from vllm_ascend.quantization.quant_config import AscendLinearMethod from vllm_ascend.quantization.quant_config import AscendLinearMethod
from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE
@@ -83,6 +82,57 @@ else:
from vllm.attention.layer import MLAAttention from vllm.attention.layer import MLAAttention
class Indexer(nn.Module):
def __init__(self,
config,
dim: int = 7168,
n_heads: int = 64,
head_dim: int = 128,
index_topk: int = 2048,
q_lora_rank: int = 1536,
rope_head_dim: int = 64,
quant_config: Optional[QuantizationConfig] = None,
prefix: Optional[str] = ""):
super().__init__()
self.dim: int = dim # 7168
self.n_heads: int = n_heads # 64
self.head_dim: int = head_dim # 128
self.rope_head_dim: int = rope_head_dim # 64
self.index_topk: int = index_topk # 2048
self.q_lora_rank: int = q_lora_rank # 1536
self.wq_b = ReplicatedLinear(
self.q_lora_rank,
self.n_heads * self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wq_b",
return_bias=False,
)
self.wk = ReplicatedLinear(
self.dim,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wk",
return_bias=False,
)
self.weights_proj = ReplicatedLinear(
self.dim,
self.n_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.weights_proj",
return_bias=False,
)
self.k_norm = nn.LayerNorm(self.head_dim)
self.softmax_scale = self.head_dim**-0.5
def forward(self):
return
class TorchairDeepseekV2SiluAndMul(SiluAndMul): class TorchairDeepseekV2SiluAndMul(SiluAndMul):
def __init__(self, def __init__(self,

View File

@@ -577,7 +577,6 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention from vllm_ascend.models.layers.mla import AscendMultiHeadLatentAttention
from vllm_ascend.models.layers.sfa import AscendSparseFlashAttention
from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul from vllm_ascend.ops.activation import AscendQuickGELU, AscendSiluAndMul
from vllm_ascend.ops.fused_moe.fused_moe import (AscendFusedMoE, from vllm_ascend.ops.fused_moe.fused_moe import (AscendFusedMoE,
AscendSharedFusedMoE) AscendSharedFusedMoE)
@@ -625,10 +624,7 @@ def register_ascend_customop(vllm_config: Optional[VllmConfig] = None):
mla_to_register = "MultiHeadLatentAttention" if vllm_version_is( mla_to_register = "MultiHeadLatentAttention" if vllm_version_is(
"0.11.0") else "MultiHeadLatentAttentionWrapper" "0.11.0") else "MultiHeadLatentAttentionWrapper"
if vllm_config and vllm_config.model_config and vllm_config.model_config.use_mla: if vllm_config and vllm_config.model_config and vllm_config.model_config.use_mla:
AscendMLAAttentionWarrper = AscendSparseFlashAttention if hasattr( REGISTERED_ASCEND_OPS[mla_to_register] = AscendMultiHeadLatentAttention
vllm_config.model_config.hf_config,
"index_topk") else AscendMultiHeadLatentAttention
REGISTERED_ASCEND_OPS[mla_to_register] = AscendMLAAttentionWarrper
for name, op_cls in REGISTERED_ASCEND_OPS.items(): for name, op_cls in REGISTERED_ASCEND_OPS.items():
CustomOp.register_oot(_decorated_op_cls=op_cls, name=name) CustomOp.register_oot(_decorated_op_cls=op_cls, name=name)