BugFix: Resolve shape mismatch in eplb update and calculation issues in quant_apply_mlp (#4777)
## Description This PR addresses two key issues in the MoE module when redundant experts are enabled, and fixes a calculation precision bug in the forward inference of quantized MLP: ### 1. Shape Mismatch in EPLB Expert Map Update - **Root Cause**: When redundant experts are turned on, a shape inconsistency occurs during the expert map update in `Vllm_apaptor`: - The shape of `self.expert_map_per_layer[layer_id]` is `[num_physical_experts,]` (aligned with physical expert count). - The shape of `updated_expert_map` is `[num_logical_experts,]` (aligned with logical expert count). - Indices in `self.expert_map_per_layer[layer_id]` that exceed the logical expert count cannot be properly mapped, leading to tensor shape mismatch errors. - The same shape mismatch exists in the `log2phy` map update (between `self.log2phy_map_per_layer[layer_id]` and `updated_log2phy_map`). - **Fix**: - Fix the shape initialization of `expert_map_per_layer` and `log2phy_map_per_layer` to be consistently set to `[num_physical_experts,]` across the module lifecycle. - Align the shape of `updated_expert_map` and `updated_log2phy_map` with the pre-initialized physical-expert-sized tensors during update operations, ensuring shape consistency for index mapping. ### 2. Calculation Precision Issue in Quantized MoE MLP Forward Inference - **Root Cause**: In the forward pass of `moe_mlp`, the `torch_npu.npu_dequant_swiglu_quant` operator only accepts group lists in **Count format** as input. However, the group list provided by `quant_apply_mlp` was in **Cumsum format**, which caused operator input format mismatch and degraded calculation precision. - **Fix**: - Convert the cumsum-formatted group list from `quant_apply_mlp` to Count format before passing it to `torch_npu.npu_dequant_swiglu_quant`. - Ensure the input format of the dequantization operator meets its requirements, restoring the expected calculation precision for quantized MoE MLP layers. ## Impact - Resolves shape mismatch errors in EPLB expert/log2phy map updates when redundant experts are enabled, ensuring stable expert routing. - Fixes quantized MoE MLP forward precision issues on NPU, aligning operator input formats with NPU kernel requirements. - No breaking changes to existing interfaces; the fixes are backward-compatible for scenarios without redundant experts enabled. --------- Signed-off-by: Che Ruan <cr623@ic.ac.uk> Signed-off-by: Mercykid-bash <ruanche0218@gmail.com> Co-authored-by: Che Ruan <cr623@ic.ac.uk> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -48,7 +48,7 @@ class ExpertLoadBalancer(object):
|
||||
|
||||
def generate_expert_placement_map(self):
|
||||
expert_placement_map = torch.full(
|
||||
(self.layers_num, self.ranks_num, self.global_expert_num),
|
||||
(self.layers_num, self.ranks_num, self.num_experts),
|
||||
-1,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
@@ -71,7 +71,7 @@ class ExpertLoadBalancer(object):
|
||||
result_dict[key] = []
|
||||
result_dict[key].append(idx)
|
||||
|
||||
log2phy_map = torch.full((self.ranks_num, self.global_expert_num),
|
||||
log2phy_map = torch.full((self.ranks_num, self.num_experts),
|
||||
-1,
|
||||
dtype=torch.int32)
|
||||
for rank in range(self.ranks_num):
|
||||
|
||||
Reference in New Issue
Block a user