diff --git a/vllm_ascend/quantization/func_wrapper.py b/vllm_ascend/quantization/func_wrapper.py index 77ecca2..8357695 100644 --- a/vllm_ascend/quantization/func_wrapper.py +++ b/vllm_ascend/quantization/func_wrapper.py @@ -22,6 +22,39 @@ import torch_npu from vllm.logger import logger from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import UnquantizedLinearMethod +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, QuantizationConfig) + + +# func refers to vocabParallelEmbedding.__init__ +def wrapper_vocab_parallel_embedding_init(func): + + def init( + self, + num_embeddings: int, + embedding_dim: int, + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + func( + self, + num_embeddings, + embedding_dim, + params_dtype, + org_num_embeddings, + padding_size, + quant_config, + prefix, + ) + # TODO: Contact vLLM maintainers to add a `params_dtype` attribute to the `VocabParallelEmbedding` class. + if params_dtype is None: + params_dtype = torch.get_default_dtype() + self.params_dtype = params_dtype + + return init # func refers to RMSNorm.__init__ diff --git a/vllm_ascend/quantization/quantizer.py b/vllm_ascend/quantization/quantizer.py index e61593d..90c7512 100644 --- a/vllm_ascend/quantization/quantizer.py +++ b/vllm_ascend/quantization/quantizer.py @@ -22,7 +22,8 @@ from typing import Any, Dict, List, Optional from vllm.logger import logger -from .func_wrapper import wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init +from .func_wrapper import (wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init, + wrapper_vocab_parallel_embedding_init) from .w4a8_dynamic import AscendW4A8DynamicLinearMethod from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod, AscendW8A8LinearMethod) @@ -75,6 +76,9 @@ class VLLMAscendQuantizer: VLLMAscendQuantizer.apply_patch( "vllm.model_executor.layers.layernorm.RMSNorm", "forward_oot", [wrapper_rmsnorm_forward_oot]) + VLLMAscendQuantizer.apply_patch( + "vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding", + "__init__", [wrapper_vocab_parallel_embedding_init]) break VLLMAscendQuantizer.patched = True logger.info("Using the vLLM Ascend Quantizer version now!")