diff --git a/vllm_ascend/patch/__init__.py b/vllm_ascend/patch/__init__.py index 7d3725f..575d3ac 100644 --- a/vllm_ascend/patch/__init__.py +++ b/vllm_ascend/patch/__init__.py @@ -146,3 +146,17 @@ # No, this need CANN add an aclnn shift operation # Future Plan: # Revert this when CANN support shift aclnn operation +# +# ** File: worker/patch_deepseek_mtp.py** +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# 1. `vllm.model_executor.models.deepseek_mtp.DeepSeekMultiTokenPredictorLayer.__init__` +# Why: +# '__init__' func of DeepSeekMultiTokenPredictorLayer didn't pass prefix to SharedHead. +# How: +# Replace with a new __init__. +# Use a new SharedHead which passes prefix to ParallelLMHead. +# Related PR (if no, explain why): +# https://github.com/vllm-project/vllm/pull/25805 +# Future Plan: +# Remove this patch when adapted vllm version contains the above PR. +# diff --git a/vllm_ascend/patch/worker/__init__.py b/vllm_ascend/patch/worker/__init__.py index fa7d195..1cad82f 100644 --- a/vllm_ascend/patch/worker/__init__.py +++ b/vllm_ascend/patch/worker/__init__.py @@ -27,3 +27,4 @@ import vllm_ascend.patch.worker.patch_roberta # noqa import vllm_ascend.patch.worker.patch_weight_loader # noqa import vllm_ascend.patch.worker.patch_multimodal_merge # noqa import vllm_ascend.patch.worker.patch_minicpm # noqa +import vllm_ascend.patch.worker.patch_deepseek_mtp # noqa diff --git a/vllm_ascend/patch/worker/patch_deepseek_mtp.py b/vllm_ascend/patch/worker/patch_deepseek_mtp.py new file mode 100644 index 0000000..f64ebbe --- /dev/null +++ b/vllm_ascend/patch/worker/patch_deepseek_mtp.py @@ -0,0 +1,55 @@ +import torch +import torch.nn as nn +from transformers import PretrainedConfig +from vllm.config import VllmConfig +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.models.deepseek_mtp import \ + DeepSeekMultiTokenPredictorLayer +from vllm.model_executor.models.deepseek_v2 import DeepseekV2DecoderLayer +from vllm.model_executor.models.utils import maybe_prefix + + +class SharedHead(nn.Module): + + def __init__( + self, + config: PretrainedConfig, + prefix: str, + quant_config: QuantizationConfig = None, + ) -> None: + super().__init__() + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.head = ParallelLMHead( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "head"), + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + return self.norm(hidden_states) + + +def predictor_init(self, vllm_config: VllmConfig, prefix: str) -> None: + nn.Module.__init__(self) + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.eh_proj = nn.Linear(config.hidden_size * 2, + config.hidden_size, + bias=False) + + # We don't need topk_indices_buffer in Ascend + topk_indices_buffer = None + self.shared_head = SharedHead(config=config, + prefix=prefix, + quant_config=quant_config) + self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix, + topk_indices_buffer) + + +DeepSeekMultiTokenPredictorLayer.__init__ = predictor_init