[Refactor] Adjustments to moe_comm_method selection process (#3001)
### What this PR does / why we need it?
Fix issues mentioned in
https://github.com/vllm-project/vllm-ascend/pull/2791 and some minor
refactoring.
1. Use Enum instead of string.
2. Avoid setting a new property to forward_context in
AscendFusedMoE.forward().
3. Enabling TokenDispatcherWithMoge.
4. Remove redundant code.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Qwen3-30B-A3B/Qwen3-30B-A3B-W8A8/DeepSeek-V3-W4A8-Pruing/deepseek-mtp/pangu-pro-moe-pruing:
1. Enable/Disable EP
2. Aclgraph & eager
- vLLM version: v0.10.2
- vLLM main:
9607d5eb44
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
@@ -23,106 +23,23 @@ from vllm.config import CompilationLevel, get_current_vllm_config
|
||||
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.config import \
|
||||
FusedMoEParallelConfig # isort: skip
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
|
||||
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.eplb.core.eplb_utils import (determine_default_expert_map,
|
||||
determine_default_log2phy_map)
|
||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||
from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl,
|
||||
AlltoAllCommImpl, MC2CommImpl,
|
||||
NaiveMulticastCommImpl)
|
||||
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, npu_stream_switch
|
||||
|
||||
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
|
||||
|
||||
|
||||
def fused_experts_moge(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
global_num_experts: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
|
||||
Args:
|
||||
hidden_states: Hidden states of shape (num_tokens, hidden_size).
|
||||
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size).
|
||||
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size).
|
||||
topk_weights: Routing weights of shape (num_tokens, top_k).
|
||||
topk_ids: Selected expert IDs of shape (num_tokens, top_k).
|
||||
top_k: Number of experts to select.
|
||||
expert_map: Expert mapping of shape (num_experts,).
|
||||
|
||||
Returns:
|
||||
hidden_states: Hidden states after routing.
|
||||
"""
|
||||
ep_size = moe_parallel_config.ep_size
|
||||
local_num_experts = global_num_experts // ep_size
|
||||
local_num_group = top_k // ep_size
|
||||
|
||||
bsz, _ = hidden_states.shape
|
||||
flatten_topk_ids = topk_ids.view(-1)
|
||||
sorted_topk_ids = torch.argsort(flatten_topk_ids.float())
|
||||
sorted_topk_ids = sorted_topk_ids.to(torch.int32)
|
||||
sorted_hidden_states = hidden_states.index_select(
|
||||
0, sorted_topk_ids // local_num_group)
|
||||
|
||||
experts_id = torch.arange(0,
|
||||
local_num_experts,
|
||||
dtype=topk_ids.dtype,
|
||||
device=topk_ids.device)
|
||||
num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to(
|
||||
torch.float32).sum(0)
|
||||
topk_scales = topk_weights.view(-1).index_select(
|
||||
0, sorted_topk_ids).unsqueeze(-1)
|
||||
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
|
||||
|
||||
gate_up_out = torch_npu.npu_grouped_matmul(
|
||||
x=[sorted_hidden_states],
|
||||
weight=[w1],
|
||||
split_item=2,
|
||||
group_list_type=0,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
|
||||
if is_310p():
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to(
|
||||
torch.float16)
|
||||
else:
|
||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
||||
gate_up_out *= topk_scales
|
||||
|
||||
down_out_list = torch_npu.npu_grouped_matmul(
|
||||
x=[gate_up_out],
|
||||
weight=[w2],
|
||||
split_item=2,
|
||||
group_list_type=0,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
)[0]
|
||||
|
||||
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32)
|
||||
unsorted_hidden_states = down_out_list.index_select(0, unsorted_topk_ids)
|
||||
final_hidden_states = unsorted_hidden_states.reshape(
|
||||
bsz, top_k // ep_size, -1).sum(1)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
def unquantized_fused_moe_init_func(self, *args, **kwargs):
|
||||
original_unquantized_fused_moe_init_func(self, *args, **kwargs)
|
||||
|
||||
@@ -178,20 +95,6 @@ def forward_oot(
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
if topk_ids.shape[1] < top_k or is_310p():
|
||||
assert global_num_experts is not None
|
||||
return fused_experts_moge(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
moe_parallel_config=self.moe.moe_parallel_config,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
return moe_comm_method.fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
@@ -277,13 +180,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
if self.dynamic_eplb:
|
||||
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
|
||||
|
||||
for method in {
|
||||
AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl,
|
||||
NaiveMulticastCommImpl
|
||||
}:
|
||||
setattr(
|
||||
self, method.__name__.lower(),
|
||||
method(moe_config=self.moe_config)) # type: ignore[abstract]
|
||||
setup_moe_comm_method(self.moe_config)
|
||||
|
||||
def update_expert_map(self, new_expert_map):
|
||||
self.expert_map = new_expert_map
|
||||
@@ -307,8 +204,8 @@ class AscendFusedMoE(FusedMoE):
|
||||
outputs since each rank only has partial outputs.
|
||||
"""
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_method_name = forward_context.moe_comm_method_name
|
||||
if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
|
||||
moe_comm_type = forward_context.moe_comm_type
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}:
|
||||
return final_hidden_states
|
||||
else:
|
||||
return tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
@@ -318,10 +215,6 @@ class AscendFusedMoE(FusedMoE):
|
||||
assert self.quant_method is not None
|
||||
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_method_name = forward_context.moe_comm_method_name
|
||||
|
||||
forward_context.moe_comm_method = getattr(self, moe_comm_method_name)
|
||||
|
||||
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
|
||||
hidden_states=hidden_states, router_logits=router_logits)
|
||||
|
||||
@@ -449,8 +342,8 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
|
||||
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_method_name = forward_context.moe_comm_method_name
|
||||
if moe_comm_method_name in {"alltoallcommimpl", "mc2commimpl"}:
|
||||
moe_comm_type = forward_context.moe_comm_type
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2}:
|
||||
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
||||
|
||||
_, fused_out = AscendFusedMoE.forward(
|
||||
|
||||
Reference in New Issue
Block a user