From 509319276940b405ebdb53d40b8827d10034ece8 Mon Sep 17 00:00:00 2001 From: realliujiaxu Date: Thu, 13 Nov 2025 11:02:31 +0800 Subject: [PATCH] [Bugfix] fix mtp profile run error where main model and mtp model use different quantization (#4102) ### What this PR does / why we need it? In PR https://github.com/vllm-project/vllm-ascend/pull/3420, we initially placed the quantization type (quant_type) in the MoECommMethod class. However, since MoECommMethod follows a singleton pattern, it couldn't accommodate scenarios where different layers in the model might use different quantization approaches (e.g., MTP modules using floating-point computation while the main model employs quantized computation). In this PR, we've moved the quantization type to the AscendFusedMoe class and pass it as a parameter to MoECommMethod. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? ```bash export HCCL_BUFFSIZE=1024 export VLLM_VERSION=0.11.0 vllm serve /home/data/DeepSeek-R1_w8a8/ \ --data-parallel-size 2 \ --tensor-parallel-size 8 \ --enable-expert-parallel \ --served-model-name dsv3 \ --max-model-len 32768 \ --max-num-batched-tokens 4096 \ --max-num-seqs 16 \ --quantization ascend \ --trust-remote-code \ --gpu-memory-utilization 0.9 \ --speculative-config '{"num_speculative_tokens": 2, "method":"deepseek_mtp"}' ``` - vLLM version: v0.11.0 - vLLM main: https://github.com/vllm-project/vllm/commit/83f478bb19489b41e9d208b47b4bb5a95ac171ac --------- Signed-off-by: realliujiaxu --- tests/ut/ops/test_moe_comm_method.py | 7 +- vllm_ascend/ops/fused_moe/fused_moe.py | 26 +++++- vllm_ascend/ops/fused_moe/moe_comm_method.py | 85 +++++++------------ vllm_ascend/ops/fused_moe/prepare_finalize.py | 38 +++++---- vllm_ascend/quantization/w4a8_dynamic.py | 1 + vllm_ascend/quantization/w8a8_dynamic.py | 1 + 6 files changed, 82 insertions(+), 76 deletions(-) diff --git a/tests/ut/ops/test_moe_comm_method.py b/tests/ut/ops/test_moe_comm_method.py index b3e5cc74..f258f8e7 100644 --- a/tests/ut/ops/test_moe_comm_method.py +++ b/tests/ut/ops/test_moe_comm_method.py @@ -7,6 +7,7 @@ from tests.ut.base import TestBase from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl) +from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType class TestMoECommMethod(TestBase): @@ -67,7 +68,7 @@ class TestMoECommMethod(TestBase): # Verify prepare was called with correct arguments mock_pf_instance.prepare.assert_called_once_with( - hidden_states, router_logits, False, False) + hidden_states, router_logits, False, False, QuantType.NONE) # Test finalize method comm_impl.finalize(h_out, @@ -115,7 +116,7 @@ class TestMoECommMethod(TestBase): # Verify prepare was called with correct arguments mock_pf_instance.prepare.assert_called_once_with( - hidden_states, router_logits, False, False) + hidden_states, router_logits, False, False, QuantType.NONE) # Test finalize method comm_impl.finalize(h_out, @@ -165,7 +166,7 @@ class TestMoECommMethod(TestBase): # Verify prepare was called with correct arguments mock_pf_instance.prepare.assert_called_once_with( - hidden_states, router_logits, False, False) + hidden_states, router_logits, False, False, QuantType.NONE) @patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_current_vllm_config") @patch("vllm_ascend.ops.fused_moe.moe_comm_method.get_forward_context") diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index 14b615ba..113cd47e 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -37,6 +37,11 @@ from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map, from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.ops.fused_moe.experts_selector import select_experts from vllm_ascend.ops.fused_moe.moe_comm_method import setup_moe_comm_method +from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType +from vllm_ascend.quantization.w4a8_dynamic import \ + AscendW4A8DynamicFusedMoEMethod +from vllm_ascend.quantization.w8a8_dynamic import \ + AscendW8A8DynamicFusedMoEMethod from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, enable_sp, is_310p, is_enable_nz, npu_stream_switch, shared_expert_dp_enabled, @@ -289,7 +294,23 @@ class AscendFusedMoE(FusedMoE): self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp - setup_moe_comm_method(self.moe_config, self.quant_method) + setup_moe_comm_method(self.moe_config) + self.quant_type = self._get_quant_type() + + def _get_quant_type(self) -> QuantType: + quant_method = self.quant_method + if not hasattr(quant_method, + "quant_method") or quant_method.quant_method is None: + return QuantType.NONE + + method = quant_method.quant_method + + if isinstance(method, AscendW8A8DynamicFusedMoEMethod): + return QuantType.W8A8 + elif isinstance(method, AscendW4A8DynamicFusedMoEMethod): + return QuantType.W4A8 + else: + return QuantType.NONE def update_expert_map(self, new_expert_map): self.expert_map = new_expert_map @@ -334,7 +355,8 @@ class AscendFusedMoE(FusedMoE): hidden_states=hidden_states, router_logits=router_logits, replace_allreduce=forward_context.sp_enabled, - enable_shared_expert_dp=self.enable_shared_expert_dp) + enable_shared_expert_dp=self.enable_shared_expert_dp, + quant_type=self.quant_type) if isinstance(hidden_states, tuple): hidden_states, pertoken_scale = hidden_states diff --git a/vllm_ascend/ops/fused_moe/moe_comm_method.py b/vllm_ascend/ops/fused_moe/moe_comm_method.py index 6094aafb..c89eb1df 100644 --- a/vllm_ascend/ops/fused_moe/moe_comm_method.py +++ b/vllm_ascend/ops/fused_moe/moe_comm_method.py @@ -31,10 +31,6 @@ from vllm_ascend.ops.fused_moe.prepare_finalize import ( from vllm_ascend.ops.fused_moe.token_dispatcher import ( TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather, TokenDispatcherWithMC2, TokenDispatcherWithMoge) -from vllm_ascend.quantization.w4a8_dynamic import \ - AscendW4A8DynamicFusedMoEMethod -from vllm_ascend.quantization.w8a8_dynamic import \ - AscendW8A8DynamicFusedMoEMethod _MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {} @@ -44,54 +40,37 @@ def get_moe_comm_method( return _MoECommMethods.get(moe_comm_type, None) -def setup_moe_comm_method(moe_config, quant_method): - _MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl( - moe_config, quant_method) - _MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl( - moe_config, quant_method) - _MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config, quant_method) +def setup_moe_comm_method(moe_config): + _MoECommMethods[MoECommType.ALLTOALL] = AlltoAllCommImpl(moe_config) + _MoECommMethods[MoECommType.ALLGATHER] = AllGatherCommImpl(moe_config) + _MoECommMethods[MoECommType.MC2] = MC2CommImpl(moe_config) _MoECommMethods[MoECommType.NAIVE_MULTICAST] = NaiveMulticastCommImpl( - moe_config, quant_method) + moe_config) class MoECommMethod(ABC): """Base class for MoE communication methods.""" - def __init__(self, moe_config: FusedMoEConfig, quant_method=None): + def __init__(self, moe_config: FusedMoEConfig): self.model_type = get_current_vllm_config( ).model_config.hf_config.model_type self.moe_config = moe_config self.token_dispatcher = self._get_token_dispatcher() - self.quant_type = self._get_quant_type(quant_method) - self.with_quant = self.quant_type != QuantType.NONE self.prepare_finalize = self._get_prepare_finalize() - def _get_quant_type(self, quant_method) -> QuantType: - if not hasattr(quant_method, - "quant_method") or quant_method.quant_method is None: - return QuantType.NONE - - method = quant_method.quant_method - - if isinstance(method, AscendW8A8DynamicFusedMoEMethod): - return QuantType.W8A8 - elif isinstance(method, AscendW4A8DynamicFusedMoEMethod): - return QuantType.W4A8 - else: - return QuantType.NONE - def prepare( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, enable_shared_expert_dp: bool = False, - replace_allreduce: bool = False + replace_allreduce: bool = False, + quant_type: QuantType = QuantType.NONE, ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: hidden_states, router_logits, mc2_mask, context_metadata = self.prepare_finalize.prepare( hidden_states, router_logits, enable_shared_expert_dp, - replace_allreduce) + replace_allreduce, quant_type) return hidden_states, router_logits, mc2_mask, context_metadata def finalize(self, @@ -112,6 +91,8 @@ class MoECommMethod(ABC): topk_ids: torch.Tensor, activation: str = "silu", apply_router_weight_on_input: bool = False, + use_int8_w8a8: bool = False, + use_int4_w4a8: bool = False, global_num_experts: Optional[int] = None, expert_map: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None, @@ -151,29 +132,29 @@ class MoECommMethod(ABC): dynamic_scale_for_share=dynamic_scale_for_share, mc2_mask=mc2_mask, apply_router_weight_on_input=apply_router_weight_on_input, - with_quant=self.with_quant, + with_quant=use_int8_w8a8 or use_int4_w4a8, dynamic_eplb=dynamic_eplb, pertoken_scale=pertoken_scale) permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type, topk_scales, context_metadata = \ results["hidden_states"], results["group_list"], results.get("dynamic_scale"), results["group_list_type"], results.get("topk_scales"), results.get("context_metadata") - mlp_output = unified_apply_mlp( - hidden_states=permuted_hidden_states, - w1=w1, - w1_scale=w1_scale, - w2=w2, - w2_scale=w2_scale, - group_list=expert_tokens, - dynamic_scale=dynamic_scale, - group_list_type=group_list_type, - w1_scale_bias=w1_scale_bias, - w2_scale_bias=w2_scale_bias, - topk_scales=topk_scales, - with_quant=self.with_quant, - fusion=self.quant_type == QuantType.W8A8, - need_trans=need_trans, - dynamic_eplb=dynamic_eplb) + mlp_output = unified_apply_mlp(hidden_states=permuted_hidden_states, + w1=w1, + w1_scale=w1_scale, + w2=w2, + w2_scale=w2_scale, + group_list=expert_tokens, + dynamic_scale=dynamic_scale, + group_list_type=group_list_type, + w1_scale_bias=w1_scale_bias, + w2_scale_bias=w2_scale_bias, + topk_scales=topk_scales, + with_quant=use_int8_w8a8 + or use_int4_w4a8, + fusion=use_int8_w8a8, + need_trans=need_trans, + dynamic_eplb=dynamic_eplb) final_hidden_states = self.token_dispatcher.token_combine( hidden_states=mlp_output, context_metadata=context_metadata) @@ -226,8 +207,7 @@ class AllGatherCommImpl(MoECommMethod): num_local_experts=self.moe_config.num_local_experts) def _get_prepare_finalize(self): - return PrepareAndFinalizeWithAllGather(self.moe_config, - self.quant_type) + return PrepareAndFinalizeWithAllGather(self.moe_config) class MC2CommImpl(MoECommMethod): @@ -244,7 +224,7 @@ class MC2CommImpl(MoECommMethod): return TokenDispatcherWithMC2() def _get_prepare_finalize(self): - return PrepareAndFinalizeWithMC2(self.moe_config, self.quant_type) + return PrepareAndFinalizeWithMC2(self.moe_config) class AlltoAllCommImpl(MoECommMethod): @@ -264,7 +244,7 @@ class AlltoAllCommImpl(MoECommMethod): num_local_experts=self.moe_config.num_local_experts) def _get_prepare_finalize(self): - return PrepareAndFinalizeWithAll2All(self.moe_config, self.quant_type) + return PrepareAndFinalizeWithAll2All(self.moe_config) class NaiveMulticastCommImpl(MoECommMethod): @@ -293,5 +273,4 @@ class NaiveMulticastCommImpl(MoECommMethod): num_local_experts=self.moe_config.num_local_experts) def _get_prepare_finalize(self): - return PrepareAndFinalizeWithNaiveMulticast(self.moe_config, - self.quant_type) + return PrepareAndFinalizeWithNaiveMulticast(self.moe_config) diff --git a/vllm_ascend/ops/fused_moe/prepare_finalize.py b/vllm_ascend/ops/fused_moe/prepare_finalize.py index f54d4579..46640006 100644 --- a/vllm_ascend/ops/fused_moe/prepare_finalize.py +++ b/vllm_ascend/ops/fused_moe/prepare_finalize.py @@ -53,11 +53,8 @@ class PrepareAndFinalize(ABC): sizes, ranks, and communication settings. """ - def __init__(self, - moe_config: FusedMoEConfig, - quant_type: QuantType = QuantType.NONE): + def __init__(self, moe_config: FusedMoEConfig): self.moe_config = moe_config - self.quant_type = quant_type @abstractmethod def prepare( @@ -65,7 +62,8 @@ class PrepareAndFinalize(ABC): hidden_states: torch.Tensor, router_logits: torch.Tensor, enable_shared_expert_dp: bool = False, - replace_allreduce: bool = False + replace_allreduce: bool = False, + quant_type: QuantType = QuantType.NONE ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: """ @@ -79,6 +77,7 @@ class PrepareAndFinalize(ABC): router_logits (torch.Tensor): Router outputs, shape [num_tokens, num_experts] enable_shared_expert_dp (bool): Skip DP communication for shared experts replace_allreduce (bool): Bypass default all-reduce behavior + quant_type: none, w8a8 or w4a8 Returns: Tuple of: @@ -117,10 +116,8 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize): Will be used when num_tokens exceed mc2's limitation (512 tokens/rank). """ - def __init__(self, - moe_config: FusedMoEConfig, - quant_type: QuantType = QuantType.NONE): - super().__init__(moe_config, quant_type) + def __init__(self, moe_config: FusedMoEConfig): + super().__init__(moe_config) self._restore_tp_across_dp() def _restore_tp_across_dp(self): @@ -133,7 +130,8 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize): hidden_states: torch.Tensor, router_logits: torch.Tensor, enable_shared_expert_dp: bool = False, - replace_allreduce: bool = False + replace_allreduce: bool = False, + quant_type=QuantType.NONE ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: """ @@ -211,10 +209,8 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All): Relies on `mc2_mask` and `padded_num_tokens` from forward_context for alignment. """ - def __init__(self, - moe_config: FusedMoEConfig, - quant_type: QuantType = QuantType.NONE): - super().__init__(moe_config, quant_type) + def __init__(self, moe_config: FusedMoEConfig): + super().__init__(moe_config) self._restore_tp_across_dp() def _restore_tp_across_dp(self): @@ -231,7 +227,8 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All): hidden_states: torch.Tensor, router_logits: torch.Tensor, enable_shared_expert_dp: bool = False, - replace_allreduce: bool = False + replace_allreduce: bool = False, + quant_type=QuantType.NONE ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: """ @@ -312,6 +309,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize): router_logits: torch.Tensor, enable_shared_expert_dp: bool = False, replace_allreduce: bool = False, + quant_type=QuantType.NONE ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: """ @@ -322,7 +320,8 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize): Tuple of (global_hidden_states, global_router_logits, None) """ if enable_sp(): - return self._prepare_with_ep_group(hidden_states, router_logits) + return self._prepare_with_ep_group(hidden_states, router_logits, + quant_type) return self._prepare_with_dp_group(hidden_states, router_logits, enable_shared_expert_dp, @@ -332,10 +331,11 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize): self, hidden_states: torch.Tensor, router_logits: torch.Tensor, + quant_type=QuantType.NONE ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: pertoken_scale = None - if self.quant_type == QuantType.W8A8: + if quant_type == QuantType.W8A8: hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant( hidden_states) pertoken_scale = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( @@ -356,6 +356,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize): router_logits: torch.Tensor, enable_shared_expert_dp: bool = False, replace_allreduce: bool = False, + quant_type=QuantType.NONE ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: """ @@ -500,7 +501,8 @@ class PrepareAndFinalizeWithNaiveMulticast(PrepareAndFinalize): hidden_states: torch.Tensor, router_logits: torch.Tensor, enable_shared_expert_dp: bool = False, - replace_allreduce: bool = False + replace_allreduce: bool = False, + quant_type=QuantType.NONE ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: """ diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py index cd889a04..77f0f4b2 100644 --- a/vllm_ascend/quantization/w4a8_dynamic.py +++ b/vllm_ascend/quantization/w4a8_dynamic.py @@ -386,6 +386,7 @@ class AscendW4A8DynamicFusedMoEMethod: w2_scale_bias=layer.w2_scale_bias, topk_weights=topk_weights, topk_ids=topk_ids, + use_int4_w4a8=True, expert_map=expert_map, log2phy=log2phy, global_redundant_expert_num=global_redundant_expert_num, diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index b5f10c4d..8bef2567 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -256,6 +256,7 @@ class AscendW8A8DynamicFusedMoEMethod: w2_scale=layer.w2_weight_scale, topk_weights=topk_weights, topk_ids=topk_ids, + use_int8_w8a8=True, expert_map=expert_map, log2phy=log2phy, global_redundant_expert_num=global_redundant_expert_num,