[refactor] replace scattered business kwargs with typed request objects and explicit stage boundaries (#7024)
### What this PR does / why we need it? Refactor `vllm_ascend/ops/fused_moe` to replace scattered MoE business `**kwargs` with typed request objects and explicit stage boundaries. - Prepare, dispatch, MLP, and quant stages now have clearer ownership. - Main MoE path no longer depends on business `kwargs.get(...)` lookups. - Comm and dispatcher interfaces are request-only on the main path. - UTs can assert stage-level fields directly instead of inferring behavior indirectly. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed. --------- Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
244
vllm_ascend/ops/fused_moe/moe_runtime_args.py
Normal file
244
vllm_ascend/ops/fused_moe/moe_runtime_args.py
Normal file
@@ -0,0 +1,244 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Typed runtime contracts and builders for fused MoE execution.
|
||||
|
||||
This module is the single entry point for the runtime payloads used across the
|
||||
fused MoE pipeline.
|
||||
|
||||
Relationship overview:
|
||||
|
||||
stage params: reusable sub-payloads
|
||||
- MoERoutingParams
|
||||
- MoEQuantParams
|
||||
- internal MXFP leaf: MoEMxfpParams
|
||||
|
||||
stage contracts: stage input/output payloads
|
||||
prepare
|
||||
-> MoEPrepareOutput
|
||||
|
||||
fused_experts input
|
||||
-> MoEFusedExpertsInput
|
||||
|- weights: MoEWeights
|
||||
|- routing: MoERoutingParams
|
||||
|- quant: MoEQuantParams
|
||||
|
||||
dispatch
|
||||
input -> MoETokenDispatchInput
|
||||
output -> MoETokenDispatchOutput[TMoECombineMetadata]
|
||||
TMoECombineMetadata is one of:
|
||||
- MoEAllGatherCombineMetadata
|
||||
- MoEAllToAllCombineMetadata
|
||||
- MoEMC2CombineMetadata
|
||||
|
||||
mlp
|
||||
input -> MoEMlpComputeInput
|
||||
|
||||
combine
|
||||
output -> torch.Tensor
|
||||
|
||||
The helper builders below adapt legacy call sites into these typed contracts.
|
||||
Only the fused_moe package should need to know about the internal MXFP leaf
|
||||
dataclass directly.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
import vllm_ascend.ops.fused_moe.moe_stage_params as _stage_params
|
||||
from vllm_ascend.ops.fused_moe.moe_stage_contracts import (
|
||||
MoEAllGatherCombineMetadata,
|
||||
MoEAllToAllCombineMetadata,
|
||||
MoEFusedExpertsInput,
|
||||
MoEMC2CombineMetadata,
|
||||
MoEMlpComputeInput,
|
||||
MoEPrepareOutput,
|
||||
MoETokenDispatchInput,
|
||||
MoETokenDispatchOutput,
|
||||
MoEWeights,
|
||||
TMoECombineMetadata,
|
||||
)
|
||||
from vllm_ascend.ops.fused_moe.moe_stage_params import (
|
||||
MoEQuantParams,
|
||||
MoERoutingParams,
|
||||
)
|
||||
from vllm_ascend.quantization.quant_type import QuantType
|
||||
|
||||
|
||||
def _build_mxfp_params(
|
||||
*,
|
||||
quant_type: QuantType,
|
||||
mxfp_act_quant_type: torch.dtype | None = None,
|
||||
mxfp_weight_quant_type: torch.dtype | None = None,
|
||||
mxfp_scale_dtype: torch.dtype | None = None,
|
||||
mxfp_per_token_scale_dtype: torch.dtype | None = None,
|
||||
mxfp_use_bf16: bool | None = None,
|
||||
) -> _stage_params.MoEMxfpParams | None:
|
||||
if quant_type != QuantType.MXFP8:
|
||||
return None
|
||||
|
||||
has_explicit_mxfp_args = any(
|
||||
value is not None
|
||||
for value in (
|
||||
mxfp_act_quant_type,
|
||||
mxfp_weight_quant_type,
|
||||
mxfp_scale_dtype,
|
||||
mxfp_per_token_scale_dtype,
|
||||
mxfp_use_bf16,
|
||||
)
|
||||
)
|
||||
if not has_explicit_mxfp_args:
|
||||
raise ValueError("primitive MXFP params are required when quant_type is QuantType.MXFP8.")
|
||||
|
||||
return _stage_params.MoEMxfpParams(
|
||||
act_quant_type=mxfp_act_quant_type,
|
||||
weight_quant_type=mxfp_weight_quant_type,
|
||||
scale_dtype=mxfp_scale_dtype,
|
||||
per_token_scale_dtype=mxfp_per_token_scale_dtype,
|
||||
use_bf16=True if mxfp_use_bf16 is None else mxfp_use_bf16,
|
||||
)
|
||||
|
||||
|
||||
def build_fused_experts_input(
|
||||
*,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
w1: torch.Tensor | list[torch.Tensor],
|
||||
w2: torch.Tensor | list[torch.Tensor],
|
||||
quant_type: QuantType,
|
||||
dynamic_eplb: bool,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
mc2_mask: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
log2phy: torch.Tensor | None = None,
|
||||
pertoken_scale: torch.Tensor | None = None,
|
||||
activation: str = "silu",
|
||||
need_trans: bool = False,
|
||||
w1_bias: torch.Tensor | None = None,
|
||||
w2_bias: torch.Tensor | None = None,
|
||||
comm_quant_mode: int | None = None,
|
||||
mxfp_act_quant_type: torch.dtype | None = None,
|
||||
mxfp_weight_quant_type: torch.dtype | None = None,
|
||||
mxfp_scale_dtype: torch.dtype | None = None,
|
||||
mxfp_per_token_scale_dtype: torch.dtype | None = None,
|
||||
mxfp_use_bf16: bool | None = None,
|
||||
w1_scale: list[torch.Tensor] | torch.Tensor | None = None,
|
||||
w2_scale: list[torch.Tensor] | torch.Tensor | None = None,
|
||||
w1_scale_bias: torch.Tensor | None = None,
|
||||
w2_scale_bias: torch.Tensor | None = None,
|
||||
w1_offset: torch.Tensor | None = None,
|
||||
w2_offset: torch.Tensor | None = None,
|
||||
) -> MoEFusedExpertsInput:
|
||||
return MoEFusedExpertsInput(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
weights=MoEWeights(
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
w1_offset=w1_offset,
|
||||
w2_offset=w2_offset,
|
||||
),
|
||||
routing=MoERoutingParams(
|
||||
expert_map=expert_map,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
mc2_mask=mc2_mask,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
log2phy=log2phy,
|
||||
pertoken_scale=pertoken_scale,
|
||||
),
|
||||
activation=activation,
|
||||
need_trans=need_trans,
|
||||
dynamic_eplb=dynamic_eplb,
|
||||
quant=MoEQuantParams(
|
||||
quant_type=quant_type,
|
||||
comm_quant_mode=comm_quant_mode,
|
||||
mxfp=_build_mxfp_params(
|
||||
quant_type=quant_type,
|
||||
mxfp_act_quant_type=mxfp_act_quant_type,
|
||||
mxfp_weight_quant_type=mxfp_weight_quant_type,
|
||||
mxfp_scale_dtype=mxfp_scale_dtype,
|
||||
mxfp_per_token_scale_dtype=mxfp_per_token_scale_dtype,
|
||||
mxfp_use_bf16=mxfp_use_bf16,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def build_token_dispatch_input(
|
||||
*,
|
||||
fused_experts_input: MoEFusedExpertsInput,
|
||||
topk_ids: torch.Tensor | None = None,
|
||||
) -> MoETokenDispatchInput:
|
||||
return MoETokenDispatchInput(
|
||||
hidden_states=fused_experts_input.hidden_states,
|
||||
topk_weights=fused_experts_input.topk_weights,
|
||||
topk_ids=fused_experts_input.topk_ids if topk_ids is None else topk_ids,
|
||||
routing=fused_experts_input.routing,
|
||||
quant=fused_experts_input.quant,
|
||||
)
|
||||
|
||||
|
||||
def build_mlp_compute_input(
|
||||
*,
|
||||
fused_experts_input: MoEFusedExpertsInput,
|
||||
token_dispatch_output: MoETokenDispatchOutput[TMoECombineMetadata],
|
||||
use_fusion_ops: bool,
|
||||
) -> MoEMlpComputeInput:
|
||||
if fused_experts_input.quant.is_mxfp and fused_experts_input.quant.mxfp is None:
|
||||
raise ValueError("fused_experts_input.quant.mxfp is required when quant_type is QuantType.MXFP8.")
|
||||
|
||||
return MoEMlpComputeInput(
|
||||
hidden_states=token_dispatch_output.hidden_states,
|
||||
group_list=token_dispatch_output.group_list,
|
||||
group_list_type=token_dispatch_output.group_list_type,
|
||||
dynamic_scale=token_dispatch_output.dynamic_scale,
|
||||
topk_scales=token_dispatch_output.topk_scales,
|
||||
weights=fused_experts_input.weights,
|
||||
quant=fused_experts_input.quant,
|
||||
fusion=fused_experts_input.quant.quant_type in (QuantType.W8A8, QuantType.MXFP8) and use_fusion_ops,
|
||||
activation=fused_experts_input.activation,
|
||||
need_trans=fused_experts_input.need_trans,
|
||||
dynamic_eplb=fused_experts_input.dynamic_eplb,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"MoEAllGatherCombineMetadata",
|
||||
"MoEAllToAllCombineMetadata",
|
||||
"MoEFusedExpertsInput",
|
||||
"MoEMC2CombineMetadata",
|
||||
"MoEMlpComputeInput",
|
||||
"MoEPrepareOutput",
|
||||
"MoEQuantParams",
|
||||
"MoERoutingParams",
|
||||
"MoETokenDispatchInput",
|
||||
"MoETokenDispatchOutput",
|
||||
"MoEWeights",
|
||||
"TMoECombineMetadata",
|
||||
"build_fused_experts_input",
|
||||
"build_token_dispatch_input",
|
||||
"build_mlp_compute_input",
|
||||
]
|
||||
Reference in New Issue
Block a user