[refactor] replace scattered business kwargs with typed request objects and explicit stage boundaries (#7024)
### What this PR does / why we need it? Refactor `vllm_ascend/ops/fused_moe` to replace scattered MoE business `**kwargs` with typed request objects and explicit stage boundaries. - Prepare, dispatch, MLP, and quant stages now have clearer ownership. - Main MoE path no longer depends on business `kwargs.get(...)` lookups. - Comm and dispatcher interfaces are request-only on the main path. - UTs can assert stage-level fields directly instead of inferring behavior indirectly. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed. --------- Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -24,6 +24,13 @@ from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
|
||||
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
|
||||
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
|
||||
MoEFusedExpertsInput,
|
||||
MoEMlpComputeInput,
|
||||
MoEPrepareOutput,
|
||||
build_mlp_compute_input,
|
||||
build_token_dispatch_input,
|
||||
)
|
||||
from vllm_ascend.ops.fused_moe.prepare_finalize import (
|
||||
PrepareAndFinalize,
|
||||
PrepareAndFinalizeWithAll2All,
|
||||
@@ -36,8 +43,7 @@ from vllm_ascend.ops.fused_moe.token_dispatcher import (
|
||||
TokenDispatcherWithAllGather,
|
||||
TokenDispatcherWithMC2,
|
||||
)
|
||||
from vllm_ascend.quantization.methods.base import QuantType
|
||||
from vllm_ascend.quantization.quant_parser import parse_mxfp_quant_params
|
||||
from vllm_ascend.quantization.quant_type import QuantType
|
||||
|
||||
_MoECommMethods: dict[MoECommType | None, MoECommMethod] = {}
|
||||
|
||||
@@ -90,131 +96,70 @@ class MoECommMethod(ABC):
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
quant_type: QuantType = QuantType.NONE,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||
hidden_states, router_logits, mc2_mask, context_metadata = self.prepare_finalize.prepare(
|
||||
hidden_states, router_logits, enable_shared_expert_dp, replace_allreduce, quant_type
|
||||
) -> MoEPrepareOutput:
|
||||
return self.prepare_finalize.prepare(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
enable_shared_expert_dp,
|
||||
replace_allreduce,
|
||||
quant_type,
|
||||
)
|
||||
return hidden_states, router_logits, mc2_mask, context_metadata
|
||||
|
||||
def finalize(
|
||||
self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
reduce_results: bool,
|
||||
padded_hidden_states_shape: torch.Size | None = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.prepare_finalize.finalize(hidden_states, reduce_results, context_metadata)
|
||||
hidden_states = self.prepare_finalize.finalize(hidden_states, reduce_results, padded_hidden_states_shape)
|
||||
return hidden_states
|
||||
|
||||
def fused_experts(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor | list[torch.Tensor],
|
||||
w2: torch.Tensor | list[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
w1_bias: torch.Tensor = None,
|
||||
w2_bias: torch.Tensor = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int4_w4a8: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
w1_scale: list[torch.Tensor] | None = None,
|
||||
w2_scale: list[torch.Tensor] | None = None,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
w1_offset: torch.Tensor | None = None,
|
||||
w2_offset: torch.Tensor | None = None,
|
||||
# For load balance
|
||||
log2phy: torch.Tensor = None,
|
||||
need_trans: bool = False,
|
||||
dynamic_eplb: bool = False,
|
||||
mc2_mask: torch.Tensor = None,
|
||||
pertoken_scale: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
fused_experts_input: MoEFusedExpertsInput,
|
||||
):
|
||||
# Check constraints
|
||||
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16, torch.int8]
|
||||
assert fused_experts_input.hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16, torch.int8]
|
||||
|
||||
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
||||
assert moe_comm_method is not None, "Missing communication context"
|
||||
|
||||
before_dispatch_evt = torch.npu.current_stream().record_event()
|
||||
# Apply log2phy if needed
|
||||
if log2phy is not None:
|
||||
topk_ids = log2phy[topk_ids]
|
||||
# TODO(linfeng): Current massive parameter passing is quite severe; parameter differences introduced
|
||||
# by different quantization modes will be consolidated into a dataclass in a follow-up.
|
||||
use_mxfp_quant = kwargs.get("use_mxfp_quant", False)
|
||||
dispatch_with_quant = use_int8_w8a8 or use_int4_w4a8 or use_mxfp_quant
|
||||
act_quant_type, weight_quant_type, scale_type, per_token_scale_type, round_mode = parse_mxfp_quant_params(
|
||||
**kwargs
|
||||
routed_topk_ids = fused_experts_input.topk_ids
|
||||
if fused_experts_input.routing.log2phy is not None:
|
||||
routed_topk_ids = fused_experts_input.routing.log2phy[routed_topk_ids]
|
||||
|
||||
token_dispatch_input = build_token_dispatch_input(
|
||||
fused_experts_input=fused_experts_input,
|
||||
topk_ids=routed_topk_ids,
|
||||
)
|
||||
token_dispatch_output = self.token_dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
|
||||
|
||||
mlp_compute_input = build_mlp_compute_input(
|
||||
fused_experts_input=fused_experts_input,
|
||||
token_dispatch_output=token_dispatch_output,
|
||||
use_fusion_ops=self.use_fusion_ops,
|
||||
)
|
||||
|
||||
dispatch_kwargs = {
|
||||
"hidden_states": hidden_states,
|
||||
"topk_weights": topk_weights,
|
||||
"topk_ids": topk_ids,
|
||||
"expert_map": expert_map,
|
||||
"global_redundant_expert_num": self.moe_config.global_redundant_expert_num,
|
||||
"mc2_mask": mc2_mask,
|
||||
"apply_router_weight_on_input": apply_router_weight_on_input,
|
||||
"dynamic_eplb": dynamic_eplb,
|
||||
"pertoken_scale": pertoken_scale,
|
||||
}
|
||||
|
||||
if isinstance(self.token_dispatcher, TokenDispatcherWithMC2):
|
||||
dispatch_kwargs["with_quant"] = dispatch_with_quant
|
||||
dispatch_kwargs["comm_quant_mode"] = kwargs.get("comm_quant_mode")
|
||||
dispatch_kwargs["y_dtype"] = act_quant_type if use_mxfp_quant else None
|
||||
dispatch_kwargs["use_mxfp_quant"] = use_mxfp_quant
|
||||
else:
|
||||
dispatch_kwargs["with_quant"] = use_int8_w8a8 or use_int4_w4a8
|
||||
|
||||
dispatch_results = self.token_dispatcher.token_dispatch(**dispatch_kwargs)
|
||||
|
||||
mlp_output = unified_apply_mlp(
|
||||
hidden_states=dispatch_results.hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
w1_bias=w1_bias,
|
||||
w2_bias=w2_bias,
|
||||
activation=activation,
|
||||
group_list=dispatch_results.group_list,
|
||||
dynamic_scale=dispatch_results.dynamic_scale,
|
||||
group_list_type=dispatch_results.group_list_type,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
w1_offset=w1_offset,
|
||||
w2_offset=w2_offset,
|
||||
topk_scales=dispatch_results.topk_scales,
|
||||
with_quant=use_int8_w8a8 or use_int4_w4a8 or use_int4_w4a16 or use_mxfp_quant,
|
||||
fusion=(use_int8_w8a8 or use_mxfp_quant) and self.use_fusion_ops,
|
||||
need_trans=need_trans,
|
||||
dynamic_eplb=dynamic_eplb,
|
||||
use_mxfp_quant=use_mxfp_quant,
|
||||
act_quant_type=act_quant_type,
|
||||
weight_quant_type=weight_quant_type,
|
||||
scale_type=scale_type,
|
||||
per_token_scale_type=per_token_scale_type,
|
||||
round_mode=round_mode,
|
||||
use_bf16=(hidden_states.dtype == torch.bfloat16),
|
||||
rollback_quant_config=kwargs.get("rollback_quant_config"),
|
||||
)
|
||||
mlp_output = self._apply_mlp(mlp_compute_input)
|
||||
|
||||
before_combine_evt = torch.npu.current_stream().record_event()
|
||||
combine_results = self.token_dispatcher.token_combine(
|
||||
hidden_states=mlp_output, context_metadata=dispatch_results.context_metadata
|
||||
routed_out = self.token_dispatcher.token_combine(
|
||||
hidden_states=mlp_output,
|
||||
combine_metadata=token_dispatch_output.combine_metadata,
|
||||
)
|
||||
|
||||
return FusedExpertsResult(
|
||||
routed_out=combine_results.routed_out,
|
||||
routed_out=routed_out,
|
||||
before_dispatch_evt=before_dispatch_evt,
|
||||
before_combine_evt=before_combine_evt,
|
||||
group_list_type=dispatch_results.group_list_type,
|
||||
expert_tokens=dispatch_results.group_list,
|
||||
group_list_type=token_dispatch_output.group_list_type,
|
||||
expert_tokens=token_dispatch_output.group_list,
|
||||
)
|
||||
|
||||
def _apply_mlp(self, mlp_compute_input: MoEMlpComputeInput) -> torch.Tensor:
|
||||
return unified_apply_mlp(mlp_compute_input=mlp_compute_input)
|
||||
|
||||
@abstractmethod
|
||||
def _get_token_dispatcher(self) -> MoETokenDispatcher:
|
||||
raise NotImplementedError("_get_token_dispatcher function not implemented.")
|
||||
@@ -317,54 +262,32 @@ class FusedMC2CommImpl(MoECommMethod):
|
||||
|
||||
def fused_experts(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor | list[torch.Tensor],
|
||||
w2: torch.Tensor | list[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
w1_bias: torch.Tensor = None,
|
||||
w2_bias: torch.Tensor = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int4_w4a8: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
w1_scale: list[torch.Tensor] | None = None,
|
||||
w2_scale: list[torch.Tensor] | None = None,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
w1_offset: torch.Tensor | None = None,
|
||||
w2_offset: torch.Tensor | None = None,
|
||||
# For load balance
|
||||
log2phy: torch.Tensor = None,
|
||||
need_trans: bool = False,
|
||||
dynamic_eplb: bool = False,
|
||||
mc2_mask: torch.Tensor = None,
|
||||
pertoken_scale: torch.Tensor | None = None,
|
||||
**kwargs,
|
||||
fused_experts_input: MoEFusedExpertsInput,
|
||||
):
|
||||
assert not (w1_scale is None or w2_scale is None), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl."
|
||||
assert not (fused_experts_input.weights.w1_scale is None or fused_experts_input.weights.w2_scale is None), (
|
||||
"w1_scale and w2_scale cannot be None for FusedMC2CommImpl."
|
||||
)
|
||||
|
||||
assert isinstance(self.token_dispatcher, TokenDispatcherWithMC2), (
|
||||
"token_dispatcher must be an instance of TokenDispatcherWithMC2."
|
||||
)
|
||||
|
||||
# Apply log2phy if needed
|
||||
if log2phy is not None:
|
||||
topk_ids = log2phy[topk_ids]
|
||||
topk_ids = fused_experts_input.topk_ids
|
||||
if fused_experts_input.routing.log2phy is not None:
|
||||
topk_ids = fused_experts_input.routing.log2phy[topk_ids]
|
||||
|
||||
expert_tokens = None
|
||||
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
|
||||
out = torch.empty_like(hidden_states)
|
||||
out = torch.empty_like(fused_experts_input.hidden_states)
|
||||
torch.ops._C_ascend.dispatch_ffn_combine( # type: ignore
|
||||
x=hidden_states,
|
||||
weight1=w1,
|
||||
weight2=w2,
|
||||
x=fused_experts_input.hidden_states,
|
||||
weight1=fused_experts_input.weights.w1,
|
||||
weight2=fused_experts_input.weights.w2,
|
||||
expert_idx=topk_ids,
|
||||
scale1=w1_scale,
|
||||
scale2=w2_scale,
|
||||
probs=topk_weights.to(torch.float32),
|
||||
scale1=fused_experts_input.weights.w1_scale,
|
||||
scale2=fused_experts_input.weights.w2_scale,
|
||||
probs=fused_experts_input.topk_weights.to(torch.float32),
|
||||
group=self.token_dispatcher.moe_all_to_all_group_name,
|
||||
max_output_size=65536,
|
||||
out=out,
|
||||
@@ -372,16 +295,16 @@ class FusedMC2CommImpl(MoECommMethod):
|
||||
)
|
||||
expert_tokens = self.expert_token_nums
|
||||
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
|
||||
assert expert_map is not None, "expert_map cannot be None."
|
||||
assert fused_experts_input.routing.expert_map is not None, "expert_map cannot be None."
|
||||
out, expert_tokens = torch.ops._C_ascend.dispatch_gmm_combine_decode( # type: ignore
|
||||
x=hidden_states,
|
||||
x=fused_experts_input.hidden_states,
|
||||
expert_ids=topk_ids,
|
||||
gmm1_permuted_weight=w1,
|
||||
gmm1_permuted_weight_scale=w1_scale,
|
||||
gmm2_weight=w2,
|
||||
gmm2_weight_scale=w2_scale,
|
||||
gmm1_permuted_weight=fused_experts_input.weights.w1,
|
||||
gmm1_permuted_weight_scale=fused_experts_input.weights.w1_scale,
|
||||
gmm2_weight=fused_experts_input.weights.w2,
|
||||
gmm2_weight_scale=fused_experts_input.weights.w2_scale,
|
||||
expert_smooth_scales=None,
|
||||
expert_scales=topk_weights.to(torch.float32),
|
||||
expert_scales=fused_experts_input.topk_weights.to(torch.float32),
|
||||
group_ep=self.token_dispatcher.moe_all_to_all_group_name,
|
||||
ep_rank_size=self.token_dispatcher.ep_world_size,
|
||||
ep_rank_id=self.token_dispatcher.ep_rank_id,
|
||||
|
||||
Reference in New Issue
Block a user