[BugFix]add all2all when dp_size > 1 && downgrade npu_dequant_swiglu_quant (#819)

<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
### What this PR does / why we need it?
1. This PR introduces native `all_to_all` communication operator to fix
`allgather` bugs when dp_size > 1. Besides, it adds a naive
implementation of force-load-balance when doing profile runs.
2. The operator `npu_dequant_swiglu_quant` only supports input
hidden_states with dtype `torch.int32`. This tensor occupies space of
`global_bs * seq_len * topk * hidden_size`, which might be very large as
`ep_size` grows. Therefore we need to disable this operator and use
original `swiglu` && `quantize`.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
By performing offline inference:

![image](https://github.com/user-attachments/assets/e003d5dc-0753-41ae-9303-e87f73ac6828)

---------

Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
This commit is contained in:
Angazenn
2025-05-15 09:19:55 +08:00
committed by GitHub
parent 68fb63428b
commit 1e67089bc9
7 changed files with 317 additions and 80 deletions

View File

@@ -321,14 +321,15 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = True,
enable_force_load_balance: bool = False,
dp_size: int = 1,
**kwargs,
) -> torch.Tensor:
return self.quant_method.apply(layer, x, router_logits, top_k,
renormalize, use_grouped_topk,
global_num_experts, expert_map,
topk_group, num_expert_group,
custom_routing_function, scoring_func,
e_score_correction_bias, is_prefill)
return self.quant_method.apply(
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
global_num_experts, expert_map, topk_group, num_expert_group,
custom_routing_function, scoring_func, e_score_correction_bias,
is_prefill, enable_force_load_balance, dp_size)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"):