[refactor] replace scattered business kwargs with typed request objects and explicit stage boundaries (#7024)
### What this PR does / why we need it? Refactor `vllm_ascend/ops/fused_moe` to replace scattered MoE business `**kwargs` with typed request objects and explicit stage boundaries. - Prepare, dispatch, MLP, and quant stages now have clearer ownership. - Main MoE path no longer depends on business `kwargs.get(...)` lookups. - Comm and dispatcher interfaces are request-only on the main path. - UTs can assert stage-level fields directly instead of inferring behavior indirectly. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed. --------- Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -25,26 +25,27 @@
|
||||
import torch
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
|
||||
from vllm_ascend.ops.fused_moe.token_dispatcher import TokenDispatcherWithAllGather, TokenDispatchResult
|
||||
from vllm_ascend.ops.fused_moe.moe_runtime_args import MoEAllGatherCombineMetadata, MoETokenDispatchInput
|
||||
from vllm_ascend.ops.fused_moe.token_dispatcher import MoETokenDispatchOutput, TokenDispatcherWithAllGather
|
||||
|
||||
|
||||
class TokenDispatcherWithAllGather310(TokenDispatcherWithAllGather):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def token_dispatch( # type: ignore[override]
|
||||
def token_dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
token_dispatch_input: MoETokenDispatchInput,
|
||||
):
|
||||
self.original_shape = hidden_states.shape
|
||||
hidden_states = token_dispatch_input.hidden_states
|
||||
topk_weights = token_dispatch_input.topk_weights
|
||||
topk_ids = token_dispatch_input.topk_ids
|
||||
expert_map = token_dispatch_input.routing.expert_map
|
||||
apply_router_weight_on_input = token_dispatch_input.routing.apply_router_weight_on_input
|
||||
restore_shape = hidden_states.shape
|
||||
|
||||
num_tokens = hidden_states.shape[:-1].numel()
|
||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||
if self.apply_router_weight_on_input:
|
||||
if apply_router_weight_on_input:
|
||||
assert topk_weights.dim() == 2, "`topk_weights` should be in shape (num_tokens, topk)"
|
||||
_, topk = topk_weights.shape
|
||||
assert topk == 1, "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
@@ -66,13 +67,16 @@ class TokenDispatcherWithAllGather310(TokenDispatcherWithAllGather):
|
||||
)
|
||||
expert_tokens = expert_tokens.to(torch.int64)
|
||||
group_list_type = 1 # `count` mode
|
||||
context_metadata = {"topk_weights": topk_weights, "expanded_row_idx": expanded_row_idx}
|
||||
|
||||
return TokenDispatchResult(
|
||||
return MoETokenDispatchOutput(
|
||||
hidden_states=sorted_hidden_states,
|
||||
group_list=expert_tokens,
|
||||
group_list_type=group_list_type,
|
||||
context_metadata=context_metadata,
|
||||
combine_metadata=MoEAllGatherCombineMetadata(
|
||||
topk_weights=topk_weights,
|
||||
expanded_row_idx=expanded_row_idx,
|
||||
restore_shape=restore_shape,
|
||||
),
|
||||
)
|
||||
|
||||
def moe_init_routing(self, x, expert_idx, active_num, active_expert_range):
|
||||
|
||||
Reference in New Issue
Block a user