[refactor] replace scattered business kwargs with typed request objects and explicit stage boundaries (#7024)
### What this PR does / why we need it? Refactor `vllm_ascend/ops/fused_moe` to replace scattered MoE business `**kwargs` with typed request objects and explicit stage boundaries. - Prepare, dispatch, MLP, and quant stages now have clearer ownership. - Main MoE path no longer depends on business `kwargs.get(...)` lookups. - Comm and dispatcher interfaces are request-only on the main path. - UTs can assert stage-level fields directly instead of inferring behavior indirectly. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed. --------- Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -255,28 +255,34 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
|
||||
enable_force_load_balance: bool = False,
|
||||
log2phy: torch.Tensor | None = None,
|
||||
global_redundant_expert_num=0,
|
||||
**kwargs,
|
||||
pertoken_scale: torch.Tensor | None = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
mc2_mask: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
return self.quant_method.apply(
|
||||
layer,
|
||||
x,
|
||||
router_logits,
|
||||
top_k,
|
||||
renormalize,
|
||||
use_grouped_topk,
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
custom_routing_function,
|
||||
scoring_func,
|
||||
routed_scaling_factor,
|
||||
e_score_correction_bias,
|
||||
is_prefill,
|
||||
enable_force_load_balance,
|
||||
log2phy,
|
||||
global_redundant_expert_num,
|
||||
**kwargs,
|
||||
layer=layer,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
is_prefill=is_prefill,
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
pertoken_scale=pertoken_scale,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
mc2_mask=mc2_mask,
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
|
||||
Reference in New Issue
Block a user