[EPLB][bugfix] Bugfix for fused mc2 (#6794)

### What this PR does / why we need it?
This pull request addresses a bug related to the fused mc2 functionality
within the EPLB (Expert Parallelism Load Balancing) system, specifically
impacting quantization and MoE communication.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.15.0
- vLLM main:
83b47f67b1

Signed-off-by: Spicy-Stick <873805887@qq.com>
Signed-off-by: root <root@localhost.localdomain>
This commit is contained in:
JIACHENG XU
2026-03-09 11:26:57 +08:00
committed by GitHub
parent 06ec136f08
commit 23bf5d4d48
5 changed files with 50 additions and 28 deletions

View File

@@ -4,6 +4,7 @@ from unittest.mock import MagicMock, patch
import torch
from vllm_ascend.eplb.adaptor.vllm_adaptor import VllmEplbAdaptor
from vllm_ascend.quantization.methods.base import QuantType
from transformers import DeepseekV2Config
@@ -17,6 +18,8 @@ class TestVllmAdaptor(unittest.TestCase):
mock_model.get_expert_map.return_value = [i for i in range(n_routed_experts)]
mock_model.get_log2phy_map.return_value = [i for i in range(n_routed_experts)]
self.model = mock_model
num_dense_layers = getattr(config, "first_k_dense_replace", 0)
self.model.model.layers[num_dense_layers].mlp.experts.quant_type = QuantType.W8A8
self.mock_rank = patch("vllm_ascend.eplb.adaptor.vllm_adaptor.dist.get_rank", return_value=0).start()
self.mock_size = patch("vllm_ascend.eplb.adaptor.vllm_adaptor.dist.get_world_size", return_value=4).start()

View File

@@ -9,7 +9,6 @@ from vllm.distributed import get_dp_group, get_ep_group, get_tensor_model_parall
from vllm.forward_context import BatchDescriptor, get_forward_context, set_forward_context
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.utils import (
AscendDeviceType,
enable_sp,
@@ -243,11 +242,10 @@ def select_moe_comm_method(num_tokens: int, vllm_config: VllmConfig, is_draft_mo
moe_comm_type = MoECommType.ALLGATHER
elif soc_version in {AscendDeviceType.A3}:
dynamic_eplb = get_ascend_config().eplb_config.dynamic_eplb
# TODO: drop the EP-size guard when dispatch_ffn_combine supports larger EP sizes
# TODO: drop speculative method guard when dispatch_gmm_combine_decode supports w16a16
fused_mc2_enable = envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 and quant_type == "w8a8_dynamic"
dispatch_ffn_combine_enable = get_ep_group().world_size <= 32 and (not is_draft_model) and (not dynamic_eplb)
dispatch_ffn_combine_enable = get_ep_group().world_size <= 32 and (not is_draft_model)
if num_tokens <= mc2_tokens_capacity:
fused_decode_enable = fused_mc2_enable
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:

View File

@@ -22,6 +22,9 @@ import torch
import torch.distributed as dist
from vllm.logger import logger
import vllm_ascend.envs as envs_ascend
from vllm_ascend.quantization.methods.base import QuantType
class VllmEplbAdaptor:
def __init__(self, model, **args):
@@ -59,12 +62,19 @@ class VllmEplbAdaptor:
def init_expert_param_per_layer(self):
self.param_dict = dict()
if self.model.quant_config is not None:
quant_type = self.model.model.layers[self.num_dense_layers].mlp.experts.quant_type
if quant_type == QuantType.W8A8:
self.expert_weight_names = [
"w13_weight_list",
"w2_weight_list",
"w13_weight_scale_fp32_list",
"w2_weight_scale_list",
]
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
self.expert_weight_names.append("fused_w1_scale_list")
self.expert_weight_names.append("fused_w2_scale_list")
else:
raise ValueError(f"EPLB not support {quant_type}")
else:
self.expert_weight_names = ["w13_weight", "w2_weight"]

View File

@@ -70,7 +70,7 @@ class FusedExpertsResult:
before_dispatch_evt: torch.npu.Event | None = None
before_combine_evt: torch.npu.Event | None = None
# For dynamic_eplb
group_list_type: int | None = None
group_list_type: int = 1
expert_tokens: torch.Tensor | None = None
@@ -355,7 +355,6 @@ class FusedMC2CommImpl(MoECommMethod):
if log2phy is not None:
topk_ids = log2phy[topk_ids]
group_list_type = None
expert_tokens = None
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
out = torch.empty_like(hidden_states)
@@ -375,7 +374,6 @@ class FusedMC2CommImpl(MoECommMethod):
expert_tokens = self.expert_token_nums
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
assert expert_map is not None, "expert_map cannot be None."
group_list_type = 1
out, expert_tokens = torch.ops._C_ascend.dispatch_gmm_combine_decode( # type: ignore
x=hidden_states,
expert_ids=topk_ids,
@@ -393,4 +391,4 @@ class FusedMC2CommImpl(MoECommMethod):
)
else:
raise ValueError(f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}")
return FusedExpertsResult(routed_out=out, group_list_type=group_list_type, expert_tokens=expert_tokens)
return FusedExpertsResult(routed_out=out, expert_tokens=expert_tokens)

View File

@@ -235,28 +235,28 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
topk_weights = topk_weights.to(self.in_dtype)
moe_comm_method = get_forward_context().moe_comm_method
if self.dynamic_eplb:
w1 = layer.w13_weight_list
w1_scale = layer.w13_weight_scale_fp32_list
w2 = layer.w2_weight_list
w2_scale = layer.w2_weight_scale_list
else:
w1 = [layer.w13_weight]
w1_scale = [layer.w13_weight_scale_fp32]
w2 = [layer.w2_weight]
w2_scale = [layer.w2_weight_scale]
fused_scale_flag = (
get_forward_context().moe_comm_type == MoECommType.FUSED_MC2
and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1
)
if self.dynamic_eplb:
w1 = layer.w13_weight_list
w1_scale = layer.fused_w1_scale_list if fused_scale_flag else layer.w13_weight_scale_fp32_list
w2 = layer.w2_weight_list
w2_scale = layer.fused_w2_scale_list if fused_scale_flag else layer.w2_weight_scale_list
else:
w1 = [layer.w13_weight]
w1_scale = [layer.fused_w1_scale] if fused_scale_flag else [layer.w13_weight_scale_fp32]
w2 = [layer.w2_weight]
w2_scale = [layer.fused_w2_scale] if fused_scale_flag else [layer.w2_weight_scale]
final_hidden_states = moe_comm_method.fused_experts(
hidden_states=x,
pertoken_scale=pertoken_scale,
w1=w1,
w1_scale=[layer.fused_w1_scale] if fused_scale_flag else w1_scale,
w1_scale=w1_scale,
w2=w2,
w2_scale=[layer.fused_w2_scale] if fused_scale_flag else w2_scale,
w2_scale=w2_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
use_int8_w8a8=True,
@@ -282,6 +282,7 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(layer.w2_weight_scale.data.shape[0], -1)
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(layer.w2_weight_offset.data.shape[0], -1)
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
layer.fused_w1_scale = scale_from_float_to_int64(layer.w13_weight_scale.data)
layer.fused_w2_scale = scale_from_float_to_int64(layer.w2_weight_scale.data)
@@ -292,9 +293,21 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
weight.clone() for weight in layer.w13_weight_scale_fp32.data.unbind(dim=0)
]
layer.w2_weight_scale_list = [weight.clone() for weight in layer.w2_weight_scale.data.unbind(dim=0)]
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
layer.fused_w1_scale_list = [
weight.clone()
for weight in layer.fused_w1_scale.view(len(layer.w13_weight_list), -1).data.unbind(dim=0)
]
layer.fused_w2_scale_list = [
weight.clone()
for weight in layer.fused_w2_scale.view(len(layer.w2_weight_list), -1).data.unbind(dim=0)
]
del layer.w13_weight
del layer.w2_weight
del layer.w13_weight_scale
del layer.w13_weight_scale_fp32
del layer.w2_weight_scale
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
del layer.fused_w1_scale
del layer.fused_w2_scale
torch.npu.empty_cache()