[BUGFIX] main-sd-bugfix && [UT] add mtp UT (#593)
### What this PR does / why we need it? The pr will fix some bug about spec decode / MTP The pr add a mtp e2e UT `test_mtp_correctness.py` **vllm_ascend/attention/attention.py** 1. add support `self.attn_mask_cache` only has 1 element to cover scene in which both spec docode and chunked prefill are enabled. **vllm_ascend/distributed/parallel_state.py** 1. remove 2 assert because spec decode worker would use init_worker twice **vllm_ascend/models/deepseek_mtp.py** 1. remove unused params; 2. add support w8a8 in `CustomDeepSeekMTP` **vllm_ascend/quantization/quant_config.py** 1. use `AscendUnquantizedFusedMoEMethod` instead of `UnquantizedFusedMoEMethod` **other** 1. replace `from vllm.logger import init_logger` to `from vllm.logger import logger` all of the vllm-ascend project ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? Signed-off-by: mengwei805 <mengwei25@huawei.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Adapted from vllm/model_executor/models/qwen2_vl.py
|
||||
# Adapted from vllm/model_executor/models/deepseek_mtp.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# This file is a part of the vllm-ascend project.
|
||||
@@ -17,12 +17,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
@@ -70,8 +69,6 @@ class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_cache: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_index: int = 0,
|
||||
@@ -91,8 +88,6 @@ class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
|
||||
|
||||
hidden_states, residual = self.mtp_block(positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=None)
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
@@ -130,8 +125,6 @@ class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
previous_hidden_states: torch.Tensor,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
spec_step_idx: int = 0,
|
||||
@@ -140,8 +133,6 @@ class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
|
||||
return self.layers_list[current_step_idx](
|
||||
input_ids,
|
||||
positions,
|
||||
kv_caches[current_step_idx],
|
||||
attn_metadata,
|
||||
previous_hidden_states,
|
||||
inputs_embeds,
|
||||
current_step_idx,
|
||||
@@ -162,6 +153,14 @@ class CustomDeepSeekMultiTokenPredictor(DeepSeekMultiTokenPredictor):
|
||||
|
||||
|
||||
class CustomDeepSeekMTP(DeepSeekMTP):
|
||||
# NOTE 1.The quantized MTP layer of deepseek on the NPU is not quantized;
|
||||
# NOTE 2.The description file generated by the current msmodelslim tool does not have
|
||||
# MTP layer info. Please manually add it and set the value to FLOAT.
|
||||
packed_modules_mapping = {
|
||||
"gate_up_proj": ["gate_proj", "up_proj"],
|
||||
"experts":
|
||||
["experts.0.gate_proj", "experts.0.up_proj", "experts.0.down_proj"]
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
nn.Module.__init__(self)
|
||||
|
||||
Reference in New Issue
Block a user