diff --git a/tests/ut/worker/test_worker_v1.py b/tests/ut/worker/test_worker_v1.py index f4551de..31e986d 100644 --- a/tests/ut/worker/test_worker_v1.py +++ b/tests/ut/worker/test_worker_v1.py @@ -18,6 +18,7 @@ class TestNPUWorker(TestBase): self.model_config_mock = MagicMock(spec=ModelConfig) self.model_config_mock.dtype = torch.float16 self.model_config_mock.trust_remote_code = False + self.model_config_mock.hf_config = None self.parallel_config_mock = MagicMock(spec=ParallelConfig) diff --git a/vllm_ascend/__init__.py b/vllm_ascend/__init__.py index f519417..74a8153 100644 --- a/vllm_ascend/__init__.py +++ b/vllm_ascend/__init__.py @@ -23,7 +23,6 @@ def register(): def register_model(): - import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa from .models import register_model register_model() diff --git a/vllm_ascend/ascend_config.py b/vllm_ascend/ascend_config.py index cbd905e..6a60695 100644 --- a/vllm_ascend/ascend_config.py +++ b/vllm_ascend/ascend_config.py @@ -34,8 +34,6 @@ class AscendConfig: def __init__(self, vllm_config): additional_config = vllm_config.additional_config if vllm_config.additional_config is not None else {} - self.is_deepseek_sfa = vllm_config.model_config is not None and vllm_config.model_config.is_deepseek_mla and vllm_config.model_config.hf_text_config.model_type == "deepseek_v32" - self.use_sfa = self.is_deepseek_sfa torchair_graph_config = additional_config.get("torchair_graph_config", {}) diff --git a/vllm_ascend/attention/sfa_v1.py b/vllm_ascend/attention/sfa_v1.py index edbd7cc..725a2be 100644 --- a/vllm_ascend/attention/sfa_v1.py +++ b/vllm_ascend/attention/sfa_v1.py @@ -510,7 +510,6 @@ class AscendSFAImpl(MLAAttentionImpl): ascend_config = get_ascend_config() self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp - self.enable_prefetch = ascend_config.enable_prefetch self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz vllm_config = get_current_vllm_config() @@ -690,6 +689,8 @@ class AscendSFAImpl(MLAAttentionImpl): topk_indices = self.indexer_select(hidden_states_decode, decode_q_c, attn_metadata=attn_metadata, + cos=cos, + sin=sin, kv_cache=kv_cache) query_states = (decode_q_nope, decode_q_pe) @@ -778,6 +779,8 @@ class AscendSFAImpl(MLAAttentionImpl): topk_indices = self.indexer_select(x=hidden_states_prefill, qr=prefill_qr, kv_cache=kv_cache, + cos=cos, + sin=sin, attn_metadata=attn_metadata) query_states = (prefill_q_nope, prefill_q_pe) key_states = (prefill_k_nope, prefill_k_pe) @@ -920,17 +923,15 @@ class AscendSFAImpl(MLAAttentionImpl): x: torch.Tensor, qr: torch.Tensor, kv_cache: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + cos, + sin, attn_metadata: M, ): if attn_metadata.prefill is not None: - cos = attn_metadata.prefill.cos - sin = attn_metadata.prefill.sin actual_seq_lengths_query = attn_metadata.prefill.query_lens actual_seq_lengths_key = attn_metadata.prefill.seq_lens block_table = attn_metadata.prefill.block_table elif attn_metadata.decode is not None: - cos = attn_metadata.decode.cos - sin = attn_metadata.decode.sin actual_seq_lengths_query = attn_metadata.decode.actual_seq_lengths_q actual_seq_lengths_key = attn_metadata.decode.seq_lens block_table = attn_metadata.decode.block_table diff --git a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py index 64b3c11..1ec0311 100644 --- a/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py +++ b/vllm_ascend/distributed/llmdatadist_c_mgr_connector.py @@ -501,7 +501,7 @@ class LLMDataDistCMgrConnectorWorker(): self.use_mla: bool = first_kv_cache_tuple[0].size( -1) != first_kv_cache_tuple[1].size(-1) and len( first_kv_cache_tuple) == 2 - self.use_sfa: bool = len(first_kv_cache_tuple) == 3 + self.use_sparse: bool = len(first_kv_cache_tuple) == 3 # MLA case. [2 (k_normed, k_pe), num_blocks, ...] # SFA case. [3 (k_normed, k_pe, k_idx), num_blocks, ...] # MHA case. [2 (k and v), num_blocks, ...] @@ -549,7 +549,7 @@ class LLMDataDistCMgrConnectorWorker(): raise RuntimeError( f"LLMDataDistCMgrConnectorWorker: Passing unexpected parameter to register_block_cache, receiving [cache_desc: {self.cache_desc}, cache_addr: {self.cache_addr}, cache_key: {self.cache_key}]" ) - elif self.use_sfa: + elif self.use_sparse: cache_k_normed_addr_list = [] cache_k_pe_addr_list = [] cache_k_idx_addr_list = [] @@ -887,7 +887,7 @@ class LLMDataDistCMgrConnectorWorker(): raise RuntimeError( "LLMDataDistCMgrConnectorWorker: Timeout during pull_blocks, you can try to increase the sync_kv_timeout config or checking your connect status" ) - elif self.use_sfa: + elif self.use_sparse: remote_cache_key_k_normed = BlocksCacheKey( cluster_id=remote_cluster_id, model_id=0) remote_cache_key_k_pe = BlocksCacheKey( diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index dcdfdf6..ebab077 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -242,7 +242,7 @@ class KVCacheRecvingThread(threading.Thread): self.block_len = block_len # TODO(jianzs): find a better way to detect MLA. self.use_mla = len(block_len) == 2 - self.use_sfa = len(block_len) == 3 + self.use_sparse = len(block_len) == 3 self.request_queue: queue.Queue[Any] = queue.Queue() self.executor = ThreadPoolExecutor(max_workers=32) @@ -373,7 +373,7 @@ class KVCacheRecvingThread(threading.Thread): zip(local_kv_caches_base_addrs, remote_kv_caches_base_addrs)): if self.use_mla: block_len = (self.block_len[k % 2]) - elif self.use_sfa: + elif self.use_sparse: block_len = (self.block_len[k % 3]) else: block_len = (self.block_len[0]) @@ -850,7 +850,8 @@ class MooncakeConnectorScheduler: assert "tp_size" in decode_parallel_config.keys() self._decode_tp_size = decode_parallel_config["tp_size"] num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads - if self.vllm_config.model_config.use_mla or self.ascend_config.use_sfa: + if self.vllm_config.model_config.use_mla or hasattr( + self.vllm_config.model_config.hf_config, "index_topk"): num_need_pulls = 1 else: num_p_block_heads = max( @@ -942,7 +943,7 @@ class MooncakeConnectorWorker: # kv_transfer variables self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size - if self.vllm_config.model_config.is_deepseek_mla or self.ascend_config.use_sfa: + if self.vllm_config.model_config.is_deepseek_mla or self.use_sparse: self.num_need_pulls = 1 else: num_d_block_heads = max(1, @@ -995,7 +996,7 @@ class MooncakeConnectorWorker: self.use_mla = first_kv_cache_tuple[0].size( -1) != first_kv_cache_tuple[1].size(-1) and len( first_kv_cache_tuple) == 2 - self.use_sfa = len(first_kv_cache_tuple) == 3 + self.use_sparse = len(first_kv_cache_tuple) == 3 if self.use_mla: # MLA case.[num_block, block_size, 1, hidden_dim] self.num_blocks = first_kv_cache.shape[0] @@ -1009,7 +1010,7 @@ class MooncakeConnectorWorker: logger.info( "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s", self.num_blocks, block_shape_norm, block_shape_pe) - elif self.use_sfa: + elif self.use_sparse: self.num_blocks = first_kv_cache.shape[0] block_rank = 3 # [block_size, latent_dim] block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:] @@ -1037,8 +1038,8 @@ class MooncakeConnectorWorker: logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, block_shape) logger.info( - "Registering KV_Caches. use_mla: %s, use_sfa: %s, shape %s", - self.use_mla, self.use_sfa, first_kv_cache.shape) + "Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s", + self.use_mla, self.use_sparse, first_kv_cache.shape) self.kv_caches = kv_caches kv_caches_base_addr = [] @@ -1050,7 +1051,7 @@ class MooncakeConnectorWorker: region_len = self.num_blocks * self.block_len[i % 2] kv_caches_base_addr.append(base_addr) self._register(base_addr, region_len) - elif self.use_sfa: + elif self.use_sparse: for i, cache in enumerate(cache_or_caches, 0): base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len[i % 3] @@ -1059,7 +1060,7 @@ class MooncakeConnectorWorker: else: cache_list = [ cache_or_caches - ] if self.use_mla or self.use_sfa else cache_or_caches + ] if self.use_mla or self.use_sparse else cache_or_caches for cache in cache_list: base_addr = cache.data_ptr() region_len = self.num_blocks * self.block_len[0] @@ -1156,9 +1157,9 @@ class MooncakeConnectorWorker: sampled_nums = [] ori_data = np.arange(self._prefill_tp_size) # random split prefill tp list - if self._prefill_tp_size > self.num_key_value_heads or self.vllm_config.model_config.is_deepseek_mla or self.ascend_config.use_sfa: + if self._prefill_tp_size > self.num_key_value_heads or self.vllm_config.model_config.is_deepseek_mla or self.use_sparse: # use deepseek mla, num_key_value_heads == 128, but consider as 1 - if self.vllm_config.model_config.is_deepseek_mla or self.ascend_config.use_sfa: + if self.vllm_config.model_config.is_deepseek_mla or self.use_sparse: num_kv_head = 1 else: num_kv_head = self.num_key_value_heads @@ -1279,4 +1280,4 @@ def ensure_zmq_recv( logger.error(f"Receive failed after all retries: {e}") raise RuntimeError( f"Failed to receive data after {max_retries} " - f"retries: {e}") \ No newline at end of file + f"retries: {e}") diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 4e0ead1..bc61522 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -31,6 +31,7 @@ import torch from torch import nn from transformers import PretrainedConfig from vllm.attention import AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import (divide, get_pp_group, get_tensor_model_parallel_rank, @@ -47,7 +48,8 @@ from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.mla import MultiHeadLatentAttention from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name) from vllm.model_executor.models.deepseek_v2 import \ @@ -56,10 +58,11 @@ from vllm.model_executor.models.deepseek_v2 import ( DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM, DeepseekV2MLAAttention, DeepseekV2MLP, DeepseekV2Model, DeepseekV2MoE, get_spec_layer_idx_from_weight_name) -from vllm.model_executor.models.utils import (PPMissingLayer, - is_pp_missing_parameter, - maybe_prefix) +from vllm.model_executor.models.utils import ( + PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.models.layers.mla import AscendMLAModules @@ -69,6 +72,53 @@ from vllm_ascend.ops.common_fused_moe import AscendFusedMoE from vllm_ascend.ops.linear import AscendLinearBase +@support_torch_compile +class AscendDeepseekV2Model(DeepseekV2Model, nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + # Rewrite this init func mainly for removing cuda-hard code + nn.Module.__init__(self) + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + + self.vocab_size = config.vocab_size + self.is_v32 = hasattr(config, "index_topk") + if self.is_v32: + topk_tokens = config.index_topk + topk_indices_buffer = torch.empty( + vllm_config.scheduler_config.max_num_batched_tokens, + topk_tokens, + dtype=torch.int32, + device=current_platform.device_type) + else: + topk_indices_buffer = None + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens") + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix, + topk_indices_buffer), + prefix=f"{prefix}.layers") + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + class CustomDeepseekV2RowParallelLinear(RowParallelLinear): def __init__( @@ -270,6 +320,7 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): scaling_factor = rope_scaling["factor"] mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale + self.indexer = None mla_modules = AscendMLAModules( q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None, @@ -281,6 +332,8 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): kv_b_proj=self.kv_b_proj, o_proj=self.o_proj, rotary_emb=self.rotary_emb, + indexer=None, + is_sparse=hasattr(config, "index_topk"), ) self.mla_attn = MultiHeadLatentAttention( @@ -499,7 +552,6 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config parallel_config = vllm_config.parallel_config - ascend_config = get_ascend_config() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) @@ -515,7 +567,7 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): self.tp_rank = get_tp_group().rank_in_group # TODO: enable mla in vllm-ascend if model_config.use_mla: - if ascend_config.use_sfa: + if hasattr(model_config.hf_config, "index_topk"): attn_cls = CustomDeepseekV2SFAAttention else: attn_cls = CustomDeepseekV2MLAAttention @@ -590,8 +642,9 @@ class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM): "kv_a_proj_with_mqa", ] - self.model = DeepseekV2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = AscendDeepseekV2Model(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, diff --git a/vllm_ascend/models/layers/mla.py b/vllm_ascend/models/layers/mla.py index 57c91bd..e0c7e2d 100644 --- a/vllm_ascend/models/layers/mla.py +++ b/vllm_ascend/models/layers/mla.py @@ -42,6 +42,8 @@ class AscendMLAModules: kv_b_proj: torch.nn.Module o_proj: torch.nn.Module rotary_emb: torch.nn.Module + indexer: Optional[torch.nn.Module] + is_sparse: bool class AscendMultiHeadLatentAttention(MultiHeadLatentAttention): diff --git a/vllm_ascend/models/layers/sfa.py b/vllm_ascend/models/layers/sfa.py index f68281c..23b77c3 100644 --- a/vllm_ascend/models/layers/sfa.py +++ b/vllm_ascend/models/layers/sfa.py @@ -94,7 +94,7 @@ class AscendSparseFlashAttention(MultiHeadLatentAttention): quant_config=quant_config, prefix=f"{prefix}.attn", use_mla=True, - use_sfa=True, + use_sparse=True, # SFA Args q_lora_rank=self.q_lora_rank, kv_lora_rank=self.kv_lora_rank, diff --git a/vllm_ascend/patch/platform/patch_common/__init__.py b/vllm_ascend/patch/platform/patch_common/__init__.py index 30e887a..b1e7f4e 100644 --- a/vllm_ascend/patch/platform/patch_common/__init__.py +++ b/vllm_ascend/patch/platform/patch_common/__init__.py @@ -18,4 +18,3 @@ import vllm_ascend.patch.platform.patch_common.patch_config # noqa import vllm_ascend.patch.platform.patch_common.patch_distributed # noqa import vllm_ascend.patch.platform.patch_common.patch_mamba_config # noqa -import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa diff --git a/vllm_ascend/patch/worker/patch_common/__init__.py b/vllm_ascend/patch/worker/patch_common/__init__.py index 99ec345..2e215b8 100644 --- a/vllm_ascend/patch/worker/patch_common/__init__.py +++ b/vllm_ascend/patch/worker/patch_common/__init__.py @@ -21,8 +21,6 @@ if HAS_TRITON: import vllm_ascend.patch.worker.patch_common.patch_triton # isort: off -import vllm_ascend.patch.worker.patch_common.patch_attention_selector # noqa -import vllm_ascend.patch.worker.patch_common.patch_attention_layer # noqa import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa import vllm_ascend.patch.worker.patch_common.patch_logits # noqa import vllm_ascend.patch.worker.patch_common.patch_roberta # noqa diff --git a/vllm_ascend/patch/worker/patch_common/patch_attention_layer.py b/vllm_ascend/patch/worker/patch_common/patch_attention_layer.py deleted file mode 100644 index b778d8a..0000000 --- a/vllm_ascend/patch/worker/patch_common/patch_attention_layer.py +++ /dev/null @@ -1,188 +0,0 @@ -from typing import List, Optional - -import torch -import vllm -import vllm.envs as envs -from torch import nn -from vllm.attention import Attention, AttentionType, get_attn_backend -from vllm.attention.backends.abstract import AttentionBackend -from vllm.attention.selector import backend_name_to_enum -from vllm.attention.utils.kv_sharing_utils import validate_kv_sharing_target -from vllm.config import CacheConfig, get_current_vllm_config -from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase -from vllm.model_executor.layers.linear import UnquantizedLinearMethod -from vllm.model_executor.layers.quantization.base_config import \ - QuantizationConfig -from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod -from vllm.platforms import current_platform - - -class AscendAttention(Attention, nn.Module, AttentionLayerBase): - """Attention layer. - - This class takes query, key, and value tensors as input. The input tensors - can either contain prompt tokens or generation tokens. - The class does the following: - - 1. Store the input key and value tensors in the KV cache. - 2. Perform (multi-head/multi-query/grouped-query) attention. - 3. Return the output tensor. - """ - - def __init__( - self, - num_heads: int, - head_size: int, - scale: float, - num_kv_heads: Optional[int] = None, - alibi_slopes: Optional[List[float]] = None, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - logits_soft_cap: Optional[float] = None, - per_layer_sliding_window: Optional[int] = None, - use_mla: bool = False, - use_sfa: bool = False, - prefix: str = "", - attn_type: str = AttentionType.DECODER, - kv_sharing_target_layer_name: Optional[str] = None, - attn_backend: Optional[type[AttentionBackend]] = None, - **extra_impl_args, - ) -> None: - """ - The KV cache is stored inside this class and is accessed via - `self.kv_cache`. - """ - nn.Module.__init__(self) - AttentionLayerBase.__init__(self) - - if per_layer_sliding_window is not None: - # per-layer sliding window - sliding_window = per_layer_sliding_window - elif cache_config is not None: - # model-level sliding window - sliding_window = cache_config.sliding_window - else: - sliding_window = None - - if cache_config is not None: - kv_cache_dtype = cache_config.cache_dtype - block_size = cache_config.block_size - calculate_kv_scales = cache_config.calculate_kv_scales - else: - kv_cache_dtype = "auto" - block_size = 16 - calculate_kv_scales = False - if num_kv_heads is None: - num_kv_heads = num_heads - assert num_heads % num_kv_heads == 0, \ - f"num_heads ({num_heads}) is not " \ - f"divisible by num_kv_heads ({num_kv_heads})" - - # The default k/v_scale is set to 1.0. This is ignored - # when kv-cache is not fp8, and should be used with - # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we - # expect the pre-quantized k/v_scale to be loaded along - # with the model weights. - self.kv_cache_dtype = kv_cache_dtype - self.calculate_kv_scales = calculate_kv_scales - self._k_scale = torch.tensor(1.0, dtype=torch.float32) - self._v_scale = torch.tensor(1.0, dtype=torch.float32) - # FlashAttn doesn't support quantizing the kv-cache only - # but requires q to be quantized as well. - self._q_scale = torch.tensor(1.0, dtype=torch.float32) - self._prob_scale = torch.tensor(1.0, dtype=torch.float32) - - # We also keep q/k/v_scale on host (cpu) memory for attention - # backends that require the scales to be on host instead of on device. - # e.g. Flashinfer - self._q_scale_float = 1.0 - self._k_scale_float = 1.0 - self._v_scale_float = 1.0 - - # The output scale on host memory. This should be the input scale of - # the quant op after this attention layer. - self._o_scale_float: Optional[float] = None - - self.use_mla = use_mla - self.num_heads = num_heads - self.head_size = head_size - self.num_kv_heads = num_kv_heads - self.sliding_window = sliding_window - self.has_sink = extra_impl_args.get("sinks") is not None - - quant_method = quant_config.get_quant_method( - self, prefix=prefix) if quant_config else None - if quant_method is not None and not isinstance( - quant_method, UnquantizedLinearMethod): - assert isinstance(quant_method, BaseKVCacheMethod) - # TODO (mgoin): kv cache dtype should be specified in the FP8 - # checkpoint config and become the "auto" behavior - if self.kv_cache_dtype == "fp8_e5m2": - raise ValueError("fp8_e5m2 kv-cache is not supported with " - "fp8 checkpoints.") - # If quantization is enabled, we make "k_scale" and "v_scale" - # parameters so that it can be loaded from the model checkpoint. - # The k/v_scale will then be converted back to native float32 - # values after weight loading. - self.quant_method = quant_method - self.quant_method.create_weights(self) - - # During model initialization, the default dtype is set as the model - # weight and activation dtype. - dtype = torch.get_default_dtype() - if attn_backend is None: - self.attn_backend = get_attn_backend(head_size, - dtype, - kv_cache_dtype, - block_size, - use_mla=use_mla, - use_sfa=use_sfa, - has_sink=self.has_sink) - else: - self.attn_backend = attn_backend - - impl_cls = self.attn_backend.get_impl_cls() - self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, - alibi_slopes, sliding_window, kv_cache_dtype, - logits_soft_cap, attn_type, - kv_sharing_target_layer_name, **extra_impl_args) - self.backend = backend_name_to_enum(self.attn_backend.get_name()) - self.dtype = dtype - - # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how - # torch.compile works by registering the attention as one giant - # opaque custom op. For other platforms, we directly call them - # and let torch.compile handle them. - self.use_direct_call = not current_platform.opaque_attention_op() - - self.use_output = self.attn_backend.accept_output_buffer - compilation_config = get_current_vllm_config().compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") - compilation_config.static_forward_context[prefix] = self - self.layer_name = prefix - self.attn_type = attn_type - - if kv_sharing_target_layer_name is not None: - validate_kv_sharing_target( - prefix, - kv_sharing_target_layer_name, - compilation_config.static_forward_context, - ) - self.kv_sharing_target_layer_name = kv_sharing_target_layer_name - - # use a placeholder kv cache tensor during init, which will be replaced - # by bind_kv_cache - # this variable will not be accessed if use_direct_call is True - self.kv_cache = [ - torch.tensor([]) for _ in range(get_current_vllm_config( - ).parallel_config.pipeline_parallel_size) - ] - - self.q_range = torch.tensor(envs.Q_SCALE_CONSTANT, dtype=torch.float32) - self.k_range = torch.tensor(envs.K_SCALE_CONSTANT, dtype=torch.float32) - self.v_range = torch.tensor(envs.V_SCALE_CONSTANT, dtype=torch.float32) - self.query_quant = None - - -vllm.attention.Attention = AscendAttention diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index b00e8d6..b9eaf5f 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -164,6 +164,9 @@ class NPUPlatform(Platform): "kv_cache_dtype", None) if kv_cache_dtype is not None: vllm_config.cache_config.cache_dtype = kv_cache_dtype + elif model_config and hasattr(model_config.hf_config, "index_topk"): + vllm_config.cache_config.cache_dtype = str( + model_config.dtype).replace("torch.", "") if model_config is None: logger.warning("Model config is missing. This may indicate " "that we are running a test case") @@ -284,25 +287,27 @@ class NPUPlatform(Platform): vllm_config.scheduler_config = ascend_scheduler_config @classmethod - def get_attn_backend_cls(cls, - selected_backend, - head_size, - dtype, - kv_cache_dtype, - block_size, - use_v1, - use_mla, - use_sfa, - has_sink=False): + def get_attn_backend_cls( + cls, + selected_backend, + head_size, + dtype, + kv_cache_dtype, + block_size, + use_v1, + use_mla, + has_sink=False, + use_sparse=False, + ): if not use_v1: raise ValueError("vLLM Ascend does not support V0 engine.") ascend_config = get_ascend_config() if use_mla and ascend_config.enable_shared_expert_dp: - if use_mla and not use_sfa: + if use_mla and not use_sparse: return "vllm_ascend.torchair.torchair_mla.AscendMLATorchairBackend" - if use_mla and use_sfa: + if use_mla and use_sparse: return "vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend" use_torchair = ascend_config.torchair_graph_config.enabled @@ -321,7 +326,7 @@ class NPUPlatform(Platform): (True, True, True): "vllm_ascend.torchair.torchair_sfa.AscendSFATorchairBackend", } - return backend_map[(use_mla, use_sfa, use_torchair)] + return backend_map[(use_mla, use_sparse, use_torchair)] @classmethod def get_punica_wrapper(cls) -> str: diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 185b206..f484400 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -183,6 +183,11 @@ packed_modules_model_mapping = { "experts": ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] }, + "deepseek_v32": { + "gate_up_proj": ["gate_proj", "up_proj"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] + }, # NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized; # NOTE 2.The description file generated by the current msmodelslim tool does not have # MTP layer info. Please manually add it and set the value to FLOAT. diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index f42d381..d5f1216 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -67,6 +67,8 @@ class MtpProposer(Proposer): 1, device=self.runner.device, dtype=torch.int32) + self.use_sparse = hasattr(vllm_config.model_config.hf_config, + "index_topk") def load_model(self, model) -> None: loader = get_model_loader(self.vllm_config.load_config) @@ -613,7 +615,7 @@ class MtpProposer(Proposer): npu_backend = torchair.get_npu_backend(compiler_config=config) self.torchair_compiled_model = torch.compile( self.model, - dynamic=not get_ascend_config().use_sfa, + dynamic=not self.use_sparse, fullgraph=True, backend=npu_backend) return self.torchair_compiled_model @@ -636,7 +638,7 @@ class MtpProposer(Proposer): self.torchair_compiled_models[ batch_size] = torchair.inference.cache_compile( self.model.__dict__[forward_proxy_name], - dynamic=not get_ascend_config().use_sfa, + dynamic=not self.use_sparse, fullgraph=True, cache_dir=TORCHAIR_CACHE_DIR, config=config, diff --git a/vllm_ascend/torchair/models/torchair_deepseek_v2.py b/vllm_ascend/torchair/models/torchair_deepseek_v2.py index 3b69e25..462e05b 100644 --- a/vllm_ascend/torchair/models/torchair_deepseek_v2.py +++ b/vllm_ascend/torchair/models/torchair_deepseek_v2.py @@ -791,7 +791,7 @@ class TorchairDeepseekV2SFAAttention(DeepseekV2MLAAttention): quant_config=quant_config, prefix=f"{prefix}.attn", use_mla=True, - use_sfa=True, + use_sparse=True, # SFA Args q_lora_rank=self.q_lora_rank, kv_lora_rank=self.kv_lora_rank, @@ -879,12 +879,12 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): self.tp_rank = get_tp_group().rank_in_group ascend_config = get_ascend_config() self.use_mla = False - self.use_sfa = False + self.use_sparse = False # TODO: enable mla in vllm-ascend if model_config.use_mla: - if ascend_config.use_sfa: + if hasattr(model_config.hf_config, "index_topk"): attn_cls = TorchairDeepseekV2SFAAttention - self.use_sfa = True + self.use_sparse = True else: attn_cls = TorchairDeepseekV2MLAAttention # type: ignore[assignment] self.use_mla = True @@ -950,7 +950,7 @@ class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): forward_context = get_forward_context() if attn_metadata is not None: decoding_condition_met = ( - not attn_metadata.is_prefill if self.use_sfa else + not attn_metadata.is_prefill if self.use_sparse else not forward_context.with_prefill if self.use_mla else False) mla_moe_communication = decoding_condition_met and self.mla_moe_communication and replace_allreduce else: diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index daf6b5d..b29fdc2 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -376,7 +376,7 @@ class NPUTorchairModelRunner(NPUModelRunner): npu_backend = torchair.get_npu_backend(compiler_config=config) self.torchair_compiled_model = torch.compile( self.model, - dynamic=not self.ascend_config.use_sfa, + dynamic=not self.use_sparse, fullgraph=True, backend=npu_backend) return self.torchair_compiled_model @@ -399,7 +399,7 @@ class NPUTorchairModelRunner(NPUModelRunner): self.torchair_compiled_models[ batch_size] = torchair.inference.cache_compile( self.model.__dict__[forward_proxy_name], - dynamic=not self.ascend_config.use_sfa, + dynamic=not self.use_sparse, fullgraph=True, cache_dir=TORCHAIR_CACHE_DIR, config=config, diff --git a/vllm_ascend/torchair/torchair_sfa.py b/vllm_ascend/torchair/torchair_sfa.py index f4a00e5..abe30b7 100644 --- a/vllm_ascend/torchair/torchair_sfa.py +++ b/vllm_ascend/torchair/torchair_sfa.py @@ -738,7 +738,7 @@ class AscendSFATorchairImpl(MLAAttentionImpl): ascend_config = get_ascend_config() self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp - self.enable_prefetch = ascend_config.enable_prefetch + self.enable_prefetch = ascend_config.weight_prefetch_config.enabled self.enable_kv_nz = ascend_config.torchair_graph_config.enable_kv_nz if ascend_config.torchair_graph_config.enabled: self.graph_batch_size = ascend_config.torchair_graph_config.graph_batch_sizes[ diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index e30fc66..95acbe4 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -309,13 +309,14 @@ class NPUModelRunner(LoRAModelRunnerMixin): dtype=self.dtype, device=self.device) # Set up Attention - self.attn_backend = get_attn_backend( - 0, - self.dtype, - None, - self.block_size, - use_mla=self.model_config.use_mla, - use_sfa=self.ascend_config.use_sfa) + self.use_sparse = hasattr(self.vllm_config.model_config.hf_config, + "index_topk") + self.attn_backend = get_attn_backend(0, + self.dtype, + None, + self.block_size, + use_mla=self.model_config.use_mla, + use_sparse=self.use_sparse) if torch.version.cann.startswith("8.3"): self.attn_mask_builder = AttentionMaskBuilder( self.scheduler_config.max_num_batched_tokens, self.dtype, @@ -871,7 +872,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): if self.model_config.runner_type == "pooling" and self.model_config.pooler_config.pooling_type == "CLS": return self.attn_mask_builder.get_pooling_mask(self.device) # Chunk Prefill situation. - elif attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.ascend_config.use_sfa: + elif attn_state == AscendAttentionState.ChunkedPrefill and not self.vllm_config.model_config.use_mla and not self.use_sparse: if torch.version.cann.startswith("8.3"): return self.attn_mask_builder.get_splitfuse_attn_mask() else: @@ -1507,7 +1508,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): model=self.get_model(), **extra_attn_metadata_args) - if self.vllm_config.model_config.use_mla or self.ascend_config.use_sfa: + if self.vllm_config.model_config.use_mla or self.use_sparse: attn_metadata_i.num_input_tokens = num_input_tokens for layer_name in attn_group.layer_names: attn_metadata[layer_name] = attn_metadata_i @@ -2655,7 +2656,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.may_add_encoder_only_layers_to_kv_cache_config() self.initialize_attn_backend(kv_cache_config) - if self.ascend_config.is_deepseek_sfa: + if self.use_sparse: kv_caches = self.initialize_kv_cache_tensors_deepseek_sfa( kv_cache_config) elif self.model_config.is_deepseek_mla: @@ -2699,7 +2700,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) elif hasattr( attn_backend, "get_supported_block_size" - ) and not self.model_config.is_deepseek_mla and not self.ascend_config.is_deepseek_sfa: + ) and not self.model_config.is_deepseek_mla and not self.use_sparse: block_size = attn_backend.get_supported_block_size()[0] block_size_chunk = kv_cache_spec.block_size // block_size kv_cache_shape = attn_backend.get_kv_cache_shape( @@ -3245,7 +3246,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): block_size = self.vllm_config.cache_config.block_size use_mla = self.vllm_config.model_config.use_mla - use_sfa = self.ascend_config.use_sfa + use_sparse = self.use_sparse kv_cache_spec: dict[str, KVCacheSpec] = {} attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) for layer_name, attn_module in attn_layers.items(): @@ -3267,7 +3268,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): # TODO(lucas): move the attention specs into the model layers like # the attention backends if attn_module.attn_type == AttentionType.DECODER: - if use_mla and not use_sfa: + if use_mla and not use_sparse: kv_cache_spec[layer_name] = MLAAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index f1acda3..e3ced0a 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -43,7 +43,7 @@ from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, from vllm.v1.worker.worker_base import WorkerBase import vllm_ascend.envs as envs_ascend -from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config +from vllm_ascend.ascend_config import init_ascend_config from vllm_ascend.device_allocator.camem import CaMemAllocator from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel from vllm_ascend.platform import NPUPlatform @@ -88,7 +88,11 @@ class NPUWorker(WorkerBase): # init ascend config and soc version init_ascend_config(vllm_config) init_ascend_soc_version() - if get_ascend_config().use_sfa: + use_sparse = False + if vllm_config.model_config is not None: + use_sparse = hasattr(vllm_config.model_config.hf_config, + "index_topk") + if use_sparse: # Direct import instead of using try_register_lib to ensure proper error handling when # custom_ops is necessary but not available (e.g., in DeepSeek v3.2 deployments) # yapf: disable