From ca8007f584141d3a59b2bcbd4f8ba269c9b7e252 Mon Sep 17 00:00:00 2001 From: curryliu <99582471+Irving11-BKN@users.noreply.github.com> Date: Tue, 29 Jul 2025 18:51:57 +0800 Subject: [PATCH] [Feature] Enable inference support for Deepseekr1-w8a8-MTP (#1994) Support the inference of the Deepseekr1-w8a8-mtp model with statically-quantized shared_head in MTP layers. - vLLM version: v0.9.2 - vLLM main: https://github.com/vllm-project/vllm/commit/6eca337ce09e7dfa05ce57c4183ddb5d4488c85e Signed-off-by: curryliu <120010041@link.cuhk.edu.cn> --- vllm_ascend/models/deepseek_mtp.py | 23 ++++++++++++++++++++--- vllm_ascend/models/deepseek_v2.py | 4 +++- vllm_ascend/quantization/quant_config.py | 23 +++++++++++++++++++++++ 3 files changed, 46 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/models/deepseek_mtp.py b/vllm_ascend/models/deepseek_mtp.py index 3cbc62e..bd78115 100644 --- a/vllm_ascend/models/deepseek_mtp.py +++ b/vllm_ascend/models/deepseek_mtp.py @@ -28,8 +28,8 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import get_sampler -from vllm.model_executor.layers.vocab_parallel_embedding import \ - VocabParallelEmbedding +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.models.deepseek_mtp import ( DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer, SharedHead) @@ -40,6 +40,20 @@ from vllm.sequence import IntermediateTensors from .deepseek_v2 import CustomDeepseekV2DecoderLayer +class CustomDeepSeekShareHead(SharedHead): + + def __init__(self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: + nn.Module.__init__(self) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "head")) + + class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer): def __init__( @@ -61,7 +75,10 @@ class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer): self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) - self.shared_head = SharedHead(config=config, quant_config=quant_config) + self.shared_head = CustomDeepSeekShareHead(config=config, + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "shared_head")) self.mtp_block = CustomDeepseekV2DecoderLayer(config, prefix, model_config, cache_config, diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index cb0649a..129e5eb 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -868,7 +868,9 @@ class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM): if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, - quant_config=quant_config) + quant_config=quant_config, + prefix=maybe_prefix( + prefix, "lm_head")) else: self.lm_head = PPMissingLayer() self.logits_processor = LogitsProcessor(config.vocab_size) diff --git a/vllm_ascend/quantization/quant_config.py b/vllm_ascend/quantization/quant_config.py index 7c7ee58..5984dc7 100644 --- a/vllm_ascend/quantization/quant_config.py +++ b/vllm_ascend/quantization/quant_config.py @@ -34,6 +34,8 @@ from vllm.model_executor.layers.quantization import \ from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.vocab_parallel_embedding import ( + UnquantizedEmbeddingMethod, VocabParallelEmbedding) from vllm.model_executor.parameter import PerTensorScaleParameter from vllm.model_executor.utils import set_weight_attrs @@ -107,6 +109,12 @@ class AscendQuantConfig(QuantizationConfig): return AscendUnquantizedFusedMoEMethod() return AscendFusedMoEMethod(self, prefix, self.packed_modules_mapping) + elif isinstance(layer, VocabParallelEmbedding): + if self.is_layer_skipped_ascend(prefix, + self.packed_modules_mapping): + return UnquantizedEmbeddingMethod() + return AscendEmbeddingMethod(self, prefix, + self.packed_modules_mapping) return None def is_layer_skipped_ascend( @@ -319,3 +327,18 @@ class AscendFusedMoEMethod(FusedMoEMethodBase): def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if hasattr(self.quant_method, "process_weights_after_loading"): self.quant_method.process_weights_after_loading(layer) + + +class AscendEmbeddingMethod(AscendLinearMethod): + """Embedding method for Ascend quantization. + This class calls AscendQuantizer to search a specific quantization + implementations supported on ascend hardware for Embedding methods. + Args: + quant_config: The Ascend quantization config. + """ + + def __init__(self, quant_config: AscendQuantConfig, prefix: str, + packed_modules_mapping: Dict[str, Any]) -> None: + self.quantizer = AscendQuantizer.get_quantizer( + quant_config.quant_description, prefix, packed_modules_mapping) + self.quant_method = self.quantizer.build_linear_method() \ No newline at end of file