[BugFix]add all2all when dp_size > 1 && downgrade npu_dequant_swiglu_quant (#819)

<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
### What this PR does / why we need it?
1. This PR introduces native `all_to_all` communication operator to fix
`allgather` bugs when dp_size > 1. Besides, it adds a naive
implementation of force-load-balance when doing profile runs.
2. The operator `npu_dequant_swiglu_quant` only supports input
hidden_states with dtype `torch.int32`. This tensor occupies space of
`global_bs * seq_len * topk * hidden_size`, which might be very large as
`ep_size` grows. Therefore we need to disable this operator and use
original `swiglu` && `quantize`.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
By performing offline inference:

![image](https://github.com/user-attachments/assets/e003d5dc-0753-41ae-9303-e87f73ac6828)

---------

Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
This commit is contained in:
Angazenn
2025-05-15 09:19:55 +08:00
committed by GitHub
parent 68fb63428b
commit 1e67089bc9
7 changed files with 317 additions and 80 deletions

View File

@@ -205,50 +205,66 @@ class CustomDeepseekV2MoE(nn.Module):
)
CustomDeepseekV2MoE.top_k = config.num_experts_per_tok
vllm_config = get_current_vllm_config()
self.dp_size = get_dp_group().world_size
batch_size = vllm_config.scheduler_config.max_num_seqs
params_dtype = torch.get_default_dtype()
self.final_hidden_states = torch.zeros(
[batch_size, config.hidden_size], dtype=params_dtype, device="npu")
self.tp_group = get_tp_group().device_group
self.tp_rank = get_tp_group().rank_in_group
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
attn_metadata = get_forward_context().attn_metadata
# when profile runs, force experts to load balanced tokens
# to avoid high memory consumption on a single rank.
# TODO: need a better flag to indicate whether in profile run or not.
if attn_metadata is None:
# for profile run
is_prefill = True
enable_force_load_balance = True
else:
is_prefill = attn_metadata.num_prefills > 0
enable_force_load_balance = False
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
if (self.tp_size > 1 and VLLM_ENABLE_MC2 and not is_prefill):
chunks = torch.chunk(hidden_states,
get_tp_group().world_size,
dim=0)
hidden_states = chunks[get_tp_group().rank_in_group]
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
is_prefill=is_prefill,
top_k=CustomDeepseekV2MoE.top_k) * self.routed_scaling_factor
if self.tp_size > 1:
if VLLM_ENABLE_MC2 and not is_prefill:
dist.all_gather_into_tensor(self.final_hidden_states,
final_hidden_states, self.tp_group)
final_hidden_states = self.final_hidden_states
else:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
if self.tp_size > 1:
# pass
num_tokens, hidden_size = hidden_states.shape
if num_tokens < self.tp_size:
target_size = self.tp_size
new_hidden_states = torch.empty([target_size, hidden_size],
dtype=hidden_states.dtype,
device=hidden_states.device)
new_hidden_states[:num_tokens] = hidden_states
hidden_states = new_hidden_states
chunk_hidden_states = torch.tensor_split(hidden_states,
self.tp_size,
dim=0)
local_hidden_states = chunk_hidden_states[self.tp_rank]
else:
local_hidden_states = hidden_states
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(local_hidden_states)
router_hidden_states = self.experts(
hidden_states=local_hidden_states,
router_logits=router_logits,
is_prefill=is_prefill,
top_k=CustomDeepseekV2MoE.top_k,
enable_force_load_balance=enable_force_load_balance,
) * self.routed_scaling_factor
if self.tp_size > 1:
dist.all_gather(list(chunk_hidden_states), router_hidden_states,
self.tp_group)
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
if num_tokens < self.tp_size:
final_hidden_states = final_hidden_states[:num_tokens]
else:
final_hidden_states = router_hidden_states
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
return final_hidden_states.view(num_tokens, hidden_dim)