From 23bf5d4d48e6ec09e2b4f726279591a1b42f033b Mon Sep 17 00:00:00 2001 From: JIACHENG XU <56331162+Spicy-Stick@users.noreply.github.com> Date: Mon, 9 Mar 2026 11:26:57 +0800 Subject: [PATCH] [EPLB][bugfix] Bugfix for fused mc2 (#6794) ### What this PR does / why we need it? This pull request addresses a bug related to the fused mc2 functionality within the EPLB (Expert Parallelism Load Balancing) system, specifically impacting quantization and MoE communication. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/83b47f67b1dfad505606070ae4d9f83e50ad4ebd Signed-off-by: Spicy-Stick <873805887@qq.com> Signed-off-by: root --- tests/ut/eplb/adaptor/test_vllm_adaptor.py | 3 ++ vllm_ascend/ascend_forward_context.py | 4 +- vllm_ascend/eplb/adaptor/vllm_adaptor.py | 22 +++++++--- vllm_ascend/ops/fused_moe/moe_comm_method.py | 6 +-- .../quantization/methods/w8a8_dynamic.py | 43 ++++++++++++------- 5 files changed, 50 insertions(+), 28 deletions(-) diff --git a/tests/ut/eplb/adaptor/test_vllm_adaptor.py b/tests/ut/eplb/adaptor/test_vllm_adaptor.py index a5dc0559..f9a41b60 100644 --- a/tests/ut/eplb/adaptor/test_vllm_adaptor.py +++ b/tests/ut/eplb/adaptor/test_vllm_adaptor.py @@ -4,6 +4,7 @@ from unittest.mock import MagicMock, patch import torch from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor +from vllm_ascend.quantization.methods.base import QuantType from transformers import DeepseekV2Config @@ -17,6 +18,8 @@ class TestVllmAdaptor(unittest.TestCase): mock_model.get_expert_map.return_value = [i for i in range(n_routed_experts)] mock_model.get_log2phy_map.return_value = [i for i in range(n_routed_experts)] self.model = mock_model + num_dense_layers = getattr(config, "first_k_dense_replace", 0) + self.model.model.layers[num_dense_layers].mlp.experts.quant_type = QuantType.W8A8 self.mock_rank = patch("vllm_ascend.eplb.adaptor.vllm_adaptor.dist.get_rank", return_value=0).start() self.mock_size = patch("vllm_ascend.eplb.adaptor.vllm_adaptor.dist.get_world_size", return_value=4).start() diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 936fa603..116f562a 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -9,7 +9,6 @@ from vllm.distributed import get_dp_group, get_ep_group, get_tensor_model_parall from vllm.forward_context import BatchDescriptor, get_forward_context, set_forward_context import vllm_ascend.envs as envs_ascend -from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.utils import ( AscendDeviceType, enable_sp, @@ -243,11 +242,10 @@ def select_moe_comm_method(num_tokens: int, vllm_config: VllmConfig, is_draft_mo moe_comm_type = MoECommType.ALLGATHER elif soc_version in {AscendDeviceType.A3}: - dynamic_eplb = get_ascend_config().eplb_config.dynamic_eplb # TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes # TODO: drop speculative method guard when dispatch_gmm_combine_decode supports w16a16 fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic" - dispatch_ffn_combine_enable = get_ep_group().world_size <= 32 and (not is_draft_model) and (not dynamic_eplb) + dispatch_ffn_combine_enable = get_ep_group().world_size <= 32 and (not is_draft_model) if num_tokens <= mc2_tokens_capacity: fused_decode_enable = fused_mc2_enable if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: diff --git a/vllm_ascend/eplb/adaptor/vllm_adaptor.py b/vllm_ascend/eplb/adaptor/vllm_adaptor.py index cf414ac5..7cd71f89 100644 --- a/vllm_ascend/eplb/adaptor/vllm_adaptor.py +++ b/vllm_ascend/eplb/adaptor/vllm_adaptor.py @@ -22,6 +22,9 @@ import torch import torch.distributed as dist from vllm.logger import logger +import vllm_ascend.envs as envs_ascend +from vllm_ascend.quantization.methods.base import QuantType + class VllmEplbAdaptor: def __init__(self, model, **args): @@ -59,12 +62,19 @@ class VllmEplbAdaptor: def init_expert_param_per_layer(self): self.param_dict = dict() if self.model.quant_config is not None: - self.expert_weight_names = [ - "w13_weight_list", - "w2_weight_list", - "w13_weight_scale_fp32_list", - "w2_weight_scale_list", - ] + quant_type = self.model.model.layers[self.num_dense_layers].mlp.experts.quant_type + if quant_type == QuantType.W8A8: + self.expert_weight_names = [ + "w13_weight_list", + "w2_weight_list", + "w13_weight_scale_fp32_list", + "w2_weight_scale_list", + ] + if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: + self.expert_weight_names.append("fused_w1_scale_list") + self.expert_weight_names.append("fused_w2_scale_list") + else: + raise ValueError(f"EPLB not support {quant_type}") else: self.expert_weight_names = ["w13_weight", "w2_weight"] diff --git a/vllm_ascend/ops/fused_moe/moe_comm_method.py b/vllm_ascend/ops/fused_moe/moe_comm_method.py index 39e80342..14d72531 100644 --- a/vllm_ascend/ops/fused_moe/moe_comm_method.py +++ b/vllm_ascend/ops/fused_moe/moe_comm_method.py @@ -70,7 +70,7 @@ class FusedExpertsResult: before_dispatch_evt: torch.npu.Event | None = None before_combine_evt: torch.npu.Event | None = None # For dynamic_eplb - group_list_type: int | None = None + group_list_type: int = 1 expert_tokens: torch.Tensor | None = None @@ -355,7 +355,6 @@ class FusedMC2CommImpl(MoECommMethod): if log2phy is not None: topk_ids = log2phy[topk_ids] - group_list_type = None expert_tokens = None if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: out = torch.empty_like(hidden_states) @@ -375,7 +374,6 @@ class FusedMC2CommImpl(MoECommMethod): expert_tokens = self.expert_token_nums elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2: assert expert_map is not None, "expert_map cannot be None." - group_list_type = 1 out, expert_tokens = torch.ops._C_ascend.dispatch_gmm_combine_decode( # type: ignore x=hidden_states, expert_ids=topk_ids, @@ -393,4 +391,4 @@ class FusedMC2CommImpl(MoECommMethod): ) else: raise ValueError(f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}") - return FusedExpertsResult(routed_out=out, group_list_type=group_list_type, expert_tokens=expert_tokens) + return FusedExpertsResult(routed_out=out, expert_tokens=expert_tokens) diff --git a/vllm_ascend/quantization/methods/w8a8_dynamic.py b/vllm_ascend/quantization/methods/w8a8_dynamic.py index b150d1a5..d9d838ec 100644 --- a/vllm_ascend/quantization/methods/w8a8_dynamic.py +++ b/vllm_ascend/quantization/methods/w8a8_dynamic.py @@ -235,28 +235,28 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme): topk_weights = topk_weights.to(self.in_dtype) moe_comm_method = get_forward_context().moe_comm_method - if self.dynamic_eplb: - w1 = layer.w13_weight_list - w1_scale = layer.w13_weight_scale_fp32_list - w2 = layer.w2_weight_list - w2_scale = layer.w2_weight_scale_list - else: - w1 = [layer.w13_weight] - w1_scale = [layer.w13_weight_scale_fp32] - w2 = [layer.w2_weight] - w2_scale = [layer.w2_weight_scale] - fused_scale_flag = ( get_forward_context().moe_comm_type == MoECommType.FUSED_MC2 and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1 ) + if self.dynamic_eplb: + w1 = layer.w13_weight_list + w1_scale = layer.fused_w1_scale_list if fused_scale_flag else layer.w13_weight_scale_fp32_list + w2 = layer.w2_weight_list + w2_scale = layer.fused_w2_scale_list if fused_scale_flag else layer.w2_weight_scale_list + else: + w1 = [layer.w13_weight] + w1_scale = [layer.fused_w1_scale] if fused_scale_flag else [layer.w13_weight_scale_fp32] + w2 = [layer.w2_weight] + w2_scale = [layer.fused_w2_scale] if fused_scale_flag else [layer.w2_weight_scale] + final_hidden_states = moe_comm_method.fused_experts( hidden_states=x, pertoken_scale=pertoken_scale, w1=w1, - w1_scale=[layer.fused_w1_scale] if fused_scale_flag else w1_scale, + w1_scale=w1_scale, w2=w2, - w2_scale=[layer.fused_w2_scale] if fused_scale_flag else w2_scale, + w2_scale=w2_scale, topk_weights=topk_weights, topk_ids=topk_ids, use_int8_w8a8=True, @@ -282,8 +282,9 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme): layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(layer.w2_weight_scale.data.shape[0], -1) layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(layer.w2_weight_offset.data.shape[0], -1) - layer.fused_w1_scale = scale_from_float_to_int64(layer.w13_weight_scale.data) - layer.fused_w2_scale = scale_from_float_to_int64(layer.w2_weight_scale.data) + if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: + layer.fused_w1_scale = scale_from_float_to_int64(layer.w13_weight_scale.data) + layer.fused_w2_scale = scale_from_float_to_int64(layer.w2_weight_scale.data) if self.dynamic_eplb: layer.w13_weight_list = [weight.clone() for weight in layer.w13_weight.data.unbind(dim=0)] @@ -292,9 +293,21 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme): weight.clone() for weight in layer.w13_weight_scale_fp32.data.unbind(dim=0) ] layer.w2_weight_scale_list = [weight.clone() for weight in layer.w2_weight_scale.data.unbind(dim=0)] + if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: + layer.fused_w1_scale_list = [ + weight.clone() + for weight in layer.fused_w1_scale.view(len(layer.w13_weight_list), -1).data.unbind(dim=0) + ] + layer.fused_w2_scale_list = [ + weight.clone() + for weight in layer.fused_w2_scale.view(len(layer.w2_weight_list), -1).data.unbind(dim=0) + ] del layer.w13_weight del layer.w2_weight del layer.w13_weight_scale del layer.w13_weight_scale_fp32 del layer.w2_weight_scale + if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: + del layer.fused_w1_scale + del layer.fused_w2_scale torch.npu.empty_cache()