[Model][2/N] Remove deepseek_mtp modeling. (#3561)
This PR is step 2 of deepseek model refactoring and removes deepseek_mtp. - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: whx-sjtu <2952154980@qq.com>
This commit is contained in:
@@ -306,6 +306,7 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
"kv_b_proj": MagicMock(),
|
"kv_b_proj": MagicMock(),
|
||||||
"o_proj": MagicMock(),
|
"o_proj": MagicMock(),
|
||||||
"kv_a_proj_with_mqa": MagicMock(),
|
"kv_a_proj_with_mqa": MagicMock(),
|
||||||
|
"fused_qkv_a_proj": MagicMock(),
|
||||||
"kv_a_layernorm": kv_a_layernorm,
|
"kv_a_layernorm": kv_a_layernorm,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -511,7 +512,6 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
attn_metadata.prefill.cos = torch.randn(2, 64)
|
attn_metadata.prefill.cos = torch.randn(2, 64)
|
||||||
attn_metadata.prefill.sin = torch.randn(2, 64)
|
attn_metadata.prefill.sin = torch.randn(2, 64)
|
||||||
|
|
||||||
self.impl.q_a_proj = MagicMock()
|
|
||||||
self.impl.q_a_layernorm = MagicMock()
|
self.impl.q_a_layernorm = MagicMock()
|
||||||
self.impl.q_a_layernorm.return_value = torch.randn(
|
self.impl.q_a_layernorm.return_value = torch.randn(
|
||||||
attn_metadata.num_actual_tokens, self.impl.num_heads,
|
attn_metadata.num_actual_tokens, self.impl.num_heads,
|
||||||
@@ -519,7 +519,14 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
self.impl.kv_a_proj_with_mqa = MagicMock()
|
self.impl.kv_a_proj_with_mqa = MagicMock()
|
||||||
self.impl.kv_a_proj_with_mqa.return_value = [
|
self.impl.kv_a_proj_with_mqa.return_value = [
|
||||||
torch.randn(num_prefill_tokens, self.impl.num_heads,
|
torch.randn(num_prefill_tokens, self.impl.num_heads,
|
||||||
self.impl.qk_nope_head_dim + self.impl.kv_lora_rank)
|
self.impl.qk_rope_head_dim + self.impl.kv_lora_rank)
|
||||||
|
]
|
||||||
|
self.impl.fused_qkv_a_proj = MagicMock()
|
||||||
|
self.impl.fused_qkv_a_proj.return_value = [
|
||||||
|
torch.randn(
|
||||||
|
num_prefill_tokens, self.impl.num_heads,
|
||||||
|
self.impl.qk_rope_head_dim + self.impl.kv_lora_rank +
|
||||||
|
self.impl.q_lora_rank)
|
||||||
]
|
]
|
||||||
self.impl.q_proj = MagicMock()
|
self.impl.q_proj = MagicMock()
|
||||||
self.impl.q_proj.return_value = [
|
self.impl.q_proj.return_value = [
|
||||||
|
|||||||
@@ -1,198 +0,0 @@
|
|||||||
import pytest
|
|
||||||
import torch
|
|
||||||
from pytest_mock import MockerFixture
|
|
||||||
from transformers import PretrainedConfig
|
|
||||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
|
||||||
|
|
||||||
from tests.ut.base import PytestBase
|
|
||||||
from vllm_ascend.models.deepseek_mtp import (
|
|
||||||
CustomDeepSeekMTP, CustomDeepSeekMultiTokenPredictor,
|
|
||||||
CustomDeepSeekMultiTokenPredictorLayer)
|
|
||||||
|
|
||||||
|
|
||||||
class TestCustomDeepSeekMultiTokenPredictorLayer(PytestBase):
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def setup_mtp_layer(self, mocker: MockerFixture, vllm_config: VllmConfig,
|
|
||||||
mock_distributed):
|
|
||||||
config = PretrainedConfig(vocab_size=1000,
|
|
||||||
hidden_size=768,
|
|
||||||
rms_norm_eps=1e-5)
|
|
||||||
mocker.patch("vllm_ascend.models.deepseek_mtp.get_current_vllm_config",
|
|
||||||
return_value=vllm_config)
|
|
||||||
mocker.patch(
|
|
||||||
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
|
|
||||||
return_value=None)
|
|
||||||
mocker.patch("vllm.model_executor.layers.layernorm.RMSNorm.__init__",
|
|
||||||
return_value=None)
|
|
||||||
mocker.patch(
|
|
||||||
"vllm.model_executor.models.deepseek_mtp.SharedHead.__init__",
|
|
||||||
return_value=None)
|
|
||||||
mocker.patch(
|
|
||||||
"vllm_ascend.models.deepseek_mtp.CustomDeepSeekShareHead.__init__",
|
|
||||||
return_value=None)
|
|
||||||
mocker_deepseek_v2_decode_layer = mocker.patch(
|
|
||||||
"vllm.model_executor.models.deepseek_v2.DeepseekV2DecoderLayer.__init__",
|
|
||||||
return_value=None)
|
|
||||||
mocker.patch(
|
|
||||||
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
|
|
||||||
return_value=None)
|
|
||||||
|
|
||||||
mtp_layer = CustomDeepSeekMultiTokenPredictorLayer(config, "0", None)
|
|
||||||
mocker_deepseek_v2_decode_layer.assert_called_once()
|
|
||||||
return mtp_layer
|
|
||||||
|
|
||||||
def test_init(self, mocker: MockerFixture, setup_mtp_layer):
|
|
||||||
mtp_layer = setup_mtp_layer
|
|
||||||
assert isinstance(mtp_layer, CustomDeepSeekMultiTokenPredictorLayer)
|
|
||||||
|
|
||||||
def test_forward(self, mocker: MockerFixture, setup_mtp_layer):
|
|
||||||
mtp_layer = setup_mtp_layer
|
|
||||||
mocker.patch("torch.nn.Module.__setattr__")
|
|
||||||
mocker.patch("torch.nn.Module.__getattr__")
|
|
||||||
mocker.patch("torch.nn.Module.__delattr__")
|
|
||||||
mocker.patch.object(mtp_layer,
|
|
||||||
'eh_proj',
|
|
||||||
return_value=torch.randn(2, 3, 768))
|
|
||||||
mocker.patch("torch.cat", return_value=torch.randn(2, 3, 768))
|
|
||||||
mocker.patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad",
|
|
||||||
lambda x, label: x)
|
|
||||||
mtp_layer.mtp_block.return_value = (torch.randn(2, 3, 768),
|
|
||||||
torch.randn(2, 3, 768))
|
|
||||||
|
|
||||||
input_ids = torch.tensor([[1, 2, 3], [4, 5, 6]])
|
|
||||||
positions = torch.tensor([[0, 1, 2], [0, 1, 2]])
|
|
||||||
kv_cache = torch.randn(2, 3, 768)
|
|
||||||
previous_hidden_states = torch.randn(2, 3, 768)
|
|
||||||
inputs_embeds = torch.tensor([[1.0, 2.0, 3.0]])
|
|
||||||
|
|
||||||
output = mtp_layer(input_ids, positions, kv_cache, None,
|
|
||||||
previous_hidden_states, inputs_embeds, 0)
|
|
||||||
assert output.shape == (2, 3, 768)
|
|
||||||
|
|
||||||
|
|
||||||
class TestCustomDeepSeekMultiTokenPredictor(PytestBase):
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def setup_predictor(self, mocker: MockerFixture):
|
|
||||||
mock_vllm_config = mocker.MagicMock(spec=VllmConfig)
|
|
||||||
mock_model_config = mocker.MagicMock(spec=ModelConfig)
|
|
||||||
mock_hf_config = mocker.MagicMock()
|
|
||||||
mock_hf_config.num_hidden_layers = 12
|
|
||||||
mock_hf_config.num_nextn_predict_layers = 3
|
|
||||||
mock_hf_config.vocab_size = 30000
|
|
||||||
mock_model_config.hf_config = mock_hf_config
|
|
||||||
mock_vllm_config.model_config = mock_model_config
|
|
||||||
mock_vllm_config.cache_config = CacheConfig()
|
|
||||||
mock_vllm_config.quant_config = mocker.MagicMock()
|
|
||||||
mocker.patch(
|
|
||||||
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
|
|
||||||
return_value=None)
|
|
||||||
mocker.patch(
|
|
||||||
"vllm_ascend.models.deepseek_mtp.CustomDeepSeekMultiTokenPredictorLayer.__init__",
|
|
||||||
return_value=None)
|
|
||||||
mocker.patch(
|
|
||||||
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
|
|
||||||
return_value=None)
|
|
||||||
mocker.patch("vllm_ascend.utils.get_ascend_config",
|
|
||||||
return_value=mocker.Mock())
|
|
||||||
|
|
||||||
predictor = CustomDeepSeekMultiTokenPredictor(
|
|
||||||
vllm_config=mock_vllm_config)
|
|
||||||
return predictor
|
|
||||||
|
|
||||||
def test_init(self, mocker: MockerFixture, setup_predictor):
|
|
||||||
predictor = setup_predictor
|
|
||||||
assert predictor.num_mtp_layers == 3
|
|
||||||
assert isinstance(predictor, CustomDeepSeekMultiTokenPredictor)
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
'kv_caches, inputs_embeds',
|
|
||||||
[(torch.tensor([[[0.1, 0.2, 0.3]]]), torch.tensor([[0.1, 0.2, 0.3]]))])
|
|
||||||
def test_forward(self, mocker: MockerFixture, setup_predictor, kv_caches,
|
|
||||||
inputs_embeds):
|
|
||||||
predictor = setup_predictor
|
|
||||||
mock_layer = mocker.MagicMock()
|
|
||||||
mock_layer.return_value = torch.tensor([1.0, 2.0, 3.0])
|
|
||||||
predictor.layers_list = [mock_layer]
|
|
||||||
|
|
||||||
# todo: need or not?
|
|
||||||
# predictor.num_mtp_layers = 1
|
|
||||||
input_ids = torch.tensor([[1, 2, 3]])
|
|
||||||
positions = torch.tensor([[0, 1, 2]])
|
|
||||||
mocker.patch(
|
|
||||||
"vllm_ascend.models.deepseek_mtp.CustomDeepSeekMultiTokenPredictorLayer.__call__",
|
|
||||||
return_value=torch.tensor([[1.0, 2.0, 3.0]]))
|
|
||||||
output = predictor.forward(input_ids, positions, kv_caches, None, None,
|
|
||||||
inputs_embeds, 0)
|
|
||||||
mock_layer.assert_called_once()
|
|
||||||
assert torch.allclose(output, torch.tensor([1.0, 2.0, 3.0]))
|
|
||||||
|
|
||||||
def test_compute_logits(self, mocker: MockerFixture, setup_predictor):
|
|
||||||
hidden_states = torch.tensor([[1, 2, 3], [4, 5, 6]])
|
|
||||||
predictor = setup_predictor
|
|
||||||
|
|
||||||
mock_layer = mocker.MagicMock()
|
|
||||||
mock_layer.return_value = torch.tensor([1.0, 2.0, 3.0])
|
|
||||||
predictor.layers_list = [mock_layer]
|
|
||||||
mocker.patch("torch.nn.Module.__setattr__")
|
|
||||||
mocker.patch("torch.nn.Module.__getattr__")
|
|
||||||
mocker.patch("torch.nn.Module.__delattr__")
|
|
||||||
mocker.patch(
|
|
||||||
"vllm.model_executor.layers.logits_processor.LogitsProcessor.__init__",
|
|
||||||
return_value=None)
|
|
||||||
predictor.logits_processor.return_value = torch.tensor([1.0, 2.0, 3.0])
|
|
||||||
|
|
||||||
result_logits = predictor.compute_logits(hidden_states=hidden_states,
|
|
||||||
sampling_metadata=None)
|
|
||||||
predictor.logits_processor.assert_called_once()
|
|
||||||
assert torch.allclose(result_logits, torch.tensor([1.0, 2.0, 3.0]))
|
|
||||||
|
|
||||||
|
|
||||||
class TestCustomDeepSeekMTP(PytestBase):
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def setup_mtp(self, mocker: MockerFixture):
|
|
||||||
vllm_config = mocker.MagicMock()
|
|
||||||
vllm_config.model_config.hf_config.num_hidden_layers = 12
|
|
||||||
vllm_config.model_config.hf_config.num_nextn_predict_layers = 3
|
|
||||||
vllm_config.cache_config = mocker.MagicMock()
|
|
||||||
vllm_config.quant_config = mocker.MagicMock()
|
|
||||||
|
|
||||||
mocker.patch("torch.nn.Module.__setattr__")
|
|
||||||
mocker.patch("torch.nn.Module.__getattr__")
|
|
||||||
mocker.patch("torch.nn.Module.__delattr__")
|
|
||||||
mocker.patch(
|
|
||||||
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding.__init__",
|
|
||||||
return_value=None)
|
|
||||||
mocker.patch(
|
|
||||||
"vllm_ascend.models.deepseek_mtp.CustomDeepSeekMultiTokenPredictorLayer.__call__",
|
|
||||||
return_value=None)
|
|
||||||
mocker.patch(
|
|
||||||
"vllm_ascend.ops.vocab_parallel_embedding.AscendVocabParallelEmbedding.__init__",
|
|
||||||
return_value=None)
|
|
||||||
mocker.patch("vllm_ascend.utils.get_ascend_config",
|
|
||||||
return_value=mocker.Mock())
|
|
||||||
|
|
||||||
mtp = CustomDeepSeekMTP(vllm_config=vllm_config)
|
|
||||||
return mtp
|
|
||||||
|
|
||||||
def test_init(self, mocker: MockerFixture, setup_mtp):
|
|
||||||
mtp = setup_mtp
|
|
||||||
assert isinstance(mtp, CustomDeepSeekMTP)
|
|
||||||
|
|
||||||
def test_forward(self, mocker: MockerFixture, setup_mtp):
|
|
||||||
mocker.patch("torch.ops.vllm.maybe_all_gather_and_maybe_unpad",
|
|
||||||
lambda x, label: x)
|
|
||||||
input_ids = torch.tensor([[1, 2, 3]])
|
|
||||||
positions = torch.tensor([[0, 1, 2]])
|
|
||||||
kv_caches = [torch.tensor([[0.1, 0.2, 0.3]])]
|
|
||||||
previous_hidden_states = torch.tensor([[0.1, 0.2, 0.3]])
|
|
||||||
inputs_embeds = torch.tensor([[0.1, 0.2, 0.3]])
|
|
||||||
spec_step_idx = 0
|
|
||||||
setup_mtp.model.return_value = torch.tensor([[1.0, 2.0, 3.0]])
|
|
||||||
|
|
||||||
output = setup_mtp.forward(input_ids, positions, kv_caches, None,
|
|
||||||
previous_hidden_states, inputs_embeds,
|
|
||||||
spec_step_idx)
|
|
||||||
assert torch.allclose(output, torch.tensor([[1.0, 2.0, 3.0]]))
|
|
||||||
@@ -11,6 +11,7 @@ from vllm.attention.backends.abstract import (AttentionBackend,
|
|||||||
from vllm.config import VllmConfig, get_current_vllm_config
|
from vllm.config import VllmConfig, get_current_vllm_config
|
||||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||||
from vllm.forward_context import ForwardContext, get_forward_context
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||||||
|
from vllm.logger import logger
|
||||||
from vllm.model_executor.layers.linear import (LinearBase,
|
from vllm.model_executor.layers.linear import (LinearBase,
|
||||||
UnquantizedLinearMethod)
|
UnquantizedLinearMethod)
|
||||||
from vllm.utils import cdiv, round_down
|
from vllm.utils import cdiv, round_down
|
||||||
@@ -29,6 +30,7 @@ from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
|||||||
from vllm_ascend.multistream.context import get_multistream_comm_context
|
from vllm_ascend.multistream.context import get_multistream_comm_context
|
||||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||||
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
from vllm_ascend.ops.weight_prefetch import maybe_npu_prefetch
|
||||||
|
from vllm_ascend.quantization.w8a8 import AscendW8A8LinearMethod
|
||||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||||
is_enable_nz)
|
is_enable_nz)
|
||||||
from vllm_ascend.worker.npu_input_batch import InputBatch
|
from vllm_ascend.worker.npu_input_batch import InputBatch
|
||||||
@@ -557,6 +559,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
self.prefill_mask = None
|
self.prefill_mask = None
|
||||||
|
|
||||||
self.speculative_config = vllm_config.speculative_config
|
self.speculative_config = vllm_config.speculative_config
|
||||||
|
self.enable_mlapo = envs.VLLM_ASCEND_ENABLE_MLAPO
|
||||||
|
|
||||||
def _v_up_proj(self, x):
|
def _v_up_proj(self, x):
|
||||||
if self.W_UV.shape[0] * self.W_UV.shape[1] < 65536:
|
if self.W_UV.shape[0] * self.W_UV.shape[1] < 65536:
|
||||||
@@ -654,7 +657,17 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
|
# self.W_UV.data = torch_npu.npu_format_cast(self.W_UV.data, 29)
|
||||||
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
|
# self.W_UK_T.data = torch_npu.npu_format_cast(self.W_UK_T.data, 29)
|
||||||
|
|
||||||
if envs.VLLM_ASCEND_ENABLE_MLAPO:
|
# Currently mlapo only supports W8A8 quantization in MLA scenario
|
||||||
|
# TODO(whx): modify this limitation when mlapo supports floating point
|
||||||
|
if self.fused_qkv_a_proj is None or not isinstance(
|
||||||
|
getattr(self.fused_qkv_a_proj.quant_method, 'quant_method',
|
||||||
|
None), AscendW8A8LinearMethod):
|
||||||
|
self.enable_mlapo = False
|
||||||
|
logger.warning(
|
||||||
|
"Currently mlapo only supports W8A8 quantization in MLA scenario."
|
||||||
|
"Some layers in your model are not quantized with W8A8,"
|
||||||
|
"thus mlapo is disabled for these layers.")
|
||||||
|
if self.enable_mlapo:
|
||||||
self._process_weights_for_fused_mlapo(act_dtype)
|
self._process_weights_for_fused_mlapo(act_dtype)
|
||||||
|
|
||||||
def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
|
def _process_weights_for_fused_mlapo(self, act_dtype: torch.dtype):
|
||||||
@@ -1229,7 +1242,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
|
|
||||||
# MLA Preprocess
|
# MLA Preprocess
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
if (envs.VLLM_ASCEND_ENABLE_MLAPO and
|
if (self.enable_mlapo and
|
||||||
(attn_metadata is None or not forward_context.with_prefill)):
|
(attn_metadata is None or not forward_context.with_prefill)):
|
||||||
decode_preprocess_res, prefill_preprocess_res = self._mla_decode_preprocess(
|
decode_preprocess_res, prefill_preprocess_res = self._mla_decode_preprocess(
|
||||||
hidden_states, kv_cache, attn_metadata)
|
hidden_states, kv_cache, attn_metadata)
|
||||||
|
|||||||
@@ -33,10 +33,6 @@ def register_model():
|
|||||||
"DeepseekV32ForCausalLM",
|
"DeepseekV32ForCausalLM",
|
||||||
"vllm_ascend.models.deepseek_v3_2:CustomDeepseekV3ForCausalLM")
|
"vllm_ascend.models.deepseek_v3_2:CustomDeepseekV3ForCausalLM")
|
||||||
|
|
||||||
ModelRegistry.register_model(
|
|
||||||
"DeepSeekMTPModel",
|
|
||||||
"vllm_ascend.models.deepseek_mtp:CustomDeepSeekMTP")
|
|
||||||
|
|
||||||
# There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization
|
# There is no PanguProMoEForCausalLM in vLLM, so we should register it before vLLM config initialization
|
||||||
# to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM.
|
# to make sure the model can be loaded correctly. This register step can be removed once vLLM support PanguProMoEForCausalLM.
|
||||||
ModelRegistry.register_model(
|
ModelRegistry.register_model(
|
||||||
|
|||||||
@@ -1,209 +0,0 @@
|
|||||||
#
|
|
||||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
||||||
# Adapted from vllm/model_executor/models/deepseek_mtp.py
|
|
||||||
# Copyright 2023 The vLLM team.
|
|
||||||
#
|
|
||||||
# This file is a part of the vllm-ascend project.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
from transformers import PretrainedConfig
|
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
|
||||||
from vllm.compilation.decorators import support_torch_compile
|
|
||||||
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
|
|
||||||
get_current_vllm_config)
|
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
||||||
ParallelLMHead, VocabParallelEmbedding)
|
|
||||||
from vllm.model_executor.models.deepseek_mtp import (
|
|
||||||
DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer,
|
|
||||||
SharedHead)
|
|
||||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer
|
|
||||||
from vllm.model_executor.models.utils import maybe_prefix
|
|
||||||
from vllm.sequence import IntermediateTensors
|
|
||||||
|
|
||||||
|
|
||||||
class CustomDeepSeekShareHead(SharedHead):
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
config: PretrainedConfig,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
|
||||||
prefix: str = "") -> None:
|
|
||||||
nn.Module.__init__(self)
|
|
||||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
||||||
self.head = ParallelLMHead(config.vocab_size,
|
|
||||||
config.hidden_size,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=maybe_prefix(prefix, "head"))
|
|
||||||
|
|
||||||
|
|
||||||
class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
config: PretrainedConfig,
|
|
||||||
prefix: str,
|
|
||||||
model_config: ModelConfig,
|
|
||||||
cache_config: Optional[CacheConfig] = None,
|
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
|
||||||
) -> None:
|
|
||||||
nn.Module.__init__(self)
|
|
||||||
vllm_config = get_current_vllm_config()
|
|
||||||
|
|
||||||
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
||||||
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
||||||
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
|
||||||
config.hidden_size,
|
|
||||||
bias=False)
|
|
||||||
self.shared_head = CustomDeepSeekShareHead(config=config,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=maybe_prefix(
|
|
||||||
prefix, "shared_head"))
|
|
||||||
self.mtp_block = DeepseekV2DecoderLayer(vllm_config=vllm_config,
|
|
||||||
prefix=prefix)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
kv_cache: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
previous_hidden_states: torch.Tensor,
|
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
|
||||||
spec_step_index: int = 0,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
assert inputs_embeds is not None
|
|
||||||
inputs_embeds = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
|
||||||
inputs_embeds, True)
|
|
||||||
# masking inputs at position 0, as not needed by MTP
|
|
||||||
inputs_embeds = torch.where((positions == 0).unsqueeze(-1),
|
|
||||||
torch.zeros_like(inputs_embeds),
|
|
||||||
inputs_embeds)
|
|
||||||
inputs_embeds = self.enorm(inputs_embeds)
|
|
||||||
previous_hidden_states = self.hnorm(previous_hidden_states)
|
|
||||||
|
|
||||||
hidden_states = self.eh_proj(
|
|
||||||
torch.cat([inputs_embeds, previous_hidden_states], dim=-1))
|
|
||||||
|
|
||||||
hidden_states, residual = self.mtp_block(positions=positions,
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
residual=None)
|
|
||||||
hidden_states = residual + hidden_states
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
|
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
||||||
nn.Module.__init__(self)
|
|
||||||
config = vllm_config.model_config.hf_config
|
|
||||||
self.mtp_start_layer_idx = config.num_hidden_layers
|
|
||||||
self.num_mtp_layers = config.num_nextn_predict_layers
|
|
||||||
# to map the exact layer index from weights
|
|
||||||
self.layers = torch.nn.ModuleDict({
|
|
||||||
str(idx):
|
|
||||||
CustomDeepSeekMultiTokenPredictorLayer(
|
|
||||||
config,
|
|
||||||
f"{prefix}.layers.{idx}",
|
|
||||||
model_config=vllm_config.model_config,
|
|
||||||
cache_config=vllm_config.cache_config,
|
|
||||||
quant_config=vllm_config.quant_config,
|
|
||||||
)
|
|
||||||
for idx in range(self.mtp_start_layer_idx,
|
|
||||||
self.mtp_start_layer_idx + self.num_mtp_layers)
|
|
||||||
})
|
|
||||||
self.embed_tokens = VocabParallelEmbedding(
|
|
||||||
config.vocab_size,
|
|
||||||
config.hidden_size,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Note: torch._dynamo.exc.Unsupported: builtin: str
|
|
||||||
self.layers_list = [
|
|
||||||
self.layers[str(idx)]
|
|
||||||
for idx in range(self.mtp_start_layer_idx,
|
|
||||||
self.mtp_start_layer_idx + self.num_mtp_layers)
|
|
||||||
]
|
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
kv_caches: torch.Tensor,
|
|
||||||
attn_metadata: AttentionMetadata,
|
|
||||||
previous_hidden_states: torch.Tensor,
|
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
|
||||||
spec_step_idx: int = 0,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
if inputs_embeds is None:
|
|
||||||
inputs_embeds = self.embed_tokens(input_ids)
|
|
||||||
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
|
||||||
step_kv_cache = kv_caches[
|
|
||||||
current_step_idx] if kv_caches is not None else None
|
|
||||||
return self.layers_list[current_step_idx](
|
|
||||||
input_ids,
|
|
||||||
positions,
|
|
||||||
step_kv_cache,
|
|
||||||
attn_metadata,
|
|
||||||
previous_hidden_states,
|
|
||||||
inputs_embeds,
|
|
||||||
current_step_idx,
|
|
||||||
)
|
|
||||||
|
|
||||||
def compute_logits(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
sampling_metadata=None, # type: ignore
|
|
||||||
spec_step_idx: int = 0,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
current_step_idx = (spec_step_idx % self.num_mtp_layers)
|
|
||||||
mtp_layer = self.layers_list[current_step_idx]
|
|
||||||
logits = self.logits_processor(mtp_layer.shared_head.head,
|
|
||||||
mtp_layer.shared_head(hidden_states),
|
|
||||||
sampling_metadata)
|
|
||||||
return logits
|
|
||||||
|
|
||||||
|
|
||||||
@support_torch_compile
|
|
||||||
class CustomDeepSeekMTP(DeepSeekMTP):
|
|
||||||
|
|
||||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
|
||||||
nn.Module.__init__(self)
|
|
||||||
self.config = vllm_config.model_config.hf_config
|
|
||||||
self.model = CustomDeepSeekMultiTokenPredictor(vllm_config=vllm_config,
|
|
||||||
prefix=maybe_prefix(
|
|
||||||
prefix, "model"))
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: torch.Tensor,
|
|
||||||
positions: torch.Tensor,
|
|
||||||
kv_caches: Optional[List[torch.Tensor]] = None,
|
|
||||||
attn_metadata: Optional[AttentionMetadata] = None,
|
|
||||||
previous_hidden_states: Optional[torch.Tensor] = None,
|
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
||||||
inputs_embeds: Optional[torch.Tensor] = None,
|
|
||||||
spec_step_idx: int = 0,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
|
||||||
attn_metadata, previous_hidden_states,
|
|
||||||
inputs_embeds, spec_step_idx)
|
|
||||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
|
||||||
hidden_states, True)
|
|
||||||
return hidden_states
|
|
||||||
@@ -56,6 +56,14 @@ class AscendQuantConfig(QuantizationConfig):
|
|||||||
def __init__(self, quant_config: Dict[str, Any]):
|
def __init__(self, quant_config: Dict[str, Any]):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.quant_description = quant_config
|
self.quant_description = quant_config
|
||||||
|
# TODO(whx): remove this adaptation after adding "shared_head"
|
||||||
|
# to prefix of DeepSeekShareHead in vLLM.
|
||||||
|
extra_quant_dict = {}
|
||||||
|
for k in self.quant_description.keys():
|
||||||
|
if "shared_head" in k:
|
||||||
|
new_k = k.replace(".shared_head.", ".")
|
||||||
|
extra_quant_dict[new_k] = self.quant_description[k]
|
||||||
|
self.quant_description.update(extra_quant_dict)
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return "AscendQuantConfig:\n" + super().__repr__()
|
return "AscendQuantConfig:\n" + super().__repr__()
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from vllm.forward_context import BatchDescriptor, get_forward_context
|
|||||||
from vllm.model_executor.model_loader import get_model_loader
|
from vllm.model_executor.model_loader import get_model_loader
|
||||||
from vllm.model_executor.model_loader.utils import (
|
from vllm.model_executor.model_loader.utils import (
|
||||||
process_weights_after_loading, set_default_torch_dtype)
|
process_weights_after_loading, set_default_torch_dtype)
|
||||||
|
from vllm.model_executor.models.deepseek_mtp import DeepSeekMTP
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
@@ -18,7 +19,6 @@ from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
|||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||||
from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP
|
|
||||||
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
|
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
|
||||||
from vllm_ascend.torchair.models.torchair_deepseek_mtp import \
|
from vllm_ascend.torchair.models.torchair_deepseek_mtp import \
|
||||||
TorchairDeepSeekMTP
|
TorchairDeepSeekMTP
|
||||||
@@ -86,7 +86,7 @@ class MtpProposer(Proposer):
|
|||||||
self.model = TorchairDeepSeekMTP(
|
self.model = TorchairDeepSeekMTP(
|
||||||
vllm_config=self.vllm_config).to(target_device)
|
vllm_config=self.vllm_config).to(target_device)
|
||||||
else:
|
else:
|
||||||
self.model = CustomDeepSeekMTP(
|
self.model = DeepSeekMTP(
|
||||||
vllm_config=self.vllm_config).to(target_device)
|
vllm_config=self.vllm_config).to(target_device)
|
||||||
|
|
||||||
draft_attn_layer_names = (
|
draft_attn_layer_names = (
|
||||||
@@ -184,7 +184,7 @@ class MtpProposer(Proposer):
|
|||||||
else:
|
else:
|
||||||
self.model(input_ids=input_ids,
|
self.model(input_ids=input_ids,
|
||||||
positions=positions,
|
positions=positions,
|
||||||
previous_hidden_states=previous_hidden_states)
|
hidden_states=previous_hidden_states)
|
||||||
if with_prefill:
|
if with_prefill:
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -470,9 +470,8 @@ class MtpProposer(Proposer):
|
|||||||
hidden_states = self.model(
|
hidden_states = self.model(
|
||||||
input_ids=self.input_ids[:num_input_tokens],
|
input_ids=self.input_ids[:num_input_tokens],
|
||||||
positions=self.positions[:num_input_tokens],
|
positions=self.positions[:num_input_tokens],
|
||||||
previous_hidden_states=self.
|
hidden_states=self.hidden_states[:num_input_tokens]
|
||||||
hidden_states[:num_input_tokens],
|
)
|
||||||
kv_caches=self.runner.kv_caches[-1:])
|
|
||||||
|
|
||||||
num_indices = last_token_indices.shape[0]
|
num_indices = last_token_indices.shape[0]
|
||||||
if lmhead_tp_enable():
|
if lmhead_tp_enable():
|
||||||
@@ -485,7 +484,7 @@ class MtpProposer(Proposer):
|
|||||||
(0, max_num_reqs_across_dp - num_indices))
|
(0, max_num_reqs_across_dp - num_indices))
|
||||||
|
|
||||||
sample_hidden_states = hidden_states[last_token_indices]
|
sample_hidden_states = hidden_states[last_token_indices]
|
||||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
logits = self.model.compute_logits(sample_hidden_states)
|
||||||
if lmhead_tp_enable() and num_indices < logits.shape[0]:
|
if lmhead_tp_enable() and num_indices < logits.shape[0]:
|
||||||
logits = logits[:num_indices]
|
logits = logits[:num_indices]
|
||||||
draft_token_ids = logits.argmax(dim=-1)
|
draft_token_ids = logits.argmax(dim=-1)
|
||||||
|
|||||||
Reference in New Issue
Block a user