[refactor] replace scattered business kwargs with typed request objects and explicit stage boundaries (#7024)
### What this PR does / why we need it? Refactor `vllm_ascend/ops/fused_moe` to replace scattered MoE business `**kwargs` with typed request objects and explicit stage boundaries. - Prepare, dispatch, MLP, and quant stages now have clearer ownership. - Main MoE path no longer depends on business `kwargs.get(...)` lookups. - Comm and dispatcher interfaces are request-only on the main path. - UTs can assert stage-level fields directly instead of inferring behavior indirectly. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed. --------- Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -25,7 +25,8 @@ from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
|
||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import zero_experts_compute
|
||||
from vllm_ascend.ops.fused_moe.moe_comm_method import FusedExpertsResult, _MoECommMethods
|
||||
from vllm_ascend.quantization.methods.base import QuantType
|
||||
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
|
||||
from vllm_ascend.quantization.quant_type import QuantType
|
||||
|
||||
from .experts_selector import select_experts
|
||||
from .moe_comm_method import AllGatherCommImpl310
|
||||
@@ -93,13 +94,17 @@ class AscendUnquantizedFusedMoEMethod310(UnquantizedFusedMoEMethod):
|
||||
|
||||
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
||||
final_hidden_states = moe_comm_method.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
fused_experts_input=build_fused_experts_input(
|
||||
hidden_states=x,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
quant_type=QuantType.NONE,
|
||||
dynamic_eplb=False,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
),
|
||||
)
|
||||
if zero_expert_num > 0 and zero_expert_type is not None:
|
||||
final_hidden_states += zero_expert_result
|
||||
@@ -218,9 +223,13 @@ class AscendFusedMoE310(FusedMoE):
|
||||
assert self.quant_method is not None
|
||||
assert self.routed_scaling_factor == 1.0, "routed_scaling_factor != 1.0 is not supported."
|
||||
|
||||
hidden_states, router_logits, _, context_metadata = _EXTRA_CTX.moe_comm_method.prepare(
|
||||
prepare_output = _EXTRA_CTX.moe_comm_method.prepare(
|
||||
hidden_states=hidden_states, router_logits=router_logits, quant_type=self.quant_type
|
||||
)
|
||||
hidden_states = prepare_output.hidden_states
|
||||
router_logits = prepare_output.router_logits
|
||||
pertoken_scale = prepare_output.pertoken_scale
|
||||
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
|
||||
|
||||
# Matrix multiply.
|
||||
fused_experts_results: FusedExpertsResult = self.quant_method.apply(
|
||||
@@ -238,12 +247,13 @@ class AscendFusedMoE310(FusedMoE):
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_map=self.local_expert_map,
|
||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||
pertoken_scale=pertoken_scale,
|
||||
)
|
||||
|
||||
routed_out = _EXTRA_CTX.moe_comm_method.finalize(
|
||||
hidden_states=fused_experts_results.routed_out,
|
||||
reduce_results=self.reduce_results,
|
||||
context_metadata=context_metadata,
|
||||
padded_hidden_states_shape=padded_hidden_states_shape,
|
||||
)
|
||||
|
||||
return routed_out
|
||||
|
||||
@@ -17,8 +17,8 @@ from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||
from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl, FusedExpertsResult
|
||||
from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl
|
||||
from vllm_ascend.ops.fused_moe.moe_runtime_args import MoEMlpComputeInput
|
||||
|
||||
from .moe_mlp import unified_apply_mlp
|
||||
from .token_dispatcher import TokenDispatcherWithAllGather310
|
||||
@@ -35,52 +35,12 @@ class AllGatherCommImpl310(AllGatherCommImpl):
|
||||
to handle the token-to-expert mapping and communication efficiently.
|
||||
"""
|
||||
|
||||
def fused_experts( # type: ignore[override]
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
use_int8_w8a8: bool = False,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
) -> FusedExpertsResult:
|
||||
# This method is overridden to use the 310p-specific unified_apply_mlp
|
||||
# which provides optimized MLP computation for the 310p platform
|
||||
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
||||
assert moe_comm_method is not None, "Missing communication context"
|
||||
def __init__(self, moe_config):
|
||||
super().__init__(moe_config)
|
||||
self.use_fusion_ops = False
|
||||
|
||||
dispatch_results = self.token_dispatcher.token_dispatch(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
||||
mlp_output = unified_apply_mlp(
|
||||
hidden_states=dispatch_results.hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
group_list=dispatch_results.group_list,
|
||||
group_list_type=dispatch_results.group_list_type,
|
||||
with_quant=use_int8_w8a8,
|
||||
)
|
||||
|
||||
combine_results = self.token_dispatcher.token_combine(
|
||||
hidden_states=mlp_output, context_metadata=dispatch_results.context_metadata
|
||||
)
|
||||
|
||||
return FusedExpertsResult(
|
||||
routed_out=combine_results.routed_out,
|
||||
group_list_type=dispatch_results.group_list_type,
|
||||
expert_tokens=dispatch_results.group_list,
|
||||
)
|
||||
def _apply_mlp(self, mlp_compute_input: MoEMlpComputeInput) -> torch.Tensor:
|
||||
return unified_apply_mlp(mlp_compute_input=mlp_compute_input)
|
||||
|
||||
def _get_token_dispatcher(self):
|
||||
return TokenDispatcherWithAllGather310(
|
||||
|
||||
@@ -18,6 +18,8 @@
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
from vllm_ascend.ops.fused_moe.moe_runtime_args import MoEMlpComputeInput
|
||||
|
||||
|
||||
def quant_apply_mlp(
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -66,17 +68,20 @@ def unquant_apply_mlp(
|
||||
return hidden_states
|
||||
|
||||
|
||||
def unified_apply_mlp(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
group_list_type: int = 1,
|
||||
with_quant: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if with_quant:
|
||||
def unified_apply_mlp(*, mlp_compute_input: MoEMlpComputeInput) -> torch.Tensor:
|
||||
hidden_states = mlp_compute_input.hidden_states
|
||||
w1 = mlp_compute_input.weights.w1
|
||||
w2 = mlp_compute_input.weights.w2
|
||||
w1_scale = mlp_compute_input.weights.w1_scale
|
||||
w2_scale = mlp_compute_input.weights.w2_scale
|
||||
group_list = mlp_compute_input.group_list
|
||||
group_list_type = mlp_compute_input.group_list_type
|
||||
assert isinstance(w1, torch.Tensor)
|
||||
assert isinstance(w2, torch.Tensor)
|
||||
|
||||
if mlp_compute_input.quant.is_quant:
|
||||
assert isinstance(w1_scale, torch.Tensor)
|
||||
assert isinstance(w2_scale, torch.Tensor)
|
||||
assert w1_scale is not None and w2_scale is not None
|
||||
return quant_apply_mlp(
|
||||
hidden_states=hidden_states,
|
||||
@@ -87,7 +92,11 @@ def unified_apply_mlp(
|
||||
group_list=group_list,
|
||||
group_list_type=group_list_type,
|
||||
)
|
||||
else:
|
||||
return unquant_apply_mlp(
|
||||
hidden_states=hidden_states, w1=w1, w2=w2, group_list=group_list, group_list_type=group_list_type
|
||||
)
|
||||
|
||||
return unquant_apply_mlp(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
group_list=group_list,
|
||||
group_list_type=group_list_type,
|
||||
)
|
||||
|
||||
@@ -25,26 +25,27 @@
|
||||
import torch
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
|
||||
from vllm_ascend.ops.fused_moe.token_dispatcher import TokenDispatcherWithAllGather, TokenDispatchResult
|
||||
from vllm_ascend.ops.fused_moe.moe_runtime_args import MoEAllGatherCombineMetadata, MoETokenDispatchInput
|
||||
from vllm_ascend.ops.fused_moe.token_dispatcher import MoETokenDispatchOutput, TokenDispatcherWithAllGather
|
||||
|
||||
|
||||
class TokenDispatcherWithAllGather310(TokenDispatcherWithAllGather):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def token_dispatch( # type: ignore[override]
|
||||
def token_dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
token_dispatch_input: MoETokenDispatchInput,
|
||||
):
|
||||
self.original_shape = hidden_states.shape
|
||||
hidden_states = token_dispatch_input.hidden_states
|
||||
topk_weights = token_dispatch_input.topk_weights
|
||||
topk_ids = token_dispatch_input.topk_ids
|
||||
expert_map = token_dispatch_input.routing.expert_map
|
||||
apply_router_weight_on_input = token_dispatch_input.routing.apply_router_weight_on_input
|
||||
restore_shape = hidden_states.shape
|
||||
|
||||
num_tokens = hidden_states.shape[:-1].numel()
|
||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||
if self.apply_router_weight_on_input:
|
||||
if apply_router_weight_on_input:
|
||||
assert topk_weights.dim() == 2, "`topk_weights` should be in shape (num_tokens, topk)"
|
||||
_, topk = topk_weights.shape
|
||||
assert topk == 1, "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
@@ -66,13 +67,16 @@ class TokenDispatcherWithAllGather310(TokenDispatcherWithAllGather):
|
||||
)
|
||||
expert_tokens = expert_tokens.to(torch.int64)
|
||||
group_list_type = 1 # `count` mode
|
||||
context_metadata = {"topk_weights": topk_weights, "expanded_row_idx": expanded_row_idx}
|
||||
|
||||
return TokenDispatchResult(
|
||||
return MoETokenDispatchOutput(
|
||||
hidden_states=sorted_hidden_states,
|
||||
group_list=expert_tokens,
|
||||
group_list_type=group_list_type,
|
||||
context_metadata=context_metadata,
|
||||
combine_metadata=MoEAllGatherCombineMetadata(
|
||||
topk_weights=topk_weights,
|
||||
expanded_row_idx=expanded_row_idx,
|
||||
restore_shape=restore_shape,
|
||||
),
|
||||
)
|
||||
|
||||
def moe_init_routing(self, x, expert_idx, active_num, active_expert_range):
|
||||
|
||||
@@ -25,6 +25,7 @@ from vllm.distributed import get_ep_group
|
||||
from vllm_ascend._310p.fused_moe.experts_selector import select_experts
|
||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import zero_experts_compute
|
||||
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
|
||||
from vllm_ascend.quantization.methods.base import AscendMoEScheme, QuantType
|
||||
|
||||
from .registry import register_scheme
|
||||
@@ -95,7 +96,9 @@ class AscendW8A8DynamicFusedMoEMethod310(AscendMoEScheme):
|
||||
log2phy: torch.Tensor | None = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
pertoken_scale: Any | None = None,
|
||||
**kwargs,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
mc2_mask: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
zero_expert_num = getattr(layer, "zero_expert_num", 0)
|
||||
zero_expert_type = getattr(layer, "zero_expert_type", None)
|
||||
@@ -128,15 +131,19 @@ class AscendW8A8DynamicFusedMoEMethod310(AscendMoEScheme):
|
||||
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
||||
|
||||
final_hidden_states = moe_comm_method.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2=layer.w2_weight,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
use_int8_w8a8=True,
|
||||
fused_experts_input=build_fused_experts_input(
|
||||
hidden_states=x,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
quant_type=self.quant_type,
|
||||
dynamic_eplb=False,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
),
|
||||
)
|
||||
if zero_expert_num > 0 and zero_expert_type is not None:
|
||||
final_hidden_states += zero_expert_result
|
||||
|
||||
Reference in New Issue
Block a user