[BugFix] Fix accuracy bugs for unquantized deepseekv3 models (#897)
### What this PR does / why we need it? This PR fixes two accuracy bugs incurred by PR #819 when running deepseekv3 series models: 1. #819 adds `all_to_all` communication in quantized cases, but `all_gather` && `reduce_scatter` are removed in both of quantized and unquantized cases. When running unquantized deepseekv3 models with `ep_size == world_size`, the moe modules fail to communicate. Therefore, this PR adds `all_to_all` communication on unquantized situation to solve this accuracy issue. 2. Use `ep_size` rather than `dp_size` to decide whether to use `all_to_all` in moe. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed with new added/existing test. --------- Signed-off-by: angazenn <zengyanjia@huawei.com> Co-authored-by: angazenn <zengyanjia@huawei.com>
This commit is contained in:
@@ -323,14 +323,13 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
enable_force_load_balance: bool = False,
|
||||
dp_size: int = 1,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return self.quant_method.apply(
|
||||
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
|
||||
global_num_experts, expert_map, topk_group, num_expert_group,
|
||||
custom_routing_function, scoring_func, e_score_correction_bias,
|
||||
is_prefill, enable_force_load_balance, dp_size)
|
||||
is_prefill, enable_force_load_balance)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if hasattr(self.quant_method, "process_weights_after_loading"):
|
||||
|
||||
@@ -582,7 +582,6 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
enable_force_load_balance: bool = True,
|
||||
dp_size: int = 1,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
@@ -635,7 +634,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
top_k=top_k,
|
||||
expert_map=expert_map,
|
||||
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
|
||||
elif dp_size == 1:
|
||||
elif self.ep_group.world_size == 1:
|
||||
return fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
@@ -646,6 +645,10 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
top_k=top_k,
|
||||
expert_map=expert_map)
|
||||
else:
|
||||
# The current implementation of deepseek moe splits hidden_states
|
||||
# according to tp_size before they are feed into fused_moe module.
|
||||
# Therefore, all2all is needed no matter how dp/tp is set so as to
|
||||
# dispatch/combine tokens.
|
||||
return fused_experts_with_all2all(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
|
||||
Reference in New Issue
Block a user