diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index 2f1ce8a6..2fabca60 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -39,7 +39,7 @@ def register_model(): ModelRegistry.register_model( "DeepseekV32ForCausalLM", - "vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM") + "vllm_ascend.models.deepseek_v3_2:CustomDeepseekV3ForCausalLM") ModelRegistry.register_model( "DeepSeekMTPModel", diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 573f48e7..ddbb55f0 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -62,12 +62,9 @@ from vllm.model_executor.models.utils import ( PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) from vllm.model_executor.utils import set_weight_attrs -from vllm.platforms import current_platform from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.models.layers.mla import AscendMLAModules -from vllm_ascend.models.layers.sfa import (AscendSFAModules, - AscendSparseFlashAttention, Indexer) from vllm_ascend.ops.common_fused_moe import AscendFusedMoE from vllm_ascend.ops.linear import AscendLinearBase @@ -84,16 +81,7 @@ class AscendDeepseekV2Model(DeepseekV2Model, nn.Module): self.config = config self.vocab_size = config.vocab_size - self.is_v32 = hasattr(config, "index_topk") - if self.is_v32: - topk_tokens = config.index_topk - topk_indices_buffer = torch.empty( - vllm_config.scheduler_config.max_num_batched_tokens, - topk_tokens, - dtype=torch.int32, - device=current_platform.device_type) - else: - topk_indices_buffer = None + topk_indices_buffer = None if get_pp_group().is_first_rank: self.embed_tokens = VocabParallelEmbedding( @@ -332,7 +320,7 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): o_proj=self.o_proj, rotary_emb=self.rotary_emb, indexer=None, - is_sparse=hasattr(config, "index_topk"), + is_sparse=False, ) self.mla_attn = MultiHeadLatentAttention( @@ -365,180 +353,6 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): return self.mla_attn(positions, hidden_states, kv_cache, attn_metadata) -class CustomDeepseekV2SFAAttention(DeepseekV2MLAAttention): - - def __init__( - self, - config: PretrainedConfig, - hidden_size: int, - num_heads: int, - qk_nope_head_dim: int, - qk_rope_head_dim: int, - v_head_dim: int, - q_lora_rank: Optional[int], - kv_lora_rank: int, - rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, - max_position_embeddings: int = 8192, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - nn.Module.__init__(self) - self.hidden_size = hidden_size - self.qk_nope_head_dim = qk_nope_head_dim - self.qk_rope_head_dim = qk_rope_head_dim - self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim - self.v_head_dim = v_head_dim - - self.q_lora_rank = q_lora_rank - self.kv_lora_rank = kv_lora_rank - - self.num_heads = num_heads - self.tp_size = get_tensor_model_parallel_world_size() - assert num_heads % self.tp_size == 0 - self.num_local_heads = num_heads // self.tp_size - self.layers = config.num_hidden_layers - self.first_k_dense_replace = config.first_k_dense_replace - - self.scaling = self.qk_head_dim**-0.5 - self.rope_theta = rope_theta - self.max_position_embeddings = max_position_embeddings - - self.prefix = prefix - self.debug_layer_idx = int(self.prefix.split(".")[-2]) - - ascend_config = get_ascend_config() - self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp - - if self.q_lora_rank is not None: - self.q_a_proj = ReplicatedLinear( - self.hidden_size, - self.q_lora_rank, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_a_proj", - return_bias=False, - ) - self.q_a_layernorm = RMSNorm(self.q_lora_rank, - eps=config.rms_norm_eps) - self.q_b_proj = ColumnParallelLinear( - q_lora_rank, - self.num_heads * self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_b_proj", - return_bias=False, - ) - else: - self.q_proj = ColumnParallelLinear( - self.hidden_size, - self.num_heads * self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_proj", - return_bias=False, - ) - - self.kv_a_proj_with_mqa = ReplicatedLinear( - self.hidden_size, - self.kv_lora_rank + self.qk_rope_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.kv_a_proj_with_mqa", - return_bias=False, - ) - self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, - eps=config.rms_norm_eps) - self.kv_b_proj = ColumnParallelLinear( - self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.kv_b_proj", - return_bias=False, - ) - self.o_proj = CustomDeepseekV2RowParallelLinear( - self.num_heads * self.v_head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj", - return_bias=False, - ) - - if rope_scaling: - rope_scaling["rope_type"] = 'deepseek_yarn' - self.rotary_emb = get_rope(qk_rope_head_dim, - rotary_dim=qk_rope_head_dim, - max_position=max_position_embeddings, - base=rope_theta, - rope_scaling=rope_scaling, - is_neox_style=False) - if rope_scaling: - mscale_all_dim = rope_scaling.get("mscale_all_dim", False) - scaling_factor = rope_scaling["factor"] - mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) - self.scaling = self.scaling * mscale * mscale - - self.dim: int = config.hidden_size # 7168 - # TODO(zzzzwwjj): wait transformers add these params - self.n_heads: int = 64 # 64 - self.head_dim: int = 128 # 128 - self.index_topk: int = 2048 # 2048 - self.indexer = Indexer( - config, - quant_config=quant_config, - dim=self.dim, - n_heads=self.n_heads, - head_dim=self.head_dim, - index_topk=self.index_topk, - prefix=f"{prefix}.indexer", - ) - - sfa_modules = AscendSFAModules( - q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None, - q_a_layernorm=self.q_a_layernorm - if self.q_lora_rank is not None else None, - q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, - kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, - kv_a_layernorm=self.kv_a_layernorm, - kv_b_proj=self.kv_b_proj, - o_proj=self.o_proj, - rotary_emb=self.rotary_emb, - indexer=self.indexer) - - self.sfa_attn = AscendSparseFlashAttention( - self.hidden_size, - self.enable_shared_expert_dp, - self.debug_layer_idx, - self.first_k_dense_replace, - self.tp_size, - sfa_modules, - self.num_local_heads, - self.scaling, - self.layers, - self.kv_lora_rank, - self.qk_rope_head_dim, - self.q_lora_rank, - self.qk_nope_head_dim, - self.qk_head_dim, - self.v_head_dim, - cache_config, - quant_config, - prefix, - ) - self.prefix = prefix - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: Optional[torch.Tensor] = None, - attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: - return self.sfa_attn(positions, hidden_states, kv_cache, attn_metadata) - - class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): def __init__(self, @@ -566,10 +380,7 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): self.tp_rank = get_tp_group().rank_in_group # TODO: enable mla in vllm-ascend if model_config.use_mla: - if hasattr(model_config.hf_config, "index_topk"): - attn_cls = CustomDeepseekV2SFAAttention - else: - attn_cls = CustomDeepseekV2MLAAttention + attn_cls = CustomDeepseekV2MLAAttention else: attn_cls = DeepseekV2Attention self.self_attn = attn_cls( diff --git a/vllm_ascend/models/deepseek_v3.py b/vllm_ascend/models/deepseek_v3.py deleted file mode 100644 index e69de29b..00000000 diff --git a/vllm_ascend/models/deepseek_v3_2.py b/vllm_ascend/models/deepseek_v3_2.py new file mode 100644 index 00000000..adeca893 --- /dev/null +++ b/vllm_ascend/models/deepseek_v3_2.py @@ -0,0 +1,633 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # Adapted from +# # vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_v2.py +# # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py +# """Inference-only DeepseekV2/DeepseekV3 model.""" + +from typing import Any, Dict, Iterable, Optional, Union + +import torch +from torch import nn +from transformers import PretrainedConfig +from vllm.attention import AttentionMetadata +from vllm.compilation.decorators import support_torch_compile +from vllm.config import CacheConfig, VllmConfig +from vllm.distributed import (divide, get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + get_tp_group, split_tensor_along_last_dim, + tensor_model_parallel_all_reduce) +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED, + ColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.deepseek_v2 import \ + yarn_get_mscale # noqa: E501 +from vllm.model_executor.models.deepseek_v2 import ( + DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM, + DeepseekV2MLAAttention, DeepseekV2MLP, DeepseekV2Model, DeepseekV2MoE, + get_spec_layer_idx_from_weight_name) +from vllm.model_executor.models.utils import ( + PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +from vllm.model_executor.utils import set_weight_attrs +from vllm.platforms import current_platform + +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.models.layers.sfa import (AscendSFAModules, + AscendSparseFlashAttention, Indexer) +from vllm_ascend.ops.common_fused_moe import AscendFusedMoE +from vllm_ascend.ops.linear import AscendLinearBase + + +@support_torch_compile +class AscendDeepseekV2Model(DeepseekV2Model, nn.Module): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + # Rewrite this init func mainly for removing cuda-hard code + nn.Module.__init__(self) + + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + + self.vocab_size = config.vocab_size + assert hasattr(config, "index_topk") + topk_tokens = config.index_topk + topk_indices_buffer = torch.empty( + vllm_config.scheduler_config.max_num_batched_tokens, + topk_tokens, + dtype=torch.int32, + device=current_platform.device_type) + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens") + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix, + topk_indices_buffer), + prefix=f"{prefix}.layers") + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + +class CustomDeepseekV2RowParallelLinear(RowParallelLinear): + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + input_is_parallel: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + reduce_results: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + *, + return_bias: bool = True, + disable_tp: bool = False, + ): + # Divide the weight matrix along the first dimension. + self.tp_rank = (get_tensor_model_parallel_rank() + if not disable_tp else 0) + self.tp_size = (get_tensor_model_parallel_world_size() + if not disable_tp else 1) + self.input_size_per_partition = divide(input_size, self.tp_size) + self.output_size_per_partition = output_size + self.output_partition_sizes = [output_size] + + AscendLinearBase.__init__(self, + input_size, + output_size, + skip_bias_add, + params_dtype, + quant_config, + prefix, + return_bias=return_bias, + disable_tp=disable_tp) + + self.input_is_parallel = input_is_parallel + self.reduce_results = reduce_results + + assert self.quant_method is not None + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.input_size_per_partition, + output_partition_sizes=self.output_partition_sizes, + input_size=self.input_size, + output_size=self.output_size, + params_dtype=self.params_dtype, + weight_loader=( + self.weight_loader_v2 if self.quant_method.__class__.__name__ + in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) + if not reduce_results and (bias and not skip_bias_add): + raise ValueError("When not reduce the results, adding bias to the " + "results can lead to incorrect results") + + if bias: + self.bias = nn.Parameter( + torch.empty(self.output_size, dtype=params_dtype)) + set_weight_attrs(self.bias, { + "output_dim": 0, + "weight_loader": self.weight_loader, + }) + else: + self.register_parameter("bias", None) + self.update_param_tp_status() + + def forward( + self, + input_, + is_prefill=True, + is_force_scatter=False + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]: + if self.input_is_parallel: + input_parallel = input_ + else: + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[tp_rank].contiguous() + + # Matrix multiply. + assert self.quant_method is not None + # Only fuse bias add into GEMM for rank 0 (this ensures that + # bias will not get added more than once in TP>1 case) + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + output_parallel = self.quant_method.apply(self, + input_parallel, + bias=bias_) + if self.reduce_results and self.tp_size > 1: + output = tensor_model_parallel_all_reduce(output_parallel) + else: + output = output_parallel + + output_bias = self.bias if self.skip_bias_add else None + + if not self.return_bias: + return output + return output, output_bias + + +class CustomDeepseekV2SFAAttention(DeepseekV2MLAAttention): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: Optional[int], + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + nn.Module.__init__(self) + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + + self.num_heads = num_heads + self.tp_size = get_tensor_model_parallel_world_size() + assert num_heads % self.tp_size == 0 + self.num_local_heads = num_heads // self.tp_size + self.layers = config.num_hidden_layers + self.first_k_dense_replace = config.first_k_dense_replace + + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.prefix = prefix + self.debug_layer_idx = int(self.prefix.split(".")[-2]) + + ascend_config = get_ascend_config() + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear( + self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj", + return_bias=False, + ) + self.q_a_layernorm = RMSNorm(self.q_lora_rank, + eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear( + q_lora_rank, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj", + return_bias=False, + ) + else: + self.q_proj = ColumnParallelLinear( + self.hidden_size, + self.num_heads * self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj", + return_bias=False, + ) + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa", + return_bias=False, + ) + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj", + return_bias=False, + ) + self.o_proj = CustomDeepseekV2RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + return_bias=False, + ) + + if rope_scaling: + rope_scaling["rope_type"] = 'deepseek_yarn' + self.rotary_emb = get_rope(qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False) + if rope_scaling: + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + self.dim: int = config.hidden_size # 7168 + # TODO(zzzzwwjj): wait transformers add these params + self.n_heads: int = 64 # 64 + self.head_dim: int = 128 # 128 + self.index_topk: int = 2048 # 2048 + self.indexer = Indexer( + config, + quant_config=quant_config, + dim=self.dim, + n_heads=self.n_heads, + head_dim=self.head_dim, + index_topk=self.index_topk, + prefix=f"{prefix}.indexer", + ) + + sfa_modules = AscendSFAModules( + q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None, + q_a_layernorm=self.q_a_layernorm + if self.q_lora_rank is not None else None, + q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, + kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, + kv_a_layernorm=self.kv_a_layernorm, + kv_b_proj=self.kv_b_proj, + o_proj=self.o_proj, + rotary_emb=self.rotary_emb, + indexer=self.indexer) + + self.sfa_attn = AscendSparseFlashAttention( + self.hidden_size, + self.enable_shared_expert_dp, + self.debug_layer_idx, + self.first_k_dense_replace, + self.tp_size, + sfa_modules, + self.num_local_heads, + self.scaling, + self.layers, + self.kv_lora_rank, + self.qk_rope_head_dim, + self.q_lora_rank, + self.qk_nope_head_dim, + self.qk_head_dim, + self.v_head_dim, + cache_config, + quant_config, + prefix, + ) + self.prefix = prefix + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + return self.sfa_attn(positions, hidden_states, kv_cache, attn_metadata) + + +class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): + + def __init__(self, + vllm_config: VllmConfig, + prefix: str, + topk_indices_buffer=None) -> None: + nn.Module.__init__(self) + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + parallel_config = vllm_config.parallel_config + + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # DecoderLayers are created with `make_layers` which passes the prefix + # with the layer's index. + layer_idx = int(prefix.split(sep='.')[-1]) + self.layer_idx = layer_idx + self.layers = config.num_hidden_layers + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tp_group().rank_in_group + # TODO: enable mla in vllm-ascend + if model_config.use_mla: + attn_cls = CustomDeepseekV2SFAAttention + else: + attn_cls = DeepseekV2Attention + self.self_attn = attn_cls( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=config.v_head_dim, + q_lora_rank=config.q_lora_rank + if hasattr(config, "q_lora_rank") else None, + kv_lora_rank=config.kv_lora_rank, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + if (config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0): + self.mlp = DeepseekV2MoE( + config=config, + parallel_config=parallel_config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + if self.mlp.gate.e_score_correction_bias is not None: + self.mlp.gate.e_score_correction_bias.data = ( + self.mlp.gate.e_score_correction_bias.data.to( + dtype=torch.get_default_dtype())) + else: + self.mlp = DeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.routed_scaling_factor = config.routed_scaling_factor + self.first_k_dense_replace = config.first_k_dense_replace + self.tp_group = get_tp_group().device_group + + +class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + + # `packed_modules_mapping` needs to be modified before + # initializing DeepseekV2Model, as it is passed inplace to + # quantization config init and may be used to select the + # quant_method for relevant layers during initialization. + self.fuse_qkv_a_proj = hasattr( + config, "q_lora_rank") and config.q_lora_rank is not None + if self.fuse_qkv_a_proj: + self.packed_modules_mapping["fused_qkv_a_proj"] = [ + "q_a_proj", + "kv_a_proj_with_mqa", + ] + + self.model = AscendDeepseekV2Model(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) + else: + self.lm_head = PPMissingLayer() + self.logits_processor = LogitsProcessor(config.vocab_size) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + self.expert_weights: list[Any] = [] + + # Set MoE hyperparameters + self.num_moe_layers = (config.num_hidden_layers - + config.first_k_dense_replace) + self.num_expert_groups = config.n_group + + self.moe_layers: list[FusedMoE] = [] + example_moe = None + for layer in self.model.layers: + if isinstance(layer, PPMissingLayer): + continue + + assert isinstance(layer, DeepseekV2DecoderLayer) + if isinstance(layer.mlp, DeepseekV2MoE): + # Pick last one layer since the first ones may be dense layers. + example_moe = layer.mlp + self.moe_layers.append(layer.mlp.experts) + + if example_moe is None: + raise RuntimeError("No DeepseekV2MoE layer found in model.layers.") + + self.num_logical_experts = example_moe.n_logical_experts + self.num_physical_experts = example_moe.n_physical_experts + self.num_local_physical_experts = example_moe.n_local_physical_experts + self.num_routed_experts = example_moe.n_routed_experts + self.num_shared_experts = example_moe.n_shared_experts + self.num_redundant_experts = example_moe.n_redundant_experts + + # NOTE: This `load_weights` is mainly copied from + # https://github.com/vllm-project/vllm/commit/07b8fae219b1fff51ef115c38c44b51395be5bb5 + # to fix CI, and it is different from the implementation in main + # TODO: support eplb style load_weights + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + """""" + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = AscendFusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if "module" in name: + continue + + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is not None: + continue # skip spec decode layers for main model + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + return_success=False) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM): + pass + + +DeepseekV2DecoderLayer.__init__ = CustomDeepseekV2DecoderLayer.__init__