From 3fc31ee1cbdf0c0d11efc4da5fd865bb2d077c4b Mon Sep 17 00:00:00 2001 From: linfeng-yuan <1102311262@qq.com> Date: Mon, 18 Aug 2025 15:00:37 +0800 Subject: [PATCH] [1/N][refactor] torchair deepseek modeling refactor (#2384) ### What this PR does / why we need it? Move torchair related model arch into torchair moduel to make the code clear. Next step we'll remove all torchair related code outside of torchair moduel. ### Does this PR introduce _any_ user-facing change? No. - vLLM version: v0.10.0 - vLLM main: https://github.com/vllm-project/vllm/commit/08d5f7113a024818b2867782c2539794b7aa162b Signed-off-by: linfeng-yuan <1102311262@qq.com> --- .../models/test_torchair_deepseek_mtp.py | 180 +++ .../models/test_torchair_deepseek_v2.py | 324 +++++ tests/ut/torchair/test_utils.py | 45 + vllm_ascend/torchair/models/__init__.py | 0 .../torchair/models/torchair_deepseek_mtp.py | 218 ++++ .../torchair/models/torchair_deepseek_v2.py | 1047 +++++++++++++++++ .../torchair/models/torchair_deepseek_v3.py | 28 + vllm_ascend/torchair/torchair_model_runner.py | 2 + vllm_ascend/torchair/utils.py | 19 + 9 files changed, 1863 insertions(+) create mode 100644 tests/ut/torchair/models/test_torchair_deepseek_mtp.py create mode 100644 tests/ut/torchair/models/test_torchair_deepseek_v2.py create mode 100644 vllm_ascend/torchair/models/__init__.py create mode 100644 vllm_ascend/torchair/models/torchair_deepseek_mtp.py create mode 100644 vllm_ascend/torchair/models/torchair_deepseek_v2.py create mode 100644 vllm_ascend/torchair/models/torchair_deepseek_v3.py diff --git a/tests/ut/torchair/models/test_torchair_deepseek_mtp.py b/tests/ut/torchair/models/test_torchair_deepseek_mtp.py new file mode 100644 index 0000000..1c1e6c7 --- /dev/null +++ b/tests/ut/torchair/models/test_torchair_deepseek_mtp.py @@ -0,0 +1,180 @@ +import pytest +import torch +from pytest_mock import MockerFixture +from transformers import PretrainedConfig +from vllm.config import CacheConfig, ModelConfig, VllmConfig + +from tests.ut.base import PytestBase +from vllm_ascend.torchair.models.torchair_deepseek_mtp import ( + TorchairDeepSeekMTP, TorchairDeepSeekMultiTokenPredictor, + TorchairDeepSeekMultiTokenPredictorLayer) + + +class TestTorchairDeepSeekMultiTokenPredictorLayer(PytestBase): + + @pytest.fixture + def setup_mtp_layer(self, mocker: MockerFixture): + config = PretrainedConfig(vocab_size=1000, + hidden_size=768, + rms_norm_eps=1e-5) + mocker.patch( + "vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__", + return_value=None) + mocker.patch("vllm.model_executor.layers.layernorm.RMSNorm.__init__", + return_value=None) + mocker.patch( + "vllm.model_executor.models.deepseek_mtp.SharedHead.__init__", + return_value=None) + mocker.patch( + "vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekShareHead.__init__", + return_value=None) + mocker_deepseek_v2_decode_layer = mocker.patch( + "vllm_ascend.torchair.models.torchair_deepseek_v2.TorchairDeepseekV2DecoderLayer.__init__", + return_value=None) + + mtp_layer = TorchairDeepSeekMultiTokenPredictorLayer(config, "", None) + mocker_deepseek_v2_decode_layer.assert_called_once() + return mtp_layer + + def test_init(self, mocker: MockerFixture, setup_mtp_layer): + mtp_layer = setup_mtp_layer + assert isinstance(mtp_layer, TorchairDeepSeekMultiTokenPredictorLayer) + + def test_forward(self, mocker: MockerFixture, setup_mtp_layer): + mtp_layer = setup_mtp_layer + mocker.patch("torch.nn.Module.__setattr__") + mocker.patch("torch.nn.Module.__getattr__") + mocker.patch("torch.nn.Module.__delattr__") + mocker.patch.object(mtp_layer, + 'eh_proj', + return_value=torch.randn(2, 3, 768)) + mocker.patch("torch.cat", return_value=torch.randn(2, 3, 768)) + mtp_layer.mtp_block.return_value = (torch.randn(2, 3, 768), + torch.randn(2, 3, 768)) + + input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]) + positions = torch.tensor([[0, 1, 2], [0, 1, 2]]) + kv_cache = torch.randn(2, 3, 768) + previous_hidden_states = torch.randn(2, 3, 768) + inputs_embeds = torch.tensor([[1.0, 2.0, 3.0]]) + + output = mtp_layer(input_ids, positions, kv_cache, None, + previous_hidden_states, inputs_embeds, 0) + assert output.shape == (2, 3, 768) + + +class TestTorchairDeepSeekMultiTokenPredictor(PytestBase): + + @pytest.fixture + def setup_predictor(self, mocker: MockerFixture): + mock_vllm_config = mocker.MagicMock(spec=VllmConfig) + mock_model_config = mocker.MagicMock(spec=ModelConfig) + mock_hf_config = mocker.MagicMock() + mock_hf_config.num_hidden_layers = 12 + mock_hf_config.num_nextn_predict_layers = 3 + mock_hf_config.vocab_size = 30000 + mock_model_config.hf_config = mock_hf_config + mock_vllm_config.model_config = mock_model_config + mock_vllm_config.cache_config = CacheConfig() + mock_vllm_config.quant_config = mocker.MagicMock() + mocker.patch( + "vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__", + return_value=None) + mocker.patch( + "vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__init__", + return_value=None) + + predictor = TorchairDeepSeekMultiTokenPredictor( + vllm_config=mock_vllm_config) + return predictor + + def test_init(self, mocker: MockerFixture, setup_predictor): + predictor = setup_predictor + assert predictor.num_mtp_layers == 3 + assert isinstance(predictor, TorchairDeepSeekMultiTokenPredictor) + + @pytest.mark.parametrize( + 'kv_caches, inputs_embeds', + [(torch.tensor([[[0.1, 0.2, 0.3]]]), torch.tensor([[0.1, 0.2, 0.3]]))]) + def test_forward(self, mocker: MockerFixture, setup_predictor, kv_caches, + inputs_embeds): + predictor = setup_predictor + mock_layer = mocker.MagicMock() + mock_layer.return_value = torch.tensor([1.0, 2.0, 3.0]) + predictor.layers_list = [mock_layer] + + # todo: need or not? + # predictor.num_mtp_layers = 1 + input_ids = torch.tensor([[1, 2, 3]]) + positions = torch.tensor([[0, 1, 2]]) + mocker.patch( + "vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__call__", + return_value=torch.tensor([[1.0, 2.0, 3.0]])) + output = predictor.forward(input_ids, positions, kv_caches, None, None, + inputs_embeds, 0) + mock_layer.assert_called_once() + assert torch.allclose(output, torch.tensor([1.0, 2.0, 3.0])) + + def test_compute_logits(self, mocker: MockerFixture, setup_predictor): + hidden_states = torch.tensor([[1, 2, 3], [4, 5, 6]]) + predictor = setup_predictor + + mock_layer = mocker.MagicMock() + mock_layer.return_value = torch.tensor([1.0, 2.0, 3.0]) + predictor.layers_list = [mock_layer] + mocker.patch("torch.nn.Module.__setattr__") + mocker.patch("torch.nn.Module.__getattr__") + mocker.patch("torch.nn.Module.__delattr__") + mocker.patch( + "vllm.model_executor.layers.logits_processor.LogitsProcessor.__init__", + return_value=None) + predictor.logits_processor.return_value = torch.tensor([1.0, 2.0, 3.0]) + + result_logits = predictor.compute_logits(hidden_states=hidden_states, + sampling_metadata=None) + predictor.logits_processor.assert_called_once() + assert torch.allclose(result_logits, torch.tensor([1.0, 2.0, 3.0])) + + +class TestTorchairDeepSeekMTP(PytestBase): + + @pytest.fixture + def setup_mtp(self, mocker: MockerFixture): + vllm_config = mocker.MagicMock() + vllm_config.model_config.hf_config.num_hidden_layers = 12 + vllm_config.model_config.hf_config.num_nextn_predict_layers = 3 + vllm_config.cache_config = mocker.MagicMock() + vllm_config.quant_config = mocker.MagicMock() + + mocker.patch("torch.nn.Module.__setattr__") + mocker.patch("torch.nn.Module.__getattr__") + mocker.patch("torch.nn.Module.__delattr__") + mocker.patch( + "vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__", + return_value=None) + mocker.patch( + "vllm_ascend.torchair.models.torchair_deepseek_mtp.TorchairDeepSeekMultiTokenPredictorLayer.__call__", + return_value=None) + mocker.patch("vllm.model_executor.layers.sampler.get_sampler", + return_value=None) + + mtp = TorchairDeepSeekMTP(vllm_config=vllm_config) + return mtp + + def test_init(self, mocker: MockerFixture, setup_mtp): + mtp = setup_mtp + assert isinstance(mtp, TorchairDeepSeekMTP) + + def test_forward(self, mocker: MockerFixture, setup_mtp): + input_ids = torch.tensor([[1, 2, 3]]) + positions = torch.tensor([[0, 1, 2]]) + kv_caches = [torch.tensor([[0.1, 0.2, 0.3]])] + previous_hidden_states = torch.tensor([[0.1, 0.2, 0.3]]) + inputs_embeds = torch.tensor([[0.1, 0.2, 0.3]]) + spec_step_idx = 0 + setup_mtp.model.return_value = torch.tensor([[1.0, 2.0, 3.0]]) + + output = setup_mtp.forward(input_ids, positions, kv_caches, None, + previous_hidden_states, inputs_embeds, + spec_step_idx) + assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]])) \ No newline at end of file diff --git a/tests/ut/torchair/models/test_torchair_deepseek_v2.py b/tests/ut/torchair/models/test_torchair_deepseek_v2.py new file mode 100644 index 0000000..0a4ae8c --- /dev/null +++ b/tests/ut/torchair/models/test_torchair_deepseek_v2.py @@ -0,0 +1,324 @@ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +from types import SimpleNamespace +from unittest.mock import Mock, patch + +import pytest +import torch +from transformers import PretrainedConfig +from vllm.config import CacheConfig +from vllm.distributed.parallel_state import GroupCoordinator + +from vllm_ascend.torchair.models.torchair_deepseek_v2 import ( + TorchairDeepseekV2DecoderLayer, TorchairDeepseekV2ForCausalLM, + TorchairDeepseekV2MergedReplicatedLinear, TorchairDeepseekV2MLAAttention, + TorchairDeepseekV2MLP, TorchairDeepseekV2MoE, + TorchairDeepseekV2RowParallelLinear, + TorchairDeepseekV2RowParallelLinearReplaceAllreduce, + TorchairDeepseekV2SiluAndMul) + + +@pytest.fixture +def base_config(): + config = PretrainedConfig( + hidden_size=128, + num_attention_heads=8, + num_hidden_layers=2, + intermediate_size=256, + hidden_act="silu", + rms_norm_eps=1e-6, + rope_theta=10000.0, + max_position_embeddings=2048, + n_routed_experts=4, + n_shared_experts=1, + moe_intermediate_size=256, + num_experts_per_tok=2, + routed_scaling_factor=1.0, + first_k_dense_replace=0, + moe_layer_freq=1, + kv_lora_rank=16, + qk_nope_head_dim=16, + qk_rope_head_dim=16, + v_head_dim=32, + topk_method="noaux_tc", + scoring_func="softmax", + norm_topk_prob=True, + n_group=1, + topk_group=1, + vocab_size=10000, + ) + return config + + +@pytest.fixture +def vllm_config(base_config): + model_config = SimpleNamespace( + hf_config=base_config, + tensor_parallel_size=1, + dtype=torch.float32, + use_mla=False, + quant_config=None, + max_model_len=2048, + ) + + cache_config = CacheConfig() + vllm_config = Mock() + vllm_config.model_config = model_config + vllm_config.cache_config = cache_config + vllm_config.quant_config = None + return vllm_config + + +@pytest.fixture +def mock_distributed(): + tp_group = Mock(spec=GroupCoordinator) + tp_group.rank_in_group = 0 + tp_group.world_size = 1 + tp_group.device_group = Mock() + + dp_group = Mock(spec=GroupCoordinator) + dp_group.rank_in_group = 0 + dp_group.world_size = 1 + + ep_group = Mock(spec=GroupCoordinator) + ep_group.rank_in_group = 0 + ep_group.world_size = 1 + + pp_group = Mock(spec=GroupCoordinator) + pp_group.rank_in_group = 0 + pp_group.world_size = 1 + + mock_vllm_config = Mock() + mock_vllm_config.scheduler_config = Mock(max_num_seqs=256) + mock_vllm_config.model_config = Mock(max_model_len=2048, quant_config=None) + + with patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_tensor_model_parallel_rank", return_value=0), \ + patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_tensor_model_parallel_world_size", return_value=1), \ + patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_tp_group", return_value=tp_group), \ + patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_ep_group", return_value=ep_group), \ + patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_dp_group", return_value=dp_group), \ + patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_pp_group", return_value=pp_group), \ + patch("vllm_ascend.torchair.models.torchair_deepseek_v2.get_pp_group", + return_value=Mock(is_first_rank=False, is_last_rank=False)), \ + patch("vllm_ascend.ops.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \ + patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group, + _PP=pp_group), \ + patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group): + yield + + +@pytest.fixture +def mock_forward_context(): + forward_context = Mock(in_profile_run=False, with_prefill=False) + with patch( + "vllm_ascend.torchair.models.torchair_deepseek_v2.get_forward_context", + return_value=forward_context): + yield + + +def test_torchair_deepseek_v2_silu_and_mul(): + torch.set_default_device("cpu") + + silu = TorchairDeepseekV2SiluAndMul() + assert silu.weight_scale is None + + x = torch.randn(2, 4) + output = silu.forward_oot(x) + assert output.shape == (2, 2) + + weight_scale = Mock(return_value=torch.tensor(0.1)) + silu = TorchairDeepseekV2SiluAndMul(weight_scale=weight_scale) + quant_x = torch.randint(-128, 127, (2, 4), dtype=torch.int32) + dynamic_scale = torch.randn(2, 1) + with patch("torch_npu.npu_dequant_swiglu_quant", + return_value=torch.randn(2, 4)): + output = silu.forward_oot((quant_x, dynamic_scale)) + assert output.shape == (2, 4) + + +def test_torchair_deepseek_v2_merged_replicated_linear(mock_distributed): + linear = TorchairDeepseekV2MergedReplicatedLinear(input_size=128, + output_sizes=[64, 64], + bias=False, + quant_config=None) + assert linear.output_sizes == [64, 64] + + param = Mock() + param.data = torch.zeros(128, 128) + param.output_dim = 1 + param.is_gguf_weight = False + param.is_gguf_weight_type = False + loaded_weight = torch.randn(128, 64) + linear.weight_loader(param, loaded_weight, loaded_shard_id=0) + + with pytest.raises(AssertionError): + linear.weight_loader(param, torch.randn(128, 32), loaded_shard_id=0) + + +@pytest.mark.parametrize("cls", [ + TorchairDeepseekV2RowParallelLinearReplaceAllreduce, + TorchairDeepseekV2RowParallelLinear +]) +def test_row_parallel_linear(cls, mock_distributed): + linear = cls(input_size=128, output_size=64, bias=False, quant_config=None) + linear.quant_method = Mock() + linear.quant_method.apply.return_value = torch.randn(2, 4, 64) + + input_ = torch.randn(2, 4, 128) + with patch( + "vllm_ascend.torchair.models.torchair_deepseek_v2.split_tensor_along_last_dim", + return_value=[torch.randn(2, 4, 64)]): + linear.input_is_parallel = False + output = linear(input_, is_prefill=True) + assert output[0].shape == (2, 4, 64) + + linear.input_is_parallel = True + output = linear(input_, is_prefill=False) + assert output[0].shape == (2, 4, 64) + + +def test_torchair_deepseek_v2_mlp(mock_distributed, base_config): + mlp = TorchairDeepseekV2MLP(hidden_size=128, + intermediate_size=256, + hidden_act="silu", + quant_config=None) + assert isinstance(mlp.act_fn, TorchairDeepseekV2SiluAndMul) + + x = torch.randn(2, 4, 128) + output = mlp(x) + assert output.shape == (2, 4, 128) + + with patch( + "vllm_ascend.torchair.models.torchair_deepseek_v2.QuantizationConfig" + ) as mock_quant_config: + mock_quant_config.name = "w8a8dynamic" + with pytest.raises(NotImplementedError): + TorchairDeepseekV2MLP(hidden_size=128, + intermediate_size=256, + hidden_act="silu", + quant_config=mock_quant_config, + force_replicate=False) + with pytest.raises(ValueError): + TorchairDeepseekV2MLP(hidden_size=128, + intermediate_size=256, + hidden_act="relu", + quant_config=None) + + +def test_torchair_deepseek_v2_moe(mock_distributed, base_config, + mock_forward_context): + base_config.n_shared_experts = 1 + moe = TorchairDeepseekV2MoE(config=base_config, + quant_config=None, + prefix="mlp") + assert moe.top_k == 2 + + x = torch.randn(2, 4, 128) + attn_metadata = Mock(num_prefills=1) + with patch("vllm_ascend.ops.fused_moe.AscendFusedMoE.__call__", + return_value=(torch.randn(2, 4, 128), torch.randn(2, 4, 128))): + output = moe(x, attn_metadata) + assert output.shape == (2, 4, 128) + + +@patch("torch_npu.npu_rms_norm") +def test_torchair_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed, + base_config): + mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128)) + + attn = TorchairDeepseekV2MLAAttention(config=base_config, + hidden_size=128, + num_heads=8, + qk_nope_head_dim=16, + qk_rope_head_dim=16, + v_head_dim=32, + q_lora_rank=16, + kv_lora_rank=16, + cache_config=CacheConfig(), + quant_config=None, + prefix="layers.0.self_attn") + assert attn.debug_layer_idx == 0 + + x = torch.randn(2, 4, 128) + positions = torch.arange(4).repeat(2, 1) + with patch.object(attn.mla_attn, + "__call__", + return_value=torch.randn(2, 4, 128)): + with pytest.raises(AssertionError): + attn(positions, x) + + attn = TorchairDeepseekV2MLAAttention(config=base_config, + hidden_size=128, + num_heads=8, + qk_nope_head_dim=16, + qk_rope_head_dim=16, + v_head_dim=32, + q_lora_rank=None, + kv_lora_rank=16, + prefix="layers.1.self_attn") + assert hasattr(attn, "q_proj") + + +@patch("torch_npu.npu_add_rms_norm") +@patch("torch_npu.npu_rms_norm") +def test_torchair_deepseek_v2_decoder_layer(mock_rms_norm, mock_add_norm, + mock_distributed, base_config, + vllm_config): + mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128)) + mock_add_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128), + torch.randn(2, 128)) + base_config.n_routed_experts = 4 + layer = TorchairDeepseekV2DecoderLayer( + config=base_config, + prefix="layers.0", + model_config=vllm_config.model_config, + cache_config=CacheConfig(), + quant_config=None) + assert isinstance(layer.mlp, TorchairDeepseekV2MoE) + + x = torch.randn(2, 4, 128) + positions = torch.arange(4).repeat(2, 1) + + with patch.object(layer.self_attn, "forward", Mock(return_value=torch.randn(2, 4, 128))), \ + patch.object(layer.mlp, "forward", Mock(return_value=torch.randn(2, 4, 128))): + hidden_states, residual = layer(positions, x, None) + assert hidden_states.shape == (2, 4, 128) + + base_config.n_routed_experts = None + layer = TorchairDeepseekV2DecoderLayer( + config=base_config, + prefix="layers.0", + model_config=vllm_config.model_config, + quant_config=None) + assert isinstance(layer.mlp, TorchairDeepseekV2MLP) + + +def test_torchair_deepseek_v2_for_causal_lm(mock_distributed, vllm_config): + model = TorchairDeepseekV2ForCausalLM(vllm_config=vllm_config) + + input_ids = torch.randint(0, 10000, (2, 4)) + positions = torch.arange(4).repeat(2, 1) + with patch.object(model.model, + "forward", + return_value=torch.randn(2, 4, 128)): + output = model(input_ids, positions) + assert output.shape == (2, 4, 128) + + weights = [("model.embed_tokens.weight", torch.randn(10000, 128))] + with patch( + "vllm.model_executor.model_loader.weight_utils.default_weight_loader" + ): + loaded = model.load_weights(weights) + assert loaded is not None diff --git a/tests/ut/torchair/test_utils.py b/tests/ut/torchair/test_utils.py index b2b3c65..367a9a4 100644 --- a/tests/ut/torchair/test_utils.py +++ b/tests/ut/torchair/test_utils.py @@ -1,4 +1,6 @@ import os +from concurrent.futures import ThreadPoolExecutor +from unittest.mock import MagicMock, patch from tests.ut.base import TestBase from vllm_ascend.torchair import utils @@ -26,3 +28,46 @@ class TestTorchairUtils(TestBase): "Delete torchair cache dir failed") self.assertFalse(utils.check_kv_cache_bytes_cache_exist(), "Delete kv cache bytes cache dir failed") + + def test_torchair_cache_dir_multiple_ranks(self): + ranks = [0, 1, 2, 3] + values = [100, 200, 300, 400] + + with ThreadPoolExecutor() as executor: + executor.map(utils.write_kv_cache_bytes_to_file, ranks, values) + for rank, expected in zip(ranks, values): + self.assertEqual(expected, + utils.read_kv_cache_bytes_from_file(rank)) + utils.delete_torchair_cache_file() + + self.assertFalse(utils.check_torchair_cache_exist(), + "Delete torchair cache dir failed") + self.assertFalse(utils.check_kv_cache_bytes_cache_exist(), + "Delete kv cache bytes cache dir failed") + + @patch('vllm.ModelRegistry') + def test_register_torchair_model(self, mock_model_registry): + mock_registry = MagicMock() + mock_model_registry.return_value = mock_registry + utils.register_torchair_model() + + self.assertEqual(mock_model_registry.register_model.call_count, 3) + call_args_list = mock_model_registry.register_model.call_args_list + + expected_registrations = [ + ("DeepSeekMTPModel", + "vllm_ascend.torchair.models.torchair_deepseek_mtp:TorchairDeepSeekMTP" + ), + ("DeepseekV2ForCausalLM", + "vllm_ascend.torchair.models.torchair_deepseek_v2:TorchairDeepseekV2ForCausalLM" + ), + ("DeepseekV3ForCausalLM", + "vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM" + ) + ] + + for i, (expected_name, + expected_path) in enumerate(expected_registrations): + args, kwargs = call_args_list[i] + self.assertEqual(args[0], expected_name) + self.assertEqual(args[1], expected_path) diff --git a/vllm_ascend/torchair/models/__init__.py b/vllm_ascend/torchair/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_ascend/torchair/models/torchair_deepseek_mtp.py b/vllm_ascend/torchair/models/torchair_deepseek_mtp.py new file mode 100644 index 0000000..6cb98a5 --- /dev/null +++ b/vllm_ascend/torchair/models/torchair_deepseek_mtp.py @@ -0,0 +1,218 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Adapted from vllm/model_executor/models/deepseek_mtp.py +# Copyright 2023 The vLLM team. +# +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +import torch +import torch.nn as nn +from transformers import PretrainedConfig +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.config import CacheConfig, ModelConfig, VllmConfig +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.models.deepseek_mtp import ( + DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer, + SharedHead) +from vllm.model_executor.models.utils import maybe_prefix +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from vllm_ascend.torchair.models.torchair_deepseek_v2 import \ + TorchairDeepseekV2DecoderLayer + + +class TorchairDeepSeekShareHead(SharedHead): + + def __init__(self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + nn.Module.__init__(self) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "head")) + + +class TorchairDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer + ): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + model_config: ModelConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + nn.Module.__init__(self) + + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.eh_proj = nn.Linear(config.hidden_size * 2, + config.hidden_size, + bias=False) + self.shared_head = TorchairDeepSeekShareHead(config=config, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, + "shared_head")) + self.mtp_block = TorchairDeepseekV2DecoderLayer( + config, prefix, model_config, cache_config, quant_config) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_index: int = 0, + ) -> torch.Tensor: + assert inputs_embeds is not None + # masking inputs at position 0, as not needed by MTP + inputs_embeds = torch.where((positions == 0).unsqueeze(-1), + torch.zeros_like(inputs_embeds), + inputs_embeds) + inputs_embeds = self.enorm(inputs_embeds) + previous_hidden_states = self.hnorm(previous_hidden_states) + + hidden_states = self.eh_proj( + torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) + + hidden_states, residual = self.mtp_block(positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + residual=None) + hidden_states = residual + hidden_states + return hidden_states + + +class TorchairDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor): + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + config = vllm_config.model_config.hf_config + self.mtp_start_layer_idx = config.num_hidden_layers + self.num_mtp_layers = config.num_nextn_predict_layers + # to map the exact layer index from weights + self.layers = torch.nn.ModuleDict({ + str(idx): + TorchairDeepSeekMultiTokenPredictorLayer( + config, + f"{prefix}.layers.{idx}", + model_config=vllm_config.model_config, + cache_config=vllm_config.cache_config, + quant_config=vllm_config.quant_config, + ) + for idx in range(self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers) + }) + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + ) + + # Note: torch._dynamo.exc.Unsupported: builtin: str + self.layers_list = [ + self.layers[str(idx)] + for idx in range(self.mtp_start_layer_idx, + self.mtp_start_layer_idx + self.num_mtp_layers) + ] + self.logits_processor = LogitsProcessor(config.vocab_size) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: torch.Tensor, + attn_metadata: AttentionMetadata, + previous_hidden_states: torch.Tensor, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + current_step_idx = (spec_step_idx % self.num_mtp_layers) + step_kv_cache = kv_caches[ + current_step_idx] if kv_caches is not None else None + return self.layers_list[current_step_idx]( + input_ids, + positions, + step_kv_cache, + attn_metadata, + previous_hidden_states, + inputs_embeds, + current_step_idx, + ) + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + spec_step_idx: int = 0, + ) -> torch.Tensor: + current_step_idx = (spec_step_idx % self.num_mtp_layers) + mtp_layer = self.layers_list[current_step_idx] + logits = self.logits_processor(mtp_layer.shared_head.head, + mtp_layer.shared_head(hidden_states), + sampling_metadata) + return logits + + +class TorchairDeepSeekMTP(DeepSeekMTP): + # NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized; + # NOTE 2.The description file generated by the current msmodelslim tool does not have + # MTP layer info. Please manually add it and set the value to FLOAT. + packed_modules_mapping = { + "gate_up_proj": ["gate_proj", "up_proj"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + self.config = vllm_config.model_config.hf_config + self.model = TorchairDeepSeekMultiTokenPredictor( + vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")) + + self.sampler = get_sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, + previous_hidden_states: Optional[torch.Tensor] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + spec_step_idx: int = 0, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, previous_hidden_states, + inputs_embeds, spec_step_idx) + return hidden_states diff --git a/vllm_ascend/torchair/models/torchair_deepseek_v2.py b/vllm_ascend/torchair/models/torchair_deepseek_v2.py new file mode 100644 index 0000000..c695796 --- /dev/null +++ b/vllm_ascend/torchair/models/torchair_deepseek_v2.py @@ -0,0 +1,1047 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# # Adapted from +# # vllm-project/vllm/blob/main/vllm/model_executor/models/deepseek_v2.py +# # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py +# """Inference-only DeepseekV2/DeepseekV3 model.""" + +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union + +import torch +import torch_npu +from torch import nn +from transformers import PretrainedConfig +from vllm.attention import Attention, AttentionMetadata +from vllm.config import (CacheConfig, ModelConfig, VllmConfig, + get_current_vllm_config) +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + get_tp_group, split_tensor_along_last_dim, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, + tensor_model_parallel_reduce_scatter) +from vllm.distributed.parallel_state import get_dp_group, get_ep_group +from vllm.forward_context import get_forward_context +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + ReplicatedLinear, + RowParallelLinear, + UnquantizedLinearMethod) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.deepseek_v2 import \ + DeepseekV2ForCausalLM # noqa: E501 +from vllm.model_executor.models.deepseek_v2 import \ + yarn_get_mscale # noqa: E501 +from vllm.model_executor.models.deepseek_v2 import ( + DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2MLAAttention, + get_spec_layer_idx_from_weight_name) +from vllm.model_executor.models.utils import ( + PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +from vllm.sequence import IntermediateTensors + +from vllm_ascend.ascend_config import get_ascend_config +from vllm_ascend.ops.fused_moe import AscendFusedMoE +from vllm_ascend.quantization.quant_config import AscendLinearMethod +from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod +from vllm_ascend.utils import dispose_tensor, npu_prefetch + + +class TorchairDeepseekV2SiluAndMul(SiluAndMul): + + def __init__(self, + *, + weight_scale: Optional[Callable[[], torch.Tensor]] = None): + super().__init__() + self.weight_scale = weight_scale + + def forward_oot(self, x: Union[torch.Tensor, Tuple[torch.Tensor, + torch.Tensor]]): + if isinstance(x, tuple): + assert self.weight_scale is not None + # For AscendW8A8DynamicLinearMethod: + # a dynamic scale is passed along with the quantized value. + quantized_x, dynamic_scale = x + return torch_npu.npu_dequant_swiglu_quant( + x=quantized_x, + weight_scale=self.weight_scale(), + activation_scale=dynamic_scale, + activate_left=True, + quant_mode=1) + else: + return super().forward_oot(x) + + +class TorchairDeepseekV2MergedReplicatedLinear(ReplicatedLinear): + + def __init__( + self, + input_size: int, + output_sizes: list[int], + bias: bool = True, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + self.output_sizes = output_sizes + super().__init__(input_size, + sum(output_sizes), + bias=bias, + quant_config=quant_config, + prefix=prefix) + + def weight_loader(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, loaded_shard_id: int): + # With no support for GGUF format yet. + assert not getattr(param, "is_gguf_weight", False) + assert not getattr(param, "is_gguf_weight_type", False) + + assert loaded_shard_id < len(self.output_sizes) + shard_offset = sum(self.output_sizes[:loaded_shard_id]) + shard_size = self.output_sizes[loaded_shard_id] + shard = param.data.narrow(param.output_dim, shard_offset, shard_size) + + assert shard.size() == loaded_weight.size(), ( + f"Tried to load weights of size {loaded_weight.size()}" + f"to a parameter shard of id {loaded_shard_id} size {shard.size()}" + ) + shard.copy_(loaded_weight) + + +class TorchairDeepseekV2RowParallelLinearReplaceAllreduce(RowParallelLinear): + + def forward( + self, + input_, + is_prefill=True, + is_force_scatter=False + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]: + if self.input_is_parallel: + input_parallel = input_ + else: + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[tp_rank].contiguous() + + # Matrix multiply. + assert self.quant_method is not None + # Only fuse bias add into GEMM for rank 0 (this ensures that + # bias will not get added more than once in TP>1 case) + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + output_parallel = self.quant_method.apply(self, + input_parallel, + bias=bias_) + if self.reduce_results and self.tp_size > 1: + num_tokens = output_parallel.shape[0] + if is_force_scatter and num_tokens % self.tp_size: + output_parallel = nn.functional.pad( + output_parallel, (0, 0, 0, -num_tokens % self.tp_size)) + if is_force_scatter or (not is_prefill + and output_parallel.shape[0] % self.tp_size + == 0): + output = tensor_model_parallel_reduce_scatter(output_parallel, + dim=0) + else: + output = tensor_model_parallel_all_reduce(output_parallel) + else: + output = output_parallel + + output_bias = self.bias if self.skip_bias_add else None + + if not self.return_bias: + return output + return output, output_bias + + +class TorchairDeepseekV2RowParallelLinear(RowParallelLinear): + + def forward( + self, + input_, + is_prefill=True, + is_force_scatter=False + ) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]: + if self.input_is_parallel: + input_parallel = input_ + else: + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.tp_size) + input_parallel = splitted_input[tp_rank].contiguous() + + # Matrix multiply. + assert self.quant_method is not None + # Only fuse bias add into GEMM for rank 0 (this ensures that + # bias will not get added more than once in TP>1 case) + bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias + output_parallel = self.quant_method.apply(self, + input_parallel, + bias=bias_) + if self.reduce_results and self.tp_size > 1: + output = tensor_model_parallel_all_reduce(output_parallel) + else: + output = output_parallel + + output_bias = self.bias if self.skip_bias_add else None + + if not self.return_bias: + return output + return output, output_bias + + +class TorchairDeepseekV2MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + reduce_results: bool = True, + force_replicate: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + if not force_replicate: + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + reduce_results=reduce_results, + prefix=f"{prefix}.down_proj") + else: + self.gate_up_proj = TorchairDeepseekV2MergedReplicatedLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = ReplicatedLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.down_proj") + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + + quant_method = self.gate_up_proj.quant_method + if isinstance(quant_method, UnquantizedLinearMethod): + self.act_fn = TorchairDeepseekV2SiluAndMul() + elif (isinstance(quant_method, AscendLinearMethod) and isinstance( + quant_method.quant_method, AscendW8A8DynamicLinearMethod)): + # TODO(sdmyzlp): Currently preserved as before: + # 1. The only quantization supported for silu is W8A8Dynamic + # 2. Output dtype of gate_up/down is fixed to be int32/bfloat16 + # + # Maybe one can implement a better and more general configuration + # scheme, e.g. by somehow passing around the tweaked `quant_config` + self.act_fn = TorchairDeepseekV2SiluAndMul( + # Use lazy binding, for `weight_scale_fp32` is accessible + # only after `process_weights_after_loading`. + weight_scale=lambda: self.gate_up_proj.weight_scale_fp32) + # To be consumed by AscendW8A8DynamicLinearMethod.apply() + self.gate_up_proj._ascend_quant_config = { + "output_dtype": torch.int32, + "pertoken_scale": False, + "return_scale": True, + } + self.down_proj._ascend_quant_config = { + "output_dtype": torch.bfloat16, + "pertoken_scale": True, + "return_scale": False, + } + else: + raise NotImplementedError( + f"Quantization with [{type(quant_method)}] is NOT supported") + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class TorchairDeepseekV2MoE(nn.Module): + + top_k: int + + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.tp_size = get_tensor_model_parallel_world_size() + self.routed_scaling_factor = config.routed_scaling_factor + self.n_shared_experts = config.n_shared_experts + if self.tp_size > config.n_routed_experts: + raise ValueError( + f"Tensor parallel size {self.tp_size} is greater than " + f"the number of experts {config.n_routed_experts}.") + + if config.hidden_act != "silu": + raise ValueError(f"Unsupported activation: {config.hidden_act}. " + "Only silu is supported for now.") + + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.enable_multistream_moe = \ + ascend_config.torchair_graph_config.enable_multistream_moe and \ + self.torchair_graph_enabled + + self.gate = ReplicatedLinear(config.hidden_size, + config.n_routed_experts, + bias=False, + quant_config=None, + prefix=f"{prefix}.gate") + if config.topk_method == "noaux_tc": + self.gate.e_score_correction_bias = nn.Parameter( + torch.empty(config.n_routed_experts)) + else: + self.gate.e_score_correction_bias = None + + self.experts = AscendFusedMoE( + num_experts=config.n_routed_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + reduce_results=False, + renormalize=config.norm_topk_prob, + quant_config=quant_config, + use_grouped_topk=True, + num_expert_group=config.n_group, + topk_group=config.topk_group, + prefix=f"{prefix}.experts", + scoring_func=config.scoring_func, + e_score_correction_bias=self.gate.e_score_correction_bias) + + if config.n_shared_experts is not None: + self.all_reduce_merge = self.experts.all_reduce_merge + reduce_results = not self.all_reduce_merge + intermediate_size = (config.moe_intermediate_size * + config.n_shared_experts) + enable_shared_expert_dp = ascend_config.enable_shared_expert_dp + self.shared_experts = TorchairDeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + reduce_results=reduce_results, + force_replicate=self.enable_multistream_moe + or enable_shared_expert_dp, + prefix=f"{prefix}.shared_experts", + ) + else: + self.shared_experts = None # type: ignore + TorchairDeepseekV2MoE.top_k = config.num_experts_per_tok + + self.dp_size = get_dp_group().world_size + + self.tp_group = get_tp_group().device_group + self.tp_rank = get_tp_group().rank_in_group + self.ep_group = get_ep_group() + self.kv_consumer = None + transfer_config = get_current_vllm_config().kv_transfer_config + if transfer_config is not None: + self.kv_consumer = transfer_config.kv_role == "kv_consumer" + + self.params_dtype = torch.get_default_dtype() + self.rm_router_logits = self.experts.rm_router_logits + + def forward(self, + hidden_states: torch.Tensor, + attn_metadata: Optional[AttentionMetadata] = None, + replace_allreduce: bool = False) -> torch.Tensor: + + forward_context = get_forward_context() + # when profile runs, force experts to load balanced tokens + # to avoid high memory consumption on a single rank. + + enable_force_load_balance = forward_context.in_profile_run + + is_prefill = forward_context.with_prefill + + # If this node is kv_consumer, we force the moe always runs in decode path to make sure + # the behaviour aligned between dummy_run and normal model_execute. + if self.kv_consumer: + is_prefill = False + enable_force_load_balance = False + + # router_logits: (num_tokens, n_experts) + router_logits = None + if not self.rm_router_logits and not self.enable_multistream_moe: + router_logits, _ = self.gate(hidden_states) + + experts_hidden_states = self.experts( + hidden_states=hidden_states, + router_logits=router_logits, + is_prefill=is_prefill, + top_k=TorchairDeepseekV2MoE.top_k, + enable_force_load_balance=enable_force_load_balance, + shared_experts=self.shared_experts, + gate=self.gate, + replace_allreduce=replace_allreduce) + + hidden_states = ( + experts_hidden_states[0] * self.routed_scaling_factor + + experts_hidden_states[1]) + if self.all_reduce_merge: + # When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce + hidden_states = tensor_model_parallel_all_reduce(hidden_states) + + return hidden_states + + +class TorchairDeepseekV2MLAAttention(DeepseekV2MLAAttention): + + def __init__( + self, + config: PretrainedConfig, + hidden_size: int, + num_heads: int, + qk_nope_head_dim: int, + qk_rope_head_dim: int, + v_head_dim: int, + q_lora_rank: Optional[int], + kv_lora_rank: int, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + nn.Module.__init__(self) + self.hidden_size = hidden_size + self.qk_nope_head_dim = qk_nope_head_dim + self.qk_rope_head_dim = qk_rope_head_dim + self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim + self.v_head_dim = v_head_dim + + self.q_lora_rank = q_lora_rank + self.kv_lora_rank = kv_lora_rank + + self.num_heads = num_heads + self.tp_size = get_tensor_model_parallel_world_size() + assert num_heads % self.tp_size == 0 + self.num_local_heads = num_heads // self.tp_size + self.layers = config.num_hidden_layers + self.first_k_dense_replace = config.first_k_dense_replace + + self.scaling = self.qk_head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.prefix = prefix + self.debug_layer_idx = int(self.prefix.split(".")[-2]) + + ascend_config = get_ascend_config() + self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.enable_multistream_mla = \ + ascend_config.torchair_graph_config.enable_multistream_mla + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp + + if self.q_lora_rank is not None: + self.q_a_proj = ReplicatedLinear(self.hidden_size, + self.q_lora_rank, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_a_proj") + self.q_a_layernorm = RMSNorm(self.q_lora_rank, + eps=config.rms_norm_eps) + self.q_b_proj = ColumnParallelLinear(q_lora_rank, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_b_proj") + else: + self.q_proj = ColumnParallelLinear(self.hidden_size, + self.num_heads * + self.qk_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.q_proj") + + self.kv_a_proj_with_mqa = ReplicatedLinear( + self.hidden_size, + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_a_proj_with_mqa") + self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, + eps=config.rms_norm_eps) + self.kv_b_proj = ColumnParallelLinear( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim), + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.kv_b_proj") + if (config.n_routed_experts is not None + and self.debug_layer_idx >= config.first_k_dense_replace + and self.debug_layer_idx % config.moe_layer_freq == 0 + and (ascend_config.torchair_graph_config.enable_multistream_moe + or self.enable_shared_expert_dp)): + self.o_proj = TorchairDeepseekV2RowParallelLinearReplaceAllreduce( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + else: + self.o_proj = TorchairDeepseekV2RowParallelLinear( + self.num_heads * self.v_head_dim, + self.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj") + + if rope_scaling: + rope_scaling["rope_type"] = 'deepseek_yarn' + self.rotary_emb = get_rope(qk_rope_head_dim, + rotary_dim=qk_rope_head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=False) + if rope_scaling: + mscale_all_dim = rope_scaling.get("mscale_all_dim", False) + scaling_factor = rope_scaling["factor"] + mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) + self.scaling = self.scaling * mscale * mscale + + # In the MLA backend, kv_cache includes both k_c and + # pe (i.e. decoupled position embeddings). In particular, + # the concat_and_cache_mla op requires + # k_c.size(1) + k_pe.size(1) == kv_cache.size(2) + # i.e. + # kv_lora_rank + qk_rope_head_dim == head_size + self.mla_attn = Attention( + num_heads=self.num_local_heads, + head_size=self.kv_lora_rank + self.qk_rope_head_dim, + scale=self.scaling, + num_kv_heads=1, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn", + use_mla=True, + # MLA Args + q_lora_rank=self.q_lora_rank, + kv_lora_rank=self.kv_lora_rank, + qk_nope_head_dim=self.qk_nope_head_dim, + qk_rope_head_dim=self.qk_rope_head_dim, + qk_head_dim=self.qk_head_dim, + v_head_dim=self.v_head_dim, + rotary_emb=self.rotary_emb, + q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj, + kv_a_proj_with_mqa=self.kv_a_proj_with_mqa, + kv_a_layernorm=self.kv_a_layernorm, + kv_b_proj=self.kv_b_proj, + o_proj=self.o_proj, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None) -> torch.Tensor: + forward_context = get_forward_context() + enable_multistream_mla = (self.enable_multistream_mla + and attn_metadata is not None + and not forward_context.with_prefill + and attn_metadata.num_decodes > 0) + forward_kwargs = {"enable_multistream_mla": enable_multistream_mla} + if self.q_lora_rank is not None: + npu_prefetch(self.q_a_proj.weight, + hidden_states, + enabled=enable_multistream_mla) + ckq = self.q_a_proj(hidden_states)[0] + hidden_states_or_q_c = self.q_a_layernorm(ckq) + forward_kwargs['ckq'] = ckq + else: + hidden_states_or_q_c = hidden_states + if self.torchair_graph_enabled: + output_shape = hidden_states.shape + output = torch.empty(output_shape, + dtype=hidden_states_or_q_c.dtype, + device=hidden_states_or_q_c.device) + forward_kwargs['output'] = output + output = self.mla_attn.impl.forward(self.mla_attn, + hidden_states_or_q_c, + hidden_states, None, kv_cache, + attn_metadata, + **forward_kwargs) + output = output.view(-1, output_shape[-1]) + return output + else: + kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0] + if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers: + hidden_states_or_q_c = get_tp_group().all_gather( + hidden_states_or_q_c, 0) + kv_no_split = get_tp_group().all_gather(kv_no_split, 0) + + kv_c, k_pe = kv_no_split.split( + [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) + if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace: + output_shape = hidden_states.shape + else: + num_tokens = hidden_states_or_q_c.shape[0] + rows = num_tokens // self.tp_size + if num_tokens % self.tp_size: + rows += 1 + output_shape = (rows, hidden_states.shape[1]) + return self.mla_attn(hidden_states_or_q_c, + kv_c_normed, + k_pe, + output_shape=output_shape) + + +class TorchairDeepseekV2DecoderLayer(DeepseekV2DecoderLayer): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + model_config: ModelConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + nn.Module.__init__(self) + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # DecoderLayers are created with `make_layers` which passes the prefix + # with the layer's index. + layer_idx = int(prefix.split(sep='.')[-1]) + self.layer_idx = layer_idx + self.layers = config.num_hidden_layers + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tp_group().rank_in_group + ascend_config = get_ascend_config() + # TODO: enable mla in vllm-ascend + if model_config.use_mla: + attn_cls = TorchairDeepseekV2MLAAttention + else: + attn_cls = DeepseekV2Attention + self.self_attn = attn_cls( + config=config, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + qk_nope_head_dim=config.qk_nope_head_dim, + qk_rope_head_dim=config.qk_rope_head_dim, + v_head_dim=config.v_head_dim, + q_lora_rank=config.q_lora_rank + if hasattr(config, "q_lora_rank") else None, + kv_lora_rank=config.kv_lora_rank, + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) + + if (config.n_routed_experts is not None + and layer_idx >= config.first_k_dense_replace + and layer_idx % config.moe_layer_freq == 0): + self.mlp = TorchairDeepseekV2MoE( + config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.mla_moe_communication = ascend_config.torchair_graph_config.enable_multistream_moe \ + and model_config.use_mla and self.tp_size > 1 + else: + self.mlp = TorchairDeepseekV2MLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp", + ) + self.mla_moe_communication = False + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.routed_scaling_factor = config.routed_scaling_factor + self.first_k_dense_replace = config.first_k_dense_replace + self.tp_group = get_tp_group().device_group + self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + kv_cache: Optional[torch.Tensor] = None, + attn_metadata: Optional[AttentionMetadata] = None, + replace_allreduce: bool = False, + ) -> torch.Tensor: + # Self Attention + if attn_metadata is not None and attn_metadata.num_decodes > 0: + mla_moe_communication = self.mla_moe_communication and replace_allreduce + else: + mla_moe_communication = False + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + previous_hidden_states, previous_residual = hidden_states, residual + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + # Dispose hidden_states and residual from the previous layer + # to save npu memory because they're no longer used. + dispose_tensor(previous_hidden_states) + dispose_tensor(previous_residual) + if mla_moe_communication and self.layer_idx > self.first_k_dense_replace: + hidden_states = tensor_model_parallel_all_gather(hidden_states, + dim=0) + + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + if mla_moe_communication and residual.shape[0] != hidden_states.shape[ + 0]: + chunk_hidden_states = torch.tensor_split(residual, + self.tp_size, + dim=0) + residual = chunk_hidden_states[self.tp_rank] + + if hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # We scale both hidden_states and residual before + # rmsnorm, and rmsnorm result would not affect by scale. + hidden_states *= 1. / self.routed_scaling_factor + if self.layer_idx == 0: + # The residual is shared by all layers, we only scale it on + # first layer. + residual *= 1. / self.routed_scaling_factor + + tp_size = get_tensor_model_parallel_world_size() + if self.enable_shared_expert_dp and ( + self.layer_idx == self.first_k_dense_replace + or self.layer_idx == self.layers) and tp_size > 1: + num_tokens, _ = residual.shape + if num_tokens % tp_size: + residual = nn.functional.pad(residual, + (0, 0, 0, -num_tokens % tp_size)) + chunk_residual = torch.tensor_split(residual, tp_size, dim=0) + tp_rank = get_tensor_model_parallel_rank() + residual = chunk_residual[tp_rank] + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + + if isinstance(self.mlp, TorchairDeepseekV2MoE): + hidden_states = self.mlp(hidden_states, + attn_metadata, + replace_allreduce=mla_moe_communication) + else: + hidden_states = self.mlp(hidden_states) + + if isinstance(self.mlp, TorchairDeepseekV2MLP + ) and hidden_states.dtype == torch.float16: + # Fix FP16 overflow + # Scaling the DeepseekV2MLP output, it is the input of + # input_layernorm of next decoder layer. + # The scaling of DeepseekV2MOE output would be done in the forward + # of DeepseekV2MOE + hidden_states *= 1. / self.routed_scaling_factor + if mla_moe_communication and self.layer_idx == self.layers - 1: + hidden_states = tensor_model_parallel_all_gather(hidden_states, + dim=0) + residual = tensor_model_parallel_all_gather(residual, dim=0) + + # for last layer of main model and mtp layer. + if self.enable_shared_expert_dp and self.layer_idx >= ( + self.layers - 1) and tp_size > 1: + hidden_states = get_tp_group().all_gather(hidden_states, 0) + residual = get_tp_group().all_gather(residual, 0) + + attn_metadata = get_forward_context().attn_metadata + if attn_metadata is not None: + num_tokens = attn_metadata.num_actual_tokens + else: + num_tokens = hidden_states.shape[0] + + if num_tokens < hidden_states.shape[0]: + hidden_states = hidden_states[:num_tokens] + residual = residual[:num_tokens] + + return hidden_states, residual + + +class TorchairDeepseekV2Model(nn.Module): + + fall_back_to_pt_during_load = False + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + + config = vllm_config.model_config.hf_config + model_config = vllm_config.model_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.tp_size = get_tensor_model_parallel_world_size() + + if get_pp_group().is_first_rank: + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=f"{prefix}.embed_tokens") + else: + self.embed_tokens = PPMissingLayer() + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: TorchairDeepseekV2DecoderLayer( + config, + prefix, + model_config=model_config, + cache_config=cache_config, + quant_config=quant_config, + ), + prefix=f"{prefix}.layers") + + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + replace_allreduce = hidden_states.shape[0] % self.tp_size == 0 + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + residual, + kv_caches[i - + self.start_layer] if kv_caches is not None else None, + attn_metadata, + replace_allreduce=replace_allreduce) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class TorchairDeepseekV2ForCausalLM(DeepseekV2ForCausalLM): + # add `packed_modules_mapping` in `DeepseekV2ForCausalLM` to support weight merging + packed_modules_mapping = { + "gate_up_proj": ["gate_proj", "up_proj"], + "experts": + ["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"] + } + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + nn.Module.__init__(self) + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + self.config = config + self.quant_config = quant_config + self.model = TorchairDeepseekV2Model(vllm_config=vllm_config, + prefix=maybe_prefix( + prefix, "model")) + if get_pp_group().is_last_rank: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) + else: + self.lm_head = PPMissingLayer() + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = get_sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + + # NOTE: This `load_weights` is mainly copied from + # https://github.com/vllm-project/vllm/commit/07b8fae219b1fff51ef115c38c44b51395be5bb5 + # to fix CI, and it is different from the implementation in main + # TODO: support eplb style load_weights + def load_weights(self, weights: Iterable[tuple[str, + torch.Tensor]]) -> set[str]: + """""" + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + # Params for weights, fp8 weight scales, fp8 activation scales + # (param_name, weight_name, expert_id, shard_id) + expert_params_mapping = AscendFusedMoE.make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.n_routed_experts) + + params_dict = dict(self.named_parameters()) + loaded_params: set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if "module" in name: + continue + + spec_layer = get_spec_layer_idx_from_weight_name(self.config, name) + if spec_layer is not None: + continue # skip spec decode layers for main model + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + # Skip non-stacked layers and experts (experts handled below). + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if (("mlp.experts." in name) and name not in params_dict): + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + return_success=False) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: Optional[List[torch.Tensor]] = None, + attn_metadata: Optional[AttentionMetadata] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors, + inputs_embeds) + return hidden_states diff --git a/vllm_ascend/torchair/models/torchair_deepseek_v3.py b/vllm_ascend/torchair/models/torchair_deepseek_v3.py new file mode 100644 index 0000000..aef8ae0 --- /dev/null +++ b/vllm_ascend/torchair/models/torchair_deepseek_v3.py @@ -0,0 +1,28 @@ +# SPDX-License-Identifier: Apache-2.0 +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# Copyright 2023 DeepSeek-AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from vllm_ascend.torchair.models.torchair_deepseek_v2 import \ + TorchairDeepseekV2ForCausalLM + + +class TorchairDeepseekV3ForCausalLM(TorchairDeepseekV2ForCausalLM): + pass diff --git a/vllm_ascend/torchair/torchair_model_runner.py b/vllm_ascend/torchair/torchair_model_runner.py index f42f83d..a07304b 100644 --- a/vllm_ascend/torchair/torchair_model_runner.py +++ b/vllm_ascend/torchair/torchair_model_runner.py @@ -27,6 +27,7 @@ from vllm.logger import logger from vllm_ascend.platform import NPUPlatform from vllm_ascend.torchair.utils import (check_torchair_cache_exist, + register_torchair_model, write_kv_cache_bytes_to_file) from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, maybe_converting_weight_acl_format) @@ -37,6 +38,7 @@ class NPUTorchairModelRunner(NPUModelRunner): def __init__(self, vllm_config: VllmConfig, device: torch.device): super().__init__(vllm_config, device) + register_torchair_model() def _get_forward_metadata_across_dp_and_pad( self, num_tokens: int, with_prefill: bool, enable_dbo: bool diff --git a/vllm_ascend/torchair/utils.py b/vllm_ascend/torchair/utils.py index f1a6138..0a94494 100644 --- a/vllm_ascend/torchair/utils.py +++ b/vllm_ascend/torchair/utils.py @@ -96,3 +96,22 @@ def npu_wait_tensor(self: torch.Tensor, *, enabled: bool = True): return _npu_wait_tensor(self, dependency) if enabled else self + + +def register_torchair_model(): + from vllm import ModelRegistry + + ModelRegistry.register_model( + "DeepSeekMTPModel", + "vllm_ascend.torchair.models.torchair_deepseek_mtp:TorchairDeepSeekMTP" + ) + + ModelRegistry.register_model( + "DeepseekV2ForCausalLM", + "vllm_ascend.torchair.models.torchair_deepseek_v2:TorchairDeepseekV2ForCausalLM" + ) + + ModelRegistry.register_model( + "DeepseekV3ForCausalLM", + "vllm_ascend.torchair.models.torchair_deepseek_v3:TorchairDeepseekV3ForCausalLM" + )