From 8abe5178708977e605b974fff404caceac5b4c81 Mon Sep 17 00:00:00 2001 From: Mengqing Cao Date: Wed, 15 Oct 2025 17:48:58 +0800 Subject: [PATCH] [Refactor] Adapt deepseek-v3.2 to vllm 0.11.0 (#3432) ### What this PR does / why we need it? Adapt deepseek-v3.2 to vllm 0.11.0, removing the useless patch. The final goal is to remove all the patches and align the code arch to vllm, thus we need to do the following work in next prs. TODO: - [x] remove patch on attention spec - [ ] refactor the kvcache creation logic ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? 1. CI passed with existing test. 2. Test pass with deepseek-v3.2-exp - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: MengqingCao --- tests/ut/worker/test_worker_v1.py | 1 + vllm_ascend/__init__.py | 1 - vllm_ascend/ascend_config.py | 2 - vllm_ascend/attention/sfa_v1.py | 11 +- .../llmdatadist_c_mgr_connector.py | 6 +- vllm_ascend/distributed/mooncake_connector.py | 27 +-- vllm_ascend/models/deepseek_v2.py | 69 ++++++- vllm_ascend/models/layers/mla.py | 2 + vllm_ascend/models/layers/sfa.py | 2 +- .../patch/platform/patch_common/__init__.py | 1 - .../patch/worker/patch_common/__init__.py | 2 - .../patch_common/patch_attention_layer.py | 188 ------------------ vllm_ascend/platform.py | 31 +-- vllm_ascend/quantization/quant_config.py | 5 + vllm_ascend/spec_decode/mtp_proposer.py | 6 +- .../torchair/models/torchair_deepseek_v2.py | 10 +- vllm_ascend/torchair/torchair_model_runner.py | 4 +- vllm_ascend/torchair/torchair_sfa.py | 2 +- vllm_ascend/worker/model_runner_v1.py | 27 +-- vllm_ascend/worker/worker_v1.py | 8 +- 20 files changed, 143 insertions(+), 262 deletions(-) delete mode 100644 vllm_ascend/patch/worker/patch_common/patch_attention_layer.py diff --git a/tests/ut/worker/test_worker_v1.py b/tests/ut/worker/test_worker_v1.py index f4551de..31e986d 100644 --- a/tests/ut/worker/test_worker_v1.py +++ b/tests/ut/worker/test_worker_v1.py @@ -18,6 +18,7 @@ class TestNPUWorker(TestBase): self.model_config_mock = MagicMock(spec=ModelConfig) self.model_config_mock.dtype = torch.float16 self.model_config_mock.trust_remote_code = False + self.model_config_mock.hf_config = None self.parallel_config_mock = MagicMock(spec=ParallelConfig) diff --git a/vllm_ascend/__init__.py b/vllm_ascend/__init__.py index f519417..74a8153 100644 --- a/vllm_ascend/__init__.py +++ b/vllm_ascend/__init__.py @@ -23,7 +23,6 @@ def register(): def register_model(): - import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa from .models import register_model register_model() diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index cbd905e..6a60695 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -34,8 +34,6 @@ class AscendConfig: def __init__(self, vllm_config): additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {} - self.is_deepseek_sfa = vllm_config.model_config is not None and vllm_config.model_config.is_deepseek_mla and vllm_config.model_config.hf_text_config.model_type == "deepseek_v32" - self.use_sfa = self.is_deepseek_sfa torchair_graph_config = additional_config.get("torchair_graph_config", {}) diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index edbd7cc..725a2be 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -510,7 +510,6 @@ class AscendSFAImpl(MLAAttentionImpl): ascend_config = get_ascend_config() self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp - self.enable_prefetch = ascend_config.enable_prefetch self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz vllm_config = get_current_vllm_config() @@ -690,6 +689,8 @@ class AscendSFAImpl(MLAAttentionImpl): topk_indices = self.indexer_select(hidden_states_decode, decode_q_c, attn_metadata=attn_metadata, + cos=cos, + sin=sin, kv_cache=kv_cache) query_states = (decode_q_nope, decode_q_pe) @@ -778,6 +779,8 @@ class AscendSFAImpl(MLAAttentionImpl): topk_indices = self.indexer_select(x=hidden_states_prefill, qr=prefill_qr, kv_cache=kv_cache, + cos=cos, + sin=sin, attn_metadata=attn_metadata) query_states = (prefill_q_nope, prefill_q_pe) key_states = (prefill_k_nope, prefill_k_pe) @@ -920,17 +923,15 @@ class AscendSFAImpl(MLAAttentionImpl): x: torch.Tensor, qr: torch.Tensor, kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + cos, + sin, attn_metadata: M, ): if attn_metadata.prefill is not None: - cos = attn_metadata.prefill.cos - sin = attn_metadata.prefill.sin actual_seq_lengths_query = attn_metadata.prefill.query_lens actual_seq_lengths_key = attn_metadata.prefill.seq_lens block_table = attn_metadata.prefill.block_table elif attn_metadata.decode is not None: - cos = attn_metadata.decode.cos - sin = attn_metadata.decode.sin actual_seq_lengths_query = attn_metadata.decode.actual_seq_lengths_q actual_seq_lengths_key = attn_metadata.decode.seq_lens block_table = attn_metadata.decode.block_table diff --git a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py index 64b3c11..1ec0311 100644 --- a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py +++ b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py @@ -501,7 +501,7 @@ class LLMDataDistCMgrConnectorWorker(): self.use_mla: bool = first_kv_cache_tuple[0].size( -1) != first_kv_cache_tuple[1].size(-1) and len( first_kv_cache_tuple) == 2 - self.use_sfa: bool = len(first_kv_cache_tuple) == 3 + self.use_sparse: bool = len(first_kv_cache_tuple) == 3 # MLA case. [2 (k_normed, k_pe), num_blocks, ...] # SFA case. [3 (k_normed, k_pe, k_idx), num_blocks, ...] # MHA case. [2 (k and v), num_blocks, ...] @@ -549,7 +549,7 @@ class LLMDataDistCMgrConnectorWorker(): raise RuntimeError( f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]" ) - elif self.use_sfa: + elif self.use_sparse: cache_k_normed_addr_list = [] cache_k_pe_addr_list = [] cache_k_idx_addr_list = [] @@ -887,7 +887,7 @@ class LLMDataDistCMgrConnectorWorker(): raise RuntimeError( "LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status" ) - elif self.use_sfa: + elif self.use_sparse: remote_cache_key_k_normed = BlocksCacheKey( cluster_id=remote_cluster_id, model_id=0) remote_cache_key_k_pe = BlocksCacheKey( diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index dcdfdf6..ebab077 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -242,7 +242,7 @@ class KVCacheRecvingThread(threading.Thread): self.block_len = block_len # TODO(jianzs): find a better way to detect MLA. self.use_mla = len(block_len) == 2 - self.use_sfa = len(block_len) == 3 + self.use_sparse = len(block_len) == 3 self.request_queue: queue.Queue[Any] = queue.Queue() self.executor = ThreadPoolExecutor(max_workers=32) @@ -373,7 +373,7 @@ class KVCacheRecvingThread(threading.Thread): zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs)): if self.use_mla: block_len = (self.block_len[k % 2]) - elif self.use_sfa: + elif self.use_sparse: block_len = (self.block_len[k % 3]) else: block_len = (self.block_len[0]) @@ -850,7 +850,8 @@ class MooncakeConnectorScheduler: assert "tp_size" in decode_parallel_config.keys() self._decode_tp_size = decode_parallel_config["tp_size"] num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads - if self.vllm_config.model_config.use_mla or self.ascend_config.use_sfa: + if self.vllm_config.model_config.use_mla or hasattr( + self.vllm_config.model_config.hf_config, "index_topk"): num_need_pulls = 1 else: num_p_block_heads = max( @@ -942,7 +943,7 @@ class MooncakeConnectorWorker: # kv_transfer variables self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size - if self.vllm_config.model_config.is_deepseek_mla or self.ascend_config.use_sfa: + if self.vllm_config.model_config.is_deepseek_mla or self.use_sparse: self.num_need_pulls = 1 else: num_d_block_heads = max(1, @@ -995,7 +996,7 @@ class MooncakeConnectorWorker: self.use_mla = first_kv_cache_tuple[0].size( -1) != first_kv_cache_tuple[1].size(-1) and len( first_kv_cache_tuple) == 2 - self.use_sfa = len(first_kv_cache_tuple) == 3 + self.use_sparse = len(first_kv_cache_tuple) == 3 if self.use_mla: # MLA case.[num_block, block_size, 1, hidden_dim] self.num_blocks = first_kv_cache.shape[0] @@ -1009,7 +1010,7 @@ class MooncakeConnectorWorker: logger.info( "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s", self.num_blocks, block_shape_norm, block_shape_pe) - elif self.use_sfa: + elif self.use_sparse: self.num_blocks = first_kv_cache.shape[0] block_rank = 3 # [block_size, latent_dim] block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:] @@ -1037,8 +1038,8 @@ class MooncakeConnectorWorker: logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape) logger.info( - "Registering KV_Caches. use_mla: %s, use_sfa: %s, shape %s", - self.use_mla, self.use_sfa, first_kv_cache.shape) + "Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s", + self.use_mla, self.use_sparse, first_kv_cache.shape) self.kv_caches = kv_caches kv_caches_base_addr = [] @@ -1050,7 +1051,7 @@ class MooncakeConnectorWorker: region_len = self.num_blocks * self.block_len[i % 2] kv_caches_base_addr.append(base_addr) self._register(base_addr, region_len) - elif self.use_sfa: + elif self.use_sparse: for i, cache in enumerate(cache_or_caches, 0): base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len[i % 3] @@ -1059,7 +1060,7 @@ class MooncakeConnectorWorker: else: cache_list = [ cache_or_caches - ] if self.use_mla or self.use_sfa else cache_or_caches + ] if self.use_mla or self.use_sparse else cache_or_caches for cache in cache_list: base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len[0] @@ -1156,9 +1157,9 @@ class MooncakeConnectorWorker: sampled_nums = [] ori_data = np.arange(self._prefill_tp_size) # random split prefill tp list - if self._prefill_tp_size > self.num_key_value_heads or self.vllm_config.model_config.is_deepseek_mla or self.ascend_config.use_sfa: + if self._prefill_tp_size > self.num_key_value_heads or self.vllm_config.model_config.is_deepseek_mla or self.use_sparse: # use deepseek mla, num_key_value_heads == 128, but consider as 1 - if self.vllm_config.model_config.is_deepseek_mla or self.ascend_config.use_sfa: + if self.vllm_config.model_config.is_deepseek_mla or self.use_sparse: num_kv_head = 1 else: num_kv_head = self.num_key_value_heads @@ -1279,4 +1280,4 @@ def ensure_zmq_recv( logger.error(f"Receive failed after all retries: {e}") raise RuntimeError( f"Failed to receive data after {max_retries} " - f"retries: {e}") \ No newline at end of file + f"retries: {e}") diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 4e0ead1..bc61522 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -31,6 +31,7 @@ import torch from torch import nn from transformers import PretrainedConfig from vllm.attention import AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (divide, get_pp_group, get_tensor_model_parallel_rank, @@ -47,7 +48,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mla import MultiHeadLatentAttention from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.models.deepseek_v2 import \ @@ -56,10 +58,11 @@ from vllm.model_executor.models.deepseek_v2 import ( DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM, DeepseekV2MLAAttention, DeepseekV2MLP, DeepseekV2Model, DeepseekV2MoE, get_spec_layer_idx_from_weight_name) -from vllm.model_executor.models.utils import (PPMissingLayer, - is_pp_missing_parameter, - maybe_prefix) +from vllm.model_executor.models.utils import ( + PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.models.layers.mla import AscendMLAModules @@ -69,6 +72,53 @@ from vllm_ascend.ops.common_fused_moe import AscendFusedMoE from vllm_ascend.ops.linear import AscendLinearBase +@support_torch_compile +class AscendDeepseekV2Model(DeepseekV2Model, nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + # Rewrite this init func mainly for removing cuda-hard code + nn.Module.__init__(self) + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + + self.vocab_size = config.vocab_size + self.is_v32 = hasattr(config, "index_topk") + if self.is_v32: + topk_tokens = config.index_topk + topk_indices_buffer = torch.empty( + vllm_config.scheduler_config.max_num_batched_tokens, + topk_tokens, + dtype=torch.int32, + device=current_platform.device_type) + else: + topk_indices_buffer = None + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens") + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix, + topk_indices_buffer), + prefix=f"{prefix}.layers") + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + class CustomDeepseekV2RowParallelLinear(RowParallelLinear): def __init__( @@ -270,6 +320,7 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): scaling_factor = rope_scaling["factor"] mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale + self.indexer = None mla_modules = AscendMLAModules( q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None, @@ -281,6 +332,8 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): kv_b_proj=self.kv_b_proj, o_proj=self.o_proj, rotary_emb=self.rotary_emb, + indexer=None, + is_sparse=hasattr(config, "index_topk"), ) self.mla_attn = MultiHeadLatentAttention( @@ -499,7 +552,6 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config parallel_config = vllm_config.parallel_config - ascend_config = get_ascend_config() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) @@ -515,7 +567,7 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): self.tp_rank = get_tp_group().rank_in_group # TODO: enable mla in vllm-ascend if model_config.use_mla: - if ascend_config.use_sfa: + if hasattr(model_config.hf_config, "index_topk"): attn_cls = CustomDeepseekV2SFAAttention else: attn_cls = CustomDeepseekV2MLAAttention @@ -590,8 +642,9 @@ class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM): "kv_a_proj_with_mqa", ] - self.model = DeepseekV2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = AscendDeepseekV2Model(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, diff --git a/vllm_ascend/models/layers/mla.py b/vllm_ascend/models/layers/mla.py index 57c91bd..e0c7e2d 100644 --- a/vllm_ascend/models/layers/mla.py +++ b/vllm_ascend/models/layers/mla.py @@ -42,6 +42,8 @@ class AscendMLAModules: kv_b_proj: torch.nn.Module o_proj: torch.nn.Module rotary_emb: torch.nn.Module + indexer: Optional[torch.nn.Module] + is_sparse: bool class AscendMultiHeadLatentAttention(MultiHeadLatentAttention): diff --git a/vllm_ascend/models/layers/sfa.py b/vllm_ascend/models/layers/sfa.py index f68281c..23b77c3 100644 --- a/vllm_ascend/models/layers/sfa.py +++ b/vllm_ascend/models/layers/sfa.py @@ -94,7 +94,7 @@ class AscendSparseFlashAttention(MultiHeadLatentAttention): quant_config=quant_config, prefix=f"{prefix}.attn", use_mla=True, - use_sfa=True, + use_sparse=True, # SFA Args q_lora_rank=self.q_lora_rank, kv_lora_rank=self.kv_lora_rank, diff --git a/vllm_ascend/patch/platform/patch_common/__init__.py b/vllm_ascend/patch/platform/patch_common/__init__.py index 30e887a..b1e7f4e 100644 --- a/vllm_ascend/patch/platform/patch_common/__init__.py +++ b/vllm_ascend/patch/platform/patch_common/__init__.py @@ -18,4 +18,3 @@ import vllm_ascend.patch.platform.patch_common.patch_config # noqa import vllm_ascend.patch.platform.patch_common.patch_distributed # noqa import vllm_ascend.patch.platform.patch_common.patch_mamba_config # noqa -import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa diff --git a/vllm_ascend/patch/worker/patch_common/__init__.py b/vllm_ascend/patch/worker/patch_common/__init__.py index 99ec345..2e215b8 100644 --- a/vllm_ascend/patch/worker/patch_common/__init__.py +++ b/vllm_ascend/patch/worker/patch_common/__init__.py @@ -21,8 +21,6 @@ if HAS_TRITON: import vllm_ascend.patch.worker.patch_common.patch_triton # isort: off -import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa -import vllm_ascend.patch.worker.patch_common.patch_attention_layer # noqa import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa import vllm_ascend.patch.worker.patch_common.patch_logits # noqa import vllm_ascend.patch.worker.patch_common.patch_roberta # noqa diff --git a/vllm_ascend/patch/worker/patch_common/patch_attention_layer.py b/vllm_ascend/patch/worker/patch_common/patch_attention_layer.py deleted file mode 100644 index b778d8a..0000000 --- a/vllm_ascend/patch/worker/patch_common/patch_attention_layer.py +++ /dev/null @@ -1,188 +0,0 @@ -from typing import List, Optional - -import torch -import vllm -import vllm.envs as envs -from torch import nn -from vllm.attention import Attention, AttentionType, get_attn_backend -from vllm.attention.backends.abstract import AttentionBackend -from vllm.attention.selector import backend_name_to_enum -from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target -from vllm.config import CacheConfig, get_current_vllm_config -from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.model_executor.layers.linear import UnquantizedLinearMethod -from vllm.model_executor.layers.quantization.base_config import \ - QuantizationConfig -from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod -from vllm.platforms import current_platform - - -class AscendAttention(Attention, nn.Module, AttentionLayerBase): - """Attention layer. - - This class takes query, key, and value tensors as input. The input tensors - can either contain prompt tokens or generation tokens. - The class does the following: - - 1. Store the input key and value tensors in the KV cache. - 2. Perform (multi-head/multi-query/grouped-query) attention. - 3. Return the output tensor. - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - logits_soft_cap: Optional[float] = None, - per_layer_sliding_window: Optional[int] = None, - use_mla: bool = False, - use_sfa: bool = False, - prefix: str = "", - attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, - attn_backend: Optional[type[AttentionBackend]] = None, - **extra_impl_args, - ) -> None: - """ - The KV cache is stored inside this class and is accessed via - `self.kv_cache`. - """ - nn.Module.__init__(self) - AttentionLayerBase.__init__(self) - - if per_layer_sliding_window is not None: - # per-layer sliding window - sliding_window = per_layer_sliding_window - elif cache_config is not None: - # model-level sliding window - sliding_window = cache_config.sliding_window - else: - sliding_window = None - - if cache_config is not None: - kv_cache_dtype = cache_config.cache_dtype - block_size = cache_config.block_size - calculate_kv_scales = cache_config.calculate_kv_scales - else: - kv_cache_dtype = "auto" - block_size = 16 - calculate_kv_scales = False - if num_kv_heads is None: - num_kv_heads = num_heads - assert num_heads % num_kv_heads == 0, \ - f"num_heads ({num_heads}) is not " \ - f"divisible by num_kv_heads ({num_kv_heads})" - - # The default k/v_scale is set to 1.0. This is ignored - # when kv-cache is not fp8, and should be used with - # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we - # expect the pre-quantized k/v_scale to be loaded along - # with the model weights. - self.kv_cache_dtype = kv_cache_dtype - self.calculate_kv_scales = calculate_kv_scales - self._k_scale = torch.tensor(1.0, dtype=torch.float32) - self._v_scale = torch.tensor(1.0, dtype=torch.float32) - # FlashAttn doesn't support quantizing the kv-cache only - # but requires q to be quantized as well. - self._q_scale = torch.tensor(1.0, dtype=torch.float32) - self._prob_scale = torch.tensor(1.0, dtype=torch.float32) - - # We also keep q/k/v_scale on host (cpu) memory for attention - # backends that require the scales to be on host instead of on device. - # e.g. Flashinfer - self._q_scale_float = 1.0 - self._k_scale_float = 1.0 - self._v_scale_float = 1.0 - - # The output scale on host memory. This should be the input scale of - # the quant op after this attention layer. - self._o_scale_float: Optional[float] = None - - self.use_mla = use_mla - self.num_heads = num_heads - self.head_size = head_size - self.num_kv_heads = num_kv_heads - self.sliding_window = sliding_window - self.has_sink = extra_impl_args.get("sinks") is not None - - quant_method = quant_config.get_quant_method( - self, prefix=prefix) if quant_config else None - if quant_method is not None and not isinstance( - quant_method, UnquantizedLinearMethod): - assert isinstance(quant_method, BaseKVCacheMethod) - # TODO (mgoin): kv cache dtype should be specified in the FP8 - # checkpoint config and become the "auto" behavior - if self.kv_cache_dtype == "fp8_e5m2": - raise ValueError("fp8_e5m2 kv-cache is not supported with " - "fp8 checkpoints.") - # If quantization is enabled, we make "k_scale" and "v_scale" - # parameters so that it can be loaded from the model checkpoint. - # The k/v_scale will then be converted back to native float32 - # values after weight loading. - self.quant_method = quant_method - self.quant_method.create_weights(self) - - # During model initialization, the default dtype is set as the model - # weight and activation dtype. - dtype = torch.get_default_dtype() - if attn_backend is None: - self.attn_backend = get_attn_backend(head_size, - dtype, - kv_cache_dtype, - block_size, - use_mla=use_mla, - use_sfa=use_sfa, - has_sink=self.has_sink) - else: - self.attn_backend = attn_backend - - impl_cls = self.attn_backend.get_impl_cls() - self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **extra_impl_args) - self.backend = backend_name_to_enum(self.attn_backend.get_name()) - self.dtype = dtype - - # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how - # torch.compile works by registering the attention as one giant - # opaque custom op. For other platforms, we directly call them - # and let torch.compile handle them. - self.use_direct_call = not current_platform.opaque_attention_op() - - self.use_output = self.attn_backend.accept_output_buffer - compilation_config = get_current_vllm_config().compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - compilation_config.static_forward_context[prefix] = self - self.layer_name = prefix - self.attn_type = attn_type - - if kv_sharing_target_layer_name is not None: - validate_kv_sharing_target( - prefix, - kv_sharing_target_layer_name, - compilation_config.static_forward_context, - ) - self.kv_sharing_target_layer_name = kv_sharing_target_layer_name - - # use a placeholder kv cache tensor during init, which will be replaced - # by bind_kv_cache - # this variable will not be accessed if use_direct_call is True - self.kv_cache = [ - torch.tensor([]) for _ in range(get_current_vllm_config( - ).parallel_config.pipeline_parallel_size) - ] - - self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) - self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) - self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) - self.query_quant = None - - -vllm.attention.Attention = AscendAttention diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index b00e8d6..b9eaf5f 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -164,6 +164,9 @@ class NPUPlatform(Platform): "kv_cache_dtype", None) if kv_cache_dtype is not None: vllm_config.cache_config.cache_dtype = kv_cache_dtype + elif model_config and hasattr(model_config.hf_config, "index_topk"): + vllm_config.cache_config.cache_dtype = str( + model_config.dtype).replace("torch.", "") if model_config is None: logger.warning("Model config is missing. This may indicate " "that we are running a test case") @@ -284,25 +287,27 @@ class NPUPlatform(Platform): vllm_config.scheduler_config = ascend_scheduler_config @classmethod - def get_attn_backend_cls(cls, - selected_backend, - head_size, - dtype, - kv_cache_dtype, - block_size, - use_v1, - use_mla, - use_sfa, - has_sink=False): + def get_attn_backend_cls( + cls, + selected_backend, + head_size, + dtype, + kv_cache_dtype, + block_size, + use_v1, + use_mla, + has_sink=False, + use_sparse=False, + ): if not use_v1: raise ValueError("vLLM Ascend does not support V0 engine.") ascend_config = get_ascend_config() if use_mla and ascend_config.enable_shared_expert_dp: - if use_mla and not use_sfa: + if use_mla and not use_sparse: return "vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend" - if use_mla and use_sfa: + if use_mla and use_sparse: return "vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend" use_torchair = ascend_config.torchair_graph_config.enabled @@ -321,7 +326,7 @@ class NPUPlatform(Platform): (True, True, True): "vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend", } - return backend_map[(use_mla, use_sfa, use_torchair)] + return backend_map[(use_mla, use_sparse, use_torchair)] @classmethod def get_punica_wrapper(cls) -> str: diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 185b206..f484400 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -183,6 +183,11 @@ packed_modules_model_mapping = { "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] }, + "deepseek_v32": { + "gate_up_proj": ["gate_proj", "up_proj"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] + }, # NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized; # NOTE 2.The description file generated by the current msmodelslim tool does not have # MTP layer info. Please manually add it and set the value to FLOAT. diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index f42d381..d5f1216 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -67,6 +67,8 @@ class MtpProposer(Proposer): 1, device=self.runner.device, dtype=torch.int32) + self.use_sparse = hasattr(vllm_config.model_config.hf_config, + "index_topk") def load_model(self, model) -> None: loader = get_model_loader(self.vllm_config.load_config) @@ -613,7 +615,7 @@ class MtpProposer(Proposer): npu_backend = torchair.get_npu_backend(compiler_config=config) self.torchair_compiled_model = torch.compile( self.model, - dynamic=not get_ascend_config().use_sfa, + dynamic=not self.use_sparse, fullgraph=True, backend=npu_backend) return self.torchair_compiled_model @@ -636,7 +638,7 @@ class MtpProposer(Proposer): self.torchair_compiled_models[ batch_size] = torchair.inference.cache_compile( self.model.__dict__[forward_proxy_name], - dynamic=not get_ascend_config().use_sfa, + dynamic=not self.use_sparse, fullgraph=True, cache_dir=TORCHAIR_CACHE_DIR, config=config, diff --git a/vllm_ascend/torchair/models/torchair_deepseek_v2.py b/vllm_ascend/torchair/models/torchair_deepseek_v2.py index 3b69e25..462e05b 100644 --- a/vllm_ascend/torchair/models/torchair_deepseek_v2.py +++ b/vllm_ascend/torchair/models/torchair_deepseek_v2.py @@ -791,7 +791,7 @@ class TorchairDeepseekV2SFAAttention(DeepseekV2MLAAttention): quant_config=quant_config, prefix=f"{prefix}.attn", use_mla=True, - use_sfa=True, + use_sparse=True, # SFA Args q_lora_rank=self.q_lora_rank, kv_lora_rank=self.kv_lora_rank, @@ -879,12 +879,12 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): self.tp_rank = get_tp_group().rank_in_group ascend_config = get_ascend_config() self.use_mla = False - self.use_sfa = False + self.use_sparse = False # TODO: enable mla in vllm-ascend if model_config.use_mla: - if ascend_config.use_sfa: + if hasattr(model_config.hf_config, "index_topk"): attn_cls = TorchairDeepseekV2SFAAttention - self.use_sfa = True + self.use_sparse = True else: attn_cls = TorchairDeepseekV2MLAAttention # type: ignore[assignment] self.use_mla = True @@ -950,7 +950,7 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): forward_context = get_forward_context() if attn_metadata is not None: decoding_condition_met = ( - not attn_metadata.is_prefill if self.use_sfa else + not attn_metadata.is_prefill if self.use_sparse else not forward_context.with_prefill if self.use_mla else False) mla_moe_communication = decoding_condition_met and self.mla_moe_communication and replace_allreduce else: diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index daf6b5d..b29fdc2 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -376,7 +376,7 @@ class NPUTorchairModelRunner(NPUModelRunner): npu_backend = torchair.get_npu_backend(compiler_config=config) self.torchair_compiled_model = torch.compile( self.model, - dynamic=not self.ascend_config.use_sfa, + dynamic=not self.use_sparse, fullgraph=True, backend=npu_backend) return self.torchair_compiled_model @@ -399,7 +399,7 @@ class NPUTorchairModelRunner(NPUModelRunner): self.torchair_compiled_models[ batch_size] = torchair.inference.cache_compile( self.model.__dict__[forward_proxy_name], - dynamic=not self.ascend_config.use_sfa, + dynamic=not self.use_sparse, fullgraph=True, cache_dir=TORCHAIR_CACHE_DIR, config=config, diff --git a/vllm_ascend/torchair/torchair_sfa.py b/vllm_ascend/torchair/torchair_sfa.py index f4a00e5..abe30b7 100644 --- a/vllm_ascend/torchair/torchair_sfa.py +++ b/vllm_ascend/torchair/torchair_sfa.py @@ -738,7 +738,7 @@ class AscendSFATorchairImpl(MLAAttentionImpl): ascend_config = get_ascend_config() self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp - self.enable_prefetch = ascend_config.enable_prefetch + self.enable_prefetch = ascend_config.weight_prefetch_config.enabled self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz if ascend_config.torchair_graph_config.enabled: self.graph_batch_size = ascend_config.torchair_graph_config.graph_batch_sizes[ diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index e30fc66..95acbe4 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -309,13 +309,14 @@ class NPUModelRunner(LoRAModelRunnerMixin): dtype=self.dtype, device=self.device) # Set up Attention - self.attn_backend = get_attn_backend( - 0, - self.dtype, - None, - self.block_size, - use_mla=self.model_config.use_mla, - use_sfa=self.ascend_config.use_sfa) + self.use_sparse = hasattr(self.vllm_config.model_config.hf_config, + "index_topk") + self.attn_backend = get_attn_backend(0, + self.dtype, + None, + self.block_size, + use_mla=self.model_config.use_mla, + use_sparse=self.use_sparse) if torch.version.cann.startswith("8.3"): self.attn_mask_builder = AttentionMaskBuilder( self.scheduler_config.max_num_batched_tokens, self.dtype, @@ -871,7 +872,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS": return self.attn_mask_builder.get_pooling_mask(self.device) # Chunk Prefill situation. - elif attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.ascend_config.use_sfa: + elif attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.use_sparse: if torch.version.cann.startswith("8.3"): return self.attn_mask_builder.get_splitfuse_attn_mask() else: @@ -1507,7 +1508,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): model=self.get_model(), **extra_attn_metadata_args) - if self.vllm_config.model_config.use_mla or self.ascend_config.use_sfa: + if self.vllm_config.model_config.use_mla or self.use_sparse: attn_metadata_i.num_input_tokens = num_input_tokens for layer_name in attn_group.layer_names: attn_metadata[layer_name] = attn_metadata_i @@ -2655,7 +2656,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.may_add_encoder_only_layers_to_kv_cache_config() self.initialize_attn_backend(kv_cache_config) - if self.ascend_config.is_deepseek_sfa: + if self.use_sparse: kv_caches = self.initialize_kv_cache_tensors_deepseek_sfa( kv_cache_config) elif self.model_config.is_deepseek_mla: @@ -2699,7 +2700,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) elif hasattr( attn_backend, "get_supported_block_size" - ) and not self.model_config.is_deepseek_mla and not self.ascend_config.is_deepseek_sfa: + ) and not self.model_config.is_deepseek_mla and not self.use_sparse: block_size = attn_backend.get_supported_block_size()[0] block_size_chunk = kv_cache_spec.block_size // block_size kv_cache_shape = attn_backend.get_kv_cache_shape( @@ -3245,7 +3246,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): block_size = self.vllm_config.cache_config.block_size use_mla = self.vllm_config.model_config.use_mla - use_sfa = self.ascend_config.use_sfa + use_sparse = self.use_sparse kv_cache_spec: dict[str, KVCacheSpec] = {} attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) for layer_name, attn_module in attn_layers.items(): @@ -3267,7 +3268,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): # TODO(lucas): move the attention specs into the model layers like # the attention backends if attn_module.attn_type == AttentionType.DECODER: - if use_mla and not use_sfa: + if use_mla and not use_sparse: kv_cache_spec[layer_name] = MLAAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index f1acda3..e3ced0a 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -43,7 +43,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, from vllm.v1.worker.worker_base import WorkerBase import vllm_ascend.envs as envs_ascend -from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config +from vllm_ascend.ascend_config import init_ascend_config from vllm_ascend.device_allocator.camem import CaMemAllocator from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel from vllm_ascend.platform import NPUPlatform @@ -88,7 +88,11 @@ class NPUWorker(WorkerBase): # init ascend config and soc version init_ascend_config(vllm_config) init_ascend_soc_version() - if get_ascend_config().use_sfa: + use_sparse = False + if vllm_config.model_config is not None: + use_sparse = hasattr(vllm_config.model_config.hf_config, + "index_topk") + if use_sparse: # Direct import instead of using try_register_lib to ensure proper error handling when # custom_ops is necessary but not available (e.g., in DeepSeek v3.2 deployments) # yapf: disable