Add DeepSeek V3.2 support (#3270)

### What this PR does / why we need it?

This PR added the initial DeepSeek V3.2 support with [vLLM
v0.11.0](https://github.com/vllm-project/vllm/tree/releases/v0.11.0)
(not released yet). We will complete vLLM adaptation as soon as
possible. This feature will be ready in recent 1-2 days.

Related doc: https://github.com/vllm-project/vllm-ascend/pull/3223 .

### Does this PR introduce _any_ user-facing change?
Yes!

### How was this patch tested?
CI passed and Run deepseek doc soon.


- vLLM version: v0.11.0rc3
- vLLM main:
https://github.com/vllm-project/vllm/commit/releases/v0.11.0

---------

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: zzzzwwjj <1183291235@qq.com>
Signed-off-by: linfeng-yuan <1102311262@qq.com>
Signed-off-by: wxsIcey <1790571317@qq.com>
Signed-off-by: MengqingCao <cmq0113@163.com>
Co-authored-by: zzzzwwjj <1183291235@qq.com>
Co-authored-by: linfeng-yuan <1102311262@qq.com>
Co-authored-by: wxsIcey <1790571317@qq.com>
Co-authored-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
wangxiyuan
2025-09-30 03:25:58 +08:00
committed by GitHub
parent 5503a3142f
commit 81bd6e4c99
27 changed files with 4354 additions and 70 deletions

View File

@@ -121,7 +121,14 @@ jobs:
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/devlib
pytest -sv --cov --cov-report=xml:unittests-coverage.xml tests/ut \
--ignore=tests/ut/test_platform.py \
--ignore=tests/ut/patch/worker/patch_common/test_patch_minicpm.py
--ignore=tests/ut/patch/worker/patch_common/test_patch_minicpm.py \
--ignore=tests/ut/core/test_scheduler.py \
--ignore=tests/ut/kv_connector/test_llmdatadist_connector.py \
--ignore=tests/ut/kv_connector/test_mooncake_connector.py \
--ignore=tests/ut/kv_connector/test_remote_decode_lifecycle.py \
--ignore=tests/ut/kv_connector/test_remote_prefill_lifecycle.py \
--ignore=tests/ut/torchair/models/test_torchair_deepseek_v2.py \
--ignore=tests/ut/torchair/test_utils.py
- name: Upload coverage to Codecov
# only upload coverage when commits merged

View File

@@ -23,5 +23,7 @@ def register():
def register_model():
import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa
from .models import register_model
register_model()

View File

@@ -34,6 +34,8 @@ class AscendConfig:
def __init__(self, vllm_config):
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
self.is_deepseek_sfa = vllm_config.model_config is not None and vllm_config.model_config.is_deepseek_mla and vllm_config.model_config.hf_text_config.model_type == "deepseek_v32"
self.use_sfa = self.is_deepseek_sfa
torchair_graph_config = additional_config.get("torchair_graph_config",
{})

View File

@@ -73,7 +73,7 @@ class AttentionMaskBuilder:
device: torch.device):
self._update_attn_cache(max_seq_len, dtype)
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous(
).to(device)
).to(device, non_blocking=True)
def get_splitfuse_attn_mask(
self,

View File

@@ -0,0 +1,986 @@
from dataclasses import dataclass
from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type,
TypeVar)
import torch
import torch_npu
from torch import nn
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata,
MLAAttentionImpl)
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
from vllm.model_executor.layers.linear import (LinearBase,
UnquantizedLinearMethod)
from vllm.utils import cdiv, round_down
from vllm.v1.attention.backends.utils import AttentionCGSupport
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
split_decodes_and_prefills)
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
from vllm_ascend.worker.npu_input_batch import InputBatch
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
class AscendSFABackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_name() -> str:
return "ASCEND_SFA"
@staticmethod
def get_metadata_cls() -> type["AttentionMetadata"]:
return AscendSFAMetadata
@staticmethod
def get_builder_cls():
return AscendSFAMetadataBuilder
@staticmethod
def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int,
head_size: int) -> tuple[int, ...]:
return (num_blocks, block_size, num_kv_heads, head_size)
@staticmethod
def get_impl_cls() -> Type["AscendSFAImpl"]:
return AscendSFAImpl
@dataclass
class AscendSFAPrefillMetadata:
""" Prefill Specific Metadata for Ascend"""
@dataclass
class ChunkedContextMetadata:
# New for MLA (compared to FlashAttention)
# For handling chunked prefill
cu_seq_lens: torch.Tensor
starts: torch.Tensor
seq_tot: list[int]
max_seq_lens: list[int]
workspace: torch.Tensor
chunk_seq_lens: torch.Tensor
attn_mask: torch.Tensor
query_lens: list[int]
seq_lens: list[int]
context_lens: torch.Tensor
input_positions: torch.Tensor
query_start_loc: torch.Tensor
block_table: torch.Tensor
max_query_len: int
max_seq_lens: int
sin: torch.Tensor
cos: torch.Tensor
chunked_context: Optional[ChunkedContextMetadata] = None
@dataclass
class AscendSFADecodeMetadata:
# Input positions for rotrary embeddings since for MLA the rotary
# position embeddings are applied inside the attention backend
input_positions: torch.Tensor
block_table: torch.Tensor
seq_lens: torch.Tensor
max_seq_lens: int
seq_lens_list: list[int]
actual_seq_lengths_q: torch.Tensor
sin: torch.Tensor
cos: torch.Tensor
attn_mask: Optional[torch.Tensor] = None
@dataclass
class AscendSFAMetadata:
"""Metadata for MLACommon.
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
num_actual_tokens: int # Number of tokens excluding padding.
slot_mapping: torch.Tensor
query_start_loc: torch.Tensor
seq_lens: torch.Tensor
block_tables: torch.Tensor
# New for MLA (compared to FlashAttention)
# For handling prefill decode split
num_decodes: int
num_decode_tokens: int
num_prefills: int
# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.
query_lens: Optional[list[int]] = None
# The dimension of the attention heads
head_dim: Optional[int] = None
attn_mask: torch.Tensor = None
# chunked prefill by default if no attn_states passed
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
decode: Optional[AscendSFADecodeMetadata] = None
prefill: Optional[AscendSFAPrefillMetadata] = None
enable_dbo_across_dp: bool = False
def __post_init__(self):
pass
# supported_head_sizes = AscendMLABackend.get_supported_head_sizes()
# if self.head_dim is not None and self.head_dim \
# not in supported_head_sizes:
# raise ValueError(
# f"Only {supported_head_sizes} are supported for head_dim,",
# f"received {self.head_dim}.")
def split_metadata_for_multistream(
self,
ms_split_config: MSAttentionMetadataSplitConfig,
) -> list["AscendSFAMetadata"]:
"""Split metadata for multi-stream with AscendSFAMetadata"""
return model_input_split_v1_mla_attn(
ms_split_config=ms_split_config,
attn_metadata=self,
_metadata_cls=AscendMLAMetadata,
)
M = TypeVar("M", bound=AscendSFAMetadata)
class AscendSFAMetadataBuilder:
# Does this backend/builder support ACL Graphs for attention (default: no).
aclgraph_support: ClassVar[AttentionCGSupport] = \
AttentionCGSupport.NEVER
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
# _attn_mask_builder = None
def __init__(self,
kv_cache_spec,
layer_names,
vllm_config: VllmConfig,
device: torch.device,
metadata_cls: Optional[AscendSFAMetadata] = None):
self.metadata_cls: Optional[AscendSFAMetadata] = metadata_cls \
if metadata_cls is not None else AscendSFAMetadata # type: ignore
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.device = device
scheduler_config = vllm_config.scheduler_config
self.block_size = vllm_config.cache_config.block_size
self.max_blocks = (vllm_config.model_config.max_model_len +
self.block_size - 1) // self.block_size
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
self.speculative_config = vllm_config.speculative_config
self.decode_threshold = 1
if self.speculative_config:
spec_token_num = self.speculative_config.num_speculative_tokens
self.decode_threshold += spec_token_num
assert self.decode_threshold <= 16, f"decode_threshold exceeded \
npu_fused_infer_attention_score TND layout's limit of 16, \
got {self.decode_threshold}"
if self.chunked_prefill_enabled:
self.chunked_prefill_workspace_size = min(
# Max sure there is enough for 8 full length request or at least
# 4 pages of cache per request
max(8 * self.model_config.max_model_len,
4 * scheduler_config.max_num_seqs * self.block_size),
# For long-context models try not to over-allocate limiting
# kv-cache space, limiting it to 64k tokens,
# which would result in the workspace being:
# 2*(576)*(64*1024) = 144mb
# (assuming 576 MLA head dim, and fp16)
# which would result in up-projected context being
# 2*(192*128)*(64*1024) = 3gb
# (assuming 192 QK head dim, 128 heads, and fp16)
128 * 1024)
assert self.chunked_prefill_workspace_size >= \
scheduler_config.max_num_seqs * self.block_size
self.chunked_prefill_workspace = torch.empty(
(self.chunked_prefill_workspace_size,
self.model_config.get_head_size()),
dtype=self.model_config.dtype,
device=device,
)
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
self.cos_cache = None
self.sin_cache = None
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
# We now want to reorder the batch so that the "decode" requests are at
# the front and the "prefill" requests are at the using the least amount
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
# where attention is likely memory-bound and "prefill" to mean requests
# where attention is likely compute-bound, TODO(lucas): figure out a
# better naming here)
decodes = []
prefills = []
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
if num_tokens <= self.decode_threshold:
decodes.append(i)
else:
prefills.append(i)
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
# relatively stationary (and new request are generally appended to the
# persistent batch so already should be at the back)
# To achieve this we loop over the decodes in descending order and
# the prefills in ascending order. We swap decodes from the "back"
# i.e. past where the last decode should be in the reodorered with
# prefills from the front of the batch.
# `decodes` and `prefills` are already in ascending order just based on
# the above loop
num_decodes = len(decodes)
num_prefills = len(prefills)
first_prefill = 0
modified_batch = False
for i in range(1, min(num_decodes, num_prefills) + 1):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
if decodes[num_decodes - i] >= num_decodes:
input_batch.swap_states(prefills[first_prefill],
decodes[num_decodes - i])
first_prefill += 1
modified_batch = True
else:
break
# Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this
return modified_batch
def build(
self,
common_prefix_len: int,
common_attn_metadata: AscendCommonAttentionMetadata,
model: nn.Module,
) -> AscendSFAMetadata:
num_reqs = common_attn_metadata.num_reqs
num_actual_tokens = common_attn_metadata.num_actual_tokens
query_start_loc = common_attn_metadata.query_start_loc
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
# Note(simon): be careful about the CPU <> GPU memory movement in this
# function. We should avoid GPU -> CPU sync as much as possible because
# it blocks on all previous kernels.
device = self.device
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
slot_mapping = common_attn_metadata.slot_mapping[:
num_actual_tokens].to(
device,
non_blocking=True)
input_positions = common_attn_metadata.positions[:
num_actual_tokens].long(
)
if self.cos_cache is None:
self.cos_cache = model.model.layers[
0].self_attn.rotary_emb.cos_cached
self.sin_cache = model.model.layers[
0].self_attn.rotary_emb.sin_cached
if self.cos_cache.dtype != self.model_config.dtype: # type: ignore
self.cos_cache = self.cos_cache.to( # type: ignore
self.model_config.dtype) # type: ignore
self.sin_cache = self.sin_cache.to( # type: ignore
self.model_config.dtype) # type: ignore
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
query_lens = query_seq_lens_cpu[:num_reqs]
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
num_computed_tokens_cpu = (seq_lens - query_lens)
prefill_metadata = None
chunked_context_metadata = None
if num_prefills > 0:
reqs_start = num_decodes # prefill_start
tokens_start = num_decode_tokens
max_query_len = query_lens[reqs_start:].max().item()
max_seq_lens = seq_lens[reqs_start:].max().item()
prefill_query_start_loc = query_start_loc[
reqs_start:] - query_start_loc[reqs_start]
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
max_context_len_cpu = context_lens_cpu.max().item()
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
max_context_chunk = (self.chunked_prefill_workspace_size //
num_prefills_with_context_cpu)
max_context_chunk = round_down(max_context_chunk,
self.block_size)
assert max_context_chunk > 0
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \
.unsqueeze(1).expand(-1, num_prefills) * max_context_chunk
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
chunk_starts + max_context_chunk)
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
cu_seq_lens_cpu = torch.zeros(num_chunks,
num_prefills + 1,
dtype=torch.int32,
pin_memory=True)
torch.cumsum(chunk_seq_lens,
dim=1,
out=cu_seq_lens_cpu[:, 1:],
dtype=torch.int32)
chunked_context_metadata = \
AscendSFAPrefillMetadata.ChunkedContextMetadata(
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
starts=chunk_starts.to(device, non_blocking=True),
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
chunk_seq_lens=chunk_seq_lens,
workspace=self.chunked_prefill_workspace,
)
prefill_input_positions = input_positions[tokens_start:]
cos = self.cos_cache[
prefill_input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
sin = self.sin_cache[
prefill_input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
actual_query_lens = torch.tensor(query_lens[reqs_start:],
dtype=torch.int32).npu()
query_lens_prefill_sfa = torch.cumsum(actual_query_lens,
dim=0).to(torch.int32)
seq_lens_prefill_sfa = seq_lens[reqs_start:].to(torch.int32).npu()
prefill_metadata = AscendSFAPrefillMetadata(
attn_mask=common_attn_metadata.attn_mask,
query_lens=query_lens_prefill_sfa,
seq_lens=seq_lens_prefill_sfa,
context_lens=seq_lens[reqs_start:],
input_positions=prefill_input_positions,
block_table=block_table[reqs_start:, ...],
max_query_len=max_query_len,
max_seq_lens=max_seq_lens,
query_start_loc=prefill_query_start_loc,
chunked_context=chunked_context_metadata,
sin=sin,
cos=cos,
)
decode_metadata = None
if num_decodes > 0:
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].to(
torch.int32).npu()
max_seq_lens = seq_lens[:num_decodes].max().item()
seq_lens = seq_lens[:num_decodes].to(torch.int32).npu()
input_positions = input_positions[:num_decode_tokens]
block_table = block_table[:num_decodes, ...]
seq_lens_list = seq_lens.tolist()
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
1).unsqueeze(2)
decode_metadata = AscendSFADecodeMetadata(
input_positions=input_positions,
block_table=block_table,
seq_lens=seq_lens,
seq_lens_list=seq_lens_list,
max_seq_lens=max_seq_lens,
attn_mask=common_attn_metadata.spec_attn_mask,
actual_seq_lengths_q=actual_seq_lengths_q,
sin=sin,
cos=cos)
return self.metadata_cls( # type: ignore
num_actual_tokens=num_actual_tokens,
query_lens=query_lens.tolist(),
slot_mapping=slot_mapping,
head_dim=self.model_config.get_head_size(),
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
attn_mask=common_attn_metadata.attn_mask,
attn_state=common_attn_metadata.attn_state,
prefill=prefill_metadata,
decode=decode_metadata,
query_start_loc=query_start_loc,
block_tables=block_table,
seq_lens=seq_lens,
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
)
class PrefillSFAPreprocessResult(NamedTuple):
q_nope: Optional[torch.Tensor] = None
q_pe: Optional[torch.Tensor] = None
k_nope: Optional[torch.Tensor] = None
k_pe: Optional[torch.Tensor] = None
topk_indices: Optional[torch.Tensor] = None
query_states: Optional[torch.Tensor] = None
key_states: Optional[torch.Tensor] = None
class DecodeSFAPreprocessResult(NamedTuple):
q_nope: Optional[torch.Tensor] = None
q_pe: Optional[torch.Tensor] = None
# nope_cache: Optional[torch.Tensor] = None
# rope_cache: Optional[torch.Tensor] = None
topk_indices: Optional[torch.Tensor] = None
query_states: Optional[torch.Tensor] = None
key_states: Optional[torch.Tensor] = None
bsz: Optional[int] = None
class AscendSFAImpl(MLAAttentionImpl):
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
logits_soft_cap: Optional[float],
attn_type: str,
kv_sharing_target_layer_name: Optional[str],
**kwargs,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
# MLA Args
self.q_lora_rank = kwargs['q_lora_rank']
self.kv_lora_rank = kwargs['kv_lora_rank']
self.qk_nope_head_dim = kwargs['qk_nope_head_dim']
self.qk_rope_head_dim = kwargs['qk_rope_head_dim']
self.qk_head_dim = kwargs['qk_head_dim']
self.v_head_dim = kwargs['v_head_dim']
self.rotary_emb = kwargs['rotary_emb']
self.q_proj = kwargs['q_proj']
self.kv_b_proj = kwargs['kv_b_proj']
self.o_proj = kwargs['o_proj']
self.indexer = kwargs['indexer']
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
self.q_a_proj = kwargs.get('q_a_proj', None)
self.q_a_layernorm = kwargs.get('q_a_layernorm', None)
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_rank = self.num_heads // self.tp_size
if self.q_a_proj is not None:
self.q_b_proj = self.q_proj
else:
self.q_b_proj = None
ascend_config = get_ascend_config()
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
self.enable_prefetch = ascend_config.enable_prefetch
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
vllm_config = get_current_vllm_config()
self.ring_mla_mask_size = 512
self.prefill_mask = None
# indexer param
self.dim = self.indexer.dim
self.n_heads: int = self.indexer.n_heads # 64
self.head_dim: int = self.indexer.head_dim # 128
self.index_topk: int = self.indexer.index_topk # 2048
self.wq_b = self.indexer.wq_b
self.wk = self.indexer.wk
self.weights_proj = self.indexer.weights_proj
self.k_norm = self.indexer.k_norm
self.softmax_scale = self.indexer.softmax_scale
# Adapt torch air graph mode with spec decoding.
speculative_config = vllm_config.speculative_config
if speculative_config is not None:
self.spec_token_num = speculative_config.num_speculative_tokens
assert self.spec_token_num > 0
self.cp_size = 1
def process_weights_after_loading(self, act_dtype: torch.dtype):
def get_layer_weight(layer):
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
for attr in WEIGHT_NAMES:
if hasattr(layer, attr):
return getattr(layer, attr)
raise AttributeError(
f"Layer '{layer}' has no recognized weight attribute:"
f" {WEIGHT_NAMES}.")
def get_and_maybe_dequant_weights(layer: LinearBase):
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
# NOTE: This should only be used offline, since it's O(N^3)
eye = torch.eye(layer.input_size_per_partition,
dtype=act_dtype,
device=get_layer_weight(layer).device)
dequant_weights = layer.quant_method.apply(layer,
eye,
bias=None)
del eye
# standardize to (output, input)
return dequant_weights.T
return layer.weight
# we currently do not have quantized bmm's which are needed for
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
f"{kv_b_proj_weight.shape=}, "
f"{self.kv_lora_rank=}, "
f"{self.num_heads=}, "
f"{self.qk_nope_head_dim=}, "
f"{self.v_head_dim=}")
kv_b_proj_weight = kv_b_proj_weight.view(
self.kv_lora_rank,
self.num_heads,
self.qk_nope_head_dim + self.v_head_dim,
)
self.kv_b_proj_w_k, self.kv_b_proj_w_v = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
# Convert from (L, N, V) to (N, L, V)
self.kv_b_proj_w_v = self.kv_b_proj_w_v.transpose(0, 1).contiguous()
# Convert from (L, N, P) to (N, P, L)
self.kv_b_proj_w_k = self.kv_b_proj_w_k.permute(1, 2, 0).contiguous()
# Waiting for BMM NZ support
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
def _sfa_preprocess(self, hidden_states, kv_cache, attn_metadata,
need_gather_q_kv):
# SFA Preprocess:
# 1. Perform q_a_proj and q_a_layernorm to obtain q_c
# 2. Perform kv_a_proj_with_mqa to obtain kv_no_split
# 3. If need_gather_q_kv, perform all_gather.
# 4. Preprocess decode tokens, write kv cache and get:
# decode_ql_nope, decode_q_pe, decode_k_pe, decode_k_nope
# 5. Preprocess prefill tokens, write kv cache and get:
# prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
num_actual_tokens = attn_metadata.num_actual_tokens
if need_gather_q_kv:
# q_c = get_tp_group().all_gather(q_c, 0)
# kv_no_split = get_tp_group().all_gather(kv_no_split, 0)
hidden_states = get_tp_group().all_gather(hidden_states, 0)
# hidden_states_decode = hidden_states[:num_decode_tokens]
# if self.q_a_proj is not None:
# npu_prefetch(self.q_a_proj.weight,
# hidden_states,
# enabled=self.enable_prefetch)
# ckq = self.q_a_proj(hidden_states) # q down
# q_c = self.q_a_layernorm(ckq) # q down layernorm
# else:
# q_c = hidden_states
# kv_no_split = self.kv_a_proj_with_mqa(hidden_states) # c_kv
# Process for shared_expert_dp
decode_preprocess_res = None
prefill_preprocess_res = None
# Preprocess for decode tokens
if has_decode:
q_len = 1
hidden_states_decode = hidden_states[:num_decode_tokens]
decode_kq = self.q_a_proj(hidden_states_decode) # q down
decode_q_c = self.q_a_layernorm(decode_kq) # q down layernorm
decode_kv_no_split = self.kv_a_proj_with_mqa(
hidden_states_decode) # c_kv
# decode_q_c = q_c[:num_decode_tokens]
decode_slot_mapping = attn_metadata.slot_mapping[:
num_decode_tokens]
# decode_kv_no_split = decode_kv_no_split[:num_decode_tokens]
decode_q = self.q_b_proj(decode_q_c)
bsz, _ = decode_q.shape
decode_q = decode_q.view(bsz, self.num_heads, 1, self.qk_head_dim)
decode_q_nope, decode_q_pe = torch.split(
decode_q, [self.qk_nope_head_dim, self.qk_rope_head_dim],
dim=-1)
decode_q_nope = decode_q_nope.view(
-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1)
decode_q_nope = (torch.matmul(decode_q_nope,
self.kv_b_proj_w_k).transpose(
1,
0).view(bsz, q_len,
self.num_heads,
self.kv_lora_rank))
# stream2 kv
key_cache = kv_cache[0]
value_cache = kv_cache[1]
cos = attn_metadata.decode.cos
sin = attn_metadata.decode.sin
cos_q, sin_q = cos, sin
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
decode_kv_no_split = decode_kv_no_split.unsqueeze(1).unsqueeze(1)
decode_k_rope, decode_k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
decode_kv_no_split,
self.kv_a_layernorm.weight,
cos,
sin,
decode_slot_mapping.to(torch.int64),
value_cache,
key_cache,
c_kv_scale=None,
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode='PA') # adapter NZ
# nz_block_size = 16
# KVCACHE_NZ_DIM = 16
# decode_k_nope = decode_k_nope.view(block_num, 1, self.kv_lora_rank // nz_block_size, block_size, nz_block_size)
# decode_k_rope = decode_k_rope.view(block_num, 1, self.qk_rope_head_dim // KVCACHE_NZ_DIM, block_size, KVCACHE_NZ_DIM)
decode_q_pe = torch_npu.npu_interleave_rope(decode_q_pe, cos,
sin) # BNSD
decode_q_nope = decode_q_nope.view(bsz, self.num_heads,
self.kv_lora_rank)
decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1)
topk_indices = self.indexer_select(hidden_states_decode,
decode_q_c,
attn_metadata=attn_metadata,
kv_cache=kv_cache)
query_states = (decode_q_nope, decode_q_pe)
key_states = (decode_k_nope, decode_k_rope)
decode_preprocess_res = DecodeSFAPreprocessResult(
q_nope=decode_q_nope,
q_pe=decode_q_pe,
# nope_cache = nope_cache,
# rope_cache = rope_cache,
topk_indices=topk_indices,
query_states=query_states,
key_states=key_states,
bsz=bsz,
)
# Preprocess for prefill tokens
if has_prefill:
bsz = 1
hidden_states_prefill = hidden_states[
num_decode_tokens:num_actual_tokens]
prefill_kq = self.q_a_proj(hidden_states_prefill) # q down
prefill_q_c = self.q_a_layernorm(prefill_kq) # q down layernorm
prefill_kv_no_split = self.kv_a_proj_with_mqa(
hidden_states_prefill) # c_kv
# prefill_q_c = q_c[
# num_decode_tokens:num_actual_tokens]
prefill_slot_mapping = attn_metadata.slot_mapping[
num_decode_tokens:num_actual_tokens]
# decode_kv_no_split = decode_kv_no_split[:num_decode_tokens]
prefill_slot_mapping = attn_metadata.slot_mapping[
num_decode_tokens:num_actual_tokens]
# prefill_kv_no_split = kv_no_split[
# num_decode_tokens:num_actual_tokens]
# prefill_qr = prefill_q_c[num_decode_tokens:num_actual_tokens]
prefill_qr = prefill_q_c
prefill_q = self.q_b_proj(prefill_qr)
prefill_q = prefill_q.view(-1, self.num_heads, self.qk_head_dim)
prefill_q_nope, prefill_q_pe = torch.split(
prefill_q, [self.qk_nope_head_dim, self.qk_rope_head_dim],
dim=-1)
prefill_q_nope = prefill_q_nope.view(
-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1)
prefill_q_nope = (torch.matmul(prefill_q_nope,
self.kv_b_proj_w_k).transpose(
1,
0).view(-1, self.num_heads,
self.kv_lora_rank))
prefill_q_pe = prefill_q_pe.unsqueeze(2)
# stream2 kv
nope_cache = kv_cache[0]
rope_cache = kv_cache[1]
cos = attn_metadata.prefill.cos
sin = attn_metadata.prefill.sin
cos_q, sin_q = cos, sin
# cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
# sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
prefill_q_pe = torch_npu.npu_interleave_rope(
prefill_q_pe, cos_q, sin_q) # BNSD
prefill_q_pe = prefill_q_pe.squeeze(2) #BSH
# q[..., self.qk_nope_head_dim:] = prefill_q_pe # TODO:????
prefill_latent_cache = prefill_kv_no_split # (B,S,N,D)
prefill_k_pe, prefill_k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
prefill_latent_cache.view(
-1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim),
self.kv_a_layernorm.weight,
cos.view(-1, 1, 1, self.qk_rope_head_dim),
sin.view(-1, 1, 1, self.qk_rope_head_dim),
prefill_slot_mapping.to(torch.int64),
rope_cache,
nope_cache,
k_rope_scale=None,
c_kv_scale=None,
k_rope_offset=None,
c_kv_offset=None,
epsilon=self.kv_a_layernorm.variance_epsilon,
cache_mode="PA")
topk_indices = self.indexer_select(x=hidden_states_prefill,
qr=prefill_qr,
kv_cache=kv_cache,
attn_metadata=attn_metadata)
query_states = (prefill_q_nope, prefill_q_pe)
key_states = (prefill_k_nope, prefill_k_pe)
prefill_preprocess_res = PrefillSFAPreprocessResult(
q_nope=prefill_q_nope,
q_pe=prefill_q_pe,
topk_indices=topk_indices,
k_nope=prefill_k_nope,
k_pe=prefill_k_pe,
query_states=query_states,
key_states=key_states,
)
return decode_preprocess_res, prefill_preprocess_res
def forward(
self,
hidden_states: torch.Tensor, # query in unified attn
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
attn_metadata: M,
need_gather_q_kv: bool = False,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert output is not None, "Output tensor must be provided."
if attn_metadata is None:
# Profiling run.
return output
num_actual_tokens = attn_metadata.num_actual_tokens
assert attn_metadata.num_decodes is not None and \
attn_metadata.num_prefills is not None and \
attn_metadata.num_decode_tokens is not None
num_decode_tokens = attn_metadata.num_decode_tokens
# Inputs and outputs may be padded for CUDA graphs
output = output[:num_actual_tokens, ...]
o_proj_input_shape = (num_actual_tokens,
self.num_heads * self.v_head_dim)
o_proj_input = torch.empty(o_proj_input_shape,
dtype=hidden_states.dtype,
device=hidden_states.device)
# SFA Preprocess
decode_preprocess_res, prefill_preprocess_res = self._sfa_preprocess(
hidden_states, kv_cache, attn_metadata, need_gather_q_kv)
if decode_preprocess_res is not None:
# bsz, q_len, _, _ = query_states[0].shape
decode_attn_output = self.apply_attention_fusion(
query_states=decode_preprocess_res.query_states,
key_states=decode_preprocess_res.key_states,
attn_metadata=attn_metadata,
topk_indices=decode_preprocess_res.topk_indices)
o_proj_input[:num_decode_tokens] = decode_attn_output
if prefill_preprocess_res is not None:
prefill_attn_output = self.apply_attention_fusion(
query_states=prefill_preprocess_res.query_states,
key_states=prefill_preprocess_res.key_states,
attn_metadata=attn_metadata,
topk_indices=prefill_preprocess_res.topk_indices)
o_proj_input[num_decode_tokens:] = prefill_attn_output
output[...] = self.mla_epilog(o_proj_input, absorb=True)
return output
def apply_attention_fusion(self, query_states, key_states, topk_indices,
attn_metadata: M):
# repeat k/v heads if n_kv_heads < n_heads
q_nope, q_pe = query_states
k_nope, k_rope = key_states
if attn_metadata.prefill is not None:
prefill_metadata = attn_metadata.prefill
slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention(
query=q_nope,
key=k_nope,
value=k_nope,
sparse_indices=topk_indices,
scale_value=self.scale,
sparse_block_size=1,
block_table=prefill_metadata.block_table,
actual_seq_lengths_query=prefill_metadata.query_lens,
actual_seq_lengths_kv=prefill_metadata.seq_lens,
query_rope=q_pe,
key_rope=k_rope,
layout_query="TND",
layout_kv="PA_BSND",
sparse_mode=3,
)
elif attn_metadata.decode is not None:
decode_metadata = attn_metadata.decode
slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention(
query=q_nope,
key=k_nope,
value=k_nope,
sparse_indices=topk_indices,
scale_value=self.scale,
sparse_block_size=1,
block_table=attn_metadata.decode.block_table,
actual_seq_lengths_query=decode_metadata.actual_seq_lengths_q,
actual_seq_lengths_kv=decode_metadata.seq_lens,
query_rope=q_pe,
key_rope=k_rope,
layout_query="TND",
layout_kv="PA_BSND",
sparse_mode=3,
)
slc_fa_fusion = slc_fa_fusion.squeeze(1)
slc_fa_fusion = slc_fa_fusion.transpose(0, 1)
# input shape [N//attn_tp_size, T(bs*q_len), D]
# output shape [T(bs*q_len), N//attn_tp_size, D]
attn_output = torch.matmul(slc_fa_fusion,
self.kv_b_proj_w_v).transpose(1, 0).reshape(
-1, self.num_heads * self.v_head_dim)
# Note: Considering the fusion rules of TBMM, attn_output shape requires a 3-dim shape, and
# with appropriate tensor stride for the later 'view' operation if oproj_tp_size > 1.
# after reshape: [T(bs*q_len), 1, N//attn_tp_size*D]
# attn_output = attn_output.reshape(-1, self.num_heads * self.v_head_dim)
return attn_output
def mla_epilog(self,
attn_output: torch.Tensor = None,
absorb: bool = False):
# TODO: need to check
attn_output = self.o_proj(attn_output.reshape(attn_output.shape[0],
-1),
is_prefill=True,
is_force_scatter=False)
return attn_output
def indexer_select(
self,
x: torch.Tensor,
qr: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
attn_metadata: M,
):
if attn_metadata.prefill is not None:
cos = attn_metadata.prefill.cos
sin = attn_metadata.prefill.sin
actual_seq_lengths_query = attn_metadata.prefill.query_lens
actual_seq_lengths_key = attn_metadata.prefill.seq_lens
block_table = attn_metadata.prefill.block_table
elif attn_metadata.decode is not None:
cos = attn_metadata.decode.cos
sin = attn_metadata.decode.sin
actual_seq_lengths_query = attn_metadata.decode.actual_seq_lengths_q
actual_seq_lengths_key = attn_metadata.decode.seq_lens
block_table = attn_metadata.decode.block_table
cos_q, sin_q = cos, sin
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
# q process in new stream
q = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
q = q.view(-1, self.n_heads, self.head_dim) # [b,s,64,128]
q_pe, q_nope = torch.split(
q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
dim=-1) # [b,s,64,64+64]
q_pe = q_pe.unsqueeze(2)
q_pe = torch_npu.npu_interleave_rope(q_pe, cos_q, sin_q)
q_pe = q_pe.squeeze(2)
q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128]
k_proj = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128]
k = self.k_norm(k_proj).unsqueeze(1)
k_pe, k_nope = torch.split(
k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
dim=-1) # [b,s,64+64]
k_pe = k_pe.unsqueeze(2)
k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin)
k_pe = k_pe.squeeze(2)
k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128]
if kv_cache is not None:
torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]),
attn_metadata.slot_mapping.view(
-1, 1),
k.view(-1,
k.shape[-1])) # b, s, n, d
weights = self.weights_proj(x)
topk_indices = torch.ops.custom.npu_lightning_indexer(
query=q,
key=kv_cache[2],
weights=weights,
actual_seq_lengths_query=actual_seq_lengths_query,
actual_seq_lengths_key=actual_seq_lengths_key,
block_table=block_table,
layout_query="TND",
layout_key="PA_BSND",
sparse_count=2048,
sparse_mode=3)
return topk_indices

View File

@@ -493,8 +493,11 @@ class LLMDataDistCMgrConnectorWorker():
assert self.local_agent_metadata is not None
kv_cache_dtype = first_kv_cache.dtype
self.use_mla: bool = first_kv_cache_tuple[0].size(
-1) != first_kv_cache_tuple[1].size(-1)
-1) != first_kv_cache_tuple[1].size(-1) and len(
first_kv_cache_tuple) == 2
self.use_sfa: bool = len(first_kv_cache_tuple) == 3
# MLA case. [2 (k_normed, k_pe), num_blocks, ...]
# SFA case. [3 (k_normed, k_pe, k_idx), num_blocks, ...]
# MHA case. [2 (k and v), num_blocks, ...]
self.num_blocks = first_kv_cache.shape[0]
block_rank = 3 # [block_size, latent_dim]
@@ -540,6 +543,58 @@ class LLMDataDistCMgrConnectorWorker():
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
)
elif self.use_sfa:
cache_k_normed_addr_list = []
cache_k_pe_addr_list = []
cache_k_idx_addr_list = []
k_normed = None
k_pe = None
k_idx = None
for cache_or_caches in kv_caches.values():
assert len(cache_or_caches) > 1
k_normed, k_pe, k_idx = cache_or_caches[0], cache_or_caches[
1], cache_or_caches[2]
cache_k_normed_addr_list.append(k_normed.data_ptr())
cache_k_pe_addr_list.append(k_pe.data_ptr())
cache_k_idx_addr_list.append(k_idx.data_ptr())
self.cache_addr = (cache_k_normed_addr_list, cache_k_pe_addr_list,
cache_k_idx_addr_list)
cache_desc_k_normed = CacheDesc(
len(self.cache_addr[0]), [*k_normed.shape],
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
cache_desc_k_pe = CacheDesc(
len(self.cache_addr[1]), [*k_pe.shape],
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
cache_desc_k_idx = CacheDesc(
len(self.cache_addr[2]), [*k_idx.shape],
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
cache_key_k_normed = BlocksCacheKey(cluster_id=int(
self.local_agent_metadata.cluster_id),
model_id=0)
cache_key_k_pe = BlocksCacheKey(cluster_id=int(
self.local_agent_metadata.cluster_id),
model_id=1)
cache_key_k_idx = BlocksCacheKey(cluster_id=int(
self.local_agent_metadata.cluster_id),
model_id=2)
self.cache_desc = (cache_desc_k_normed, cache_desc_k_pe,
cache_desc_k_idx)
self.cache_key = (cache_key_k_normed, cache_key_k_pe,
cache_key_k_idx)
try:
cache_k_normed = self.cache_manager.register_blocks_cache(
self.cache_desc[0], self.cache_addr[0], self.cache_key[0])
cache_k_pe = self.cache_manager.register_blocks_cache(
self.cache_desc[1], self.cache_addr[1], self.cache_key[1])
cache_k_idx = self.cache_manager.register_blocks_cache(
self.cache_desc[2], self.cache_addr[2], self.cache_key[2])
self.cache = (cache_k_normed, cache_k_pe, cache_k_idx)
logger.info("LLMDataDistWorker: End of register Paged Cache.")
except (TypeError, ValueError):
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
)
else:
for cache_or_caches in kv_caches.values():
for cache in cache_or_caches:
@@ -826,6 +881,38 @@ class LLMDataDistCMgrConnectorWorker():
raise RuntimeError(
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
)
elif self.use_sfa:
remote_cache_key_k_normed = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=0)
remote_cache_key_k_pe = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=1)
remote_cache_key_k_idx = BlocksCacheKey(
cluster_id=remote_cluster_id, model_id=2)
logger.info("Try pull blocks from remote server")
try:
self.cache_manager.pull_blocks(
remote_cache_key_k_normed,
self.cache[0], # type: ignore[has-type]
remote_block_ids,
local_block_ids)
self.cache_manager.pull_blocks(
remote_cache_key_k_pe,
self.cache[1], # type: ignore[has-type]
remote_block_ids,
local_block_ids)
self.cache_manager.pull_blocks(
remote_cache_key_k_idx,
self.cache[2], # type: ignore[has-type]
remote_block_ids,
local_block_ids)
except (TypeError, ValueError):
raise RuntimeError(
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key_k_normed} {remote_cache_key_k_pe} {remote_cache_key_k_idx}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
)
except LLMException:
raise RuntimeError(
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
)
else:
remote_cache_key = BlocksCacheKey(cluster_id=remote_cluster_id)
logger.info("Try pull blocks from remote server")

View File

@@ -30,6 +30,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import RequestStatus
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
@@ -238,6 +239,7 @@ class KVCacheRecvingThread(threading.Thread):
self.block_len = block_len
# TODO(jianzs): find a better way to detect MLA.
self.use_mla = len(block_len) == 2
self.use_sfa = len(block_len) == 3
self.request_queue: queue.Queue[Any] = queue.Queue()
# TODO(jianzs): make this configurable
@@ -349,8 +351,12 @@ class KVCacheRecvingThread(threading.Thread):
src_list, dst_list, length_list = [], [], []
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs)):
block_len = (self.block_len[k % 2]
if self.use_mla else self.block_len[0])
if self.use_mla:
block_len = (self.block_len[k % 2])
elif self.use_sfa:
block_len = (self.block_len[k % 3])
else:
block_len = (self.block_len[0])
for i, remote_block_id in enumerate(grouped_remote_block_ids):
local_block_ids = grouped_local_block_ids[i]
src = src_layer_base_addr + local_block_ids[0] * block_len
@@ -567,6 +573,7 @@ class MooncakeConnectorScheduler:
def __init__(self, vllm_config: VllmConfig, engine_id: str):
self.vllm_config = vllm_config
self.ascend_config = get_ascend_config()
self.block_size = vllm_config.cache_config.block_size
self.engine_id = engine_id
logger.info("Initializing Mooncake Scheduler %s", engine_id)
@@ -726,7 +733,7 @@ class MooncakeConnectorScheduler:
assert "tp_size" in decode_parallel_config.keys()
self._decode_tp_size = decode_parallel_config["tp_size"]
if self.vllm_config.model_config.use_mla:
if self.vllm_config.model_config.use_mla or self.ascend_config.use_sfa:
return self._decode_tp_size
else:
# TODO support mha and gqa
@@ -847,7 +854,9 @@ class MooncakeConnectorWorker:
# TODO(tms): Find a more robust way to detect and handle MLA
self.use_mla = first_kv_cache_tuple[0].size(
-1) != first_kv_cache_tuple[1].size(-1)
-1) != first_kv_cache_tuple[1].size(-1) and len(
first_kv_cache_tuple) == 2
self.use_sfa = len(first_kv_cache_tuple) == 3
if self.use_mla:
# MLA case.[num_block, block_size, 1, hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
@@ -861,6 +870,21 @@ class MooncakeConnectorWorker:
logger.info(
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
self.num_blocks, block_shape_norm, block_shape_pe)
elif self.use_sfa:
self.num_blocks = first_kv_cache.shape[0]
block_rank = 3 # [block_size, latent_dim]
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
block_shape_k = first_kv_cache_tuple[2].shape[-block_rank:]
self.block_len = [
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
first_kv_cache[2].element_size() * math.prod(block_shape_k)
]
logger.info(
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s, block_shape_k: %s",
self.num_blocks, block_shape_norm, block_shape_pe,
block_shape_k)
else:
# [num_block, block_size, num_head, hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
@@ -871,8 +895,9 @@ class MooncakeConnectorWorker:
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
block_shape)
logger.info("Registering KV_Caches. use_mla: %s, shape %s",
self.use_mla, first_kv_cache.shape)
logger.info(
"Registering KV_Caches. use_mla: %s, use_sfa: %s, shape %s",
self.use_mla, self.use_sfa, first_kv_cache.shape)
self.kv_caches = kv_caches
kv_caches_base_addr = []
@@ -884,9 +909,16 @@ class MooncakeConnectorWorker:
region_len = self.num_blocks * self.block_len[i % 2]
kv_caches_base_addr.append(base_addr)
self._register(base_addr, region_len)
elif self.use_sfa:
for i, cache in enumerate(cache_or_caches, 0):
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[i % 3]
kv_caches_base_addr.append(base_addr)
self._register(base_addr, region_len)
else:
cache_list = [cache_or_caches
] if self.use_mla else cache_or_caches
cache_list = [
cache_or_caches
] if self.use_mla or self.use_sfa else cache_or_caches
for cache in cache_list:
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[0]

View File

@@ -162,6 +162,13 @@ env_variables: Dict[str, Callable[[], Any]] = {
# Whether to enable msMonitor tool to monitor the performance of vllm-ascend.
"MSMONITOR_USE_DAEMON":
lambda: bool(int(os.getenv("MSMONITOR_USE_DAEMON", '0'))),
# Timeout (in seconds) for delayed KVCache block release. In the prefill
# node, if a request is marked for delayed KV block release and the blocks
# are not freed within this timeout, they will be forcibly released.
"VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT":
lambda: int(os.getenv("VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT", 250)),
"VLLM_ASCEND_ENABLE_MLAPO":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MLAPO", '0'))),
}
# end-env-vars-definition

View File

@@ -37,6 +37,10 @@ def register_model():
"DeepseekV3ForCausalLM",
"vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM")
ModelRegistry.register_model(
"DeepseekV32ForCausalLM",
"vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM")
ModelRegistry.register_model(
"DeepSeekMTPModel",
"vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP")

View File

@@ -60,6 +60,8 @@ from vllm.model_executor.models.utils import (PPMissingLayer,
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.models.layers.mla import AscendMLAModules
from vllm_ascend.models.layers.sfa import (AscendSFAModules,
AscendSparseFlashAttention, Indexer)
from vllm_ascend.ops.fused_moe import AscendFusedMoE
@@ -244,6 +246,180 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
return self.mla_attn(positions, hidden_states, kv_cache, attn_metadata)
class CustomDeepseekV2SFAAttention(DeepseekV2MLAAttention):
def __init__(
self,
config: PretrainedConfig,
hidden_size: int,
num_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: Optional[int],
kv_lora_rank: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.num_heads = num_heads
self.tp_size = get_tensor_model_parallel_world_size()
assert num_heads % self.tp_size == 0
self.num_local_heads = num_heads // self.tp_size
self.layers = config.num_hidden_layers
self.first_k_dense_replace = config.first_k_dense_replace
self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.prefix = prefix
self.debug_layer_idx = int(self.prefix.split(".")[-2])
ascend_config = get_ascend_config()
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(
self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_a_proj",
return_bias=False,
)
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(
q_lora_rank,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_b_proj",
return_bias=False,
)
else:
self.q_proj = ColumnParallelLinear(
self.hidden_size,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj",
return_bias=False,
)
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_a_proj_with_mqa",
return_bias=False,
)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj",
return_bias=False,
)
self.o_proj = CustomDeepseekV2RowParallelLinear(
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
return_bias=False,
)
if rope_scaling:
rope_scaling["rope_type"] = 'deepseek_yarn'
self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=False)
if rope_scaling:
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
scaling_factor = rope_scaling["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
self.dim: int = config.hidden_size # 7168
# TODO(zzzzwwjj): wait transformers add these params
self.n_heads: int = 64 # 64
self.head_dim: int = 128 # 128
self.index_topk: int = 2048 # 2048
self.indexer = Indexer(
config,
quant_config=quant_config,
dim=self.dim,
n_heads=self.n_heads,
head_dim=self.head_dim,
index_topk=self.index_topk,
prefix=f"{prefix}.indexer",
)
sfa_modules = AscendSFAModules(
q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None,
q_a_layernorm=self.q_a_layernorm
if self.q_lora_rank is not None else None,
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
kv_a_layernorm=self.kv_a_layernorm,
kv_b_proj=self.kv_b_proj,
o_proj=self.o_proj,
rotary_emb=self.rotary_emb,
indexer=self.indexer)
self.sfa_attn = AscendSparseFlashAttention(
self.hidden_size,
self.enable_shared_expert_dp,
self.debug_layer_idx,
self.first_k_dense_replace,
self.tp_size,
sfa_modules,
self.num_local_heads,
self.scaling,
self.layers,
self.kv_lora_rank,
self.qk_rope_head_dim,
self.q_lora_rank,
self.qk_nope_head_dim,
self.qk_head_dim,
self.v_head_dim,
cache_config,
quant_config,
prefix,
)
self.prefix = prefix
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
return self.sfa_attn(positions, hidden_states, kv_cache, attn_metadata)
class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
@@ -253,6 +429,7 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
parallel_config = vllm_config.parallel_config
ascend_config = get_ascend_config()
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
@@ -268,7 +445,10 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
self.tp_rank = get_tp_group().rank_in_group
# TODO: enable mla in vllm-ascend
if model_config.use_mla:
attn_cls = CustomDeepseekV2MLAAttention
if ascend_config.use_sfa:
attn_cls = CustomDeepseekV2SFAAttention
else:
attn_cls = CustomDeepseekV2MLAAttention
else:
attn_cls = DeepseekV2Attention
self.self_attn = attn_cls(

View File

@@ -0,0 +1,233 @@
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional
import torch
from torch import nn
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.mla import MultiHeadLatentAttention
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.utils import direct_register_custom_op
@dataclass
class AscendSFAModules:
q_a_proj: Optional[torch.nn.Module]
q_a_layernorm: Optional[torch.nn.Module]
q_proj: Optional[torch.nn.Module]
kv_a_proj_with_mqa: torch.nn.Module
kv_a_layernorm: torch.nn.Module
kv_b_proj: torch.nn.Module
o_proj: torch.nn.Module
rotary_emb: torch.nn.Module
indexer: torch.nn.Module
class AscendSparseFlashAttention(MultiHeadLatentAttention):
def __init__(
self,
hidden_size: int,
enable_shared_expert_dp: bool,
debug_layer_idx: int,
first_k_dense_replace: int,
tp_size: int,
sfa_modules: AscendSFAModules,
num_local_heads: int,
scaling: float,
layers: int,
kv_lora_rank: int,
qk_rope_head_dim: int,
q_lora_rank: Optional[int],
qk_nope_head_dim: int,
qk_head_dim: int,
v_head_dim: int,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.enable_shared_expert_dp = enable_shared_expert_dp
self.debug_layer_idx = debug_layer_idx
self.first_k_dense_replace = first_k_dense_replace
self.tp_size = tp_size
self.num_local_heads = num_local_heads
self.layers = layers
self.kv_lora_rank = kv_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.q_lora_rank = q_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_head_dim = qk_head_dim
self.v_head_dim = v_head_dim
self.prefix = prefix
self.sfa_attn = Attention(
num_heads=self.num_local_heads,
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
scale=scaling,
num_kv_heads=1,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_mla=True,
use_sfa=True,
# SFA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
qk_head_dim=self.qk_head_dim,
v_head_dim=self.v_head_dim,
rotary_emb=sfa_modules.rotary_emb,
q_a_proj=sfa_modules.q_a_proj,
q_a_layernorm=sfa_modules.q_a_layernorm,
q_proj=sfa_modules.q_proj,
kv_a_proj_with_mqa=sfa_modules.kv_a_proj_with_mqa,
kv_a_layernorm=sfa_modules.kv_a_layernorm,
kv_b_proj=sfa_modules.kv_b_proj,
o_proj=sfa_modules.o_proj,
indexer=sfa_modules.indexer)
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
num_tokens = hidden_states.shape[0]
need_gather_q_kv = False
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
# Simulate all gather to calculate output shape
num_tokens = num_tokens * self.tp_size
need_gather_q_kv = True
if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace:
output_shape = hidden_states.shape
else:
rows = num_tokens // self.tp_size
if num_tokens % self.tp_size:
rows += 1
output_shape = (rows, hidden_states.shape[1])
# FIXME: This does not seem right, should make sure the buffer is fixed
output = torch.empty(output_shape,
dtype=hidden_states.dtype,
device=hidden_states.device)
torch.ops.vllm.sfa_forward(hidden_states, need_gather_q_kv, output,
self.prefix)
output = output.view(-1, output_shape[-1])
return output
def sfa_forward(
hidden_states: torch.Tensor,
need_gather_q_kv: bool,
output: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
if forward_context.attn_metadata:
attn_metadata = forward_context.attn_metadata[self.sfa_attn.layer_name]
else:
attn_metadata = forward_context.attn_metadata
kv_cache = self.sfa_attn.kv_cache[forward_context.virtual_engine]
self.sfa_attn.impl.forward(hidden_states, kv_cache, attn_metadata,
need_gather_q_kv, output)
return
class Indexer(nn.Module):
def __init__(self,
config,
dim: int = 7168,
n_heads: int = 64,
head_dim: int = 128,
index_topk: int = 2048,
q_lora_rank: int = 1536,
rope_head_dim: int = 64,
quant_config: Optional[QuantizationConfig] = None,
prefix: Optional[str] = ""):
super().__init__()
self.dim: int = dim # 7168
self.n_heads: int = n_heads # 64
self.head_dim: int = head_dim # 128
self.rope_head_dim: int = rope_head_dim # 64
self.index_topk: int = index_topk # 2048
self.q_lora_rank: int = q_lora_rank # 1536
self.wq_b = ReplicatedLinear(
self.q_lora_rank,
self.n_heads * self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wq_b",
return_bias=False,
)
self.wk = ReplicatedLinear(
self.dim,
self.head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.wk",
return_bias=False,
)
self.weights_proj = ReplicatedLinear(
self.dim,
self.n_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.weights_proj",
return_bias=False,
)
self.k_norm = nn.LayerNorm(self.head_dim)
self.softmax_scale = self.head_dim**-0.5
def forward(self):
return
def sfa_forward_fake(
hidden_states: torch.Tensor,
need_gather_q_kv: bool,
output: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="sfa_forward",
op_func=sfa_forward,
mutates_args=["output"],
fake_impl=sfa_forward_fake,
dispatch_key="PrivateUse1",
)

View File

@@ -15,6 +15,10 @@
# limitations under the License.
#
import vllm_ascend.patch.platform.patch_common.patch_config # noqa
import vllm_ascend.patch.platform.patch_common.patch_distributed # noqa
import vllm_ascend.patch.platform.patch_common.patch_mamba_config # noqa
import vllm_ascend.patch.platform.patch_common.patch_multimodal_merge # noqa
import vllm_ascend.patch.platform.patch_common.patch_transformers_utils # noqa
import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa
import vllm_ascend.patch.worker.patch_common.patch_attentionspec # noqa

View File

@@ -0,0 +1,313 @@
import ast
import vllm.envs as envs
from transformers import PretrainedConfig
from vllm.config import ModelConfig
from vllm.config.speculative import SpeculativeConfig
from vllm.logger import logger
# mypy: ignore-errors
@property
def is_deepseek_mla(self: ModelConfig):
if not hasattr(self.hf_text_config, "model_type"):
return False
elif self.hf_text_config.model_type in \
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp',
'kimi_k2', 'longcat_flash', 'deepseek_v32'):
return self.hf_text_config.kv_lora_rank is not None
elif self.hf_text_config.model_type == 'eagle':
# if the model is an EAGLE module, check for the
# underlying architecture
return self.hf_text_config.model.model_type in \
('deepseek_v2', 'deepseek_v3', 'deepseek_v32') \
and self.hf_text_config.kv_lora_rank is not None
return False
@staticmethod
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
if hf_config.model_type in ("deepseek_v3", "deepseek_v32"):
hf_config.model_type = "deepseek_mtp"
if hf_config.model_type == "deepseek_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update({
"n_predict": n_predict,
"architectures": ["DeepSeekMTPModel"]
})
if hf_config.architectures[0] == "MiMoForCausalLM":
hf_config.model_type = "mimo_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update({
"num_hidden_layers": 0,
"n_predict": n_predict,
"architectures": ["MiMoMTPModel"]
})
if hf_config.architectures[0] == "Glm4MoeForCausalLM":
hf_config.model_type = "glm4_moe_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update({
"num_hidden_layers": 0,
"n_predict": n_predict,
"architectures": ["Glm4MoeMTPModel"]
})
if hf_config.model_type == "ernie4_5_moe":
hf_config.model_type = "ernie_mtp"
if hf_config.model_type == "ernie_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update({
"n_predict": n_predict,
"architectures": ["ErnieMTPModel"]
})
if hf_config.model_type == "qwen3_next":
hf_config.model_type = "qwen3_next_mtp"
if hf_config.model_type == "qwen3_next_mtp":
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
hf_config.update({
"n_predict": n_predict,
"architectures": ["Qwen3NextMTP"]
})
if hf_config.model_type == "longcat_flash":
hf_config.model_type = "longcat_flash_mtp"
n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
hf_config.update({
"n_predict": n_predict,
"architectures": ["LongCatFlashMTPModel"]
})
return hf_config
def __post_init__(self):
# Note: "method" is a new parameter that helps to extend the
# configuration of non-model-based proposers, and the "model" parameter
# will be used to set the draft model, eagle head, or additional weight
# when needed. If users do not specify "method", the speculative method
# will be detected automatically if possible. If the speculative method
# can not be detected, it will be considered as the "draft_model" by
# default.
if self.model is None and self.num_speculative_tokens is not None:
# TODO(Shangming): Refactor mtp configuration logic when supporting
if (self.target_model_config
and self.target_model_config.hf_text_config.model_type
in ("deepseek_v3", "deepseek_v32", "mimo", "ernie4_5_moe",
"qwen3_next")):
# use the draft model from the same model:
self.model = self.target_model_config.model
# Align the quantization of draft model for cases such as
# --quantization fp8 with a bf16 checkpoint.
if not self.quantization:
self.quantization = self.target_model_config.quantization
elif self.method in ("ngram", "[ngram]"):
self.model = "ngram"
else:
raise ValueError("num_speculative_tokens was provided but without "
"speculative model.")
# Automatically configure the method for ngram when "model" is used
# instead of "method"
if self.method is None and (self.model is not None
and self.model in ("ngram", "[ngram]")):
self.method = "ngram"
if self.method in ("ngram", "[ngram]"):
# Unified to "ngram" internally
self.method = "ngram"
# Set default values if not provided
if (self.prompt_lookup_min is None and self.prompt_lookup_max is None):
# TODO(woosuk): Tune these values. They are arbitrarily chosen.
self.prompt_lookup_min = 5
self.prompt_lookup_max = 5
elif self.prompt_lookup_min is None:
assert self.prompt_lookup_max is not None
self.prompt_lookup_min = self.prompt_lookup_max
elif self.prompt_lookup_max is None:
assert self.prompt_lookup_min is not None
self.prompt_lookup_max = self.prompt_lookup_min
# Validate values
if self.prompt_lookup_min < 1:
raise ValueError(
f"prompt_lookup_min={self.prompt_lookup_min} must be > 0")
if self.prompt_lookup_max < 1:
raise ValueError(
f"prompt_lookup_max={self.prompt_lookup_max} must be > 0")
if self.prompt_lookup_min > self.prompt_lookup_max:
raise ValueError(
f"prompt_lookup_min={self.prompt_lookup_min} must "
f"be <= prompt_lookup_max={self.prompt_lookup_max}")
# TODO: current we still need extract vocab_size from target model
# config, in future, we may try refactor it out, and set
# draft related config as None here.
self.draft_model_config = self.target_model_config
self.draft_parallel_config = self.target_parallel_config
else:
self.prompt_lookup_max = 0
self.prompt_lookup_min = 0
if self.model is not None:
# TODO: Move this import to the top once `ModelConfig`
# lives in `vllm.config.model`.
from vllm.config import ModelConfig
self.draft_model_config = ModelConfig(
model=self.model,
runner="draft",
tokenizer=self.target_model_config.tokenizer,
tokenizer_mode=self.target_model_config.tokenizer_mode,
trust_remote_code=self.target_model_config.trust_remote_code,
allowed_local_media_path=self.target_model_config.
allowed_local_media_path,
allowed_media_domains=self.target_model_config.
allowed_media_domains,
dtype=self.target_model_config.dtype,
seed=self.target_model_config.seed,
revision=self.revision,
code_revision=self.code_revision,
tokenizer_revision=self.target_model_config.tokenizer_revision,
spec_target_max_model_len=self.target_model_config.
max_model_len,
quantization=self.quantization,
enforce_eager=self.target_model_config.enforce_eager,
max_logprobs=self.target_model_config.max_logprobs,
hf_overrides=SpeculativeConfig.hf_config_override,
)
# Automatically detect the method
if self.method in ('eagle', 'eagle3'):
pass
# examples:
# yuhuili/EAGLE-LLaMA3-Instruct-8B
# yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
# AngelSlim/Qwen3-8B_eagle3
elif "eagle-" in self.draft_model_config.model.lower():
self.method = "eagle"
elif "eagle3" in self.draft_model_config.model.lower():
self.method = "eagle3"
elif self.draft_model_config.hf_config.model_type == "medusa":
self.method = "medusa"
elif (self.draft_model_config.hf_config.model_type ==
"mlp_speculator"):
self.method = "mlp_speculator"
elif (self.draft_model_config.hf_config.model_type
in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")):
self.method = "deepseek_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Deepseek MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
elif (self.draft_model_config.hf_config.model_type == "ernie_mtp"):
self.method = "ernie_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Ernie MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
elif (self.draft_model_config.hf_config.model_type ==
"qwen3_next_mtp"):
self.method = "qwen3_next_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Qwen3Next MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
elif (self.draft_model_config.hf_config.model_type
in ("longcat_flash_mtp")):
self.method = "longcat_flash_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"LongCat MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
else:
self.method = "draft_model"
raise NotImplementedError(
"Speculative decoding with draft model is not "
"supported yet. Please consider using other "
"speculative decoding methods such as ngram, medusa, "
"eagle, or deepseek_mtp.")
# Replace hf_config for EAGLE draft_model
if self.method in ("eagle", "eagle3"):
if self.enable_chunked_prefill and not envs.VLLM_USE_V1:
raise ValueError(
"Chunked prefill and EAGLE are not compatible "
"when using V0.")
from vllm.transformers_utils.configs import SpeculatorsConfig
from vllm.transformers_utils.configs.eagle import EAGLEConfig
if isinstance(self.draft_model_config.hf_config,
(EAGLEConfig, SpeculatorsConfig)):
pass
else:
eagle_config = EAGLEConfig(
self.draft_model_config.hf_config,
method=self.method,
model_type="eagle")
self.draft_model_config.hf_config = eagle_config
if (self.num_speculative_tokens is not None
and hasattr(self.draft_model_config.hf_config,
"num_lookahead_tokens")):
self.draft_model_config.hf_config.num_lookahead_tokens = \
self.num_speculative_tokens
n_predict = getattr(self.draft_model_config.hf_config, "n_predict",
None)
if n_predict is not None:
if self.num_speculative_tokens is None:
# Default to max value defined in draft model config.
self.num_speculative_tokens = n_predict
elif self.num_speculative_tokens > n_predict and \
self.num_speculative_tokens % n_predict != 0:
# Ensure divisibility for MTP module reuse.
raise ValueError(
f"num_speculative_tokens:{self.num_speculative_tokens}"
f" must be divisible by {n_predict=}")
if self.speculative_token_tree is None:
# Generate chain of tokens.
self.speculative_token_tree = str([
(i + 1) * (0, ) for i in range(self.num_speculative_tokens)
])
else:
# Sort the token tree breadth-first.
tree_choices = ast.literal_eval(self.speculative_token_tree)
self.speculative_token_tree = str(
sorted(tree_choices, key=lambda t: (len(t), t)))
self.draft_tensor_parallel_size = \
SpeculativeConfig._verify_and_get_draft_tp(
self.target_parallel_config,
self.draft_tensor_parallel_size,
self.draft_model_config.hf_config
)
self.draft_model_config.max_model_len = (
SpeculativeConfig._maybe_override_draft_max_model_len(
self.max_model_len,
self.draft_model_config.max_model_len,
self.target_model_config.max_model_len,
))
self.draft_parallel_config = (
SpeculativeConfig.create_draft_parallel_config(
self.target_parallel_config,
self.draft_tensor_parallel_size))
ModelConfig.is_deepseek_mla = is_deepseek_mla
SpeculativeConfig.__post_init__ = __post_init__
SpeculativeConfig.hf_config_override = hf_config_override

View File

@@ -6,6 +6,8 @@ from vllm.model_executor.models.config import MambaModelConfig
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
from vllm_ascend.ascend_config import get_ascend_config
@classmethod
def verify_and_update_config(cls, vllm_config) -> None:
@@ -22,6 +24,7 @@ def verify_and_update_config(cls, vllm_config) -> None:
logger = init_logger(__name__)
# Enable FULL_AND_PIECEWISE by default
MambaModelConfig.verify_and_update_config(vllm_config)
ascend_config = get_ascend_config()
cache_config = vllm_config.cache_config
model_config = vllm_config.model_config
@@ -38,7 +41,7 @@ def verify_and_update_config(cls, vllm_config) -> None:
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
head_size=model_config.get_head_size(),
dtype=kv_cache_dtype,
use_mla=model_config.use_mla).page_size_bytes
use_mla=model_config.use_mla or ascend_config.use_sfa).page_size_bytes
model_cls, _ = ModelRegistry.resolve_model_cls(
model_config.architecture,

View File

@@ -0,0 +1,200 @@
import vllm.transformers_utils.configs
from transformers.configuration_utils import PretrainedConfig
from transformers.utils import logging
from vllm.transformers_utils import config
logger = logging.get_logger(__name__)
class DeepseekV3Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
defaults will yield a similar configuration to that of the DeepSeek-V3.
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.
Args:
vocab_size (`int`, *optional*, defaults to 129280):
Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
`inputs_ids` passed when calling [`DeepseekV3Model`]
hidden_size (`int`, *optional*, defaults to 4096):
Dimension of the hidden representations.
intermediate_size (`int`, *optional*, defaults to 11008):
Dimension of the MLP representations.
moe_intermediate_size (`int`, *optional*, defaults to 1407):
Dimension of the MoE representations.
num_hidden_layers (`int`, *optional*, defaults to 32):
Number of hidden layers in the Transformer decoder.
num_nextn_predict_layers (`int`, *optional*, defaults to 1):
Number of nextn predict layers in the DeepSeekV3 Model.
num_attention_heads (`int`, *optional*, defaults to 32):
Number of attention heads for each attention layer in the Transformer decoder.
n_shared_experts (`int`, *optional*, defaults to None):
Number of shared experts, None means dense model.
n_routed_experts (`int`, *optional*, defaults to None):
Number of routed experts, None means dense model.
routed_scaling_factor (`float`, *optional*, defaults to 1.0):
Scaling factor or routed experts.
topk_method (`str`, *optional*, defaults to `gready`):
Topk method used in routed gate.
n_group (`int`, *optional*, defaults to None):
Number of groups for routed experts.
topk_group (`int`, *optional*, defaults to None):
Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
num_experts_per_tok (`int`, *optional*, defaults to None):
Number of selected experts, None means dense model.
moe_layer_freq (`int`, *optional*, defaults to 1):
The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
first_k_dense_replace (`int`, *optional*, defaults to 0):
Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
\--k dense layers--/
norm_topk_prob (`bool`, *optional*, defaults to False):
Whether to normalize the weights of the routed experts.
scoring_func (`str`, *optional*, defaults to 'softmax'):
Method of computing expert weights.
aux_loss_alpha (`float`, *optional*, defaults to 0.001):
Auxiliary loss weight coefficient.
seq_aux = (`bool`, *optional*, defaults to True):
Whether to compute the auxiliary loss for each individual sample.
num_key_value_heads (`int`, *optional*):
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
by meanpooling all the original heads within that group. For more details checkout [this
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
`num_attention_heads`.
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
The non-linear activation function (function or string) in the decoder.
max_position_embeddings (`int`, *optional*, defaults to 2048):
The maximum sequence length that this model might ever be used with.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
The epsilon used by the rms normalization layers.
use_cache (`bool`, *optional*, defaults to `True`):
Whether or not the model should return the last key/values attentions (not used by all models). Only
relevant if `config.is_decoder=True`.
pad_token_id (`int`, *optional*):
Padding token id.
bos_token_id (`int`, *optional*, defaults to 1):
Beginning of stream token id.
eos_token_id (`int`, *optional*, defaults to 2):
End of stream token id.
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
Whether to tie weight embeddings
rope_theta (`float`, *optional*, defaults to 10000.0):
The base period of the RoPE embeddings.
rope_scaling (`Dict`, *optional*):
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
`max_position_embeddings` to the expected new maximum.
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
Whether to use a bias in the query, key, value and output projection layers during self-attention.
attention_dropout (`float`, *optional*, defaults to 0.0):
The dropout ratio for the attention probabilities.
```python
>>> from transformers import DeepseekV3Model, DeepseekV3Config
>>> # Initializing a Deepseek-V3 style configuration
>>> configuration = DeepseekV3Config()
>>> # Accessing the model configuration
>>> configuration = model.config
```"""
model_type = "deepseek_v3"
keys_to_ignore_at_inference = ["past_key_values"]
def __init__(
self,
vocab_size=129280,
hidden_size=7168,
intermediate_size=18432,
moe_intermediate_size=2048,
num_hidden_layers=61,
num_nextn_predict_layers=1,
num_attention_heads=128,
num_key_value_heads=128,
n_shared_experts=1,
n_routed_experts=256,
ep_size=1,
routed_scaling_factor=2.5,
kv_lora_rank=512,
q_lora_rank=1536,
qk_rope_head_dim=64,
v_head_dim=128,
qk_nope_head_dim=128,
topk_method='noaux_tc',
n_group=8,
topk_group=4,
num_experts_per_tok=8,
moe_layer_freq=1,
first_k_dense_replace=3,
norm_topk_prob=True,
scoring_func='sigmoid',
hidden_act="silu",
max_position_embeddings=4096,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=None,
bos_token_id=0,
eos_token_id=1,
tie_word_embeddings=False,
rope_theta=10000.0,
rope_scaling=None,
attention_bias=False,
attention_dropout=0.0,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.moe_intermediate_size = moe_intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_nextn_predict_layers = num_nextn_predict_layers
self.num_attention_heads = num_attention_heads
self.n_shared_experts = n_shared_experts
self.n_routed_experts = n_routed_experts
self.ep_size = ep_size
self.routed_scaling_factor = routed_scaling_factor
self.kv_lora_rank = kv_lora_rank
self.q_lora_rank = q_lora_rank
self.qk_rope_head_dim = qk_rope_head_dim
self.v_head_dim = v_head_dim
self.qk_nope_head_dim = qk_nope_head_dim
self.topk_method = topk_method
self.n_group = n_group
self.topk_group = topk_group
self.num_experts_per_tok = num_experts_per_tok
self.moe_layer_freq = moe_layer_freq
self.first_k_dense_replace = first_k_dense_replace
self.norm_topk_prob = norm_topk_prob
self.scoring_func = scoring_func
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.use_cache = use_cache
self.rope_theta = rope_theta
self.rope_scaling = rope_scaling
self.attention_bias = attention_bias
self.attention_dropout = attention_dropout
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
vllm.transformers_utils.configs.__all__.append("DeepseekV3Config")
vllm.transformers_utils.configs.DeepseekV3Config = DeepseekV3Config
config._CONFIG_REGISTRY["deepseek_v32"] = "DeepseekV3Config"

View File

@@ -20,6 +20,10 @@ from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
import vllm_ascend.patch.worker.patch_common.patch_triton
# isort: off
import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa
import vllm_ascend.patch.worker.patch_common.patch_attentionspec # noqa
import vllm_ascend.patch.worker.patch_common.patch_attention_layer # noqa
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
import vllm_ascend.patch.worker.patch_common.patch_logits # noqa
import vllm_ascend.patch.worker.patch_common.patch_weight_loader # noqa

View File

@@ -0,0 +1,202 @@
from typing import List, Optional
import torch
import vllm
import vllm.envs as envs
from torch import nn
from vllm.attention import Attention, AttentionType, get_attn_backend
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.selector import backend_name_to_enum
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
from vllm.config import CacheConfig, get_current_vllm_config
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.quantization.base_config import \
QuantizationConfig
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import current_platform
from vllm_ascend.utils import vllm_version_is
class AscendAttention(Attention, nn.Module, AttentionLayerBase):
"""Attention layer.
This class takes query, key, and value tensors as input. The input tensors
can either contain prompt tokens or generation tokens.
The class does the following:
1. Store the input key and value tensors in the KV cache.
2. Perform (multi-head/multi-query/grouped-query) attention.
3. Return the output tensor.
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
alibi_slopes: Optional[List[float]] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
logits_soft_cap: Optional[float] = None,
per_layer_sliding_window: Optional[int] = None,
use_mla: bool = False,
use_sfa: bool = False,
prefix: str = "",
attn_type: str = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[str] = None,
attn_backend: Optional[type[AttentionBackend]] = None,
**extra_impl_args,
) -> None:
"""
The KV cache is stored inside this class and is accessed via
`self.kv_cache`.
"""
nn.Module.__init__(self)
AttentionLayerBase.__init__(self)
if per_layer_sliding_window is not None:
# per-layer sliding window
sliding_window = per_layer_sliding_window
elif cache_config is not None:
# model-level sliding window
sliding_window = cache_config.sliding_window
else:
sliding_window = None
if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
is_attention_free = cache_config.is_attention_free
calculate_kv_scales = cache_config.calculate_kv_scales
else:
kv_cache_dtype = "auto"
block_size = 16
is_attention_free = False
calculate_kv_scales = False
if num_kv_heads is None:
num_kv_heads = num_heads
assert num_heads % num_kv_heads == 0, \
f"num_heads ({num_heads}) is not " \
f"divisible by num_kv_heads ({num_kv_heads})"
# The default k/v_scale is set to 1.0. This is ignored
# when kv-cache is not fp8, and should be used with
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
# expect the pre-quantized k/v_scale to be loaded along
# with the model weights.
self.kv_cache_dtype = kv_cache_dtype
self.calculate_kv_scales = calculate_kv_scales
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
# FlashAttn doesn't support quantizing the kv-cache only
# but requires q to be quantized as well.
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
# We also keep q/k/v_scale on host (cpu) memory for attention
# backends that require the scales to be on host instead of on device.
# e.g. Flashinfer
self._q_scale_float = 1.0
self._k_scale_float = 1.0
self._v_scale_float = 1.0
# The output scale on host memory. This should be the input scale of
# the quant op after this attention layer.
self._o_scale_float: Optional[float] = None
self.use_mla = use_mla
self.num_heads = num_heads
self.head_size = head_size
self.num_kv_heads = num_kv_heads
self.sliding_window = sliding_window
self.has_sink = extra_impl_args.get("sinks") is not None
quant_method = quant_config.get_quant_method(
self, prefix=prefix) if quant_config else None
if quant_method is not None and not isinstance(
quant_method, UnquantizedLinearMethod):
assert isinstance(quant_method, BaseKVCacheMethod)
# TODO (mgoin): kv cache dtype should be specified in the FP8
# checkpoint config and become the "auto" behavior
if self.kv_cache_dtype == "fp8_e5m2":
raise ValueError("fp8_e5m2 kv-cache is not supported with "
"fp8 checkpoints.")
# If quantization is enabled, we make "k_scale" and "v_scale"
# parameters so that it can be loaded from the model checkpoint.
# The k/v_scale will then be converted back to native float32
# values after weight loading.
self.quant_method = quant_method
self.quant_method.create_weights(self)
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
if attn_backend is None:
if vllm_version_is("0.10.2"):
self.attn_backend = get_attn_backend(head_size,
dtype,
kv_cache_dtype,
block_size,
is_attention_free,
use_mla=use_mla,
use_sfa=use_sfa,
has_sink=self.has_sink)
else:
self.attn_backend = get_attn_backend(head_size,
dtype,
kv_cache_dtype,
block_size,
use_mla=use_mla,
use_sfa=use_sfa,
has_sink=self.has_sink)
else:
self.attn_backend = attn_backend
impl_cls = self.attn_backend.get_impl_cls()
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
logits_soft_cap, attn_type,
kv_sharing_target_layer_name, **extra_impl_args)
self.backend = backend_name_to_enum(self.attn_backend.get_name())
self.dtype = dtype
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
# torch.compile works by registering the attention as one giant
# opaque custom op. For other platforms, we directly call them
# and let torch.compile handle them.
self.use_direct_call = not current_platform.opaque_attention_op()
self.use_output = self.attn_backend.accept_output_buffer
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
self.layer_name = prefix
self.attn_type = attn_type
if kv_sharing_target_layer_name is not None:
validate_kv_sharing_target(
prefix,
kv_sharing_target_layer_name,
compilation_config.static_forward_context,
)
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
# use a placeholder kv cache tensor during init, which will be replaced
# by bind_kv_cache
# this variable will not be accessed if use_direct_call is True
self.kv_cache = [
torch.tensor([]) for _ in range(get_current_vllm_config(
).parallel_config.pipeline_parallel_size)
]
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
self.query_quant = None
vllm.attention.Attention = AscendAttention

View File

@@ -0,0 +1,181 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# mypy: ignore-errors
from functools import cache
from typing import Optional
import torch
import vllm
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.attention.selector import (backend_name_to_enum,
get_global_forced_attn_backend)
from vllm.platforms import _Backend, current_platform
from vllm.utils import resolve_obj_by_qualname
from vllm_ascend.utils import vllm_version_is
if vllm_version_is("0.10.2"):
def get_attn_backend(
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
is_attention_free: bool = False,
use_mla: bool = False,
use_sfa: bool = False,
has_sink: bool = False,
) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
# value to be returned from the cache if the value changes between calls.
# To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the
# private function.
return _cached_get_attn_backend(
head_size=head_size,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
block_size=block_size,
is_attention_free=is_attention_free,
use_v1=envs.VLLM_USE_V1,
use_mla=use_mla,
use_sfa=use_sfa,
has_sink=has_sink,
)
@cache
def _cached_get_attn_backend(
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
is_attention_free: bool,
use_v1: bool = False,
use_mla: bool = False,
use_sfa: bool = False,
has_sink: bool = False,
) -> type[AttentionBackend]:
# If there are no attention layers (e.g. we are running Mamba),
# use the placeholder NO_ATTENTION
if is_attention_free:
from vllm.attention.backends.placeholder_attn import \
PlaceholderAttentionBackend
return PlaceholderAttentionBackend
# Check whether a particular choice of backend was
# previously forced.
#
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
# ENVIRONMENT VARIABLE.
selected_backend = None
backend_by_global_setting: Optional[_Backend] = (
get_global_forced_attn_backend())
if backend_by_global_setting is not None:
selected_backend = backend_by_global_setting
else:
# Check the environment variable and override if specified
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None:
raise ValueError(
f"Invalid attention backend: '{backend_by_env_var}'. "
f"Valid backends are: {list(_Backend.__members__.keys())}"
)
# get device-specific attn_backend
attention_cls = current_platform.get_attn_backend_cls(
selected_backend, head_size, dtype, kv_cache_dtype, block_size,
use_v1, use_mla, use_sfa, has_sink)
if not attention_cls:
raise ValueError(
f"Invalid attention backend for {current_platform.device_name}"
)
return resolve_obj_by_qualname(attention_cls)
else:
def get_attn_backend( # type: ignore[misc]
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
use_mla: bool = False,
use_sfa: bool = False,
has_sink: bool = False,
) -> type[AttentionBackend]:
"""Selects which attention backend to use and lazily imports it."""
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
# value to be returned from the cache if the value changes between calls.
# To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the
# private function.
return _cached_get_attn_backend(
head_size=head_size,
dtype=dtype,
kv_cache_dtype=kv_cache_dtype,
block_size=block_size,
use_v1=envs.VLLM_USE_V1,
use_mla=use_mla,
use_sfa=use_sfa,
has_sink=has_sink,
)
@cache
def _cached_get_attn_backend(
head_size: int,
dtype: torch.dtype,
kv_cache_dtype: Optional[str],
block_size: int,
use_v1: bool = False,
use_mla: bool = False,
use_sfa: bool = False,
has_sink: bool = False,
) -> type[AttentionBackend]:
# Check whether a particular choice of backend was
# previously forced.
#
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
# ENVIRONMENT VARIABLE.
selected_backend = None
backend_by_global_setting: Optional[_Backend] = (
get_global_forced_attn_backend())
if backend_by_global_setting is not None:
selected_backend = backend_by_global_setting
else:
# Check the environment variable and override if specified
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)
if selected_backend is None:
raise ValueError(
f"Invalid attention backend: '{backend_by_env_var}'. "
f"Valid backends are: {list(_Backend.__members__.keys())}"
)
# get device-specific attn_backend
attention_cls = current_platform.get_attn_backend_cls(
selected_backend, head_size, dtype, kv_cache_dtype, block_size,
use_v1, use_mla, use_sfa, has_sink)
if not attention_cls:
raise ValueError(
f"Invalid attention backend for {current_platform.device_name}"
)
return resolve_obj_by_qualname(attention_cls)
vllm.attention.get_attn_backend = get_attn_backend
vllm.attention.selector._cached_get_attn_backend = _cached_get_attn_backend

View File

@@ -0,0 +1,110 @@
from dataclasses import dataclass, fields
from typing import Optional
import torch
import vllm
from typing_extensions import Self
from vllm.config import VllmConfig
from vllm.utils import cdiv, get_dtype_size
from vllm.v1.core.single_type_kv_cache_manager import (FullAttentionManager,
spec_manager_map)
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec
@dataclass(frozen=True)
class AttentionSpec(KVCacheSpec):
num_kv_heads: int
head_size: int
dtype: torch.dtype
use_mla: bool
use_sfa: bool
@property
def page_size_bytes(self) -> int:
# For MLA we only store a single latent vector
coef = 1 if self.use_mla else 2
sfa_bytes = 128 * self.block_size * get_dtype_size(
self.dtype) if self.use_sfa else 0
return coef * self.block_size * self.num_kv_heads * self.head_size \
* get_dtype_size(self.dtype) + sfa_bytes
vllm.v1.kv_cache_interface.AttentionSpec = AttentionSpec
@dataclass(frozen=True)
class AscendFullAttentionSpec(FullAttentionSpec, AttentionSpec):
sliding_window: Optional[int] = None
attention_chunk_size: Optional[int] = None
"""
When hybrid allocator is disabled and the model contains both full
attention layers and sliding window attention layers, sliding
window attention are regarded as full attention in KV cache manager
(blocks are allocated for all tokens), while computed as sliding window
attention in model runner.
In this case, we use FullAttentionSpec and record the sliding window size.
Default to None for not using sliding window attention.
"""
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
dcp_world_size = \
vllm_config.parallel_config.decode_context_parallel_size
# Note(hc): each dcp rank only need save
# (max_model_len//dcp_world_size) tokens locally.
if dcp_world_size > 1:
max_model_len = cdiv(max_model_len, dcp_world_size)
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
@classmethod
def merge_window_sizes(cls, window_sizes: set[int]) -> Optional[int]:
if len(window_sizes) == 0:
return None
elif len(window_sizes) == 1:
return window_sizes.pop()
else:
raise ValueError(
"All attention layers in the same KV cache group must have the "
"same window size.")
@classmethod
def merge(cls, specs: list[Self]) -> Self:
"""
Merge a list of FullAttentionSpec objects into a single
FullAttentionSpec object.
"""
assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
"All attention layers in the same KV cache group must be "
"FullAttentionSpec.")
sliding_window = set(spec.sliding_window for spec in specs
if spec.sliding_window is not None)
attention_chunk_size = set(spec.attention_chunk_size for spec in specs
if spec.attention_chunk_size is not None)
merged_spec = cls(
block_size=specs[0].block_size,
num_kv_heads=specs[0].num_kv_heads,
head_size=specs[0].head_size,
dtype=specs[0].dtype,
use_mla=specs[0].use_mla,
use_sfa=specs[0].use_sfa,
sliding_window=cls.merge_window_sizes(sliding_window),
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
)
for spec in specs:
for f in fields(AttentionSpec):
assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
"All attention layers in the same KV cache group must have "
"the same attention spec.")
assert (
(merged_spec.sliding_window is not None) +
(merged_spec.attention_chunk_size is not None) <= 1
), ("Model with both sliding window layers and chunked local attention "
"layers is not supported.")
return merged_spec
spec_manager_map.update({AscendFullAttentionSpec: FullAttentionManager})
vllm.v1.kv_cache_interface.FullAttentionSpec = AscendFullAttentionSpec

View File

@@ -300,6 +300,7 @@ class NPUPlatform(Platform):
block_size,
use_v1,
use_mla,
use_sfa,
has_sink=False):
if not use_v1:
raise ValueError("vLLM Ascend does not support V0 engine.")
@@ -307,21 +308,28 @@ class NPUPlatform(Platform):
ascend_config = get_ascend_config()
if use_mla and ascend_config.enable_shared_expert_dp:
return "vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend"
if use_mla and not use_sfa:
return "vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend"
if use_mla and use_sfa:
return "vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend"
use_torchair = ascend_config.torchair_graph_config.enabled
# choose attention backend based on use_mla and use_torchair
backend_map = {
(True, True):
(True, False, True):
"vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend",
(True, False):
(True, False, False):
"vllm_ascend.attention.mla_v1.AscendMLABackend",
(False, True):
(False, False, True):
"vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend",
(False, False):
"vllm_ascend.attention.attention_v1.AscendAttentionBackend"
(False, False, False):
"vllm_ascend.attention.attention_v1.AscendAttentionBackend",
(True, True, False):
"vllm_ascend.attention.sfa_v1.AscendSFABackend",
(True, True, True):
"vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend",
}
return backend_map[(use_mla, use_torchair)]
return backend_map[(use_mla, use_sfa, use_torchair)]
@classmethod
def get_punica_wrapper(cls) -> str:

View File

@@ -603,10 +603,11 @@ class MtpProposer(Proposer):
torch.npu.set_compile_mode(jit_compile=False)
if not self.runner.use_cached_npu_graph:
npu_backend = torchair.get_npu_backend(compiler_config=config)
self.torchair_compiled_model = torch.compile(self.model,
dynamic=True,
fullgraph=True,
backend=npu_backend)
self.torchair_compiled_model = torch.compile(
self.model,
dynamic=not get_ascend_config().use_sfa,
fullgraph=True,
backend=npu_backend)
return self.torchair_compiled_model
else:
# Generate a new forward proxy code object to prevent the invalidation of
@@ -627,7 +628,7 @@ class MtpProposer(Proposer):
self.torchair_compiled_models[
batch_size] = torchair.inference.cache_compile(
self.model.__dict__[forward_proxy_name],
dynamic=True,
dynamic=not get_ascend_config().use_sfa,
fullgraph=True,
cache_dir=TORCHAIR_CACHE_DIR,
config=config,

View File

@@ -67,7 +67,9 @@ from vllm.model_executor.models.utils import (
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
from vllm.sequence import IntermediateTensors
from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.models.layers.sfa import Indexer
from vllm_ascend.quantization.quant_config import AscendLinearMethod
from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \
@@ -435,6 +437,7 @@ class TorchairDeepseekV2MLAAttention(DeepseekV2MLAAttention):
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
decoder_layer=None,
) -> None:
nn.Module.__init__(self)
self.hidden_size = hidden_size
@@ -630,6 +633,225 @@ class TorchairDeepseekV2MLAAttention(DeepseekV2MLAAttention):
output_shape=output_shape)
class TorchairDeepseekV2SFAAttention(DeepseekV2MLAAttention):
def __init__(
self,
config: PretrainedConfig,
hidden_size: int,
num_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: Optional[int],
kv_lora_rank: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
decoder_layer=None,
) -> None:
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.num_heads = num_heads
self.tp_size = get_tensor_model_parallel_world_size()
assert num_heads % self.tp_size == 0
self.num_local_heads = num_heads // self.tp_size
self.layers = config.num_hidden_layers
self.first_k_dense_replace = config.first_k_dense_replace
self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.prefix = prefix
self.debug_layer_idx = int(self.prefix.split(".")[-2])
ascend_config = get_ascend_config()
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(
self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_a_proj",
return_bias=False,
)
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(
q_lora_rank,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_b_proj",
return_bias=False,
)
else:
self.q_proj = ColumnParallelLinear(
self.hidden_size,
self.num_heads * self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj",
return_bias=False,
)
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_a_proj_with_mqa",
return_bias=False,
)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj",
return_bias=False,
)
if (config.n_routed_experts is not None
and self.debug_layer_idx >= config.first_k_dense_replace
and self.debug_layer_idx % config.moe_layer_freq == 0
and (ascend_config.multistream_overlap_shared_expert
or self.enable_shared_expert_dp)):
self.o_proj = TorchairDeepseekV2RowParallelLinearReplaceAllreduce(
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
return_bias=False,
)
else:
self.o_proj = TorchairDeepseekV2RowParallelLinear(
self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
return_bias=False,
)
if rope_scaling:
rope_scaling["rope_type"] = 'deepseek_yarn'
self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=False)
if rope_scaling:
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
scaling_factor = rope_scaling["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
self.dim: int = config.hidden_size # 7168
# TODO(zzzzwwjj): wait transformers add these params
self.n_heads: int = 64 # 64
self.head_dim: int = 128 # 128
self.index_topk: int = 2048 # 2048
self.indexer = Indexer(
config,
quant_config=quant_config,
dim=self.dim,
n_heads=self.n_heads,
head_dim=self.head_dim,
index_topk=self.index_topk,
prefix=f"{prefix}.indexer",
)
self.sfa_attn = Attention(
num_heads=self.num_local_heads,
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
scale=self.scaling,
num_kv_heads=1,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_mla=True,
use_sfa=True,
# SFA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
qk_head_dim=self.qk_head_dim,
v_head_dim=self.v_head_dim,
rotary_emb=self.rotary_emb,
q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None,
q_a_layernorm=self.q_a_layernorm
if self.q_lora_rank is not None else None,
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
kv_a_layernorm=self.kv_a_layernorm,
kv_b_proj=self.kv_b_proj,
o_proj=self.o_proj,
indexer=self.indexer,
decoder_layer=decoder_layer,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: Optional[torch.Tensor] = None,
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
forward_context = get_forward_context()
if not self.torchair_graph_enabled:
if forward_context.attn_metadata is not None and isinstance(
forward_context.attn_metadata, dict):
attn_metadata = next(
iter(forward_context.attn_metadata.values()), None)
else:
attn_metadata = forward_context.attn_metadata
if kv_cache is None:
kv_cache = self.sfa_attn.kv_cache[
forward_context.virtual_engine]
num_tokens = hidden_states.shape[0]
need_gather_q_kv = False
# if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
# # Simulate all gather to calculate output shape
# num_tokens = num_tokens * self.tp_size
# need_gather_q_kv = True
if not self.enable_shared_expert_dp or self.debug_layer_idx != self.first_k_dense_replace:
output_shape = hidden_states.shape
if self.enable_shared_expert_dp and (
self.debug_layer_idx == self.first_k_dense_replace
or self.debug_layer_idx == self.layers):
rows = num_tokens // self.tp_size
if num_tokens % self.tp_size:
rows += 1
output_shape = (rows, hidden_states.shape[1])
output = torch.empty(output_shape,
dtype=hidden_states.dtype,
device=hidden_states.device)
self.sfa_attn.impl.forward(hidden_states, kv_cache, attn_metadata,
need_gather_q_kv, output)
output = output.view(-1, output_shape[-1])
return output
class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
def __init__(
@@ -654,9 +876,16 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tp_group().rank_in_group
ascend_config = get_ascend_config()
self.use_mla = False
self.use_sfa = False
# TODO: enable mla in vllm-ascend
if model_config.use_mla:
attn_cls = TorchairDeepseekV2MLAAttention
if ascend_config.use_sfa:
attn_cls = TorchairDeepseekV2SFAAttention
self.use_sfa = True
else:
attn_cls = TorchairDeepseekV2MLAAttention # type: ignore[assignment]
self.use_mla = True
else:
attn_cls = DeepseekV2Attention
self.self_attn = attn_cls(
@@ -675,6 +904,7 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.self_attn",
decoder_layer=self,
)
if (config.n_routed_experts is not None
@@ -715,21 +945,34 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
replace_allreduce: bool = False,
) -> torch.Tensor:
# Self Attention
if attn_metadata is not None and attn_metadata.num_decodes > 0:
mla_moe_communication = self.mla_moe_communication and replace_allreduce
if attn_metadata is not None:
decoding_condition_met = (
not attn_metadata.is_prefill if self.use_sfa else
attn_metadata.num_decodes > 0 if self.use_mla else False)
mla_moe_communication = decoding_condition_met and self.mla_moe_communication and replace_allreduce
else:
mla_moe_communication = False
if residual is None:
forward_context = get_forward_context()
if (envs.VLLM_ASCEND_ENABLE_MLAPO
and isinstance(self.self_attn, TorchairDeepseekV2SFAAttention)
and attn_metadata is not None
and not forward_context.with_prefill):
if residual is not None:
hidden_states = hidden_states + residual
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
previous_hidden_states, previous_residual = hidden_states, residual
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
# Dispose hidden_states and residual from the previous layer
# to save npu memory because they're no longer used.
dispose_tensor(previous_hidden_states)
dispose_tensor(previous_residual)
if residual is None:
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
previous_hidden_states, previous_residual = hidden_states, residual
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
# Dispose hidden_states and residual from the previous layer
# to save npu memory because they're no longer used.
dispose_tensor(previous_hidden_states)
dispose_tensor(previous_residual)
if mla_moe_communication and self.layer_idx > self.first_k_dense_replace:
hidden_states = tensor_model_parallel_all_gather(hidden_states,
dim=0)

View File

@@ -48,8 +48,8 @@ from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
class NPUTorchairModelRunner(NPUModelRunner):
def __init__(self, vllm_config: VllmConfig, device: torch.device):
ascend_config = get_ascend_config()
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
self.ascend_config = get_ascend_config()
self.enable_shared_expert_dp = self.ascend_config.enable_shared_expert_dp
super().__init__(vllm_config, device)
if self.speculative_config:
self.actual_seq_lengths_q = list(
@@ -66,10 +66,10 @@ class NPUTorchairModelRunner(NPUModelRunner):
self.new_kv_cache_bytes = -1
self.torchair_compiled_model = None # type: ignore
self.torchair_compiled_models = {} # type: ignore
self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph
self.use_cached_kv_cache_bytes = ascend_config.torchair_graph_config.use_cached_kv_cache_bytes
self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes
if ascend_config.torchair_graph_config.graph_batch_sizes_init:
self.use_cached_npu_graph = self.ascend_config.torchair_graph_config.use_cached_graph
self.use_cached_kv_cache_bytes = self.ascend_config.torchair_graph_config.use_cached_kv_cache_bytes
self.torchair_graph_batch_sizes = self.ascend_config.torchair_graph_config.graph_batch_sizes
if self.ascend_config.torchair_graph_config.graph_batch_sizes_init:
self.init_torchair_graph_batch_sizes()
self.update_torchair_graph_batch_sizes()
@@ -362,22 +362,23 @@ class NPUTorchairModelRunner(NPUModelRunner):
communication_adaptation_310p()
config = torchair.CompilerConfig()
if get_ascend_config().torchair_graph_config.mode:
config.mode = get_ascend_config().torchair_graph_config.mode
if self.ascend_config.torchair_graph_config.mode:
config.mode = self.ascend_config.torchair_graph_config.mode
config.experimental_config.frozen_parameter = \
get_ascend_config().torchair_graph_config.enable_frozen_parameter
self.ascend_config.torchair_graph_config.enable_frozen_parameter
# enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to
# disable it on 300I Duo platform now.
config.experimental_config.tiling_schedule_optimize = not is_310p()
config.experimental_config.enable_view_optimize = \
get_ascend_config().torchair_graph_config.enable_view_optimize
self.ascend_config.torchair_graph_config.enable_view_optimize
torch.npu.set_compile_mode(jit_compile=False)
if not self.use_cached_npu_graph:
npu_backend = torchair.get_npu_backend(compiler_config=config)
self.torchair_compiled_model = torch.compile(self.model,
dynamic=True,
fullgraph=True,
backend=npu_backend)
self.torchair_compiled_model = torch.compile(
self.model,
dynamic=not self.ascend_config.use_sfa,
fullgraph=True,
backend=npu_backend)
return self.torchair_compiled_model
else:
# Generate a new forward proxy code object to prevent the invalidation of
@@ -398,7 +399,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
self.torchair_compiled_models[
batch_size] = torchair.inference.cache_compile(
self.model.__dict__[forward_proxy_name],
dynamic=True,
dynamic=not self.ascend_config.use_sfa,
fullgraph=True,
cache_dir=TORCHAIR_CACHE_DIR,
config=config,

File diff suppressed because it is too large Load Diff

View File

@@ -165,6 +165,11 @@ def register_torchair_model():
"vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM"
)
ModelRegistry.register_model(
"DeepseekV32ForCausalLM",
"vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM"
)
ModelRegistry.register_model(
"Qwen2ForCausalLM",
"vllm_ascend.torchair.models.qwen2:CustomQwen2ForCausalLM")

View File

@@ -285,8 +285,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.intermediate_tensors: Optional[IntermediateTensors] = None
self.runner_only_attn_layers: set[str] = set()
ascend_config = get_ascend_config()
if ascend_config.ascend_scheduler_config.enabled:
self.ascend_config = get_ascend_config()
if self.ascend_config.ascend_scheduler_config.enabled:
self.chunked_prefill_enabled = self.scheduler_config.chunked_prefill_enabled
else:
self.chunked_prefill_enabled = True
@@ -298,6 +298,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.cache_config.cache_dtype]
# use_hybrid_blocks: if hybrid blocks is used.
self.use_hybrid_blocks: bool = False
self.need_accepted_tokens: bool = False
self.is_multimodal_model = self.model_config.is_multimodal_model
self.is_pooling_model = self.model_config.pooler_config is not None
@@ -315,7 +316,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.block_size,
self.model_config.is_attention_free,
use_mla=self.model_config.use_mla,
)
use_sfa=self.ascend_config.use_sfa)
else:
self.attn_backend = get_attn_backend(
0,
@@ -323,7 +324,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
None,
self.block_size,
use_mla=self.model_config.use_mla,
)
use_sfa=self.ascend_config.use_sfa)
if torch.version.cann.startswith("8.3"):
self.attn_mask_builder = AttentionMaskBuilder(
self.scheduler_config.max_num_batched_tokens, self.dtype,
@@ -457,7 +458,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
dtype=torch.bool,
device=self.device,
)
self.dynamic_eplb = ascend_config.dynamic_eplb
self.dynamic_eplb = self.ascend_config.dynamic_eplb
if self.dynamic_eplb:
self.is_eplb_warmuped = False
self.eplb_loader = D2DExpertWeightLoader()
@@ -890,15 +891,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
def _make_attention_mask(self, seq_lens, position,
attn_state) -> torch.Tensor:
# Chunk Prefill situation.
if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla:
if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.ascend_config.use_sfa:
if torch.version.cann.startswith("8.3"):
return self.attn_mask_builder.get_splitfuse_attn_mask()
else:
return self.attn_mask_builder.get_splitfuse_attn_mask(
seq_lens, position, self.dtype, self.device)
# Prefill without cache situation.
elif attn_state == AscendAttentionState.PrefillNoCache:
max_seq_len = max(seq_lens, default=0)
max_seq_len = max(seq_lens.max().item(), 0)
return self.attn_mask_builder.get_attn_mask(
max_seq_len, self.dtype, self.device)
# Prefill with cache hit.
@@ -1252,7 +1254,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
req_ids = self.input_batch.req_ids
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
max_num_scheduled_tokens = max(tokens)
max_num_scheduled_tokens = num_scheduled_tokens.max()
num_valid_tokens = np.array([
num_tokens -
len(scheduler_output.scheduled_spec_decode_tokens.get(i, []))
@@ -1376,8 +1378,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
positions_cpu = self.positions_cpu[:num_input_tokens]
positions = self.positions[:num_input_tokens]
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
num_valid_tokens)
self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
position=positions_cpu,
attn_state=attn_state)
@@ -1477,7 +1477,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_computed_tokens_cpu = (
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
spec_decode_common_attn_metadata = None
if use_spec_decode:
if use_spec_decode and self.need_accepted_tokens:
self.num_accepted_tokens.np[:num_reqs] = (
self.input_batch.num_accepted_tokens_cpu[:num_reqs])
self.num_accepted_tokens.np[num_reqs:].fill(1)
@@ -1550,7 +1550,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
model=self.model,
**extra_attn_metadata_args)
if self.vllm_config.model_config.use_mla:
if self.vllm_config.model_config.use_mla or self.ascend_config.use_sfa:
attn_metadata_i.num_input_tokens = num_input_tokens
for layer_name in attn_group.layer_names:
attn_metadata[layer_name] = attn_metadata_i
@@ -2060,7 +2060,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
sampling_metadata,
)
sampler_output.sampled_token_ids = output_token_ids
self._update_states_after_model_execute(output_token_ids)
if self.need_accepted_tokens:
self._update_states_after_model_execute(output_token_ids)
discard_sampled_tokens_req_indices: list[int] = []
# TODO(woosuk): The following loop can be slow since it iterates over
@@ -2683,10 +2684,26 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.kv_cache_config = kv_cache_config
self.initialize_attn_backend(kv_cache_config)
self.use_hybrid_blocks = (len(self.attn_groups) > 1)
# NOTE: Currently, we determine whether we need `num_accepted_tokens` through `MambaSpec`.
if vllm_version_is("0.10.2"):
self.need_accepted_tokens = any([
isinstance(
self.kv_cache_config.kv_cache_groups[0].kv_cache_spec,
MambaSpec) for attn_group in self.attn_groups
])
else:
self.need_accepted_tokens = any([
isinstance(attn_group[0].kv_cache_spec, MambaSpec)
for attn_group in self.attn_groups
])
self.may_reinitialize_input_batch(kv_cache_config)
if self.model_config.is_deepseek_mla:
kv_caches = self.initialize_kv_cache_tensors_deepseek(
if self.ascend_config.is_deepseek_sfa:
kv_caches = self.initialize_kv_cache_tensors_deepseek_sfa(
kv_cache_config)
elif self.model_config.is_deepseek_mla:
kv_caches = self.initialize_kv_cache_tensors_deepseek_mla(
kv_cache_config)
else:
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
@@ -2701,7 +2718,116 @@ class NPUModelRunner(LoRAModelRunnerMixin):
offset = (aligned_addr - data_ptr) // tensor.element_size()
return tensor[int(offset):]
def initialize_kv_cache_tensors_deepseek(
def initialize_kv_cache_tensors_deepseek_sfa(
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
kv_cache_sizes = {}
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
assert len(kv_cache_tensor.shared_by) == 1, (
"KV cache tensor shared by multiple layers is not supported in "
"NPU.")
kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
kv_caches: Dict[str, torch.Tensor] = {}
for group in self._kv_cache_spec_attn_group_iterator_dispatcher():
if vllm_version_is("0.10.2"):
kv_cache_spec, group = group
else:
kv_cache_spec = group.kv_cache_spec
attn_backend = group.backend
for layer_name in group.layer_names:
if layer_name in self.runner_only_attn_layers:
continue
tensor_size = kv_cache_sizes[layer_name]
num_blocks = tensor_size // kv_cache_spec.page_size_bytes
if self.vllm_config.additional_config.get(
"kv_cache_dtype", None) == 'int8':
kv_cache_shape = attn_backend.get_bsh_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
elif hasattr(
attn_backend, "get_supported_block_size"
) and not self.model_config.is_deepseek_mla and not self.ascend_config.is_deepseek_sfa:
block_size = attn_backend.get_supported_block_size()[0]
block_size_chunk = kv_cache_spec.block_size // block_size
kv_cache_shape = attn_backend.get_kv_cache_shape(
num_blocks * block_size_chunk, block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
else:
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
dtype = kv_cache_spec.dtype
alignment = 2 * 1024 * 1024
num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
nope_dim = head_size - rope_dim
nope_cache_shape = (num_blocks, block_size, num_kv_heads,
nope_dim)
rope_cache_shape = (num_blocks, block_size, num_kv_heads,
rope_dim)
#### k cache
# TODO(zzzzwwjj): wait transformers add these params
k_cache_shape = (num_blocks, block_size, 1, 128)
if self.vllm_config.kv_transfer_config is None:
# For no disaggregate pd scenario, allocate kv cache in normal way
rope_cache = torch.zeros(rope_cache_shape,
dtype=dtype,
device=self.device)
nope_cache = torch.zeros(nope_cache_shape,
dtype=dtype,
device=self.device)
rope_cache = self._convert_torch_format(rope_cache)
nope_cache = self._convert_torch_format(nope_cache)
#### k cache
k_cache = torch.zeros(k_cache_shape,
dtype=dtype,
device=self.device)
k_cache = self._convert_torch_format(k_cache)
else:
# In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory
# address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but
# we found there are also some exceptions during test, so we manual align those memory here, this part
# of code may consume 2M * 2 * elem_size memory every layer.
nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim
nope_allocate_shape_alignment = nope_allocate_shape + alignment
rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim
rope_allocate_shape_alignment = rope_allocate_shape + alignment
nope_cache = torch.zeros(nope_allocate_shape_alignment,
dtype=dtype,
device=self.device)
rope_cache = torch.zeros(rope_allocate_shape_alignment,
dtype=dtype,
device=self.device)
#### k cache
# TODO(zzzzwwjj): wait transformers add these params
k_allocate_shape = num_blocks * block_size * 1 * 128
k_allocate_shape_alignment = k_allocate_shape + alignment
k_cache = torch.zeros(k_allocate_shape_alignment,
dtype=dtype,
device=self.device)
nope_cache = self._align_memory(
nope_cache,
alignment)[:nope_allocate_shape].view(nope_cache_shape)
rope_cache = self._align_memory(
rope_cache,
alignment)[:rope_allocate_shape].view(rope_cache_shape)
k_cache = self._align_memory(
k_cache,
alignment)[:k_allocate_shape].view(k_cache_shape)
kv_caches[layer_name] = (nope_cache, rope_cache, k_cache)
bind_kv_cache(kv_caches,
self.compilation_config.static_forward_context,
self.kv_caches)
return kv_caches
def initialize_kv_cache_tensors_deepseek_mla(
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
kv_cache_sizes = {}
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
@@ -3217,6 +3343,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
block_size = self.vllm_config.cache_config.block_size
use_mla = self.vllm_config.model_config.use_mla
use_sfa = self.ascend_config.use_sfa
kv_cache_spec: dict[str, KVCacheSpec] = {}
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
for layer_name, attn_module in attn_layers.items():
@@ -3243,7 +3370,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_kv_heads=attn_module.num_kv_heads,
head_size=attn_module.head_size,
dtype=self.kv_cache_dtype,
use_mla=use_mla)
use_mla=use_mla,
use_sfa=use_sfa)
elif attn_module.attn_type in (AttentionType.ENCODER,
AttentionType.ENCODER_ONLY):
# encoder-only attention does not need KV cache.

View File

@@ -43,7 +43,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
from vllm.v1.worker.worker_base import WorkerBase
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import init_ascend_config
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
from vllm_ascend.device_allocator.camem import CaMemAllocator
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
from vllm_ascend.platform import NPUPlatform
@@ -88,6 +88,17 @@ class NPUWorker(WorkerBase):
# init ascend config and soc version
init_ascend_config(vllm_config)
init_ascend_soc_version()
if get_ascend_config().use_sfa:
# Direct import instead of using try_register_lib to ensure proper error handling when
# custom_ops is necessary but not available (e.g., in DeepSeek v3.2 deployments)
# yapf: disable
import custom_ops # type: ignore # noqa
# yapf: enable
logger.info(
"custom_ops module loaded successfully. Custom operators like "
"torch.ops.custom.npu_sparse_flash_attention are now available."
)
super().__init__(vllm_config=vllm_config,
local_rank=local_rank,