diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index ff814c0..c55234b 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -306,6 +306,7 @@ class TestAscendMLAImpl(TestBase): "kv_b_proj": MagicMock(), "o_proj": MagicMock(), "kv_a_proj_with_mqa": MagicMock(), + "fused_qkv_a_proj": MagicMock(), "kv_a_layernorm": kv_a_layernorm, } @@ -511,7 +512,6 @@ class TestAscendMLAImpl(TestBase): attn_metadata.prefill.cos = torch.randn(2, 64) attn_metadata.prefill.sin = torch.randn(2, 64) - self.impl.q_a_proj = MagicMock() self.impl.q_a_layernorm = MagicMock() self.impl.q_a_layernorm.return_value = torch.randn( attn_metadata.num_actual_tokens, self.impl.num_heads, @@ -519,7 +519,14 @@ class TestAscendMLAImpl(TestBase): self.impl.kv_a_proj_with_mqa = MagicMock() self.impl.kv_a_proj_with_mqa.return_value = [ torch.randn(num_prefill_tokens, self.impl.num_heads, - self.impl.qk_nope_head_dim + self.impl.kv_lora_rank) + self.impl.qk_rope_head_dim + self.impl.kv_lora_rank) + ] + self.impl.fused_qkv_a_proj = MagicMock() + self.impl.fused_qkv_a_proj.return_value = [ + torch.randn( + num_prefill_tokens, self.impl.num_heads, + self.impl.qk_rope_head_dim + self.impl.kv_lora_rank + + self.impl.q_lora_rank) ] self.impl.q_proj = MagicMock() self.impl.q_proj.return_value = [ diff --git a/tests/ut/models/test_deepseek_mtp.py b/tests/ut/models/test_deepseek_mtp.py deleted file mode 100644 index 6e2d868..0000000 --- a/tests/ut/models/test_deepseek_mtp.py +++ /dev/null @@ -1,198 +0,0 @@ -import pytest -import torch -from pytest_mock import MockerFixture -from transformers import PretrainedConfig -from vllm.config import CacheConfig, ModelConfig, VllmConfig - -from tests.ut.base import PytestBase -from vllm_ascend.models.deepseek_mtp import ( - CustomDeepSeekMTP, CustomDeepSeekMultiTokenPredictor, - CustomDeepSeekMultiTokenPredictorLayer) - - -class TestCustomDeepSeekMultiTokenPredictorLayer(PytestBase): - - @pytest.fixture - def setup_mtp_layer(self, mocker: MockerFixture, vllm_config: VllmConfig, - mock_distributed): - config = PretrainedConfig(vocab_size=1000, - hidden_size=768, - rms_norm_eps=1e-5) - mocker.patch("vllm_ascend.models.deepseek_mtp.get_current_vllm_config", - return_value=vllm_config) - mocker.patch( - "vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__", - return_value=None) - mocker.patch("vllm.model_executor.layers.layernorm.RMSNorm.__init__", - return_value=None) - mocker.patch( - "vllm.model_executor.models.deepseek_mtp.SharedHead.__init__", - return_value=None) - mocker.patch( - "vllm_ascend.models.deepseek_mtp.CustomDeepSeekShareHead.__init__", - return_value=None) - mocker_deepseek_v2_decode_layer = mocker.patch( - "vllm.model_executor.models.deepseek_v2.DeepseekV2DecoderLayer.__init__", - return_value=None) - mocker.patch( - "vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__", - return_value=None) - - mtp_layer = CustomDeepSeekMultiTokenPredictorLayer(config, "0", None) - mocker_deepseek_v2_decode_layer.assert_called_once() - return mtp_layer - - def test_init(self, mocker: MockerFixture, setup_mtp_layer): - mtp_layer = setup_mtp_layer - assert isinstance(mtp_layer, CustomDeepSeekMultiTokenPredictorLayer) - - def test_forward(self, mocker: MockerFixture, setup_mtp_layer): - mtp_layer = setup_mtp_layer - mocker.patch("torch.nn.Module.__setattr__") - mocker.patch("torch.nn.Module.__getattr__") - mocker.patch("torch.nn.Module.__delattr__") - mocker.patch.object(mtp_layer, - 'eh_proj', - return_value=torch.randn(2, 3, 768)) - mocker.patch("torch.cat", return_value=torch.randn(2, 3, 768)) - mocker.patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad", - lambda x, label: x) - mtp_layer.mtp_block.return_value = (torch.randn(2, 3, 768), - torch.randn(2, 3, 768)) - - input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]]) - positions = torch.tensor([[0, 1, 2], [0, 1, 2]]) - kv_cache = torch.randn(2, 3, 768) - previous_hidden_states = torch.randn(2, 3, 768) - inputs_embeds = torch.tensor([[1.0, 2.0, 3.0]]) - - output = mtp_layer(input_ids, positions, kv_cache, None, - previous_hidden_states, inputs_embeds, 0) - assert output.shape == (2, 3, 768) - - -class TestCustomDeepSeekMultiTokenPredictor(PytestBase): - - @pytest.fixture - def setup_predictor(self, mocker: MockerFixture): - mock_vllm_config = mocker.MagicMock(spec=VllmConfig) - mock_model_config = mocker.MagicMock(spec=ModelConfig) - mock_hf_config = mocker.MagicMock() - mock_hf_config.num_hidden_layers = 12 - mock_hf_config.num_nextn_predict_layers = 3 - mock_hf_config.vocab_size = 30000 - mock_model_config.hf_config = mock_hf_config - mock_vllm_config.model_config = mock_model_config - mock_vllm_config.cache_config = CacheConfig() - mock_vllm_config.quant_config = mocker.MagicMock() - mocker.patch( - "vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__", - return_value=None) - mocker.patch( - "vllm_ascend.models.deepseek_mtp.CustomDeepSeekMultiTokenPredictorLayer.__init__", - return_value=None) - mocker.patch( - "vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__", - return_value=None) - mocker.patch("vllm_ascend.utils.get_ascend_config", - return_value=mocker.Mock()) - - predictor = CustomDeepSeekMultiTokenPredictor( - vllm_config=mock_vllm_config) - return predictor - - def test_init(self, mocker: MockerFixture, setup_predictor): - predictor = setup_predictor - assert predictor.num_mtp_layers == 3 - assert isinstance(predictor, CustomDeepSeekMultiTokenPredictor) - - @pytest.mark.parametrize( - 'kv_caches, inputs_embeds', - [(torch.tensor([[[0.1, 0.2, 0.3]]]), torch.tensor([[0.1, 0.2, 0.3]]))]) - def test_forward(self, mocker: MockerFixture, setup_predictor, kv_caches, - inputs_embeds): - predictor = setup_predictor - mock_layer = mocker.MagicMock() - mock_layer.return_value = torch.tensor([1.0, 2.0, 3.0]) - predictor.layers_list = [mock_layer] - - # todo: need or not? - # predictor.num_mtp_layers = 1 - input_ids = torch.tensor([[1, 2, 3]]) - positions = torch.tensor([[0, 1, 2]]) - mocker.patch( - "vllm_ascend.models.deepseek_mtp.CustomDeepSeekMultiTokenPredictorLayer.__call__", - return_value=torch.tensor([[1.0, 2.0, 3.0]])) - output = predictor.forward(input_ids, positions, kv_caches, None, None, - inputs_embeds, 0) - mock_layer.assert_called_once() - assert torch.allclose(output, torch.tensor([1.0, 2.0, 3.0])) - - def test_compute_logits(self, mocker: MockerFixture, setup_predictor): - hidden_states = torch.tensor([[1, 2, 3], [4, 5, 6]]) - predictor = setup_predictor - - mock_layer = mocker.MagicMock() - mock_layer.return_value = torch.tensor([1.0, 2.0, 3.0]) - predictor.layers_list = [mock_layer] - mocker.patch("torch.nn.Module.__setattr__") - mocker.patch("torch.nn.Module.__getattr__") - mocker.patch("torch.nn.Module.__delattr__") - mocker.patch( - "vllm.model_executor.layers.logits_processor.LogitsProcessor.__init__", - return_value=None) - predictor.logits_processor.return_value = torch.tensor([1.0, 2.0, 3.0]) - - result_logits = predictor.compute_logits(hidden_states=hidden_states, - sampling_metadata=None) - predictor.logits_processor.assert_called_once() - assert torch.allclose(result_logits, torch.tensor([1.0, 2.0, 3.0])) - - -class TestCustomDeepSeekMTP(PytestBase): - - @pytest.fixture - def setup_mtp(self, mocker: MockerFixture): - vllm_config = mocker.MagicMock() - vllm_config.model_config.hf_config.num_hidden_layers = 12 - vllm_config.model_config.hf_config.num_nextn_predict_layers = 3 - vllm_config.cache_config = mocker.MagicMock() - vllm_config.quant_config = mocker.MagicMock() - - mocker.patch("torch.nn.Module.__setattr__") - mocker.patch("torch.nn.Module.__getattr__") - mocker.patch("torch.nn.Module.__delattr__") - mocker.patch( - "vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__", - return_value=None) - mocker.patch( - "vllm_ascend.models.deepseek_mtp.CustomDeepSeekMultiTokenPredictorLayer.__call__", - return_value=None) - mocker.patch( - "vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__", - return_value=None) - mocker.patch("vllm_ascend.utils.get_ascend_config", - return_value=mocker.Mock()) - - mtp = CustomDeepSeekMTP(vllm_config=vllm_config) - return mtp - - def test_init(self, mocker: MockerFixture, setup_mtp): - mtp = setup_mtp - assert isinstance(mtp, CustomDeepSeekMTP) - - def test_forward(self, mocker: MockerFixture, setup_mtp): - mocker.patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad", - lambda x, label: x) - input_ids = torch.tensor([[1, 2, 3]]) - positions = torch.tensor([[0, 1, 2]]) - kv_caches = [torch.tensor([[0.1, 0.2, 0.3]])] - previous_hidden_states = torch.tensor([[0.1, 0.2, 0.3]]) - inputs_embeds = torch.tensor([[0.1, 0.2, 0.3]]) - spec_step_idx = 0 - setup_mtp.model.return_value = torch.tensor([[1.0, 2.0, 3.0]]) - - output = setup_mtp.forward(input_ids, positions, kv_caches, None, - previous_hidden_states, inputs_embeds, - spec_step_idx) - assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]])) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 356165e..f6627ef 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -11,6 +11,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed import get_tensor_model_parallel_world_size from vllm.forward_context import ForwardContext, get_forward_context +from vllm.logger import logger from vllm.model_executor.layers.linear import (LinearBase, UnquantizedLinearMethod) from vllm.utils import cdiv, round_down @@ -29,6 +30,7 @@ from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig from vllm_ascend.multistream.context import get_multistream_comm_context from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch +from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ, is_enable_nz) from vllm_ascend.worker.npu_input_batch import InputBatch @@ -557,6 +559,7 @@ class AscendMLAImpl(MLAAttentionImpl): self.prefill_mask = None self.speculative_config = vllm_config.speculative_config + self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO def _v_up_proj(self, x): if self.W_UV.shape[0] * self.W_UV.shape[1] < 65536: @@ -654,7 +657,17 @@ class AscendMLAImpl(MLAAttentionImpl): # self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29) # self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29) - if envs.VLLM_ASCEND_ENABLE_MLAPO: + # Currently mlapo only supports W8A8 quantization in MLA scenario + # TODO(whx): modify this limitation when mlapo supports floating point + if self.fused_qkv_a_proj is None or not isinstance( + getattr(self.fused_qkv_a_proj.quant_method, 'quant_method', + None), AscendW8A8LinearMethod): + self.enable_mlapo = False + logger.warning( + "Currently mlapo only supports W8A8 quantization in MLA scenario." + "Some layers in your model are not quantized with W8A8," + "thus mlapo is disabled for these layers.") + if self.enable_mlapo: self._process_weights_for_fused_mlapo(act_dtype) def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype): @@ -1229,7 +1242,7 @@ class AscendMLAImpl(MLAAttentionImpl): # MLA Preprocess forward_context = get_forward_context() - if (envs.VLLM_ASCEND_ENABLE_MLAPO and + if (self.enable_mlapo and (attn_metadata is None or not forward_context.with_prefill)): decode_preprocess_res, prefill_preprocess_res = self._mla_decode_preprocess( hidden_states, kv_cache, attn_metadata) diff --git a/vllm_ascend/models/__init__.py b/vllm_ascend/models/__init__.py index 15cddd0..2ebbdeb 100644 --- a/vllm_ascend/models/__init__.py +++ b/vllm_ascend/models/__init__.py @@ -33,10 +33,6 @@ def register_model(): "DeepseekV32ForCausalLM", "vllm_ascend.models.deepseek_v3_2:CustomDeepseekV3ForCausalLM") - ModelRegistry.register_model( - "DeepSeekMTPModel", - "vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP") - # There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization # to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM. ModelRegistry.register_model( diff --git a/vllm_ascend/models/deepseek_mtp.py b/vllm_ascend/models/deepseek_mtp.py deleted file mode 100644 index 7fbec3b..0000000 --- a/vllm_ascend/models/deepseek_mtp.py +++ /dev/null @@ -1,209 +0,0 @@ -# -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Adapted from vllm/model_executor/models/deepseek_mtp.py -# Copyright 2023 The vLLM team. -# -# This file is a part of the vllm-ascend project. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List, Optional - -import torch -import torch.nn as nn -from transformers import PretrainedConfig -from vllm.attention.backends.abstract import AttentionMetadata -from vllm.compilation.decorators import support_torch_compile -from vllm.config import (CacheConfig, ModelConfig, VllmConfig, - get_current_vllm_config) -from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.models.deepseek_mtp import ( - DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer, - SharedHead) -from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer -from vllm.model_executor.models.utils import maybe_prefix -from vllm.sequence import IntermediateTensors - - -class CustomDeepSeekShareHead(SharedHead): - - def __init__(self, - config: PretrainedConfig, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = "") -> None: - nn.Module.__init__(self) - self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config, - prefix=maybe_prefix(prefix, "head")) - - -class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer): - - def __init__( - self, - config: PretrainedConfig, - prefix: str, - model_config: ModelConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - ) -> None: - nn.Module.__init__(self) - vllm_config = get_current_vllm_config() - - self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) - self.eh_proj = nn.Linear(config.hidden_size * 2, - config.hidden_size, - bias=False) - self.shared_head = CustomDeepSeekShareHead(config=config, - quant_config=quant_config, - prefix=maybe_prefix( - prefix, "shared_head")) - self.mtp_block = DeepseekV2DecoderLayer(vllm_config=vllm_config, - prefix=prefix) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_cache: torch.Tensor, - attn_metadata: AttentionMetadata, - previous_hidden_states: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, - spec_step_index: int = 0, - ) -> torch.Tensor: - assert inputs_embeds is not None - inputs_embeds = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - inputs_embeds, True) - # masking inputs at position 0, as not needed by MTP - inputs_embeds = torch.where((positions == 0).unsqueeze(-1), - torch.zeros_like(inputs_embeds), - inputs_embeds) - inputs_embeds = self.enorm(inputs_embeds) - previous_hidden_states = self.hnorm(previous_hidden_states) - - hidden_states = self.eh_proj( - torch.cat([inputs_embeds, previous_hidden_states], dim=-1)) - - hidden_states, residual = self.mtp_block(positions=positions, - hidden_states=hidden_states, - residual=None) - hidden_states = residual + hidden_states - return hidden_states - - -class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - nn.Module.__init__(self) - config = vllm_config.model_config.hf_config - self.mtp_start_layer_idx = config.num_hidden_layers - self.num_mtp_layers = config.num_nextn_predict_layers - # to map the exact layer index from weights - self.layers = torch.nn.ModuleDict({ - str(idx): - CustomDeepSeekMultiTokenPredictorLayer( - config, - f"{prefix}.layers.{idx}", - model_config=vllm_config.model_config, - cache_config=vllm_config.cache_config, - quant_config=vllm_config.quant_config, - ) - for idx in range(self.mtp_start_layer_idx, - self.mtp_start_layer_idx + self.num_mtp_layers) - }) - self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, - config.hidden_size, - ) - - # Note: torch._dynamo.exc.Unsupported: builtin: str - self.layers_list = [ - self.layers[str(idx)] - for idx in range(self.mtp_start_layer_idx, - self.mtp_start_layer_idx + self.num_mtp_layers) - ] - self.logits_processor = LogitsProcessor(config.vocab_size) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: torch.Tensor, - attn_metadata: AttentionMetadata, - previous_hidden_states: torch.Tensor, - inputs_embeds: Optional[torch.Tensor] = None, - spec_step_idx: int = 0, - ) -> torch.Tensor: - if inputs_embeds is None: - inputs_embeds = self.embed_tokens(input_ids) - current_step_idx = (spec_step_idx % self.num_mtp_layers) - step_kv_cache = kv_caches[ - current_step_idx] if kv_caches is not None else None - return self.layers_list[current_step_idx]( - input_ids, - positions, - step_kv_cache, - attn_metadata, - previous_hidden_states, - inputs_embeds, - current_step_idx, - ) - - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata=None, # type: ignore - spec_step_idx: int = 0, - ) -> torch.Tensor: - current_step_idx = (spec_step_idx % self.num_mtp_layers) - mtp_layer = self.layers_list[current_step_idx] - logits = self.logits_processor(mtp_layer.shared_head.head, - mtp_layer.shared_head(hidden_states), - sampling_metadata) - return logits - - -@support_torch_compile -class CustomDeepSeekMTP(DeepSeekMTP): - - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - nn.Module.__init__(self) - self.config = vllm_config.model_config.hf_config - self.model = CustomDeepSeekMultiTokenPredictor(vllm_config=vllm_config, - prefix=maybe_prefix( - prefix, "model")) - - def forward( - self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: Optional[List[torch.Tensor]] = None, - attn_metadata: Optional[AttentionMetadata] = None, - previous_hidden_states: Optional[torch.Tensor] = None, - intermediate_tensors: Optional[IntermediateTensors] = None, - inputs_embeds: Optional[torch.Tensor] = None, - spec_step_idx: int = 0, - ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, previous_hidden_states, - inputs_embeds, spec_step_idx) - hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( - hidden_states, True) - return hidden_states diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index c8d8f1f..dea8c0d 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -56,6 +56,14 @@ class AscendQuantConfig(QuantizationConfig): def __init__(self, quant_config: Dict[str, Any]): super().__init__() self.quant_description = quant_config + # TODO(whx): remove this adaptation after adding "shared_head" + # to prefix of DeepSeekShareHead in vLLM. + extra_quant_dict = {} + for k in self.quant_description.keys(): + if "shared_head" in k: + new_k = k.replace(".shared_head.", ".") + extra_quant_dict[new_k] = self.quant_description[k] + self.quant_description.update(extra_quant_dict) def __repr__(self) -> str: return "AscendQuantConfig:\n" + super().__repr__() diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index a3baabf..97b094e 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -11,6 +11,7 @@ from vllm.forward_context import BatchDescriptor, get_forward_context from vllm.model_executor.model_loader import get_model_loader from vllm.model_executor.model_loader.utils import ( process_weights_after_loading, set_default_torch_dtype) +from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.spec_decode.metadata import SpecDecodeMetadata @@ -18,7 +19,6 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.attention.utils import AscendCommonAttentionMetadata -from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType from vllm_ascend.torchair.models.torchair_deepseek_mtp import \ TorchairDeepSeekMTP @@ -86,7 +86,7 @@ class MtpProposer(Proposer): self.model = TorchairDeepSeekMTP( vllm_config=self.vllm_config).to(target_device) else: - self.model = CustomDeepSeekMTP( + self.model = DeepSeekMTP( vllm_config=self.vllm_config).to(target_device) draft_attn_layer_names = ( @@ -184,7 +184,7 @@ class MtpProposer(Proposer): else: self.model(input_ids=input_ids, positions=positions, - previous_hidden_states=previous_hidden_states) + hidden_states=previous_hidden_states) if with_prefill: break @@ -470,9 +470,8 @@ class MtpProposer(Proposer): hidden_states = self.model( input_ids=self.input_ids[:num_input_tokens], positions=self.positions[:num_input_tokens], - previous_hidden_states=self. - hidden_states[:num_input_tokens], - kv_caches=self.runner.kv_caches[-1:]) + hidden_states=self.hidden_states[:num_input_tokens] + ) num_indices = last_token_indices.shape[0] if lmhead_tp_enable(): @@ -485,7 +484,7 @@ class MtpProposer(Proposer): (0, max_num_reqs_across_dp - num_indices)) sample_hidden_states = hidden_states[last_token_indices] - logits = self.model.compute_logits(sample_hidden_states, None) + logits = self.model.compute_logits(sample_hidden_states) if lmhead_tp_enable() and num_indices < logits.shape[0]: logits = logits[:num_indices] draft_token_ids = logits.argmax(dim=-1)