[refactor] replace scattered business kwargs with typed request objects and explicit stage boundaries (#7024)
### What this PR does / why we need it? Refactor `vllm_ascend/ops/fused_moe` to replace scattered MoE business `**kwargs` with typed request objects and explicit stage boundaries. - Prepare, dispatch, MLP, and quant stages now have clearer ownership. - Main MoE path no longer depends on business `kwargs.get(...)` lookups. - Comm and dispatcher interfaces are request-only on the main path. - UTs can assert stage-level fields directly instead of inferring behavior indirectly. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed. --------- Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -25,7 +25,8 @@ from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
|
||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import zero_experts_compute
|
||||
from vllm_ascend.ops.fused_moe.moe_comm_method import FusedExpertsResult, _MoECommMethods
|
||||
from vllm_ascend.quantization.methods.base import QuantType
|
||||
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
|
||||
from vllm_ascend.quantization.quant_type import QuantType
|
||||
|
||||
from .experts_selector import select_experts
|
||||
from .moe_comm_method import AllGatherCommImpl310
|
||||
@@ -93,13 +94,17 @@ class AscendUnquantizedFusedMoEMethod310(UnquantizedFusedMoEMethod):
|
||||
|
||||
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
||||
final_hidden_states = moe_comm_method.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
fused_experts_input=build_fused_experts_input(
|
||||
hidden_states=x,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
quant_type=QuantType.NONE,
|
||||
dynamic_eplb=False,
|
||||
expert_map=expert_map,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
),
|
||||
)
|
||||
if zero_expert_num > 0 and zero_expert_type is not None:
|
||||
final_hidden_states += zero_expert_result
|
||||
@@ -218,9 +223,13 @@ class AscendFusedMoE310(FusedMoE):
|
||||
assert self.quant_method is not None
|
||||
assert self.routed_scaling_factor == 1.0, "routed_scaling_factor != 1.0 is not supported."
|
||||
|
||||
hidden_states, router_logits, _, context_metadata = _EXTRA_CTX.moe_comm_method.prepare(
|
||||
prepare_output = _EXTRA_CTX.moe_comm_method.prepare(
|
||||
hidden_states=hidden_states, router_logits=router_logits, quant_type=self.quant_type
|
||||
)
|
||||
hidden_states = prepare_output.hidden_states
|
||||
router_logits = prepare_output.router_logits
|
||||
pertoken_scale = prepare_output.pertoken_scale
|
||||
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
|
||||
|
||||
# Matrix multiply.
|
||||
fused_experts_results: FusedExpertsResult = self.quant_method.apply(
|
||||
@@ -238,12 +247,13 @@ class AscendFusedMoE310(FusedMoE):
|
||||
global_num_experts=self.global_num_experts,
|
||||
expert_map=self.local_expert_map,
|
||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||
pertoken_scale=pertoken_scale,
|
||||
)
|
||||
|
||||
routed_out = _EXTRA_CTX.moe_comm_method.finalize(
|
||||
hidden_states=fused_experts_results.routed_out,
|
||||
reduce_results=self.reduce_results,
|
||||
context_metadata=context_metadata,
|
||||
padded_hidden_states_shape=padded_hidden_states_shape,
|
||||
)
|
||||
|
||||
return routed_out
|
||||
|
||||
Reference in New Issue
Block a user