[refactor] replace scattered business kwargs with typed request objects and explicit stage boundaries (#7024)
### What this PR does / why we need it? Refactor `vllm_ascend/ops/fused_moe` to replace scattered MoE business `**kwargs` with typed request objects and explicit stage boundaries. - Prepare, dispatch, MLP, and quant stages now have clearer ownership. - Main MoE path no longer depends on business `kwargs.get(...)` lookups. - Comm and dispatcher interfaces are request-only on the main path. - UTs can assert stage-level fields directly instead of inferring behavior indirectly. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed. --------- Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -31,7 +31,8 @@ from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||
from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl
|
||||
from vllm_ascend.quantization.methods.base import QuantType
|
||||
from vllm_ascend.ops.fused_moe.moe_runtime_args import MoEPrepareOutput
|
||||
from vllm_ascend.quantization.quant_type import QuantType
|
||||
from vllm_ascend.utils import enable_sp, npu_stream_switch, prefill_context_parallel_enable
|
||||
|
||||
|
||||
@@ -64,7 +65,7 @@ class PrepareAndFinalize(ABC):
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
quant_type: QuantType = QuantType.NONE,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||
) -> MoEPrepareOutput:
|
||||
"""
|
||||
Prepare tensors before MoE computation. May involve:
|
||||
- Padding to align communication boundaries
|
||||
@@ -79,16 +80,20 @@ class PrepareAndFinalize(ABC):
|
||||
quant_type: none, w8a8, w4a8 or mxfp8
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
MoEPrepareOutput:
|
||||
- processed hidden_states (may be padded/sliced/broadcasted)
|
||||
- processed router_logits (may be recomputed or broadcasted)
|
||||
- optional communication mask (e.g., mc2_mask for sparse ops)
|
||||
- optional context metadata (e.g., saved split_hidden_states for finalization)
|
||||
- optional padded hidden state shape for finalization
|
||||
- optional per-token scale for quantized path
|
||||
"""
|
||||
raise NotImplementedError("Prepare not implemented.")
|
||||
|
||||
def finalize(
|
||||
self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
reduce_results: bool,
|
||||
padded_hidden_states_shape: torch.Size | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Finalize MoE output. May involve:
|
||||
@@ -130,7 +135,7 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
quant_type=QuantType.NONE,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||
) -> MoEPrepareOutput:
|
||||
"""
|
||||
Preparation steps:
|
||||
1. Pad hidden_states and router_logits to next multiple of TP size.
|
||||
@@ -140,7 +145,7 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
|
||||
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
|
||||
|
||||
Returns:
|
||||
Tuple of (hidden_states, router_logits, None, context_metadata) — no mask used in All2All.
|
||||
MoEPrepareOutput where `mc2_mask` is None for All2All path.
|
||||
"""
|
||||
self.replace_allreduce = replace_allreduce
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
@@ -162,12 +167,19 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
|
||||
hidden_states = split_hidden_states[self.tp_rank]
|
||||
router_logits = split_router_logits[self.tp_rank]
|
||||
|
||||
context_metadata = {"padded_hidden_states_shape": padded_hidden_states_shape}
|
||||
|
||||
return hidden_states, router_logits, None, context_metadata
|
||||
return MoEPrepareOutput(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
mc2_mask=None,
|
||||
padded_hidden_states_shape=padded_hidden_states_shape,
|
||||
pertoken_scale=None,
|
||||
)
|
||||
|
||||
def finalize(
|
||||
self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
reduce_results: bool,
|
||||
padded_hidden_states_shape: torch.Size | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Finalization steps:
|
||||
@@ -180,12 +192,11 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
|
||||
|
||||
if not (self.enable_shared_expert_dp or self.replace_allreduce):
|
||||
if self.tp_size > 1:
|
||||
assert context_metadata is not None
|
||||
assert padded_hidden_states_shape is not None
|
||||
# Cannot reuse `split_hidden_states` from prepare phase as it
|
||||
# may share memory with original hidden_states. Since shared
|
||||
# experts may use the original tensor, reusing it would cause
|
||||
# in-place modification during all_gather, corrupting the data.
|
||||
padded_hidden_states_shape = context_metadata["padded_hidden_states_shape"]
|
||||
gathered_hidden_states = torch.empty(
|
||||
padded_hidden_states_shape, device=hidden_states.device, dtype=hidden_states.dtype
|
||||
)
|
||||
@@ -227,7 +238,7 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
quant_type=QuantType.NONE,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||
) -> MoEPrepareOutput:
|
||||
"""
|
||||
Preparation steps:
|
||||
1. Fetch `mc2_mask` and target padding length from forward context.
|
||||
@@ -238,7 +249,7 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
|
||||
Skips padding/slicing if `enable_shared_expert_dp` or `replace_allreduce` is True.
|
||||
|
||||
Returns:
|
||||
Tuple of (hidden_states, router_logits, mc2_mask, context_metadata), possibly sliced/padded.
|
||||
MoEPrepareOutput, possibly sliced/padded.
|
||||
"""
|
||||
self.replace_allreduce = replace_allreduce
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
@@ -267,11 +278,13 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
|
||||
hidden_states = split_hidden_states[self.tp_rank]
|
||||
router_logits = split_router_logits[self.tp_rank]
|
||||
|
||||
context_metadata = {
|
||||
"padded_hidden_states_shape": padded_hidden_states_shape,
|
||||
}
|
||||
|
||||
return hidden_states, router_logits, mc2_mask, context_metadata
|
||||
return MoEPrepareOutput(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
mc2_mask=mc2_mask,
|
||||
padded_hidden_states_shape=padded_hidden_states_shape,
|
||||
pertoken_scale=None,
|
||||
)
|
||||
|
||||
|
||||
class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
@@ -303,13 +316,13 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
quant_type=QuantType.NONE,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||
) -> MoEPrepareOutput:
|
||||
"""
|
||||
Preparation steps:
|
||||
AllGather hidden_states and router_logits to form global tensors.
|
||||
|
||||
Returns:
|
||||
Tuple of (global_hidden_states, global_router_logits, None)
|
||||
MoEPrepareOutput with global tensors.
|
||||
"""
|
||||
if enable_sp():
|
||||
return self._prepare_with_ep_group(hidden_states, router_logits, quant_type)
|
||||
@@ -318,7 +331,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
|
||||
def _prepare_with_ep_group(
|
||||
self, hidden_states: torch.Tensor, router_logits: torch.Tensor, quant_type=QuantType.NONE
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||
) -> MoEPrepareOutput:
|
||||
pertoken_scale = None
|
||||
if quant_type == QuantType.W8A8:
|
||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
||||
@@ -342,10 +355,13 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
if self.multistream_overlap_gate:
|
||||
torch.npu.current_stream().wait_stream(PrepareAndFinalize.quant_stream)
|
||||
|
||||
if pertoken_scale is not None:
|
||||
return (hidden_states, pertoken_scale), router_logits, None, None
|
||||
|
||||
return hidden_states, router_logits, None, None
|
||||
return MoEPrepareOutput(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
mc2_mask=None,
|
||||
padded_hidden_states_shape=None,
|
||||
pertoken_scale=pertoken_scale,
|
||||
)
|
||||
|
||||
def _prepare_with_dp_group(
|
||||
self,
|
||||
@@ -354,7 +370,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
quant_type=QuantType.NONE,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||
) -> MoEPrepareOutput:
|
||||
"""
|
||||
Preparation steps:
|
||||
1. Fetch max token count across DP group from forward context.
|
||||
@@ -362,7 +378,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
3. All-gather across DP group to form global input tensor.
|
||||
|
||||
Returns:
|
||||
Tuple of (global_hidden_states, global_router_logits, None, None)
|
||||
MoEPrepareOutput with global tensors.
|
||||
"""
|
||||
self.enable_shared_expert_dp = enable_shared_expert_dp
|
||||
if self.moe_config.dp_size > 1:
|
||||
@@ -396,10 +412,19 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
dim=0,
|
||||
)
|
||||
|
||||
return hidden_states, router_logits, None, None
|
||||
return MoEPrepareOutput(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
mc2_mask=None,
|
||||
padded_hidden_states_shape=None,
|
||||
pertoken_scale=None,
|
||||
)
|
||||
|
||||
def finalize(
|
||||
self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
reduce_results: bool,
|
||||
padded_hidden_states_shape: torch.Size | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Finalization steps:
|
||||
|
||||
Reference in New Issue
Block a user