[refactor] replace scattered business kwargs with typed request objects and explicit stage boundaries (#7024)

### What this PR does / why we need it?
Refactor `vllm_ascend/ops/fused_moe` to replace scattered MoE business
`**kwargs` with typed request objects and explicit stage boundaries.

- Prepare, dispatch, MLP, and quant stages now have clearer ownership.
- Main MoE path no longer depends on business `kwargs.get(...)` lookups.
- Comm and dispatcher interfaces are request-only on the main path.
- UTs can assert stage-level fields directly instead of inferring
behavior indirectly.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed.

---------

Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
linfeng-yuan
2026-03-20 23:23:57 +08:00
committed by GitHub
parent c860535246
commit 88d03a783f
33 changed files with 2146 additions and 947 deletions

View File

@@ -31,7 +31,8 @@ from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl
from vllm_ascend.quantization.methods.base import QuantType
from vllm_ascend.ops.fused_moe.moe_runtime_args import MoEPrepareOutput
from vllm_ascend.quantization.quant_type import QuantType
from vllm_ascend.utils import enable_sp, npu_stream_switch, prefill_context_parallel_enable
@@ -64,7 +65,7 @@ class PrepareAndFinalize(ABC):
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
quant_type: QuantType = QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
) -> MoEPrepareOutput:
"""
Prepare tensors before MoE computation. May involve:
- Padding to align communication boundaries
@@ -79,16 +80,20 @@ class PrepareAndFinalize(ABC):
quant_type: none, w8a8, w4a8 or mxfp8
Returns:
Tuple of:
MoEPrepareOutput:
- processed hidden_states (may be padded/sliced/broadcasted)
- processed router_logits (may be recomputed or broadcasted)
- optional communication mask (e.g., mc2_mask for sparse ops)
- optional context metadata (e.g., saved split_hidden_states for finalization)
- optional padded hidden state shape for finalization
- optional per-token scale for quantized path
"""
raise NotImplementedError("Prepare not implemented.")
def finalize(
self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None
self,
hidden_states: torch.Tensor,
reduce_results: bool,
padded_hidden_states_shape: torch.Size | None = None,
) -> torch.Tensor:
"""
Finalize MoE output. May involve:
@@ -130,7 +135,7 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
quant_type=QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
) -> MoEPrepareOutput:
"""
Preparation steps:
1. Pad hidden_states and router_logits to next multiple of TP size.
@@ -140,7 +145,7 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
Returns:
Tuple of (hidden_states, router_logits, None, context_metadata) — no mask used in All2All.
MoEPrepareOutput where `mc2_mask` is None for All2All path.
"""
self.replace_allreduce = replace_allreduce
self.enable_shared_expert_dp = enable_shared_expert_dp
@@ -162,12 +167,19 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank]
context_metadata = {"padded_hidden_states_shape": padded_hidden_states_shape}
return hidden_states, router_logits, None, context_metadata
return MoEPrepareOutput(
hidden_states=hidden_states,
router_logits=router_logits,
mc2_mask=None,
padded_hidden_states_shape=padded_hidden_states_shape,
pertoken_scale=None,
)
def finalize(
self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None
self,
hidden_states: torch.Tensor,
reduce_results: bool,
padded_hidden_states_shape: torch.Size | None = None,
) -> torch.Tensor:
"""
Finalization steps:
@@ -180,12 +192,11 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
if not (self.enable_shared_expert_dp or self.replace_allreduce):
if self.tp_size > 1:
assert context_metadata is not None
assert padded_hidden_states_shape is not None
# Cannot reuse `split_hidden_states` from prepare phase as it
# may share memory with original hidden_states. Since shared
# experts may use the original tensor, reusing it would cause
# in-place modification during all_gather, corrupting the data.
padded_hidden_states_shape = context_metadata["padded_hidden_states_shape"]
gathered_hidden_states = torch.empty(
padded_hidden_states_shape, device=hidden_states.device, dtype=hidden_states.dtype
)
@@ -227,7 +238,7 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
quant_type=QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
) -> MoEPrepareOutput:
"""
Preparation steps:
1. Fetch `mc2_mask` and target padding length from forward context.
@@ -238,7 +249,7 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
Skips padding/slicing if `enable_shared_expert_dp` or `replace_allreduce` is True.
Returns:
Tuple of (hidden_states, router_logits, mc2_mask, context_metadata), possibly sliced/padded.
MoEPrepareOutput, possibly sliced/padded.
"""
self.replace_allreduce = replace_allreduce
self.enable_shared_expert_dp = enable_shared_expert_dp
@@ -267,11 +278,13 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank]
context_metadata = {
"padded_hidden_states_shape": padded_hidden_states_shape,
}
return hidden_states, router_logits, mc2_mask, context_metadata
return MoEPrepareOutput(
hidden_states=hidden_states,
router_logits=router_logits,
mc2_mask=mc2_mask,
padded_hidden_states_shape=padded_hidden_states_shape,
pertoken_scale=None,
)
class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
@@ -303,13 +316,13 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
quant_type=QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
) -> MoEPrepareOutput:
"""
Preparation steps:
AllGather hidden_states and router_logits to form global tensors.
Returns:
Tuple of (global_hidden_states, global_router_logits, None)
MoEPrepareOutput with global tensors.
"""
if enable_sp():
return self._prepare_with_ep_group(hidden_states, router_logits, quant_type)
@@ -318,7 +331,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
def _prepare_with_ep_group(
self, hidden_states: torch.Tensor, router_logits: torch.Tensor, quant_type=QuantType.NONE
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
) -> MoEPrepareOutput:
pertoken_scale = None
if quant_type == QuantType.W8A8:
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
@@ -342,10 +355,13 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
if self.multistream_overlap_gate:
torch.npu.current_stream().wait_stream(PrepareAndFinalize.quant_stream)
if pertoken_scale is not None:
return (hidden_states, pertoken_scale), router_logits, None, None
return hidden_states, router_logits, None, None
return MoEPrepareOutput(
hidden_states=hidden_states,
router_logits=router_logits,
mc2_mask=None,
padded_hidden_states_shape=None,
pertoken_scale=pertoken_scale,
)
def _prepare_with_dp_group(
self,
@@ -354,7 +370,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
quant_type=QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
) -> MoEPrepareOutput:
"""
Preparation steps:
1. Fetch max token count across DP group from forward context.
@@ -362,7 +378,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
3. All-gather across DP group to form global input tensor.
Returns:
Tuple of (global_hidden_states, global_router_logits, None, None)
MoEPrepareOutput with global tensors.
"""
self.enable_shared_expert_dp = enable_shared_expert_dp
if self.moe_config.dp_size > 1:
@@ -396,10 +412,19 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
dim=0,
)
return hidden_states, router_logits, None, None
return MoEPrepareOutput(
hidden_states=hidden_states,
router_logits=router_logits,
mc2_mask=None,
padded_hidden_states_shape=None,
pertoken_scale=None,
)
def finalize(
self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None
self,
hidden_states: torch.Tensor,
reduce_results: bool,
padded_hidden_states_shape: torch.Size | None = None,
) -> torch.Tensor:
"""
Finalization steps: