[BugFix]add all2all when dp_size > 1 && downgrade npu_dequant_swiglu_quant (#819)
<!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? 1. This PR introduces native `all_to_all` communication operator to fix `allgather` bugs when dp_size > 1. Besides, it adds a naive implementation of force-load-balance when doing profile runs. 2. The operator `npu_dequant_swiglu_quant` only supports input hidden_states with dtype `torch.int32`. This tensor occupies space of `global_bs * seq_len * topk * hidden_size`, which might be very large as `ep_size` grows. Therefore we need to disable this operator and use original `swiglu` && `quantize`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? By performing offline inference:  --------- Signed-off-by: angazenn <zengyanjia@huawei.com> Co-authored-by: angazenn <zengyanjia@huawei.com>
This commit is contained in:
@@ -321,14 +321,15 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
enable_force_load_balance: bool = False,
|
||||
dp_size: int = 1,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return self.quant_method.apply(layer, x, router_logits, top_k,
|
||||
renormalize, use_grouped_topk,
|
||||
global_num_experts, expert_map,
|
||||
topk_group, num_expert_group,
|
||||
custom_routing_function, scoring_func,
|
||||
e_score_correction_bias, is_prefill)
|
||||
return self.quant_method.apply(
|
||||
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
|
||||
global_num_experts, expert_map, topk_group, num_expert_group,
|
||||
custom_routing_function, scoring_func, e_score_correction_bias,
|
||||
is_prefill, enable_force_load_balance, dp_size)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if hasattr(self.quant_method, "process_weights_after_loading"):
|
||||
|
||||
Reference in New Issue
Block a user