[Feature] Enable inference support for Deepseekr1-w8a8-MTP (#1994)
Support the inference of the Deepseekr1-w8a8-mtp model with
statically-quantized shared_head in MTP layers.
- vLLM version: v0.9.2
- vLLM main:
6eca337ce0
Signed-off-by: curryliu <120010041@link.cuhk.edu.cn>
This commit is contained in:
@@ -28,8 +28,8 @@ from vllm.model_executor.layers.layernorm import RMSNorm
|
|||||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.layers.sampler import get_sampler
|
from vllm.model_executor.layers.sampler import get_sampler
|
||||||
from vllm.model_executor.layers.vocab_parallel_embedding import \
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
VocabParallelEmbedding
|
ParallelLMHead, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.models.deepseek_mtp import (
|
from vllm.model_executor.models.deepseek_mtp import (
|
||||||
DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer,
|
DeepSeekMTP, DeepSeekMultiTokenPredictor, DeepSeekMultiTokenPredictorLayer,
|
||||||
SharedHead)
|
SharedHead)
|
||||||
@@ -40,6 +40,20 @@ from vllm.sequence import IntermediateTensors
|
|||||||
from .deepseek_v2 import CustomDeepseekV2DecoderLayer
|
from .deepseek_v2 import CustomDeepseekV2DecoderLayer
|
||||||
|
|
||||||
|
|
||||||
|
class CustomDeepSeekShareHead(SharedHead):
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
config: PretrainedConfig,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "") -> None:
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
self.head = ParallelLMHead(config.vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=maybe_prefix(prefix, "head"))
|
||||||
|
|
||||||
|
|
||||||
class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
|
class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -61,7 +75,10 @@ class CustomDeepSeekMultiTokenPredictorLayer(DeepSeekMultiTokenPredictorLayer):
|
|||||||
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
bias=False)
|
bias=False)
|
||||||
self.shared_head = SharedHead(config=config, quant_config=quant_config)
|
self.shared_head = CustomDeepSeekShareHead(config=config,
|
||||||
|
quant_config=quant_config,
|
||||||
|
prefix=maybe_prefix(
|
||||||
|
prefix, "shared_head"))
|
||||||
self.mtp_block = CustomDeepseekV2DecoderLayer(config, prefix,
|
self.mtp_block = CustomDeepseekV2DecoderLayer(config, prefix,
|
||||||
model_config,
|
model_config,
|
||||||
cache_config,
|
cache_config,
|
||||||
|
|||||||
@@ -868,7 +868,9 @@ class CustomDeepseekV2ForCausalLM(DeepseekV2ForCausalLM):
|
|||||||
if get_pp_group().is_last_rank:
|
if get_pp_group().is_last_rank:
|
||||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||||
config.hidden_size,
|
config.hidden_size,
|
||||||
quant_config=quant_config)
|
quant_config=quant_config,
|
||||||
|
prefix=maybe_prefix(
|
||||||
|
prefix, "lm_head"))
|
||||||
else:
|
else:
|
||||||
self.lm_head = PPMissingLayer()
|
self.lm_head = PPMissingLayer()
|
||||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||||
|
|||||||
@@ -34,6 +34,8 @@ from vllm.model_executor.layers.quantization import \
|
|||||||
from vllm.model_executor.layers.quantization.base_config import (
|
from vllm.model_executor.layers.quantization.base_config import (
|
||||||
QuantizationConfig, QuantizeMethodBase)
|
QuantizationConfig, QuantizeMethodBase)
|
||||||
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
UnquantizedEmbeddingMethod, VocabParallelEmbedding)
|
||||||
from vllm.model_executor.parameter import PerTensorScaleParameter
|
from vllm.model_executor.parameter import PerTensorScaleParameter
|
||||||
from vllm.model_executor.utils import set_weight_attrs
|
from vllm.model_executor.utils import set_weight_attrs
|
||||||
|
|
||||||
@@ -107,6 +109,12 @@ class AscendQuantConfig(QuantizationConfig):
|
|||||||
return AscendUnquantizedFusedMoEMethod()
|
return AscendUnquantizedFusedMoEMethod()
|
||||||
return AscendFusedMoEMethod(self, prefix,
|
return AscendFusedMoEMethod(self, prefix,
|
||||||
self.packed_modules_mapping)
|
self.packed_modules_mapping)
|
||||||
|
elif isinstance(layer, VocabParallelEmbedding):
|
||||||
|
if self.is_layer_skipped_ascend(prefix,
|
||||||
|
self.packed_modules_mapping):
|
||||||
|
return UnquantizedEmbeddingMethod()
|
||||||
|
return AscendEmbeddingMethod(self, prefix,
|
||||||
|
self.packed_modules_mapping)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def is_layer_skipped_ascend(
|
def is_layer_skipped_ascend(
|
||||||
@@ -319,3 +327,18 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
|
|||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
if hasattr(self.quant_method, "process_weights_after_loading"):
|
if hasattr(self.quant_method, "process_weights_after_loading"):
|
||||||
self.quant_method.process_weights_after_loading(layer)
|
self.quant_method.process_weights_after_loading(layer)
|
||||||
|
|
||||||
|
|
||||||
|
class AscendEmbeddingMethod(AscendLinearMethod):
|
||||||
|
"""Embedding method for Ascend quantization.
|
||||||
|
This class calls AscendQuantizer to search a specific quantization
|
||||||
|
implementations supported on ascend hardware for Embedding methods.
|
||||||
|
Args:
|
||||||
|
quant_config: The Ascend quantization config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, quant_config: AscendQuantConfig, prefix: str,
|
||||||
|
packed_modules_mapping: Dict[str, Any]) -> None:
|
||||||
|
self.quantizer = AscendQuantizer.get_quantizer(
|
||||||
|
quant_config.quant_description, prefix, packed_modules_mapping)
|
||||||
|
self.quant_method = self.quantizer.build_linear_method()
|
||||||
Reference in New Issue
Block a user