[bugfixed] fix the bug when run the inference of quantized ds-w8a8-mtp (#2134)

When run the inference of ds-w8a8-mtp, it reported 'ParamllelLMhead has
no attribute 'params_dtype''.

1. add wrapper of vocab_parallel_embedding, fixed the bugs when running
deepseek-w8a8-mtp

Signed-off-by: curryliu <120010041@link.cuhk.edu.cn>

- vLLM version: v0.10.0
- vLLM main:
ad57f23f6a

---------

Signed-off-by: curryliu <120010041@link.cuhk.edu.cn>
This commit is contained in:
liu
2025-08-04 15:16:42 +08:00
committed by GitHub
parent 4b3a210c33
commit 688350a3bb
2 changed files with 38 additions and 1 deletions

View File

@@ -22,6 +22,39 @@ import torch_npu
from vllm.logger import logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, QuantizationConfig)
# func refers to vocabParallelEmbedding.__init__
def wrapper_vocab_parallel_embedding_init(func):
def init(
self,
num_embeddings: int,
embedding_dim: int,
params_dtype: Optional[torch.dtype] = None,
org_num_embeddings: Optional[int] = None,
padding_size: int = DEFAULT_VOCAB_PADDING_SIZE,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
func(
self,
num_embeddings,
embedding_dim,
params_dtype,
org_num_embeddings,
padding_size,
quant_config,
prefix,
)
# TODO: Contact vLLM maintainers to add a `params_dtype` attribute to the `VocabParallelEmbedding` class.
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.params_dtype = params_dtype
return init
# func refers to RMSNorm.__init__

View File

@@ -22,7 +22,8 @@ from typing import Any, Dict, List, Optional
from vllm.logger import logger
from .func_wrapper import wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init
from .func_wrapper import (wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init,
wrapper_vocab_parallel_embedding_init)
from .w4a8_dynamic import AscendW4A8DynamicLinearMethod
from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod,
AscendW8A8LinearMethod)
@@ -75,6 +76,9 @@ class VLLMAscendQuantizer:
VLLMAscendQuantizer.apply_patch(
"vllm.model_executor.layers.layernorm.RMSNorm",
"forward_oot", [wrapper_rmsnorm_forward_oot])
VLLMAscendQuantizer.apply_patch(
"vllm.model_executor.layers.vocab_parallel_embedding.VocabParallelEmbedding",
"__init__", [wrapper_vocab_parallel_embedding_init])
break
VLLMAscendQuantizer.patched = True
logger.info("Using the vLLM Ascend Quantizer version now!")