Add DeepSeek V3.2 support (#3270)
### What this PR does / why we need it? This PR added the initial DeepSeek V3.2 support with [vLLM v0.11.0](https://github.com/vllm-project/vllm/tree/releases/v0.11.0) (not released yet). We will complete vLLM adaptation as soon as possible. This feature will be ready in recent 1-2 days. Related doc: https://github.com/vllm-project/vllm-ascend/pull/3223 . ### Does this PR introduce _any_ user-facing change? Yes! ### How was this patch tested? CI passed and Run deepseek doc soon. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/releases/v0.11.0 --------- Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: zzzzwwjj <1183291235@qq.com> Signed-off-by: linfeng-yuan <1102311262@qq.com> Signed-off-by: wxsIcey <1790571317@qq.com> Signed-off-by: MengqingCao <cmq0113@163.com> Co-authored-by: zzzzwwjj <1183291235@qq.com> Co-authored-by: linfeng-yuan <1102311262@qq.com> Co-authored-by: wxsIcey <1790571317@qq.com> Co-authored-by: MengqingCao <cmq0113@163.com>
This commit is contained in:
9
.github/workflows/vllm_ascend_test.yaml
vendored
9
.github/workflows/vllm_ascend_test.yaml
vendored
@@ -121,7 +121,14 @@ jobs:
|
|||||||
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/devlib
|
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/Ascend/ascend-toolkit/latest/x86_64-linux/devlib
|
||||||
pytest -sv --cov --cov-report=xml:unittests-coverage.xml tests/ut \
|
pytest -sv --cov --cov-report=xml:unittests-coverage.xml tests/ut \
|
||||||
--ignore=tests/ut/test_platform.py \
|
--ignore=tests/ut/test_platform.py \
|
||||||
--ignore=tests/ut/patch/worker/patch_common/test_patch_minicpm.py
|
--ignore=tests/ut/patch/worker/patch_common/test_patch_minicpm.py \
|
||||||
|
--ignore=tests/ut/core/test_scheduler.py \
|
||||||
|
--ignore=tests/ut/kv_connector/test_llmdatadist_connector.py \
|
||||||
|
--ignore=tests/ut/kv_connector/test_mooncake_connector.py \
|
||||||
|
--ignore=tests/ut/kv_connector/test_remote_decode_lifecycle.py \
|
||||||
|
--ignore=tests/ut/kv_connector/test_remote_prefill_lifecycle.py \
|
||||||
|
--ignore=tests/ut/torchair/models/test_torchair_deepseek_v2.py \
|
||||||
|
--ignore=tests/ut/torchair/test_utils.py
|
||||||
|
|
||||||
- name: Upload coverage to Codecov
|
- name: Upload coverage to Codecov
|
||||||
# only upload coverage when commits merged
|
# only upload coverage when commits merged
|
||||||
|
|||||||
@@ -23,5 +23,7 @@ def register():
|
|||||||
|
|
||||||
|
|
||||||
def register_model():
|
def register_model():
|
||||||
|
import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa
|
||||||
|
|
||||||
from .models import register_model
|
from .models import register_model
|
||||||
register_model()
|
register_model()
|
||||||
|
|||||||
@@ -34,6 +34,8 @@ class AscendConfig:
|
|||||||
|
|
||||||
def __init__(self, vllm_config):
|
def __init__(self, vllm_config):
|
||||||
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
|
additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {}
|
||||||
|
self.is_deepseek_sfa = vllm_config.model_config is not None and vllm_config.model_config.is_deepseek_mla and vllm_config.model_config.hf_text_config.model_type == "deepseek_v32"
|
||||||
|
self.use_sfa = self.is_deepseek_sfa
|
||||||
|
|
||||||
torchair_graph_config = additional_config.get("torchair_graph_config",
|
torchair_graph_config = additional_config.get("torchair_graph_config",
|
||||||
{})
|
{})
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ class AttentionMaskBuilder:
|
|||||||
device: torch.device):
|
device: torch.device):
|
||||||
self._update_attn_cache(max_seq_len, dtype)
|
self._update_attn_cache(max_seq_len, dtype)
|
||||||
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous(
|
return self.attn_mask_cache[:max_seq_len, :max_seq_len].contiguous(
|
||||||
).to(device)
|
).to(device, non_blocking=True)
|
||||||
|
|
||||||
def get_splitfuse_attn_mask(
|
def get_splitfuse_attn_mask(
|
||||||
self,
|
self,
|
||||||
|
|||||||
986
vllm_ascend/attention/sfa_v1.py
Normal file
986
vllm_ascend/attention/sfa_v1.py
Normal file
@@ -0,0 +1,986 @@
|
|||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import (TYPE_CHECKING, ClassVar, NamedTuple, Optional, Tuple, Type,
|
||||||
|
TypeVar)
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch_npu
|
||||||
|
from torch import nn
|
||||||
|
from vllm.attention.backends.abstract import (AttentionBackend,
|
||||||
|
AttentionMetadata,
|
||||||
|
MLAAttentionImpl)
|
||||||
|
from vllm.config import VllmConfig, get_current_vllm_config
|
||||||
|
from vllm.distributed import get_tensor_model_parallel_world_size, get_tp_group
|
||||||
|
from vllm.model_executor.layers.linear import (LinearBase,
|
||||||
|
UnquantizedLinearMethod)
|
||||||
|
from vllm.utils import cdiv, round_down
|
||||||
|
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||||
|
|
||||||
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
|
from vllm_ascend.attention.mla_v1 import AscendMLAMetadata
|
||||||
|
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||||
|
split_decodes_and_prefills)
|
||||||
|
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||||
|
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||||
|
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
|
||||||
|
|
||||||
|
class AscendSFABackend(AttentionBackend):
|
||||||
|
|
||||||
|
accept_output_buffer: bool = True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_name() -> str:
|
||||||
|
return "ASCEND_SFA"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_metadata_cls() -> type["AttentionMetadata"]:
|
||||||
|
return AscendSFAMetadata
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_builder_cls():
|
||||||
|
return AscendSFAMetadataBuilder
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_kv_cache_shape(num_blocks: int, block_size: int, num_kv_heads: int,
|
||||||
|
head_size: int) -> tuple[int, ...]:
|
||||||
|
return (num_blocks, block_size, num_kv_heads, head_size)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_impl_cls() -> Type["AscendSFAImpl"]:
|
||||||
|
return AscendSFAImpl
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AscendSFAPrefillMetadata:
|
||||||
|
""" Prefill Specific Metadata for Ascend"""
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChunkedContextMetadata:
|
||||||
|
# New for MLA (compared to FlashAttention)
|
||||||
|
# For handling chunked prefill
|
||||||
|
cu_seq_lens: torch.Tensor
|
||||||
|
starts: torch.Tensor
|
||||||
|
seq_tot: list[int]
|
||||||
|
max_seq_lens: list[int]
|
||||||
|
workspace: torch.Tensor
|
||||||
|
chunk_seq_lens: torch.Tensor
|
||||||
|
|
||||||
|
attn_mask: torch.Tensor
|
||||||
|
query_lens: list[int]
|
||||||
|
seq_lens: list[int]
|
||||||
|
|
||||||
|
context_lens: torch.Tensor
|
||||||
|
input_positions: torch.Tensor
|
||||||
|
query_start_loc: torch.Tensor
|
||||||
|
block_table: torch.Tensor
|
||||||
|
max_query_len: int
|
||||||
|
max_seq_lens: int
|
||||||
|
sin: torch.Tensor
|
||||||
|
cos: torch.Tensor
|
||||||
|
chunked_context: Optional[ChunkedContextMetadata] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AscendSFADecodeMetadata:
|
||||||
|
# Input positions for rotrary embeddings since for MLA the rotary
|
||||||
|
# position embeddings are applied inside the attention backend
|
||||||
|
input_positions: torch.Tensor
|
||||||
|
block_table: torch.Tensor
|
||||||
|
seq_lens: torch.Tensor
|
||||||
|
max_seq_lens: int
|
||||||
|
seq_lens_list: list[int]
|
||||||
|
actual_seq_lengths_q: torch.Tensor
|
||||||
|
sin: torch.Tensor
|
||||||
|
cos: torch.Tensor
|
||||||
|
attn_mask: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AscendSFAMetadata:
|
||||||
|
"""Metadata for MLACommon.
|
||||||
|
|
||||||
|
NOTE: Please read the comment at the top of the file before trying to
|
||||||
|
understand this class
|
||||||
|
"""
|
||||||
|
# NOTE(sang): Definition of context_len, query_len, and seq_len.
|
||||||
|
# |---------- N-1 iteration --------|
|
||||||
|
# |---------------- N iteration ---------------------|
|
||||||
|
# |- tokenA -|......................|-- newTokens ---|
|
||||||
|
# |---------- context_len ----------|
|
||||||
|
# |-------------------- seq_len ---------------------|
|
||||||
|
# |-- query_len ---|
|
||||||
|
|
||||||
|
num_actual_tokens: int # Number of tokens excluding padding.
|
||||||
|
slot_mapping: torch.Tensor
|
||||||
|
query_start_loc: torch.Tensor
|
||||||
|
seq_lens: torch.Tensor
|
||||||
|
block_tables: torch.Tensor
|
||||||
|
|
||||||
|
# New for MLA (compared to FlashAttention)
|
||||||
|
# For handling prefill decode split
|
||||||
|
num_decodes: int
|
||||||
|
num_decode_tokens: int
|
||||||
|
num_prefills: int
|
||||||
|
|
||||||
|
# For logging.
|
||||||
|
num_input_tokens: int = 0 # Number of tokens including padding.
|
||||||
|
|
||||||
|
query_lens: Optional[list[int]] = None
|
||||||
|
# The dimension of the attention heads
|
||||||
|
head_dim: Optional[int] = None
|
||||||
|
attn_mask: torch.Tensor = None
|
||||||
|
# chunked prefill by default if no attn_states passed
|
||||||
|
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
||||||
|
|
||||||
|
decode: Optional[AscendSFADecodeMetadata] = None
|
||||||
|
prefill: Optional[AscendSFAPrefillMetadata] = None
|
||||||
|
enable_dbo_across_dp: bool = False
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
pass
|
||||||
|
# supported_head_sizes = AscendMLABackend.get_supported_head_sizes()
|
||||||
|
# if self.head_dim is not None and self.head_dim \
|
||||||
|
# not in supported_head_sizes:
|
||||||
|
# raise ValueError(
|
||||||
|
# f"Only {supported_head_sizes} are supported for head_dim,",
|
||||||
|
# f"received {self.head_dim}.")
|
||||||
|
|
||||||
|
def split_metadata_for_multistream(
|
||||||
|
self,
|
||||||
|
ms_split_config: MSAttentionMetadataSplitConfig,
|
||||||
|
) -> list["AscendSFAMetadata"]:
|
||||||
|
"""Split metadata for multi-stream with AscendSFAMetadata"""
|
||||||
|
return model_input_split_v1_mla_attn(
|
||||||
|
ms_split_config=ms_split_config,
|
||||||
|
attn_metadata=self,
|
||||||
|
_metadata_cls=AscendMLAMetadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
M = TypeVar("M", bound=AscendSFAMetadata)
|
||||||
|
|
||||||
|
|
||||||
|
class AscendSFAMetadataBuilder:
|
||||||
|
# Does this backend/builder support ACL Graphs for attention (default: no).
|
||||||
|
aclgraph_support: ClassVar[AttentionCGSupport] = \
|
||||||
|
AttentionCGSupport.NEVER
|
||||||
|
"""
|
||||||
|
NOTE: Please read the comment at the top of the file before trying to
|
||||||
|
understand this class
|
||||||
|
"""
|
||||||
|
|
||||||
|
# _attn_mask_builder = None
|
||||||
|
def __init__(self,
|
||||||
|
kv_cache_spec,
|
||||||
|
layer_names,
|
||||||
|
vllm_config: VllmConfig,
|
||||||
|
device: torch.device,
|
||||||
|
metadata_cls: Optional[AscendSFAMetadata] = None):
|
||||||
|
self.metadata_cls: Optional[AscendSFAMetadata] = metadata_cls \
|
||||||
|
if metadata_cls is not None else AscendSFAMetadata # type: ignore
|
||||||
|
self.vllm_config = vllm_config
|
||||||
|
self.model_config = vllm_config.model_config
|
||||||
|
self.device = device
|
||||||
|
scheduler_config = vllm_config.scheduler_config
|
||||||
|
self.block_size = vllm_config.cache_config.block_size
|
||||||
|
self.max_blocks = (vllm_config.model_config.max_model_len +
|
||||||
|
self.block_size - 1) // self.block_size
|
||||||
|
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
|
||||||
|
|
||||||
|
self.speculative_config = vllm_config.speculative_config
|
||||||
|
self.decode_threshold = 1
|
||||||
|
if self.speculative_config:
|
||||||
|
spec_token_num = self.speculative_config.num_speculative_tokens
|
||||||
|
self.decode_threshold += spec_token_num
|
||||||
|
assert self.decode_threshold <= 16, f"decode_threshold exceeded \
|
||||||
|
npu_fused_infer_attention_score TND layout's limit of 16, \
|
||||||
|
got {self.decode_threshold}"
|
||||||
|
|
||||||
|
if self.chunked_prefill_enabled:
|
||||||
|
self.chunked_prefill_workspace_size = min(
|
||||||
|
# Max sure there is enough for 8 full length request or at least
|
||||||
|
# 4 pages of cache per request
|
||||||
|
max(8 * self.model_config.max_model_len,
|
||||||
|
4 * scheduler_config.max_num_seqs * self.block_size),
|
||||||
|
# For long-context models try not to over-allocate limiting
|
||||||
|
# kv-cache space, limiting it to 64k tokens,
|
||||||
|
# which would result in the workspace being:
|
||||||
|
# 2*(576)*(64*1024) = 144mb
|
||||||
|
# (assuming 576 MLA head dim, and fp16)
|
||||||
|
# which would result in up-projected context being
|
||||||
|
# 2*(192*128)*(64*1024) = 3gb
|
||||||
|
# (assuming 192 QK head dim, 128 heads, and fp16)
|
||||||
|
128 * 1024)
|
||||||
|
assert self.chunked_prefill_workspace_size >= \
|
||||||
|
scheduler_config.max_num_seqs * self.block_size
|
||||||
|
self.chunked_prefill_workspace = torch.empty(
|
||||||
|
(self.chunked_prefill_workspace_size,
|
||||||
|
self.model_config.get_head_size()),
|
||||||
|
dtype=self.model_config.dtype,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
self.rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
||||||
|
self.cos_cache = None
|
||||||
|
self.sin_cache = None
|
||||||
|
|
||||||
|
def reorder_batch(self, input_batch: "InputBatch",
|
||||||
|
scheduler_output: "SchedulerOutput") -> bool:
|
||||||
|
# We now want to reorder the batch so that the "decode" requests are at
|
||||||
|
# the front and the "prefill" requests are at the using the least amount
|
||||||
|
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
|
||||||
|
# where attention is likely memory-bound and "prefill" to mean requests
|
||||||
|
# where attention is likely compute-bound, TODO(lucas): figure out a
|
||||||
|
# better naming here)
|
||||||
|
decodes = []
|
||||||
|
prefills = []
|
||||||
|
|
||||||
|
for i, req_id in enumerate(input_batch.req_ids):
|
||||||
|
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||||
|
if num_tokens <= self.decode_threshold:
|
||||||
|
decodes.append(i)
|
||||||
|
else:
|
||||||
|
prefills.append(i)
|
||||||
|
|
||||||
|
# We hope that this is fairly minimal since decodes
|
||||||
|
# should be around for a number of iterations so hopefully they are
|
||||||
|
# relatively stationary (and new request are generally appended to the
|
||||||
|
# persistent batch so already should be at the back)
|
||||||
|
# To achieve this we loop over the decodes in descending order and
|
||||||
|
# the prefills in ascending order. We swap decodes from the "back"
|
||||||
|
# i.e. past where the last decode should be in the reodorered with
|
||||||
|
# prefills from the front of the batch.
|
||||||
|
# `decodes` and `prefills` are already in ascending order just based on
|
||||||
|
# the above loop
|
||||||
|
num_decodes = len(decodes)
|
||||||
|
num_prefills = len(prefills)
|
||||||
|
first_prefill = 0
|
||||||
|
modified_batch = False
|
||||||
|
|
||||||
|
for i in range(1, min(num_decodes, num_prefills) + 1):
|
||||||
|
# If the decode is at the "back" of the batch, i, we can swap it
|
||||||
|
# with the prefill closest to the front of the batch
|
||||||
|
if decodes[num_decodes - i] >= num_decodes:
|
||||||
|
input_batch.swap_states(prefills[first_prefill],
|
||||||
|
decodes[num_decodes - i])
|
||||||
|
first_prefill += 1
|
||||||
|
modified_batch = True
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Save for next `build` call
|
||||||
|
# TODO(lucas): this is a bit of a hack, we should probably have a
|
||||||
|
# better way of doing this
|
||||||
|
return modified_batch
|
||||||
|
|
||||||
|
def build(
|
||||||
|
self,
|
||||||
|
common_prefix_len: int,
|
||||||
|
common_attn_metadata: AscendCommonAttentionMetadata,
|
||||||
|
model: nn.Module,
|
||||||
|
) -> AscendSFAMetadata:
|
||||||
|
num_reqs = common_attn_metadata.num_reqs
|
||||||
|
num_actual_tokens = common_attn_metadata.num_actual_tokens
|
||||||
|
query_start_loc = common_attn_metadata.query_start_loc
|
||||||
|
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||||
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||||
|
split_decodes_and_prefills(common_attn_metadata, decode_threshold=self.decode_threshold)
|
||||||
|
assert num_decodes + num_prefills == num_reqs
|
||||||
|
assert num_decode_tokens + num_prefill_tokens == num_actual_tokens
|
||||||
|
|
||||||
|
# Note(simon): be careful about the CPU <> GPU memory movement in this
|
||||||
|
# function. We should avoid GPU -> CPU sync as much as possible because
|
||||||
|
# it blocks on all previous kernels.
|
||||||
|
device = self.device
|
||||||
|
|
||||||
|
block_table = (common_attn_metadata.block_table_tensor[:num_reqs])
|
||||||
|
slot_mapping = common_attn_metadata.slot_mapping[:
|
||||||
|
num_actual_tokens].to(
|
||||||
|
device,
|
||||||
|
non_blocking=True)
|
||||||
|
input_positions = common_attn_metadata.positions[:
|
||||||
|
num_actual_tokens].long(
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.cos_cache is None:
|
||||||
|
self.cos_cache = model.model.layers[
|
||||||
|
0].self_attn.rotary_emb.cos_cached
|
||||||
|
self.sin_cache = model.model.layers[
|
||||||
|
0].self_attn.rotary_emb.sin_cached
|
||||||
|
if self.cos_cache.dtype != self.model_config.dtype: # type: ignore
|
||||||
|
self.cos_cache = self.cos_cache.to( # type: ignore
|
||||||
|
self.model_config.dtype) # type: ignore
|
||||||
|
self.sin_cache = self.sin_cache.to( # type: ignore
|
||||||
|
self.model_config.dtype) # type: ignore
|
||||||
|
|
||||||
|
query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
|
||||||
|
query_lens = query_seq_lens_cpu[:num_reqs]
|
||||||
|
seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs]
|
||||||
|
num_computed_tokens_cpu = (seq_lens - query_lens)
|
||||||
|
|
||||||
|
prefill_metadata = None
|
||||||
|
chunked_context_metadata = None
|
||||||
|
if num_prefills > 0:
|
||||||
|
reqs_start = num_decodes # prefill_start
|
||||||
|
tokens_start = num_decode_tokens
|
||||||
|
max_query_len = query_lens[reqs_start:].max().item()
|
||||||
|
max_seq_lens = seq_lens[reqs_start:].max().item()
|
||||||
|
prefill_query_start_loc = query_start_loc[
|
||||||
|
reqs_start:] - query_start_loc[reqs_start]
|
||||||
|
|
||||||
|
context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
|
||||||
|
max_context_len_cpu = context_lens_cpu.max().item()
|
||||||
|
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
|
||||||
|
if self.chunked_prefill_enabled and max_context_len_cpu > 0:
|
||||||
|
max_context_chunk = (self.chunked_prefill_workspace_size //
|
||||||
|
num_prefills_with_context_cpu)
|
||||||
|
max_context_chunk = round_down(max_context_chunk,
|
||||||
|
self.block_size)
|
||||||
|
|
||||||
|
assert max_context_chunk > 0
|
||||||
|
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
|
||||||
|
chunk_starts = torch.arange(num_chunks, dtype=torch.int32) \
|
||||||
|
.unsqueeze(1).expand(-1, num_prefills) * max_context_chunk
|
||||||
|
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
|
||||||
|
chunk_starts + max_context_chunk)
|
||||||
|
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
|
||||||
|
cu_seq_lens_cpu = torch.zeros(num_chunks,
|
||||||
|
num_prefills + 1,
|
||||||
|
dtype=torch.int32,
|
||||||
|
pin_memory=True)
|
||||||
|
torch.cumsum(chunk_seq_lens,
|
||||||
|
dim=1,
|
||||||
|
out=cu_seq_lens_cpu[:, 1:],
|
||||||
|
dtype=torch.int32)
|
||||||
|
chunked_context_metadata = \
|
||||||
|
AscendSFAPrefillMetadata.ChunkedContextMetadata(
|
||||||
|
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
|
||||||
|
starts=chunk_starts.to(device, non_blocking=True),
|
||||||
|
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
|
||||||
|
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
|
||||||
|
chunk_seq_lens=chunk_seq_lens,
|
||||||
|
workspace=self.chunked_prefill_workspace,
|
||||||
|
)
|
||||||
|
prefill_input_positions = input_positions[tokens_start:]
|
||||||
|
cos = self.cos_cache[
|
||||||
|
prefill_input_positions].unsqueeze( # type: ignore
|
||||||
|
1).unsqueeze(2)
|
||||||
|
sin = self.sin_cache[
|
||||||
|
prefill_input_positions].unsqueeze( # type: ignore
|
||||||
|
1).unsqueeze(2)
|
||||||
|
actual_query_lens = torch.tensor(query_lens[reqs_start:],
|
||||||
|
dtype=torch.int32).npu()
|
||||||
|
query_lens_prefill_sfa = torch.cumsum(actual_query_lens,
|
||||||
|
dim=0).to(torch.int32)
|
||||||
|
seq_lens_prefill_sfa = seq_lens[reqs_start:].to(torch.int32).npu()
|
||||||
|
prefill_metadata = AscendSFAPrefillMetadata(
|
||||||
|
attn_mask=common_attn_metadata.attn_mask,
|
||||||
|
query_lens=query_lens_prefill_sfa,
|
||||||
|
seq_lens=seq_lens_prefill_sfa,
|
||||||
|
context_lens=seq_lens[reqs_start:],
|
||||||
|
input_positions=prefill_input_positions,
|
||||||
|
block_table=block_table[reqs_start:, ...],
|
||||||
|
max_query_len=max_query_len,
|
||||||
|
max_seq_lens=max_seq_lens,
|
||||||
|
query_start_loc=prefill_query_start_loc,
|
||||||
|
chunked_context=chunked_context_metadata,
|
||||||
|
sin=sin,
|
||||||
|
cos=cos,
|
||||||
|
)
|
||||||
|
|
||||||
|
decode_metadata = None
|
||||||
|
if num_decodes > 0:
|
||||||
|
# Notice that num_decodes != num_decode_tokens in SpecDecoding Scenario
|
||||||
|
actual_seq_lengths_q = query_start_loc[1:num_decodes + 1].to(
|
||||||
|
torch.int32).npu()
|
||||||
|
max_seq_lens = seq_lens[:num_decodes].max().item()
|
||||||
|
seq_lens = seq_lens[:num_decodes].to(torch.int32).npu()
|
||||||
|
input_positions = input_positions[:num_decode_tokens]
|
||||||
|
block_table = block_table[:num_decodes, ...]
|
||||||
|
seq_lens_list = seq_lens.tolist()
|
||||||
|
|
||||||
|
cos = self.cos_cache[input_positions].unsqueeze( # type: ignore
|
||||||
|
1).unsqueeze(2)
|
||||||
|
sin = self.sin_cache[input_positions].unsqueeze( # type: ignore
|
||||||
|
1).unsqueeze(2)
|
||||||
|
|
||||||
|
decode_metadata = AscendSFADecodeMetadata(
|
||||||
|
input_positions=input_positions,
|
||||||
|
block_table=block_table,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
seq_lens_list=seq_lens_list,
|
||||||
|
max_seq_lens=max_seq_lens,
|
||||||
|
attn_mask=common_attn_metadata.spec_attn_mask,
|
||||||
|
actual_seq_lengths_q=actual_seq_lengths_q,
|
||||||
|
sin=sin,
|
||||||
|
cos=cos)
|
||||||
|
|
||||||
|
return self.metadata_cls( # type: ignore
|
||||||
|
num_actual_tokens=num_actual_tokens,
|
||||||
|
query_lens=query_lens.tolist(),
|
||||||
|
slot_mapping=slot_mapping,
|
||||||
|
head_dim=self.model_config.get_head_size(),
|
||||||
|
num_decodes=num_decodes,
|
||||||
|
num_decode_tokens=num_decode_tokens,
|
||||||
|
num_prefills=num_prefills,
|
||||||
|
attn_mask=common_attn_metadata.attn_mask,
|
||||||
|
attn_state=common_attn_metadata.attn_state,
|
||||||
|
prefill=prefill_metadata,
|
||||||
|
decode=decode_metadata,
|
||||||
|
query_start_loc=query_start_loc,
|
||||||
|
block_tables=block_table,
|
||||||
|
seq_lens=seq_lens,
|
||||||
|
enable_dbo_across_dp=common_attn_metadata.enable_dbo_across_dp,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PrefillSFAPreprocessResult(NamedTuple):
|
||||||
|
q_nope: Optional[torch.Tensor] = None
|
||||||
|
q_pe: Optional[torch.Tensor] = None
|
||||||
|
k_nope: Optional[torch.Tensor] = None
|
||||||
|
k_pe: Optional[torch.Tensor] = None
|
||||||
|
topk_indices: Optional[torch.Tensor] = None
|
||||||
|
query_states: Optional[torch.Tensor] = None
|
||||||
|
key_states: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
|
||||||
|
class DecodeSFAPreprocessResult(NamedTuple):
|
||||||
|
q_nope: Optional[torch.Tensor] = None
|
||||||
|
q_pe: Optional[torch.Tensor] = None
|
||||||
|
# nope_cache: Optional[torch.Tensor] = None
|
||||||
|
# rope_cache: Optional[torch.Tensor] = None
|
||||||
|
topk_indices: Optional[torch.Tensor] = None
|
||||||
|
query_states: Optional[torch.Tensor] = None
|
||||||
|
key_states: Optional[torch.Tensor] = None
|
||||||
|
bsz: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
class AscendSFAImpl(MLAAttentionImpl):
|
||||||
|
"""
|
||||||
|
NOTE: Please read the comment at the top of the file before trying to
|
||||||
|
understand this class
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
num_kv_heads: int,
|
||||||
|
alibi_slopes: Optional[list[float]],
|
||||||
|
sliding_window: Optional[int],
|
||||||
|
kv_cache_dtype: str,
|
||||||
|
logits_soft_cap: Optional[float],
|
||||||
|
attn_type: str,
|
||||||
|
kv_sharing_target_layer_name: Optional[str],
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_size = head_size
|
||||||
|
self.scale = float(scale)
|
||||||
|
self.num_kv_heads = num_kv_heads
|
||||||
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
|
||||||
|
# MLA Args
|
||||||
|
self.q_lora_rank = kwargs['q_lora_rank']
|
||||||
|
self.kv_lora_rank = kwargs['kv_lora_rank']
|
||||||
|
self.qk_nope_head_dim = kwargs['qk_nope_head_dim']
|
||||||
|
self.qk_rope_head_dim = kwargs['qk_rope_head_dim']
|
||||||
|
self.qk_head_dim = kwargs['qk_head_dim']
|
||||||
|
self.v_head_dim = kwargs['v_head_dim']
|
||||||
|
self.rotary_emb = kwargs['rotary_emb']
|
||||||
|
self.q_proj = kwargs['q_proj']
|
||||||
|
self.kv_b_proj = kwargs['kv_b_proj']
|
||||||
|
self.o_proj = kwargs['o_proj']
|
||||||
|
self.indexer = kwargs['indexer']
|
||||||
|
self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)
|
||||||
|
self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None)
|
||||||
|
self.q_a_proj = kwargs.get('q_a_proj', None)
|
||||||
|
self.q_a_layernorm = kwargs.get('q_a_layernorm', None)
|
||||||
|
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.num_heads_per_rank = self.num_heads // self.tp_size
|
||||||
|
if self.q_a_proj is not None:
|
||||||
|
self.q_b_proj = self.q_proj
|
||||||
|
else:
|
||||||
|
self.q_b_proj = None
|
||||||
|
|
||||||
|
ascend_config = get_ascend_config()
|
||||||
|
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||||
|
self.enable_prefetch = ascend_config.enable_prefetch
|
||||||
|
self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz
|
||||||
|
|
||||||
|
vllm_config = get_current_vllm_config()
|
||||||
|
self.ring_mla_mask_size = 512
|
||||||
|
self.prefill_mask = None
|
||||||
|
|
||||||
|
# indexer param
|
||||||
|
self.dim = self.indexer.dim
|
||||||
|
self.n_heads: int = self.indexer.n_heads # 64
|
||||||
|
self.head_dim: int = self.indexer.head_dim # 128
|
||||||
|
self.index_topk: int = self.indexer.index_topk # 2048
|
||||||
|
self.wq_b = self.indexer.wq_b
|
||||||
|
self.wk = self.indexer.wk
|
||||||
|
self.weights_proj = self.indexer.weights_proj
|
||||||
|
self.k_norm = self.indexer.k_norm
|
||||||
|
self.softmax_scale = self.indexer.softmax_scale
|
||||||
|
|
||||||
|
# Adapt torch air graph mode with spec decoding.
|
||||||
|
speculative_config = vllm_config.speculative_config
|
||||||
|
if speculative_config is not None:
|
||||||
|
self.spec_token_num = speculative_config.num_speculative_tokens
|
||||||
|
assert self.spec_token_num > 0
|
||||||
|
|
||||||
|
self.cp_size = 1
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, act_dtype: torch.dtype):
|
||||||
|
|
||||||
|
def get_layer_weight(layer):
|
||||||
|
WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
|
||||||
|
for attr in WEIGHT_NAMES:
|
||||||
|
if hasattr(layer, attr):
|
||||||
|
return getattr(layer, attr)
|
||||||
|
raise AttributeError(
|
||||||
|
f"Layer '{layer}' has no recognized weight attribute:"
|
||||||
|
f" {WEIGHT_NAMES}.")
|
||||||
|
|
||||||
|
def get_and_maybe_dequant_weights(layer: LinearBase):
|
||||||
|
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
|
||||||
|
# NOTE: This should only be used offline, since it's O(N^3)
|
||||||
|
eye = torch.eye(layer.input_size_per_partition,
|
||||||
|
dtype=act_dtype,
|
||||||
|
device=get_layer_weight(layer).device)
|
||||||
|
dequant_weights = layer.quant_method.apply(layer,
|
||||||
|
eye,
|
||||||
|
bias=None)
|
||||||
|
del eye
|
||||||
|
# standardize to (output, input)
|
||||||
|
return dequant_weights.T
|
||||||
|
return layer.weight
|
||||||
|
|
||||||
|
# we currently do not have quantized bmm's which are needed for
|
||||||
|
# `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
|
||||||
|
# the bmm's in 16-bit, the extra memory overhead of this is fairly low
|
||||||
|
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
|
||||||
|
assert kv_b_proj_weight.shape == (
|
||||||
|
self.kv_lora_rank,
|
||||||
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
|
||||||
|
f"{kv_b_proj_weight.shape=}, "
|
||||||
|
f"{self.kv_lora_rank=}, "
|
||||||
|
f"{self.num_heads=}, "
|
||||||
|
f"{self.qk_nope_head_dim=}, "
|
||||||
|
f"{self.v_head_dim=}")
|
||||||
|
kv_b_proj_weight = kv_b_proj_weight.view(
|
||||||
|
self.kv_lora_rank,
|
||||||
|
self.num_heads,
|
||||||
|
self.qk_nope_head_dim + self.v_head_dim,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.kv_b_proj_w_k, self.kv_b_proj_w_v = kv_b_proj_weight.split(
|
||||||
|
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||||
|
|
||||||
|
# Convert from (L, N, V) to (N, L, V)
|
||||||
|
self.kv_b_proj_w_v = self.kv_b_proj_w_v.transpose(0, 1).contiguous()
|
||||||
|
# Convert from (L, N, P) to (N, P, L)
|
||||||
|
self.kv_b_proj_w_k = self.kv_b_proj_w_k.permute(1, 2, 0).contiguous()
|
||||||
|
|
||||||
|
# Waiting for BMM NZ support
|
||||||
|
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
|
||||||
|
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
|
||||||
|
|
||||||
|
def _sfa_preprocess(self, hidden_states, kv_cache, attn_metadata,
|
||||||
|
need_gather_q_kv):
|
||||||
|
# SFA Preprocess:
|
||||||
|
# 1. Perform q_a_proj and q_a_layernorm to obtain q_c
|
||||||
|
# 2. Perform kv_a_proj_with_mqa to obtain kv_no_split
|
||||||
|
# 3. If need_gather_q_kv, perform all_gather.
|
||||||
|
# 4. Preprocess decode tokens, write kv cache and get:
|
||||||
|
# decode_ql_nope, decode_q_pe, decode_k_pe, decode_k_nope
|
||||||
|
# 5. Preprocess prefill tokens, write kv cache and get:
|
||||||
|
# prefill_q_nope, prefill_q_pe, prefill_k_nope, prefill_k_pe, prefill_value
|
||||||
|
has_decode = attn_metadata.num_decodes > 0
|
||||||
|
has_prefill = attn_metadata.num_prefills > 0
|
||||||
|
|
||||||
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||||
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||||
|
if need_gather_q_kv:
|
||||||
|
# q_c = get_tp_group().all_gather(q_c, 0)
|
||||||
|
# kv_no_split = get_tp_group().all_gather(kv_no_split, 0)
|
||||||
|
hidden_states = get_tp_group().all_gather(hidden_states, 0)
|
||||||
|
# hidden_states_decode = hidden_states[:num_decode_tokens]
|
||||||
|
# if self.q_a_proj is not None:
|
||||||
|
# npu_prefetch(self.q_a_proj.weight,
|
||||||
|
# hidden_states,
|
||||||
|
# enabled=self.enable_prefetch)
|
||||||
|
# ckq = self.q_a_proj(hidden_states) # q down
|
||||||
|
# q_c = self.q_a_layernorm(ckq) # q down layernorm
|
||||||
|
# else:
|
||||||
|
# q_c = hidden_states
|
||||||
|
|
||||||
|
# kv_no_split = self.kv_a_proj_with_mqa(hidden_states) # c_kv
|
||||||
|
# Process for shared_expert_dp
|
||||||
|
|
||||||
|
decode_preprocess_res = None
|
||||||
|
prefill_preprocess_res = None
|
||||||
|
# Preprocess for decode tokens
|
||||||
|
if has_decode:
|
||||||
|
q_len = 1
|
||||||
|
hidden_states_decode = hidden_states[:num_decode_tokens]
|
||||||
|
decode_kq = self.q_a_proj(hidden_states_decode) # q down
|
||||||
|
decode_q_c = self.q_a_layernorm(decode_kq) # q down layernorm
|
||||||
|
decode_kv_no_split = self.kv_a_proj_with_mqa(
|
||||||
|
hidden_states_decode) # c_kv
|
||||||
|
|
||||||
|
# decode_q_c = q_c[:num_decode_tokens]
|
||||||
|
decode_slot_mapping = attn_metadata.slot_mapping[:
|
||||||
|
num_decode_tokens]
|
||||||
|
# decode_kv_no_split = decode_kv_no_split[:num_decode_tokens]
|
||||||
|
|
||||||
|
decode_q = self.q_b_proj(decode_q_c)
|
||||||
|
bsz, _ = decode_q.shape
|
||||||
|
decode_q = decode_q.view(bsz, self.num_heads, 1, self.qk_head_dim)
|
||||||
|
decode_q_nope, decode_q_pe = torch.split(
|
||||||
|
decode_q, [self.qk_nope_head_dim, self.qk_rope_head_dim],
|
||||||
|
dim=-1)
|
||||||
|
decode_q_nope = decode_q_nope.view(
|
||||||
|
-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1)
|
||||||
|
decode_q_nope = (torch.matmul(decode_q_nope,
|
||||||
|
self.kv_b_proj_w_k).transpose(
|
||||||
|
1,
|
||||||
|
0).view(bsz, q_len,
|
||||||
|
self.num_heads,
|
||||||
|
self.kv_lora_rank))
|
||||||
|
|
||||||
|
# stream2 kv
|
||||||
|
key_cache = kv_cache[0]
|
||||||
|
value_cache = kv_cache[1]
|
||||||
|
cos = attn_metadata.decode.cos
|
||||||
|
sin = attn_metadata.decode.sin
|
||||||
|
cos_q, sin_q = cos, sin
|
||||||
|
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||||
|
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||||
|
|
||||||
|
decode_kv_no_split = decode_kv_no_split.unsqueeze(1).unsqueeze(1)
|
||||||
|
decode_k_rope, decode_k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
|
||||||
|
decode_kv_no_split,
|
||||||
|
self.kv_a_layernorm.weight,
|
||||||
|
cos,
|
||||||
|
sin,
|
||||||
|
decode_slot_mapping.to(torch.int64),
|
||||||
|
value_cache,
|
||||||
|
key_cache,
|
||||||
|
c_kv_scale=None,
|
||||||
|
epsilon=self.kv_a_layernorm.variance_epsilon,
|
||||||
|
cache_mode='PA') # adapter NZ
|
||||||
|
# nz_block_size = 16
|
||||||
|
# KVCACHE_NZ_DIM = 16
|
||||||
|
# decode_k_nope = decode_k_nope.view(block_num, 1, self.kv_lora_rank // nz_block_size, block_size, nz_block_size)
|
||||||
|
# decode_k_rope = decode_k_rope.view(block_num, 1, self.qk_rope_head_dim // KVCACHE_NZ_DIM, block_size, KVCACHE_NZ_DIM)
|
||||||
|
|
||||||
|
decode_q_pe = torch_npu.npu_interleave_rope(decode_q_pe, cos,
|
||||||
|
sin) # BNSD
|
||||||
|
|
||||||
|
decode_q_nope = decode_q_nope.view(bsz, self.num_heads,
|
||||||
|
self.kv_lora_rank)
|
||||||
|
decode_q_pe = decode_q_pe.view(bsz, self.num_heads, -1)
|
||||||
|
|
||||||
|
topk_indices = self.indexer_select(hidden_states_decode,
|
||||||
|
decode_q_c,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
kv_cache=kv_cache)
|
||||||
|
|
||||||
|
query_states = (decode_q_nope, decode_q_pe)
|
||||||
|
key_states = (decode_k_nope, decode_k_rope)
|
||||||
|
decode_preprocess_res = DecodeSFAPreprocessResult(
|
||||||
|
q_nope=decode_q_nope,
|
||||||
|
q_pe=decode_q_pe,
|
||||||
|
# nope_cache = nope_cache,
|
||||||
|
# rope_cache = rope_cache,
|
||||||
|
topk_indices=topk_indices,
|
||||||
|
query_states=query_states,
|
||||||
|
key_states=key_states,
|
||||||
|
bsz=bsz,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Preprocess for prefill tokens
|
||||||
|
if has_prefill:
|
||||||
|
bsz = 1
|
||||||
|
|
||||||
|
hidden_states_prefill = hidden_states[
|
||||||
|
num_decode_tokens:num_actual_tokens]
|
||||||
|
prefill_kq = self.q_a_proj(hidden_states_prefill) # q down
|
||||||
|
prefill_q_c = self.q_a_layernorm(prefill_kq) # q down layernorm
|
||||||
|
prefill_kv_no_split = self.kv_a_proj_with_mqa(
|
||||||
|
hidden_states_prefill) # c_kv
|
||||||
|
|
||||||
|
# prefill_q_c = q_c[
|
||||||
|
# num_decode_tokens:num_actual_tokens]
|
||||||
|
prefill_slot_mapping = attn_metadata.slot_mapping[
|
||||||
|
num_decode_tokens:num_actual_tokens]
|
||||||
|
# decode_kv_no_split = decode_kv_no_split[:num_decode_tokens]
|
||||||
|
|
||||||
|
prefill_slot_mapping = attn_metadata.slot_mapping[
|
||||||
|
num_decode_tokens:num_actual_tokens]
|
||||||
|
# prefill_kv_no_split = kv_no_split[
|
||||||
|
# num_decode_tokens:num_actual_tokens]
|
||||||
|
# prefill_qr = prefill_q_c[num_decode_tokens:num_actual_tokens]
|
||||||
|
prefill_qr = prefill_q_c
|
||||||
|
prefill_q = self.q_b_proj(prefill_qr)
|
||||||
|
prefill_q = prefill_q.view(-1, self.num_heads, self.qk_head_dim)
|
||||||
|
prefill_q_nope, prefill_q_pe = torch.split(
|
||||||
|
prefill_q, [self.qk_nope_head_dim, self.qk_rope_head_dim],
|
||||||
|
dim=-1)
|
||||||
|
prefill_q_nope = prefill_q_nope.view(
|
||||||
|
-1, self.num_heads, self.qk_nope_head_dim).transpose(0, 1)
|
||||||
|
prefill_q_nope = (torch.matmul(prefill_q_nope,
|
||||||
|
self.kv_b_proj_w_k).transpose(
|
||||||
|
1,
|
||||||
|
0).view(-1, self.num_heads,
|
||||||
|
self.kv_lora_rank))
|
||||||
|
prefill_q_pe = prefill_q_pe.unsqueeze(2)
|
||||||
|
|
||||||
|
# stream2 kv
|
||||||
|
|
||||||
|
nope_cache = kv_cache[0]
|
||||||
|
rope_cache = kv_cache[1]
|
||||||
|
cos = attn_metadata.prefill.cos
|
||||||
|
sin = attn_metadata.prefill.sin
|
||||||
|
cos_q, sin_q = cos, sin
|
||||||
|
|
||||||
|
# cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||||
|
# sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||||
|
|
||||||
|
prefill_q_pe = torch_npu.npu_interleave_rope(
|
||||||
|
prefill_q_pe, cos_q, sin_q) # BNSD
|
||||||
|
prefill_q_pe = prefill_q_pe.squeeze(2) #BSH
|
||||||
|
# q[..., self.qk_nope_head_dim:] = prefill_q_pe # TODO:????
|
||||||
|
|
||||||
|
prefill_latent_cache = prefill_kv_no_split # (B,S,N,D)
|
||||||
|
prefill_k_pe, prefill_k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
|
||||||
|
prefill_latent_cache.view(
|
||||||
|
-1, 1, 1, self.kv_lora_rank + self.qk_rope_head_dim),
|
||||||
|
self.kv_a_layernorm.weight,
|
||||||
|
cos.view(-1, 1, 1, self.qk_rope_head_dim),
|
||||||
|
sin.view(-1, 1, 1, self.qk_rope_head_dim),
|
||||||
|
prefill_slot_mapping.to(torch.int64),
|
||||||
|
rope_cache,
|
||||||
|
nope_cache,
|
||||||
|
k_rope_scale=None,
|
||||||
|
c_kv_scale=None,
|
||||||
|
k_rope_offset=None,
|
||||||
|
c_kv_offset=None,
|
||||||
|
epsilon=self.kv_a_layernorm.variance_epsilon,
|
||||||
|
cache_mode="PA")
|
||||||
|
|
||||||
|
topk_indices = self.indexer_select(x=hidden_states_prefill,
|
||||||
|
qr=prefill_qr,
|
||||||
|
kv_cache=kv_cache,
|
||||||
|
attn_metadata=attn_metadata)
|
||||||
|
query_states = (prefill_q_nope, prefill_q_pe)
|
||||||
|
key_states = (prefill_k_nope, prefill_k_pe)
|
||||||
|
prefill_preprocess_res = PrefillSFAPreprocessResult(
|
||||||
|
q_nope=prefill_q_nope,
|
||||||
|
q_pe=prefill_q_pe,
|
||||||
|
topk_indices=topk_indices,
|
||||||
|
k_nope=prefill_k_nope,
|
||||||
|
k_pe=prefill_k_pe,
|
||||||
|
query_states=query_states,
|
||||||
|
key_states=key_states,
|
||||||
|
)
|
||||||
|
|
||||||
|
return decode_preprocess_res, prefill_preprocess_res
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor, # query in unified attn
|
||||||
|
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||||
|
attn_metadata: M,
|
||||||
|
need_gather_q_kv: bool = False,
|
||||||
|
output: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
assert output is not None, "Output tensor must be provided."
|
||||||
|
if attn_metadata is None:
|
||||||
|
# Profiling run.
|
||||||
|
return output
|
||||||
|
num_actual_tokens = attn_metadata.num_actual_tokens
|
||||||
|
assert attn_metadata.num_decodes is not None and \
|
||||||
|
attn_metadata.num_prefills is not None and \
|
||||||
|
attn_metadata.num_decode_tokens is not None
|
||||||
|
num_decode_tokens = attn_metadata.num_decode_tokens
|
||||||
|
# Inputs and outputs may be padded for CUDA graphs
|
||||||
|
output = output[:num_actual_tokens, ...]
|
||||||
|
o_proj_input_shape = (num_actual_tokens,
|
||||||
|
self.num_heads * self.v_head_dim)
|
||||||
|
o_proj_input = torch.empty(o_proj_input_shape,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
device=hidden_states.device)
|
||||||
|
|
||||||
|
# SFA Preprocess
|
||||||
|
decode_preprocess_res, prefill_preprocess_res = self._sfa_preprocess(
|
||||||
|
hidden_states, kv_cache, attn_metadata, need_gather_q_kv)
|
||||||
|
|
||||||
|
if decode_preprocess_res is not None:
|
||||||
|
# bsz, q_len, _, _ = query_states[0].shape
|
||||||
|
decode_attn_output = self.apply_attention_fusion(
|
||||||
|
query_states=decode_preprocess_res.query_states,
|
||||||
|
key_states=decode_preprocess_res.key_states,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
topk_indices=decode_preprocess_res.topk_indices)
|
||||||
|
o_proj_input[:num_decode_tokens] = decode_attn_output
|
||||||
|
|
||||||
|
if prefill_preprocess_res is not None:
|
||||||
|
prefill_attn_output = self.apply_attention_fusion(
|
||||||
|
query_states=prefill_preprocess_res.query_states,
|
||||||
|
key_states=prefill_preprocess_res.key_states,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
topk_indices=prefill_preprocess_res.topk_indices)
|
||||||
|
o_proj_input[num_decode_tokens:] = prefill_attn_output
|
||||||
|
|
||||||
|
output[...] = self.mla_epilog(o_proj_input, absorb=True)
|
||||||
|
return output
|
||||||
|
|
||||||
|
def apply_attention_fusion(self, query_states, key_states, topk_indices,
|
||||||
|
attn_metadata: M):
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
q_nope, q_pe = query_states
|
||||||
|
k_nope, k_rope = key_states
|
||||||
|
|
||||||
|
if attn_metadata.prefill is not None:
|
||||||
|
|
||||||
|
prefill_metadata = attn_metadata.prefill
|
||||||
|
|
||||||
|
slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention(
|
||||||
|
query=q_nope,
|
||||||
|
key=k_nope,
|
||||||
|
value=k_nope,
|
||||||
|
sparse_indices=topk_indices,
|
||||||
|
scale_value=self.scale,
|
||||||
|
sparse_block_size=1,
|
||||||
|
block_table=prefill_metadata.block_table,
|
||||||
|
actual_seq_lengths_query=prefill_metadata.query_lens,
|
||||||
|
actual_seq_lengths_kv=prefill_metadata.seq_lens,
|
||||||
|
query_rope=q_pe,
|
||||||
|
key_rope=k_rope,
|
||||||
|
layout_query="TND",
|
||||||
|
layout_kv="PA_BSND",
|
||||||
|
sparse_mode=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif attn_metadata.decode is not None:
|
||||||
|
decode_metadata = attn_metadata.decode
|
||||||
|
|
||||||
|
slc_fa_fusion = torch.ops.custom.npu_sparse_flash_attention(
|
||||||
|
query=q_nope,
|
||||||
|
key=k_nope,
|
||||||
|
value=k_nope,
|
||||||
|
sparse_indices=topk_indices,
|
||||||
|
scale_value=self.scale,
|
||||||
|
sparse_block_size=1,
|
||||||
|
block_table=attn_metadata.decode.block_table,
|
||||||
|
actual_seq_lengths_query=decode_metadata.actual_seq_lengths_q,
|
||||||
|
actual_seq_lengths_kv=decode_metadata.seq_lens,
|
||||||
|
query_rope=q_pe,
|
||||||
|
key_rope=k_rope,
|
||||||
|
layout_query="TND",
|
||||||
|
layout_kv="PA_BSND",
|
||||||
|
sparse_mode=3,
|
||||||
|
)
|
||||||
|
slc_fa_fusion = slc_fa_fusion.squeeze(1)
|
||||||
|
|
||||||
|
slc_fa_fusion = slc_fa_fusion.transpose(0, 1)
|
||||||
|
|
||||||
|
# input shape [N//attn_tp_size, T(bs*q_len), D]
|
||||||
|
# output shape [T(bs*q_len), N//attn_tp_size, D]
|
||||||
|
attn_output = torch.matmul(slc_fa_fusion,
|
||||||
|
self.kv_b_proj_w_v).transpose(1, 0).reshape(
|
||||||
|
-1, self.num_heads * self.v_head_dim)
|
||||||
|
# Note: Considering the fusion rules of TBMM, attn_output shape requires a 3-dim shape, and
|
||||||
|
# with appropriate tensor stride for the later 'view' operation if oproj_tp_size > 1.
|
||||||
|
# after reshape: [T(bs*q_len), 1, N//attn_tp_size*D]
|
||||||
|
# attn_output = attn_output.reshape(-1, self.num_heads * self.v_head_dim)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
def mla_epilog(self,
|
||||||
|
attn_output: torch.Tensor = None,
|
||||||
|
absorb: bool = False):
|
||||||
|
# TODO: need to check
|
||||||
|
attn_output = self.o_proj(attn_output.reshape(attn_output.shape[0],
|
||||||
|
-1),
|
||||||
|
is_prefill=True,
|
||||||
|
is_force_scatter=False)
|
||||||
|
|
||||||
|
return attn_output
|
||||||
|
|
||||||
|
def indexer_select(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
qr: torch.Tensor,
|
||||||
|
kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||||
|
attn_metadata: M,
|
||||||
|
):
|
||||||
|
if attn_metadata.prefill is not None:
|
||||||
|
cos = attn_metadata.prefill.cos
|
||||||
|
sin = attn_metadata.prefill.sin
|
||||||
|
actual_seq_lengths_query = attn_metadata.prefill.query_lens
|
||||||
|
actual_seq_lengths_key = attn_metadata.prefill.seq_lens
|
||||||
|
block_table = attn_metadata.prefill.block_table
|
||||||
|
elif attn_metadata.decode is not None:
|
||||||
|
cos = attn_metadata.decode.cos
|
||||||
|
sin = attn_metadata.decode.sin
|
||||||
|
actual_seq_lengths_query = attn_metadata.decode.actual_seq_lengths_q
|
||||||
|
actual_seq_lengths_key = attn_metadata.decode.seq_lens
|
||||||
|
block_table = attn_metadata.decode.block_table
|
||||||
|
|
||||||
|
cos_q, sin_q = cos, sin
|
||||||
|
cos = cos.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||||
|
sin = sin.view(-1, 1, 1, self.qk_rope_head_dim)
|
||||||
|
|
||||||
|
# q process in new stream
|
||||||
|
q = self.wq_b(qr) # [b,s,1536] @ [1536,64*128] = [b,s,64*128]
|
||||||
|
q = q.view(-1, self.n_heads, self.head_dim) # [b,s,64,128]
|
||||||
|
q_pe, q_nope = torch.split(
|
||||||
|
q, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
|
||||||
|
dim=-1) # [b,s,64,64+64]
|
||||||
|
|
||||||
|
q_pe = q_pe.unsqueeze(2)
|
||||||
|
q_pe = torch_npu.npu_interleave_rope(q_pe, cos_q, sin_q)
|
||||||
|
q_pe = q_pe.squeeze(2)
|
||||||
|
q = torch.cat([q_pe, q_nope], dim=-1) # [b*s,64,128]
|
||||||
|
|
||||||
|
k_proj = self.wk(x) # [b,s,7168] @ [7168,128] = [b,s,128]
|
||||||
|
k = self.k_norm(k_proj).unsqueeze(1)
|
||||||
|
k_pe, k_nope = torch.split(
|
||||||
|
k, [self.qk_rope_head_dim, self.head_dim - self.qk_rope_head_dim],
|
||||||
|
dim=-1) # [b,s,64+64]
|
||||||
|
|
||||||
|
k_pe = k_pe.unsqueeze(2)
|
||||||
|
k_pe = torch_npu.npu_interleave_rope(k_pe, cos, sin)
|
||||||
|
k_pe = k_pe.squeeze(2)
|
||||||
|
|
||||||
|
k = torch.cat([k_pe, k_nope], dim=-1) # [b*s,128]
|
||||||
|
|
||||||
|
if kv_cache is not None:
|
||||||
|
torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]),
|
||||||
|
attn_metadata.slot_mapping.view(
|
||||||
|
-1, 1),
|
||||||
|
k.view(-1,
|
||||||
|
k.shape[-1])) # b, s, n, d
|
||||||
|
|
||||||
|
weights = self.weights_proj(x)
|
||||||
|
|
||||||
|
topk_indices = torch.ops.custom.npu_lightning_indexer(
|
||||||
|
query=q,
|
||||||
|
key=kv_cache[2],
|
||||||
|
weights=weights,
|
||||||
|
actual_seq_lengths_query=actual_seq_lengths_query,
|
||||||
|
actual_seq_lengths_key=actual_seq_lengths_key,
|
||||||
|
block_table=block_table,
|
||||||
|
layout_query="TND",
|
||||||
|
layout_key="PA_BSND",
|
||||||
|
sparse_count=2048,
|
||||||
|
sparse_mode=3)
|
||||||
|
return topk_indices
|
||||||
@@ -493,8 +493,11 @@ class LLMDataDistCMgrConnectorWorker():
|
|||||||
assert self.local_agent_metadata is not None
|
assert self.local_agent_metadata is not None
|
||||||
kv_cache_dtype = first_kv_cache.dtype
|
kv_cache_dtype = first_kv_cache.dtype
|
||||||
self.use_mla: bool = first_kv_cache_tuple[0].size(
|
self.use_mla: bool = first_kv_cache_tuple[0].size(
|
||||||
-1) != first_kv_cache_tuple[1].size(-1)
|
-1) != first_kv_cache_tuple[1].size(-1) and len(
|
||||||
|
first_kv_cache_tuple) == 2
|
||||||
|
self.use_sfa: bool = len(first_kv_cache_tuple) == 3
|
||||||
# MLA case. [2 (k_normed, k_pe), num_blocks, ...]
|
# MLA case. [2 (k_normed, k_pe), num_blocks, ...]
|
||||||
|
# SFA case. [3 (k_normed, k_pe, k_idx), num_blocks, ...]
|
||||||
# MHA case. [2 (k and v), num_blocks, ...]
|
# MHA case. [2 (k and v), num_blocks, ...]
|
||||||
self.num_blocks = first_kv_cache.shape[0]
|
self.num_blocks = first_kv_cache.shape[0]
|
||||||
block_rank = 3 # [block_size, latent_dim]
|
block_rank = 3 # [block_size, latent_dim]
|
||||||
@@ -540,6 +543,58 @@ class LLMDataDistCMgrConnectorWorker():
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
|
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
|
||||||
)
|
)
|
||||||
|
elif self.use_sfa:
|
||||||
|
cache_k_normed_addr_list = []
|
||||||
|
cache_k_pe_addr_list = []
|
||||||
|
cache_k_idx_addr_list = []
|
||||||
|
k_normed = None
|
||||||
|
k_pe = None
|
||||||
|
k_idx = None
|
||||||
|
for cache_or_caches in kv_caches.values():
|
||||||
|
assert len(cache_or_caches) > 1
|
||||||
|
k_normed, k_pe, k_idx = cache_or_caches[0], cache_or_caches[
|
||||||
|
1], cache_or_caches[2]
|
||||||
|
cache_k_normed_addr_list.append(k_normed.data_ptr())
|
||||||
|
cache_k_pe_addr_list.append(k_pe.data_ptr())
|
||||||
|
cache_k_idx_addr_list.append(k_idx.data_ptr())
|
||||||
|
self.cache_addr = (cache_k_normed_addr_list, cache_k_pe_addr_list,
|
||||||
|
cache_k_idx_addr_list)
|
||||||
|
|
||||||
|
cache_desc_k_normed = CacheDesc(
|
||||||
|
len(self.cache_addr[0]), [*k_normed.shape],
|
||||||
|
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
||||||
|
cache_desc_k_pe = CacheDesc(
|
||||||
|
len(self.cache_addr[1]), [*k_pe.shape],
|
||||||
|
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
||||||
|
cache_desc_k_idx = CacheDesc(
|
||||||
|
len(self.cache_addr[2]), [*k_idx.shape],
|
||||||
|
TORCH_DTYPE_TO_NPU_DTYPE[kv_cache_dtype])
|
||||||
|
cache_key_k_normed = BlocksCacheKey(cluster_id=int(
|
||||||
|
self.local_agent_metadata.cluster_id),
|
||||||
|
model_id=0)
|
||||||
|
cache_key_k_pe = BlocksCacheKey(cluster_id=int(
|
||||||
|
self.local_agent_metadata.cluster_id),
|
||||||
|
model_id=1)
|
||||||
|
cache_key_k_idx = BlocksCacheKey(cluster_id=int(
|
||||||
|
self.local_agent_metadata.cluster_id),
|
||||||
|
model_id=2)
|
||||||
|
self.cache_desc = (cache_desc_k_normed, cache_desc_k_pe,
|
||||||
|
cache_desc_k_idx)
|
||||||
|
self.cache_key = (cache_key_k_normed, cache_key_k_pe,
|
||||||
|
cache_key_k_idx)
|
||||||
|
try:
|
||||||
|
cache_k_normed = self.cache_manager.register_blocks_cache(
|
||||||
|
self.cache_desc[0], self.cache_addr[0], self.cache_key[0])
|
||||||
|
cache_k_pe = self.cache_manager.register_blocks_cache(
|
||||||
|
self.cache_desc[1], self.cache_addr[1], self.cache_key[1])
|
||||||
|
cache_k_idx = self.cache_manager.register_blocks_cache(
|
||||||
|
self.cache_desc[2], self.cache_addr[2], self.cache_key[2])
|
||||||
|
self.cache = (cache_k_normed, cache_k_pe, cache_k_idx)
|
||||||
|
logger.info("LLMDataDistWorker: End of register Paged Cache.")
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
for cache_or_caches in kv_caches.values():
|
for cache_or_caches in kv_caches.values():
|
||||||
for cache in cache_or_caches:
|
for cache in cache_or_caches:
|
||||||
@@ -826,6 +881,38 @@ class LLMDataDistCMgrConnectorWorker():
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
|
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
|
||||||
)
|
)
|
||||||
|
elif self.use_sfa:
|
||||||
|
remote_cache_key_k_normed = BlocksCacheKey(
|
||||||
|
cluster_id=remote_cluster_id, model_id=0)
|
||||||
|
remote_cache_key_k_pe = BlocksCacheKey(
|
||||||
|
cluster_id=remote_cluster_id, model_id=1)
|
||||||
|
remote_cache_key_k_idx = BlocksCacheKey(
|
||||||
|
cluster_id=remote_cluster_id, model_id=2)
|
||||||
|
logger.info("Try pull blocks from remote server")
|
||||||
|
try:
|
||||||
|
self.cache_manager.pull_blocks(
|
||||||
|
remote_cache_key_k_normed,
|
||||||
|
self.cache[0], # type: ignore[has-type]
|
||||||
|
remote_block_ids,
|
||||||
|
local_block_ids)
|
||||||
|
self.cache_manager.pull_blocks(
|
||||||
|
remote_cache_key_k_pe,
|
||||||
|
self.cache[1], # type: ignore[has-type]
|
||||||
|
remote_block_ids,
|
||||||
|
local_block_ids)
|
||||||
|
self.cache_manager.pull_blocks(
|
||||||
|
remote_cache_key_k_idx,
|
||||||
|
self.cache[2], # type: ignore[has-type]
|
||||||
|
remote_block_ids,
|
||||||
|
local_block_ids)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
raise RuntimeError(
|
||||||
|
f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to pull_blocks remote_cache_key: {remote_cache_key_k_normed} {remote_cache_key_k_pe} {remote_cache_key_k_idx}, cache: {self.cache}, local_block_ids: {local_block_ids}, remote_block_ids: {remote_block_ids}" # type: ignore[has-type]
|
||||||
|
)
|
||||||
|
except LLMException:
|
||||||
|
raise RuntimeError(
|
||||||
|
"LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
remote_cache_key = BlocksCacheKey(cluster_id=remote_cluster_id)
|
remote_cache_key = BlocksCacheKey(cluster_id=remote_cluster_id)
|
||||||
logger.info("Try pull blocks from remote server")
|
logger.info("Try pull blocks from remote server")
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from vllm.v1.core.sched.output import SchedulerOutput
|
|||||||
from vllm.v1.request import RequestStatus
|
from vllm.v1.request import RequestStatus
|
||||||
|
|
||||||
import vllm_ascend.envs as envs_ascend
|
import vllm_ascend.envs as envs_ascend
|
||||||
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
@@ -238,6 +239,7 @@ class KVCacheRecvingThread(threading.Thread):
|
|||||||
self.block_len = block_len
|
self.block_len = block_len
|
||||||
# TODO(jianzs): find a better way to detect MLA.
|
# TODO(jianzs): find a better way to detect MLA.
|
||||||
self.use_mla = len(block_len) == 2
|
self.use_mla = len(block_len) == 2
|
||||||
|
self.use_sfa = len(block_len) == 3
|
||||||
|
|
||||||
self.request_queue: queue.Queue[Any] = queue.Queue()
|
self.request_queue: queue.Queue[Any] = queue.Queue()
|
||||||
# TODO(jianzs): make this configurable
|
# TODO(jianzs): make this configurable
|
||||||
@@ -349,8 +351,12 @@ class KVCacheRecvingThread(threading.Thread):
|
|||||||
src_list, dst_list, length_list = [], [], []
|
src_list, dst_list, length_list = [], [], []
|
||||||
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
|
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
|
||||||
zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs)):
|
zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs)):
|
||||||
block_len = (self.block_len[k % 2]
|
if self.use_mla:
|
||||||
if self.use_mla else self.block_len[0])
|
block_len = (self.block_len[k % 2])
|
||||||
|
elif self.use_sfa:
|
||||||
|
block_len = (self.block_len[k % 3])
|
||||||
|
else:
|
||||||
|
block_len = (self.block_len[0])
|
||||||
for i, remote_block_id in enumerate(grouped_remote_block_ids):
|
for i, remote_block_id in enumerate(grouped_remote_block_ids):
|
||||||
local_block_ids = grouped_local_block_ids[i]
|
local_block_ids = grouped_local_block_ids[i]
|
||||||
src = src_layer_base_addr + local_block_ids[0] * block_len
|
src = src_layer_base_addr + local_block_ids[0] * block_len
|
||||||
@@ -567,6 +573,7 @@ class MooncakeConnectorScheduler:
|
|||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
||||||
self.vllm_config = vllm_config
|
self.vllm_config = vllm_config
|
||||||
|
self.ascend_config = get_ascend_config()
|
||||||
self.block_size = vllm_config.cache_config.block_size
|
self.block_size = vllm_config.cache_config.block_size
|
||||||
self.engine_id = engine_id
|
self.engine_id = engine_id
|
||||||
logger.info("Initializing Mooncake Scheduler %s", engine_id)
|
logger.info("Initializing Mooncake Scheduler %s", engine_id)
|
||||||
@@ -726,7 +733,7 @@ class MooncakeConnectorScheduler:
|
|||||||
assert "tp_size" in decode_parallel_config.keys()
|
assert "tp_size" in decode_parallel_config.keys()
|
||||||
self._decode_tp_size = decode_parallel_config["tp_size"]
|
self._decode_tp_size = decode_parallel_config["tp_size"]
|
||||||
|
|
||||||
if self.vllm_config.model_config.use_mla:
|
if self.vllm_config.model_config.use_mla or self.ascend_config.use_sfa:
|
||||||
return self._decode_tp_size
|
return self._decode_tp_size
|
||||||
else:
|
else:
|
||||||
# TODO support mha and gqa
|
# TODO support mha and gqa
|
||||||
@@ -847,7 +854,9 @@ class MooncakeConnectorWorker:
|
|||||||
|
|
||||||
# TODO(tms): Find a more robust way to detect and handle MLA
|
# TODO(tms): Find a more robust way to detect and handle MLA
|
||||||
self.use_mla = first_kv_cache_tuple[0].size(
|
self.use_mla = first_kv_cache_tuple[0].size(
|
||||||
-1) != first_kv_cache_tuple[1].size(-1)
|
-1) != first_kv_cache_tuple[1].size(-1) and len(
|
||||||
|
first_kv_cache_tuple) == 2
|
||||||
|
self.use_sfa = len(first_kv_cache_tuple) == 3
|
||||||
if self.use_mla:
|
if self.use_mla:
|
||||||
# MLA case.[num_block, block_size, 1, hidden_dim]
|
# MLA case.[num_block, block_size, 1, hidden_dim]
|
||||||
self.num_blocks = first_kv_cache.shape[0]
|
self.num_blocks = first_kv_cache.shape[0]
|
||||||
@@ -861,6 +870,21 @@ class MooncakeConnectorWorker:
|
|||||||
logger.info(
|
logger.info(
|
||||||
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
|
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
|
||||||
self.num_blocks, block_shape_norm, block_shape_pe)
|
self.num_blocks, block_shape_norm, block_shape_pe)
|
||||||
|
elif self.use_sfa:
|
||||||
|
self.num_blocks = first_kv_cache.shape[0]
|
||||||
|
block_rank = 3 # [block_size, latent_dim]
|
||||||
|
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
|
||||||
|
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
|
||||||
|
block_shape_k = first_kv_cache_tuple[2].shape[-block_rank:]
|
||||||
|
self.block_len = [
|
||||||
|
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
|
||||||
|
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
|
||||||
|
first_kv_cache[2].element_size() * math.prod(block_shape_k)
|
||||||
|
]
|
||||||
|
logger.info(
|
||||||
|
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s, block_shape_k: %s",
|
||||||
|
self.num_blocks, block_shape_norm, block_shape_pe,
|
||||||
|
block_shape_k)
|
||||||
else:
|
else:
|
||||||
# [num_block, block_size, num_head, hidden_dim]
|
# [num_block, block_size, num_head, hidden_dim]
|
||||||
self.num_blocks = first_kv_cache.shape[0]
|
self.num_blocks = first_kv_cache.shape[0]
|
||||||
@@ -871,8 +895,9 @@ class MooncakeConnectorWorker:
|
|||||||
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
|
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
|
||||||
block_shape)
|
block_shape)
|
||||||
|
|
||||||
logger.info("Registering KV_Caches. use_mla: %s, shape %s",
|
logger.info(
|
||||||
self.use_mla, first_kv_cache.shape)
|
"Registering KV_Caches. use_mla: %s, use_sfa: %s, shape %s",
|
||||||
|
self.use_mla, self.use_sfa, first_kv_cache.shape)
|
||||||
|
|
||||||
self.kv_caches = kv_caches
|
self.kv_caches = kv_caches
|
||||||
kv_caches_base_addr = []
|
kv_caches_base_addr = []
|
||||||
@@ -884,9 +909,16 @@ class MooncakeConnectorWorker:
|
|||||||
region_len = self.num_blocks * self.block_len[i % 2]
|
region_len = self.num_blocks * self.block_len[i % 2]
|
||||||
kv_caches_base_addr.append(base_addr)
|
kv_caches_base_addr.append(base_addr)
|
||||||
self._register(base_addr, region_len)
|
self._register(base_addr, region_len)
|
||||||
|
elif self.use_sfa:
|
||||||
|
for i, cache in enumerate(cache_or_caches, 0):
|
||||||
|
base_addr = cache.data_ptr()
|
||||||
|
region_len = self.num_blocks * self.block_len[i % 3]
|
||||||
|
kv_caches_base_addr.append(base_addr)
|
||||||
|
self._register(base_addr, region_len)
|
||||||
else:
|
else:
|
||||||
cache_list = [cache_or_caches
|
cache_list = [
|
||||||
] if self.use_mla else cache_or_caches
|
cache_or_caches
|
||||||
|
] if self.use_mla or self.use_sfa else cache_or_caches
|
||||||
for cache in cache_list:
|
for cache in cache_list:
|
||||||
base_addr = cache.data_ptr()
|
base_addr = cache.data_ptr()
|
||||||
region_len = self.num_blocks * self.block_len[0]
|
region_len = self.num_blocks * self.block_len[0]
|
||||||
|
|||||||
@@ -162,6 +162,13 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
|||||||
# Whether to enable msMonitor tool to monitor the performance of vllm-ascend.
|
# Whether to enable msMonitor tool to monitor the performance of vllm-ascend.
|
||||||
"MSMONITOR_USE_DAEMON":
|
"MSMONITOR_USE_DAEMON":
|
||||||
lambda: bool(int(os.getenv("MSMONITOR_USE_DAEMON", '0'))),
|
lambda: bool(int(os.getenv("MSMONITOR_USE_DAEMON", '0'))),
|
||||||
|
# Timeout (in seconds) for delayed KVCache block release. In the prefill
|
||||||
|
# node, if a request is marked for delayed KV block release and the blocks
|
||||||
|
# are not freed within this timeout, they will be forcibly released.
|
||||||
|
"VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT":
|
||||||
|
lambda: int(os.getenv("VLLM_ASCEND_KVCACHE_DELAY_FREE_TIMEOUT", 250)),
|
||||||
|
"VLLM_ASCEND_ENABLE_MLAPO":
|
||||||
|
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_MLAPO", '0'))),
|
||||||
}
|
}
|
||||||
|
|
||||||
# end-env-vars-definition
|
# end-env-vars-definition
|
||||||
|
|||||||
@@ -37,6 +37,10 @@ def register_model():
|
|||||||
"DeepseekV3ForCausalLM",
|
"DeepseekV3ForCausalLM",
|
||||||
"vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM")
|
"vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM")
|
||||||
|
|
||||||
|
ModelRegistry.register_model(
|
||||||
|
"DeepseekV32ForCausalLM",
|
||||||
|
"vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM")
|
||||||
|
|
||||||
ModelRegistry.register_model(
|
ModelRegistry.register_model(
|
||||||
"DeepSeekMTPModel",
|
"DeepSeekMTPModel",
|
||||||
"vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP")
|
"vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP")
|
||||||
|
|||||||
@@ -60,6 +60,8 @@ from vllm.model_executor.models.utils import (PPMissingLayer,
|
|||||||
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.models.layers.mla import AscendMLAModules
|
from vllm_ascend.models.layers.mla import AscendMLAModules
|
||||||
|
from vllm_ascend.models.layers.sfa import (AscendSFAModules,
|
||||||
|
AscendSparseFlashAttention, Indexer)
|
||||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
||||||
|
|
||||||
|
|
||||||
@@ -244,6 +246,180 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
|||||||
return self.mla_attn(positions, hidden_states, kv_cache, attn_metadata)
|
return self.mla_attn(positions, hidden_states, kv_cache, attn_metadata)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomDeepseekV2SFAAttention(DeepseekV2MLAAttention):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
qk_nope_head_dim: int,
|
||||||
|
qk_rope_head_dim: int,
|
||||||
|
v_head_dim: int,
|
||||||
|
q_lora_rank: Optional[int],
|
||||||
|
kv_lora_rank: int,
|
||||||
|
rope_theta: float = 10000,
|
||||||
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
|
max_position_embeddings: int = 8192,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.qk_nope_head_dim = qk_nope_head_dim
|
||||||
|
self.qk_rope_head_dim = qk_rope_head_dim
|
||||||
|
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||||
|
self.v_head_dim = v_head_dim
|
||||||
|
|
||||||
|
self.q_lora_rank = q_lora_rank
|
||||||
|
self.kv_lora_rank = kv_lora_rank
|
||||||
|
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
assert num_heads % self.tp_size == 0
|
||||||
|
self.num_local_heads = num_heads // self.tp_size
|
||||||
|
self.layers = config.num_hidden_layers
|
||||||
|
self.first_k_dense_replace = config.first_k_dense_replace
|
||||||
|
|
||||||
|
self.scaling = self.qk_head_dim**-0.5
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
|
||||||
|
self.prefix = prefix
|
||||||
|
self.debug_layer_idx = int(self.prefix.split(".")[-2])
|
||||||
|
|
||||||
|
ascend_config = get_ascend_config()
|
||||||
|
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||||
|
|
||||||
|
if self.q_lora_rank is not None:
|
||||||
|
self.q_a_proj = ReplicatedLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.q_lora_rank,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.q_a_proj",
|
||||||
|
return_bias=False,
|
||||||
|
)
|
||||||
|
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.q_b_proj = ColumnParallelLinear(
|
||||||
|
q_lora_rank,
|
||||||
|
self.num_heads * self.qk_head_dim,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.q_b_proj",
|
||||||
|
return_bias=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.q_proj = ColumnParallelLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.num_heads * self.qk_head_dim,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.q_proj",
|
||||||
|
return_bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.kv_a_proj_with_mqa",
|
||||||
|
return_bias=False,
|
||||||
|
)
|
||||||
|
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.kv_b_proj = ColumnParallelLinear(
|
||||||
|
self.kv_lora_rank,
|
||||||
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.kv_b_proj",
|
||||||
|
return_bias=False,
|
||||||
|
)
|
||||||
|
self.o_proj = CustomDeepseekV2RowParallelLinear(
|
||||||
|
self.num_heads * self.v_head_dim,
|
||||||
|
self.hidden_size,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
return_bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if rope_scaling:
|
||||||
|
rope_scaling["rope_type"] = 'deepseek_yarn'
|
||||||
|
self.rotary_emb = get_rope(qk_rope_head_dim,
|
||||||
|
rotary_dim=qk_rope_head_dim,
|
||||||
|
max_position=max_position_embeddings,
|
||||||
|
base=rope_theta,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
|
is_neox_style=False)
|
||||||
|
if rope_scaling:
|
||||||
|
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
|
||||||
|
scaling_factor = rope_scaling["factor"]
|
||||||
|
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||||
|
self.scaling = self.scaling * mscale * mscale
|
||||||
|
|
||||||
|
self.dim: int = config.hidden_size # 7168
|
||||||
|
# TODO(zzzzwwjj): wait transformers add these params
|
||||||
|
self.n_heads: int = 64 # 64
|
||||||
|
self.head_dim: int = 128 # 128
|
||||||
|
self.index_topk: int = 2048 # 2048
|
||||||
|
self.indexer = Indexer(
|
||||||
|
config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
dim=self.dim,
|
||||||
|
n_heads=self.n_heads,
|
||||||
|
head_dim=self.head_dim,
|
||||||
|
index_topk=self.index_topk,
|
||||||
|
prefix=f"{prefix}.indexer",
|
||||||
|
)
|
||||||
|
|
||||||
|
sfa_modules = AscendSFAModules(
|
||||||
|
q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None,
|
||||||
|
q_a_layernorm=self.q_a_layernorm
|
||||||
|
if self.q_lora_rank is not None else None,
|
||||||
|
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
|
||||||
|
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
|
||||||
|
kv_a_layernorm=self.kv_a_layernorm,
|
||||||
|
kv_b_proj=self.kv_b_proj,
|
||||||
|
o_proj=self.o_proj,
|
||||||
|
rotary_emb=self.rotary_emb,
|
||||||
|
indexer=self.indexer)
|
||||||
|
|
||||||
|
self.sfa_attn = AscendSparseFlashAttention(
|
||||||
|
self.hidden_size,
|
||||||
|
self.enable_shared_expert_dp,
|
||||||
|
self.debug_layer_idx,
|
||||||
|
self.first_k_dense_replace,
|
||||||
|
self.tp_size,
|
||||||
|
sfa_modules,
|
||||||
|
self.num_local_heads,
|
||||||
|
self.scaling,
|
||||||
|
self.layers,
|
||||||
|
self.kv_lora_rank,
|
||||||
|
self.qk_rope_head_dim,
|
||||||
|
self.q_lora_rank,
|
||||||
|
self.qk_nope_head_dim,
|
||||||
|
self.qk_head_dim,
|
||||||
|
self.v_head_dim,
|
||||||
|
cache_config,
|
||||||
|
quant_config,
|
||||||
|
prefix,
|
||||||
|
)
|
||||||
|
self.prefix = prefix
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: Optional[torch.Tensor] = None,
|
||||||
|
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||||
|
return self.sfa_attn(positions, hidden_states, kv_cache, attn_metadata)
|
||||||
|
|
||||||
|
|
||||||
class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
|
def __init__(self, vllm_config: VllmConfig, prefix: str) -> None:
|
||||||
@@ -253,6 +429,7 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
|||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
quant_config = vllm_config.quant_config
|
quant_config = vllm_config.quant_config
|
||||||
parallel_config = vllm_config.parallel_config
|
parallel_config = vllm_config.parallel_config
|
||||||
|
ascend_config = get_ascend_config()
|
||||||
|
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
rope_theta = getattr(config, "rope_theta", 10000)
|
rope_theta = getattr(config, "rope_theta", 10000)
|
||||||
@@ -268,6 +445,9 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
|||||||
self.tp_rank = get_tp_group().rank_in_group
|
self.tp_rank = get_tp_group().rank_in_group
|
||||||
# TODO: enable mla in vllm-ascend
|
# TODO: enable mla in vllm-ascend
|
||||||
if model_config.use_mla:
|
if model_config.use_mla:
|
||||||
|
if ascend_config.use_sfa:
|
||||||
|
attn_cls = CustomDeepseekV2SFAAttention
|
||||||
|
else:
|
||||||
attn_cls = CustomDeepseekV2MLAAttention
|
attn_cls = CustomDeepseekV2MLAAttention
|
||||||
else:
|
else:
|
||||||
attn_cls = DeepseekV2Attention
|
attn_cls = DeepseekV2Attention
|
||||||
|
|||||||
233
vllm_ascend/models/layers/sfa.py
Normal file
233
vllm_ascend/models/layers/sfa.py
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
# Copyright 2023 The vLLM team.
|
||||||
|
# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from vllm.attention import Attention, AttentionMetadata
|
||||||
|
from vllm.config import CacheConfig, get_current_vllm_config
|
||||||
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||||||
|
from vllm.model_executor.layers.linear import ReplicatedLinear
|
||||||
|
from vllm.model_executor.layers.mla import MultiHeadLatentAttention
|
||||||
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.utils import direct_register_custom_op
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AscendSFAModules:
|
||||||
|
q_a_proj: Optional[torch.nn.Module]
|
||||||
|
q_a_layernorm: Optional[torch.nn.Module]
|
||||||
|
q_proj: Optional[torch.nn.Module]
|
||||||
|
kv_a_proj_with_mqa: torch.nn.Module
|
||||||
|
kv_a_layernorm: torch.nn.Module
|
||||||
|
kv_b_proj: torch.nn.Module
|
||||||
|
o_proj: torch.nn.Module
|
||||||
|
rotary_emb: torch.nn.Module
|
||||||
|
indexer: torch.nn.Module
|
||||||
|
|
||||||
|
|
||||||
|
class AscendSparseFlashAttention(MultiHeadLatentAttention):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hidden_size: int,
|
||||||
|
enable_shared_expert_dp: bool,
|
||||||
|
debug_layer_idx: int,
|
||||||
|
first_k_dense_replace: int,
|
||||||
|
tp_size: int,
|
||||||
|
sfa_modules: AscendSFAModules,
|
||||||
|
num_local_heads: int,
|
||||||
|
scaling: float,
|
||||||
|
layers: int,
|
||||||
|
kv_lora_rank: int,
|
||||||
|
qk_rope_head_dim: int,
|
||||||
|
q_lora_rank: Optional[int],
|
||||||
|
qk_nope_head_dim: int,
|
||||||
|
qk_head_dim: int,
|
||||||
|
v_head_dim: int,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
) -> None:
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||||
|
self.debug_layer_idx = debug_layer_idx
|
||||||
|
self.first_k_dense_replace = first_k_dense_replace
|
||||||
|
self.tp_size = tp_size
|
||||||
|
self.num_local_heads = num_local_heads
|
||||||
|
self.layers = layers
|
||||||
|
self.kv_lora_rank = kv_lora_rank
|
||||||
|
self.qk_rope_head_dim = qk_rope_head_dim
|
||||||
|
self.q_lora_rank = q_lora_rank
|
||||||
|
self.qk_nope_head_dim = qk_nope_head_dim
|
||||||
|
self.qk_head_dim = qk_head_dim
|
||||||
|
self.v_head_dim = v_head_dim
|
||||||
|
self.prefix = prefix
|
||||||
|
|
||||||
|
self.sfa_attn = Attention(
|
||||||
|
num_heads=self.num_local_heads,
|
||||||
|
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
|
||||||
|
scale=scaling,
|
||||||
|
num_kv_heads=1,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn",
|
||||||
|
use_mla=True,
|
||||||
|
use_sfa=True,
|
||||||
|
# SFA Args
|
||||||
|
q_lora_rank=self.q_lora_rank,
|
||||||
|
kv_lora_rank=self.kv_lora_rank,
|
||||||
|
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||||
|
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||||
|
qk_head_dim=self.qk_head_dim,
|
||||||
|
v_head_dim=self.v_head_dim,
|
||||||
|
rotary_emb=sfa_modules.rotary_emb,
|
||||||
|
q_a_proj=sfa_modules.q_a_proj,
|
||||||
|
q_a_layernorm=sfa_modules.q_a_layernorm,
|
||||||
|
q_proj=sfa_modules.q_proj,
|
||||||
|
kv_a_proj_with_mqa=sfa_modules.kv_a_proj_with_mqa,
|
||||||
|
kv_a_layernorm=sfa_modules.kv_a_layernorm,
|
||||||
|
kv_b_proj=sfa_modules.kv_b_proj,
|
||||||
|
o_proj=sfa_modules.o_proj,
|
||||||
|
indexer=sfa_modules.indexer)
|
||||||
|
|
||||||
|
compilation_config = get_current_vllm_config().compilation_config
|
||||||
|
if prefix in compilation_config.static_forward_context:
|
||||||
|
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||||
|
compilation_config.static_forward_context[prefix] = self
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: Optional[torch.Tensor] = None,
|
||||||
|
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||||
|
num_tokens = hidden_states.shape[0]
|
||||||
|
need_gather_q_kv = False
|
||||||
|
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
|
||||||
|
# Simulate all gather to calculate output shape
|
||||||
|
num_tokens = num_tokens * self.tp_size
|
||||||
|
need_gather_q_kv = True
|
||||||
|
if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace:
|
||||||
|
output_shape = hidden_states.shape
|
||||||
|
else:
|
||||||
|
rows = num_tokens // self.tp_size
|
||||||
|
if num_tokens % self.tp_size:
|
||||||
|
rows += 1
|
||||||
|
output_shape = (rows, hidden_states.shape[1])
|
||||||
|
# FIXME: This does not seem right, should make sure the buffer is fixed
|
||||||
|
output = torch.empty(output_shape,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
device=hidden_states.device)
|
||||||
|
torch.ops.vllm.sfa_forward(hidden_states, need_gather_q_kv, output,
|
||||||
|
self.prefix)
|
||||||
|
output = output.view(-1, output_shape[-1])
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def sfa_forward(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
need_gather_q_kv: bool,
|
||||||
|
output: torch.Tensor,
|
||||||
|
layer_name: str,
|
||||||
|
) -> None:
|
||||||
|
forward_context: ForwardContext = get_forward_context()
|
||||||
|
self = forward_context.no_compile_layers[layer_name]
|
||||||
|
if forward_context.attn_metadata:
|
||||||
|
attn_metadata = forward_context.attn_metadata[self.sfa_attn.layer_name]
|
||||||
|
else:
|
||||||
|
attn_metadata = forward_context.attn_metadata
|
||||||
|
kv_cache = self.sfa_attn.kv_cache[forward_context.virtual_engine]
|
||||||
|
self.sfa_attn.impl.forward(hidden_states, kv_cache, attn_metadata,
|
||||||
|
need_gather_q_kv, output)
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
class Indexer(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config,
|
||||||
|
dim: int = 7168,
|
||||||
|
n_heads: int = 64,
|
||||||
|
head_dim: int = 128,
|
||||||
|
index_topk: int = 2048,
|
||||||
|
q_lora_rank: int = 1536,
|
||||||
|
rope_head_dim: int = 64,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: Optional[str] = ""):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.dim: int = dim # 7168
|
||||||
|
self.n_heads: int = n_heads # 64
|
||||||
|
self.head_dim: int = head_dim # 128
|
||||||
|
self.rope_head_dim: int = rope_head_dim # 64
|
||||||
|
self.index_topk: int = index_topk # 2048
|
||||||
|
self.q_lora_rank: int = q_lora_rank # 1536
|
||||||
|
self.wq_b = ReplicatedLinear(
|
||||||
|
self.q_lora_rank,
|
||||||
|
self.n_heads * self.head_dim,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.wq_b",
|
||||||
|
return_bias=False,
|
||||||
|
)
|
||||||
|
self.wk = ReplicatedLinear(
|
||||||
|
self.dim,
|
||||||
|
self.head_dim,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.wk",
|
||||||
|
return_bias=False,
|
||||||
|
)
|
||||||
|
self.weights_proj = ReplicatedLinear(
|
||||||
|
self.dim,
|
||||||
|
self.n_heads,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.weights_proj",
|
||||||
|
return_bias=False,
|
||||||
|
)
|
||||||
|
self.k_norm = nn.LayerNorm(self.head_dim)
|
||||||
|
self.softmax_scale = self.head_dim**-0.5
|
||||||
|
|
||||||
|
def forward(self):
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def sfa_forward_fake(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
need_gather_q_kv: bool,
|
||||||
|
output: torch.Tensor,
|
||||||
|
layer_name: str,
|
||||||
|
) -> None:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
direct_register_custom_op(
|
||||||
|
op_name="sfa_forward",
|
||||||
|
op_func=sfa_forward,
|
||||||
|
mutates_args=["output"],
|
||||||
|
fake_impl=sfa_forward_fake,
|
||||||
|
dispatch_key="PrivateUse1",
|
||||||
|
)
|
||||||
@@ -15,6 +15,10 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#
|
#
|
||||||
|
|
||||||
|
import vllm_ascend.patch.platform.patch_common.patch_config # noqa
|
||||||
import vllm_ascend.patch.platform.patch_common.patch_distributed # noqa
|
import vllm_ascend.patch.platform.patch_common.patch_distributed # noqa
|
||||||
import vllm_ascend.patch.platform.patch_common.patch_mamba_config # noqa
|
import vllm_ascend.patch.platform.patch_common.patch_mamba_config # noqa
|
||||||
import vllm_ascend.patch.platform.patch_common.patch_multimodal_merge # noqa
|
import vllm_ascend.patch.platform.patch_common.patch_multimodal_merge # noqa
|
||||||
|
import vllm_ascend.patch.platform.patch_common.patch_transformers_utils # noqa
|
||||||
|
import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa
|
||||||
|
import vllm_ascend.patch.worker.patch_common.patch_attentionspec # noqa
|
||||||
|
|||||||
313
vllm_ascend/patch/platform/patch_common/patch_config.py
Normal file
313
vllm_ascend/patch/platform/patch_common/patch_config.py
Normal file
@@ -0,0 +1,313 @@
|
|||||||
|
import ast
|
||||||
|
|
||||||
|
import vllm.envs as envs
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
from vllm.config.speculative import SpeculativeConfig
|
||||||
|
from vllm.logger import logger
|
||||||
|
|
||||||
|
|
||||||
|
# mypy: ignore-errors
|
||||||
|
@property
|
||||||
|
def is_deepseek_mla(self: ModelConfig):
|
||||||
|
if not hasattr(self.hf_text_config, "model_type"):
|
||||||
|
return False
|
||||||
|
elif self.hf_text_config.model_type in \
|
||||||
|
('deepseek_v2', 'deepseek_v3', 'deepseek_mtp',
|
||||||
|
'kimi_k2', 'longcat_flash', 'deepseek_v32'):
|
||||||
|
return self.hf_text_config.kv_lora_rank is not None
|
||||||
|
elif self.hf_text_config.model_type == 'eagle':
|
||||||
|
# if the model is an EAGLE module, check for the
|
||||||
|
# underlying architecture
|
||||||
|
return self.hf_text_config.model.model_type in \
|
||||||
|
('deepseek_v2', 'deepseek_v3', 'deepseek_v32') \
|
||||||
|
and self.hf_text_config.kv_lora_rank is not None
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig:
|
||||||
|
if hf_config.model_type in ("deepseek_v3", "deepseek_v32"):
|
||||||
|
hf_config.model_type = "deepseek_mtp"
|
||||||
|
if hf_config.model_type == "deepseek_mtp":
|
||||||
|
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||||
|
hf_config.update({
|
||||||
|
"n_predict": n_predict,
|
||||||
|
"architectures": ["DeepSeekMTPModel"]
|
||||||
|
})
|
||||||
|
|
||||||
|
if hf_config.architectures[0] == "MiMoForCausalLM":
|
||||||
|
hf_config.model_type = "mimo_mtp"
|
||||||
|
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||||
|
hf_config.update({
|
||||||
|
"num_hidden_layers": 0,
|
||||||
|
"n_predict": n_predict,
|
||||||
|
"architectures": ["MiMoMTPModel"]
|
||||||
|
})
|
||||||
|
|
||||||
|
if hf_config.architectures[0] == "Glm4MoeForCausalLM":
|
||||||
|
hf_config.model_type = "glm4_moe_mtp"
|
||||||
|
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||||
|
hf_config.update({
|
||||||
|
"num_hidden_layers": 0,
|
||||||
|
"n_predict": n_predict,
|
||||||
|
"architectures": ["Glm4MoeMTPModel"]
|
||||||
|
})
|
||||||
|
|
||||||
|
if hf_config.model_type == "ernie4_5_moe":
|
||||||
|
hf_config.model_type = "ernie_mtp"
|
||||||
|
if hf_config.model_type == "ernie_mtp":
|
||||||
|
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||||
|
hf_config.update({
|
||||||
|
"n_predict": n_predict,
|
||||||
|
"architectures": ["ErnieMTPModel"]
|
||||||
|
})
|
||||||
|
|
||||||
|
if hf_config.model_type == "qwen3_next":
|
||||||
|
hf_config.model_type = "qwen3_next_mtp"
|
||||||
|
if hf_config.model_type == "qwen3_next_mtp":
|
||||||
|
n_predict = getattr(hf_config, "num_nextn_predict_layers", None)
|
||||||
|
hf_config.update({
|
||||||
|
"n_predict": n_predict,
|
||||||
|
"architectures": ["Qwen3NextMTP"]
|
||||||
|
})
|
||||||
|
if hf_config.model_type == "longcat_flash":
|
||||||
|
hf_config.model_type = "longcat_flash_mtp"
|
||||||
|
n_predict = getattr(hf_config, "num_nextn_predict_layers", 1)
|
||||||
|
hf_config.update({
|
||||||
|
"n_predict": n_predict,
|
||||||
|
"architectures": ["LongCatFlashMTPModel"]
|
||||||
|
})
|
||||||
|
|
||||||
|
return hf_config
|
||||||
|
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
|
||||||
|
# Note: "method" is a new parameter that helps to extend the
|
||||||
|
# configuration of non-model-based proposers, and the "model" parameter
|
||||||
|
# will be used to set the draft model, eagle head, or additional weight
|
||||||
|
# when needed. If users do not specify "method", the speculative method
|
||||||
|
# will be detected automatically if possible. If the speculative method
|
||||||
|
# can not be detected, it will be considered as the "draft_model" by
|
||||||
|
# default.
|
||||||
|
|
||||||
|
if self.model is None and self.num_speculative_tokens is not None:
|
||||||
|
# TODO(Shangming): Refactor mtp configuration logic when supporting
|
||||||
|
if (self.target_model_config
|
||||||
|
and self.target_model_config.hf_text_config.model_type
|
||||||
|
in ("deepseek_v3", "deepseek_v32", "mimo", "ernie4_5_moe",
|
||||||
|
"qwen3_next")):
|
||||||
|
# use the draft model from the same model:
|
||||||
|
self.model = self.target_model_config.model
|
||||||
|
# Align the quantization of draft model for cases such as
|
||||||
|
# --quantization fp8 with a bf16 checkpoint.
|
||||||
|
if not self.quantization:
|
||||||
|
self.quantization = self.target_model_config.quantization
|
||||||
|
elif self.method in ("ngram", "[ngram]"):
|
||||||
|
self.model = "ngram"
|
||||||
|
else:
|
||||||
|
raise ValueError("num_speculative_tokens was provided but without "
|
||||||
|
"speculative model.")
|
||||||
|
|
||||||
|
# Automatically configure the method for ngram when "model" is used
|
||||||
|
# instead of "method"
|
||||||
|
if self.method is None and (self.model is not None
|
||||||
|
and self.model in ("ngram", "[ngram]")):
|
||||||
|
self.method = "ngram"
|
||||||
|
|
||||||
|
if self.method in ("ngram", "[ngram]"):
|
||||||
|
# Unified to "ngram" internally
|
||||||
|
self.method = "ngram"
|
||||||
|
# Set default values if not provided
|
||||||
|
if (self.prompt_lookup_min is None and self.prompt_lookup_max is None):
|
||||||
|
# TODO(woosuk): Tune these values. They are arbitrarily chosen.
|
||||||
|
self.prompt_lookup_min = 5
|
||||||
|
self.prompt_lookup_max = 5
|
||||||
|
elif self.prompt_lookup_min is None:
|
||||||
|
assert self.prompt_lookup_max is not None
|
||||||
|
self.prompt_lookup_min = self.prompt_lookup_max
|
||||||
|
elif self.prompt_lookup_max is None:
|
||||||
|
assert self.prompt_lookup_min is not None
|
||||||
|
self.prompt_lookup_max = self.prompt_lookup_min
|
||||||
|
|
||||||
|
# Validate values
|
||||||
|
if self.prompt_lookup_min < 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"prompt_lookup_min={self.prompt_lookup_min} must be > 0")
|
||||||
|
if self.prompt_lookup_max < 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"prompt_lookup_max={self.prompt_lookup_max} must be > 0")
|
||||||
|
if self.prompt_lookup_min > self.prompt_lookup_max:
|
||||||
|
raise ValueError(
|
||||||
|
f"prompt_lookup_min={self.prompt_lookup_min} must "
|
||||||
|
f"be <= prompt_lookup_max={self.prompt_lookup_max}")
|
||||||
|
|
||||||
|
# TODO: current we still need extract vocab_size from target model
|
||||||
|
# config, in future, we may try refactor it out, and set
|
||||||
|
# draft related config as None here.
|
||||||
|
self.draft_model_config = self.target_model_config
|
||||||
|
self.draft_parallel_config = self.target_parallel_config
|
||||||
|
else:
|
||||||
|
self.prompt_lookup_max = 0
|
||||||
|
self.prompt_lookup_min = 0
|
||||||
|
|
||||||
|
if self.model is not None:
|
||||||
|
# TODO: Move this import to the top once `ModelConfig`
|
||||||
|
# lives in `vllm.config.model`.
|
||||||
|
from vllm.config import ModelConfig
|
||||||
|
self.draft_model_config = ModelConfig(
|
||||||
|
model=self.model,
|
||||||
|
runner="draft",
|
||||||
|
tokenizer=self.target_model_config.tokenizer,
|
||||||
|
tokenizer_mode=self.target_model_config.tokenizer_mode,
|
||||||
|
trust_remote_code=self.target_model_config.trust_remote_code,
|
||||||
|
allowed_local_media_path=self.target_model_config.
|
||||||
|
allowed_local_media_path,
|
||||||
|
allowed_media_domains=self.target_model_config.
|
||||||
|
allowed_media_domains,
|
||||||
|
dtype=self.target_model_config.dtype,
|
||||||
|
seed=self.target_model_config.seed,
|
||||||
|
revision=self.revision,
|
||||||
|
code_revision=self.code_revision,
|
||||||
|
tokenizer_revision=self.target_model_config.tokenizer_revision,
|
||||||
|
spec_target_max_model_len=self.target_model_config.
|
||||||
|
max_model_len,
|
||||||
|
quantization=self.quantization,
|
||||||
|
enforce_eager=self.target_model_config.enforce_eager,
|
||||||
|
max_logprobs=self.target_model_config.max_logprobs,
|
||||||
|
hf_overrides=SpeculativeConfig.hf_config_override,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Automatically detect the method
|
||||||
|
if self.method in ('eagle', 'eagle3'):
|
||||||
|
pass
|
||||||
|
# examples:
|
||||||
|
# yuhuili/EAGLE-LLaMA3-Instruct-8B
|
||||||
|
# yuhuili/EAGLE3-LLaMA3.1-Instruct-8B
|
||||||
|
# AngelSlim/Qwen3-8B_eagle3
|
||||||
|
elif "eagle-" in self.draft_model_config.model.lower():
|
||||||
|
self.method = "eagle"
|
||||||
|
elif "eagle3" in self.draft_model_config.model.lower():
|
||||||
|
self.method = "eagle3"
|
||||||
|
elif self.draft_model_config.hf_config.model_type == "medusa":
|
||||||
|
self.method = "medusa"
|
||||||
|
elif (self.draft_model_config.hf_config.model_type ==
|
||||||
|
"mlp_speculator"):
|
||||||
|
self.method = "mlp_speculator"
|
||||||
|
elif (self.draft_model_config.hf_config.model_type
|
||||||
|
in ("deepseek_mtp", "mimo_mtp", "glm4_moe_mtp")):
|
||||||
|
self.method = "deepseek_mtp"
|
||||||
|
if self.num_speculative_tokens > 1:
|
||||||
|
logger.warning(
|
||||||
|
"All Deepseek MTP models only have " \
|
||||||
|
"one layer. Might need some code changes " \
|
||||||
|
"to support multiple layers."
|
||||||
|
)
|
||||||
|
elif (self.draft_model_config.hf_config.model_type == "ernie_mtp"):
|
||||||
|
self.method = "ernie_mtp"
|
||||||
|
if self.num_speculative_tokens > 1:
|
||||||
|
logger.warning(
|
||||||
|
"All Ernie MTP models only have " \
|
||||||
|
"one layer. Might need some code changes " \
|
||||||
|
"to support multiple layers."
|
||||||
|
)
|
||||||
|
elif (self.draft_model_config.hf_config.model_type ==
|
||||||
|
"qwen3_next_mtp"):
|
||||||
|
self.method = "qwen3_next_mtp"
|
||||||
|
if self.num_speculative_tokens > 1:
|
||||||
|
logger.warning(
|
||||||
|
"All Qwen3Next MTP models only have " \
|
||||||
|
"one layer. Might need some code changes " \
|
||||||
|
"to support multiple layers."
|
||||||
|
)
|
||||||
|
elif (self.draft_model_config.hf_config.model_type
|
||||||
|
in ("longcat_flash_mtp")):
|
||||||
|
self.method = "longcat_flash_mtp"
|
||||||
|
if self.num_speculative_tokens > 1:
|
||||||
|
logger.warning(
|
||||||
|
"LongCat MTP models only have " \
|
||||||
|
"one layer. Might need some code changes " \
|
||||||
|
"to support multiple layers."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.method = "draft_model"
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Speculative decoding with draft model is not "
|
||||||
|
"supported yet. Please consider using other "
|
||||||
|
"speculative decoding methods such as ngram, medusa, "
|
||||||
|
"eagle, or deepseek_mtp.")
|
||||||
|
|
||||||
|
# Replace hf_config for EAGLE draft_model
|
||||||
|
if self.method in ("eagle", "eagle3"):
|
||||||
|
if self.enable_chunked_prefill and not envs.VLLM_USE_V1:
|
||||||
|
raise ValueError(
|
||||||
|
"Chunked prefill and EAGLE are not compatible "
|
||||||
|
"when using V0.")
|
||||||
|
|
||||||
|
from vllm.transformers_utils.configs import SpeculatorsConfig
|
||||||
|
from vllm.transformers_utils.configs.eagle import EAGLEConfig
|
||||||
|
|
||||||
|
if isinstance(self.draft_model_config.hf_config,
|
||||||
|
(EAGLEConfig, SpeculatorsConfig)):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
eagle_config = EAGLEConfig(
|
||||||
|
self.draft_model_config.hf_config,
|
||||||
|
method=self.method,
|
||||||
|
model_type="eagle")
|
||||||
|
self.draft_model_config.hf_config = eagle_config
|
||||||
|
|
||||||
|
if (self.num_speculative_tokens is not None
|
||||||
|
and hasattr(self.draft_model_config.hf_config,
|
||||||
|
"num_lookahead_tokens")):
|
||||||
|
self.draft_model_config.hf_config.num_lookahead_tokens = \
|
||||||
|
self.num_speculative_tokens
|
||||||
|
|
||||||
|
n_predict = getattr(self.draft_model_config.hf_config, "n_predict",
|
||||||
|
None)
|
||||||
|
if n_predict is not None:
|
||||||
|
if self.num_speculative_tokens is None:
|
||||||
|
# Default to max value defined in draft model config.
|
||||||
|
self.num_speculative_tokens = n_predict
|
||||||
|
elif self.num_speculative_tokens > n_predict and \
|
||||||
|
self.num_speculative_tokens % n_predict != 0:
|
||||||
|
# Ensure divisibility for MTP module reuse.
|
||||||
|
raise ValueError(
|
||||||
|
f"num_speculative_tokens:{self.num_speculative_tokens}"
|
||||||
|
f" must be divisible by {n_predict=}")
|
||||||
|
|
||||||
|
if self.speculative_token_tree is None:
|
||||||
|
# Generate chain of tokens.
|
||||||
|
self.speculative_token_tree = str([
|
||||||
|
(i + 1) * (0, ) for i in range(self.num_speculative_tokens)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
# Sort the token tree breadth-first.
|
||||||
|
tree_choices = ast.literal_eval(self.speculative_token_tree)
|
||||||
|
self.speculative_token_tree = str(
|
||||||
|
sorted(tree_choices, key=lambda t: (len(t), t)))
|
||||||
|
|
||||||
|
self.draft_tensor_parallel_size = \
|
||||||
|
SpeculativeConfig._verify_and_get_draft_tp(
|
||||||
|
self.target_parallel_config,
|
||||||
|
self.draft_tensor_parallel_size,
|
||||||
|
self.draft_model_config.hf_config
|
||||||
|
)
|
||||||
|
|
||||||
|
self.draft_model_config.max_model_len = (
|
||||||
|
SpeculativeConfig._maybe_override_draft_max_model_len(
|
||||||
|
self.max_model_len,
|
||||||
|
self.draft_model_config.max_model_len,
|
||||||
|
self.target_model_config.max_model_len,
|
||||||
|
))
|
||||||
|
|
||||||
|
self.draft_parallel_config = (
|
||||||
|
SpeculativeConfig.create_draft_parallel_config(
|
||||||
|
self.target_parallel_config,
|
||||||
|
self.draft_tensor_parallel_size))
|
||||||
|
|
||||||
|
|
||||||
|
ModelConfig.is_deepseek_mla = is_deepseek_mla
|
||||||
|
SpeculativeConfig.__post_init__ = __post_init__
|
||||||
|
SpeculativeConfig.hf_config_override = hf_config_override
|
||||||
@@ -6,6 +6,8 @@ from vllm.model_executor.models.config import MambaModelConfig
|
|||||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv
|
||||||
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
|
from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec
|
||||||
|
|
||||||
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def verify_and_update_config(cls, vllm_config) -> None:
|
def verify_and_update_config(cls, vllm_config) -> None:
|
||||||
@@ -22,6 +24,7 @@ def verify_and_update_config(cls, vllm_config) -> None:
|
|||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
# Enable FULL_AND_PIECEWISE by default
|
# Enable FULL_AND_PIECEWISE by default
|
||||||
MambaModelConfig.verify_and_update_config(vllm_config)
|
MambaModelConfig.verify_and_update_config(vllm_config)
|
||||||
|
ascend_config = get_ascend_config()
|
||||||
|
|
||||||
cache_config = vllm_config.cache_config
|
cache_config = vllm_config.cache_config
|
||||||
model_config = vllm_config.model_config
|
model_config = vllm_config.model_config
|
||||||
@@ -38,7 +41,7 @@ def verify_and_update_config(cls, vllm_config) -> None:
|
|||||||
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
num_kv_heads=model_config.get_num_kv_heads(parallel_config),
|
||||||
head_size=model_config.get_head_size(),
|
head_size=model_config.get_head_size(),
|
||||||
dtype=kv_cache_dtype,
|
dtype=kv_cache_dtype,
|
||||||
use_mla=model_config.use_mla).page_size_bytes
|
use_mla=model_config.use_mla or ascend_config.use_sfa).page_size_bytes
|
||||||
|
|
||||||
model_cls, _ = ModelRegistry.resolve_model_cls(
|
model_cls, _ = ModelRegistry.resolve_model_cls(
|
||||||
model_config.architecture,
|
model_config.architecture,
|
||||||
|
|||||||
@@ -0,0 +1,200 @@
|
|||||||
|
import vllm.transformers_utils.configs
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
from transformers.utils import logging
|
||||||
|
from vllm.transformers_utils import config
|
||||||
|
|
||||||
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepseekV3Config(PretrainedConfig):
|
||||||
|
r"""
|
||||||
|
This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek
|
||||||
|
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||||
|
defaults will yield a similar configuration to that of the DeepSeek-V3.
|
||||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||||
|
documentation from [`PretrainedConfig`] for more information.
|
||||||
|
Args:
|
||||||
|
vocab_size (`int`, *optional*, defaults to 129280):
|
||||||
|
Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the
|
||||||
|
`inputs_ids` passed when calling [`DeepseekV3Model`]
|
||||||
|
hidden_size (`int`, *optional*, defaults to 4096):
|
||||||
|
Dimension of the hidden representations.
|
||||||
|
intermediate_size (`int`, *optional*, defaults to 11008):
|
||||||
|
Dimension of the MLP representations.
|
||||||
|
moe_intermediate_size (`int`, *optional*, defaults to 1407):
|
||||||
|
Dimension of the MoE representations.
|
||||||
|
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||||
|
Number of hidden layers in the Transformer decoder.
|
||||||
|
num_nextn_predict_layers (`int`, *optional*, defaults to 1):
|
||||||
|
Number of nextn predict layers in the DeepSeekV3 Model.
|
||||||
|
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||||
|
Number of attention heads for each attention layer in the Transformer decoder.
|
||||||
|
n_shared_experts (`int`, *optional*, defaults to None):
|
||||||
|
Number of shared experts, None means dense model.
|
||||||
|
n_routed_experts (`int`, *optional*, defaults to None):
|
||||||
|
Number of routed experts, None means dense model.
|
||||||
|
routed_scaling_factor (`float`, *optional*, defaults to 1.0):
|
||||||
|
Scaling factor or routed experts.
|
||||||
|
topk_method (`str`, *optional*, defaults to `gready`):
|
||||||
|
Topk method used in routed gate.
|
||||||
|
n_group (`int`, *optional*, defaults to None):
|
||||||
|
Number of groups for routed experts.
|
||||||
|
topk_group (`int`, *optional*, defaults to None):
|
||||||
|
Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups).
|
||||||
|
num_experts_per_tok (`int`, *optional*, defaults to None):
|
||||||
|
Number of selected experts, None means dense model.
|
||||||
|
moe_layer_freq (`int`, *optional*, defaults to 1):
|
||||||
|
The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers.
|
||||||
|
first_k_dense_replace (`int`, *optional*, defaults to 0):
|
||||||
|
Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head).
|
||||||
|
\--k dense layers--/
|
||||||
|
norm_topk_prob (`bool`, *optional*, defaults to False):
|
||||||
|
Whether to normalize the weights of the routed experts.
|
||||||
|
scoring_func (`str`, *optional*, defaults to 'softmax'):
|
||||||
|
Method of computing expert weights.
|
||||||
|
aux_loss_alpha (`float`, *optional*, defaults to 0.001):
|
||||||
|
Auxiliary loss weight coefficient.
|
||||||
|
seq_aux = (`bool`, *optional*, defaults to True):
|
||||||
|
Whether to compute the auxiliary loss for each individual sample.
|
||||||
|
num_key_value_heads (`int`, *optional*):
|
||||||
|
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||||
|
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||||
|
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||||
|
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||||
|
by meanpooling all the original heads within that group. For more details checkout [this
|
||||||
|
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||||
|
`num_attention_heads`.
|
||||||
|
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||||
|
The non-linear activation function (function or string) in the decoder.
|
||||||
|
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||||
|
The maximum sequence length that this model might ever be used with.
|
||||||
|
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||||
|
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||||
|
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||||
|
The epsilon used by the rms normalization layers.
|
||||||
|
use_cache (`bool`, *optional*, defaults to `True`):
|
||||||
|
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||||
|
relevant if `config.is_decoder=True`.
|
||||||
|
pad_token_id (`int`, *optional*):
|
||||||
|
Padding token id.
|
||||||
|
bos_token_id (`int`, *optional*, defaults to 1):
|
||||||
|
Beginning of stream token id.
|
||||||
|
eos_token_id (`int`, *optional*, defaults to 2):
|
||||||
|
End of stream token id.
|
||||||
|
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||||
|
Whether to tie weight embeddings
|
||||||
|
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||||
|
The base period of the RoPE embeddings.
|
||||||
|
rope_scaling (`Dict`, *optional*):
|
||||||
|
Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling
|
||||||
|
strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is
|
||||||
|
`{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update
|
||||||
|
`max_position_embeddings` to the expected new maximum.
|
||||||
|
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||||
|
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||||
|
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||||
|
The dropout ratio for the attention probabilities.
|
||||||
|
```python
|
||||||
|
>>> from transformers import DeepseekV3Model, DeepseekV3Config
|
||||||
|
>>> # Initializing a Deepseek-V3 style configuration
|
||||||
|
>>> configuration = DeepseekV3Config()
|
||||||
|
>>> # Accessing the model configuration
|
||||||
|
>>> configuration = model.config
|
||||||
|
```"""
|
||||||
|
|
||||||
|
model_type = "deepseek_v3"
|
||||||
|
keys_to_ignore_at_inference = ["past_key_values"]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=129280,
|
||||||
|
hidden_size=7168,
|
||||||
|
intermediate_size=18432,
|
||||||
|
moe_intermediate_size=2048,
|
||||||
|
num_hidden_layers=61,
|
||||||
|
num_nextn_predict_layers=1,
|
||||||
|
num_attention_heads=128,
|
||||||
|
num_key_value_heads=128,
|
||||||
|
n_shared_experts=1,
|
||||||
|
n_routed_experts=256,
|
||||||
|
ep_size=1,
|
||||||
|
routed_scaling_factor=2.5,
|
||||||
|
kv_lora_rank=512,
|
||||||
|
q_lora_rank=1536,
|
||||||
|
qk_rope_head_dim=64,
|
||||||
|
v_head_dim=128,
|
||||||
|
qk_nope_head_dim=128,
|
||||||
|
topk_method='noaux_tc',
|
||||||
|
n_group=8,
|
||||||
|
topk_group=4,
|
||||||
|
num_experts_per_tok=8,
|
||||||
|
moe_layer_freq=1,
|
||||||
|
first_k_dense_replace=3,
|
||||||
|
norm_topk_prob=True,
|
||||||
|
scoring_func='sigmoid',
|
||||||
|
hidden_act="silu",
|
||||||
|
max_position_embeddings=4096,
|
||||||
|
initializer_range=0.02,
|
||||||
|
rms_norm_eps=1e-6,
|
||||||
|
use_cache=True,
|
||||||
|
pad_token_id=None,
|
||||||
|
bos_token_id=0,
|
||||||
|
eos_token_id=1,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
rope_theta=10000.0,
|
||||||
|
rope_scaling=None,
|
||||||
|
attention_bias=False,
|
||||||
|
attention_dropout=0.0,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.intermediate_size = intermediate_size
|
||||||
|
self.moe_intermediate_size = moe_intermediate_size
|
||||||
|
self.num_hidden_layers = num_hidden_layers
|
||||||
|
self.num_nextn_predict_layers = num_nextn_predict_layers
|
||||||
|
self.num_attention_heads = num_attention_heads
|
||||||
|
self.n_shared_experts = n_shared_experts
|
||||||
|
self.n_routed_experts = n_routed_experts
|
||||||
|
self.ep_size = ep_size
|
||||||
|
self.routed_scaling_factor = routed_scaling_factor
|
||||||
|
self.kv_lora_rank = kv_lora_rank
|
||||||
|
self.q_lora_rank = q_lora_rank
|
||||||
|
self.qk_rope_head_dim = qk_rope_head_dim
|
||||||
|
self.v_head_dim = v_head_dim
|
||||||
|
self.qk_nope_head_dim = qk_nope_head_dim
|
||||||
|
self.topk_method = topk_method
|
||||||
|
self.n_group = n_group
|
||||||
|
self.topk_group = topk_group
|
||||||
|
self.num_experts_per_tok = num_experts_per_tok
|
||||||
|
self.moe_layer_freq = moe_layer_freq
|
||||||
|
self.first_k_dense_replace = first_k_dense_replace
|
||||||
|
self.norm_topk_prob = norm_topk_prob
|
||||||
|
self.scoring_func = scoring_func
|
||||||
|
# for backward compatibility
|
||||||
|
if num_key_value_heads is None:
|
||||||
|
num_key_value_heads = num_attention_heads
|
||||||
|
|
||||||
|
self.num_key_value_heads = num_key_value_heads
|
||||||
|
self.hidden_act = hidden_act
|
||||||
|
self.initializer_range = initializer_range
|
||||||
|
self.rms_norm_eps = rms_norm_eps
|
||||||
|
self.use_cache = use_cache
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.rope_scaling = rope_scaling
|
||||||
|
self.attention_bias = attention_bias
|
||||||
|
self.attention_dropout = attention_dropout
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
vllm.transformers_utils.configs.__all__.append("DeepseekV3Config")
|
||||||
|
vllm.transformers_utils.configs.DeepseekV3Config = DeepseekV3Config
|
||||||
|
config._CONFIG_REGISTRY["deepseek_v32"] = "DeepseekV3Config"
|
||||||
@@ -20,6 +20,10 @@ from vllm.triton_utils import HAS_TRITON
|
|||||||
if HAS_TRITON:
|
if HAS_TRITON:
|
||||||
import vllm_ascend.patch.worker.patch_common.patch_triton
|
import vllm_ascend.patch.worker.patch_common.patch_triton
|
||||||
|
|
||||||
|
# isort: off
|
||||||
|
import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa
|
||||||
|
import vllm_ascend.patch.worker.patch_common.patch_attentionspec # noqa
|
||||||
|
import vllm_ascend.patch.worker.patch_common.patch_attention_layer # noqa
|
||||||
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
|
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
|
||||||
import vllm_ascend.patch.worker.patch_common.patch_logits # noqa
|
import vllm_ascend.patch.worker.patch_common.patch_logits # noqa
|
||||||
import vllm_ascend.patch.worker.patch_common.patch_weight_loader # noqa
|
import vllm_ascend.patch.worker.patch_common.patch_weight_loader # noqa
|
||||||
|
|||||||
202
vllm_ascend/patch/worker/patch_common/patch_attention_layer.py
Normal file
202
vllm_ascend/patch/worker/patch_common/patch_attention_layer.py
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import vllm
|
||||||
|
import vllm.envs as envs
|
||||||
|
from torch import nn
|
||||||
|
from vllm.attention import Attention, AttentionType, get_attn_backend
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
from vllm.attention.selector import backend_name_to_enum
|
||||||
|
from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target
|
||||||
|
from vllm.config import CacheConfig, get_current_vllm_config
|
||||||
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||||
|
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||||
|
from vllm.model_executor.layers.quantization.base_config import \
|
||||||
|
QuantizationConfig
|
||||||
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
|
from vllm_ascend.utils import vllm_version_is
|
||||||
|
|
||||||
|
|
||||||
|
class AscendAttention(Attention, nn.Module, AttentionLayerBase):
|
||||||
|
"""Attention layer.
|
||||||
|
|
||||||
|
This class takes query, key, and value tensors as input. The input tensors
|
||||||
|
can either contain prompt tokens or generation tokens.
|
||||||
|
The class does the following:
|
||||||
|
|
||||||
|
1. Store the input key and value tensors in the KV cache.
|
||||||
|
2. Perform (multi-head/multi-query/grouped-query) attention.
|
||||||
|
3. Return the output tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
num_heads: int,
|
||||||
|
head_size: int,
|
||||||
|
scale: float,
|
||||||
|
num_kv_heads: Optional[int] = None,
|
||||||
|
alibi_slopes: Optional[List[float]] = None,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
logits_soft_cap: Optional[float] = None,
|
||||||
|
per_layer_sliding_window: Optional[int] = None,
|
||||||
|
use_mla: bool = False,
|
||||||
|
use_sfa: bool = False,
|
||||||
|
prefix: str = "",
|
||||||
|
attn_type: str = AttentionType.DECODER,
|
||||||
|
kv_sharing_target_layer_name: Optional[str] = None,
|
||||||
|
attn_backend: Optional[type[AttentionBackend]] = None,
|
||||||
|
**extra_impl_args,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
The KV cache is stored inside this class and is accessed via
|
||||||
|
`self.kv_cache`.
|
||||||
|
"""
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
AttentionLayerBase.__init__(self)
|
||||||
|
|
||||||
|
if per_layer_sliding_window is not None:
|
||||||
|
# per-layer sliding window
|
||||||
|
sliding_window = per_layer_sliding_window
|
||||||
|
elif cache_config is not None:
|
||||||
|
# model-level sliding window
|
||||||
|
sliding_window = cache_config.sliding_window
|
||||||
|
else:
|
||||||
|
sliding_window = None
|
||||||
|
|
||||||
|
if cache_config is not None:
|
||||||
|
kv_cache_dtype = cache_config.cache_dtype
|
||||||
|
block_size = cache_config.block_size
|
||||||
|
is_attention_free = cache_config.is_attention_free
|
||||||
|
calculate_kv_scales = cache_config.calculate_kv_scales
|
||||||
|
else:
|
||||||
|
kv_cache_dtype = "auto"
|
||||||
|
block_size = 16
|
||||||
|
is_attention_free = False
|
||||||
|
calculate_kv_scales = False
|
||||||
|
if num_kv_heads is None:
|
||||||
|
num_kv_heads = num_heads
|
||||||
|
assert num_heads % num_kv_heads == 0, \
|
||||||
|
f"num_heads ({num_heads}) is not " \
|
||||||
|
f"divisible by num_kv_heads ({num_kv_heads})"
|
||||||
|
|
||||||
|
# The default k/v_scale is set to 1.0. This is ignored
|
||||||
|
# when kv-cache is not fp8, and should be used with
|
||||||
|
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
|
||||||
|
# expect the pre-quantized k/v_scale to be loaded along
|
||||||
|
# with the model weights.
|
||||||
|
self.kv_cache_dtype = kv_cache_dtype
|
||||||
|
self.calculate_kv_scales = calculate_kv_scales
|
||||||
|
self._k_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||||
|
self._v_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||||
|
# FlashAttn doesn't support quantizing the kv-cache only
|
||||||
|
# but requires q to be quantized as well.
|
||||||
|
self._q_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||||
|
self._prob_scale = torch.tensor(1.0, dtype=torch.float32)
|
||||||
|
|
||||||
|
# We also keep q/k/v_scale on host (cpu) memory for attention
|
||||||
|
# backends that require the scales to be on host instead of on device.
|
||||||
|
# e.g. Flashinfer
|
||||||
|
self._q_scale_float = 1.0
|
||||||
|
self._k_scale_float = 1.0
|
||||||
|
self._v_scale_float = 1.0
|
||||||
|
|
||||||
|
# The output scale on host memory. This should be the input scale of
|
||||||
|
# the quant op after this attention layer.
|
||||||
|
self._o_scale_float: Optional[float] = None
|
||||||
|
|
||||||
|
self.use_mla = use_mla
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.head_size = head_size
|
||||||
|
self.num_kv_heads = num_kv_heads
|
||||||
|
self.sliding_window = sliding_window
|
||||||
|
self.has_sink = extra_impl_args.get("sinks") is not None
|
||||||
|
|
||||||
|
quant_method = quant_config.get_quant_method(
|
||||||
|
self, prefix=prefix) if quant_config else None
|
||||||
|
if quant_method is not None and not isinstance(
|
||||||
|
quant_method, UnquantizedLinearMethod):
|
||||||
|
assert isinstance(quant_method, BaseKVCacheMethod)
|
||||||
|
# TODO (mgoin): kv cache dtype should be specified in the FP8
|
||||||
|
# checkpoint config and become the "auto" behavior
|
||||||
|
if self.kv_cache_dtype == "fp8_e5m2":
|
||||||
|
raise ValueError("fp8_e5m2 kv-cache is not supported with "
|
||||||
|
"fp8 checkpoints.")
|
||||||
|
# If quantization is enabled, we make "k_scale" and "v_scale"
|
||||||
|
# parameters so that it can be loaded from the model checkpoint.
|
||||||
|
# The k/v_scale will then be converted back to native float32
|
||||||
|
# values after weight loading.
|
||||||
|
self.quant_method = quant_method
|
||||||
|
self.quant_method.create_weights(self)
|
||||||
|
|
||||||
|
# During model initialization, the default dtype is set as the model
|
||||||
|
# weight and activation dtype.
|
||||||
|
dtype = torch.get_default_dtype()
|
||||||
|
if attn_backend is None:
|
||||||
|
if vllm_version_is("0.10.2"):
|
||||||
|
self.attn_backend = get_attn_backend(head_size,
|
||||||
|
dtype,
|
||||||
|
kv_cache_dtype,
|
||||||
|
block_size,
|
||||||
|
is_attention_free,
|
||||||
|
use_mla=use_mla,
|
||||||
|
use_sfa=use_sfa,
|
||||||
|
has_sink=self.has_sink)
|
||||||
|
else:
|
||||||
|
self.attn_backend = get_attn_backend(head_size,
|
||||||
|
dtype,
|
||||||
|
kv_cache_dtype,
|
||||||
|
block_size,
|
||||||
|
use_mla=use_mla,
|
||||||
|
use_sfa=use_sfa,
|
||||||
|
has_sink=self.has_sink)
|
||||||
|
else:
|
||||||
|
self.attn_backend = attn_backend
|
||||||
|
|
||||||
|
impl_cls = self.attn_backend.get_impl_cls()
|
||||||
|
self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads,
|
||||||
|
alibi_slopes, sliding_window, kv_cache_dtype,
|
||||||
|
logits_soft_cap, attn_type,
|
||||||
|
kv_sharing_target_layer_name, **extra_impl_args)
|
||||||
|
self.backend = backend_name_to_enum(self.attn_backend.get_name())
|
||||||
|
self.dtype = dtype
|
||||||
|
|
||||||
|
# For cuda-alike (CUDA and ROCM) and cpu platforms, we control how
|
||||||
|
# torch.compile works by registering the attention as one giant
|
||||||
|
# opaque custom op. For other platforms, we directly call them
|
||||||
|
# and let torch.compile handle them.
|
||||||
|
self.use_direct_call = not current_platform.opaque_attention_op()
|
||||||
|
|
||||||
|
self.use_output = self.attn_backend.accept_output_buffer
|
||||||
|
compilation_config = get_current_vllm_config().compilation_config
|
||||||
|
if prefix in compilation_config.static_forward_context:
|
||||||
|
raise ValueError(f"Duplicate layer name: {prefix}")
|
||||||
|
compilation_config.static_forward_context[prefix] = self
|
||||||
|
self.layer_name = prefix
|
||||||
|
self.attn_type = attn_type
|
||||||
|
|
||||||
|
if kv_sharing_target_layer_name is not None:
|
||||||
|
validate_kv_sharing_target(
|
||||||
|
prefix,
|
||||||
|
kv_sharing_target_layer_name,
|
||||||
|
compilation_config.static_forward_context,
|
||||||
|
)
|
||||||
|
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name
|
||||||
|
|
||||||
|
# use a placeholder kv cache tensor during init, which will be replaced
|
||||||
|
# by bind_kv_cache
|
||||||
|
# this variable will not be accessed if use_direct_call is True
|
||||||
|
self.kv_cache = [
|
||||||
|
torch.tensor([]) for _ in range(get_current_vllm_config(
|
||||||
|
).parallel_config.pipeline_parallel_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32)
|
||||||
|
self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32)
|
||||||
|
self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32)
|
||||||
|
self.query_quant = None
|
||||||
|
|
||||||
|
|
||||||
|
vllm.attention.Attention = AscendAttention
|
||||||
@@ -0,0 +1,181 @@
|
|||||||
|
#
|
||||||
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||||
|
# This file is a part of the vllm-ascend project.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
#
|
||||||
|
# mypy: ignore-errors
|
||||||
|
from functools import cache
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import vllm
|
||||||
|
import vllm.envs as envs
|
||||||
|
from vllm.attention.backends.abstract import AttentionBackend
|
||||||
|
from vllm.attention.selector import (backend_name_to_enum,
|
||||||
|
get_global_forced_attn_backend)
|
||||||
|
from vllm.platforms import _Backend, current_platform
|
||||||
|
from vllm.utils import resolve_obj_by_qualname
|
||||||
|
|
||||||
|
from vllm_ascend.utils import vllm_version_is
|
||||||
|
|
||||||
|
if vllm_version_is("0.10.2"):
|
||||||
|
|
||||||
|
def get_attn_backend(
|
||||||
|
head_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
kv_cache_dtype: Optional[str],
|
||||||
|
block_size: int,
|
||||||
|
is_attention_free: bool = False,
|
||||||
|
use_mla: bool = False,
|
||||||
|
use_sfa: bool = False,
|
||||||
|
has_sink: bool = False,
|
||||||
|
) -> type[AttentionBackend]:
|
||||||
|
"""Selects which attention backend to use and lazily imports it."""
|
||||||
|
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
|
||||||
|
# value to be returned from the cache if the value changes between calls.
|
||||||
|
# To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the
|
||||||
|
# private function.
|
||||||
|
return _cached_get_attn_backend(
|
||||||
|
head_size=head_size,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
block_size=block_size,
|
||||||
|
is_attention_free=is_attention_free,
|
||||||
|
use_v1=envs.VLLM_USE_V1,
|
||||||
|
use_mla=use_mla,
|
||||||
|
use_sfa=use_sfa,
|
||||||
|
has_sink=has_sink,
|
||||||
|
)
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def _cached_get_attn_backend(
|
||||||
|
head_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
kv_cache_dtype: Optional[str],
|
||||||
|
block_size: int,
|
||||||
|
is_attention_free: bool,
|
||||||
|
use_v1: bool = False,
|
||||||
|
use_mla: bool = False,
|
||||||
|
use_sfa: bool = False,
|
||||||
|
has_sink: bool = False,
|
||||||
|
) -> type[AttentionBackend]:
|
||||||
|
# If there are no attention layers (e.g. we are running Mamba),
|
||||||
|
# use the placeholder NO_ATTENTION
|
||||||
|
if is_attention_free:
|
||||||
|
from vllm.attention.backends.placeholder_attn import \
|
||||||
|
PlaceholderAttentionBackend
|
||||||
|
return PlaceholderAttentionBackend
|
||||||
|
|
||||||
|
# Check whether a particular choice of backend was
|
||||||
|
# previously forced.
|
||||||
|
#
|
||||||
|
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
|
||||||
|
# ENVIRONMENT VARIABLE.
|
||||||
|
selected_backend = None
|
||||||
|
backend_by_global_setting: Optional[_Backend] = (
|
||||||
|
get_global_forced_attn_backend())
|
||||||
|
if backend_by_global_setting is not None:
|
||||||
|
selected_backend = backend_by_global_setting
|
||||||
|
else:
|
||||||
|
# Check the environment variable and override if specified
|
||||||
|
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
|
||||||
|
if backend_by_env_var is not None:
|
||||||
|
selected_backend = backend_name_to_enum(backend_by_env_var)
|
||||||
|
if selected_backend is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid attention backend: '{backend_by_env_var}'. "
|
||||||
|
f"Valid backends are: {list(_Backend.__members__.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# get device-specific attn_backend
|
||||||
|
attention_cls = current_platform.get_attn_backend_cls(
|
||||||
|
selected_backend, head_size, dtype, kv_cache_dtype, block_size,
|
||||||
|
use_v1, use_mla, use_sfa, has_sink)
|
||||||
|
if not attention_cls:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid attention backend for {current_platform.device_name}"
|
||||||
|
)
|
||||||
|
return resolve_obj_by_qualname(attention_cls)
|
||||||
|
else:
|
||||||
|
|
||||||
|
def get_attn_backend( # type: ignore[misc]
|
||||||
|
head_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
kv_cache_dtype: Optional[str],
|
||||||
|
block_size: int,
|
||||||
|
use_mla: bool = False,
|
||||||
|
use_sfa: bool = False,
|
||||||
|
has_sink: bool = False,
|
||||||
|
) -> type[AttentionBackend]:
|
||||||
|
"""Selects which attention backend to use and lazily imports it."""
|
||||||
|
# Accessing envs.* behind an @lru_cache decorator can cause the wrong
|
||||||
|
# value to be returned from the cache if the value changes between calls.
|
||||||
|
# To avoid this, we read envs.VLLM_USE_V1 here and pass it explicitly to the
|
||||||
|
# private function.
|
||||||
|
return _cached_get_attn_backend(
|
||||||
|
head_size=head_size,
|
||||||
|
dtype=dtype,
|
||||||
|
kv_cache_dtype=kv_cache_dtype,
|
||||||
|
block_size=block_size,
|
||||||
|
use_v1=envs.VLLM_USE_V1,
|
||||||
|
use_mla=use_mla,
|
||||||
|
use_sfa=use_sfa,
|
||||||
|
has_sink=has_sink,
|
||||||
|
)
|
||||||
|
|
||||||
|
@cache
|
||||||
|
def _cached_get_attn_backend(
|
||||||
|
head_size: int,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
kv_cache_dtype: Optional[str],
|
||||||
|
block_size: int,
|
||||||
|
use_v1: bool = False,
|
||||||
|
use_mla: bool = False,
|
||||||
|
use_sfa: bool = False,
|
||||||
|
has_sink: bool = False,
|
||||||
|
) -> type[AttentionBackend]:
|
||||||
|
# Check whether a particular choice of backend was
|
||||||
|
# previously forced.
|
||||||
|
#
|
||||||
|
# THIS SELECTION OVERRIDES THE VLLM_ATTENTION_BACKEND
|
||||||
|
# ENVIRONMENT VARIABLE.
|
||||||
|
selected_backend = None
|
||||||
|
backend_by_global_setting: Optional[_Backend] = (
|
||||||
|
get_global_forced_attn_backend())
|
||||||
|
if backend_by_global_setting is not None:
|
||||||
|
selected_backend = backend_by_global_setting
|
||||||
|
else:
|
||||||
|
# Check the environment variable and override if specified
|
||||||
|
backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND
|
||||||
|
if backend_by_env_var is not None:
|
||||||
|
selected_backend = backend_name_to_enum(backend_by_env_var)
|
||||||
|
if selected_backend is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid attention backend: '{backend_by_env_var}'. "
|
||||||
|
f"Valid backends are: {list(_Backend.__members__.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# get device-specific attn_backend
|
||||||
|
attention_cls = current_platform.get_attn_backend_cls(
|
||||||
|
selected_backend, head_size, dtype, kv_cache_dtype, block_size,
|
||||||
|
use_v1, use_mla, use_sfa, has_sink)
|
||||||
|
if not attention_cls:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid attention backend for {current_platform.device_name}"
|
||||||
|
)
|
||||||
|
return resolve_obj_by_qualname(attention_cls)
|
||||||
|
|
||||||
|
|
||||||
|
vllm.attention.get_attn_backend = get_attn_backend
|
||||||
|
vllm.attention.selector._cached_get_attn_backend = _cached_get_attn_backend
|
||||||
110
vllm_ascend/patch/worker/patch_common/patch_attentionspec.py
Normal file
110
vllm_ascend/patch/worker/patch_common/patch_attentionspec.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
from dataclasses import dataclass, fields
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import vllm
|
||||||
|
from typing_extensions import Self
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.utils import cdiv, get_dtype_size
|
||||||
|
from vllm.v1.core.single_type_kv_cache_manager import (FullAttentionManager,
|
||||||
|
spec_manager_map)
|
||||||
|
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AttentionSpec(KVCacheSpec):
|
||||||
|
num_kv_heads: int
|
||||||
|
head_size: int
|
||||||
|
dtype: torch.dtype
|
||||||
|
use_mla: bool
|
||||||
|
use_sfa: bool
|
||||||
|
|
||||||
|
@property
|
||||||
|
def page_size_bytes(self) -> int:
|
||||||
|
# For MLA we only store a single latent vector
|
||||||
|
coef = 1 if self.use_mla else 2
|
||||||
|
sfa_bytes = 128 * self.block_size * get_dtype_size(
|
||||||
|
self.dtype) if self.use_sfa else 0
|
||||||
|
|
||||||
|
return coef * self.block_size * self.num_kv_heads * self.head_size \
|
||||||
|
* get_dtype_size(self.dtype) + sfa_bytes
|
||||||
|
|
||||||
|
|
||||||
|
vllm.v1.kv_cache_interface.AttentionSpec = AttentionSpec
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class AscendFullAttentionSpec(FullAttentionSpec, AttentionSpec):
|
||||||
|
sliding_window: Optional[int] = None
|
||||||
|
attention_chunk_size: Optional[int] = None
|
||||||
|
"""
|
||||||
|
When hybrid allocator is disabled and the model contains both full
|
||||||
|
attention layers and sliding window attention layers, sliding
|
||||||
|
window attention are regarded as full attention in KV cache manager
|
||||||
|
(blocks are allocated for all tokens), while computed as sliding window
|
||||||
|
attention in model runner.
|
||||||
|
In this case, we use FullAttentionSpec and record the sliding window size.
|
||||||
|
Default to None for not using sliding window attention.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
|
||||||
|
max_model_len = vllm_config.model_config.max_model_len
|
||||||
|
dcp_world_size = \
|
||||||
|
vllm_config.parallel_config.decode_context_parallel_size
|
||||||
|
# Note(hc): each dcp rank only need save
|
||||||
|
# (max_model_len//dcp_world_size) tokens locally.
|
||||||
|
if dcp_world_size > 1:
|
||||||
|
max_model_len = cdiv(max_model_len, dcp_world_size)
|
||||||
|
return cdiv(max_model_len, self.block_size) * self.page_size_bytes
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def merge_window_sizes(cls, window_sizes: set[int]) -> Optional[int]:
|
||||||
|
if len(window_sizes) == 0:
|
||||||
|
return None
|
||||||
|
elif len(window_sizes) == 1:
|
||||||
|
return window_sizes.pop()
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"All attention layers in the same KV cache group must have the "
|
||||||
|
"same window size.")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def merge(cls, specs: list[Self]) -> Self:
|
||||||
|
"""
|
||||||
|
Merge a list of FullAttentionSpec objects into a single
|
||||||
|
FullAttentionSpec object.
|
||||||
|
"""
|
||||||
|
assert all(isinstance(spec, FullAttentionSpec) for spec in specs), (
|
||||||
|
"All attention layers in the same KV cache group must be "
|
||||||
|
"FullAttentionSpec.")
|
||||||
|
|
||||||
|
sliding_window = set(spec.sliding_window for spec in specs
|
||||||
|
if spec.sliding_window is not None)
|
||||||
|
attention_chunk_size = set(spec.attention_chunk_size for spec in specs
|
||||||
|
if spec.attention_chunk_size is not None)
|
||||||
|
merged_spec = cls(
|
||||||
|
block_size=specs[0].block_size,
|
||||||
|
num_kv_heads=specs[0].num_kv_heads,
|
||||||
|
head_size=specs[0].head_size,
|
||||||
|
dtype=specs[0].dtype,
|
||||||
|
use_mla=specs[0].use_mla,
|
||||||
|
use_sfa=specs[0].use_sfa,
|
||||||
|
sliding_window=cls.merge_window_sizes(sliding_window),
|
||||||
|
attention_chunk_size=cls.merge_window_sizes(attention_chunk_size),
|
||||||
|
)
|
||||||
|
for spec in specs:
|
||||||
|
for f in fields(AttentionSpec):
|
||||||
|
assert getattr(spec, f.name) == getattr(merged_spec, f.name), (
|
||||||
|
"All attention layers in the same KV cache group must have "
|
||||||
|
"the same attention spec.")
|
||||||
|
assert (
|
||||||
|
(merged_spec.sliding_window is not None) +
|
||||||
|
(merged_spec.attention_chunk_size is not None) <= 1
|
||||||
|
), ("Model with both sliding window layers and chunked local attention "
|
||||||
|
"layers is not supported.")
|
||||||
|
return merged_spec
|
||||||
|
|
||||||
|
|
||||||
|
spec_manager_map.update({AscendFullAttentionSpec: FullAttentionManager})
|
||||||
|
|
||||||
|
vllm.v1.kv_cache_interface.FullAttentionSpec = AscendFullAttentionSpec
|
||||||
@@ -300,6 +300,7 @@ class NPUPlatform(Platform):
|
|||||||
block_size,
|
block_size,
|
||||||
use_v1,
|
use_v1,
|
||||||
use_mla,
|
use_mla,
|
||||||
|
use_sfa,
|
||||||
has_sink=False):
|
has_sink=False):
|
||||||
if not use_v1:
|
if not use_v1:
|
||||||
raise ValueError("vLLM Ascend does not support V0 engine.")
|
raise ValueError("vLLM Ascend does not support V0 engine.")
|
||||||
@@ -307,21 +308,28 @@ class NPUPlatform(Platform):
|
|||||||
ascend_config = get_ascend_config()
|
ascend_config = get_ascend_config()
|
||||||
|
|
||||||
if use_mla and ascend_config.enable_shared_expert_dp:
|
if use_mla and ascend_config.enable_shared_expert_dp:
|
||||||
|
if use_mla and not use_sfa:
|
||||||
return "vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend"
|
return "vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend"
|
||||||
|
if use_mla and use_sfa:
|
||||||
|
return "vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend"
|
||||||
|
|
||||||
use_torchair = ascend_config.torchair_graph_config.enabled
|
use_torchair = ascend_config.torchair_graph_config.enabled
|
||||||
# choose attention backend based on use_mla and use_torchair
|
# choose attention backend based on use_mla and use_torchair
|
||||||
backend_map = {
|
backend_map = {
|
||||||
(True, True):
|
(True, False, True):
|
||||||
"vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend",
|
"vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend",
|
||||||
(True, False):
|
(True, False, False):
|
||||||
"vllm_ascend.attention.mla_v1.AscendMLABackend",
|
"vllm_ascend.attention.mla_v1.AscendMLABackend",
|
||||||
(False, True):
|
(False, False, True):
|
||||||
"vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend",
|
"vllm_ascend.torchair.torchair_attention.AscendAttentionTorchairBackend",
|
||||||
(False, False):
|
(False, False, False):
|
||||||
"vllm_ascend.attention.attention_v1.AscendAttentionBackend"
|
"vllm_ascend.attention.attention_v1.AscendAttentionBackend",
|
||||||
|
(True, True, False):
|
||||||
|
"vllm_ascend.attention.sfa_v1.AscendSFABackend",
|
||||||
|
(True, True, True):
|
||||||
|
"vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend",
|
||||||
}
|
}
|
||||||
return backend_map[(use_mla, use_torchair)]
|
return backend_map[(use_mla, use_sfa, use_torchair)]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_punica_wrapper(cls) -> str:
|
def get_punica_wrapper(cls) -> str:
|
||||||
|
|||||||
@@ -603,8 +603,9 @@ class MtpProposer(Proposer):
|
|||||||
torch.npu.set_compile_mode(jit_compile=False)
|
torch.npu.set_compile_mode(jit_compile=False)
|
||||||
if not self.runner.use_cached_npu_graph:
|
if not self.runner.use_cached_npu_graph:
|
||||||
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
||||||
self.torchair_compiled_model = torch.compile(self.model,
|
self.torchair_compiled_model = torch.compile(
|
||||||
dynamic=True,
|
self.model,
|
||||||
|
dynamic=not get_ascend_config().use_sfa,
|
||||||
fullgraph=True,
|
fullgraph=True,
|
||||||
backend=npu_backend)
|
backend=npu_backend)
|
||||||
return self.torchair_compiled_model
|
return self.torchair_compiled_model
|
||||||
@@ -627,7 +628,7 @@ class MtpProposer(Proposer):
|
|||||||
self.torchair_compiled_models[
|
self.torchair_compiled_models[
|
||||||
batch_size] = torchair.inference.cache_compile(
|
batch_size] = torchair.inference.cache_compile(
|
||||||
self.model.__dict__[forward_proxy_name],
|
self.model.__dict__[forward_proxy_name],
|
||||||
dynamic=True,
|
dynamic=not get_ascend_config().use_sfa,
|
||||||
fullgraph=True,
|
fullgraph=True,
|
||||||
cache_dir=TORCHAIR_CACHE_DIR,
|
cache_dir=TORCHAIR_CACHE_DIR,
|
||||||
config=config,
|
config=config,
|
||||||
|
|||||||
@@ -67,7 +67,9 @@ from vllm.model_executor.models.utils import (
|
|||||||
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
|
make_empty_intermediate_tensors_factory, make_layers, maybe_prefix)
|
||||||
from vllm.sequence import IntermediateTensors
|
from vllm.sequence import IntermediateTensors
|
||||||
|
|
||||||
|
from vllm_ascend import envs
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
|
from vllm_ascend.models.layers.sfa import Indexer
|
||||||
from vllm_ascend.quantization.quant_config import AscendLinearMethod
|
from vllm_ascend.quantization.quant_config import AscendLinearMethod
|
||||||
from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE
|
from vllm_ascend.torchair.ops.torchair_fused_moe import TorchairAscendFusedMoE
|
||||||
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \
|
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \
|
||||||
@@ -435,6 +437,7 @@ class TorchairDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
|||||||
cache_config: Optional[CacheConfig] = None,
|
cache_config: Optional[CacheConfig] = None,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
decoder_layer=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
nn.Module.__init__(self)
|
nn.Module.__init__(self)
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
@@ -630,6 +633,225 @@ class TorchairDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
|||||||
output_shape=output_shape)
|
output_shape=output_shape)
|
||||||
|
|
||||||
|
|
||||||
|
class TorchairDeepseekV2SFAAttention(DeepseekV2MLAAttention):
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
hidden_size: int,
|
||||||
|
num_heads: int,
|
||||||
|
qk_nope_head_dim: int,
|
||||||
|
qk_rope_head_dim: int,
|
||||||
|
v_head_dim: int,
|
||||||
|
q_lora_rank: Optional[int],
|
||||||
|
kv_lora_rank: int,
|
||||||
|
rope_theta: float = 10000,
|
||||||
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||||
|
max_position_embeddings: int = 8192,
|
||||||
|
cache_config: Optional[CacheConfig] = None,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
decoder_layer=None,
|
||||||
|
) -> None:
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.qk_nope_head_dim = qk_nope_head_dim
|
||||||
|
self.qk_rope_head_dim = qk_rope_head_dim
|
||||||
|
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
|
||||||
|
self.v_head_dim = v_head_dim
|
||||||
|
|
||||||
|
self.q_lora_rank = q_lora_rank
|
||||||
|
self.kv_lora_rank = kv_lora_rank
|
||||||
|
|
||||||
|
self.num_heads = num_heads
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
assert num_heads % self.tp_size == 0
|
||||||
|
self.num_local_heads = num_heads // self.tp_size
|
||||||
|
self.layers = config.num_hidden_layers
|
||||||
|
self.first_k_dense_replace = config.first_k_dense_replace
|
||||||
|
|
||||||
|
self.scaling = self.qk_head_dim**-0.5
|
||||||
|
self.rope_theta = rope_theta
|
||||||
|
self.max_position_embeddings = max_position_embeddings
|
||||||
|
|
||||||
|
self.prefix = prefix
|
||||||
|
self.debug_layer_idx = int(self.prefix.split(".")[-2])
|
||||||
|
|
||||||
|
ascend_config = get_ascend_config()
|
||||||
|
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||||
|
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||||
|
|
||||||
|
if self.q_lora_rank is not None:
|
||||||
|
self.q_a_proj = ReplicatedLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.q_lora_rank,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.q_a_proj",
|
||||||
|
return_bias=False,
|
||||||
|
)
|
||||||
|
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.q_b_proj = ColumnParallelLinear(
|
||||||
|
q_lora_rank,
|
||||||
|
self.num_heads * self.qk_head_dim,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.q_b_proj",
|
||||||
|
return_bias=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.q_proj = ColumnParallelLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.num_heads * self.qk_head_dim,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.q_proj",
|
||||||
|
return_bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.kv_a_proj_with_mqa = ReplicatedLinear(
|
||||||
|
self.hidden_size,
|
||||||
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.kv_a_proj_with_mqa",
|
||||||
|
return_bias=False,
|
||||||
|
)
|
||||||
|
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
|
||||||
|
eps=config.rms_norm_eps)
|
||||||
|
self.kv_b_proj = ColumnParallelLinear(
|
||||||
|
self.kv_lora_rank,
|
||||||
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.kv_b_proj",
|
||||||
|
return_bias=False,
|
||||||
|
)
|
||||||
|
if (config.n_routed_experts is not None
|
||||||
|
and self.debug_layer_idx >= config.first_k_dense_replace
|
||||||
|
and self.debug_layer_idx % config.moe_layer_freq == 0
|
||||||
|
and (ascend_config.multistream_overlap_shared_expert
|
||||||
|
or self.enable_shared_expert_dp)):
|
||||||
|
self.o_proj = TorchairDeepseekV2RowParallelLinearReplaceAllreduce(
|
||||||
|
self.num_heads * self.v_head_dim,
|
||||||
|
self.hidden_size,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
return_bias=False,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.o_proj = TorchairDeepseekV2RowParallelLinear(
|
||||||
|
self.num_heads * self.v_head_dim,
|
||||||
|
self.hidden_size,
|
||||||
|
bias=False,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.o_proj",
|
||||||
|
return_bias=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
if rope_scaling:
|
||||||
|
rope_scaling["rope_type"] = 'deepseek_yarn'
|
||||||
|
self.rotary_emb = get_rope(qk_rope_head_dim,
|
||||||
|
rotary_dim=qk_rope_head_dim,
|
||||||
|
max_position=max_position_embeddings,
|
||||||
|
base=rope_theta,
|
||||||
|
rope_scaling=rope_scaling,
|
||||||
|
is_neox_style=False)
|
||||||
|
if rope_scaling:
|
||||||
|
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
|
||||||
|
scaling_factor = rope_scaling["factor"]
|
||||||
|
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||||
|
self.scaling = self.scaling * mscale * mscale
|
||||||
|
|
||||||
|
self.dim: int = config.hidden_size # 7168
|
||||||
|
# TODO(zzzzwwjj): wait transformers add these params
|
||||||
|
self.n_heads: int = 64 # 64
|
||||||
|
self.head_dim: int = 128 # 128
|
||||||
|
self.index_topk: int = 2048 # 2048
|
||||||
|
self.indexer = Indexer(
|
||||||
|
config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
dim=self.dim,
|
||||||
|
n_heads=self.n_heads,
|
||||||
|
head_dim=self.head_dim,
|
||||||
|
index_topk=self.index_topk,
|
||||||
|
prefix=f"{prefix}.indexer",
|
||||||
|
)
|
||||||
|
|
||||||
|
self.sfa_attn = Attention(
|
||||||
|
num_heads=self.num_local_heads,
|
||||||
|
head_size=self.kv_lora_rank + self.qk_rope_head_dim,
|
||||||
|
scale=self.scaling,
|
||||||
|
num_kv_heads=1,
|
||||||
|
cache_config=cache_config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=f"{prefix}.attn",
|
||||||
|
use_mla=True,
|
||||||
|
use_sfa=True,
|
||||||
|
# SFA Args
|
||||||
|
q_lora_rank=self.q_lora_rank,
|
||||||
|
kv_lora_rank=self.kv_lora_rank,
|
||||||
|
qk_nope_head_dim=self.qk_nope_head_dim,
|
||||||
|
qk_rope_head_dim=self.qk_rope_head_dim,
|
||||||
|
qk_head_dim=self.qk_head_dim,
|
||||||
|
v_head_dim=self.v_head_dim,
|
||||||
|
rotary_emb=self.rotary_emb,
|
||||||
|
q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None,
|
||||||
|
q_a_layernorm=self.q_a_layernorm
|
||||||
|
if self.q_lora_rank is not None else None,
|
||||||
|
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
|
||||||
|
kv_a_proj_with_mqa=self.kv_a_proj_with_mqa,
|
||||||
|
kv_a_layernorm=self.kv_a_layernorm,
|
||||||
|
kv_b_proj=self.kv_b_proj,
|
||||||
|
o_proj=self.o_proj,
|
||||||
|
indexer=self.indexer,
|
||||||
|
decoder_layer=decoder_layer,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
kv_cache: Optional[torch.Tensor] = None,
|
||||||
|
attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor:
|
||||||
|
forward_context = get_forward_context()
|
||||||
|
if not self.torchair_graph_enabled:
|
||||||
|
if forward_context.attn_metadata is not None and isinstance(
|
||||||
|
forward_context.attn_metadata, dict):
|
||||||
|
attn_metadata = next(
|
||||||
|
iter(forward_context.attn_metadata.values()), None)
|
||||||
|
else:
|
||||||
|
attn_metadata = forward_context.attn_metadata
|
||||||
|
if kv_cache is None:
|
||||||
|
kv_cache = self.sfa_attn.kv_cache[
|
||||||
|
forward_context.virtual_engine]
|
||||||
|
|
||||||
|
num_tokens = hidden_states.shape[0]
|
||||||
|
need_gather_q_kv = False
|
||||||
|
# if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
|
||||||
|
# # Simulate all gather to calculate output shape
|
||||||
|
# num_tokens = num_tokens * self.tp_size
|
||||||
|
# need_gather_q_kv = True
|
||||||
|
if not self.enable_shared_expert_dp or self.debug_layer_idx != self.first_k_dense_replace:
|
||||||
|
output_shape = hidden_states.shape
|
||||||
|
if self.enable_shared_expert_dp and (
|
||||||
|
self.debug_layer_idx == self.first_k_dense_replace
|
||||||
|
or self.debug_layer_idx == self.layers):
|
||||||
|
rows = num_tokens // self.tp_size
|
||||||
|
if num_tokens % self.tp_size:
|
||||||
|
rows += 1
|
||||||
|
output_shape = (rows, hidden_states.shape[1])
|
||||||
|
output = torch.empty(output_shape,
|
||||||
|
dtype=hidden_states.dtype,
|
||||||
|
device=hidden_states.device)
|
||||||
|
self.sfa_attn.impl.forward(hidden_states, kv_cache, attn_metadata,
|
||||||
|
need_gather_q_kv, output)
|
||||||
|
output = output.view(-1, output_shape[-1])
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -654,9 +876,16 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
|||||||
self.tp_size = get_tensor_model_parallel_world_size()
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
self.tp_rank = get_tp_group().rank_in_group
|
self.tp_rank = get_tp_group().rank_in_group
|
||||||
ascend_config = get_ascend_config()
|
ascend_config = get_ascend_config()
|
||||||
|
self.use_mla = False
|
||||||
|
self.use_sfa = False
|
||||||
# TODO: enable mla in vllm-ascend
|
# TODO: enable mla in vllm-ascend
|
||||||
if model_config.use_mla:
|
if model_config.use_mla:
|
||||||
attn_cls = TorchairDeepseekV2MLAAttention
|
if ascend_config.use_sfa:
|
||||||
|
attn_cls = TorchairDeepseekV2SFAAttention
|
||||||
|
self.use_sfa = True
|
||||||
|
else:
|
||||||
|
attn_cls = TorchairDeepseekV2MLAAttention # type: ignore[assignment]
|
||||||
|
self.use_mla = True
|
||||||
else:
|
else:
|
||||||
attn_cls = DeepseekV2Attention
|
attn_cls = DeepseekV2Attention
|
||||||
self.self_attn = attn_cls(
|
self.self_attn = attn_cls(
|
||||||
@@ -675,6 +904,7 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
|||||||
cache_config=cache_config,
|
cache_config=cache_config,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.self_attn",
|
prefix=f"{prefix}.self_attn",
|
||||||
|
decoder_layer=self,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (config.n_routed_experts is not None
|
if (config.n_routed_experts is not None
|
||||||
@@ -715,10 +945,23 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
|||||||
replace_allreduce: bool = False,
|
replace_allreduce: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Self Attention
|
# Self Attention
|
||||||
if attn_metadata is not None and attn_metadata.num_decodes > 0:
|
if attn_metadata is not None:
|
||||||
mla_moe_communication = self.mla_moe_communication and replace_allreduce
|
decoding_condition_met = (
|
||||||
|
not attn_metadata.is_prefill if self.use_sfa else
|
||||||
|
attn_metadata.num_decodes > 0 if self.use_mla else False)
|
||||||
|
mla_moe_communication = decoding_condition_met and self.mla_moe_communication and replace_allreduce
|
||||||
else:
|
else:
|
||||||
mla_moe_communication = False
|
mla_moe_communication = False
|
||||||
|
|
||||||
|
forward_context = get_forward_context()
|
||||||
|
if (envs.VLLM_ASCEND_ENABLE_MLAPO
|
||||||
|
and isinstance(self.self_attn, TorchairDeepseekV2SFAAttention)
|
||||||
|
and attn_metadata is not None
|
||||||
|
and not forward_context.with_prefill):
|
||||||
|
if residual is not None:
|
||||||
|
hidden_states = hidden_states + residual
|
||||||
|
residual = hidden_states
|
||||||
|
else:
|
||||||
if residual is None:
|
if residual is None:
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
hidden_states = self.input_layernorm(hidden_states)
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|||||||
@@ -48,8 +48,8 @@ from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
|||||||
class NPUTorchairModelRunner(NPUModelRunner):
|
class NPUTorchairModelRunner(NPUModelRunner):
|
||||||
|
|
||||||
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
def __init__(self, vllm_config: VllmConfig, device: torch.device):
|
||||||
ascend_config = get_ascend_config()
|
self.ascend_config = get_ascend_config()
|
||||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
self.enable_shared_expert_dp = self.ascend_config.enable_shared_expert_dp
|
||||||
super().__init__(vllm_config, device)
|
super().__init__(vllm_config, device)
|
||||||
if self.speculative_config:
|
if self.speculative_config:
|
||||||
self.actual_seq_lengths_q = list(
|
self.actual_seq_lengths_q = list(
|
||||||
@@ -66,10 +66,10 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
|||||||
self.new_kv_cache_bytes = -1
|
self.new_kv_cache_bytes = -1
|
||||||
self.torchair_compiled_model = None # type: ignore
|
self.torchair_compiled_model = None # type: ignore
|
||||||
self.torchair_compiled_models = {} # type: ignore
|
self.torchair_compiled_models = {} # type: ignore
|
||||||
self.use_cached_npu_graph = ascend_config.torchair_graph_config.use_cached_graph
|
self.use_cached_npu_graph = self.ascend_config.torchair_graph_config.use_cached_graph
|
||||||
self.use_cached_kv_cache_bytes = ascend_config.torchair_graph_config.use_cached_kv_cache_bytes
|
self.use_cached_kv_cache_bytes = self.ascend_config.torchair_graph_config.use_cached_kv_cache_bytes
|
||||||
self.torchair_graph_batch_sizes = ascend_config.torchair_graph_config.graph_batch_sizes
|
self.torchair_graph_batch_sizes = self.ascend_config.torchair_graph_config.graph_batch_sizes
|
||||||
if ascend_config.torchair_graph_config.graph_batch_sizes_init:
|
if self.ascend_config.torchair_graph_config.graph_batch_sizes_init:
|
||||||
self.init_torchair_graph_batch_sizes()
|
self.init_torchair_graph_batch_sizes()
|
||||||
|
|
||||||
self.update_torchair_graph_batch_sizes()
|
self.update_torchair_graph_batch_sizes()
|
||||||
@@ -362,20 +362,21 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
|||||||
communication_adaptation_310p()
|
communication_adaptation_310p()
|
||||||
|
|
||||||
config = torchair.CompilerConfig()
|
config = torchair.CompilerConfig()
|
||||||
if get_ascend_config().torchair_graph_config.mode:
|
if self.ascend_config.torchair_graph_config.mode:
|
||||||
config.mode = get_ascend_config().torchair_graph_config.mode
|
config.mode = self.ascend_config.torchair_graph_config.mode
|
||||||
config.experimental_config.frozen_parameter = \
|
config.experimental_config.frozen_parameter = \
|
||||||
get_ascend_config().torchair_graph_config.enable_frozen_parameter
|
self.ascend_config.torchair_graph_config.enable_frozen_parameter
|
||||||
# enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to
|
# enabling tiling_schedule_optimize on 300I Duo has some bugs, so we have to
|
||||||
# disable it on 300I Duo platform now.
|
# disable it on 300I Duo platform now.
|
||||||
config.experimental_config.tiling_schedule_optimize = not is_310p()
|
config.experimental_config.tiling_schedule_optimize = not is_310p()
|
||||||
config.experimental_config.enable_view_optimize = \
|
config.experimental_config.enable_view_optimize = \
|
||||||
get_ascend_config().torchair_graph_config.enable_view_optimize
|
self.ascend_config.torchair_graph_config.enable_view_optimize
|
||||||
torch.npu.set_compile_mode(jit_compile=False)
|
torch.npu.set_compile_mode(jit_compile=False)
|
||||||
if not self.use_cached_npu_graph:
|
if not self.use_cached_npu_graph:
|
||||||
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
npu_backend = torchair.get_npu_backend(compiler_config=config)
|
||||||
self.torchair_compiled_model = torch.compile(self.model,
|
self.torchair_compiled_model = torch.compile(
|
||||||
dynamic=True,
|
self.model,
|
||||||
|
dynamic=not self.ascend_config.use_sfa,
|
||||||
fullgraph=True,
|
fullgraph=True,
|
||||||
backend=npu_backend)
|
backend=npu_backend)
|
||||||
return self.torchair_compiled_model
|
return self.torchair_compiled_model
|
||||||
@@ -398,7 +399,7 @@ class NPUTorchairModelRunner(NPUModelRunner):
|
|||||||
self.torchair_compiled_models[
|
self.torchair_compiled_models[
|
||||||
batch_size] = torchair.inference.cache_compile(
|
batch_size] = torchair.inference.cache_compile(
|
||||||
self.model.__dict__[forward_proxy_name],
|
self.model.__dict__[forward_proxy_name],
|
||||||
dynamic=True,
|
dynamic=not self.ascend_config.use_sfa,
|
||||||
fullgraph=True,
|
fullgraph=True,
|
||||||
cache_dir=TORCHAIR_CACHE_DIR,
|
cache_dir=TORCHAIR_CACHE_DIR,
|
||||||
config=config,
|
config=config,
|
||||||
|
|||||||
1330
vllm_ascend/torchair/torchair_sfa.py
Normal file
1330
vllm_ascend/torchair/torchair_sfa.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -165,6 +165,11 @@ def register_torchair_model():
|
|||||||
"vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM"
|
"vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ModelRegistry.register_model(
|
||||||
|
"DeepseekV32ForCausalLM",
|
||||||
|
"vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM"
|
||||||
|
)
|
||||||
|
|
||||||
ModelRegistry.register_model(
|
ModelRegistry.register_model(
|
||||||
"Qwen2ForCausalLM",
|
"Qwen2ForCausalLM",
|
||||||
"vllm_ascend.torchair.models.qwen2:CustomQwen2ForCausalLM")
|
"vllm_ascend.torchair.models.qwen2:CustomQwen2ForCausalLM")
|
||||||
|
|||||||
@@ -285,8 +285,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.intermediate_tensors: Optional[IntermediateTensors] = None
|
self.intermediate_tensors: Optional[IntermediateTensors] = None
|
||||||
self.runner_only_attn_layers: set[str] = set()
|
self.runner_only_attn_layers: set[str] = set()
|
||||||
|
|
||||||
ascend_config = get_ascend_config()
|
self.ascend_config = get_ascend_config()
|
||||||
if ascend_config.ascend_scheduler_config.enabled:
|
if self.ascend_config.ascend_scheduler_config.enabled:
|
||||||
self.chunked_prefill_enabled = self.scheduler_config.chunked_prefill_enabled
|
self.chunked_prefill_enabled = self.scheduler_config.chunked_prefill_enabled
|
||||||
else:
|
else:
|
||||||
self.chunked_prefill_enabled = True
|
self.chunked_prefill_enabled = True
|
||||||
@@ -298,6 +298,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.cache_config.cache_dtype]
|
self.cache_config.cache_dtype]
|
||||||
# use_hybrid_blocks: if hybrid blocks is used.
|
# use_hybrid_blocks: if hybrid blocks is used.
|
||||||
self.use_hybrid_blocks: bool = False
|
self.use_hybrid_blocks: bool = False
|
||||||
|
self.need_accepted_tokens: bool = False
|
||||||
|
|
||||||
self.is_multimodal_model = self.model_config.is_multimodal_model
|
self.is_multimodal_model = self.model_config.is_multimodal_model
|
||||||
self.is_pooling_model = self.model_config.pooler_config is not None
|
self.is_pooling_model = self.model_config.pooler_config is not None
|
||||||
@@ -315,7 +316,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.block_size,
|
self.block_size,
|
||||||
self.model_config.is_attention_free,
|
self.model_config.is_attention_free,
|
||||||
use_mla=self.model_config.use_mla,
|
use_mla=self.model_config.use_mla,
|
||||||
)
|
use_sfa=self.ascend_config.use_sfa)
|
||||||
else:
|
else:
|
||||||
self.attn_backend = get_attn_backend(
|
self.attn_backend = get_attn_backend(
|
||||||
0,
|
0,
|
||||||
@@ -323,7 +324,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
None,
|
None,
|
||||||
self.block_size,
|
self.block_size,
|
||||||
use_mla=self.model_config.use_mla,
|
use_mla=self.model_config.use_mla,
|
||||||
)
|
use_sfa=self.ascend_config.use_sfa)
|
||||||
if torch.version.cann.startswith("8.3"):
|
if torch.version.cann.startswith("8.3"):
|
||||||
self.attn_mask_builder = AttentionMaskBuilder(
|
self.attn_mask_builder = AttentionMaskBuilder(
|
||||||
self.scheduler_config.max_num_batched_tokens, self.dtype,
|
self.scheduler_config.max_num_batched_tokens, self.dtype,
|
||||||
@@ -457,7 +458,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
dtype=torch.bool,
|
dtype=torch.bool,
|
||||||
device=self.device,
|
device=self.device,
|
||||||
)
|
)
|
||||||
self.dynamic_eplb = ascend_config.dynamic_eplb
|
self.dynamic_eplb = self.ascend_config.dynamic_eplb
|
||||||
if self.dynamic_eplb:
|
if self.dynamic_eplb:
|
||||||
self.is_eplb_warmuped = False
|
self.is_eplb_warmuped = False
|
||||||
self.eplb_loader = D2DExpertWeightLoader()
|
self.eplb_loader = D2DExpertWeightLoader()
|
||||||
@@ -890,15 +891,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
def _make_attention_mask(self, seq_lens, position,
|
def _make_attention_mask(self, seq_lens, position,
|
||||||
attn_state) -> torch.Tensor:
|
attn_state) -> torch.Tensor:
|
||||||
# Chunk Prefill situation.
|
# Chunk Prefill situation.
|
||||||
if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla:
|
if attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.ascend_config.use_sfa:
|
||||||
if torch.version.cann.startswith("8.3"):
|
if torch.version.cann.startswith("8.3"):
|
||||||
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
return self.attn_mask_builder.get_splitfuse_attn_mask()
|
||||||
else:
|
else:
|
||||||
return self.attn_mask_builder.get_splitfuse_attn_mask(
|
return self.attn_mask_builder.get_splitfuse_attn_mask(
|
||||||
seq_lens, position, self.dtype, self.device)
|
seq_lens, position, self.dtype, self.device)
|
||||||
|
|
||||||
# Prefill without cache situation.
|
# Prefill without cache situation.
|
||||||
elif attn_state == AscendAttentionState.PrefillNoCache:
|
elif attn_state == AscendAttentionState.PrefillNoCache:
|
||||||
max_seq_len = max(seq_lens, default=0)
|
max_seq_len = max(seq_lens.max().item(), 0)
|
||||||
return self.attn_mask_builder.get_attn_mask(
|
return self.attn_mask_builder.get_attn_mask(
|
||||||
max_seq_len, self.dtype, self.device)
|
max_seq_len, self.dtype, self.device)
|
||||||
# Prefill with cache hit.
|
# Prefill with cache hit.
|
||||||
@@ -1252,7 +1254,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
req_ids = self.input_batch.req_ids
|
req_ids = self.input_batch.req_ids
|
||||||
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
|
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
|
||||||
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
num_scheduled_tokens = np.array(tokens, dtype=np.int32)
|
||||||
max_num_scheduled_tokens = max(tokens)
|
max_num_scheduled_tokens = num_scheduled_tokens.max()
|
||||||
num_valid_tokens = np.array([
|
num_valid_tokens = np.array([
|
||||||
num_tokens -
|
num_tokens -
|
||||||
len(scheduler_output.scheduled_spec_decode_tokens.get(i, []))
|
len(scheduler_output.scheduled_spec_decode_tokens.get(i, []))
|
||||||
@@ -1376,8 +1378,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
positions_cpu = self.positions_cpu[:num_input_tokens]
|
positions_cpu = self.positions_cpu[:num_input_tokens]
|
||||||
positions = self.positions[:num_input_tokens]
|
positions = self.positions[:num_input_tokens]
|
||||||
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
|
seq_lens_cpu = self.seq_lens_cpu[:num_reqs]
|
||||||
attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens,
|
|
||||||
num_valid_tokens)
|
|
||||||
self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
|
self.attn_mask = self._make_attention_mask(seq_lens=seq_lens_cpu,
|
||||||
position=positions_cpu,
|
position=positions_cpu,
|
||||||
attn_state=attn_state)
|
attn_state=attn_state)
|
||||||
@@ -1477,7 +1477,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_computed_tokens_cpu = (
|
num_computed_tokens_cpu = (
|
||||||
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
|
self.input_batch.num_computed_tokens_cpu_tensor[:num_reqs])
|
||||||
spec_decode_common_attn_metadata = None
|
spec_decode_common_attn_metadata = None
|
||||||
if use_spec_decode:
|
if use_spec_decode and self.need_accepted_tokens:
|
||||||
self.num_accepted_tokens.np[:num_reqs] = (
|
self.num_accepted_tokens.np[:num_reqs] = (
|
||||||
self.input_batch.num_accepted_tokens_cpu[:num_reqs])
|
self.input_batch.num_accepted_tokens_cpu[:num_reqs])
|
||||||
self.num_accepted_tokens.np[num_reqs:].fill(1)
|
self.num_accepted_tokens.np[num_reqs:].fill(1)
|
||||||
@@ -1550,7 +1550,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
model=self.model,
|
model=self.model,
|
||||||
**extra_attn_metadata_args)
|
**extra_attn_metadata_args)
|
||||||
|
|
||||||
if self.vllm_config.model_config.use_mla:
|
if self.vllm_config.model_config.use_mla or self.ascend_config.use_sfa:
|
||||||
attn_metadata_i.num_input_tokens = num_input_tokens
|
attn_metadata_i.num_input_tokens = num_input_tokens
|
||||||
for layer_name in attn_group.layer_names:
|
for layer_name in attn_group.layer_names:
|
||||||
attn_metadata[layer_name] = attn_metadata_i
|
attn_metadata[layer_name] = attn_metadata_i
|
||||||
@@ -2060,6 +2060,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
sampling_metadata,
|
sampling_metadata,
|
||||||
)
|
)
|
||||||
sampler_output.sampled_token_ids = output_token_ids
|
sampler_output.sampled_token_ids = output_token_ids
|
||||||
|
if self.need_accepted_tokens:
|
||||||
self._update_states_after_model_execute(output_token_ids)
|
self._update_states_after_model_execute(output_token_ids)
|
||||||
|
|
||||||
discard_sampled_tokens_req_indices: list[int] = []
|
discard_sampled_tokens_req_indices: list[int] = []
|
||||||
@@ -2683,10 +2684,26 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
self.kv_cache_config = kv_cache_config
|
self.kv_cache_config = kv_cache_config
|
||||||
self.initialize_attn_backend(kv_cache_config)
|
self.initialize_attn_backend(kv_cache_config)
|
||||||
self.use_hybrid_blocks = (len(self.attn_groups) > 1)
|
self.use_hybrid_blocks = (len(self.attn_groups) > 1)
|
||||||
|
# NOTE: Currently, we determine whether we need `num_accepted_tokens` through `MambaSpec`.
|
||||||
|
if vllm_version_is("0.10.2"):
|
||||||
|
self.need_accepted_tokens = any([
|
||||||
|
isinstance(
|
||||||
|
self.kv_cache_config.kv_cache_groups[0].kv_cache_spec,
|
||||||
|
MambaSpec) for attn_group in self.attn_groups
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
self.need_accepted_tokens = any([
|
||||||
|
isinstance(attn_group[0].kv_cache_spec, MambaSpec)
|
||||||
|
for attn_group in self.attn_groups
|
||||||
|
])
|
||||||
|
|
||||||
self.may_reinitialize_input_batch(kv_cache_config)
|
self.may_reinitialize_input_batch(kv_cache_config)
|
||||||
|
|
||||||
if self.model_config.is_deepseek_mla:
|
if self.ascend_config.is_deepseek_sfa:
|
||||||
kv_caches = self.initialize_kv_cache_tensors_deepseek(
|
kv_caches = self.initialize_kv_cache_tensors_deepseek_sfa(
|
||||||
|
kv_cache_config)
|
||||||
|
elif self.model_config.is_deepseek_mla:
|
||||||
|
kv_caches = self.initialize_kv_cache_tensors_deepseek_mla(
|
||||||
kv_cache_config)
|
kv_cache_config)
|
||||||
else:
|
else:
|
||||||
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
|
kv_caches = self.initialize_kv_cache_tensors(kv_cache_config)
|
||||||
@@ -2701,7 +2718,116 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
offset = (aligned_addr - data_ptr) // tensor.element_size()
|
offset = (aligned_addr - data_ptr) // tensor.element_size()
|
||||||
return tensor[int(offset):]
|
return tensor[int(offset):]
|
||||||
|
|
||||||
def initialize_kv_cache_tensors_deepseek(
|
def initialize_kv_cache_tensors_deepseek_sfa(
|
||||||
|
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
||||||
|
kv_cache_sizes = {}
|
||||||
|
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||||
|
assert len(kv_cache_tensor.shared_by) == 1, (
|
||||||
|
"KV cache tensor shared by multiple layers is not supported in "
|
||||||
|
"NPU.")
|
||||||
|
kv_cache_sizes[kv_cache_tensor.shared_by[0]] = kv_cache_tensor.size
|
||||||
|
|
||||||
|
kv_caches: Dict[str, torch.Tensor] = {}
|
||||||
|
for group in self._kv_cache_spec_attn_group_iterator_dispatcher():
|
||||||
|
if vllm_version_is("0.10.2"):
|
||||||
|
kv_cache_spec, group = group
|
||||||
|
else:
|
||||||
|
kv_cache_spec = group.kv_cache_spec
|
||||||
|
attn_backend = group.backend
|
||||||
|
for layer_name in group.layer_names:
|
||||||
|
if layer_name in self.runner_only_attn_layers:
|
||||||
|
continue
|
||||||
|
tensor_size = kv_cache_sizes[layer_name]
|
||||||
|
num_blocks = tensor_size // kv_cache_spec.page_size_bytes
|
||||||
|
if self.vllm_config.additional_config.get(
|
||||||
|
"kv_cache_dtype", None) == 'int8':
|
||||||
|
kv_cache_shape = attn_backend.get_bsh_kv_cache_shape(
|
||||||
|
num_blocks, kv_cache_spec.block_size,
|
||||||
|
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||||
|
elif hasattr(
|
||||||
|
attn_backend, "get_supported_block_size"
|
||||||
|
) and not self.model_config.is_deepseek_mla and not self.ascend_config.is_deepseek_sfa:
|
||||||
|
block_size = attn_backend.get_supported_block_size()[0]
|
||||||
|
block_size_chunk = kv_cache_spec.block_size // block_size
|
||||||
|
kv_cache_shape = attn_backend.get_kv_cache_shape(
|
||||||
|
num_blocks * block_size_chunk, block_size,
|
||||||
|
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||||
|
else:
|
||||||
|
kv_cache_shape = self.attn_backend.get_kv_cache_shape(
|
||||||
|
num_blocks, kv_cache_spec.block_size,
|
||||||
|
kv_cache_spec.num_kv_heads, kv_cache_spec.head_size)
|
||||||
|
dtype = kv_cache_spec.dtype
|
||||||
|
|
||||||
|
alignment = 2 * 1024 * 1024
|
||||||
|
num_blocks, block_size, num_kv_heads, head_size = kv_cache_shape
|
||||||
|
rope_dim = self.model_config.hf_text_config.qk_rope_head_dim
|
||||||
|
nope_dim = head_size - rope_dim
|
||||||
|
nope_cache_shape = (num_blocks, block_size, num_kv_heads,
|
||||||
|
nope_dim)
|
||||||
|
rope_cache_shape = (num_blocks, block_size, num_kv_heads,
|
||||||
|
rope_dim)
|
||||||
|
#### k cache
|
||||||
|
# TODO(zzzzwwjj): wait transformers add these params
|
||||||
|
k_cache_shape = (num_blocks, block_size, 1, 128)
|
||||||
|
if self.vllm_config.kv_transfer_config is None:
|
||||||
|
# For no disaggregate pd scenario, allocate kv cache in normal way
|
||||||
|
rope_cache = torch.zeros(rope_cache_shape,
|
||||||
|
dtype=dtype,
|
||||||
|
device=self.device)
|
||||||
|
nope_cache = torch.zeros(nope_cache_shape,
|
||||||
|
dtype=dtype,
|
||||||
|
device=self.device)
|
||||||
|
rope_cache = self._convert_torch_format(rope_cache)
|
||||||
|
nope_cache = self._convert_torch_format(nope_cache)
|
||||||
|
|
||||||
|
#### k cache
|
||||||
|
k_cache = torch.zeros(k_cache_shape,
|
||||||
|
dtype=dtype,
|
||||||
|
device=self.device)
|
||||||
|
k_cache = self._convert_torch_format(k_cache)
|
||||||
|
else:
|
||||||
|
|
||||||
|
# In order to transfer kv cache through the reigster_memory api from llmdatadist, the memory
|
||||||
|
# address should be aligned by 2M. In most case, torch_npu can allocate 2M aligned memory, but
|
||||||
|
# we found there are also some exceptions during test, so we manual align those memory here, this part
|
||||||
|
# of code may consume 2M * 2 * elem_size memory every layer.
|
||||||
|
nope_allocate_shape = num_blocks * block_size * num_kv_heads * nope_dim
|
||||||
|
nope_allocate_shape_alignment = nope_allocate_shape + alignment
|
||||||
|
rope_allocate_shape = num_blocks * block_size * num_kv_heads * rope_dim
|
||||||
|
rope_allocate_shape_alignment = rope_allocate_shape + alignment
|
||||||
|
|
||||||
|
nope_cache = torch.zeros(nope_allocate_shape_alignment,
|
||||||
|
dtype=dtype,
|
||||||
|
device=self.device)
|
||||||
|
rope_cache = torch.zeros(rope_allocate_shape_alignment,
|
||||||
|
dtype=dtype,
|
||||||
|
device=self.device)
|
||||||
|
#### k cache
|
||||||
|
# TODO(zzzzwwjj): wait transformers add these params
|
||||||
|
k_allocate_shape = num_blocks * block_size * 1 * 128
|
||||||
|
k_allocate_shape_alignment = k_allocate_shape + alignment
|
||||||
|
k_cache = torch.zeros(k_allocate_shape_alignment,
|
||||||
|
dtype=dtype,
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
|
nope_cache = self._align_memory(
|
||||||
|
nope_cache,
|
||||||
|
alignment)[:nope_allocate_shape].view(nope_cache_shape)
|
||||||
|
rope_cache = self._align_memory(
|
||||||
|
rope_cache,
|
||||||
|
alignment)[:rope_allocate_shape].view(rope_cache_shape)
|
||||||
|
k_cache = self._align_memory(
|
||||||
|
k_cache,
|
||||||
|
alignment)[:k_allocate_shape].view(k_cache_shape)
|
||||||
|
|
||||||
|
kv_caches[layer_name] = (nope_cache, rope_cache, k_cache)
|
||||||
|
bind_kv_cache(kv_caches,
|
||||||
|
self.compilation_config.static_forward_context,
|
||||||
|
self.kv_caches)
|
||||||
|
|
||||||
|
return kv_caches
|
||||||
|
|
||||||
|
def initialize_kv_cache_tensors_deepseek_mla(
|
||||||
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
self, kv_cache_config: KVCacheConfig) -> dict[str, torch.Tensor]:
|
||||||
kv_cache_sizes = {}
|
kv_cache_sizes = {}
|
||||||
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
for kv_cache_tensor in kv_cache_config.kv_cache_tensors:
|
||||||
@@ -3217,6 +3343,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
|
|
||||||
block_size = self.vllm_config.cache_config.block_size
|
block_size = self.vllm_config.cache_config.block_size
|
||||||
use_mla = self.vllm_config.model_config.use_mla
|
use_mla = self.vllm_config.model_config.use_mla
|
||||||
|
use_sfa = self.ascend_config.use_sfa
|
||||||
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||||
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention)
|
||||||
for layer_name, attn_module in attn_layers.items():
|
for layer_name, attn_module in attn_layers.items():
|
||||||
@@ -3243,7 +3370,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
num_kv_heads=attn_module.num_kv_heads,
|
num_kv_heads=attn_module.num_kv_heads,
|
||||||
head_size=attn_module.head_size,
|
head_size=attn_module.head_size,
|
||||||
dtype=self.kv_cache_dtype,
|
dtype=self.kv_cache_dtype,
|
||||||
use_mla=use_mla)
|
use_mla=use_mla,
|
||||||
|
use_sfa=use_sfa)
|
||||||
elif attn_module.attn_type in (AttentionType.ENCODER,
|
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||||
AttentionType.ENCODER_ONLY):
|
AttentionType.ENCODER_ONLY):
|
||||||
# encoder-only attention does not need KV cache.
|
# encoder-only attention does not need KV cache.
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput,
|
|||||||
from vllm.v1.worker.worker_base import WorkerBase
|
from vllm.v1.worker.worker_base import WorkerBase
|
||||||
|
|
||||||
import vllm_ascend.envs as envs_ascend
|
import vllm_ascend.envs as envs_ascend
|
||||||
from vllm_ascend.ascend_config import init_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
|
||||||
from vllm_ascend.device_allocator.camem import CaMemAllocator
|
from vllm_ascend.device_allocator.camem import CaMemAllocator
|
||||||
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
||||||
from vllm_ascend.platform import NPUPlatform
|
from vllm_ascend.platform import NPUPlatform
|
||||||
@@ -88,6 +88,17 @@ class NPUWorker(WorkerBase):
|
|||||||
# init ascend config and soc version
|
# init ascend config and soc version
|
||||||
init_ascend_config(vllm_config)
|
init_ascend_config(vllm_config)
|
||||||
init_ascend_soc_version()
|
init_ascend_soc_version()
|
||||||
|
if get_ascend_config().use_sfa:
|
||||||
|
# Direct import instead of using try_register_lib to ensure proper error handling when
|
||||||
|
# custom_ops is necessary but not available (e.g., in DeepSeek v3.2 deployments)
|
||||||
|
# yapf: disable
|
||||||
|
import custom_ops # type: ignore # noqa
|
||||||
|
|
||||||
|
# yapf: enable
|
||||||
|
logger.info(
|
||||||
|
"custom_ops module loaded successfully. Custom operators like "
|
||||||
|
"torch.ops.custom.npu_sparse_flash_attention are now available."
|
||||||
|
)
|
||||||
|
|
||||||
super().__init__(vllm_config=vllm_config,
|
super().__init__(vllm_config=vllm_config,
|
||||||
local_rank=local_rank,
|
local_rank=local_rank,
|
||||||
|
|||||||
Reference in New Issue
Block a user