[main] [refactor] refactor fused_moe.py to enable token_dispatchers (#2570)

### What this PR does / why we need it?
Enable token_dispatcher to replace fused_experts_with_xxx in eager mode
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
e2e & ut


- vLLM version: v0.10.1.1
- vLLM main:
704432af3c

Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: sherie <963372609@qq.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
Co-authored-by: shiyuan680 <72335504+shiyuan680@users.noreply.github.com>
This commit is contained in:
weichen
2025-08-28 10:13:35 +08:00
committed by GitHub
parent 936c102105
commit 320edde2df
10 changed files with 1066 additions and 1639 deletions

View File

@@ -22,21 +22,18 @@
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Optional
from typing import Any, Dict, Optional
import torch
import torch_npu
from vllm.distributed.parallel_state import get_ep_group
from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.distributed.tensor_parallel import (
all_gather_last_dim_from_tensor_parallel_region, all_to_all_hp2sp,
all_to_all_sp2hp, gather_from_sequence_parallel_region,
reduce_scatter_last_dim_to_tensor_parallel_region)
from vllm_ascend.ops.comm_utils import async_all_to_all
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
@@ -460,6 +457,31 @@ class MoEAlltoAllSeqOverLapDispatcher(MoEDispatcher):
return output, None
_Dispatchers: Dict[str, Any] = {}
def _register_token_dispatcher(dispatcher: Any):
_Dispatchers[dispatcher.__class__.__name__] = dispatcher
def get_token_dispatcher(name: str):
return _Dispatchers.get(name)
def setup_token_dispatchers(ep_size: int, **kwargs):
existing_dispatchers = set(_Dispatchers.keys())
if ep_size == 1 and "TokenDispatcherWithAllGather" not in existing_dispatchers:
_register_token_dispatcher(TokenDispatcherWithAllGather(**kwargs))
elif ep_size < 16 and "TokenDispatcherWithAll2AllV" not in existing_dispatchers:
_register_token_dispatcher(TokenDispatcherWithAll2AllV(**kwargs))
elif ep_size >= 16:
if "TokenDispatcherWithAll2AllV" not in existing_dispatchers:
_register_token_dispatcher(TokenDispatcherWithAll2AllV(**kwargs))
if "TokenDispatcherWithMC2" not in existing_dispatchers:
_register_token_dispatcher(TokenDispatcherWithMC2(**kwargs))
class MoETokenDispatcher(ABC):
def __init__(self, **kwargs) -> None:
@@ -484,18 +506,19 @@ class MoETokenDispatcher(ABC):
return get_ep_group().world_size
@abstractmethod
def token_dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
shared_experts: Optional[torch.Tensor] = None,
):
def token_dispatch(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[torch.Tensor] = None,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False):
raise NotImplementedError("Dispatch function not implemented.")
@abstractmethod
@@ -516,40 +539,39 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank)
self.ep_rank_id = get_mc2_group().rank_in_group
self.ep_world_size = get_mc2_group().world_size
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_dispatch_v2 = hasattr(torch_npu,
"npu_moe_distribute_dispatch_v2")
self.need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3
or self.torchair_graph_enabled)
self.need_extra_args = (
get_ascend_soc_version() == AscendSocVersion.A3)
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
self.a3_need_extra_args = \
get_ascend_soc_version() == AscendSocVersion.A3
self.output = None
self.dynamic_scale = None
self.assist_info_for_combine = None
self.ep_recv_counts = None
self.shared_act = None
self.topk_ids = None
self.topk_weights = None
self.shared_experts = None
self.mc2_mask = None
def get_dispatch_mc2_kwargs(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor,
global_redundant_expert_num: int = 0):
quant_mode = 0
forward_context = get_forward_context()
mc2_mask = forward_context.mc2_mask
def get_dispatch_mc2_kwargs(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor,
global_redundant_expert_num: int = 0,
):
if self.with_quant:
quant_mode = 2
if (expert_map is not None):
moe_expert_num = len(expert_map) + global_redundant_expert_num
else:
moe_expert_num = global_redundant_expert_num
else:
quant_mode = 0
moe_expert_num = len(expert_map)
kwargs_mc2 = {
"x": hidden_states,
@@ -575,28 +597,30 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
})
if self.a3_need_extra_args and self.enable_dispatch_v2:
stage1_kwargs.update({
"x_active_mask": mc2_mask,
"x_active_mask": self.mc2_mask,
})
kwargs_mc2.update(stage1_kwargs)
return kwargs_mc2
def token_dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
shared_experts: Optional[torch.Tensor] = None,
):
def token_dispatch(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[torch.Tensor] = None,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False):
self.expert_map = expert_map
self.topk_ids = topk_ids
self.topk_weights = topk_weights
self.shared_experts = shared_experts
self.mc2_mask = mc2_mask
kwargs_mc2 = self.get_dispatch_mc2_kwargs(hidden_states, topk_weights,
topk_ids, expert_map,
@@ -606,28 +630,27 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch(
**kwargs_mc2)
# comm_stream.wait_stream(torch.npu.current_stream())
expand_x, self.dynamic_scale, self.assist_info_for_combine, \
expand_x, dynamic_scale, self.assist_info_for_combine, \
expert_token_nums, self.ep_recv_counts = self.output[0:5]
if self.with_quant:
if shared_experts is not None:
with npu_stream_switch("moe_secondary", 0):
npu_wait_tensor(shared_gate_up, expand_x)
shared_act_out = shared_experts.act_fn(
(shared_gate_up, shared_dequant_scale))
self.shared_act, self.swiglu_out_scale = \
shared_act_out[0], shared_act_out[1]
shared_act_out = shared_experts.act_fn(
(shared_gate_up, shared_dequant_scale))
self.shared_act, self.swiglu_out_scale = \
shared_act_out[0], shared_act_out[1]
else:
if shared_experts is not None:
with npu_stream_switch("moe_secondary", 0):
npu_wait_tensor(hidden_states, topk_weights)
shared_gate_up, _ = shared_experts.gate_up_proj(
hidden_states)
npu_wait_tensor(shared_gate_up, expand_x)
self.shared_act = shared_experts.act_fn(shared_gate_up)
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
self.shared_act = shared_experts.act_fn(shared_gate_up)
group_list_type = 1
return group_list_type, expand_x, expert_token_nums
return {
"group_list_type": group_list_type,
"hidden_states": expand_x,
"group_list": expert_token_nums,
"dynamic_scale": dynamic_scale,
}
def get_combine_mc_kwargs(self, hidden_states: torch.Tensor):
assert self.expert_map is not None
@@ -635,8 +658,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
assert self.topk_ids is not None
assert self.output is not None
moe_expert_num = len(self.expert_map)
forward_context = get_forward_context()
mc2_mask = forward_context.mc2_mask
# moeCombine
kwargs_mc2 = {
"expand_x": hidden_states,
@@ -677,7 +698,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
})
if self.a3_need_extra_args and self.enable_dispatch_v2:
stage3_kwargs.update({
"x_active_mask": mc2_mask,
"x_active_mask": self.mc2_mask,
})
kwargs_mc2.update(stage3_kwargs)
return kwargs_mc2
@@ -685,7 +706,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
def token_combine(self,
hidden_states: torch.Tensor,
bias: torch.Tensor = None):
kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states)
hidden_states = torch_npu.npu_moe_distribute_combine_v2(
**kwargs_mc2
@@ -695,15 +715,11 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
return hidden_states
else:
if self.with_quant:
with npu_stream_switch("moe_secondary", 0):
npu_wait_tensor(self.shared_act, hidden_states)
shared_hidden_states, _ = self.shared_experts.down_proj(
(self.shared_act, self.swiglu_out_scale))
shared_hidden_states, _ = self.shared_experts.down_proj(
(self.shared_act, self.swiglu_out_scale))
else:
with npu_stream_switch("moe_secondary", 0):
npu_wait_tensor(self.shared_act, hidden_states)
shared_hidden_states, _ = self.shared_experts.down_proj(
self.shared_act)
shared_hidden_states, _ = self.shared_experts.down_proj(
self.shared_act)
return hidden_states, shared_hidden_states
@@ -711,13 +727,9 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.apply_router_weight_on_input = kwargs.get(
"apply_router_weight_on_input")
self.top_k = kwargs.get("top_k")
self.apply_router_weight_on_input = False
self.max_num_tokens = kwargs.get("max_num_tokens")
ep_size = kwargs.get("ep_size")
if ep_size is not None:
self.num_experts_local = self.num_experts // ep_size
self.num_experts_local = kwargs.get("num_local_experts", 0)
self.sorted_weights = None
self.expanded_row_idx = None
self.sorted_token_indices = None
@@ -727,20 +739,20 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
self.topk_weights = None
self.topk_ids = None
def token_dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
shared_experts: Optional[torch.Tensor] = None,
):
def token_dispatch(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[torch.Tensor] = None,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False):
self.original_shape = hidden_states.shape
# assert len(original_shape) == 2
num_tokens = hidden_states.shape[:-1].numel()
dtype = hidden_states.dtype
@@ -748,9 +760,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
self.expert_map = expert_map
self.topk_weights = topk_weights
self.topk_ids = topk_ids
# assert dtype in [torch.float32, torch.float16, torch.bfloat16
# ], "Only float32, float16, and bfsloat16 are supported"
self.apply_router_weight_on_input = apply_router_weight_on_input
if self.apply_router_weight_on_input:
assert (topk_weights.dim() == 2
), "`topk_weights` should be in shape (num_tokens, topk)"
@@ -803,19 +813,13 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
sorted_hidden_states = hidden_states[self.sorted_token_indices]
if self.with_quant:
group_list_type = 1
expert_tokens = token_counts
else:
expert_tokens = torch.cumsum(token_counts,
dim=0,
dtype=torch.int64)
group_list_type = 0
else:
row_idx_len = num_tokens * self.top_k
row_idx = (torch.arange(0,
row_idx_len,
dtype=torch.int32,
device=device).view(self.top_k,
-1).permute(
1, 0).contiguous())
active_num = self.max_num_tokens if self.max_num_tokens is not None else num_tokens
sorted_hidden_states, self.expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
hidden_states,
@@ -827,18 +831,23 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
expanded_expert_idx, self.num_experts_local)
expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 0
return group_list_type, sorted_hidden_states, expert_tokens
return {
"group_list_type": group_list_type,
"hidden_states": sorted_hidden_states,
"group_list": expert_tokens,
}
def token_combine(self,
hidden_states: torch.Tensor,
bias: torch.Tensor = None):
assert self.mask is not None
assert self.sorted_token_indices is not None
assert self.sorted_weights is not None
assert self.original_shape is not None
dtype = hidden_states.dtype
device = hidden_states.device
if self.expert_map is not None:
assert self.mask is not None
assert self.sorted_token_indices is not None
assert self.sorted_weights is not None
weighted_down_out = hidden_states * \
self.sorted_weights.unsqueeze(1)
@@ -887,7 +896,6 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
expanded_src_to_dst_row=self.expanded_row_idx,
export_for_source_row=self.topk_ids,
)
return final_hidden_states
@@ -895,29 +903,27 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher):
def __init__(self, **kwargs):
super(MoETokenDispatcher, self).__init__(**kwargs)
self.apply_router_weight_on_input = kwargs.get(
"apply_router_weight_on_input")
ep_size = kwargs.get("ep_size")
self.local_ep = ep_size
assert self.local_ep is not None
super().__init__(**kwargs)
self.apply_router_weight_on_input = False
self.local_ep = 1
self.local_num_experts = self.num_experts // self.local_ep
self.local_num_group = self.top_k // self.local_ep
self.bsz = None
def token_dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
shared_experts: Optional[torch.Tensor] = None,
):
def token_dispatch(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[torch.Tensor] = None,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False):
self.apply_router_weight_on_input = apply_router_weight_on_input
if self.apply_router_weight_on_input:
assert (topk_weights.dim() == 2
), "`topk_weights` should be in shape (num_tokens, topk)"
@@ -932,7 +938,7 @@ class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher):
flatten_topk_ids = topk_ids.view(-1)
self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
self.sorted_topk_ids = self.sorted_topk_ids.to(torch.int32)
self.sorted_hidden_states = hidden_states.index_select(
sorted_hidden_states = hidden_states.index_select(
0, self.sorted_topk_ids // self.local_num_group)
experts_id = torch.arange(0,
@@ -942,15 +948,20 @@ class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher):
num_tokens_per_expert = (
flatten_topk_ids.unsqueeze(-1) == experts_id).to(
torch.float32).sum(0)
self.topk_scales = topk_weights.view(-1).index_select(
topk_scales = topk_weights.view(-1).index_select(
0, self.sorted_topk_ids).unsqueeze(-1)
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
return hidden_states, group_list
group_list_type = 0
return {
"group_list_type": group_list_type,
"hidden_states": sorted_hidden_states,
"group_list": group_list,
"topk_scales": topk_scales,
}
def token_combine(self,
hidden_states: torch.Tensor,
bias: torch.Tensor = None):
assert self.local_ep is not None
unsorted_topk_ids = torch.argsort(self.sorted_topk_ids.float()).to(
torch.int32)
unsorted_hidden_states = hidden_states.index_select(
@@ -1009,18 +1020,19 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
self.local_expert_indices[i + 1] -
1), "local_expert_indices must be continuous"
def token_dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
shared_experts: Optional[torch.Tensor] = None,
):
def token_dispatch(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
row_idx: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
log2phy: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[torch.Tensor] = None,
shared_gate_up: Optional[torch.Tensor] = None,
shared_dequant_scale: Optional[torch.Tensor] = None,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False):
self.hidden_shape = hidden_states.shape
self.topk_weights = topk_weights
assert topk_weights.dim() == 2, "Expected 2D tensor for topk_weights"