[main] [refactor] refactor common_fused_moe.py (#2706)

### What this PR does / why we need it?
1. Move prepare/finalize operation from moe_comm_method to
/ops/moe/fused_moe_prepare_and_finalize
2. Adapt to token_dispatcher in moe_comm_method
3. Move
moe_comm_method/experts_selector/token_dispatcher/fused_moe_prepare_and_finalize
to /ops/moe
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
e2e & ut

- vLLM version: v0.10.1.1
- vLLM main:
f4962a6d55

Signed-off-by: weichen <calvin_zhu0210@outlook.com>
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
weichen
2025-09-08 20:09:50 +08:00
committed by GitHub
parent 1a82b16355
commit a041d4f328
21 changed files with 1052 additions and 932 deletions

View File

@@ -15,7 +15,7 @@
# limitations under the License.
#
from typing import Any, Callable, Optional
from typing import Callable, Optional
import torch
import torch_npu
@@ -28,118 +28,16 @@ from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, UnquantizedFusedMoEMethod)
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
AlltoAllCommImpl,
MC2CommImpl)
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.layers.experts_selector import select_experts
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
setup_token_dispatchers
from vllm_ascend.ops.moe.experts_selector import select_experts
from vllm_ascend.ops.moe.moe_comm_method import (AllGatherCommImpl,
AlltoAllCommImpl, MC2CommImpl)
from vllm_ascend.ops.moe.token_dispatcher import setup_token_dispatchers
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, vllm_version_is
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
def fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_int8_w8a8: bool = False,
use_int4_w4a8: bool = False,
global_num_experts: Optional[int] = None,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
# For TorchAir graph
is_torchair: bool = False,
# For Cube/Vector parallel
shared_experts: Optional[Any] = None,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
# For load balance
log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0,
) -> torch.Tensor:
# Check constraints
assert hidden_states.shape[1] == w1.shape[1], (
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[1]}")
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
assert w2.stride(-1) == 1, "Stride of last dimension must be 1"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
if (use_int8_w8a8 or use_int4_w4a8):
assert w1_scale is not None and w2_scale is not None, \
"INT8 quantization requires weight scales."
w1_scale = w1_scale.to(torch.float32)
down_scale = [w2_scale]
down_output_dtype = w2_scale.dtype
else:
down_scale = None
down_output_dtype = None
moe_comm_method = get_forward_context().moe_comm_method
assert moe_comm_method is not None, "Missing communication context"
num_experts = w1.shape[0]
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type = moe_comm_method.permute(
hidden_states, topk_ids, topk_weights, expert_map, num_experts,
use_int8_w8a8 or use_int4_w4a8)
gate_up_output = torch_npu.npu_grouped_matmul(
x=[permuted_hidden_states],
weight=[w1],
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=expert_tokens,
output_dtype=torch.int32 if use_int8_w8a8 else None,
)[0]
if (use_int8_w8a8 or use_int4_w4a8):
activated_output, activated_output_scale = torch_npu.npu_dequant_swiglu_quant(
x=gate_up_output,
weight_scale=w1_scale,
activation_scale=dynamic_scale,
bias=None,
quant_scale=None,
quant_offset=None,
group_index=expert_tokens,
activate_left=True,
quant_mode=1,
)
activated_output_scale = [activated_output_scale]
else:
activated_output = torch_npu.npu_swiglu(gate_up_output)
activated_output_scale = None
down_output = torch_npu.npu_grouped_matmul(
x=[activated_output],
weight=[w2],
scale=down_scale,
per_token_scale=activated_output_scale,
split_item=2,
group_list_type=group_list_type,
group_type=0,
group_list=expert_tokens,
output_dtype=down_output_dtype,
)[0]
moe_comm_method.unpermute(down_output, hidden_states)
return hidden_states
def fused_experts_moge(
hidden_states: torch.Tensor,
w1: torch.Tensor,
@@ -259,7 +157,7 @@ def forward_oot_v01011(
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
topk_weights, topk_ids, _ = select_experts(
topk_weights, topk_ids, row_idx = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
@@ -287,15 +185,15 @@ def forward_oot_v01011(
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
moe_comm_method = get_forward_context().moe_comm_method
return moe_comm_method.fused_experts(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
global_num_experts=global_num_experts,
expert_map=expert_map)
def forward_oot(
@@ -321,7 +219,7 @@ def forward_oot(
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
topk_weights, topk_ids, _ = select_experts(
topk_weights, topk_ids, row_idx = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
@@ -349,15 +247,15 @@ def forward_oot(
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input)
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
moe_comm_method = get_forward_context().moe_comm_method
return moe_comm_method.fused_experts(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
row_idx=row_idx,
global_num_experts=global_num_experts,
expert_map=expert_map)
def process_weights_after_loading(self, layer):