BugFix: Resolve shape mismatch in eplb update and calculation issues in quant_apply_mlp (#4777)

## Description
This PR addresses two key issues in the MoE module when redundant
experts are enabled, and fixes a calculation precision bug in the
forward inference of quantized MLP:

### 1. Shape Mismatch in EPLB Expert Map Update
- **Root Cause**: 
When redundant experts are turned on, a shape inconsistency occurs
during the expert map update in `Vllm_apaptor`:
- The shape of `self.expert_map_per_layer[layer_id]` is
`[num_physical_experts,]` (aligned with physical expert count).
- The shape of `updated_expert_map` is `[num_logical_experts,]` (aligned
with logical expert count).
- Indices in `self.expert_map_per_layer[layer_id]` that exceed the
logical expert count cannot be properly mapped, leading to tensor shape
mismatch errors.
- The same shape mismatch exists in the `log2phy` map update (between
`self.log2phy_map_per_layer[layer_id]` and `updated_log2phy_map`).

- **Fix**:
- Fix the shape initialization of `expert_map_per_layer` and
`log2phy_map_per_layer` to be consistently set to
`[num_physical_experts,]` across the module lifecycle.
- Align the shape of `updated_expert_map` and `updated_log2phy_map` with
the pre-initialized physical-expert-sized tensors during update
operations, ensuring shape consistency for index mapping.

### 2. Calculation Precision Issue in Quantized MoE MLP Forward
Inference
- **Root Cause**:
In the forward pass of `moe_mlp`, the
`torch_npu.npu_dequant_swiglu_quant` operator only accepts group lists
in **Count format** as input. However, the group list provided by
`quant_apply_mlp` was in **Cumsum format**, which caused operator input
format mismatch and degraded calculation precision.

- **Fix**:
- Convert the cumsum-formatted group list from `quant_apply_mlp` to
Count format before passing it to `torch_npu.npu_dequant_swiglu_quant`.
- Ensure the input format of the dequantization operator meets its
requirements, restoring the expected calculation precision for quantized
MoE MLP layers.

## Impact
- Resolves shape mismatch errors in EPLB expert/log2phy map updates when
redundant experts are enabled, ensuring stable expert routing.
- Fixes quantized MoE MLP forward precision issues on NPU, aligning
operator input formats with NPU kernel requirements.
- No breaking changes to existing interfaces; the fixes are
backward-compatible for scenarios without redundant experts enabled.

---------

Signed-off-by: Che Ruan <cr623@ic.ac.uk>
Signed-off-by: Mercykid-bash <ruanche0218@gmail.com>
Co-authored-by: Che Ruan <cr623@ic.ac.uk>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
Mercykid-bash
2025-12-09 15:46:58 +08:00
committed by GitHub
parent 695e5c9ebc
commit 8f45f9ce29
5 changed files with 12 additions and 18 deletions

View File

@@ -50,10 +50,6 @@ class D2DExpertWeightLoader:
)
return
# If neither send nor receive task is needed for this layer on this rank, return
if not (expert_send_info or expert_recv_info):
return
self.updated_expert_map = updated_expert_map
self.layer_id = layer_id

View File

@@ -48,7 +48,7 @@ class ExpertLoadBalancer(object):
def generate_expert_placement_map(self):
expert_placement_map = torch.full(
(self.layers_num, self.ranks_num, self.global_expert_num),
(self.layers_num, self.ranks_num, self.num_experts),
-1,
dtype=torch.int32,
)
@@ -71,7 +71,7 @@ class ExpertLoadBalancer(object):
result_dict[key] = []
result_dict[key].append(idx)
log2phy_map = torch.full((self.ranks_num, self.global_expert_num),
log2phy_map = torch.full((self.ranks_num, self.num_experts),
-1,
dtype=torch.int32)
for rank in range(self.ranks_num):

View File

@@ -105,6 +105,9 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
group_list=group_list,
output_dtype=torch.int32)[0]
# act_fn: swiglu
group_diff = torch.diff(group_list)
new_group = torch.cat([group_list[0].unsqueeze(0), group_diff],
dim=0)
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states,
weight_scale=w1_scale,
@@ -112,7 +115,7 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=group_list,
group_index=new_group,
activate_left=True,
quant_mode=1,
)

View File

@@ -122,18 +122,14 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
expert_map: torch.Tensor,
global_redundant_expert_num: int = 0,
):
if self.with_quant:
quant_mode = 2
moe_expert_num = len(expert_map)
else:
quant_mode = 0
moe_expert_num = len(expert_map)
quant_mode = 2 if self.with_quant else 0
self.moe_expert_num = len(expert_map) + global_redundant_expert_num
kwargs_mc2 = {
"x": hidden_states,
"expert_ids": topk_ids,
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"moe_expert_num": self.moe_expert_num,
"global_bs": 0,
"expert_token_nums_type": 0,
}
@@ -229,7 +225,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
assert self.topk_weights is not None
assert self.topk_ids is not None
assert self.output is not None
moe_expert_num = len(self.expert_map)
# moeCombine
kwargs_mc2 = {
"expand_x": hidden_states,
@@ -237,7 +232,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
"expert_scales": self.topk_weights.to(torch.float32),
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"moe_expert_num": self.moe_expert_num,
"global_bs": 0,
}
if self.with_quant:
@@ -359,7 +354,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
hidden_states = hidden_states * \
topk_weights.to(hidden_states.dtype)
if expert_map is not None:
global_num_experts = len(expert_map)
global_num_experts = len(expert_map) + global_redundant_expert_num
mask = (expert_map[topk_ids] != -1)
self.topk_weights = topk_weights * mask
first_expert_idx = get_ep_group(

View File

@@ -250,7 +250,7 @@ class AscendW8A8DynamicFusedMoEMethod:
return moe_comm_method.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale_fp32,
w1_scale=layer.w13_weight_scale.to(torch.float32),
w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,