[refactor] replace scattered business kwargs with typed request objects and explicit stage boundaries (#7024)

### What this PR does / why we need it?
Refactor `vllm_ascend/ops/fused_moe` to replace scattered MoE business
`**kwargs` with typed request objects and explicit stage boundaries.

- Prepare, dispatch, MLP, and quant stages now have clearer ownership.
- Main MoE path no longer depends on business `kwargs.get(...)` lookups.
- Comm and dispatcher interfaces are request-only on the main path.
- UTs can assert stage-level fields directly instead of inferring
behavior indirectly.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed.

---------

Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
linfeng-yuan
2026-03-20 23:23:57 +08:00
committed by GitHub
parent c860535246
commit 88d03a783f
33 changed files with 2146 additions and 947 deletions

View File

@@ -25,7 +25,8 @@ from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
from vllm_ascend.ops.fused_moe.experts_selector import zero_experts_compute
from vllm_ascend.ops.fused_moe.moe_comm_method import FusedExpertsResult, _MoECommMethods
from vllm_ascend.quantization.methods.base import QuantType
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
from vllm_ascend.quantization.quant_type import QuantType
from .experts_selector import select_experts
from .moe_comm_method import AllGatherCommImpl310
@@ -93,13 +94,17 @@ class AscendUnquantizedFusedMoEMethod310(UnquantizedFusedMoEMethod):
moe_comm_method = _EXTRA_CTX.moe_comm_method
final_hidden_states = moe_comm_method.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
fused_experts_input=build_fused_experts_input(
hidden_states=x,
topk_weights=topk_weights,
topk_ids=topk_ids,
w1=layer.w13_weight,
w2=layer.w2_weight,
quant_type=QuantType.NONE,
dynamic_eplb=False,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
),
)
if zero_expert_num > 0 and zero_expert_type is not None:
final_hidden_states += zero_expert_result
@@ -218,9 +223,13 @@ class AscendFusedMoE310(FusedMoE):
assert self.quant_method is not None
assert self.routed_scaling_factor == 1.0, "routed_scaling_factor != 1.0 is not supported."
hidden_states, router_logits, _, context_metadata = _EXTRA_CTX.moe_comm_method.prepare(
prepare_output = _EXTRA_CTX.moe_comm_method.prepare(
hidden_states=hidden_states, router_logits=router_logits, quant_type=self.quant_type
)
hidden_states = prepare_output.hidden_states
router_logits = prepare_output.router_logits
pertoken_scale = prepare_output.pertoken_scale
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
# Matrix multiply.
fused_experts_results: FusedExpertsResult = self.quant_method.apply(
@@ -238,12 +247,13 @@ class AscendFusedMoE310(FusedMoE):
global_num_experts=self.global_num_experts,
expert_map=self.local_expert_map,
apply_router_weight_on_input=self.apply_router_weight_on_input,
pertoken_scale=pertoken_scale,
)
routed_out = _EXTRA_CTX.moe_comm_method.finalize(
hidden_states=fused_experts_results.routed_out,
reduce_results=self.reduce_results,
context_metadata=context_metadata,
padded_hidden_states_shape=padded_hidden_states_shape,
)
return routed_out

View File

@@ -17,8 +17,8 @@ from __future__ import annotations
import torch
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl, FusedExpertsResult
from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl
from vllm_ascend.ops.fused_moe.moe_runtime_args import MoEMlpComputeInput
from .moe_mlp import unified_apply_mlp
from .token_dispatcher import TokenDispatcherWithAllGather310
@@ -35,52 +35,12 @@ class AllGatherCommImpl310(AllGatherCommImpl):
to handle the token-to-expert mapping and communication efficiently.
"""
def fused_experts( # type: ignore[override]
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor | None = None,
use_int8_w8a8: bool = False,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
) -> FusedExpertsResult:
# This method is overridden to use the 310p-specific unified_apply_mlp
# which provides optimized MLP computation for the 310p platform
moe_comm_method = _EXTRA_CTX.moe_comm_method
assert moe_comm_method is not None, "Missing communication context"
def __init__(self, moe_config):
super().__init__(moe_config)
self.use_fusion_ops = False
dispatch_results = self.token_dispatcher.token_dispatch(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
mlp_output = unified_apply_mlp(
hidden_states=dispatch_results.hidden_states,
w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale,
group_list=dispatch_results.group_list,
group_list_type=dispatch_results.group_list_type,
with_quant=use_int8_w8a8,
)
combine_results = self.token_dispatcher.token_combine(
hidden_states=mlp_output, context_metadata=dispatch_results.context_metadata
)
return FusedExpertsResult(
routed_out=combine_results.routed_out,
group_list_type=dispatch_results.group_list_type,
expert_tokens=dispatch_results.group_list,
)
def _apply_mlp(self, mlp_compute_input: MoEMlpComputeInput) -> torch.Tensor:
return unified_apply_mlp(mlp_compute_input=mlp_compute_input)
def _get_token_dispatcher(self):
return TokenDispatcherWithAllGather310(

View File

@@ -18,6 +18,8 @@
import torch
import torch_npu
from vllm_ascend.ops.fused_moe.moe_runtime_args import MoEMlpComputeInput
def quant_apply_mlp(
hidden_states: torch.Tensor,
@@ -66,17 +68,20 @@ def unquant_apply_mlp(
return hidden_states
def unified_apply_mlp(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
group_list: torch.Tensor,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
group_list_type: int = 1,
with_quant: bool = False,
) -> torch.Tensor:
if with_quant:
def unified_apply_mlp(*, mlp_compute_input: MoEMlpComputeInput) -> torch.Tensor:
hidden_states = mlp_compute_input.hidden_states
w1 = mlp_compute_input.weights.w1
w2 = mlp_compute_input.weights.w2
w1_scale = mlp_compute_input.weights.w1_scale
w2_scale = mlp_compute_input.weights.w2_scale
group_list = mlp_compute_input.group_list
group_list_type = mlp_compute_input.group_list_type
assert isinstance(w1, torch.Tensor)
assert isinstance(w2, torch.Tensor)
if mlp_compute_input.quant.is_quant:
assert isinstance(w1_scale, torch.Tensor)
assert isinstance(w2_scale, torch.Tensor)
assert w1_scale is not None and w2_scale is not None
return quant_apply_mlp(
hidden_states=hidden_states,
@@ -87,7 +92,11 @@ def unified_apply_mlp(
group_list=group_list,
group_list_type=group_list_type,
)
else:
return unquant_apply_mlp(
hidden_states=hidden_states, w1=w1, w2=w2, group_list=group_list, group_list_type=group_list_type
)
return unquant_apply_mlp(
hidden_states=hidden_states,
w1=w1,
w2=w2,
group_list=group_list,
group_list_type=group_list_type,
)

View File

@@ -25,26 +25,27 @@
import torch
from vllm.distributed.parallel_state import get_ep_group
from vllm_ascend.ops.fused_moe.token_dispatcher import TokenDispatcherWithAllGather, TokenDispatchResult
from vllm_ascend.ops.fused_moe.moe_runtime_args import MoEAllGatherCombineMetadata, MoETokenDispatchInput
from vllm_ascend.ops.fused_moe.token_dispatcher import MoETokenDispatchOutput, TokenDispatcherWithAllGather
class TokenDispatcherWithAllGather310(TokenDispatcherWithAllGather):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def token_dispatch( # type: ignore[override]
def token_dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
token_dispatch_input: MoETokenDispatchInput,
):
self.original_shape = hidden_states.shape
hidden_states = token_dispatch_input.hidden_states
topk_weights = token_dispatch_input.topk_weights
topk_ids = token_dispatch_input.topk_ids
expert_map = token_dispatch_input.routing.expert_map
apply_router_weight_on_input = token_dispatch_input.routing.apply_router_weight_on_input
restore_shape = hidden_states.shape
num_tokens = hidden_states.shape[:-1].numel()
self.apply_router_weight_on_input = apply_router_weight_on_input
if self.apply_router_weight_on_input:
if apply_router_weight_on_input:
assert topk_weights.dim() == 2, "`topk_weights` should be in shape (num_tokens, topk)"
_, topk = topk_weights.shape
assert topk == 1, "Only support topk=1 when `apply_router_weight_on_input` is True"
@@ -66,13 +67,16 @@ class TokenDispatcherWithAllGather310(TokenDispatcherWithAllGather):
)
expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 1 # `count` mode
context_metadata = {"topk_weights": topk_weights, "expanded_row_idx": expanded_row_idx}
return TokenDispatchResult(
return MoETokenDispatchOutput(
hidden_states=sorted_hidden_states,
group_list=expert_tokens,
group_list_type=group_list_type,
context_metadata=context_metadata,
combine_metadata=MoEAllGatherCombineMetadata(
topk_weights=topk_weights,
expanded_row_idx=expanded_row_idx,
restore_shape=restore_shape,
),
)
def moe_init_routing(self, x, expert_idx, active_num, active_expert_range):

View File

@@ -25,6 +25,7 @@ from vllm.distributed import get_ep_group
from vllm_ascend._310p.fused_moe.experts_selector import select_experts
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.ops.fused_moe.experts_selector import zero_experts_compute
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
from vllm_ascend.quantization.methods.base import AscendMoEScheme, QuantType
from .registry import register_scheme
@@ -95,7 +96,9 @@ class AscendW8A8DynamicFusedMoEMethod310(AscendMoEScheme):
log2phy: torch.Tensor | None = None,
global_redundant_expert_num: int = 0,
pertoken_scale: Any | None = None,
**kwargs,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
mc2_mask: torch.Tensor | None = None,
) -> torch.Tensor:
zero_expert_num = getattr(layer, "zero_expert_num", 0)
zero_expert_type = getattr(layer, "zero_expert_type", None)
@@ -128,15 +131,19 @@ class AscendW8A8DynamicFusedMoEMethod310(AscendMoEScheme):
moe_comm_method = _EXTRA_CTX.moe_comm_method
final_hidden_states = moe_comm_method.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale,
w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
use_int8_w8a8=True,
fused_experts_input=build_fused_experts_input(
hidden_states=x,
topk_weights=topk_weights,
topk_ids=topk_ids,
w1=layer.w13_weight,
w2=layer.w2_weight,
quant_type=self.quant_type,
dynamic_eplb=False,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
),
)
if zero_expert_num > 0 and zero_expert_type is not None:
final_hidden_states += zero_expert_result