[BugFix]add all2all when dp_size > 1 && downgrade npu_dequant_swiglu_quant (#819)

<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
### What this PR does / why we need it?
1. This PR introduces native `all_to_all` communication operator to fix
`allgather` bugs when dp_size > 1. Besides, it adds a naive
implementation of force-load-balance when doing profile runs.
2. The operator `npu_dequant_swiglu_quant` only supports input
hidden_states with dtype `torch.int32`. This tensor occupies space of
`global_bs * seq_len * topk * hidden_size`, which might be very large as
`ep_size` grows. Therefore we need to disable this operator and use
original `swiglu` && `quantize`.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
By performing offline inference:

![image](https://github.com/user-attachments/assets/e003d5dc-0753-41ae-9303-e87f73ac6828)

---------

Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
This commit is contained in:
Angazenn
2025-05-15 09:19:55 +08:00
committed by GitHub
parent 68fb63428b
commit 1e67089bc9
7 changed files with 317 additions and 80 deletions

View File

@@ -18,7 +18,6 @@
from typing import Callable, Optional
import torch
import torch.distributed as dist
import torch_npu
from vllm.config import get_current_vllm_config
from vllm.distributed import tensor_model_parallel_all_reduce
@@ -636,6 +635,7 @@ class AscendFusedMoE(FusedMoE):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_prefill: bool,
enable_force_load_balance: bool = False,
top_k=None):
assert self.quant_method is not None
@@ -644,17 +644,8 @@ class AscendFusedMoE(FusedMoE):
else:
real_top_k = self.top_k
if self.dp_size > 1:
if VLLM_ENABLE_MC2 and not is_prefill:
...
elif USING_LCCL_COM: # type: ignore
hidden_states = get_dp_group().all_gather(
hidden_states, 0, False)
router_logits = get_dp_group().all_gather(
router_logits, 0, False)
else:
hidden_states = get_dp_group().all_gather(hidden_states, 0)
router_logits = get_dp_group().all_gather(router_logits, 0)
if VLLM_ENABLE_MC2 and not is_prefill:
...
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
@@ -671,17 +662,12 @@ class AscendFusedMoE(FusedMoE):
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
is_prefill=is_prefill)
is_prefill=is_prefill,
enable_force_load_balance=enable_force_load_balance,
dp_size=self.dp_size)
if self.dp_size > 1:
if VLLM_ENABLE_MC2 and not is_prefill:
...
else:
final_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
final_hidden_states,
"sum",
scatter_dim=0,
group=get_dp_group().device_group)
if VLLM_ENABLE_MC2 and not is_prefill:
...
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(