[Bugfix] fix mtp profile run error where main model and mtp model use different quantization (#4102)

### What this PR does / why we need it?
In PR https://github.com/vllm-project/vllm-ascend/pull/3420, we
initially placed the quantization type (quant_type) in the MoECommMethod
class. However, since MoECommMethod follows a singleton pattern, it
couldn't accommodate scenarios where different layers in the model might
use different quantization approaches (e.g., MTP modules using
floating-point computation while the main model employs quantized
computation).
In this PR, we've moved the quantization type to the AscendFusedMoe
class and pass it as a parameter to MoECommMethod.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
```bash
export HCCL_BUFFSIZE=1024
export VLLM_VERSION=0.11.0

vllm serve /home/data/DeepSeek-R1_w8a8/ \
 --data-parallel-size 2 \
 --tensor-parallel-size 8 \
 --enable-expert-parallel \
 --served-model-name dsv3 \
 --max-model-len 32768 \
 --max-num-batched-tokens 4096 \
 --max-num-seqs 16 \
 --quantization ascend \
 --trust-remote-code \
 --gpu-memory-utilization 0.9 \
 --speculative-config '{"num_speculative_tokens": 2, "method":"deepseek_mtp"}'
```


- vLLM version: v0.11.0
- vLLM main:
83f478bb19

---------

Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
realliujiaxu
2025-11-13 11:02:31 +08:00
committed by GitHub
parent 17259cb265
commit 5093192769
6 changed files with 82 additions and 76 deletions

View File

@@ -31,10 +31,6 @@ from vllm_ascend.ops.fused_moe.prepare_finalize import (
from vllm_ascend.ops.fused_moe.token_dispatcher import (
TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather,
TokenDispatcherWithMC2, TokenDispatcherWithMoge)
from vllm_ascend.quantization.w4a8_dynamic import \
AscendW4A8DynamicFusedMoEMethod
from vllm_ascend.quantization.w8a8_dynamic import \
AscendW8A8DynamicFusedMoEMethod
_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}
@@ -44,54 +40,37 @@ def get_moe_comm_method(
return _MoECommMethods.get(moe_comm_type, None)
def setup_moe_comm_method(moe_config, quant_method):
_MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(
moe_config, quant_method)
_MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(
moe_config, quant_method)
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config, quant_method)
def setup_moe_comm_method(moe_config):
_MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(moe_config)
_MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(moe_config)
_MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config)
_MoECommMethods[MoECommType.NAIVE_MULTICAST] = NaiveMulticastCommImpl(
moe_config, quant_method)
moe_config)
class MoECommMethod(ABC):
"""Base class for MoE communication methods."""
def __init__(self, moe_config: FusedMoEConfig, quant_method=None):
def __init__(self, moe_config: FusedMoEConfig):
self.model_type = get_current_vllm_config(
).model_config.hf_config.model_type
self.moe_config = moe_config
self.token_dispatcher = self._get_token_dispatcher()
self.quant_type = self._get_quant_type(quant_method)
self.with_quant = self.quant_type != QuantType.NONE
self.prepare_finalize = self._get_prepare_finalize()
def _get_quant_type(self, quant_method) -> QuantType:
if not hasattr(quant_method,
"quant_method") or quant_method.quant_method is None:
return QuantType.NONE
method = quant_method.quant_method
if isinstance(method, AscendW8A8DynamicFusedMoEMethod):
return QuantType.W8A8
elif isinstance(method, AscendW4A8DynamicFusedMoEMethod):
return QuantType.W4A8
else:
return QuantType.NONE
def prepare(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False
replace_allreduce: bool = False,
quant_type: QuantType = QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
hidden_states, router_logits, mc2_mask, context_metadata = self.prepare_finalize.prepare(
hidden_states, router_logits, enable_shared_expert_dp,
replace_allreduce)
replace_allreduce, quant_type)
return hidden_states, router_logits, mc2_mask, context_metadata
def finalize(self,
@@ -112,6 +91,8 @@ class MoECommMethod(ABC):
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_int8_w8a8: bool = False,
use_int4_w4a8: bool = False,
global_num_experts: Optional[int] = None,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
@@ -151,29 +132,29 @@ class MoECommMethod(ABC):
dynamic_scale_for_share=dynamic_scale_for_share,
mc2_mask=mc2_mask,
apply_router_weight_on_input=apply_router_weight_on_input,
with_quant=self.with_quant,
with_quant=use_int8_w8a8 or use_int4_w4a8,
dynamic_eplb=dynamic_eplb,
pertoken_scale=pertoken_scale)
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales, context_metadata = \
results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales"), results.get("context_metadata")
mlp_output = unified_apply_mlp(
hidden_states=permuted_hidden_states,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=expert_tokens,
dynamic_scale=dynamic_scale,
group_list_type=group_list_type,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
topk_scales=topk_scales,
with_quant=self.with_quant,
fusion=self.quant_type == QuantType.W8A8,
need_trans=need_trans,
dynamic_eplb=dynamic_eplb)
mlp_output = unified_apply_mlp(hidden_states=permuted_hidden_states,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=expert_tokens,
dynamic_scale=dynamic_scale,
group_list_type=group_list_type,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
topk_scales=topk_scales,
with_quant=use_int8_w8a8
or use_int4_w4a8,
fusion=use_int8_w8a8,
need_trans=need_trans,
dynamic_eplb=dynamic_eplb)
final_hidden_states = self.token_dispatcher.token_combine(
hidden_states=mlp_output, context_metadata=context_metadata)
@@ -226,8 +207,7 @@ class AllGatherCommImpl(MoECommMethod):
num_local_experts=self.moe_config.num_local_experts)
def _get_prepare_finalize(self):
return PrepareAndFinalizeWithAllGather(self.moe_config,
self.quant_type)
return PrepareAndFinalizeWithAllGather(self.moe_config)
class MC2CommImpl(MoECommMethod):
@@ -244,7 +224,7 @@ class MC2CommImpl(MoECommMethod):
return TokenDispatcherWithMC2()
def _get_prepare_finalize(self):
return PrepareAndFinalizeWithMC2(self.moe_config, self.quant_type)
return PrepareAndFinalizeWithMC2(self.moe_config)
class AlltoAllCommImpl(MoECommMethod):
@@ -264,7 +244,7 @@ class AlltoAllCommImpl(MoECommMethod):
num_local_experts=self.moe_config.num_local_experts)
def _get_prepare_finalize(self):
return PrepareAndFinalizeWithAll2All(self.moe_config, self.quant_type)
return PrepareAndFinalizeWithAll2All(self.moe_config)
class NaiveMulticastCommImpl(MoECommMethod):
@@ -293,5 +273,4 @@ class NaiveMulticastCommImpl(MoECommMethod):
num_local_experts=self.moe_config.num_local_experts)
def _get_prepare_finalize(self):
return PrepareAndFinalizeWithNaiveMulticast(self.moe_config,
self.quant_type)
return PrepareAndFinalizeWithNaiveMulticast(self.moe_config)