[EPLB][Bugfix] Dispatch Allgather use log2phy if enable eplb (#5933)

### What this PR does / why we need it?
1. Move the logic of expert mapping forward to prevent shotgun changes
2. Disable the update of expert map.

### How was this patch tested?
a2
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| GPQA_diamond | 53064e | accuracy | gen | 73.23 |

a3
| dataset | version | metric | mode | vllm-api-general-chat |
|----- | ----- | ----- | ----- | -----|
| aime2024 | 604a78 | accuracy | gen | 83.33 |


- vLLM version: v0.13.0
- vLLM main:
11b6af5280

Signed-off-by: shenchuxiaofugui <1311027364@qq.com>
This commit is contained in:
LI SHENGYONG
2026-01-19 09:24:25 +08:00
committed by GitHub
parent 9fed2636cb
commit bc1f6713e7
6 changed files with 6 additions and 53 deletions

View File

@@ -150,7 +150,6 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
dynamic_eplb=self.dynamic_eplb,

View File

@@ -105,7 +105,6 @@ class MoECommMethod(ABC):
use_int8_w8a8: bool = False,
use_int4_w4a8: bool = False,
use_int4_w4a16: bool = False,
global_num_experts: Optional[int] = None,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[list[torch.Tensor]] = None,
w2_scale: Optional[list[torch.Tensor]] = None,
@@ -128,12 +127,15 @@ class MoECommMethod(ABC):
assert moe_comm_method is not None, "Missing communication context"
before_dispatch_evt = torch.npu.current_stream().record_event()
# Apply log2phy if needed
if log2phy is not None:
topk_ids = log2phy[topk_ids]
dispatch_results = self.token_dispatcher.token_dispatch(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
log2phy=log2phy,
global_redundant_expert_num=self.moe_config.
global_redundant_expert_num,
mc2_mask=mc2_mask,
@@ -278,7 +280,6 @@ class FusedMC2CommImpl(MoECommMethod):
use_int8_w8a8: bool = False,
use_int4_w4a8: bool = False,
use_int4_w4a16: bool = False,
global_num_experts: Optional[int] = None,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[list[torch.Tensor]] = None,
w2_scale: Optional[list[torch.Tensor]] = None,

View File

@@ -80,7 +80,6 @@ class MoETokenDispatcher(ABC):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
@@ -188,7 +187,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
@@ -197,10 +195,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
pertoken_scale: Optional[torch.Tensor] = None):
self.with_quant = with_quant
# Apply log2phy if needed
if log2phy is not None:
topk_ids = log2phy[topk_ids]
kwargs_mc2 = self.get_dispatch_mc2_kwargs(hidden_states, topk_weights,
topk_ids, expert_map,
mc2_mask,
@@ -309,7 +303,6 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
@@ -429,7 +422,6 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
@@ -439,9 +431,6 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
self.with_quant = with_quant
self.hidden_shape = hidden_states.shape
if log2phy is not None:
topk_ids = log2phy[topk_ids]
(
permutated_local_input_tokens,
reversed_local_input_permutation_mapping,