[EPLB][bugfix] Bugfix for fused mc2 (#6794)

### What this PR does / why we need it?
This pull request addresses a bug related to the fused mc2 functionality
within the EPLB (Expert Parallelism Load Balancing) system, specifically
impacting quantization and MoE communication.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.15.0
- vLLM main:
83b47f67b1

Signed-off-by: Spicy-Stick <873805887@qq.com>
Signed-off-by: root <root@localhost.localdomain>
This commit is contained in:
JIACHENG XU
2026-03-09 11:26:57 +08:00
committed by GitHub
parent 06ec136f08
commit 23bf5d4d48
5 changed files with 50 additions and 28 deletions

View File

@@ -235,28 +235,28 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
topk_weights = topk_weights.to(self.in_dtype)
moe_comm_method = get_forward_context().moe_comm_method
if self.dynamic_eplb:
w1 = layer.w13_weight_list
w1_scale = layer.w13_weight_scale_fp32_list
w2 = layer.w2_weight_list
w2_scale = layer.w2_weight_scale_list
else:
w1 = [layer.w13_weight]
w1_scale = [layer.w13_weight_scale_fp32]
w2 = [layer.w2_weight]
w2_scale = [layer.w2_weight_scale]
fused_scale_flag = (
get_forward_context().moe_comm_type == MoECommType.FUSED_MC2
and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1
)
if self.dynamic_eplb:
w1 = layer.w13_weight_list
w1_scale = layer.fused_w1_scale_list if fused_scale_flag else layer.w13_weight_scale_fp32_list
w2 = layer.w2_weight_list
w2_scale = layer.fused_w2_scale_list if fused_scale_flag else layer.w2_weight_scale_list
else:
w1 = [layer.w13_weight]
w1_scale = [layer.fused_w1_scale] if fused_scale_flag else [layer.w13_weight_scale_fp32]
w2 = [layer.w2_weight]
w2_scale = [layer.fused_w2_scale] if fused_scale_flag else [layer.w2_weight_scale]
final_hidden_states = moe_comm_method.fused_experts(
hidden_states=x,
pertoken_scale=pertoken_scale,
w1=w1,
w1_scale=[layer.fused_w1_scale] if fused_scale_flag else w1_scale,
w1_scale=w1_scale,
w2=w2,
w2_scale=[layer.fused_w2_scale] if fused_scale_flag else w2_scale,
w2_scale=w2_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
use_int8_w8a8=True,
@@ -282,8 +282,9 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(layer.w2_weight_scale.data.shape[0], -1)
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(layer.w2_weight_offset.data.shape[0], -1)
layer.fused_w1_scale = scale_from_float_to_int64(layer.w13_weight_scale.data)
layer.fused_w2_scale = scale_from_float_to_int64(layer.w2_weight_scale.data)
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
layer.fused_w1_scale = scale_from_float_to_int64(layer.w13_weight_scale.data)
layer.fused_w2_scale = scale_from_float_to_int64(layer.w2_weight_scale.data)
if self.dynamic_eplb:
layer.w13_weight_list = [weight.clone() for weight in layer.w13_weight.data.unbind(dim=0)]
@@ -292,9 +293,21 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
weight.clone() for weight in layer.w13_weight_scale_fp32.data.unbind(dim=0)
]
layer.w2_weight_scale_list = [weight.clone() for weight in layer.w2_weight_scale.data.unbind(dim=0)]
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
layer.fused_w1_scale_list = [
weight.clone()
for weight in layer.fused_w1_scale.view(len(layer.w13_weight_list), -1).data.unbind(dim=0)
]
layer.fused_w2_scale_list = [
weight.clone()
for weight in layer.fused_w2_scale.view(len(layer.w2_weight_list), -1).data.unbind(dim=0)
]
del layer.w13_weight
del layer.w2_weight
del layer.w13_weight_scale
del layer.w13_weight_scale_fp32
del layer.w2_weight_scale
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
del layer.fused_w1_scale
del layer.fused_w2_scale
torch.npu.empty_cache()