[refactor] replace scattered business kwargs with typed request objects and explicit stage boundaries (#7024)

### What this PR does / why we need it?
Refactor `vllm_ascend/ops/fused_moe` to replace scattered MoE business
`**kwargs` with typed request objects and explicit stage boundaries.

- Prepare, dispatch, MLP, and quant stages now have clearer ownership.
- Main MoE path no longer depends on business `kwargs.get(...)` lookups.
- Comm and dispatcher interfaces are request-only on the main path.
- UTs can assert stage-level fields directly instead of inferring
behavior indirectly.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed.

---------

Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
linfeng-yuan
2026-03-20 23:23:57 +08:00
committed by GitHub
parent c860535246
commit 88d03a783f
33 changed files with 2146 additions and 947 deletions

View File

@@ -28,6 +28,7 @@ from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD, maybe_trans_nz
from .base import AscendLinearScheme, AscendMoEScheme, QuantType
@@ -343,7 +344,10 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
enable_force_load_balance: bool = False,
log2phy: torch.Tensor | None = None,
global_redundant_expert_num: int = 0,
**kwargs,
pertoken_scale: torch.Tensor | None = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
mc2_mask: torch.Tensor | None = None,
) -> torch.Tensor:
assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, (
"Number of global experts mismatch (excluding redundancy)"
@@ -377,20 +381,26 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
moe_comm_method = _EXTRA_CTX.moe_comm_method
return moe_comm_method.fused_experts(
hidden_states=x,
w1=[layer.w13_weight],
w2=[layer.w2_weight],
w1_scale=[layer.w13_weight_scale],
w2_scale=[layer.w2_weight_scale],
w1_scale_bias=layer.w13_scale_bias if hasattr(layer, "w13_scale_bias") else None,
w2_scale_bias=layer.w2_scale_bias if hasattr(layer, "w2_scale_bias") else None,
topk_weights=topk_weights,
topk_ids=topk_ids,
use_int4_w4a8=True,
expert_map=expert_map,
log2phy=log2phy,
dynamic_eplb=self.dynamic_eplb,
mc2_mask=kwargs.get("mc2_mask"),
fused_experts_input=build_fused_experts_input(
hidden_states=x,
topk_weights=topk_weights,
topk_ids=topk_ids,
w1=[layer.w13_weight],
w2=[layer.w2_weight],
quant_type=self.quant_type,
dynamic_eplb=self.dynamic_eplb,
expert_map=expert_map,
global_redundant_expert_num=global_redundant_expert_num,
mc2_mask=mc2_mask,
apply_router_weight_on_input=apply_router_weight_on_input,
log2phy=log2phy,
pertoken_scale=pertoken_scale,
activation=activation,
w1_scale=[layer.w13_weight_scale],
w2_scale=[layer.w2_weight_scale],
w1_scale_bias=layer.w13_scale_bias if hasattr(layer, "w13_scale_bias") else None,
w2_scale_bias=layer.w2_scale_bias if hasattr(layer, "w2_scale_bias") else None,
)
)
def process_scale(self, weight: torch.Tensor, scale, per_group_scale):