From f8b52fe9509e4b90e384d9780042765c99b50c63 Mon Sep 17 00:00:00 2001 From: whx <56632993+whx-sjtu@users.noreply.github.com> Date: Mon, 20 Oct 2025 15:31:34 +0800 Subject: [PATCH] [Model][1/N] Delete deepseek v2/v3 modeling codes. (#3189) This PR deletes model codes of deepseek_v2 and deepseek_v3 to reuse the model file from vLLM. vLLM Ascend now uses custom ops register way instead of model file hard-coding. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: whx-sjtu <2952154980@qq.com> --- tests/ut/attention/test_mla_v1.py | 1 + tests/ut/models/conftest.py | 16 +- tests/ut/models/test_deepseek_mtp.py | 2 - tests/ut/models/test_deepseek_v2.py | 130 ---- tests/ut/torchair/test_torchair_mla.py | 1 + vllm_ascend/attention/mla_v1.py | 70 +- vllm_ascend/models/__init__.py | 8 - vllm_ascend/models/deepseek_v2.py | 599 ------------------ vllm_ascend/models/layers/mla.py | 142 +++-- vllm_ascend/ops/common_fused_moe.py | 4 + vllm_ascend/quantization/quant_config.py | 7 +- .../torchair/models/torchair_deepseek_v2.py | 3 +- vllm_ascend/torchair/torchair_mla.py | 3 +- 13 files changed, 143 insertions(+), 843 deletions(-) delete mode 100644 tests/ut/models/test_deepseek_v2.py delete mode 100644 vllm_ascend/models/deepseek_v2.py diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index d6dd091..5e492e9 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -302,6 +302,7 @@ class TestAscendMLAImpl(TestBase): "v_head_dim": 128, "rotary_emb": MagicMock(), "q_proj": MagicMock(), + "q_b_proj": MagicMock(), "kv_b_proj": MagicMock(), "o_proj": MagicMock(), "kv_a_proj_with_mqa": MagicMock(), diff --git a/tests/ut/models/conftest.py b/tests/ut/models/conftest.py index 3edac77..88b8cfa 100644 --- a/tests/ut/models/conftest.py +++ b/tests/ut/models/conftest.py @@ -90,13 +90,7 @@ def mock_distributed(): mock_vllm_config.scheduler_config = Mock(max_num_seqs=256) mock_vllm_config.model_config = Mock(max_model_len=2048, quant_config=None) - with patch("vllm_ascend.models.deepseek_v2.get_tensor_model_parallel_rank", return_value=0), \ - patch("vllm_ascend.models.deepseek_v2.get_tensor_model_parallel_world_size", return_value=1), \ - patch("vllm_ascend.models.deepseek_v2.get_tp_group", return_value=tp_group), \ - patch("vllm_ascend.models.deepseek_v2.get_pp_group", return_value=pp_group), \ - patch("vllm_ascend.models.deepseek_v2.get_pp_group", - return_value=Mock(is_first_rank=False, is_last_rank=False)), \ - patch("vllm_ascend.ops.common_fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \ + with patch("vllm_ascend.ops.common_fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \ patch("vllm_ascend.ops.moe.token_dispatcher.torch.distributed.get_rank", return_value=0), \ patch("vllm_ascend.ops.moe.token_dispatcher.get_ascend_soc_version", return_value=None), \ patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group, @@ -104,11 +98,3 @@ def mock_distributed(): patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group), \ patch("torch.npu.current_device", return_value=0): yield - - -@pytest.fixture -def mock_forward_context(): - forward_context = Mock(in_profile_run=False, with_prefill=False) - with patch("vllm_ascend.models.deepseek_v2.get_forward_context", - return_value=forward_context): - yield diff --git a/tests/ut/models/test_deepseek_mtp.py b/tests/ut/models/test_deepseek_mtp.py index f55c3f9..6e2d868 100644 --- a/tests/ut/models/test_deepseek_mtp.py +++ b/tests/ut/models/test_deepseek_mtp.py @@ -37,8 +37,6 @@ class TestCustomDeepSeekMultiTokenPredictorLayer(PytestBase): mocker.patch( "vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__", return_value=None) - mocker.patch("vllm_ascend.models.deepseek_v2.get_ascend_config", - return_value=mocker.Mock()) mtp_layer = CustomDeepSeekMultiTokenPredictorLayer(config, "0", None) mocker_deepseek_v2_decode_layer.assert_called_once() diff --git a/tests/ut/models/test_deepseek_v2.py b/tests/ut/models/test_deepseek_v2.py deleted file mode 100644 index cf97c62..0000000 --- a/tests/ut/models/test_deepseek_v2.py +++ /dev/null @@ -1,130 +0,0 @@ -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. -# -from unittest.mock import MagicMock, Mock, patch - -import pytest -import torch -from vllm.config import CacheConfig -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead - -from vllm_ascend import ascend_config -from vllm_ascend.models.deepseek_v2 import (CustomDeepseekV2MLAAttention, - CustomDeepseekV2RowParallelLinear) - - -@pytest.mark.parametrize("cls", [CustomDeepseekV2RowParallelLinear]) -def test_row_parallel_linear(cls, mock_distributed): - linear = cls(input_size=128, output_size=64, bias=False, quant_config=None) - linear.quant_method = Mock() - linear.quant_method.apply.return_value = torch.randn(2, 4, 64) - input_ = torch.randn(2, 4, 128) - with patch("vllm_ascend.models.deepseek_v2.split_tensor_along_last_dim", - return_value=[torch.randn(2, 4, 64)]): - linear.input_is_parallel = False - output = linear(input_, is_prefill=True) - assert output[0].shape == (2, 4, 64) - - linear.input_is_parallel = True - output = linear(input_, is_prefill=False) - assert output[0].shape == (2, 4, 64) - - -@patch("vllm_ascend.models.layers.mla.get_forward_context") -@patch("torch.ops.vllm.mla_forward") -@patch("torch_npu.npu_rms_norm") -def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_mla_forward, - mock_forward_context, - mock_distributed, base_config): - mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128)) - # Make a fake ascend config because of the AscendLinearBase - vllm_config = MagicMock() - vllm_config.additional_config = None - vllm_config.parallel_config.enable_expert_parallel = False - vllm_config.parallel_config.tensor_parallel_size = 1 - vllm_config.kv_transfer_config = None - ascend_config.init_ascend_config(vllm_config) - dummy_forward_context = MagicMock() - dummy_forward_context.sp_enabled = False - mock_forward_context.return_value = dummy_forward_context - - attn = CustomDeepseekV2MLAAttention(config=base_config, - hidden_size=128, - num_heads=8, - qk_nope_head_dim=16, - qk_rope_head_dim=16, - v_head_dim=32, - q_lora_rank=16, - kv_lora_rank=16, - cache_config=CacheConfig(), - quant_config=None, - prefix="layers.0.self_attn") - assert attn.debug_layer_idx == 0 - - x = torch.randn(2, 4, 128) - positions = torch.arange(4).repeat(2, 1) - with patch.object(attn.mla_attn, - "__call__", - return_value=torch.randn(2, 4, 128)): - attn(positions, x) - mock_mla_forward.assert_called_once() - - attn = CustomDeepseekV2MLAAttention(config=base_config, - hidden_size=128, - num_heads=8, - qk_nope_head_dim=16, - qk_rope_head_dim=16, - v_head_dim=32, - q_lora_rank=None, - kv_lora_rank=16, - prefix="layers.1.self_attn") - assert hasattr(attn, "q_proj") - ascend_config._ASCEND_CONFIG = None - - -def test_deepseek_v2_lmhead(mock_distributed, vllm_config): - # 创建一个简单的配置对象 - class SimpleConfig: - - def __init__(self): - self.vocab_size = 10000 - self.hidden_size = 128 - - config = SimpleConfig() - - # Make a fake ascend config because of the AscendLinearBase - vllm_config = MagicMock() - vllm_config.additional_config = None - vllm_config.parallel_config.enable_expert_parallel = False - vllm_config.parallel_config.tensor_parallel_size = 1 - vllm_config.kv_transfer_config = None - ascend_config.init_ascend_config(vllm_config) - - # 直接创建lmhead和logits_processor - lmhead = ParallelLMHead(config.vocab_size, config.hidden_size) - logits_processor = LogitsProcessor(config.vocab_size) - - # 创建模拟输出 - mock_output = torch.randn(2, 4, config.hidden_size) - mock_logits = torch.randn(2, 4, config.vocab_size) - - # 直接测试logits_processor - with patch.object(lmhead.quant_method, "apply", return_value=mock_logits): - with patch.object(logits_processor, - "_gather_logits", - return_value=mock_logits): - logits = logits_processor(lmhead, mock_output) - assert logits.shape == (2, 4, config.vocab_size) - ascend_config._ASCEND_CONFIG = None diff --git a/tests/ut/torchair/test_torchair_mla.py b/tests/ut/torchair/test_torchair_mla.py index ec8ddfd..3dd1d2f 100644 --- a/tests/ut/torchair/test_torchair_mla.py +++ b/tests/ut/torchair/test_torchair_mla.py @@ -525,6 +525,7 @@ class TestAscendMLATorchairImpl(TestBase): "v_head_dim": 128, "rotary_emb": MagicMock(), "q_proj": MagicMock(), + "q_b_proj": MagicMock(), "kv_b_proj": MagicMock(), "o_proj": MagicMock(), "kv_a_proj_with_mqa": MagicMock(), diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 36a8f02..2196858 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -536,12 +536,13 @@ class AscendMLAImpl(MLAAttentionImpl): self.qk_head_dim = kwargs['qk_head_dim'] self.v_head_dim = kwargs['v_head_dim'] self.rotary_emb = kwargs['rotary_emb'] - self.q_proj = kwargs['q_proj'] + self.fused_qkv_a_proj = kwargs.get('fused_qkv_a_proj', None) + self.q_proj = kwargs['q_proj'] if self.q_lora_rank is None else kwargs[ + 'q_b_proj'] self.kv_b_proj = kwargs['kv_b_proj'] self.o_proj = kwargs['o_proj'] self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None) self.kv_a_layernorm = kwargs.get('kv_a_layernorm', None) - self.q_a_proj = kwargs.get('q_a_proj', None) self.q_a_layernorm = kwargs.get('q_a_layernorm', None) self.num_queries_per_kv = self.num_heads // self.num_kv_heads self.tp_size = get_tensor_model_parallel_world_size() @@ -648,36 +649,46 @@ class AscendMLAImpl(MLAAttentionImpl): self._process_weights_for_fused_mlapo(act_dtype) def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype): - kv_a_proj_wt = self.kv_a_proj_with_mqa.weight.data - kv_a_proj_wt = kv_a_proj_wt.t().contiguous() + kv_a_proj_wt = self.fused_qkv_a_proj.weight.data[ + ..., self.q_lora_rank:].contiguous() + q_a_proj_wt = self.fused_qkv_a_proj.weight.data[ + ..., :self.q_lora_rank].contiguous() + kv_a_proj_wt = kv_a_proj_wt.contiguous() kv_a_proj_wt = trans_rope_weight(kv_a_proj_wt, self.qk_rope_head_dim) - kv_a_proj_wt = kv_a_proj_wt.t().contiguous() - wd_qkv = torch.cat((kv_a_proj_wt, self.q_a_proj.weight.data), dim=-1) + kv_a_proj_wt = kv_a_proj_wt.contiguous() + wd_qkv = torch.cat((kv_a_proj_wt, q_a_proj_wt), dim=-1) wd_qkv = wd_qkv.t().contiguous() wd_qkv = transdata(wd_qkv, block_size=(16, 32)).unsqueeze(0).contiguous() self.wd_qkv = torch_npu.npu_format_cast(wd_qkv, 29) - kv_a_proj_deq_scl = self.kv_a_proj_with_mqa.deq_scale + kv_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[ + self.q_lora_rank:].contiguous() + q_a_proj_deq_scl = self.fused_qkv_a_proj.deq_scale[:self. + q_lora_rank].contiguous( + ) kv_a_proj_deq_scl = kv_a_proj_deq_scl.reshape( self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous() kv_a_proj_deq_scl = trans_rope_weight(kv_a_proj_deq_scl, self.qk_rope_head_dim) kv_a_proj_deq_scl = kv_a_proj_deq_scl.view( self.kv_lora_rank + self.qk_rope_head_dim).contiguous() - self.deq_scale_qkv = torch.cat( - (kv_a_proj_deq_scl, self.q_a_proj.deq_scale), dim=-1).contiguous() + self.deq_scale_qkv = torch.cat((kv_a_proj_deq_scl, q_a_proj_deq_scl), + dim=-1).contiguous() - kv_a_proj_qt_bias = self.kv_a_proj_with_mqa.quant_bias + kv_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[ + self.q_lora_rank:].contiguous() + q_a_proj_qt_bias = self.fused_qkv_a_proj.quant_bias[:self. + q_lora_rank].contiguous( + ) kv_a_proj_qt_bias = kv_a_proj_qt_bias.reshape( self.kv_lora_rank + self.qk_rope_head_dim, -1).contiguous() kv_a_proj_qt_bias = trans_rope_weight(kv_a_proj_qt_bias, self.qk_rope_head_dim) kv_a_proj_qt_bias = kv_a_proj_qt_bias.view( self.kv_lora_rank + self.qk_rope_head_dim).contiguous() - self.quant_bias_qkv = torch.cat( - (kv_a_proj_qt_bias, self.q_a_proj.quant_bias), - dim=-1).contiguous() + self.quant_bias_qkv = torch.cat((kv_a_proj_qt_bias, q_a_proj_qt_bias), + dim=-1).contiguous() wu_q = self.q_proj.weight.data wu_q = wu_q.t().reshape(self.num_heads, @@ -704,22 +715,22 @@ class AscendMLAImpl(MLAAttentionImpl): self.qb_qt_bias = qb_qt_bias.reshape( self.num_heads * (self.qk_nope_head_dim + self.qk_rope_head_dim)) - device = self.q_a_proj.weight.device + device = self.q_proj.weight.device self.gamma0 = torch.ones( - [self.q_a_proj.weight.shape[-1]], + [self.fused_qkv_a_proj.weight.shape[-1]], dtype=act_dtype, device=device, ) self.beta0 = torch.zeros( - [self.q_a_proj.weight.shape[-1]], + [self.fused_qkv_a_proj.weight.shape[-1]], dtype=act_dtype, device=device, ) self.gamma1 = self.q_a_layernorm.weight.data self.beta1 = self.q_a_layernorm.bias.data self.gamma2 = self.kv_a_layernorm.weight.data - self.quant_scale0 = self.q_a_proj.input_scale.data - self.quant_offset0 = self.q_a_proj.input_offset.data + self.quant_scale0 = self.fused_qkv_a_proj.input_scale.data + self.quant_offset0 = self.fused_qkv_a_proj.input_offset.data self.quant_scale1 = self.q_proj.input_scale.data self.quant_offset1 = self.q_proj.input_offset.data self.ctkv_scale = torch.tensor([1], dtype=act_dtype, device=device) @@ -1122,21 +1133,26 @@ class AscendMLAImpl(MLAAttentionImpl): has_prefill = attn_metadata.num_prefills > 0 num_decode_tokens = attn_metadata.num_decode_tokens num_actual_tokens = attn_metadata.num_actual_tokens - if self.q_a_proj is not None: - maybe_npu_prefetch(inputs=self.q_a_proj.weight, + if self.fused_qkv_a_proj is not None: + maybe_npu_prefetch(inputs=self.fused_qkv_a_proj.weight, dependency=hidden_states, enabled=self.enable_prefetch) - ckq = self.q_a_proj(hidden_states)[0] - q_c = self.q_a_layernorm(ckq) + qkv_lora = self.fused_qkv_a_proj(hidden_states)[0] + q_c, kv_no_split = qkv_lora.split( + [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], + dim=-1, + ) + q_c = self.q_a_layernorm(q_c) else: q_c = hidden_states + kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0] - kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0] # Process for Flash Comm V1 q_c = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( q_c, need_gather_q_kv) kv_no_split = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( kv_no_split, need_gather_q_kv) + decode_preprocess_res = None prefill_preprocess_res = None if has_prefill: @@ -1264,14 +1280,18 @@ class AscendMLAImpl(MLAAttentionImpl): max_size=MAX_O_PROJ_PREFETCH_SIZE, enabled=self.enable_prefetch) - output[...] = self.o_proj(o_proj_input)[0] + output[...] = self.o_proj(o_proj_input, + is_prefill=prefill_preprocess_res + is not None)[0] else: with torch.npu.stream(current_ms_metadata.comm_stream): maybe_npu_prefetch(inputs=self.o_proj.weight, dependency=o_proj_input, max_size=MAX_O_PROJ_PREFETCH_SIZE, enabled=self.enable_prefetch) - output[...] = self.o_proj(o_proj_input)[0] + output[...] = self.o_proj(o_proj_input, + is_prefill=prefill_preprocess_res + is not None)[0] current_ms_metadata.after_comm_event.record() del o_proj_input diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index 2fabca6..15cddd0 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -29,14 +29,6 @@ def register_model(): "vllm_ascend.models.qwen2_5_vl_without_padding:AscendQwen2_5_VLForConditionalGeneration_Without_Padding" ) - ModelRegistry.register_model( - "DeepseekV2ForCausalLM", - "vllm_ascend.models.deepseek_v2:CustomDeepseekV2ForCausalLM") - - ModelRegistry.register_model( - "DeepseekV3ForCausalLM", - "vllm_ascend.models.deepseek_v2:CustomDeepseekV3ForCausalLM") - ModelRegistry.register_model( "DeepseekV32ForCausalLM", "vllm_ascend.models.deepseek_v3_2:CustomDeepseekV3ForCausalLM") diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py deleted file mode 100644 index ddbb55f..0000000 --- a/vllm_ascend/models/deepseek_v2.py +++ /dev/null @@ -1,599 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. -# -# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX -# and OPT implementations in this library. It has been modified from its -# original forms to accommodate minor architectural differences compared -# to GPT-NeoX and OPT used by the Meta AI team that trained the model. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# # Adapted from -# # vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_v2.py -# # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py -# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py -# """Inference-only DeepseekV2/DeepseekV3 model.""" - -from typing import Any, Dict, Iterable, Optional, Union - -import torch -from torch import nn -from transformers import PretrainedConfig -from vllm.attention import AttentionMetadata -from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, VllmConfig -from vllm.distributed import (divide, get_pp_group, - get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size, - get_tp_group, split_tensor_along_last_dim, - tensor_model_parallel_all_reduce) -from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (WEIGHT_LOADER_V2_SUPPORTED, - ColumnParallelLinear, - ReplicatedLinear, - RowParallelLinear) -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.mla import MultiHeadLatentAttention -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import ( - default_weight_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.models.deepseek_v2 import \ - yarn_get_mscale # noqa: E501 -from vllm.model_executor.models.deepseek_v2 import ( - DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2ForCausalLM, - DeepseekV2MLAAttention, DeepseekV2MLP, DeepseekV2Model, DeepseekV2MoE, - get_spec_layer_idx_from_weight_name) -from vllm.model_executor.models.utils import ( - PPMissingLayer, is_pp_missing_parameter, - make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) -from vllm.model_executor.utils import set_weight_attrs - -from vllm_ascend.ascend_config import get_ascend_config -from vllm_ascend.models.layers.mla import AscendMLAModules -from vllm_ascend.ops.common_fused_moe import AscendFusedMoE -from vllm_ascend.ops.linear import AscendLinearBase - - -@support_torch_compile -class AscendDeepseekV2Model(DeepseekV2Model, nn.Module): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - # Rewrite this init func mainly for removing cuda-hard code - nn.Module.__init__(self) - - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.config = config - - self.vocab_size = config.vocab_size - topk_indices_buffer = None - - if get_pp_group().is_first_rank: - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=f"{prefix}.embed_tokens") - else: - self.embed_tokens = PPMissingLayer() - - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix, - topk_indices_buffer), - prefix=f"{prefix}.layers") - - if get_pp_group().is_last_rank: - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - else: - self.norm = PPMissingLayer() - self.make_empty_intermediate_tensors = ( - make_empty_intermediate_tensors_factory( - ["hidden_states", "residual"], config.hidden_size)) - - -class CustomDeepseekV2RowParallelLinear(RowParallelLinear): - - def __init__( - self, - input_size: int, - output_size: int, - bias: bool = True, - input_is_parallel: bool = True, - skip_bias_add: bool = False, - params_dtype: Optional[torch.dtype] = None, - reduce_results: bool = True, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - *, - return_bias: bool = True, - disable_tp: bool = False, - ): - # Divide the weight matrix along the first dimension. - self.tp_rank = (get_tensor_model_parallel_rank() - if not disable_tp else 0) - self.tp_size = (get_tensor_model_parallel_world_size() - if not disable_tp else 1) - self.input_size_per_partition = divide(input_size, self.tp_size) - self.output_size_per_partition = output_size - self.output_partition_sizes = [output_size] - - AscendLinearBase.__init__(self, - input_size, - output_size, - skip_bias_add, - params_dtype, - quant_config, - prefix, - return_bias=return_bias, - disable_tp=disable_tp) - - self.input_is_parallel = input_is_parallel - self.reduce_results = reduce_results - - assert self.quant_method is not None - self.quant_method.create_weights( - layer=self, - input_size_per_partition=self.input_size_per_partition, - output_partition_sizes=self.output_partition_sizes, - input_size=self.input_size, - output_size=self.output_size, - params_dtype=self.params_dtype, - weight_loader=( - self.weight_loader_v2 if self.quant_method.__class__.__name__ - in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) - if not reduce_results and (bias and not skip_bias_add): - raise ValueError("When not reduce the results, adding bias to the " - "results can lead to incorrect results") - - if bias: - self.bias = nn.Parameter( - torch.empty(self.output_size, dtype=params_dtype)) - set_weight_attrs(self.bias, { - "output_dim": 0, - "weight_loader": self.weight_loader, - }) - else: - self.register_parameter("bias", None) - self.update_param_tp_status() - - def forward( - self, - input_, - is_prefill=True, - is_force_scatter=False - ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]: - if self.input_is_parallel: - input_parallel = input_ - else: - tp_rank = get_tensor_model_parallel_rank() - splitted_input = split_tensor_along_last_dim( - input_, num_partitions=self.tp_size) - input_parallel = splitted_input[tp_rank].contiguous() - - # Matrix multiply. - assert self.quant_method is not None - # Only fuse bias add into GEMM for rank 0 (this ensures that - # bias will not get added more than once in TP>1 case) - bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias - output_parallel = self.quant_method.apply(self, - input_parallel, - bias=bias_) - if self.reduce_results and self.tp_size > 1: - output = tensor_model_parallel_all_reduce(output_parallel) - else: - output = output_parallel - - output_bias = self.bias if self.skip_bias_add else None - - if not self.return_bias: - return output - return output, output_bias - - -class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): - - def __init__( - self, - config: PretrainedConfig, - hidden_size: int, - num_heads: int, - qk_nope_head_dim: int, - qk_rope_head_dim: int, - v_head_dim: int, - q_lora_rank: Optional[int], - kv_lora_rank: int, - rope_theta: float = 10000, - rope_scaling: Optional[Dict[str, Any]] = None, - max_position_embeddings: int = 8192, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - ) -> None: - nn.Module.__init__(self) - self.hidden_size = hidden_size - self.qk_nope_head_dim = qk_nope_head_dim - self.qk_rope_head_dim = qk_rope_head_dim - self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim - self.v_head_dim = v_head_dim - - self.q_lora_rank = q_lora_rank - self.kv_lora_rank = kv_lora_rank - - self.num_heads = num_heads - self.tp_size = get_tensor_model_parallel_world_size() - assert num_heads % self.tp_size == 0 - self.num_local_heads = num_heads // self.tp_size - self.layers = config.num_hidden_layers - self.first_k_dense_replace = config.first_k_dense_replace - - self.scaling = self.qk_head_dim**-0.5 - self.rope_theta = rope_theta - self.max_position_embeddings = max_position_embeddings - - self.prefix = prefix - self.debug_layer_idx = int(self.prefix.split(".")[-2]) - - ascend_config = get_ascend_config() - self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp - - if self.q_lora_rank is not None: - self.q_a_proj = ReplicatedLinear(self.hidden_size, - self.q_lora_rank, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_a_proj") - self.q_a_layernorm = RMSNorm(self.q_lora_rank, - eps=config.rms_norm_eps) - self.q_b_proj = ColumnParallelLinear(q_lora_rank, - self.num_heads * - self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_b_proj") - else: - self.q_proj = ColumnParallelLinear(self.hidden_size, - self.num_heads * - self.qk_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.q_proj") - - self.kv_a_proj_with_mqa = ReplicatedLinear( - self.hidden_size, - self.kv_lora_rank + self.qk_rope_head_dim, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.kv_a_proj_with_mqa") - self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, - eps=config.rms_norm_eps) - self.kv_b_proj = ColumnParallelLinear( - self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.kv_b_proj") - self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim, - self.hidden_size, - bias=False, - quant_config=quant_config, - prefix=f"{prefix}.o_proj") - - if rope_scaling: - rope_scaling["rope_type"] = 'deepseek_yarn' - self.rotary_emb = get_rope(qk_rope_head_dim, - rotary_dim=qk_rope_head_dim, - max_position=max_position_embeddings, - base=rope_theta, - rope_scaling=rope_scaling, - is_neox_style=False) - if rope_scaling: - mscale_all_dim = rope_scaling.get("mscale_all_dim", False) - scaling_factor = rope_scaling["factor"] - mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) - self.scaling = self.scaling * mscale * mscale - self.indexer = None - - mla_modules = AscendMLAModules( - q_a_proj=self.q_a_proj if self.q_lora_rank is not None else None, - q_a_layernorm=self.q_a_layernorm - if self.q_lora_rank is not None else None, - q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, - kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, - kv_a_layernorm=self.kv_a_layernorm, - kv_b_proj=self.kv_b_proj, - o_proj=self.o_proj, - rotary_emb=self.rotary_emb, - indexer=None, - is_sparse=False, - ) - - self.mla_attn = MultiHeadLatentAttention( - self.hidden_size, - self.enable_shared_expert_dp, - self.debug_layer_idx, - self.first_k_dense_replace, - self.tp_size, - mla_modules, - self.num_local_heads, - self.scaling, - self.layers, - self.kv_lora_rank, - self.qk_rope_head_dim, - self.q_lora_rank, - self.qk_nope_head_dim, - self.qk_head_dim, - self.v_head_dim, - cache_config, - quant_config, - prefix, - ) - - def forward( - self, - positions: torch.Tensor, - hidden_states: torch.Tensor, - kv_cache: Optional[torch.Tensor] = None, - attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: - return self.mla_attn(positions, hidden_states, kv_cache, attn_metadata) - - -class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): - - def __init__(self, - vllm_config: VllmConfig, - prefix: str, - topk_indices_buffer=None) -> None: - nn.Module.__init__(self) - config = vllm_config.model_config.hf_config - model_config = vllm_config.model_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - parallel_config = vllm_config.parallel_config - - self.hidden_size = config.hidden_size - rope_theta = getattr(config, "rope_theta", 10000) - rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) - # DecoderLayers are created with `make_layers` which passes the prefix - # with the layer's index. - layer_idx = int(prefix.split(sep='.')[-1]) - self.layer_idx = layer_idx - self.layers = config.num_hidden_layers - self.tp_size = get_tensor_model_parallel_world_size() - self.tp_rank = get_tp_group().rank_in_group - # TODO: enable mla in vllm-ascend - if model_config.use_mla: - attn_cls = CustomDeepseekV2MLAAttention - else: - attn_cls = DeepseekV2Attention - self.self_attn = attn_cls( - config=config, - hidden_size=self.hidden_size, - num_heads=config.num_attention_heads, - qk_nope_head_dim=config.qk_nope_head_dim, - qk_rope_head_dim=config.qk_rope_head_dim, - v_head_dim=config.v_head_dim, - q_lora_rank=config.q_lora_rank - if hasattr(config, "q_lora_rank") else None, - kv_lora_rank=config.kv_lora_rank, - rope_theta=rope_theta, - rope_scaling=rope_scaling, - max_position_embeddings=max_position_embeddings, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.self_attn", - ) - - if (config.n_routed_experts is not None - and layer_idx >= config.first_k_dense_replace - and layer_idx % config.moe_layer_freq == 0): - self.mlp = DeepseekV2MoE( - config=config, - parallel_config=parallel_config, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - ) - if self.mlp.gate.e_score_correction_bias is not None: - self.mlp.gate.e_score_correction_bias.data = ( - self.mlp.gate.e_score_correction_bias.data.to( - dtype=torch.get_default_dtype())) - else: - self.mlp = DeepseekV2MLP( - hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - hidden_act=config.hidden_act, - quant_config=quant_config, - prefix=f"{prefix}.mlp", - ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.routed_scaling_factor = config.routed_scaling_factor - self.first_k_dense_replace = config.first_k_dense_replace - self.tp_group = get_tp_group().device_group - - -class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - nn.Module.__init__(self) - config = vllm_config.model_config.hf_config - quant_config = vllm_config.quant_config - self.config = config - self.quant_config = quant_config - - # `packed_modules_mapping` needs to be modified before - # initializing DeepseekV2Model, as it is passed inplace to - # quantization config init and may be used to select the - # quant_method for relevant layers during initialization. - self.fuse_qkv_a_proj = hasattr( - config, "q_lora_rank") and config.q_lora_rank is not None - if self.fuse_qkv_a_proj: - self.packed_modules_mapping["fused_qkv_a_proj"] = [ - "q_a_proj", - "kv_a_proj_with_mqa", - ] - - self.model = AscendDeepseekV2Model(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "model")) - if get_pp_group().is_last_rank: - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "lm_head")) - else: - self.lm_head = PPMissingLayer() - self.logits_processor = LogitsProcessor(config.vocab_size) - self.make_empty_intermediate_tensors = ( - self.model.make_empty_intermediate_tensors) - self.expert_weights: list[Any] = [] - - # Set MoE hyperparameters - self.num_moe_layers = (config.num_hidden_layers - - config.first_k_dense_replace) - self.num_expert_groups = config.n_group - - self.moe_layers: list[FusedMoE] = [] - example_moe = None - for layer in self.model.layers: - if isinstance(layer, PPMissingLayer): - continue - - assert isinstance(layer, DeepseekV2DecoderLayer) - if isinstance(layer.mlp, DeepseekV2MoE): - # Pick last one layer since the first ones may be dense layers. - example_moe = layer.mlp - self.moe_layers.append(layer.mlp.experts) - - if example_moe is None: - raise RuntimeError("No DeepseekV2MoE layer found in model.layers.") - - self.num_logical_experts = example_moe.n_logical_experts - self.num_physical_experts = example_moe.n_physical_experts - self.num_local_physical_experts = example_moe.n_local_physical_experts - self.num_routed_experts = example_moe.n_routed_experts - self.num_shared_experts = example_moe.n_shared_experts - self.num_redundant_experts = example_moe.n_redundant_experts - - # NOTE: This `load_weights` is mainly copied from - # https://github.com/vllm-project/vllm/commit/07b8fae219b1fff51ef115c38c44b51395be5bb5 - # to fix CI, and it is different from the implementation in main - # TODO: support eplb style load_weights - def load_weights(self, weights: Iterable[tuple[str, - torch.Tensor]]) -> set[str]: - """""" - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - ("gate_up_proj", "gate_proj", 0), - ("gate_up_proj", "up_proj", 1), - ] - - # Params for weights, fp8 weight scales, fp8 activation scales - # (param_name, weight_name, expert_id, shard_id) - expert_params_mapping = AscendFusedMoE.make_expert_params_mapping( - ckpt_gate_proj_name="gate_proj", - ckpt_down_proj_name="down_proj", - ckpt_up_proj_name="up_proj", - num_experts=self.config.n_routed_experts) - - params_dict = dict(self.named_parameters()) - loaded_params: set[str] = set() - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if "module" in name: - continue - - spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) - if spec_layer is not None: - continue # skip spec decode layers for main model - - for (param_name, weight_name, shard_id) in stacked_params_mapping: - # Skip non-stacked layers and experts (experts handled below). - if weight_name not in name: - continue - # We have mlp.experts[0].gate_proj in the checkpoint. - # Since we handle the experts below in expert_params_mapping, - # we need to skip here BEFORE we update the name, otherwise - # name will be updated to mlp.experts[0].gate_up_proj, which - # will then be updated below in expert_params_mapping - # for mlp.experts[0].gate_gate_up_proj, which breaks load. - if (("mlp.experts." in name) and name not in params_dict): - continue - name = name.replace(weight_name, param_name) - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - for mapping in expert_params_mapping: - param_name, weight_name, expert_id, shard_id = mapping - if weight_name not in name: - continue - name = name.replace(weight_name, param_name) - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = param.weight_loader - weight_loader(param, - loaded_weight, - name, - shard_id=shard_id, - expert_id=expert_id, - return_success=False) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - - # Remapping the name of FP8 kv-scale. - name = maybe_remap_kv_scale_name(name, params_dict) - if name is None: - continue - - if is_pp_missing_parameter(name, self): - continue - - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params - - -class CustomDeepseekV3ForCausalLM(CustomDeepseekV2ForCausalLM): - pass - - -DeepseekV2DecoderLayer.__init__ = CustomDeepseekV2DecoderLayer.__init__ diff --git a/vllm_ascend/models/layers/mla.py b/vllm_ascend/models/layers/mla.py index 5b7bc46..d701631 100644 --- a/vllm_ascend/models/layers/mla.py +++ b/vllm_ascend/models/layers/mla.py @@ -19,97 +19,119 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass from typing import Optional import torch from torch import nn -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, get_current_vllm_config +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.forward_context import ForwardContext, get_forward_context -from vllm.model_executor.layers.mla import MultiHeadLatentAttention +from vllm.model_executor.layers.mla import MLAModules from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.utils import direct_register_custom_op +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.utils import vllm_version_is -@dataclass -class AscendMLAModules: - q_a_proj: Optional[torch.nn.Module] - q_a_layernorm: Optional[torch.nn.Module] - q_proj: Optional[torch.nn.Module] - kv_a_proj_with_mqa: torch.nn.Module - kv_a_layernorm: torch.nn.Module - kv_b_proj: torch.nn.Module - o_proj: torch.nn.Module - rotary_emb: torch.nn.Module - indexer: Optional[torch.nn.Module] - is_sparse: bool +if vllm_version_is("0.11.0"): + from vllm.attention import Attention + from vllm.model_executor.layers.mla import \ + MultiHeadLatentAttention as MultiHeadLatentAttentionWrapper +else: + from vllm.attention.layer import MLAAttention + from vllm.model_executor.layers.mla import MultiHeadLatentAttentionWrapper -class AscendMultiHeadLatentAttention(MultiHeadLatentAttention): +# TODO(whx): adapt v0.11.0 and DSA +class AscendMultiHeadLatentAttention(MultiHeadLatentAttentionWrapper): def __init__( self, hidden_size: int, - enable_shared_expert_dp: bool, - debug_layer_idx: int, - first_k_dense_replace: int, - tp_size: int, - mla_modules: AscendMLAModules, - num_local_heads: int, - scaling: float, - layers: int, - kv_lora_rank: int, - qk_rope_head_dim: int, - q_lora_rank: Optional[int], + num_heads: int, + scale: float, qk_nope_head_dim: int, - qk_head_dim: int, + qk_rope_head_dim: int, v_head_dim: int, + q_lora_rank: Optional[int], + kv_lora_rank: int, + mla_modules: MLAModules, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ) -> None: nn.Module.__init__(self) self.hidden_size = hidden_size - self.enable_shared_expert_dp = enable_shared_expert_dp - self.debug_layer_idx = debug_layer_idx - self.first_k_dense_replace = first_k_dense_replace - self.tp_size = tp_size - self.num_local_heads = num_local_heads - self.layers = layers self.kv_lora_rank = kv_lora_rank self.qk_rope_head_dim = qk_rope_head_dim self.q_lora_rank = q_lora_rank self.qk_nope_head_dim = qk_nope_head_dim - self.qk_head_dim = qk_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim self.v_head_dim = v_head_dim self.prefix = prefix + hf_config = get_current_vllm_config().model_config.hf_config + self.enable_shared_expert_dp = get_ascend_config( + ).enable_shared_expert_dp + self.debug_layer_idx = int(self.prefix.split(".")[-2]) + self.first_k_dense_replace = hf_config.first_k_dense_replace + self.tp_size = get_tensor_model_parallel_world_size() + self.layers = hf_config.num_hidden_layers - self.mla_attn = Attention( - num_heads=self.num_local_heads, - head_size=self.kv_lora_rank + self.qk_rope_head_dim, - scale=scaling, - num_kv_heads=1, - cache_config=cache_config, - quant_config=quant_config, - prefix=f"{prefix}.attn", - use_mla=True, - # MLA Args - q_lora_rank=self.q_lora_rank, - kv_lora_rank=self.kv_lora_rank, - qk_nope_head_dim=self.qk_nope_head_dim, - qk_rope_head_dim=self.qk_rope_head_dim, - qk_head_dim=self.qk_head_dim, - v_head_dim=self.v_head_dim, - rotary_emb=mla_modules.rotary_emb, - q_a_proj=mla_modules.q_a_proj, - q_a_layernorm=mla_modules.q_a_layernorm, - q_proj=mla_modules.q_proj, - kv_a_proj_with_mqa=mla_modules.kv_a_proj_with_mqa, - kv_a_layernorm=mla_modules.kv_a_layernorm, - kv_b_proj=mla_modules.kv_b_proj, - o_proj=mla_modules.o_proj, - ) + if vllm_version_is("0.11.0"): + self.mla_attn = Attention( + num_heads=num_heads, + head_size=self.kv_lora_rank + self.qk_rope_head_dim, + scale=scale, + num_kv_heads=1, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_mla=True, + # MLA Args + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + v_head_dim=self.v_head_dim, + qk_head_dim=self.qk_head_dim, + rotary_emb=mla_modules.rotary_emb, + fused_qkv_a_proj=mla_modules.fused_qkv_a_proj, + q_b_proj=mla_modules.q_b_proj, + q_a_layernorm=mla_modules.q_a_layernorm, + q_proj=mla_modules.q_proj, + kv_a_proj_with_mqa=mla_modules.kv_a_proj_with_mqa, + kv_a_layernorm=mla_modules.kv_a_layernorm, + kv_b_proj=mla_modules.kv_b_proj, + o_proj=mla_modules.o_proj, + ) + else: + self.mla_attn = MLAAttention( + num_heads=self.num_heads, + scale=scale, + head_size=self.kv_lora_rank + self.qk_rope_head_dim, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + v_head_dim=self.v_head_dim, + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + kv_b_proj=mla_modules.kv_b_proj, + use_sparse=mla_modules.is_sparse, + indexer=mla_modules.indexer, + # extra args + qk_head_dim=self.qk_head_dim, + rotary_emb=mla_modules.rotary_emb, + fused_qkv_a_proj=mla_modules.fused_qkv_a_proj, + q_b_proj=mla_modules.q_b_proj, + q_a_layernorm=mla_modules.q_a_layernorm, + q_proj=mla_modules.q_proj, + kv_a_proj_with_mqa=mla_modules.kv_a_proj_with_mqa, + kv_a_layernorm=mla_modules.kv_a_layernorm, + o_proj=mla_modules.o_proj, + ) compilation_config = get_current_vllm_config().compilation_config if prefix in compilation_config.static_forward_context: diff --git a/vllm_ascend/ops/common_fused_moe.py b/vllm_ascend/ops/common_fused_moe.py index bac07b2..1335be5 100644 --- a/vllm_ascend/ops/common_fused_moe.py +++ b/vllm_ascend/ops/common_fused_moe.py @@ -182,6 +182,10 @@ class AscendFusedMoE(FusedMoE): self.expert_map_path = ascend_config.expert_map_path self.global_redundant_expert_num = ascend_config.init_redundancy_expert self.global_num_experts = num_experts + self.global_redundant_expert_num + if self.custom_routing_function is None and self.e_score_correction_bias is not None: + vllm_config = get_current_vllm_config() + self.e_score_correction_bias.data = self.e_score_correction_bias.data.to( + dtype=vllm_config.model_config.dtype) # static eplb initializing with expert_map_path if self.expert_map_path and os.path.exists( self.expert_map_path) and os.access(self.expert_map_path, diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index f484400..c8d8f1f 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -176,12 +176,14 @@ packed_modules_model_mapping = { "deepseek_v2": { "gate_up_proj": ["gate_proj", "up_proj"], "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"] }, "deepseek_v3": { "gate_up_proj": ["gate_proj", "up_proj"], "experts": - ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"], + "fused_qkv_a_proj": ["q_a_proj", "kv_a_proj_with_mqa"] }, "deepseek_v32": { "gate_up_proj": ["gate_proj", "up_proj"], @@ -274,6 +276,7 @@ class AscendLinearMethod(LinearMethodBase): # disable warning param.ignore_warning = True layer.register_parameter(pertensor_name, param) + param.weight_loader = extra_weight_attrs.get("weight_loader") perchannel_dict = self.quant_method.get_perchannel_param( output_size_per_partition, params_dtype) diff --git a/vllm_ascend/torchair/models/torchair_deepseek_v2.py b/vllm_ascend/torchair/models/torchair_deepseek_v2.py index 462e05b..7f09c52 100644 --- a/vllm_ascend/torchair/models/torchair_deepseek_v2.py +++ b/vllm_ascend/torchair/models/torchair_deepseek_v2.py @@ -570,7 +570,8 @@ class TorchairDeepseekV2MLAAttention(DeepseekV2MLAAttention): qk_head_dim=self.qk_head_dim, v_head_dim=self.v_head_dim, rotary_emb=self.rotary_emb, - q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, + q_proj=self.q_proj if self.q_lora_rank is None else None, + q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None, kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, kv_a_layernorm=self.kv_a_layernorm, kv_b_proj=self.kv_b_proj, diff --git a/vllm_ascend/torchair/torchair_mla.py b/vllm_ascend/torchair/torchair_mla.py index 4269727..57179c9 100644 --- a/vllm_ascend/torchair/torchair_mla.py +++ b/vllm_ascend/torchair/torchair_mla.py @@ -656,7 +656,8 @@ class AscendMLATorchairImpl(MLAAttentionImpl): self.qk_head_dim = kwargs['qk_head_dim'] self.v_head_dim = kwargs['v_head_dim'] self.rotary_emb = kwargs['rotary_emb'] - self.q_proj = kwargs['q_proj'] + self.q_proj = kwargs['q_proj'] if self.q_lora_rank is None else kwargs[ + 'q_b_proj'] self.kv_b_proj = kwargs['kv_b_proj'] self.o_proj = kwargs['o_proj'] self.kv_a_proj_with_mqa = kwargs.get('kv_a_proj_with_mqa', None)