[BugFix]add all2all when dp_size > 1 && downgrade npu_dequant_swiglu_quant (#819)

<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
### What this PR does / why we need it?
1. This PR introduces native `all_to_all` communication operator to fix
`allgather` bugs when dp_size > 1. Besides, it adds a naive
implementation of force-load-balance when doing profile runs.
2. The operator `npu_dequant_swiglu_quant` only supports input
hidden_states with dtype `torch.int32`. This tensor occupies space of
`global_bs * seq_len * topk * hidden_size`, which might be very large as
`ep_size` grows. Therefore we need to disable this operator and use
original `swiglu` && `quantize`.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
By performing offline inference:

![image](https://github.com/user-attachments/assets/e003d5dc-0753-41ae-9303-e87f73ac6828)

---------

Signed-off-by: angazenn <zengyanjia@huawei.com>
Co-authored-by: angazenn <zengyanjia@huawei.com>
This commit is contained in:
Angazenn
2025-05-15 09:19:55 +08:00
committed by GitHub
parent 68fb63428b
commit 1e67089bc9
7 changed files with 317 additions and 80 deletions

View File

@@ -15,15 +15,19 @@
# limitations under the License.
#
import os
from typing import Any, Callable, Dict, List, Optional
import torch
import torch.distributed as dist
import torch_npu
from vllm.distributed import GroupCoordinator
import vllm_ascend.envs as envs_ascend
from vllm_ascend.distributed.parallel_state import get_ep_group
from vllm_ascend.ops.fused_moe import select_experts
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
w1: torch.Tensor,
@@ -68,24 +72,18 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w1],
split_item=3,
scale=[w1_scale],
per_token_scale=[pertoken_scale],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=torch.int32)[0]
output_dtype=w2_scale.dtype)[0]
# act_fn: swiglu
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
x=hidden_states,
weight_scale=w1_scale,
activation_scale=pertoken_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=group_list,
activate_left=True,
quant_mode=1,
)
hidden_states = torch_npu.npu_swiglu(hidden_states)
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
hidden_states)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
@@ -201,6 +199,132 @@ def fused_experts_with_mc2(
return hidden_states
# currently expert parallelism implemented with all2all
# is under-optimized.
def fused_experts_with_all2all(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None,
ep_group: GroupCoordinator = None,
):
original_shape = hidden_states.shape
if len(original_shape) == 3:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
num_tokens, _ = hidden_states.shape
num_experts = w1.shape[0]
device = hidden_states.device
if expert_map is not None:
global_num_experts = len(expert_map)
local_num_experts = global_num_experts // ep_group.world_size
row_idx_len = num_tokens * top_k
row_idx = (torch.arange(0,
row_idx_len,
dtype=torch.int32,
device=device).view(top_k, -1).permute(
1, 0).contiguous())
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
expert_idx=topk_ids,
active_num=num_tokens)
global_expert_tokens = torch.bincount(expanded_expert_idx,
minlength=global_num_experts)
scatter_sizes = global_expert_tokens.view(ep_group.world_size,
-1).sum(-1)
gather_sizes = torch.empty_like(scatter_sizes)
dist.all_to_all_single(gather_sizes,
scatter_sizes,
group=ep_group.device_group)
scatter_size_list = scatter_sizes.cpu().tolist()
gather_size_list = gather_sizes.cpu().tolist()
expanded_expert_idx = expanded_expert_idx % local_num_experts
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
scatter_size_list,
gather_size_list)
local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0,
scatter_size_list,
gather_size_list)
sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx)
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
sorted_local_expert_idx, local_num_experts).to(torch.int64)
hidden_states = hidden_states[sorted_idx]
group_list_type = 0
else:
row_idx_len = num_tokens * top_k
row_idx = torch.arange(0,
row_idx_len,
dtype=torch.int32,
device=topk_weights.device).view(
top_k, -1).permute(1, 0).contiguous()
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
expert_idx=topk_ids,
active_num=num_tokens)
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
expanded_expert_idx, num_experts)
expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 0
hidden_states_wrapper = [hidden_states]
del hidden_states
hidden_states = apply_mlp(hidden_states_wrapper,
w1,
w1_scale,
w2,
w2_scale,
expert_tokens,
group_list_type=group_list_type)
if expert_map is not None:
resorted_idx = torch.argsort(sorted_idx)
hidden_states = hidden_states[resorted_idx]
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
gather_size_list,
scatter_size_list)
final_hidden_states = torch_npu.npu_moe_finalize_routing(
hidden_states,
skip1=None,
skip2=None,
bias=None,
scales=topk_weights,
expanded_src_to_dst_row=expanded_row_idx,
export_for_source_row=topk_ids,
)
else:
# TODO: Reorder device memory 2 times here, replace the current
# implementation here when suitable operators become available.
final_hidden_states = torch_npu.npu_moe_finalize_routing(
hidden_states,
skip1=None,
skip2=None,
bias=None,
scales=topk_weights,
expanded_src_to_dst_row=expanded_row_idx,
export_for_source_row=topk_ids,
)
if len(original_shape) == 3:
final_hidden_states = final_hidden_states.view(original_shape)
return final_hidden_states
def fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w1_scale: torch.Tensor,
@@ -387,10 +511,10 @@ class AscendW8A8DynamicFusedMoEMethod:
def __init__(self):
self.transpose_weight = True
ep_group = get_ep_group()
self.ep_group = get_ep_group()
try:
device_group = ep_group.device_group
device_group = self.ep_group.device_group
# TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch.distributed.get_rank(group=device_group)
backend = device_group._get_backend(torch.device("npu"))
@@ -457,6 +581,8 @@ class AscendW8A8DynamicFusedMoEMethod:
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: bool = True,
enable_force_load_balance: bool = True,
dp_size: int = 1,
**kwargs,
) -> torch.Tensor:
assert router_logits.shape[
@@ -491,7 +617,13 @@ class AscendW8A8DynamicFusedMoEMethod:
e_score_correction_bias=e_score_correction_bias,
)
if os.environ.get("VLLM_ENABLE_MC2", '0') == "1" and not is_prefill:
# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
if VLLM_ENABLE_MC2 and not is_prefill:
return fused_experts_with_mc2(
hidden_states=x,
w1=layer.w13_weight,
@@ -503,7 +635,7 @@ class AscendW8A8DynamicFusedMoEMethod:
top_k=top_k,
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
else:
elif dp_size == 1:
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale,
@@ -513,6 +645,17 @@ class AscendW8A8DynamicFusedMoEMethod:
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map)
else:
return fused_experts_with_all2all(hidden_states=x,
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale,
w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map,
ep_group=self.ep_group)
def process_weights_after_loading(self, layer):
if self.transpose_weight:
@@ -521,7 +664,7 @@ class AscendW8A8DynamicFusedMoEMethod:
layer.w2_weight.data = layer.w2_weight.data.transpose(
1, 2).contiguous()
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
layer.w13_weight_scale.data.shape[0], -1).to(torch.float32)
layer.w13_weight_scale.data.shape[0], -1)
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
layer.w13_weight_offset.data.shape[0], -1)
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(