Add DeepSeek V3.2 support (#3270)
### What this PR does / why we need it? This PR added the initial DeepSeek V3.2 support with [vLLM v0.11.0](https://github.com/vllm-project/vllm/tree/releases/v0.11.0) (not released yet). We will complete vLLM adaptation as soon as possible. This feature will be ready in recent 1-2 days. Related doc: https://github.com/vllm-project/vllm-ascend/pull/3223 . ### Does this PR introduce _any_ user-facing change? Yes! ### How was this patch tested? CI passed and Run deepseek doc soon. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/releases/v0.11.0 --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: zzzzwwjj <1183291235@qq.com> Signed-off-by: linfeng-yuan <1102311262@qq.com> Signed-off-by: wxsIcey <1790571317@qq.com> Signed-off-by: MengqingCao <cmq0113@163.com> Co-authored-by: zzzzwwjj <1183291235@qq.com> Co-authored-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: wxsIcey <1790571317@qq.com> Co-authored-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
9
.github/workflows/vllm_ascend_test.yaml
vendored
9
.github/workflows/vllm_ascend_test.yaml
vendored
@@ -121,7 +121,14 @@ jobs:
|
||||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/devlib
|
||||
pytest -sv --cov --cov-report=xml:unittests-coverage.xml tests/ut \
|
||||
--ignore=tests/ut/test_platform.py \
|
||||
--ignore=tests/ut/patch/worker/patch_common/test_patch_minicpm.py
|
||||
--ignore=tests/ut/patch/worker/patch_common/test_patch_minicpm.py \
|
||||
--ignore=tests/ut/core/test_scheduler.py \
|
||||
--ignore=tests/ut/kv_connector/test_llmdatadist_connector.py \
|
||||
--ignore=tests/ut/kv_connector/test_mooncake_connector.py \
|
||||
--ignore=tests/ut/kv_connector/test_remote_decode_lifecycle.py \
|
||||
--ignore=tests/ut/kv_connector/test_remote_prefill_lifecycle.py \
|
||||
--ignore=tests/ut/torchair/models/test_torchair_deepseek_v2.py \
|
||||
--ignore=tests/ut/torchair/test_utils.py
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
# only upload coverage when commits merged
|
||||
|
||||
@@ -23,5 +23,7 @@ def register():
|
||||
|
||||
|
||||
def register_model():
|
||||
import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa
|
||||
|
||||
from .models import register_model
|
||||
register_model()
|
||||
|
||||
@@ -34,6 +34,8 @@ class AscendConfig:
|
||||
|
||||
def __init__(self, vllm_config):
|
||||
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
|
||||
self.is_deepseek_sfa = vllm_config.model_config is not None and vllm_config.model_config.is_deepseek_mla and vllm_config.model_config.hf_text_config.model_type == "deepseek_v32"
|
||||
self.use_sfa = self.is_deepseek_sfa
|
||||
|
||||
torchair_graph_config = additional_config.get("torchair_graph_config",
|
||||
{})
|
||||
|
||||
@@ -73,7 +73,7 @@ class AttentionMaskBuilder:
|
||||
device: torch.device):
|
||||
self._update_attn_cache(max_seq_len, dtype)
|
||||
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous(
|
||||
).to(device)
|
||||
).to(device, non_blocking=True)
|
||||
|
||||
def get_splitfuse_attn_mask(
|
||||
self,
|
||||
|
||||
986
vllm_ascend/attention/sfa_v1.py
Normal file
986
vllm_ascend/attention/sfa_v1.py
Normal file
@@ -0,0 +1,986 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type,
|
||||
TypeVar)
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from torch import nn
|
||||
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||
AttentionMetadata,
|
||||
MLAAttentionImpl)
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
||||
from vllm.model_executor.layers.linear import (LinearBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.utils import cdiv, round_down
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
|
||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
|
||||
class AscendSFABackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
|
||||
@staticmethod
|
||||
def get_name() -> str:
|
||||
return "ASCEND_SFA"
|
||||
|
||||
@staticmethod
|
||||
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||
return AscendSFAMetadata
|
||||
|
||||
@staticmethod
|
||||
def get_builder_cls():
|
||||
return AscendSFAMetadataBuilder
|
||||
|
||||
@staticmethod
|
||||
def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int,
|
||||
head_size: int) -> tuple[int, ...]:
|
||||
return (num_blocks, block_size, num_kv_heads, head_size)
|
||||
|
||||
@staticmethod
|
||||
def get_impl_cls() -> Type["AscendSFAImpl"]:
|
||||
return AscendSFAImpl
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendSFAPrefillMetadata:
|
||||
""" Prefill Specific Metadata for Ascend"""
|
||||
|
||||
@dataclass
|
||||
class ChunkedContextMetadata:
|
||||
# New for MLA (compared to FlashAttention)
|
||||
# For handling chunked prefill
|
||||
cu_seq_lens: torch.Tensor
|
||||
starts: torch.Tensor
|
||||
seq_tot: list[int]
|
||||
max_seq_lens: list[int]
|
||||
workspace: torch.Tensor
|
||||
chunk_seq_lens: torch.Tensor
|
||||
|
||||
attn_mask: torch.Tensor
|
||||
query_lens: list[int]
|
||||
seq_lens: list[int]
|
||||
|
||||
context_lens: torch.Tensor
|
||||
input_positions: torch.Tensor
|
||||
query_start_loc: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
max_query_len: int
|
||||
max_seq_lens: int
|
||||
sin: torch.Tensor
|
||||
cos: torch.Tensor
|
||||
chunked_context: Optional[ChunkedContextMetadata] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendSFADecodeMetadata:
|
||||
# Input positions for rotrary embeddings since for MLA the rotary
|
||||
# position embeddings are applied inside the attention backend
|
||||
input_positions: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
max_seq_lens: int
|
||||
seq_lens_list: list[int]
|
||||
actual_seq_lengths_q: torch.Tensor
|
||||
sin: torch.Tensor
|
||||
cos: torch.Tensor
|
||||
attn_mask: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendSFAMetadata:
|
||||
"""Metadata for MLACommon.
|
||||
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
"""
|
||||
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||
# |---------- N-1 iteration --------|
|
||||
# |---------------- N iteration ---------------------|
|
||||
# |- tokenA -|......................|-- newTokens ---|
|
||||
# |---------- context_len ----------|
|
||||
# |-------------------- seq_len ---------------------|
|
||||
# |-- query_len ---|
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
slot_mapping: torch.Tensor
|
||||
query_start_loc: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
block_tables: torch.Tensor
|
||||
|
||||
# New for MLA (compared to FlashAttention)
|
||||
# For handling prefill decode split
|
||||
num_decodes: int
|
||||
num_decode_tokens: int
|
||||
num_prefills: int
|
||||
|
||||
# For logging.
|
||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
||||
|
||||
query_lens: Optional[list[int]] = None
|
||||
# The dimension of the attention heads
|
||||
head_dim: Optional[int] = None
|
||||
attn_mask: torch.Tensor = None
|
||||
# chunked prefill by default if no attn_states passed
|
||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
||||
|
||||
decode: Optional[AscendSFADecodeMetadata] = None
|
||||
prefill: Optional[AscendSFAPrefillMetadata] = None
|
||||
enable_dbo_across_dp: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
pass
|
||||
# supported_head_sizes = AscendMLABackend.get_supported_head_sizes()
|
||||
# if self.head_dim is not None and self.head_dim \
|
||||
# not in supported_head_sizes:
|
||||
# raise ValueError(
|
||||
# f"Only {supported_head_sizes} are supported for head_dim,",
|
||||
# f"received {self.head_dim}.")
|
||||
|
||||
def split_metadata_for_multistream(
|
||||
self,
|
||||
ms_split_config: MSAttentionMetadataSplitConfig,
|
||||
) -> list["AscendSFAMetadata"]:
|
||||
"""Split metadata for multi-stream with AscendSFAMetadata"""
|
||||
return model_input_split_v1_mla_attn(
|
||||
ms_split_config=ms_split_config,
|
||||
attn_metadata=self,
|
||||
_metadata_cls=AscendMLAMetadata,
|
||||
)
|
||||
|
||||
|
||||
M = TypeVar("M", bound=AscendSFAMetadata)
|
||||
|
||||
|
||||
class AscendSFAMetadataBuilder:
|
||||
# Does this backend/builder support ACL Graphs for attention (default: no).
|
||||
aclgraph_support: ClassVar[AttentionCGSupport] = \
|
||||
AttentionCGSupport.NEVER
|
||||
"""
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
"""
|
||||
|
||||
# _attn_mask_builder = None
|
||||
def __init__(self,
|
||||
kv_cache_spec,
|
||||
layer_names,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
metadata_cls: Optional[AscendSFAMetadata] = None):
|
||||
self.metadata_cls: Optional[AscendSFAMetadata] = metadata_cls \
|
||||
if metadata_cls is not None else AscendSFAMetadata # type: ignore
|
||||
self.vllm_config = vllm_config
|
||||
self.model_config = vllm_config.model_config
|
||||
self.device = device
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.max_blocks = (vllm_config.model_config.max_model_len +
|
||||
self.block_size - 1) // self.block_size
|
||||
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
|
||||
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.decode_threshold = 1
|
||||
if self.speculative_config:
|
||||
spec_token_num = self.speculative_config.num_speculative_tokens
|
||||
self.decode_threshold += spec_token_num
|
||||
assert self.decode_threshold <= 16, f"decode_threshold exceeded \
|
||||
npu_fused_infer_attention_score TND layout's limit of 16, \
|
||||
got {self.decode_threshold}"
|
||||
|
||||
if self.chunked_prefill_enabled:
|
||||
self.chunked_prefill_workspace_size = min(
|
||||
# Max sure there is enough for 8 full length request or at least
|
||||
# 4 pages of cache per request
|
||||
max(8 * self.model_config.max_model_len,
|
||||
4 * scheduler_config.max_num_seqs * self.block_size),
|
||||
# For long-context models try not to over-allocate limiting
|
||||
# kv-cache space, limiting it to 64k tokens,
|
||||
# which would result in the workspace being:
|
||||
# 2*(576)*(64*1024) = 144mb
|
||||
# (assuming 576 MLA head dim, and fp16)
|
||||
# which would result in up-projected context being
|
||||
# 2*(192*128)*(64*1024) = 3gb
|
||||
# (assuming 192 QK head dim, 128 heads, and fp16)
|
||||
128 * 1024)
|
||||
assert self.chunked_prefill_workspace_size >= \
|
||||
scheduler_config.max_num_seqs * self.block_size
|
||||
self.chunked_prefill_workspace = torch.empty(
|
||||
(self.chunked_prefill_workspace_size,
|
||||
self.model_config.get_head_size()),
|
||||
dtype=self.model_config.dtype,
|
||||
device=device,
|
||||
)
|
||||
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
||||
self.cos_cache = None
|
||||
self.sin_cache = None
|
||||
|
||||
def reorder_batch(self, input_batch: "InputBatch",
|
||||
scheduler_output: "SchedulerOutput") -> bool:
|
||||
# We now want to reorder the batch so that the "decode" requests are at
|
||||
# the front and the "prefill" requests are at the using the least amount
|
||||
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
|
||||
# where attention is likely memory-bound and "prefill" to mean requests
|
||||
# where attention is likely compute-bound, TODO(lucas): figure out a
|
||||
# better naming here)
|
||||
decodes = []
|
||||
prefills = []
|
||||
|
||||
for i, req_id in enumerate(input_batch.req_ids):
|
||||
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
if num_tokens <= self.decode_threshold:
|
||||
decodes.append(i)
|
||||
else:
|
||||
prefills.append(i)
|
||||
|
||||
# We hope that this is fairly minimal since decodes
|
||||
# should be around for a number of iterations so hopefully they are
|
||||
# relatively stationary (and new request are generally appended to the
|
||||
# persistent batch so already should be at the back)
|
||||
# To achieve this we loop over the decodes in descending order and
|
||||
# the prefills in ascending order. We swap decodes from the "back"
|
||||
# i.e. past where the last decode should be in the reodorered with
|
||||
# prefills from the front of the batch.
|
||||
# `decodes` and `prefills` are already in ascending order just based on
|
||||
# the above loop
|
||||
num_decodes = len(decodes)
|
||||
num_prefills = len(prefills)
|
||||
first_prefill = 0
|
||||
modified_batch = False
|
||||
|
||||
for i in range(1, min(num_decodes, num_prefills) + 1):
|
||||
# If the decode is at the "back" of the batch, i, we can swap it
|
||||
# with the prefill closest to the front of the batch
|
||||
if decodes[num_decodes - i] >= num_decodes:
|
||||
input_batch.swap_states(prefills[first_prefill],
|
||||
decodes[num_decodes - i])
|
||||
first_prefill += 1
|
||||
modified_batch = True
|
||||
else:
|
||||
break
|
||||
|
||||
# Save for next `build` call
|
||||
# TODO(lucas): this is a bit of a hack, we should probably have a
|
||||
# better way of doing this
|
||||
return modified_batch
|
||||
|
||||
def build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||
model: nn.Module,
|
||||
) -> AscendSFAMetadata:
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
|
||||
assert num_decodes + num_prefills == num_reqs
|
||||
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
|
||||
|
||||
# Note(simon): be careful about the CPU <> GPU memory movement in this
|
||||
# function. We should avoid GPU -> CPU sync as much as possible because
|
||||
# it blocks on all previous kernels.
|
||||
device = self.device
|
||||
|
||||
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
|
||||
slot_mapping = common_attn_metadata.slot_mapping[:
|
||||
num_actual_tokens].to(
|
||||
device,
|
||||
non_blocking=True)
|
||||
input_positions = common_attn_metadata.positions[:
|
||||
num_actual_tokens].long(
|
||||
)
|
||||
|
||||
if self.cos_cache is None:
|
||||
self.cos_cache = model.model.layers[
|
||||
0].self_attn.rotary_emb.cos_cached
|
||||
self.sin_cache = model.model.layers[
|
||||
0].self_attn.rotary_emb.sin_cached
|
||||
if self.cos_cache.dtype != self.model_config.dtype: # type: ignore
|
||||
self.cos_cache = self.cos_cache.to( # type: ignore
|
||||
self.model_config.dtype) # type: ignore
|
||||
self.sin_cache = self.sin_cache.to( # type: ignore
|
||||
self.model_config.dtype) # type: ignore
|
||||
|
||||
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||
query_lens = query_seq_lens_cpu[:num_reqs]
|
||||
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||
num_computed_tokens_cpu = (seq_lens - query_lens)
|
||||
|
||||
prefill_metadata = None
|
||||
chunked_context_metadata = None
|
||||
if num_prefills > 0:
|
||||
reqs_start = num_decodes # prefill_start
|
||||
tokens_start = num_decode_tokens
|
||||
max_query_len = query_lens[reqs_start:].max().item()
|
||||
max_seq_lens = seq_lens[reqs_start:].max().item()
|
||||
prefill_query_start_loc = query_start_loc[
|
||||
reqs_start:] - query_start_loc[reqs_start]
|
||||
|
||||
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
|
||||
max_context_len_cpu = context_lens_cpu.max().item()
|
||||
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
|
||||
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
|
||||
max_context_chunk = (self.chunked_prefill_workspace_size //
|
||||
num_prefills_with_context_cpu)
|
||||
max_context_chunk = round_down(max_context_chunk,
|
||||
self.block_size)
|
||||
|
||||
assert max_context_chunk > 0
|
||||
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
|
||||
chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \
|
||||
.unsqueeze(1).expand(-1, num_prefills) * max_context_chunk
|
||||
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
|
||||
chunk_starts + max_context_chunk)
|
||||
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
|
||||
cu_seq_lens_cpu = torch.zeros(num_chunks,
|
||||
num_prefills + 1,
|
||||
dtype=torch.int32,
|
||||
pin_memory=True)
|
||||
torch.cumsum(chunk_seq_lens,
|
||||
dim=1,
|
||||
out=cu_seq_lens_cpu[:, 1:],
|
||||
dtype=torch.int32)
|
||||
chunked_context_metadata = \
|
||||
AscendSFAPrefillMetadata.ChunkedContextMetadata(
|
||||
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
|
||||
starts=chunk_starts.to(device, non_blocking=True),
|
||||
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||
chunk_seq_lens=chunk_seq_lens,
|
||||
workspace=self.chunked_prefill_workspace,
|
||||
)
|
||||
prefill_input_positions = input_positions[tokens_start:]
|
||||
cos = self.cos_cache[
|
||||
prefill_input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
sin = self.sin_cache[
|
||||
prefill_input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
actual_query_lens = torch.tensor(query_lens[reqs_start:],
|
||||
dtype=torch.int32).npu()
|
||||
query_lens_prefill_sfa = torch.cumsum(actual_query_lens,
|
||||
dim=0).to(torch.int32)
|
||||
seq_lens_prefill_sfa = seq_lens[reqs_start:].to(torch.int32).npu()
|
||||
prefill_metadata = AscendSFAPrefillMetadata(
|
||||
attn_mask=common_attn_metadata.attn_mask,
|
||||
query_lens=query_lens_prefill_sfa,
|
||||
seq_lens=seq_lens_prefill_sfa,
|
||||
context_lens=seq_lens[reqs_start:],
|
||||
input_positions=prefill_input_positions,
|
||||
block_table=block_table[reqs_start:, ...],
|
||||
max_query_len=max_query_len,
|
||||
max_seq_lens=max_seq_lens,
|
||||
query_start_loc=prefill_query_start_loc,
|
||||
chunked_context=chunked_context_metadata,
|
||||
sin=sin,
|
||||
cos=cos,
|
||||
)
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
|
||||
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].to(
|
||||
torch.int32).npu()
|
||||
max_seq_lens = seq_lens[:num_decodes].max().item()
|
||||
seq_lens = seq_lens[:num_decodes].to(torch.int32).npu()
|
||||
input_positions = input_positions[:num_decode_tokens]
|
||||
block_table = block_table[:num_decodes, ...]
|
||||
seq_lens_list = seq_lens.tolist()
|
||||
|
||||
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
|
||||
1).unsqueeze(2)
|
||||
|
||||
decode_metadata = AscendSFADecodeMetadata(
|
||||
input_positions=input_positions,
|
||||
block_table=block_table,
|
||||
seq_lens=seq_lens,
|
||||
seq_lens_list=seq_lens_list,
|
||||
max_seq_lens=max_seq_lens,
|
||||
attn_mask=common_attn_metadata.spec_attn_mask,
|
||||
actual_seq_lengths_q=actual_seq_lengths_q,
|
||||
sin=sin,
|
||||
cos=cos)
|
||||
|
||||
return self.metadata_cls( # type: ignore
|
||||
num_actual_tokens=num_actual_tokens,
|
||||
query_lens=query_lens.tolist(),
|
||||
slot_mapping=slot_mapping,
|
||||
head_dim=self.model_config.get_head_size(),
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
attn_mask=common_attn_metadata.attn_mask,
|
||||
attn_state=common_attn_metadata.attn_state,
|
||||
prefill=prefill_metadata,
|
||||
decode=decode_metadata,
|
||||
query_start_loc=query_start_loc,
|
||||
block_tables=block_table,
|
||||
seq_lens=seq_lens,
|
||||
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
|
||||
)
|
||||
|
||||
|
||||
class PrefillSFAPreprocessResult(NamedTuple):
|
||||
q_nope: Optional[torch.Tensor] = None
|
||||
q_pe: Optional[torch.Tensor] = None
|
||||
k_nope: Optional[torch.Tensor] = None
|
||||
k_pe: Optional[torch.Tensor] = None
|
||||
topk_indices: Optional[torch.Tensor] = None
|
||||
query_states: Optional[torch.Tensor] = None
|
||||
key_states: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class DecodeSFAPreprocessResult(NamedTuple):
|
||||
q_nope: Optional[torch.Tensor] = None
|
||||
q_pe: Optional[torch.Tensor] = None
|
||||
# nope_cache: Optional[torch.Tensor] = None
|
||||
# rope_cache: Optional[torch.Tensor] = None
|
||||
topk_indices: Optional[torch.Tensor] = None
|
||||
query_states: Optional[torch.Tensor] = None
|
||||
key_states: Optional[torch.Tensor] = None
|
||||
bsz: Optional[int] = None
|
||||
|
||||
|
||||
class AscendSFAImpl(MLAAttentionImpl):
|
||||
"""
|
||||
NOTE: Please read the comment at the top of the file before trying to
|
||||
understand this class
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: int,
|
||||
alibi_slopes: Optional[list[float]],
|
||||
sliding_window: Optional[int],
|
||||
kv_cache_dtype: str,
|
||||
logits_soft_cap: Optional[float],
|
||||
attn_type: str,
|
||||
kv_sharing_target_layer_name: Optional[str],
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.scale = float(scale)
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
# MLA Args
|
||||
self.q_lora_rank = kwargs['q_lora_rank']
|
||||
self.kv_lora_rank = kwargs['kv_lora_rank']
|
||||
self.qk_nope_head_dim = kwargs['qk_nope_head_dim']
|
||||
self.qk_rope_head_dim = kwargs['qk_rope_head_dim']
|
||||
self.qk_head_dim = kwargs['qk_head_dim']
|
||||
self.v_head_dim = kwargs['v_head_dim']
|
||||
self.rotary_emb = kwargs['rotary_emb']
|
||||
self.q_proj = kwargs['q_proj']
|
||||
self.kv_b_proj = kwargs['kv_b_proj']
|
||||
self.o_proj = kwargs['o_proj']
|
||||
self.indexer = kwargs['indexer']
|
||||
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
|
||||
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
|
||||
self.q_a_proj = kwargs.get('q_a_proj', None)
|
||||
self.q_a_layernorm = kwargs.get('q_a_layernorm', None)
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_heads_per_rank = self.num_heads // self.tp_size
|
||||
if self.q_a_proj is not None:
|
||||
self.q_b_proj = self.q_proj
|
||||
else:
|
||||
self.q_b_proj = None
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
self.enable_prefetch = ascend_config.enable_prefetch
|
||||
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.ring_mla_mask_size = 512
|
||||
self.prefill_mask = None
|
||||
|
||||
# indexer param
|
||||
self.dim = self.indexer.dim
|
||||
self.n_heads: int = self.indexer.n_heads # 64
|
||||
self.head_dim: int = self.indexer.head_dim # 128
|
||||
self.index_topk: int = self.indexer.index_topk # 2048
|
||||
self.wq_b = self.indexer.wq_b
|
||||
self.wk = self.indexer.wk
|
||||
self.weights_proj = self.indexer.weights_proj
|
||||
self.k_norm = self.indexer.k_norm
|
||||
self.softmax_scale = self.indexer.softmax_scale
|
||||
|
||||
# Adapt torch air graph mode with spec decoding.
|
||||
speculative_config = vllm_config.speculative_config
|
||||
if speculative_config is not None:
|
||||
self.spec_token_num = speculative_config.num_speculative_tokens
|
||||
assert self.spec_token_num > 0
|
||||
|
||||
self.cp_size = 1
|
||||
|
||||
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||
|
||||
def get_layer_weight(layer):
|
||||
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
|
||||
for attr in WEIGHT_NAMES:
|
||||
if hasattr(layer, attr):
|
||||
return getattr(layer, attr)
|
||||
raise AttributeError(
|
||||
f"Layer '{layer}' has no recognized weight attribute:"
|
||||
f" {WEIGHT_NAMES}.")
|
||||
|
||||
def get_and_maybe_dequant_weights(layer: LinearBase):
|
||||
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
|
||||
# NOTE: This should only be used offline, since it's O(N^3)
|
||||
eye = torch.eye(layer.input_size_per_partition,
|
||||
dtype=act_dtype,
|
||||
device=get_layer_weight(layer).device)
|
||||
dequant_weights = layer.quant_method.apply(layer,
|
||||
eye,
|
||||
bias=None)
|
||||
del eye
|
||||
# standardize to (output, input)
|
||||
return dequant_weights.T
|
||||
return layer.weight
|
||||
|
||||
# we currently do not have quantized bmm's which are needed for
|
||||
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
|
||||
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
|
||||
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
||||
assert kv_b_proj_weight.shape == (
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
|
||||
f"{kv_b_proj_weight.shape=}, "
|
||||
f"{self.kv_lora_rank=}, "
|
||||
f"{self.num_heads=}, "
|
||||
f"{self.qk_nope_head_dim=}, "
|
||||
f"{self.v_head_dim=}")
|
||||
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads,
|
||||
self.qk_nope_head_dim + self.v_head_dim,
|
||||
)
|
||||
|
||||
self.kv_b_proj_w_k, self.kv_b_proj_w_v = kv_b_proj_weight.split(
|
||||
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
|
||||
# Convert from (L, N, V) to (N, L, V)
|
||||
self.kv_b_proj_w_v = self.kv_b_proj_w_v.transpose(0, 1).contiguous()
|
||||
# Convert from (L, N, P) to (N, P, L)
|
||||
self.kv_b_proj_w_k = self.kv_b_proj_w_k.permute(1, 2, 0).contiguous()
|
||||
|
||||
# Waiting for BMM NZ support
|
||||
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
|
||||
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
|
||||
|
||||
def _sfa_preprocess(self, hidden_states, kv_cache, attn_metadata,
|
||||
need_gather_q_kv):
|
||||
# SFA Preprocess:
|
||||
# 1. Perform q_a_proj and q_a_layernorm to obtain q_c
|
||||
# 2. Perform kv_a_proj_with_mqa to obtain kv_no_split
|
||||
# 3. If need_gather_q_kv, perform all_gather.
|
||||
# 4. Preprocess decode tokens, write kv cache and get:
|
||||
# decode_ql_nope, decode_q_pe, decode_k_pe, decode_k_nope
|
||||
# 5. Preprocess prefill tokens, write kv cache and get:
|
||||
# prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value
|
||||
has_decode = attn_metadata.num_decodes > 0
|
||||
has_prefill = attn_metadata.num_prefills > 0
|
||||
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
if need_gather_q_kv:
|
||||
# q_c = get_tp_group().all_gather(q_c, 0)
|
||||
# kv_no_split = get_tp_group().all_gather(kv_no_split, 0)
|
||||
hidden_states = get_tp_group().all_gather(hidden_states, 0)
|
||||
# hidden_states_decode = hidden_states[:num_decode_tokens]
|
||||
# if self.q_a_proj is not None:
|
||||
# npu_prefetch(self.q_a_proj.weight,
|
||||
# hidden_states,
|
||||
# enabled=self.enable_prefetch)
|
||||
# ckq = self.q_a_proj(hidden_states) # q down
|
||||
# q_c = self.q_a_layernorm(ckq) # q down layernorm
|
||||
# else:
|
||||
# q_c = hidden_states
|
||||
|
||||
# kv_no_split = self.kv_a_proj_with_mqa(hidden_states) # c_kv
|
||||
# Process for shared_expert_dp
|
||||
|
||||
decode_preprocess_res = None
|
||||
prefill_preprocess_res = None
|
||||
# Preprocess for decode tokens
|
||||
if has_decode:
|
||||
q_len = 1
|
||||
hidden_states_decode = hidden_states[:num_decode_tokens]
|
||||
decode_kq = self.q_a_proj(hidden_states_decode) # q down
|
||||
decode_q_c = self.q_a_layernorm(decode_kq) # q down layernorm
|
||||
decode_kv_no_split = self.kv_a_proj_with_mqa(
|
||||
hidden_states_decode) # c_kv
|
||||
|
||||
# decode_q_c = q_c[:num_decode_tokens]
|
||||
decode_slot_mapping = attn_metadata.slot_mapping[:
|
||||
num_decode_tokens]
|
||||
# decode_kv_no_split = decode_kv_no_split[:num_decode_tokens]
|
||||
|
||||
decode_q = self.q_b_proj(decode_q_c)
|
||||
bsz, _ = decode_q.shape
|
||||
decode_q = decode_q.view(bsz, self.num_heads, 1, self.qk_head_dim)
|
||||
decode_q_nope, decode_q_pe = torch.split(
|
||||
decode_q, [self.qk_nope_head_dim, self.qk_rope_head_dim],
|
||||
dim=-1)
|
||||
decode_q_nope = decode_q_nope.view(
|
||||
-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1)
|
||||
decode_q_nope = (torch.matmul(decode_q_nope,
|
||||
self.kv_b_proj_w_k).transpose(
|
||||
1,
|
||||
0).view(bsz, q_len,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank))
|
||||
|
||||
# stream2 kv
|
||||
key_cache = kv_cache[0]
|
||||
value_cache = kv_cache[1]
|
||||
cos = attn_metadata.decode.cos
|
||||
sin = attn_metadata.decode.sin
|
||||
cos_q, sin_q = cos, sin
|
||||
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||
|
||||
decode_kv_no_split = decode_kv_no_split.unsqueeze(1).unsqueeze(1)
|
||||
decode_k_rope, decode_k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
|
||||
decode_kv_no_split,
|
||||
self.kv_a_layernorm.weight,
|
||||
cos,
|
||||
sin,
|
||||
decode_slot_mapping.to(torch.int64),
|
||||
value_cache,
|
||||
key_cache,
|
||||
c_kv_scale=None,
|
||||
epsilon=self.kv_a_layernorm.variance_epsilon,
|
||||
cache_mode='PA') # adapter NZ
|
||||
# nz_block_size = 16
|
||||
# KVCACHE_NZ_DIM = 16
|
||||
# decode_k_nope = decode_k_nope.view(block_num, 1, self.kv_lora_rank // nz_block_size, block_size, nz_block_size)
|
||||
# decode_k_rope = decode_k_rope.view(block_num, 1, self.qk_rope_head_dim // KVCACHE_NZ_DIM, block_size, KVCACHE_NZ_DIM)
|
||||
|
||||
decode_q_pe = torch_npu.npu_interleave_rope(decode_q_pe, cos,
|
||||
sin) # BNSD
|
||||
|
||||
decode_q_nope = decode_q_nope.view(bsz, self.num_heads,
|
||||
self.kv_lora_rank)
|
||||
decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1)
|
||||
|
||||
topk_indices = self.indexer_select(hidden_states_decode,
|
||||
decode_q_c,
|
||||
attn_metadata=attn_metadata,
|
||||
kv_cache=kv_cache)
|
||||
|
||||
query_states = (decode_q_nope, decode_q_pe)
|
||||
key_states = (decode_k_nope, decode_k_rope)
|
||||
decode_preprocess_res = DecodeSFAPreprocessResult(
|
||||
q_nope=decode_q_nope,
|
||||
q_pe=decode_q_pe,
|
||||
# nope_cache = nope_cache,
|
||||
# rope_cache = rope_cache,
|
||||
topk_indices=topk_indices,
|
||||
query_states=query_states,
|
||||
key_states=key_states,
|
||||
bsz=bsz,
|
||||
)
|
||||
|
||||
# Preprocess for prefill tokens
|
||||
if has_prefill:
|
||||
bsz = 1
|
||||
|
||||
hidden_states_prefill = hidden_states[
|
||||
num_decode_tokens:num_actual_tokens]
|
||||
prefill_kq = self.q_a_proj(hidden_states_prefill) # q down
|
||||
prefill_q_c = self.q_a_layernorm(prefill_kq) # q down layernorm
|
||||
prefill_kv_no_split = self.kv_a_proj_with_mqa(
|
||||
hidden_states_prefill) # c_kv
|
||||
|
||||
# prefill_q_c = q_c[
|
||||
# num_decode_tokens:num_actual_tokens]
|
||||
prefill_slot_mapping = attn_metadata.slot_mapping[
|
||||
num_decode_tokens:num_actual_tokens]
|
||||
# decode_kv_no_split = decode_kv_no_split[:num_decode_tokens]
|
||||
|
||||
prefill_slot_mapping = attn_metadata.slot_mapping[
|
||||
num_decode_tokens:num_actual_tokens]
|
||||
# prefill_kv_no_split = kv_no_split[
|
||||
# num_decode_tokens:num_actual_tokens]
|
||||
# prefill_qr = prefill_q_c[num_decode_tokens:num_actual_tokens]
|
||||
prefill_qr = prefill_q_c
|
||||
prefill_q = self.q_b_proj(prefill_qr)
|
||||
prefill_q = prefill_q.view(-1, self.num_heads, self.qk_head_dim)
|
||||
prefill_q_nope, prefill_q_pe = torch.split(
|
||||
prefill_q, [self.qk_nope_head_dim, self.qk_rope_head_dim],
|
||||
dim=-1)
|
||||
prefill_q_nope = prefill_q_nope.view(
|
||||
-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1)
|
||||
prefill_q_nope = (torch.matmul(prefill_q_nope,
|
||||
self.kv_b_proj_w_k).transpose(
|
||||
1,
|
||||
0).view(-1, self.num_heads,
|
||||
self.kv_lora_rank))
|
||||
prefill_q_pe = prefill_q_pe.unsqueeze(2)
|
||||
|
||||
# stream2 kv
|
||||
|
||||
nope_cache = kv_cache[0]
|
||||
rope_cache = kv_cache[1]
|
||||
cos = attn_metadata.prefill.cos
|
||||
sin = attn_metadata.prefill.sin
|
||||
cos_q, sin_q = cos, sin
|
||||
|
||||
# cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||
# sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||
|
||||
prefill_q_pe = torch_npu.npu_interleave_rope(
|
||||
prefill_q_pe, cos_q, sin_q) # BNSD
|
||||
prefill_q_pe = prefill_q_pe.squeeze(2) #BSH
|
||||
# q[..., self.qk_nope_head_dim:] = prefill_q_pe # TODO:????
|
||||
|
||||
prefill_latent_cache = prefill_kv_no_split # (B,S,N,D)
|
||||
prefill_k_pe, prefill_k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
|
||||
prefill_latent_cache.view(
|
||||
-1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim),
|
||||
self.kv_a_layernorm.weight,
|
||||
cos.view(-1, 1, 1, self.qk_rope_head_dim),
|
||||
sin.view(-1, 1, 1, self.qk_rope_head_dim),
|
||||
prefill_slot_mapping.to(torch.int64),
|
||||
rope_cache,
|
||||
nope_cache,
|
||||
k_rope_scale=None,
|
||||
c_kv_scale=None,
|
||||
k_rope_offset=None,
|
||||
c_kv_offset=None,
|
||||
epsilon=self.kv_a_layernorm.variance_epsilon,
|
||||
cache_mode="PA")
|
||||
|
||||
topk_indices = self.indexer_select(x=hidden_states_prefill,
|
||||
qr=prefill_qr,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata)
|
||||
query_states = (prefill_q_nope, prefill_q_pe)
|
||||
key_states = (prefill_k_nope, prefill_k_pe)
|
||||
prefill_preprocess_res = PrefillSFAPreprocessResult(
|
||||
q_nope=prefill_q_nope,
|
||||
q_pe=prefill_q_pe,
|
||||
topk_indices=topk_indices,
|
||||
k_nope=prefill_k_nope,
|
||||
k_pe=prefill_k_pe,
|
||||
query_states=query_states,
|
||||
key_states=key_states,
|
||||
)
|
||||
|
||||
return decode_preprocess_res, prefill_preprocess_res
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor, # query in unified attn
|
||||
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
attn_metadata: M,
|
||||
need_gather_q_kv: bool = False,
|
||||
output: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
assert output is not None, "Output tensor must be provided."
|
||||
if attn_metadata is None:
|
||||
# Profiling run.
|
||||
return output
|
||||
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||
assert attn_metadata.num_decodes is not None and \
|
||||
attn_metadata.num_prefills is not None and \
|
||||
attn_metadata.num_decode_tokens is not None
|
||||
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||
# Inputs and outputs may be padded for CUDA graphs
|
||||
output = output[:num_actual_tokens, ...]
|
||||
o_proj_input_shape = (num_actual_tokens,
|
||||
self.num_heads * self.v_head_dim)
|
||||
o_proj_input = torch.empty(o_proj_input_shape,
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
|
||||
# SFA Preprocess
|
||||
decode_preprocess_res, prefill_preprocess_res = self._sfa_preprocess(
|
||||
hidden_states, kv_cache, attn_metadata, need_gather_q_kv)
|
||||
|
||||
if decode_preprocess_res is not None:
|
||||
# bsz, q_len, _, _ = query_states[0].shape
|
||||
decode_attn_output = self.apply_attention_fusion(
|
||||
query_states=decode_preprocess_res.query_states,
|
||||
key_states=decode_preprocess_res.key_states,
|
||||
attn_metadata=attn_metadata,
|
||||
topk_indices=decode_preprocess_res.topk_indices)
|
||||
o_proj_input[:num_decode_tokens] = decode_attn_output
|
||||
|
||||
if prefill_preprocess_res is not None:
|
||||
prefill_attn_output = self.apply_attention_fusion(
|
||||
query_states=prefill_preprocess_res.query_states,
|
||||
key_states=prefill_preprocess_res.key_states,
|
||||
attn_metadata=attn_metadata,
|
||||
topk_indices=prefill_preprocess_res.topk_indices)
|
||||
o_proj_input[num_decode_tokens:] = prefill_attn_output
|
||||
|
||||
output[...] = self.mla_epilog(o_proj_input, absorb=True)
|
||||
return output
|
||||
|
||||
def apply_attention_fusion(self, query_states, key_states, topk_indices,
|
||||
attn_metadata: M):
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
q_nope, q_pe = query_states
|
||||
k_nope, k_rope = key_states
|
||||
|
||||
if attn_metadata.prefill is not None:
|
||||
|
||||
prefill_metadata = attn_metadata.prefill
|
||||
|
||||
slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention(
|
||||
query=q_nope,
|
||||
key=k_nope,
|
||||
value=k_nope,
|
||||
sparse_indices=topk_indices,
|
||||
scale_value=self.scale,
|
||||
sparse_block_size=1,
|
||||
block_table=prefill_metadata.block_table,
|
||||
actual_seq_lengths_query=prefill_metadata.query_lens,
|
||||
actual_seq_lengths_kv=prefill_metadata.seq_lens,
|
||||
query_rope=q_pe,
|
||||
key_rope=k_rope,
|
||||
layout_query="TND",
|
||||
layout_kv="PA_BSND",
|
||||
sparse_mode=3,
|
||||
)
|
||||
|
||||
elif attn_metadata.decode is not None:
|
||||
decode_metadata = attn_metadata.decode
|
||||
|
||||
slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention(
|
||||
query=q_nope,
|
||||
key=k_nope,
|
||||
value=k_nope,
|
||||
sparse_indices=topk_indices,
|
||||
scale_value=self.scale,
|
||||
sparse_block_size=1,
|
||||
block_table=attn_metadata.decode.block_table,
|
||||
actual_seq_lengths_query=decode_metadata.actual_seq_lengths_q,
|
||||
actual_seq_lengths_kv=decode_metadata.seq_lens,
|
||||
query_rope=q_pe,
|
||||
key_rope=k_rope,
|
||||
layout_query="TND",
|
||||
layout_kv="PA_BSND",
|
||||
sparse_mode=3,
|
||||
)
|
||||
slc_fa_fusion = slc_fa_fusion.squeeze(1)
|
||||
|
||||
slc_fa_fusion = slc_fa_fusion.transpose(0, 1)
|
||||
|
||||
# input shape [N//attn_tp_size, T(bs*q_len), D]
|
||||
# output shape [T(bs*q_len), N//attn_tp_size, D]
|
||||
attn_output = torch.matmul(slc_fa_fusion,
|
||||
self.kv_b_proj_w_v).transpose(1, 0).reshape(
|
||||
-1, self.num_heads * self.v_head_dim)
|
||||
# Note: Considering the fusion rules of TBMM, attn_output shape requires a 3-dim shape, and
|
||||
# with appropriate tensor stride for the later 'view' operation if oproj_tp_size > 1.
|
||||
# after reshape: [T(bs*q_len), 1, N//attn_tp_size*D]
|
||||
# attn_output = attn_output.reshape(-1, self.num_heads * self.v_head_dim)
|
||||
|
||||
return attn_output
|
||||
|
||||
def mla_epilog(self,
|
||||
attn_output: torch.Tensor = None,
|
||||
absorb: bool = False):
|
||||
# TODO: need to check
|
||||
attn_output = self.o_proj(attn_output.reshape(attn_output.shape[0],
|
||||
-1),
|
||||
is_prefill=True,
|
||||
is_force_scatter=False)
|
||||
|
||||
return attn_output
|
||||
|
||||
def indexer_select(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
qr: torch.Tensor,
|
||||
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
attn_metadata: M,
|
||||
):
|
||||
if attn_metadata.prefill is not None:
|
||||
cos = attn_metadata.prefill.cos
|
||||
sin = attn_metadata.prefill.sin
|
||||
actual_seq_lengths_query = attn_metadata.prefill.query_lens
|
||||
actual_seq_lengths_key = attn_metadata.prefill.seq_lens
|
||||
block_table = attn_metadata.prefill.block_table
|
||||
elif attn_metadata.decode is not None:
|
||||
cos = attn_metadata.decode.cos
|
||||
sin = attn_metadata.decode.sin
|
||||
actual_seq_lengths_query = attn_metadata.decode.actual_seq_lengths_q
|
||||
actual_seq_lengths_key = attn_metadata.decode.seq_lens
|
||||
block_table = attn_metadata.decode.block_table
|
||||
|
||||
cos_q, sin_q = cos, sin
|
||||
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||
|
||||
# q process in new stream
|
||||
q = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
|
||||
q = q.view(-1, self.n_heads, self.head_dim) # [b,s,64,128]
|
||||
q_pe, q_nope = torch.split(
|
||||
q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
|
||||
dim=-1) # [b,s,64,64+64]
|
||||
|
||||
q_pe = q_pe.unsqueeze(2)
|
||||
q_pe = torch_npu.npu_interleave_rope(q_pe, cos_q, sin_q)
|
||||
q_pe = q_pe.squeeze(2)
|
||||
q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128]
|
||||
|
||||
k_proj = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128]
|
||||
k = self.k_norm(k_proj).unsqueeze(1)
|
||||
k_pe, k_nope = torch.split(
|
||||
k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
|
||||
dim=-1) # [b,s,64+64]
|
||||
|
||||
k_pe = k_pe.unsqueeze(2)
|
||||
k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin)
|
||||
k_pe = k_pe.squeeze(2)
|
||||
|
||||
k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128]
|
||||
|
||||
if kv_cache is not None:
|
||||
torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]),
|
||||
attn_metadata.slot_mapping.view(
|
||||
-1, 1),
|
||||
k.view(-1,
|
||||
k.shape[-1])) # b, s, n, d
|
||||
|
||||
weights = self.weights_proj(x)
|
||||
|
||||
topk_indices = torch.ops.custom.npu_lightning_indexer(
|
||||
query=q,
|
||||
key=kv_cache[2],
|
||||
weights=weights,
|
||||
actual_seq_lengths_query=actual_seq_lengths_query,
|
||||
actual_seq_lengths_key=actual_seq_lengths_key,
|
||||
block_table=block_table,
|
||||
layout_query="TND",
|
||||
layout_key="PA_BSND",
|
||||
sparse_count=2048,
|
||||
sparse_mode=3)
|
||||
return topk_indices
|
||||
@@ -493,8 +493,11 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
assert self.local_agent_metadata is not None
|
||||
kv_cache_dtype = first_kv_cache.dtype
|
||||
self.use_mla: bool = first_kv_cache_tuple[0].size(
|
||||
-1) != first_kv_cache_tuple[1].size(-1)
|
||||
-1) != first_kv_cache_tuple[1].size(-1) and len(
|
||||
first_kv_cache_tuple) == 2
|
||||
self.use_sfa: bool = len(first_kv_cache_tuple) == 3
|
||||
# MLA case. [2 (k_normed, k_pe), num_blocks, ...]
|
||||
# SFA case. [3 (k_normed, k_pe, k_idx), num_blocks, ...]
|
||||
# MHA case. [2 (k and v), num_blocks, ...]
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 3 # [block_size, latent_dim]
|
||||
@@ -540,6 +543,58 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
|
||||
)
|
||||
elif self.use_sfa:
|
||||
cache_k_normed_addr_list = []
|
||||
cache_k_pe_addr_list = []
|
||||
cache_k_idx_addr_list = []
|
||||
k_normed = None
|
||||
k_pe = None
|
||||
k_idx = None
|
||||
for cache_or_caches in kv_caches.values():
|
||||
assert len(cache_or_caches) > 1
|
||||
k_normed, k_pe, k_idx = cache_or_caches[0], cache_or_caches[
|
||||
1], cache_or_caches[2]
|
||||
cache_k_normed_addr_list.append(k_normed.data_ptr())
|
||||
cache_k_pe_addr_list.append(k_pe.data_ptr())
|
||||
cache_k_idx_addr_list.append(k_idx.data_ptr())
|
||||
self.cache_addr = (cache_k_normed_addr_list, cache_k_pe_addr_list,
|
||||
cache_k_idx_addr_list)
|
||||
|
||||
cache_desc_k_normed = CacheDesc(
|
||||
len(self.cache_addr[0]), [*k_normed.shape],
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
||||
cache_desc_k_pe = CacheDesc(
|
||||
len(self.cache_addr[1]), [*k_pe.shape],
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
||||
cache_desc_k_idx = CacheDesc(
|
||||
len(self.cache_addr[2]), [*k_idx.shape],
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
||||
cache_key_k_normed = BlocksCacheKey(cluster_id=int(
|
||||
self.local_agent_metadata.cluster_id),
|
||||
model_id=0)
|
||||
cache_key_k_pe = BlocksCacheKey(cluster_id=int(
|
||||
self.local_agent_metadata.cluster_id),
|
||||
model_id=1)
|
||||
cache_key_k_idx = BlocksCacheKey(cluster_id=int(
|
||||
self.local_agent_metadata.cluster_id),
|
||||
model_id=2)
|
||||
self.cache_desc = (cache_desc_k_normed, cache_desc_k_pe,
|
||||
cache_desc_k_idx)
|
||||
self.cache_key = (cache_key_k_normed, cache_key_k_pe,
|
||||
cache_key_k_idx)
|
||||
try:
|
||||
cache_k_normed = self.cache_manager.register_blocks_cache(
|
||||
self.cache_desc[0], self.cache_addr[0], self.cache_key[0])
|
||||
cache_k_pe = self.cache_manager.register_blocks_cache(
|
||||
self.cache_desc[1], self.cache_addr[1], self.cache_key[1])
|
||||
cache_k_idx = self.cache_manager.register_blocks_cache(
|
||||
self.cache_desc[2], self.cache_addr[2], self.cache_key[2])
|
||||
self.cache = (cache_k_normed, cache_k_pe, cache_k_idx)
|
||||
logger.info("LLMDataDistWorker: End of register Paged Cache.")
|
||||
except (TypeError, ValueError):
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
|
||||
)
|
||||
else:
|
||||
for cache_or_caches in kv_caches.values():
|
||||
for cache in cache_or_caches:
|
||||
@@ -826,6 +881,38 @@ class LLMDataDistCMgrConnectorWorker():
|
||||
raise RuntimeError(
|
||||
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
|
||||
)
|
||||
elif self.use_sfa:
|
||||
remote_cache_key_k_normed = BlocksCacheKey(
|
||||
cluster_id=remote_cluster_id, model_id=0)
|
||||
remote_cache_key_k_pe = BlocksCacheKey(
|
||||
cluster_id=remote_cluster_id, model_id=1)
|
||||
remote_cache_key_k_idx = BlocksCacheKey(
|
||||
cluster_id=remote_cluster_id, model_id=2)
|
||||
logger.info("Try pull blocks from remote server")
|
||||
try:
|
||||
self.cache_manager.pull_blocks(
|
||||
remote_cache_key_k_normed,
|
||||
self.cache[0], # type: ignore[has-type]
|
||||
remote_block_ids,
|
||||
local_block_ids)
|
||||
self.cache_manager.pull_blocks(
|
||||
remote_cache_key_k_pe,
|
||||
self.cache[1], # type: ignore[has-type]
|
||||
remote_block_ids,
|
||||
local_block_ids)
|
||||
self.cache_manager.pull_blocks(
|
||||
remote_cache_key_k_idx,
|
||||
self.cache[2], # type: ignore[has-type]
|
||||
remote_block_ids,
|
||||
local_block_ids)
|
||||
except (TypeError, ValueError):
|
||||
raise RuntimeError(
|
||||
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key_k_normed} {remote_cache_key_k_pe} {remote_cache_key_k_idx}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
|
||||
)
|
||||
except LLMException:
|
||||
raise RuntimeError(
|
||||
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
|
||||
)
|
||||
else:
|
||||
remote_cache_key = BlocksCacheKey(cluster_id=remote_cluster_id)
|
||||
logger.info("Try pull blocks from remote server")
|
||||
|
||||
@@ -30,6 +30,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.request import RequestStatus
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
@@ -238,6 +239,7 @@ class KVCacheRecvingThread(threading.Thread):
|
||||
self.block_len = block_len
|
||||
# TODO(jianzs): find a better way to detect MLA.
|
||||
self.use_mla = len(block_len) == 2
|
||||
self.use_sfa = len(block_len) == 3
|
||||
|
||||
self.request_queue: queue.Queue[Any] = queue.Queue()
|
||||
# TODO(jianzs): make this configurable
|
||||
@@ -349,8 +351,12 @@ class KVCacheRecvingThread(threading.Thread):
|
||||
src_list, dst_list, length_list = [], [], []
|
||||
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
|
||||
zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs)):
|
||||
block_len = (self.block_len[k % 2]
|
||||
if self.use_mla else self.block_len[0])
|
||||
if self.use_mla:
|
||||
block_len = (self.block_len[k % 2])
|
||||
elif self.use_sfa:
|
||||
block_len = (self.block_len[k % 3])
|
||||
else:
|
||||
block_len = (self.block_len[0])
|
||||
for i, remote_block_id in enumerate(grouped_remote_block_ids):
|
||||
local_block_ids = grouped_local_block_ids[i]
|
||||
src = src_layer_base_addr + local_block_ids[0] * block_len
|
||||
@@ -567,6 +573,7 @@ class MooncakeConnectorScheduler:
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
||||
self.vllm_config = vllm_config
|
||||
self.ascend_config = get_ascend_config()
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.engine_id = engine_id
|
||||
logger.info("Initializing Mooncake Scheduler %s", engine_id)
|
||||
@@ -726,7 +733,7 @@ class MooncakeConnectorScheduler:
|
||||
assert "tp_size" in decode_parallel_config.keys()
|
||||
self._decode_tp_size = decode_parallel_config["tp_size"]
|
||||
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
if self.vllm_config.model_config.use_mla or self.ascend_config.use_sfa:
|
||||
return self._decode_tp_size
|
||||
else:
|
||||
# TODO support mha and gqa
|
||||
@@ -847,7 +854,9 @@ class MooncakeConnectorWorker:
|
||||
|
||||
# TODO(tms): Find a more robust way to detect and handle MLA
|
||||
self.use_mla = first_kv_cache_tuple[0].size(
|
||||
-1) != first_kv_cache_tuple[1].size(-1)
|
||||
-1) != first_kv_cache_tuple[1].size(-1) and len(
|
||||
first_kv_cache_tuple) == 2
|
||||
self.use_sfa = len(first_kv_cache_tuple) == 3
|
||||
if self.use_mla:
|
||||
# MLA case.[num_block, block_size, 1, hidden_dim]
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
@@ -861,6 +870,21 @@ class MooncakeConnectorWorker:
|
||||
logger.info(
|
||||
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
|
||||
self.num_blocks, block_shape_norm, block_shape_pe)
|
||||
elif self.use_sfa:
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
block_rank = 3 # [block_size, latent_dim]
|
||||
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
|
||||
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
|
||||
block_shape_k = first_kv_cache_tuple[2].shape[-block_rank:]
|
||||
self.block_len = [
|
||||
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
|
||||
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
|
||||
first_kv_cache[2].element_size() * math.prod(block_shape_k)
|
||||
]
|
||||
logger.info(
|
||||
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s, block_shape_k: %s",
|
||||
self.num_blocks, block_shape_norm, block_shape_pe,
|
||||
block_shape_k)
|
||||
else:
|
||||
# [num_block, block_size, num_head, hidden_dim]
|
||||
self.num_blocks = first_kv_cache.shape[0]
|
||||
@@ -871,8 +895,9 @@ class MooncakeConnectorWorker:
|
||||
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
|
||||
block_shape)
|
||||
|
||||
logger.info("Registering KV_Caches. use_mla: %s, shape %s",
|
||||
self.use_mla, first_kv_cache.shape)
|
||||
logger.info(
|
||||
"Registering KV_Caches. use_mla: %s, use_sfa: %s, shape %s",
|
||||
self.use_mla, self.use_sfa, first_kv_cache.shape)
|
||||
|
||||
self.kv_caches = kv_caches
|
||||
kv_caches_base_addr = []
|
||||
@@ -884,9 +909,16 @@ class MooncakeConnectorWorker:
|
||||
region_len = self.num_blocks * self.block_len[i % 2]
|
||||
kv_caches_base_addr.append(base_addr)
|
||||
self._register(base_addr, region_len)
|
||||
elif self.use_sfa:
|
||||
for i, cache in enumerate(cache_or_caches, 0):
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len[i % 3]
|
||||
kv_caches_base_addr.append(base_addr)
|
||||
self._register(base_addr, region_len)
|
||||
else:
|
||||
cache_list = [cache_or_caches
|
||||
] if self.use_mla else cache_or_caches
|
||||
cache_list = [
|
||||
cache_or_caches
|
||||
] if self.use_mla or self.use_sfa else cache_or_caches
|
||||
for cache in cache_list:
|
||||
base_addr = cache.data_ptr()
|
||||
region_len = self.num_blocks * self.block_len[0]
|
||||
|
||||
@@ -162,6 +162,13 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
||||
# Whether to enable msMonitor tool to monitor the performance of vllm-ascend.
|
||||
"MSMONITOR_USE_DAEMON":
|
||||
lambda: bool(int(os.getenv("MSMONITOR_USE_DAEMON", '0'))),
|
||||
# Timeout (in seconds) for delayed KVCache block release. In the prefill
|
||||
# node, if a request is marked for delayed KV block release and the blocks
|
||||
# are not freed within this timeout, they will be forcibly released.
|
||||
"VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT":
|
||||
lambda: int(os.getenv("VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT", 250)),
|
||||
"VLLM_ASCEND_ENABLE_MLAPO":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MLAPO", '0'))),
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
||||
@@ -37,6 +37,10 @@ def register_model():
|
||||
"DeepseekV3ForCausalLM",
|
||||
"vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepseekV32ForCausalLM",
|
||||
"vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM")
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepSeekMTPModel",
|
||||
"vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP")
|
||||
|
||||
@@ -60,6 +60,8 @@ from vllm.model_executor.models.utils import (PPMissingLayer,
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.models.layers.mla import AscendMLAModules
|
||||
from vllm_ascend.models.layers.sfa import (AscendSFAModules,
|
||||
AscendSparseFlashAttention, Indexer)
|
||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
||||
|
||||
|
||||
@@ -244,6 +246,180 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
||||
return self.mla_attn(positions, hidden_states, kv_cache, attn_metadata)
|
||||
|
||||
|
||||
class CustomDeepseekV2SFAAttention(DeepseekV2MLAAttention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
v_head_dim: int,
|
||||
q_lora_rank: Optional[int],
|
||||
kv_lora_rank: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.hidden_size = hidden_size
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
assert num_heads % self.tp_size == 0
|
||||
self.num_local_heads = num_heads // self.tp_size
|
||||
self.layers = config.num_hidden_layers
|
||||
self.first_k_dense_replace = config.first_k_dense_replace
|
||||
|
||||
self.scaling = self.qk_head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.prefix = prefix
|
||||
self.debug_layer_idx = int(self.prefix.split(".")[-2])
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
self.q_a_proj = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.q_lora_rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_a_proj",
|
||||
return_bias=False,
|
||||
)
|
||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
|
||||
eps=config.rms_norm_eps)
|
||||
self.q_b_proj = ColumnParallelLinear(
|
||||
q_lora_rank,
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_b_proj",
|
||||
return_bias=False,
|
||||
)
|
||||
else:
|
||||
self.q_proj = ColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_proj",
|
||||
return_bias=False,
|
||||
)
|
||||
|
||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.kv_a_proj_with_mqa",
|
||||
return_bias=False,
|
||||
)
|
||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
|
||||
eps=config.rms_norm_eps)
|
||||
self.kv_b_proj = ColumnParallelLinear(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.kv_b_proj",
|
||||
return_bias=False,
|
||||
)
|
||||
self.o_proj = CustomDeepseekV2RowParallelLinear(
|
||||
self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
return_bias=False,
|
||||
)
|
||||
|
||||
if rope_scaling:
|
||||
rope_scaling["rope_type"] = 'deepseek_yarn'
|
||||
self.rotary_emb = get_rope(qk_rope_head_dim,
|
||||
rotary_dim=qk_rope_head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=False)
|
||||
if rope_scaling:
|
||||
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
self.scaling = self.scaling * mscale * mscale
|
||||
|
||||
self.dim: int = config.hidden_size # 7168
|
||||
# TODO(zzzzwwjj): wait transformers add these params
|
||||
self.n_heads: int = 64 # 64
|
||||
self.head_dim: int = 128 # 128
|
||||
self.index_topk: int = 2048 # 2048
|
||||
self.indexer = Indexer(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
dim=self.dim,
|
||||
n_heads=self.n_heads,
|
||||
head_dim=self.head_dim,
|
||||
index_topk=self.index_topk,
|
||||
prefix=f"{prefix}.indexer",
|
||||
)
|
||||
|
||||
sfa_modules = AscendSFAModules(
|
||||
q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None,
|
||||
q_a_layernorm=self.q_a_layernorm
|
||||
if self.q_lora_rank is not None else None,
|
||||
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
|
||||
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
|
||||
kv_a_layernorm=self.kv_a_layernorm,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
o_proj=self.o_proj,
|
||||
rotary_emb=self.rotary_emb,
|
||||
indexer=self.indexer)
|
||||
|
||||
self.sfa_attn = AscendSparseFlashAttention(
|
||||
self.hidden_size,
|
||||
self.enable_shared_expert_dp,
|
||||
self.debug_layer_idx,
|
||||
self.first_k_dense_replace,
|
||||
self.tp_size,
|
||||
sfa_modules,
|
||||
self.num_local_heads,
|
||||
self.scaling,
|
||||
self.layers,
|
||||
self.kv_lora_rank,
|
||||
self.qk_rope_head_dim,
|
||||
self.q_lora_rank,
|
||||
self.qk_nope_head_dim,
|
||||
self.qk_head_dim,
|
||||
self.v_head_dim,
|
||||
cache_config,
|
||||
quant_config,
|
||||
prefix,
|
||||
)
|
||||
self.prefix = prefix
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||
return self.sfa_attn(positions, hidden_states, kv_cache, attn_metadata)
|
||||
|
||||
|
||||
class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
|
||||
@@ -253,6 +429,7 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
parallel_config = vllm_config.parallel_config
|
||||
ascend_config = get_ascend_config()
|
||||
|
||||
self.hidden_size = config.hidden_size
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
@@ -268,7 +445,10 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
self.tp_rank = get_tp_group().rank_in_group
|
||||
# TODO: enable mla in vllm-ascend
|
||||
if model_config.use_mla:
|
||||
attn_cls = CustomDeepseekV2MLAAttention
|
||||
if ascend_config.use_sfa:
|
||||
attn_cls = CustomDeepseekV2SFAAttention
|
||||
else:
|
||||
attn_cls = CustomDeepseekV2MLAAttention
|
||||
else:
|
||||
attn_cls = DeepseekV2Attention
|
||||
self.self_attn = attn_cls(
|
||||
|
||||
233
vllm_ascend/models/layers/sfa.py
Normal file
233
vllm_ascend/models/layers/sfa.py
Normal file
@@ -0,0 +1,233 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, get_current_vllm_config
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||
from vllm.model_executor.layers.mla import MultiHeadLatentAttention
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
|
||||
@dataclass
|
||||
class AscendSFAModules:
|
||||
q_a_proj: Optional[torch.nn.Module]
|
||||
q_a_layernorm: Optional[torch.nn.Module]
|
||||
q_proj: Optional[torch.nn.Module]
|
||||
kv_a_proj_with_mqa: torch.nn.Module
|
||||
kv_a_layernorm: torch.nn.Module
|
||||
kv_b_proj: torch.nn.Module
|
||||
o_proj: torch.nn.Module
|
||||
rotary_emb: torch.nn.Module
|
||||
indexer: torch.nn.Module
|
||||
|
||||
|
||||
class AscendSparseFlashAttention(MultiHeadLatentAttention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
enable_shared_expert_dp: bool,
|
||||
debug_layer_idx: int,
|
||||
first_k_dense_replace: int,
|
||||
tp_size: int,
|
||||
sfa_modules: AscendSFAModules,
|
||||
num_local_heads: int,
|
||||
scaling: float,
|
||||
layers: int,
|
||||
kv_lora_rank: int,
|
||||
qk_rope_head_dim: int,
|
||||
q_lora_rank: Optional[int],
|
||||
qk_nope_head_dim: int,
|
||||
qk_head_dim: int,
|
||||
v_head_dim: int,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.hidden_size = hidden_size
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
self.debug_layer_idx = debug_layer_idx
|
||||
self.first_k_dense_replace = first_k_dense_replace
|
||||
self.tp_size = tp_size
|
||||
self.num_local_heads = num_local_heads
|
||||
self.layers = layers
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_head_dim = qk_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.prefix = prefix
|
||||
|
||||
self.sfa_attn = Attention(
|
||||
num_heads=self.num_local_heads,
|
||||
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
scale=scaling,
|
||||
num_kv_heads=1,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_mla=True,
|
||||
use_sfa=True,
|
||||
# SFA Args
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
qk_head_dim=self.qk_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
rotary_emb=sfa_modules.rotary_emb,
|
||||
q_a_proj=sfa_modules.q_a_proj,
|
||||
q_a_layernorm=sfa_modules.q_a_layernorm,
|
||||
q_proj=sfa_modules.q_proj,
|
||||
kv_a_proj_with_mqa=sfa_modules.kv_a_proj_with_mqa,
|
||||
kv_a_layernorm=sfa_modules.kv_a_layernorm,
|
||||
kv_b_proj=sfa_modules.kv_b_proj,
|
||||
o_proj=sfa_modules.o_proj,
|
||||
indexer=sfa_modules.indexer)
|
||||
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||
num_tokens = hidden_states.shape[0]
|
||||
need_gather_q_kv = False
|
||||
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
|
||||
# Simulate all gather to calculate output shape
|
||||
num_tokens = num_tokens * self.tp_size
|
||||
need_gather_q_kv = True
|
||||
if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace:
|
||||
output_shape = hidden_states.shape
|
||||
else:
|
||||
rows = num_tokens // self.tp_size
|
||||
if num_tokens % self.tp_size:
|
||||
rows += 1
|
||||
output_shape = (rows, hidden_states.shape[1])
|
||||
# FIXME: This does not seem right, should make sure the buffer is fixed
|
||||
output = torch.empty(output_shape,
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
torch.ops.vllm.sfa_forward(hidden_states, need_gather_q_kv, output,
|
||||
self.prefix)
|
||||
output = output.view(-1, output_shape[-1])
|
||||
return output
|
||||
|
||||
|
||||
def sfa_forward(
|
||||
hidden_states: torch.Tensor,
|
||||
need_gather_q_kv: bool,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
if forward_context.attn_metadata:
|
||||
attn_metadata = forward_context.attn_metadata[self.sfa_attn.layer_name]
|
||||
else:
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
kv_cache = self.sfa_attn.kv_cache[forward_context.virtual_engine]
|
||||
self.sfa_attn.impl.forward(hidden_states, kv_cache, attn_metadata,
|
||||
need_gather_q_kv, output)
|
||||
return
|
||||
|
||||
|
||||
class Indexer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
dim: int = 7168,
|
||||
n_heads: int = 64,
|
||||
head_dim: int = 128,
|
||||
index_topk: int = 2048,
|
||||
q_lora_rank: int = 1536,
|
||||
rope_head_dim: int = 64,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: Optional[str] = ""):
|
||||
super().__init__()
|
||||
|
||||
self.dim: int = dim # 7168
|
||||
self.n_heads: int = n_heads # 64
|
||||
self.head_dim: int = head_dim # 128
|
||||
self.rope_head_dim: int = rope_head_dim # 64
|
||||
self.index_topk: int = index_topk # 2048
|
||||
self.q_lora_rank: int = q_lora_rank # 1536
|
||||
self.wq_b = ReplicatedLinear(
|
||||
self.q_lora_rank,
|
||||
self.n_heads * self.head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.wq_b",
|
||||
return_bias=False,
|
||||
)
|
||||
self.wk = ReplicatedLinear(
|
||||
self.dim,
|
||||
self.head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.wk",
|
||||
return_bias=False,
|
||||
)
|
||||
self.weights_proj = ReplicatedLinear(
|
||||
self.dim,
|
||||
self.n_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.weights_proj",
|
||||
return_bias=False,
|
||||
)
|
||||
self.k_norm = nn.LayerNorm(self.head_dim)
|
||||
self.softmax_scale = self.head_dim**-0.5
|
||||
|
||||
def forward(self):
|
||||
return
|
||||
|
||||
|
||||
def sfa_forward_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
need_gather_q_kv: bool,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="sfa_forward",
|
||||
op_func=sfa_forward,
|
||||
mutates_args=["output"],
|
||||
fake_impl=sfa_forward_fake,
|
||||
dispatch_key="PrivateUse1",
|
||||
)
|
||||
@@ -15,6 +15,10 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import vllm_ascend.patch.platform.patch_common.patch_config # noqa
|
||||
import vllm_ascend.patch.platform.patch_common.patch_distributed # noqa
|
||||
import vllm_ascend.patch.platform.patch_common.patch_mamba_config # noqa
|
||||
import vllm_ascend.patch.platform.patch_common.patch_multimodal_merge # noqa
|
||||
import vllm_ascend.patch.platform.patch_common.patch_transformers_utils # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_attentionspec # noqa
|
||||
|
||||
313
vllm_ascend/patch/platform/patch_common/patch_config.py
Normal file
313
vllm_ascend/patch/platform/patch_common/patch_config.py
Normal file
@@ -0,0 +1,313 @@
|
||||
import ast
|
||||
|
||||
import vllm.envs as envs
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.config.speculative import SpeculativeConfig
|
||||
from vllm.logger import logger
|
||||
|
||||
|
||||
# mypy: ignore-errors
|
||||
@property
|
||||
def is_deepseek_mla(self: ModelConfig):
|
||||
if not hasattr(self.hf_text_config, "model_type"):
|
||||
return False
|
||||
elif self.hf_text_config.model_type in \
|
||||
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp',
|
||||
'kimi_k2', 'longcat_flash', 'deepseek_v32'):
|
||||
return self.hf_text_config.kv_lora_rank is not None
|
||||
elif self.hf_text_config.model_type == 'eagle':
|
||||
# if the model is an EAGLE module, check for the
|
||||
# underlying architecture
|
||||
return self.hf_text_config.model.model_type in \
|
||||
('deepseek_v2', 'deepseek_v3', 'deepseek_v32') \
|
||||
and self.hf_text_config.kv_lora_rank is not None
|
||||
return False
|
||||
|
||||
|
||||
@staticmethod
|
||||
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
|
||||
if hf_config.model_type in ("deepseek_v3", "deepseek_v32"):
|
||||
hf_config.model_type = "deepseek_mtp"
|
||||
if hf_config.model_type == "deepseek_mtp":
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update({
|
||||
"n_predict": n_predict,
|
||||
"architectures": ["DeepSeekMTPModel"]
|
||||
})
|
||||
|
||||
if hf_config.architectures[0] == "MiMoForCausalLM":
|
||||
hf_config.model_type = "mimo_mtp"
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update({
|
||||
"num_hidden_layers": 0,
|
||||
"n_predict": n_predict,
|
||||
"architectures": ["MiMoMTPModel"]
|
||||
})
|
||||
|
||||
if hf_config.architectures[0] == "Glm4MoeForCausalLM":
|
||||
hf_config.model_type = "glm4_moe_mtp"
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update({
|
||||
"num_hidden_layers": 0,
|
||||
"n_predict": n_predict,
|
||||
"architectures": ["Glm4MoeMTPModel"]
|
||||
})
|
||||
|
||||
if hf_config.model_type == "ernie4_5_moe":
|
||||
hf_config.model_type = "ernie_mtp"
|
||||
if hf_config.model_type == "ernie_mtp":
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update({
|
||||
"n_predict": n_predict,
|
||||
"architectures": ["ErnieMTPModel"]
|
||||
})
|
||||
|
||||
if hf_config.model_type == "qwen3_next":
|
||||
hf_config.model_type = "qwen3_next_mtp"
|
||||
if hf_config.model_type == "qwen3_next_mtp":
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||
hf_config.update({
|
||||
"n_predict": n_predict,
|
||||
"architectures": ["Qwen3NextMTP"]
|
||||
})
|
||||
if hf_config.model_type == "longcat_flash":
|
||||
hf_config.model_type = "longcat_flash_mtp"
|
||||
n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
|
||||
hf_config.update({
|
||||
"n_predict": n_predict,
|
||||
"architectures": ["LongCatFlashMTPModel"]
|
||||
})
|
||||
|
||||
return hf_config
|
||||
|
||||
|
||||
def __post_init__(self):
|
||||
|
||||
# Note: "method" is a new parameter that helps to extend the
|
||||
# configuration of non-model-based proposers, and the "model" parameter
|
||||
# will be used to set the draft model, eagle head, or additional weight
|
||||
# when needed. If users do not specify "method", the speculative method
|
||||
# will be detected automatically if possible. If the speculative method
|
||||
# can not be detected, it will be considered as the "draft_model" by
|
||||
# default.
|
||||
|
||||
if self.model is None and self.num_speculative_tokens is not None:
|
||||
# TODO(Shangming): Refactor mtp configuration logic when supporting
|
||||
if (self.target_model_config
|
||||
and self.target_model_config.hf_text_config.model_type
|
||||
in ("deepseek_v3", "deepseek_v32", "mimo", "ernie4_5_moe",
|
||||
"qwen3_next")):
|
||||
# use the draft model from the same model:
|
||||
self.model = self.target_model_config.model
|
||||
# Align the quantization of draft model for cases such as
|
||||
# --quantization fp8 with a bf16 checkpoint.
|
||||
if not self.quantization:
|
||||
self.quantization = self.target_model_config.quantization
|
||||
elif self.method in ("ngram", "[ngram]"):
|
||||
self.model = "ngram"
|
||||
else:
|
||||
raise ValueError("num_speculative_tokens was provided but without "
|
||||
"speculative model.")
|
||||
|
||||
# Automatically configure the method for ngram when "model" is used
|
||||
# instead of "method"
|
||||
if self.method is None and (self.model is not None
|
||||
and self.model in ("ngram", "[ngram]")):
|
||||
self.method = "ngram"
|
||||
|
||||
if self.method in ("ngram", "[ngram]"):
|
||||
# Unified to "ngram" internally
|
||||
self.method = "ngram"
|
||||
# Set default values if not provided
|
||||
if (self.prompt_lookup_min is None and self.prompt_lookup_max is None):
|
||||
# TODO(woosuk): Tune these values. They are arbitrarily chosen.
|
||||
self.prompt_lookup_min = 5
|
||||
self.prompt_lookup_max = 5
|
||||
elif self.prompt_lookup_min is None:
|
||||
assert self.prompt_lookup_max is not None
|
||||
self.prompt_lookup_min = self.prompt_lookup_max
|
||||
elif self.prompt_lookup_max is None:
|
||||
assert self.prompt_lookup_min is not None
|
||||
self.prompt_lookup_max = self.prompt_lookup_min
|
||||
|
||||
# Validate values
|
||||
if self.prompt_lookup_min < 1:
|
||||
raise ValueError(
|
||||
f"prompt_lookup_min={self.prompt_lookup_min} must be > 0")
|
||||
if self.prompt_lookup_max < 1:
|
||||
raise ValueError(
|
||||
f"prompt_lookup_max={self.prompt_lookup_max} must be > 0")
|
||||
if self.prompt_lookup_min > self.prompt_lookup_max:
|
||||
raise ValueError(
|
||||
f"prompt_lookup_min={self.prompt_lookup_min} must "
|
||||
f"be <= prompt_lookup_max={self.prompt_lookup_max}")
|
||||
|
||||
# TODO: current we still need extract vocab_size from target model
|
||||
# config, in future, we may try refactor it out, and set
|
||||
# draft related config as None here.
|
||||
self.draft_model_config = self.target_model_config
|
||||
self.draft_parallel_config = self.target_parallel_config
|
||||
else:
|
||||
self.prompt_lookup_max = 0
|
||||
self.prompt_lookup_min = 0
|
||||
|
||||
if self.model is not None:
|
||||
# TODO: Move this import to the top once `ModelConfig`
|
||||
# lives in `vllm.config.model`.
|
||||
from vllm.config import ModelConfig
|
||||
self.draft_model_config = ModelConfig(
|
||||
model=self.model,
|
||||
runner="draft",
|
||||
tokenizer=self.target_model_config.tokenizer,
|
||||
tokenizer_mode=self.target_model_config.tokenizer_mode,
|
||||
trust_remote_code=self.target_model_config.trust_remote_code,
|
||||
allowed_local_media_path=self.target_model_config.
|
||||
allowed_local_media_path,
|
||||
allowed_media_domains=self.target_model_config.
|
||||
allowed_media_domains,
|
||||
dtype=self.target_model_config.dtype,
|
||||
seed=self.target_model_config.seed,
|
||||
revision=self.revision,
|
||||
code_revision=self.code_revision,
|
||||
tokenizer_revision=self.target_model_config.tokenizer_revision,
|
||||
spec_target_max_model_len=self.target_model_config.
|
||||
max_model_len,
|
||||
quantization=self.quantization,
|
||||
enforce_eager=self.target_model_config.enforce_eager,
|
||||
max_logprobs=self.target_model_config.max_logprobs,
|
||||
hf_overrides=SpeculativeConfig.hf_config_override,
|
||||
)
|
||||
|
||||
# Automatically detect the method
|
||||
if self.method in ('eagle', 'eagle3'):
|
||||
pass
|
||||
# examples:
|
||||
# yuhuili/EAGLE-LLaMA3-Instruct-8B
|
||||
# yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
|
||||
# AngelSlim/Qwen3-8B_eagle3
|
||||
elif "eagle-" in self.draft_model_config.model.lower():
|
||||
self.method = "eagle"
|
||||
elif "eagle3" in self.draft_model_config.model.lower():
|
||||
self.method = "eagle3"
|
||||
elif self.draft_model_config.hf_config.model_type == "medusa":
|
||||
self.method = "medusa"
|
||||
elif (self.draft_model_config.hf_config.model_type ==
|
||||
"mlp_speculator"):
|
||||
self.method = "mlp_speculator"
|
||||
elif (self.draft_model_config.hf_config.model_type
|
||||
in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")):
|
||||
self.method = "deepseek_mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
"All Deepseek MTP models only have " \
|
||||
"one layer. Might need some code changes " \
|
||||
"to support multiple layers."
|
||||
)
|
||||
elif (self.draft_model_config.hf_config.model_type == "ernie_mtp"):
|
||||
self.method = "ernie_mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
"All Ernie MTP models only have " \
|
||||
"one layer. Might need some code changes " \
|
||||
"to support multiple layers."
|
||||
)
|
||||
elif (self.draft_model_config.hf_config.model_type ==
|
||||
"qwen3_next_mtp"):
|
||||
self.method = "qwen3_next_mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
"All Qwen3Next MTP models only have " \
|
||||
"one layer. Might need some code changes " \
|
||||
"to support multiple layers."
|
||||
)
|
||||
elif (self.draft_model_config.hf_config.model_type
|
||||
in ("longcat_flash_mtp")):
|
||||
self.method = "longcat_flash_mtp"
|
||||
if self.num_speculative_tokens > 1:
|
||||
logger.warning(
|
||||
"LongCat MTP models only have " \
|
||||
"one layer. Might need some code changes " \
|
||||
"to support multiple layers."
|
||||
)
|
||||
else:
|
||||
self.method = "draft_model"
|
||||
raise NotImplementedError(
|
||||
"Speculative decoding with draft model is not "
|
||||
"supported yet. Please consider using other "
|
||||
"speculative decoding methods such as ngram, medusa, "
|
||||
"eagle, or deepseek_mtp.")
|
||||
|
||||
# Replace hf_config for EAGLE draft_model
|
||||
if self.method in ("eagle", "eagle3"):
|
||||
if self.enable_chunked_prefill and not envs.VLLM_USE_V1:
|
||||
raise ValueError(
|
||||
"Chunked prefill and EAGLE are not compatible "
|
||||
"when using V0.")
|
||||
|
||||
from vllm.transformers_utils.configs import SpeculatorsConfig
|
||||
from vllm.transformers_utils.configs.eagle import EAGLEConfig
|
||||
|
||||
if isinstance(self.draft_model_config.hf_config,
|
||||
(EAGLEConfig, SpeculatorsConfig)):
|
||||
pass
|
||||
else:
|
||||
eagle_config = EAGLEConfig(
|
||||
self.draft_model_config.hf_config,
|
||||
method=self.method,
|
||||
model_type="eagle")
|
||||
self.draft_model_config.hf_config = eagle_config
|
||||
|
||||
if (self.num_speculative_tokens is not None
|
||||
and hasattr(self.draft_model_config.hf_config,
|
||||
"num_lookahead_tokens")):
|
||||
self.draft_model_config.hf_config.num_lookahead_tokens = \
|
||||
self.num_speculative_tokens
|
||||
|
||||
n_predict = getattr(self.draft_model_config.hf_config, "n_predict",
|
||||
None)
|
||||
if n_predict is not None:
|
||||
if self.num_speculative_tokens is None:
|
||||
# Default to max value defined in draft model config.
|
||||
self.num_speculative_tokens = n_predict
|
||||
elif self.num_speculative_tokens > n_predict and \
|
||||
self.num_speculative_tokens % n_predict != 0:
|
||||
# Ensure divisibility for MTP module reuse.
|
||||
raise ValueError(
|
||||
f"num_speculative_tokens:{self.num_speculative_tokens}"
|
||||
f" must be divisible by {n_predict=}")
|
||||
|
||||
if self.speculative_token_tree is None:
|
||||
# Generate chain of tokens.
|
||||
self.speculative_token_tree = str([
|
||||
(i + 1) * (0, ) for i in range(self.num_speculative_tokens)
|
||||
])
|
||||
else:
|
||||
# Sort the token tree breadth-first.
|
||||
tree_choices = ast.literal_eval(self.speculative_token_tree)
|
||||
self.speculative_token_tree = str(
|
||||
sorted(tree_choices, key=lambda t: (len(t), t)))
|
||||
|
||||
self.draft_tensor_parallel_size = \
|
||||
SpeculativeConfig._verify_and_get_draft_tp(
|
||||
self.target_parallel_config,
|
||||
self.draft_tensor_parallel_size,
|
||||
self.draft_model_config.hf_config
|
||||
)
|
||||
|
||||
self.draft_model_config.max_model_len = (
|
||||
SpeculativeConfig._maybe_override_draft_max_model_len(
|
||||
self.max_model_len,
|
||||
self.draft_model_config.max_model_len,
|
||||
self.target_model_config.max_model_len,
|
||||
))
|
||||
|
||||
self.draft_parallel_config = (
|
||||
SpeculativeConfig.create_draft_parallel_config(
|
||||
self.target_parallel_config,
|
||||
self.draft_tensor_parallel_size))
|
||||
|
||||
|
||||
ModelConfig.is_deepseek_mla = is_deepseek_mla
|
||||
SpeculativeConfig.__post_init__ = __post_init__
|
||||
SpeculativeConfig.hf_config_override = hf_config_override
|
||||
@@ -6,6 +6,8 @@ from vllm.model_executor.models.config import MambaModelConfig
|
||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
|
||||
|
||||
@classmethod
|
||||
def verify_and_update_config(cls, vllm_config) -> None:
|
||||
@@ -22,6 +24,7 @@ def verify_and_update_config(cls, vllm_config) -> None:
|
||||
logger = init_logger(__name__)
|
||||
# Enable FULL_AND_PIECEWISE by default
|
||||
MambaModelConfig.verify_and_update_config(vllm_config)
|
||||
ascend_config = get_ascend_config()
|
||||
|
||||
cache_config = vllm_config.cache_config
|
||||
model_config = vllm_config.model_config
|
||||
@@ -38,7 +41,7 @@ def verify_and_update_config(cls, vllm_config) -> None:
|
||||
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
||||
head_size=model_config.get_head_size(),
|
||||
dtype=kv_cache_dtype,
|
||||
use_mla=model_config.use_mla).page_size_bytes
|
||||
use_mla=model_config.use_mla or ascend_config.use_sfa).page_size_bytes
|
||||
|
||||
model_cls, _ = ModelRegistry.resolve_model_cls(
|
||||
model_config.architecture,
|
||||
|
||||
@@ -0,0 +1,200 @@
|
||||
import vllm.transformers_utils.configs
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
from vllm.transformers_utils import config
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class DeepseekV3Config(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the DeepSeek-V3.
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 129280):
|
||||
Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`DeepseekV3Model`]
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 11008):
|
||||
Dimension of the MLP representations.
|
||||
moe_intermediate_size (`int`, *optional*, defaults to 1407):
|
||||
Dimension of the MoE representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the Transformer decoder.
|
||||
num_nextn_predict_layers (`int`, *optional*, defaults to 1):
|
||||
Number of nextn predict layers in the DeepSeekV3 Model.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
n_shared_experts (`int`, *optional*, defaults to None):
|
||||
Number of shared experts, None means dense model.
|
||||
n_routed_experts (`int`, *optional*, defaults to None):
|
||||
Number of routed experts, None means dense model.
|
||||
routed_scaling_factor (`float`, *optional*, defaults to 1.0):
|
||||
Scaling factor or routed experts.
|
||||
topk_method (`str`, *optional*, defaults to `gready`):
|
||||
Topk method used in routed gate.
|
||||
n_group (`int`, *optional*, defaults to None):
|
||||
Number of groups for routed experts.
|
||||
topk_group (`int`, *optional*, defaults to None):
|
||||
Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
|
||||
num_experts_per_tok (`int`, *optional*, defaults to None):
|
||||
Number of selected experts, None means dense model.
|
||||
moe_layer_freq (`int`, *optional*, defaults to 1):
|
||||
The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
|
||||
first_k_dense_replace (`int`, *optional*, defaults to 0):
|
||||
Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
|
||||
\--k dense layers--/
|
||||
norm_topk_prob (`bool`, *optional*, defaults to False):
|
||||
Whether to normalize the weights of the routed experts.
|
||||
scoring_func (`str`, *optional*, defaults to 'softmax'):
|
||||
Method of computing expert weights.
|
||||
aux_loss_alpha (`float`, *optional*, defaults to 0.001):
|
||||
Auxiliary loss weight coefficient.
|
||||
seq_aux = (`bool`, *optional*, defaults to True):
|
||||
Whether to compute the auxiliary loss for each individual sample.
|
||||
num_key_value_heads (`int`, *optional*):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||
`num_attention_heads`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
pad_token_id (`int`, *optional*):
|
||||
Padding token id.
|
||||
bos_token_id (`int`, *optional*, defaults to 1):
|
||||
Beginning of stream token id.
|
||||
eos_token_id (`int`, *optional*, defaults to 2):
|
||||
End of stream token id.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie weight embeddings
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
||||
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
||||
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
||||
`max_position_embeddings` to the expected new maximum.
|
||||
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
```python
|
||||
>>> from transformers import DeepseekV3Model, DeepseekV3Config
|
||||
>>> # Initializing a Deepseek-V3 style configuration
|
||||
>>> configuration = DeepseekV3Config()
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "deepseek_v3"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=129280,
|
||||
hidden_size=7168,
|
||||
intermediate_size=18432,
|
||||
moe_intermediate_size=2048,
|
||||
num_hidden_layers=61,
|
||||
num_nextn_predict_layers=1,
|
||||
num_attention_heads=128,
|
||||
num_key_value_heads=128,
|
||||
n_shared_experts=1,
|
||||
n_routed_experts=256,
|
||||
ep_size=1,
|
||||
routed_scaling_factor=2.5,
|
||||
kv_lora_rank=512,
|
||||
q_lora_rank=1536,
|
||||
qk_rope_head_dim=64,
|
||||
v_head_dim=128,
|
||||
qk_nope_head_dim=128,
|
||||
topk_method='noaux_tc',
|
||||
n_group=8,
|
||||
topk_group=4,
|
||||
num_experts_per_tok=8,
|
||||
moe_layer_freq=1,
|
||||
first_k_dense_replace=3,
|
||||
norm_topk_prob=True,
|
||||
scoring_func='sigmoid',
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=4096,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=None,
|
||||
bos_token_id=0,
|
||||
eos_token_id=1,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000.0,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
attention_dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.moe_intermediate_size = moe_intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_nextn_predict_layers = num_nextn_predict_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.n_shared_experts = n_shared_experts
|
||||
self.n_routed_experts = n_routed_experts
|
||||
self.ep_size = ep_size
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.topk_method = topk_method
|
||||
self.n_group = n_group
|
||||
self.topk_group = topk_group
|
||||
self.num_experts_per_tok = num_experts_per_tok
|
||||
self.moe_layer_freq = moe_layer_freq
|
||||
self.first_k_dense_replace = first_k_dense_replace
|
||||
self.norm_topk_prob = norm_topk_prob
|
||||
self.scoring_func = scoring_func
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
vllm.transformers_utils.configs.__all__.append("DeepseekV3Config")
|
||||
vllm.transformers_utils.configs.DeepseekV3Config = DeepseekV3Config
|
||||
config._CONFIG_REGISTRY["deepseek_v32"] = "DeepseekV3Config"
|
||||
@@ -20,6 +20,10 @@ from vllm.triton_utils import HAS_TRITON
|
||||
if HAS_TRITON:
|
||||
import vllm_ascend.patch.worker.patch_common.patch_triton
|
||||
|
||||
# isort: off
|
||||
import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_attentionspec # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_attention_layer # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_logits # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_weight_loader # noqa
|
||||
|
||||
202
vllm_ascend/patch/worker/patch_common/patch_attention_layer.py
Normal file
202
vllm_ascend/patch/worker/patch_common/patch_attention_layer.py
Normal file
@@ -0,0 +1,202 @@
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import vllm
|
||||
import vllm.envs as envs
|
||||
from torch import nn
|
||||
from vllm.attention import Attention, AttentionType, get_attn_backend
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.selector import backend_name_to_enum
|
||||
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
|
||||
from vllm.config import CacheConfig, get_current_vllm_config
|
||||
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.quantization.base_config import \
|
||||
QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
|
||||
class AscendAttention(Attention, nn.Module, AttentionLayerBase):
|
||||
"""Attention layer.
|
||||
|
||||
This class takes query, key, and value tensors as input. The input tensors
|
||||
can either contain prompt tokens or generation tokens.
|
||||
The class does the following:
|
||||
|
||||
1. Store the input key and value tensors in the KV cache.
|
||||
2. Perform (multi-head/multi-query/grouped-query) attention.
|
||||
3. Return the output tensor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
head_size: int,
|
||||
scale: float,
|
||||
num_kv_heads: Optional[int] = None,
|
||||
alibi_slopes: Optional[List[float]] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
logits_soft_cap: Optional[float] = None,
|
||||
per_layer_sliding_window: Optional[int] = None,
|
||||
use_mla: bool = False,
|
||||
use_sfa: bool = False,
|
||||
prefix: str = "",
|
||||
attn_type: str = AttentionType.DECODER,
|
||||
kv_sharing_target_layer_name: Optional[str] = None,
|
||||
attn_backend: Optional[type[AttentionBackend]] = None,
|
||||
**extra_impl_args,
|
||||
) -> None:
|
||||
"""
|
||||
The KV cache is stored inside this class and is accessed via
|
||||
`self.kv_cache`.
|
||||
"""
|
||||
nn.Module.__init__(self)
|
||||
AttentionLayerBase.__init__(self)
|
||||
|
||||
if per_layer_sliding_window is not None:
|
||||
# per-layer sliding window
|
||||
sliding_window = per_layer_sliding_window
|
||||
elif cache_config is not None:
|
||||
# model-level sliding window
|
||||
sliding_window = cache_config.sliding_window
|
||||
else:
|
||||
sliding_window = None
|
||||
|
||||
if cache_config is not None:
|
||||
kv_cache_dtype = cache_config.cache_dtype
|
||||
block_size = cache_config.block_size
|
||||
is_attention_free = cache_config.is_attention_free
|
||||
calculate_kv_scales = cache_config.calculate_kv_scales
|
||||
else:
|
||||
kv_cache_dtype = "auto"
|
||||
block_size = 16
|
||||
is_attention_free = False
|
||||
calculate_kv_scales = False
|
||||
if num_kv_heads is None:
|
||||
num_kv_heads = num_heads
|
||||
assert num_heads % num_kv_heads == 0, \
|
||||
f"num_heads ({num_heads}) is not " \
|
||||
f"divisible by num_kv_heads ({num_kv_heads})"
|
||||
|
||||
# The default k/v_scale is set to 1.0. This is ignored
|
||||
# when kv-cache is not fp8, and should be used with
|
||||
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
|
||||
# expect the pre-quantized k/v_scale to be loaded along
|
||||
# with the model weights.
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.calculate_kv_scales = calculate_kv_scales
|
||||
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
# FlashAttn doesn't support quantizing the kv-cache only
|
||||
# but requires q to be quantized as well.
|
||||
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||
|
||||
# We also keep q/k/v_scale on host (cpu) memory for attention
|
||||
# backends that require the scales to be on host instead of on device.
|
||||
# e.g. Flashinfer
|
||||
self._q_scale_float = 1.0
|
||||
self._k_scale_float = 1.0
|
||||
self._v_scale_float = 1.0
|
||||
|
||||
# The output scale on host memory. This should be the input scale of
|
||||
# the quant op after this attention layer.
|
||||
self._o_scale_float: Optional[float] = None
|
||||
|
||||
self.use_mla = use_mla
|
||||
self.num_heads = num_heads
|
||||
self.head_size = head_size
|
||||
self.num_kv_heads = num_kv_heads
|
||||
self.sliding_window = sliding_window
|
||||
self.has_sink = extra_impl_args.get("sinks") is not None
|
||||
|
||||
quant_method = quant_config.get_quant_method(
|
||||
self, prefix=prefix) if quant_config else None
|
||||
if quant_method is not None and not isinstance(
|
||||
quant_method, UnquantizedLinearMethod):
|
||||
assert isinstance(quant_method, BaseKVCacheMethod)
|
||||
# TODO (mgoin): kv cache dtype should be specified in the FP8
|
||||
# checkpoint config and become the "auto" behavior
|
||||
if self.kv_cache_dtype == "fp8_e5m2":
|
||||
raise ValueError("fp8_e5m2 kv-cache is not supported with "
|
||||
"fp8 checkpoints.")
|
||||
# If quantization is enabled, we make "k_scale" and "v_scale"
|
||||
# parameters so that it can be loaded from the model checkpoint.
|
||||
# The k/v_scale will then be converted back to native float32
|
||||
# values after weight loading.
|
||||
self.quant_method = quant_method
|
||||
self.quant_method.create_weights(self)
|
||||
|
||||
# During model initialization, the default dtype is set as the model
|
||||
# weight and activation dtype.
|
||||
dtype = torch.get_default_dtype()
|
||||
if attn_backend is None:
|
||||
if vllm_version_is("0.10.2"):
|
||||
self.attn_backend = get_attn_backend(head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
is_attention_free,
|
||||
use_mla=use_mla,
|
||||
use_sfa=use_sfa,
|
||||
has_sink=self.has_sink)
|
||||
else:
|
||||
self.attn_backend = get_attn_backend(head_size,
|
||||
dtype,
|
||||
kv_cache_dtype,
|
||||
block_size,
|
||||
use_mla=use_mla,
|
||||
use_sfa=use_sfa,
|
||||
has_sink=self.has_sink)
|
||||
else:
|
||||
self.attn_backend = attn_backend
|
||||
|
||||
impl_cls = self.attn_backend.get_impl_cls()
|
||||
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
||||
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||
logits_soft_cap, attn_type,
|
||||
kv_sharing_target_layer_name, **extra_impl_args)
|
||||
self.backend = backend_name_to_enum(self.attn_backend.get_name())
|
||||
self.dtype = dtype
|
||||
|
||||
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
|
||||
# torch.compile works by registering the attention as one giant
|
||||
# opaque custom op. For other platforms, we directly call them
|
||||
# and let torch.compile handle them.
|
||||
self.use_direct_call = not current_platform.opaque_attention_op()
|
||||
|
||||
self.use_output = self.attn_backend.accept_output_buffer
|
||||
compilation_config = get_current_vllm_config().compilation_config
|
||||
if prefix in compilation_config.static_forward_context:
|
||||
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||
compilation_config.static_forward_context[prefix] = self
|
||||
self.layer_name = prefix
|
||||
self.attn_type = attn_type
|
||||
|
||||
if kv_sharing_target_layer_name is not None:
|
||||
validate_kv_sharing_target(
|
||||
prefix,
|
||||
kv_sharing_target_layer_name,
|
||||
compilation_config.static_forward_context,
|
||||
)
|
||||
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||
|
||||
# use a placeholder kv cache tensor during init, which will be replaced
|
||||
# by bind_kv_cache
|
||||
# this variable will not be accessed if use_direct_call is True
|
||||
self.kv_cache = [
|
||||
torch.tensor([]) for _ in range(get_current_vllm_config(
|
||||
).parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
|
||||
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
|
||||
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
|
||||
self.query_quant = None
|
||||
|
||||
|
||||
vllm.attention.Attention = AscendAttention
|
||||
@@ -0,0 +1,181 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# mypy: ignore-errors
|
||||
from functools import cache
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import vllm
|
||||
import vllm.envs as envs
|
||||
from vllm.attention.backends.abstract import AttentionBackend
|
||||
from vllm.attention.selector import (backend_name_to_enum,
|
||||
get_global_forced_attn_backend)
|
||||
from vllm.platforms import _Backend, current_platform
|
||||
from vllm.utils import resolve_obj_by_qualname
|
||||
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
|
||||
if vllm_version_is("0.10.2"):
|
||||
|
||||
def get_attn_backend(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
is_attention_free: bool = False,
|
||||
use_mla: bool = False,
|
||||
use_sfa: bool = False,
|
||||
has_sink: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
"""Selects which attention backend to use and lazily imports it."""
|
||||
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
|
||||
# value to be returned from the cache if the value changes between calls.
|
||||
# To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the
|
||||
# private function.
|
||||
return _cached_get_attn_backend(
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
block_size=block_size,
|
||||
is_attention_free=is_attention_free,
|
||||
use_v1=envs.VLLM_USE_V1,
|
||||
use_mla=use_mla,
|
||||
use_sfa=use_sfa,
|
||||
has_sink=has_sink,
|
||||
)
|
||||
|
||||
@cache
|
||||
def _cached_get_attn_backend(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
is_attention_free: bool,
|
||||
use_v1: bool = False,
|
||||
use_mla: bool = False,
|
||||
use_sfa: bool = False,
|
||||
has_sink: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
# If there are no attention layers (e.g. we are running Mamba),
|
||||
# use the placeholder NO_ATTENTION
|
||||
if is_attention_free:
|
||||
from vllm.attention.backends.placeholder_attn import \
|
||||
PlaceholderAttentionBackend
|
||||
return PlaceholderAttentionBackend
|
||||
|
||||
# Check whether a particular choice of backend was
|
||||
# previously forced.
|
||||
#
|
||||
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
|
||||
# ENVIRONMENT VARIABLE.
|
||||
selected_backend = None
|
||||
backend_by_global_setting: Optional[_Backend] = (
|
||||
get_global_forced_attn_backend())
|
||||
if backend_by_global_setting is not None:
|
||||
selected_backend = backend_by_global_setting
|
||||
else:
|
||||
# Check the environment variable and override if specified
|
||||
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
|
||||
if backend_by_env_var is not None:
|
||||
selected_backend = backend_name_to_enum(backend_by_env_var)
|
||||
if selected_backend is None:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend: '{backend_by_env_var}'. "
|
||||
f"Valid backends are: {list(_Backend.__members__.keys())}"
|
||||
)
|
||||
|
||||
# get device-specific attn_backend
|
||||
attention_cls = current_platform.get_attn_backend_cls(
|
||||
selected_backend, head_size, dtype, kv_cache_dtype, block_size,
|
||||
use_v1, use_mla, use_sfa, has_sink)
|
||||
if not attention_cls:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend for {current_platform.device_name}"
|
||||
)
|
||||
return resolve_obj_by_qualname(attention_cls)
|
||||
else:
|
||||
|
||||
def get_attn_backend( # type: ignore[misc]
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
use_mla: bool = False,
|
||||
use_sfa: bool = False,
|
||||
has_sink: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
"""Selects which attention backend to use and lazily imports it."""
|
||||
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
|
||||
# value to be returned from the cache if the value changes between calls.
|
||||
# To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the
|
||||
# private function.
|
||||
return _cached_get_attn_backend(
|
||||
head_size=head_size,
|
||||
dtype=dtype,
|
||||
kv_cache_dtype=kv_cache_dtype,
|
||||
block_size=block_size,
|
||||
use_v1=envs.VLLM_USE_V1,
|
||||
use_mla=use_mla,
|
||||
use_sfa=use_sfa,
|
||||
has_sink=has_sink,
|
||||
)
|
||||
|
||||
@cache
|
||||
def _cached_get_attn_backend(
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
kv_cache_dtype: Optional[str],
|
||||
block_size: int,
|
||||
use_v1: bool = False,
|
||||
use_mla: bool = False,
|
||||
use_sfa: bool = False,
|
||||
has_sink: bool = False,
|
||||
) -> type[AttentionBackend]:
|
||||
# Check whether a particular choice of backend was
|
||||
# previously forced.
|
||||
#
|
||||
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
|
||||
# ENVIRONMENT VARIABLE.
|
||||
selected_backend = None
|
||||
backend_by_global_setting: Optional[_Backend] = (
|
||||
get_global_forced_attn_backend())
|
||||
if backend_by_global_setting is not None:
|
||||
selected_backend = backend_by_global_setting
|
||||
else:
|
||||
# Check the environment variable and override if specified
|
||||
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
|
||||
if backend_by_env_var is not None:
|
||||
selected_backend = backend_name_to_enum(backend_by_env_var)
|
||||
if selected_backend is None:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend: '{backend_by_env_var}'. "
|
||||
f"Valid backends are: {list(_Backend.__members__.keys())}"
|
||||
)
|
||||
|
||||
# get device-specific attn_backend
|
||||
attention_cls = current_platform.get_attn_backend_cls(
|
||||
selected_backend, head_size, dtype, kv_cache_dtype, block_size,
|
||||
use_v1, use_mla, use_sfa, has_sink)
|
||||
if not attention_cls:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend for {current_platform.device_name}"
|
||||
)
|
||||
return resolve_obj_by_qualname(attention_cls)
|
||||
|
||||
|
||||
vllm.attention.get_attn_backend = get_attn_backend
|
||||
vllm.attention.selector._cached_get_attn_backend = _cached_get_attn_backend
|
||||
110
vllm_ascend/patch/worker/patch_common/patch_attentionspec.py
Normal file
110
vllm_ascend/patch/worker/patch_common/patch_attentionspec.py
Normal file
@@ -0,0 +1,110 @@
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import vllm
|
||||
from typing_extensions import Self
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.utils import cdiv, get_dtype_size
|
||||
from vllm.v1.core.single_type_kv_cache_manager import (FullAttentionManager,
|
||||
spec_manager_map)
|
||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AttentionSpec(KVCacheSpec):
|
||||
num_kv_heads: int
|
||||
head_size: int
|
||||
dtype: torch.dtype
|
||||
use_mla: bool
|
||||
use_sfa: bool
|
||||
|
||||
@property
|
||||
def page_size_bytes(self) -> int:
|
||||
# For MLA we only store a single latent vector
|
||||
coef = 1 if self.use_mla else 2
|
||||
sfa_bytes = 128 * self.block_size * get_dtype_size(
|
||||
self.dtype) if self.use_sfa else 0
|
||||
|
||||
return coef * self.block_size * self.num_kv_heads * self.head_size \
|
||||
* get_dtype_size(self.dtype) + sfa_bytes
|
||||
|
||||
|
||||
vllm.v1.kv_cache_interface.AttentionSpec = AttentionSpec
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AscendFullAttentionSpec(FullAttentionSpec, AttentionSpec):
|
||||
sliding_window: Optional[int] = None
|
||||
attention_chunk_size: Optional[int] = None
|
||||
"""
|
||||
When hybrid allocator is disabled and the model contains both full
|
||||
attention layers and sliding window attention layers, sliding
|
||||
window attention are regarded as full attention in KV cache manager
|
||||
(blocks are allocated for all tokens), while computed as sliding window
|
||||
attention in model runner.
|
||||
In this case, we use FullAttentionSpec and record the sliding window size.
|
||||
Default to None for not using sliding window attention.
|
||||
"""
|
||||
|
||||
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||
max_model_len = vllm_config.model_config.max_model_len
|
||||
dcp_world_size = \
|
||||
vllm_config.parallel_config.decode_context_parallel_size
|
||||
# Note(hc): each dcp rank only need save
|
||||
# (max_model_len//dcp_world_size) tokens locally.
|
||||
if dcp_world_size > 1:
|
||||
max_model_len = cdiv(max_model_len, dcp_world_size)
|
||||
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
|
||||
|
||||
@classmethod
|
||||
def merge_window_sizes(cls, window_sizes: set[int]) -> Optional[int]:
|
||||
if len(window_sizes) == 0:
|
||||
return None
|
||||
elif len(window_sizes) == 1:
|
||||
return window_sizes.pop()
|
||||
else:
|
||||
raise ValueError(
|
||||
"All attention layers in the same KV cache group must have the "
|
||||
"same window size.")
|
||||
|
||||
@classmethod
|
||||
def merge(cls, specs: list[Self]) -> Self:
|
||||
"""
|
||||
Merge a list of FullAttentionSpec objects into a single
|
||||
FullAttentionSpec object.
|
||||
"""
|
||||
assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
|
||||
"All attention layers in the same KV cache group must be "
|
||||
"FullAttentionSpec.")
|
||||
|
||||
sliding_window = set(spec.sliding_window for spec in specs
|
||||
if spec.sliding_window is not None)
|
||||
attention_chunk_size = set(spec.attention_chunk_size for spec in specs
|
||||
if spec.attention_chunk_size is not None)
|
||||
merged_spec = cls(
|
||||
block_size=specs[0].block_size,
|
||||
num_kv_heads=specs[0].num_kv_heads,
|
||||
head_size=specs[0].head_size,
|
||||
dtype=specs[0].dtype,
|
||||
use_mla=specs[0].use_mla,
|
||||
use_sfa=specs[0].use_sfa,
|
||||
sliding_window=cls.merge_window_sizes(sliding_window),
|
||||
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
|
||||
)
|
||||
for spec in specs:
|
||||
for f in fields(AttentionSpec):
|
||||
assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
|
||||
"All attention layers in the same KV cache group must have "
|
||||
"the same attention spec.")
|
||||
assert (
|
||||
(merged_spec.sliding_window is not None) +
|
||||
(merged_spec.attention_chunk_size is not None) <= 1
|
||||
), ("Model with both sliding window layers and chunked local attention "
|
||||
"layers is not supported.")
|
||||
return merged_spec
|
||||
|
||||
|
||||
spec_manager_map.update({AscendFullAttentionSpec: FullAttentionManager})
|
||||
|
||||
vllm.v1.kv_cache_interface.FullAttentionSpec = AscendFullAttentionSpec
|
||||
@@ -300,6 +300,7 @@ class NPUPlatform(Platform):
|
||||
block_size,
|
||||
use_v1,
|
||||
use_mla,
|
||||
use_sfa,
|
||||
has_sink=False):
|
||||
if not use_v1:
|
||||
raise ValueError("vLLM Ascend does not support V0 engine.")
|
||||
@@ -307,21 +308,28 @@ class NPUPlatform(Platform):
|
||||
ascend_config = get_ascend_config()
|
||||
|
||||
if use_mla and ascend_config.enable_shared_expert_dp:
|
||||
return "vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend"
|
||||
if use_mla and not use_sfa:
|
||||
return "vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend"
|
||||
if use_mla and use_sfa:
|
||||
return "vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend"
|
||||
|
||||
use_torchair = ascend_config.torchair_graph_config.enabled
|
||||
# choose attention backend based on use_mla and use_torchair
|
||||
backend_map = {
|
||||
(True, True):
|
||||
(True, False, True):
|
||||
"vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend",
|
||||
(True, False):
|
||||
(True, False, False):
|
||||
"vllm_ascend.attention.mla_v1.AscendMLABackend",
|
||||
(False, True):
|
||||
(False, False, True):
|
||||
"vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend",
|
||||
(False, False):
|
||||
"vllm_ascend.attention.attention_v1.AscendAttentionBackend"
|
||||
(False, False, False):
|
||||
"vllm_ascend.attention.attention_v1.AscendAttentionBackend",
|
||||
(True, True, False):
|
||||
"vllm_ascend.attention.sfa_v1.AscendSFABackend",
|
||||
(True, True, True):
|
||||
"vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend",
|
||||
}
|
||||
return backend_map[(use_mla, use_torchair)]
|
||||
return backend_map[(use_mla, use_sfa, use_torchair)]
|
||||
|
||||
@classmethod
|
||||
def get_punica_wrapper(cls) -> str:
|
||||
|
||||
@@ -603,10 +603,11 @@ class MtpProposer(Proposer):
|
||||
torch.npu.set_compile_mode(jit_compile=False)
|
||||
if not self.runner.use_cached_npu_graph:
|
||||
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
||||
self.torchair_compiled_model = torch.compile(self.model,
|
||||
dynamic=True,
|
||||
fullgraph=True,
|
||||
backend=npu_backend)
|
||||
self.torchair_compiled_model = torch.compile(
|
||||
self.model,
|
||||
dynamic=not get_ascend_config().use_sfa,
|
||||
fullgraph=True,
|
||||
backend=npu_backend)
|
||||
return self.torchair_compiled_model
|
||||
else:
|
||||
# Generate a new forward proxy code object to prevent the invalidation of
|
||||
@@ -627,7 +628,7 @@ class MtpProposer(Proposer):
|
||||
self.torchair_compiled_models[
|
||||
batch_size] = torchair.inference.cache_compile(
|
||||
self.model.__dict__[forward_proxy_name],
|
||||
dynamic=True,
|
||||
dynamic=not get_ascend_config().use_sfa,
|
||||
fullgraph=True,
|
||||
cache_dir=TORCHAIR_CACHE_DIR,
|
||||
config=config,
|
||||
|
||||
@@ -67,7 +67,9 @@ from vllm.model_executor.models.utils import (
|
||||
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from vllm_ascend import envs
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.models.layers.sfa import Indexer
|
||||
from vllm_ascend.quantization.quant_config import AscendLinearMethod
|
||||
from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE
|
||||
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \
|
||||
@@ -435,6 +437,7 @@ class TorchairDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
decoder_layer=None,
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.hidden_size = hidden_size
|
||||
@@ -630,6 +633,225 @@ class TorchairDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
||||
output_shape=output_shape)
|
||||
|
||||
|
||||
class TorchairDeepseekV2SFAAttention(DeepseekV2MLAAttention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
qk_nope_head_dim: int,
|
||||
qk_rope_head_dim: int,
|
||||
v_head_dim: int,
|
||||
q_lora_rank: Optional[int],
|
||||
kv_lora_rank: int,
|
||||
rope_theta: float = 10000,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
max_position_embeddings: int = 8192,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
decoder_layer=None,
|
||||
) -> None:
|
||||
nn.Module.__init__(self)
|
||||
self.hidden_size = hidden_size
|
||||
self.qk_nope_head_dim = qk_nope_head_dim
|
||||
self.qk_rope_head_dim = qk_rope_head_dim
|
||||
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||
self.v_head_dim = v_head_dim
|
||||
|
||||
self.q_lora_rank = q_lora_rank
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
|
||||
self.num_heads = num_heads
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
assert num_heads % self.tp_size == 0
|
||||
self.num_local_heads = num_heads // self.tp_size
|
||||
self.layers = config.num_hidden_layers
|
||||
self.first_k_dense_replace = config.first_k_dense_replace
|
||||
|
||||
self.scaling = self.qk_head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
self.prefix = prefix
|
||||
self.debug_layer_idx = int(self.prefix.split(".")[-2])
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
self.q_a_proj = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.q_lora_rank,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_a_proj",
|
||||
return_bias=False,
|
||||
)
|
||||
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
|
||||
eps=config.rms_norm_eps)
|
||||
self.q_b_proj = ColumnParallelLinear(
|
||||
q_lora_rank,
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_b_proj",
|
||||
return_bias=False,
|
||||
)
|
||||
else:
|
||||
self.q_proj = ColumnParallelLinear(
|
||||
self.hidden_size,
|
||||
self.num_heads * self.qk_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.q_proj",
|
||||
return_bias=False,
|
||||
)
|
||||
|
||||
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||
self.hidden_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.kv_a_proj_with_mqa",
|
||||
return_bias=False,
|
||||
)
|
||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
|
||||
eps=config.rms_norm_eps)
|
||||
self.kv_b_proj = ColumnParallelLinear(
|
||||
self.kv_lora_rank,
|
||||
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.kv_b_proj",
|
||||
return_bias=False,
|
||||
)
|
||||
if (config.n_routed_experts is not None
|
||||
and self.debug_layer_idx >= config.first_k_dense_replace
|
||||
and self.debug_layer_idx % config.moe_layer_freq == 0
|
||||
and (ascend_config.multistream_overlap_shared_expert
|
||||
or self.enable_shared_expert_dp)):
|
||||
self.o_proj = TorchairDeepseekV2RowParallelLinearReplaceAllreduce(
|
||||
self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
return_bias=False,
|
||||
)
|
||||
else:
|
||||
self.o_proj = TorchairDeepseekV2RowParallelLinear(
|
||||
self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
return_bias=False,
|
||||
)
|
||||
|
||||
if rope_scaling:
|
||||
rope_scaling["rope_type"] = 'deepseek_yarn'
|
||||
self.rotary_emb = get_rope(qk_rope_head_dim,
|
||||
rotary_dim=qk_rope_head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
is_neox_style=False)
|
||||
if rope_scaling:
|
||||
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
self.scaling = self.scaling * mscale * mscale
|
||||
|
||||
self.dim: int = config.hidden_size # 7168
|
||||
# TODO(zzzzwwjj): wait transformers add these params
|
||||
self.n_heads: int = 64 # 64
|
||||
self.head_dim: int = 128 # 128
|
||||
self.index_topk: int = 2048 # 2048
|
||||
self.indexer = Indexer(
|
||||
config,
|
||||
quant_config=quant_config,
|
||||
dim=self.dim,
|
||||
n_heads=self.n_heads,
|
||||
head_dim=self.head_dim,
|
||||
index_topk=self.index_topk,
|
||||
prefix=f"{prefix}.indexer",
|
||||
)
|
||||
|
||||
self.sfa_attn = Attention(
|
||||
num_heads=self.num_local_heads,
|
||||
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
scale=self.scaling,
|
||||
num_kv_heads=1,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
use_mla=True,
|
||||
use_sfa=True,
|
||||
# SFA Args
|
||||
q_lora_rank=self.q_lora_rank,
|
||||
kv_lora_rank=self.kv_lora_rank,
|
||||
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||
qk_head_dim=self.qk_head_dim,
|
||||
v_head_dim=self.v_head_dim,
|
||||
rotary_emb=self.rotary_emb,
|
||||
q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None,
|
||||
q_a_layernorm=self.q_a_layernorm
|
||||
if self.q_lora_rank is not None else None,
|
||||
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
|
||||
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
|
||||
kv_a_layernorm=self.kv_a_layernorm,
|
||||
kv_b_proj=self.kv_b_proj,
|
||||
o_proj=self.o_proj,
|
||||
indexer=self.indexer,
|
||||
decoder_layer=decoder_layer,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor] = None,
|
||||
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||
forward_context = get_forward_context()
|
||||
if not self.torchair_graph_enabled:
|
||||
if forward_context.attn_metadata is not None and isinstance(
|
||||
forward_context.attn_metadata, dict):
|
||||
attn_metadata = next(
|
||||
iter(forward_context.attn_metadata.values()), None)
|
||||
else:
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if kv_cache is None:
|
||||
kv_cache = self.sfa_attn.kv_cache[
|
||||
forward_context.virtual_engine]
|
||||
|
||||
num_tokens = hidden_states.shape[0]
|
||||
need_gather_q_kv = False
|
||||
# if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
|
||||
# # Simulate all gather to calculate output shape
|
||||
# num_tokens = num_tokens * self.tp_size
|
||||
# need_gather_q_kv = True
|
||||
if not self.enable_shared_expert_dp or self.debug_layer_idx != self.first_k_dense_replace:
|
||||
output_shape = hidden_states.shape
|
||||
if self.enable_shared_expert_dp and (
|
||||
self.debug_layer_idx == self.first_k_dense_replace
|
||||
or self.debug_layer_idx == self.layers):
|
||||
rows = num_tokens // self.tp_size
|
||||
if num_tokens % self.tp_size:
|
||||
rows += 1
|
||||
output_shape = (rows, hidden_states.shape[1])
|
||||
output = torch.empty(output_shape,
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
self.sfa_attn.impl.forward(hidden_states, kv_cache, attn_metadata,
|
||||
need_gather_q_kv, output)
|
||||
output = output.view(-1, output_shape[-1])
|
||||
return output
|
||||
|
||||
|
||||
class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
|
||||
def __init__(
|
||||
@@ -654,9 +876,16 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tp_group().rank_in_group
|
||||
ascend_config = get_ascend_config()
|
||||
self.use_mla = False
|
||||
self.use_sfa = False
|
||||
# TODO: enable mla in vllm-ascend
|
||||
if model_config.use_mla:
|
||||
attn_cls = TorchairDeepseekV2MLAAttention
|
||||
if ascend_config.use_sfa:
|
||||
attn_cls = TorchairDeepseekV2SFAAttention
|
||||
self.use_sfa = True
|
||||
else:
|
||||
attn_cls = TorchairDeepseekV2MLAAttention # type: ignore[assignment]
|
||||
self.use_mla = True
|
||||
else:
|
||||
attn_cls = DeepseekV2Attention
|
||||
self.self_attn = attn_cls(
|
||||
@@ -675,6 +904,7 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
decoder_layer=self,
|
||||
)
|
||||
|
||||
if (config.n_routed_experts is not None
|
||||
@@ -715,21 +945,34 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
replace_allreduce: bool = False,
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
if attn_metadata is not None and attn_metadata.num_decodes > 0:
|
||||
mla_moe_communication = self.mla_moe_communication and replace_allreduce
|
||||
if attn_metadata is not None:
|
||||
decoding_condition_met = (
|
||||
not attn_metadata.is_prefill if self.use_sfa else
|
||||
attn_metadata.num_decodes > 0 if self.use_mla else False)
|
||||
mla_moe_communication = decoding_condition_met and self.mla_moe_communication and replace_allreduce
|
||||
else:
|
||||
mla_moe_communication = False
|
||||
if residual is None:
|
||||
|
||||
forward_context = get_forward_context()
|
||||
if (envs.VLLM_ASCEND_ENABLE_MLAPO
|
||||
and isinstance(self.self_attn, TorchairDeepseekV2SFAAttention)
|
||||
and attn_metadata is not None
|
||||
and not forward_context.with_prefill):
|
||||
if residual is not None:
|
||||
hidden_states = hidden_states + residual
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
previous_hidden_states, previous_residual = hidden_states, residual
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
# Dispose hidden_states and residual from the previous layer
|
||||
# to save npu memory because they're no longer used.
|
||||
dispose_tensor(previous_hidden_states)
|
||||
dispose_tensor(previous_residual)
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
previous_hidden_states, previous_residual = hidden_states, residual
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
# Dispose hidden_states and residual from the previous layer
|
||||
# to save npu memory because they're no longer used.
|
||||
dispose_tensor(previous_hidden_states)
|
||||
dispose_tensor(previous_residual)
|
||||
if mla_moe_communication and self.layer_idx > self.first_k_dense_replace:
|
||||
hidden_states = tensor_model_parallel_all_gather(hidden_states,
|
||||
dim=0)
|
||||
|
||||
@@ -48,8 +48,8 @@ from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
class NPUTorchairModelRunner(NPUModelRunner):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
||||
ascend_config = get_ascend_config()
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
self.ascend_config = get_ascend_config()
|
||||
self.enable_shared_expert_dp = self.ascend_config.enable_shared_expert_dp
|
||||
super().__init__(vllm_config, device)
|
||||
if self.speculative_config:
|
||||
self.actual_seq_lengths_q = list(
|
||||
@@ -66,10 +66,10 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
self.new_kv_cache_bytes = -1
|
||||
self.torchair_compiled_model = None # type: ignore
|
||||
self.torchair_compiled_models = {} # type: ignore
|
||||
self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph
|
||||
self.use_cached_kv_cache_bytes = ascend_config.torchair_graph_config.use_cached_kv_cache_bytes
|
||||
self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes
|
||||
if ascend_config.torchair_graph_config.graph_batch_sizes_init:
|
||||
self.use_cached_npu_graph = self.ascend_config.torchair_graph_config.use_cached_graph
|
||||
self.use_cached_kv_cache_bytes = self.ascend_config.torchair_graph_config.use_cached_kv_cache_bytes
|
||||
self.torchair_graph_batch_sizes = self.ascend_config.torchair_graph_config.graph_batch_sizes
|
||||
if self.ascend_config.torchair_graph_config.graph_batch_sizes_init:
|
||||
self.init_torchair_graph_batch_sizes()
|
||||
|
||||
self.update_torchair_graph_batch_sizes()
|
||||
@@ -362,22 +362,23 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
communication_adaptation_310p()
|
||||
|
||||
config = torchair.CompilerConfig()
|
||||
if get_ascend_config().torchair_graph_config.mode:
|
||||
config.mode = get_ascend_config().torchair_graph_config.mode
|
||||
if self.ascend_config.torchair_graph_config.mode:
|
||||
config.mode = self.ascend_config.torchair_graph_config.mode
|
||||
config.experimental_config.frozen_parameter = \
|
||||
get_ascend_config().torchair_graph_config.enable_frozen_parameter
|
||||
self.ascend_config.torchair_graph_config.enable_frozen_parameter
|
||||
# enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to
|
||||
# disable it on 300I Duo platform now.
|
||||
config.experimental_config.tiling_schedule_optimize = not is_310p()
|
||||
config.experimental_config.enable_view_optimize = \
|
||||
get_ascend_config().torchair_graph_config.enable_view_optimize
|
||||
self.ascend_config.torchair_graph_config.enable_view_optimize
|
||||
torch.npu.set_compile_mode(jit_compile=False)
|
||||
if not self.use_cached_npu_graph:
|
||||
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
||||
self.torchair_compiled_model = torch.compile(self.model,
|
||||
dynamic=True,
|
||||
fullgraph=True,
|
||||
backend=npu_backend)
|
||||
self.torchair_compiled_model = torch.compile(
|
||||
self.model,
|
||||
dynamic=not self.ascend_config.use_sfa,
|
||||
fullgraph=True,
|
||||
backend=npu_backend)
|
||||
return self.torchair_compiled_model
|
||||
else:
|
||||
# Generate a new forward proxy code object to prevent the invalidation of
|
||||
@@ -398,7 +399,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
||||
self.torchair_compiled_models[
|
||||
batch_size] = torchair.inference.cache_compile(
|
||||
self.model.__dict__[forward_proxy_name],
|
||||
dynamic=True,
|
||||
dynamic=not self.ascend_config.use_sfa,
|
||||
fullgraph=True,
|
||||
cache_dir=TORCHAIR_CACHE_DIR,
|
||||
config=config,
|
||||
|
||||
1330
vllm_ascend/torchair/torchair_sfa.py
Normal file
1330
vllm_ascend/torchair/torchair_sfa.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -165,6 +165,11 @@ def register_torchair_model():
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"DeepseekV32ForCausalLM",
|
||||
"vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM"
|
||||
)
|
||||
|
||||
ModelRegistry.register_model(
|
||||
"Qwen2ForCausalLM",
|
||||
"vllm_ascend.torchair.models.qwen2:CustomQwen2ForCausalLM")
|
||||
|
||||
@@ -285,8 +285,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.intermediate_tensors: Optional[IntermediateTensors] = None
|
||||
self.runner_only_attn_layers: set[str] = set()
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
if ascend_config.ascend_scheduler_config.enabled:
|
||||
self.ascend_config = get_ascend_config()
|
||||
if self.ascend_config.ascend_scheduler_config.enabled:
|
||||
self.chunked_prefill_enabled = self.scheduler_config.chunked_prefill_enabled
|
||||
else:
|
||||
self.chunked_prefill_enabled = True
|
||||
@@ -298,6 +298,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.cache_config.cache_dtype]
|
||||
# use_hybrid_blocks: if hybrid blocks is used.
|
||||
self.use_hybrid_blocks: bool = False
|
||||
self.need_accepted_tokens: bool = False
|
||||
|
||||
self.is_multimodal_model = self.model_config.is_multimodal_model
|
||||
self.is_pooling_model = self.model_config.pooler_config is not None
|
||||
@@ -315,7 +316,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.block_size,
|
||||
self.model_config.is_attention_free,
|
||||
use_mla=self.model_config.use_mla,
|
||||
)
|
||||
use_sfa=self.ascend_config.use_sfa)
|
||||
else:
|
||||
self.attn_backend = get_attn_backend(
|
||||
0,
|
||||
@@ -323,7 +324,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
None,
|
||||
self.block_size,
|
||||
use_mla=self.model_config.use_mla,
|
||||
)
|
||||
use_sfa=self.ascend_config.use_sfa)
|
||||
if torch.version.cann.startswith("8.3"):
|
||||
self.attn_mask_builder = AttentionMaskBuilder(
|
||||
self.scheduler_config.max_num_batched_tokens, self.dtype,
|
||||
@@ -457,7 +458,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
dtype=torch.bool,
|
||||
device=self.device,
|
||||
)
|
||||
self.dynamic_eplb = ascend_config.dynamic_eplb
|
||||
self.dynamic_eplb = self.ascend_config.dynamic_eplb
|
||||
if self.dynamic_eplb:
|
||||
self.is_eplb_warmuped = False
|
||||
self.eplb_loader = D2DExpertWeightLoader()
|
||||
@@ -890,15 +891,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
def _make_attention_mask(self, seq_lens, position,
|
||||
attn_state) -> torch.Tensor:
|
||||
# Chunk Prefill situation.
|
||||
if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla:
|
||||
if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.ascend_config.use_sfa:
|
||||
if torch.version.cann.startswith("8.3"):
|
||||
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||
else:
|
||||
return self.attn_mask_builder.get_splitfuse_attn_mask(
|
||||
seq_lens, position, self.dtype, self.device)
|
||||
|
||||
# Prefill without cache situation.
|
||||
elif attn_state == AscendAttentionState.PrefillNoCache:
|
||||
max_seq_len = max(seq_lens, default=0)
|
||||
max_seq_len = max(seq_lens.max().item(), 0)
|
||||
return self.attn_mask_builder.get_attn_mask(
|
||||
max_seq_len, self.dtype, self.device)
|
||||
# Prefill with cache hit.
|
||||
@@ -1252,7 +1254,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
req_ids = self.input_batch.req_ids
|
||||
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
|
||||
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
||||
max_num_scheduled_tokens = max(tokens)
|
||||
max_num_scheduled_tokens = num_scheduled_tokens.max()
|
||||
num_valid_tokens = np.array([
|
||||
num_tokens -
|
||||
len(scheduler_output.scheduled_spec_decode_tokens.get(i, []))
|
||||
@@ -1376,8 +1378,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
positions_cpu = self.positions_cpu[:num_input_tokens]
|
||||
positions = self.positions[:num_input_tokens]
|
||||
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
|
||||
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
|
||||
num_valid_tokens)
|
||||
self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
|
||||
position=positions_cpu,
|
||||
attn_state=attn_state)
|
||||
@@ -1477,7 +1477,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_computed_tokens_cpu = (
|
||||
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
|
||||
spec_decode_common_attn_metadata = None
|
||||
if use_spec_decode:
|
||||
if use_spec_decode and self.need_accepted_tokens:
|
||||
self.num_accepted_tokens.np[:num_reqs] = (
|
||||
self.input_batch.num_accepted_tokens_cpu[:num_reqs])
|
||||
self.num_accepted_tokens.np[num_reqs:].fill(1)
|
||||
@@ -1550,7 +1550,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
model=self.model,
|
||||
**extra_attn_metadata_args)
|
||||
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
if self.vllm_config.model_config.use_mla or self.ascend_config.use_sfa:
|
||||
attn_metadata_i.num_input_tokens = num_input_tokens
|
||||
for layer_name in attn_group.layer_names:
|
||||
attn_metadata[layer_name] = attn_metadata_i
|
||||
@@ -2060,7 +2060,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
sampling_metadata,
|
||||
)
|
||||
sampler_output.sampled_token_ids = output_token_ids
|
||||
self._update_states_after_model_execute(output_token_ids)
|
||||
if self.need_accepted_tokens:
|
||||
self._update_states_after_model_execute(output_token_ids)
|
||||
|
||||
discard_sampled_tokens_req_indices: list[int] = []
|
||||
# TODO(woosuk): The following loop can be slow since it iterates over
|
||||
@@ -2683,10 +2684,26 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.initialize_attn_backend(kv_cache_config)
|
||||
self.use_hybrid_blocks = (len(self.attn_groups) > 1)
|
||||
# NOTE: Currently, we determine whether we need `num_accepted_tokens` through `MambaSpec`.
|
||||
if vllm_version_is("0.10.2"):
|
||||
self.need_accepted_tokens = any([
|
||||
isinstance(
|
||||
self.kv_cache_config.kv_cache_groups[0].kv_cache_spec,
|
||||
MambaSpec) for attn_group in self.attn_groups
|
||||
])
|
||||
else:
|
||||
self.need_accepted_tokens = any([
|
||||
isinstance(attn_group[0].kv_cache_spec, MambaSpec)
|
||||
for attn_group in self.attn_groups
|
||||
])
|
||||
|
||||
self.may_reinitialize_input_batch(kv_cache_config)
|
||||
|
||||
if self.model_config.is_deepseek_mla:
|
||||
kv_caches = self.initialize_kv_cache_tensors_deepseek(
|
||||
if self.ascend_config.is_deepseek_sfa:
|
||||
kv_caches = self.initialize_kv_cache_tensors_deepseek_sfa(
|
||||
kv_cache_config)
|
||||
elif self.model_config.is_deepseek_mla:
|
||||
kv_caches = self.initialize_kv_cache_tensors_deepseek_mla(
|
||||
kv_cache_config)
|
||||
else:
|
||||
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
|
||||
@@ -2701,7 +2718,116 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
offset = (aligned_addr - data_ptr) // tensor.element_size()
|
||||
return tensor[int(offset):]
|
||||
|
||||
def initialize_kv_cache_tensors_deepseek(
|
||||
def initialize_kv_cache_tensors_deepseek_sfa(
|
||||
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
||||
kv_cache_sizes = {}
|
||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||
assert len(kv_cache_tensor.shared_by) == 1, (
|
||||
"KV cache tensor shared by multiple layers is not supported in "
|
||||
"NPU.")
|
||||
kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
|
||||
|
||||
kv_caches: Dict[str, torch.Tensor] = {}
|
||||
for group in self._kv_cache_spec_attn_group_iterator_dispatcher():
|
||||
if vllm_version_is("0.10.2"):
|
||||
kv_cache_spec, group = group
|
||||
else:
|
||||
kv_cache_spec = group.kv_cache_spec
|
||||
attn_backend = group.backend
|
||||
for layer_name in group.layer_names:
|
||||
if layer_name in self.runner_only_attn_layers:
|
||||
continue
|
||||
tensor_size = kv_cache_sizes[layer_name]
|
||||
num_blocks = tensor_size // kv_cache_spec.page_size_bytes
|
||||
if self.vllm_config.additional_config.get(
|
||||
"kv_cache_dtype", None) == 'int8':
|
||||
kv_cache_shape = attn_backend.get_bsh_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
elif hasattr(
|
||||
attn_backend, "get_supported_block_size"
|
||||
) and not self.model_config.is_deepseek_mla and not self.ascend_config.is_deepseek_sfa:
|
||||
block_size = attn_backend.get_supported_block_size()[0]
|
||||
block_size_chunk = kv_cache_spec.block_size // block_size
|
||||
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
num_blocks * block_size_chunk, block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
else:
|
||||
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||
num_blocks, kv_cache_spec.block_size,
|
||||
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||
dtype = kv_cache_spec.dtype
|
||||
|
||||
alignment = 2 * 1024 * 1024
|
||||
num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape
|
||||
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
||||
nope_dim = head_size - rope_dim
|
||||
nope_cache_shape = (num_blocks, block_size, num_kv_heads,
|
||||
nope_dim)
|
||||
rope_cache_shape = (num_blocks, block_size, num_kv_heads,
|
||||
rope_dim)
|
||||
#### k cache
|
||||
# TODO(zzzzwwjj): wait transformers add these params
|
||||
k_cache_shape = (num_blocks, block_size, 1, 128)
|
||||
if self.vllm_config.kv_transfer_config is None:
|
||||
# For no disaggregate pd scenario, allocate kv cache in normal way
|
||||
rope_cache = torch.zeros(rope_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
nope_cache = torch.zeros(nope_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
rope_cache = self._convert_torch_format(rope_cache)
|
||||
nope_cache = self._convert_torch_format(nope_cache)
|
||||
|
||||
#### k cache
|
||||
k_cache = torch.zeros(k_cache_shape,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
k_cache = self._convert_torch_format(k_cache)
|
||||
else:
|
||||
|
||||
# In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory
|
||||
# address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but
|
||||
# we found there are also some exceptions during test, so we manual align those memory here, this part
|
||||
# of code may consume 2M * 2 * elem_size memory every layer.
|
||||
nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim
|
||||
nope_allocate_shape_alignment = nope_allocate_shape + alignment
|
||||
rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim
|
||||
rope_allocate_shape_alignment = rope_allocate_shape + alignment
|
||||
|
||||
nope_cache = torch.zeros(nope_allocate_shape_alignment,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
rope_cache = torch.zeros(rope_allocate_shape_alignment,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
#### k cache
|
||||
# TODO(zzzzwwjj): wait transformers add these params
|
||||
k_allocate_shape = num_blocks * block_size * 1 * 128
|
||||
k_allocate_shape_alignment = k_allocate_shape + alignment
|
||||
k_cache = torch.zeros(k_allocate_shape_alignment,
|
||||
dtype=dtype,
|
||||
device=self.device)
|
||||
|
||||
nope_cache = self._align_memory(
|
||||
nope_cache,
|
||||
alignment)[:nope_allocate_shape].view(nope_cache_shape)
|
||||
rope_cache = self._align_memory(
|
||||
rope_cache,
|
||||
alignment)[:rope_allocate_shape].view(rope_cache_shape)
|
||||
k_cache = self._align_memory(
|
||||
k_cache,
|
||||
alignment)[:k_allocate_shape].view(k_cache_shape)
|
||||
|
||||
kv_caches[layer_name] = (nope_cache, rope_cache, k_cache)
|
||||
bind_kv_cache(kv_caches,
|
||||
self.compilation_config.static_forward_context,
|
||||
self.kv_caches)
|
||||
|
||||
return kv_caches
|
||||
|
||||
def initialize_kv_cache_tensors_deepseek_mla(
|
||||
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
||||
kv_cache_sizes = {}
|
||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||
@@ -3217,6 +3343,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
use_mla = self.vllm_config.model_config.use_mla
|
||||
use_sfa = self.ascend_config.use_sfa
|
||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
@@ -3243,7 +3370,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
use_mla=use_mla)
|
||||
use_mla=use_mla,
|
||||
use_sfa=use_sfa)
|
||||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||
AttentionType.ENCODER_ONLY):
|
||||
# encoder-only attention does not need KV cache.
|
||||
|
||||
@@ -43,7 +43,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import init_ascend_config
|
||||
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
|
||||
from vllm_ascend.device_allocator.camem import CaMemAllocator
|
||||
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
@@ -88,6 +88,17 @@ class NPUWorker(WorkerBase):
|
||||
# init ascend config and soc version
|
||||
init_ascend_config(vllm_config)
|
||||
init_ascend_soc_version()
|
||||
if get_ascend_config().use_sfa:
|
||||
# Direct import instead of using try_register_lib to ensure proper error handling when
|
||||
# custom_ops is necessary but not available (e.g., in DeepSeek v3.2 deployments)
|
||||
# yapf: disable
|
||||
import custom_ops # type: ignore # noqa
|
||||
|
||||
# yapf: enable
|
||||
logger.info(
|
||||
"custom_ops module loaded successfully. Custom operators like "
|
||||
"torch.ops.custom.npu_sparse_flash_attention are now available."
|
||||
)
|
||||
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
local_rank=local_rank,
|
||||
|
||||
Reference in New Issue
Block a user