[bugfixed] fix the bug when run the inference of quantized ds-w8a8-mtp (#2134)
When run the inference of ds-w8a8-mtp, it reported 'ParamllelLMhead has
no attribute 'params_dtype''.
1. add wrapper of vocab_parallel_embedding, fixed the bugs when running
deepseek-w8a8-mtp
Signed-off-by: curryliu <120010041@link.cuhk.edu.cn>
- vLLM version: v0.10.0
- vLLM main:
ad57f23f6a
---------
Signed-off-by: curryliu <120010041@link.cuhk.edu.cn>
This commit is contained in:
@@ -22,6 +22,39 @@ import torch_npu
|
|||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
DEFAULT_VOCAB_PADDING_SIZE, QuantizationConfig)
|
||||||
|
|
||||||
|
|
||||||
|
# func refers to vocabParallelEmbedding.__init__
|
||||||
|
def wrapper_vocab_parallel_embedding_init(func):
|
||||||
|
|
||||||
|
def init(
|
||||||
|
self,
|
||||||
|
num_embeddings: int,
|
||||||
|
embedding_dim: int,
|
||||||
|
params_dtype: Optional[torch.dtype] = None,
|
||||||
|
org_num_embeddings: Optional[int] = None,
|
||||||
|
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||||
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
|
prefix: str = "",
|
||||||
|
):
|
||||||
|
func(
|
||||||
|
self,
|
||||||
|
num_embeddings,
|
||||||
|
embedding_dim,
|
||||||
|
params_dtype,
|
||||||
|
org_num_embeddings,
|
||||||
|
padding_size,
|
||||||
|
quant_config,
|
||||||
|
prefix,
|
||||||
|
)
|
||||||
|
# TODO: Contact vLLM maintainers to add a `params_dtype` attribute to the `VocabParallelEmbedding` class.
|
||||||
|
if params_dtype is None:
|
||||||
|
params_dtype = torch.get_default_dtype()
|
||||||
|
self.params_dtype = params_dtype
|
||||||
|
|
||||||
|
return init
|
||||||
|
|
||||||
|
|
||||||
# func refers to RMSNorm.__init__
|
# func refers to RMSNorm.__init__
|
||||||
|
|||||||
@@ -22,7 +22,8 @@ from typing import Any, Dict, List, Optional
|
|||||||
|
|
||||||
from vllm.logger import logger
|
from vllm.logger import logger
|
||||||
|
|
||||||
from .func_wrapper import wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init
|
from .func_wrapper import (wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init,
|
||||||
|
wrapper_vocab_parallel_embedding_init)
|
||||||
from .w4a8_dynamic import AscendW4A8DynamicLinearMethod
|
from .w4a8_dynamic import AscendW4A8DynamicLinearMethod
|
||||||
from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod,
|
from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod,
|
||||||
AscendW8A8LinearMethod)
|
AscendW8A8LinearMethod)
|
||||||
@@ -75,6 +76,9 @@ class VLLMAscendQuantizer:
|
|||||||
VLLMAscendQuantizer.apply_patch(
|
VLLMAscendQuantizer.apply_patch(
|
||||||
"vllm.model_executor.layers.layernorm.RMSNorm",
|
"vllm.model_executor.layers.layernorm.RMSNorm",
|
||||||
"forward_oot", [wrapper_rmsnorm_forward_oot])
|
"forward_oot", [wrapper_rmsnorm_forward_oot])
|
||||||
|
VLLMAscendQuantizer.apply_patch(
|
||||||
|
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding",
|
||||||
|
"__init__", [wrapper_vocab_parallel_embedding_init])
|
||||||
break
|
break
|
||||||
VLLMAscendQuantizer.patched = True
|
VLLMAscendQuantizer.patched = True
|
||||||
logger.info("Using the vLLM Ascend Quantizer version now!")
|
logger.info("Using the vLLM Ascend Quantizer version now!")
|
||||||
|
|||||||
Reference in New Issue
Block a user