[bugfixed] fix the bug when run the inference of quantized ds-w8a8-mtp (#2134)
When run the inference of ds-w8a8-mtp, it reported 'ParamllelLMhead has
no attribute 'params_dtype''.
1. add wrapper of vocab_parallel_embedding, fixed the bugs when running
deepseek-w8a8-mtp
Signed-off-by: curryliu <120010041@link.cuhk.edu.cn>
- vLLM version: v0.10.0
- vLLM main:
ad57f23f6a
---------
Signed-off-by: curryliu <120010041@link.cuhk.edu.cn>
This commit is contained in:
@@ -22,6 +22,39 @@ import torch_npu
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, QuantizationConfig)
|
||||
|
||||
|
||||
# func refers to vocabParallelEmbedding.__init__
|
||||
def wrapper_vocab_parallel_embedding_init(func):
|
||||
|
||||
def init(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
org_num_embeddings: Optional[int] = None,
|
||||
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
func(
|
||||
self,
|
||||
num_embeddings,
|
||||
embedding_dim,
|
||||
params_dtype,
|
||||
org_num_embeddings,
|
||||
padding_size,
|
||||
quant_config,
|
||||
prefix,
|
||||
)
|
||||
# TODO: Contact vLLM maintainers to add a `params_dtype` attribute to the `VocabParallelEmbedding` class.
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.params_dtype = params_dtype
|
||||
|
||||
return init
|
||||
|
||||
|
||||
# func refers to RMSNorm.__init__
|
||||
|
||||
@@ -22,7 +22,8 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from vllm.logger import logger
|
||||
|
||||
from .func_wrapper import wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init
|
||||
from .func_wrapper import (wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init,
|
||||
wrapper_vocab_parallel_embedding_init)
|
||||
from .w4a8_dynamic import AscendW4A8DynamicLinearMethod
|
||||
from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod,
|
||||
AscendW8A8LinearMethod)
|
||||
@@ -75,6 +76,9 @@ class VLLMAscendQuantizer:
|
||||
VLLMAscendQuantizer.apply_patch(
|
||||
"vllm.model_executor.layers.layernorm.RMSNorm",
|
||||
"forward_oot", [wrapper_rmsnorm_forward_oot])
|
||||
VLLMAscendQuantizer.apply_patch(
|
||||
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding",
|
||||
"__init__", [wrapper_vocab_parallel_embedding_init])
|
||||
break
|
||||
VLLMAscendQuantizer.patched = True
|
||||
logger.info("Using the vLLM Ascend Quantizer version now!")
|
||||
|
||||
Reference in New Issue
Block a user