[Refactor][MoE] remove redundant code after refactoring fused_moe (#2612)

### What this PR does / why we need it?
There are a lot of redundant codes related to moe here, and the
structure is not very clear.
We did the following things:

we have placed the relatively independent code related to apply_mlp into
a separate file;
removed the environment variables of alltoall_buffer and alltoall_seq.
Remove the code related to alltoall_buffer and alltoall_seq, and retain
the sole TokenDispatcher inheritance class.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
e2e&ut

- vLLM version: v0.10.1.1
- vLLM main:
4071c76cf3

---------

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Signed-off-by: weijinqian_v1 <weijinqian@huawei.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
weichen
2025-08-30 22:28:50 +08:00
committed by GitHub
parent 20ae71291d
commit 3a5fc5ee01
13 changed files with 417 additions and 1237 deletions

View File

@@ -38,15 +38,12 @@ from vllm.model_executor.layers.fused_moe.layer import (
from vllm.model_executor.layers.quantization.base_config import \
QuantizationConfig
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.distributed.communication_op import \
data_parallel_reduce_scatter
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import (
MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig)
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
@@ -54,74 +51,6 @@ from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
get_ascend_soc_version,
get_rm_router_logits_state, is_310p)
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
def torchair_process_topk_ids(topk_ids: torch.Tensor, expert_num: int,
ep_size: int, max_row_per_ep_rank: int,
num_tokens: int,
top_k: int) -> tuple[torch.Tensor, torch.Tensor]:
original_total_elements = num_tokens * top_k
device = topk_ids.device
original_dtype = topk_ids.dtype
if original_total_elements == 0:
output_len = ep_size * max_row_per_ep_rank
topk_ids_pad = torch.full((output_len, ),
expert_num,
dtype=original_dtype,
device=device)
unpad_indices = torch.full((original_total_elements, ),
-1,
dtype=torch.long,
device=device)
return topk_ids_pad, unpad_indices
experts_per_ep_rank_val = expert_num // ep_size
if experts_per_ep_rank_val == 0:
raise ValueError(
"expert_num // ep_size is 0, which leads to division by zero in ep_rank calculation. "
"Ensure expert_num >= ep_size.")
assigned_ep_rank = (topk_ids.float() /
experts_per_ep_rank_val).to(original_dtype)
indices_arange = torch.arange(topk_ids.shape[0], device=device)
is_new_segment = torch.cat(
(torch.tensor([True], device=device), assigned_ep_rank[1:]
!= assigned_ep_rank[:-1]))
temp_start_markers = torch.full_like(indices_arange,
-1,
dtype=indices_arange.dtype)
temp_start_markers[is_new_segment] = indices_arange[is_new_segment]
start_offset_for_each_token = torch.cummax(temp_start_markers, dim=0)[0]
token_intra_ep_rank_idx = indices_arange - start_offset_for_each_token
is_kept_mask = token_intra_ep_rank_idx < max_row_per_ep_rank
cumsum_kept = torch.cumsum(is_kept_mask.float(), dim=0).to(torch.long)
indices_in_rec_cond_list_for_all = cumsum_kept - 1
unpad_indices = torch.where(
is_kept_mask, indices_in_rec_cond_list_for_all,
torch.tensor(-1, device=device, dtype=torch.long))
output_len = ep_size * max_row_per_ep_rank
topk_ids_pad = torch.full((output_len, ),
expert_num,
dtype=original_dtype,
device=device)
if topk_ids.shape[0] > 0:
all_destination_indices = assigned_ep_rank * max_row_per_ep_rank + token_intra_ep_rank_idx
temp_pad_buffer = torch.full((output_len + 1, ),
expert_num,
dtype=original_dtype,
device=device)
output_len_tensor = torch.tensor(output_len,
dtype=torch.long,
device=device)
scatter_indices = torch.where(is_kept_mask, all_destination_indices,
output_len_tensor)
temp_pad_buffer.scatter_(0, scatter_indices, topk_ids)
topk_ids_pad = temp_pad_buffer[:output_len]
return topk_ids_pad, unpad_indices
def torchair_fused_experts_with_mc2(
hidden_states: torch.Tensor,
@@ -459,130 +388,6 @@ def torchair_fused_experts_with_all2all(
return final_hidden_states
# currently expert parallelism implemented with all2all
# is under-optimized.
def torchair_fused_experts_with_all2all_buffer(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
max_model_len: int,
global_batch_size: int,
expert_map: torch.Tensor = None,
ep_group: GroupCoordinator = None,
):
original_shape = hidden_states.shape
if len(original_shape) == 3:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
num_tokens, _ = hidden_states.shape
device = hidden_states.device
global_num_experts = len(expert_map)
local_num_experts = global_num_experts // ep_group.world_size
row_idx_len = num_tokens * top_k
row_idx = (torch.arange(0, row_idx_len, dtype=torch.int32,
device=device).view(top_k,
-1).permute(1, 0).contiguous())
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
row_idx=row_idx,
expert_idx=topk_ids,
active_num=num_tokens)
max_row_per_ep_rank = (-(-global_batch_size // ep_group.world_size) *
max_model_len // ep_group.world_size +
1) * top_k * 2
expert_idx_buffer_scatter, unpad_indices = torchair_process_topk_ids(
expanded_expert_idx, global_num_experts, ep_group.world_size,
max_row_per_ep_rank, num_tokens, top_k)
hidden_states_pad_idx = torch.zeros(
expert_idx_buffer_scatter.shape,
dtype=expert_idx_buffer_scatter.dtype,
device=expert_idx_buffer_scatter.device)
non_pad_len = torch.sum((expert_idx_buffer_scatter
!= global_num_experts).to(torch.int32))
hidden_states_pad_idx[expert_idx_buffer_scatter !=
global_num_experts] = torch.arange(
non_pad_len,
dtype=expert_idx_buffer_scatter.dtype,
device=hidden_states.device)
hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx]
expert_idx_buffer_gather = torch.empty_like(
expert_idx_buffer_scatter,
dtype=expert_idx_buffer_scatter.dtype,
device=expert_idx_buffer_scatter.device)
hidden_states_buffer_gather = torch.empty_like(
hidden_states_buffer_scatter,
dtype=hidden_states_buffer_scatter.dtype,
device=hidden_states_buffer_scatter.device)
dist.all_to_all_single(expert_idx_buffer_gather,
expert_idx_buffer_scatter,
group=ep_group.device_group)
dist.all_to_all_single(hidden_states_buffer_gather,
hidden_states_buffer_scatter,
group=ep_group.device_group)
mask = expert_idx_buffer_gather != global_num_experts
local_expert_idx = expert_idx_buffer_gather[mask] - ep_group.rank * (
global_num_experts // ep_group.world_size)
hidden_states = hidden_states_buffer_gather[mask]
idx_type = local_expert_idx.dtype
sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx.float())
sorted_local_expert_idx = sorted_local_expert_idx.to(idx_type)
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
sorted_local_expert_idx, local_num_experts).to(torch.int64)
hidden_states = hidden_states[sorted_idx]
group_list_type = 0
hidden_states = torchair_apply_mlp(hidden_states,
w1,
w2,
expert_tokens,
group_list_type=group_list_type)
resorted_idx = torch.argsort(sorted_idx.float()).to(sorted_idx.dtype)
hidden_states = hidden_states[resorted_idx]
hidden_states_scatter = torch.zeros(
(mask.shape[0], hidden_states.shape[1]),
dtype=hidden_states.dtype,
device=hidden_states.device)
hidden_states_scatter[mask] = hidden_states
hidden_states_gatter = torch.empty_like(
hidden_states_scatter,
dtype=hidden_states_scatter.dtype,
device=hidden_states_scatter.device)
dist.all_to_all_single(hidden_states_gatter,
hidden_states_scatter,
group=ep_group.device_group)
hidden_states_gatter = hidden_states_gatter[expert_idx_buffer_scatter !=
global_num_experts]
if hidden_states_gatter.shape[0] != row_idx_len:
hidden_states = torch.zeros((row_idx_len, hidden_states.shape[1]),
dtype=hidden_states.dtype,
device=hidden_states.device)
hidden_states[unpad_indices != -1] = hidden_states_gatter
else:
# TODO: Reorder device memory 2 times here, replace the current
hidden_states = hidden_states_gatter
final_hidden_states = torch_npu.npu_moe_finalize_routing(
hidden_states,
skip1=None,
skip2=None,
bias=None,
scales=topk_weights,
expanded_src_to_dst_row=expanded_row_idx,
export_for_source_row=topk_ids,
)
if len(original_shape) == 3:
final_hidden_states = final_hidden_states.view(original_shape)
return final_hidden_states
def torchair_fused_experts_moge(
hidden_states: torch.Tensor,
w1: torch.Tensor,
@@ -674,25 +479,6 @@ def torchair_fused_experts_moge(
return final_hidden_states
def torchair_fused_experts_with_all2allv(
token_dispatcher,
probs,
routing_map,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
):
# Enable moe alltoallv, it's a balanced policy for precision and efficiency.
(share_experts_output, dispatched_input,
tokens_per_expert) = (token_dispatcher.token_permutation(
hidden_states, probs, routing_map))
expert_output = torchair_apply_mlp(dispatched_input, w1, w2,
tokens_per_expert)
output, mlp_bias = token_dispatcher.token_unpermutation(expert_output)
return output
def torchair_fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
@@ -1120,28 +906,6 @@ class TorchairAscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map)
elif MOE_ALL2ALL_BUFFER:
return torchair_fused_experts_with_all2all_buffer(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
max_model_len=self.max_model_len,
global_batch_size=self.global_batch_size,
expert_map=expert_map,
ep_group=get_ep_group())
elif fused_moe_state == FusedMoEState.All2AllSeq:
token_dispatcher = kwargs.get("token_dispatcher")
return torchair_fused_experts_with_all2allv(
token_dispatcher=token_dispatcher,
probs=topk_weights,
routing_map=topk_ids,
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
)
else:
return torchair_fused_experts_with_all2all(
hidden_states=x,
@@ -1315,25 +1079,6 @@ class TorchairAscendFusedMoE(FusedMoE):
# NOTE: self.tp_group is not expert_tp_group
self.tp_group = get_tp_group().device_group
self.quant_method.create_weights(layer=self, **moe_quant_params)
self.token_dispatcher = None
if envs_ascend.VLLM_ASCEND_ENABLE_MOE_ALL2ALL_SEQ and isinstance(
self.quant_method, TorchairAscendUnquantizedFusedMoEMethod):
self.reduce_results = False
moe_dispatcher_config = (
MoEDispatcherConfig().set_num_moe_experts(
self.global_num_experts).set_num_local_experts(
self.local_num_experts).set_moe_router_topk(
top_k).set_group_topk(topk_group).
set_num_groups(num_expert_group).set_expert_bias(
e_score_correction_bias).set_scaling_factor(1.0).build())
self.token_dispatcher = MoEAlltoAllSeqOverLapDispatcher(
moe_dispatcher_config)
if envs_ascend.VLLM_ASCEND_ENABLE_DBO:
token_dispatcher1 = MoEAlltoAllSeqOverLapDispatcher(
moe_dispatcher_config)
self.token_dispatchers = [
self.token_dispatcher, token_dispatcher1
]
def naive_multicast(self, x: torch.Tensor,
cu_tokens_across_dp_cpu: torch.Tensor):
@@ -1486,7 +1231,6 @@ class TorchairAscendFusedMoE(FusedMoE):
shared_experts=shared_experts if self.torchair_graph_enabled
and self.enable_multistream_moe and not is_prefill else None,
mc2_mask=mc2_mask,
token_dispatcher=self.token_dispatcher,
quantized_x_for_share=quantized_x_for_share,
dynamic_scale_for_share=dynamic_scale_for_share,
)