[refactor] replace scattered business kwargs with typed request objects and explicit stage boundaries (#7024)

### What this PR does / why we need it?
Refactor `vllm_ascend/ops/fused_moe` to replace scattered MoE business
`**kwargs` with typed request objects and explicit stage boundaries.

- Prepare, dispatch, MLP, and quant stages now have clearer ownership.
- Main MoE path no longer depends on business `kwargs.get(...)` lookups.
- Comm and dispatcher interfaces are request-only on the main path.
- UTs can assert stage-level fields directly instead of inferring
behavior indirectly.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
CI passed.

---------

Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
linfeng-yuan
2026-03-20 23:23:57 +08:00
committed by GitHub
parent c860535246
commit 88d03a783f
33 changed files with 2146 additions and 947 deletions

View File

@@ -25,7 +25,8 @@ from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
from vllm_ascend.ops.fused_moe.experts_selector import zero_experts_compute
from vllm_ascend.ops.fused_moe.moe_comm_method import FusedExpertsResult, _MoECommMethods
from vllm_ascend.quantization.methods.base import QuantType
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
from vllm_ascend.quantization.quant_type import QuantType
from .experts_selector import select_experts
from .moe_comm_method import AllGatherCommImpl310
@@ -93,13 +94,17 @@ class AscendUnquantizedFusedMoEMethod310(UnquantizedFusedMoEMethod):
moe_comm_method = _EXTRA_CTX.moe_comm_method
final_hidden_states = moe_comm_method.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
fused_experts_input=build_fused_experts_input(
hidden_states=x,
topk_weights=topk_weights,
topk_ids=topk_ids,
w1=layer.w13_weight,
w2=layer.w2_weight,
quant_type=QuantType.NONE,
dynamic_eplb=False,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
),
)
if zero_expert_num > 0 and zero_expert_type is not None:
final_hidden_states += zero_expert_result
@@ -218,9 +223,13 @@ class AscendFusedMoE310(FusedMoE):
assert self.quant_method is not None
assert self.routed_scaling_factor == 1.0, "routed_scaling_factor != 1.0 is not supported."
hidden_states, router_logits, _, context_metadata = _EXTRA_CTX.moe_comm_method.prepare(
prepare_output = _EXTRA_CTX.moe_comm_method.prepare(
hidden_states=hidden_states, router_logits=router_logits, quant_type=self.quant_type
)
hidden_states = prepare_output.hidden_states
router_logits = prepare_output.router_logits
pertoken_scale = prepare_output.pertoken_scale
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
# Matrix multiply.
fused_experts_results: FusedExpertsResult = self.quant_method.apply(
@@ -238,12 +247,13 @@ class AscendFusedMoE310(FusedMoE):
global_num_experts=self.global_num_experts,
expert_map=self.local_expert_map,
apply_router_weight_on_input=self.apply_router_weight_on_input,
pertoken_scale=pertoken_scale,
)
routed_out = _EXTRA_CTX.moe_comm_method.finalize(
hidden_states=fused_experts_results.routed_out,
reduce_results=self.reduce_results,
context_metadata=context_metadata,
padded_hidden_states_shape=padded_hidden_states_shape,
)
return routed_out

View File

@@ -17,8 +17,8 @@ from __future__ import annotations
import torch
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl, FusedExpertsResult
from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl
from vllm_ascend.ops.fused_moe.moe_runtime_args import MoEMlpComputeInput
from .moe_mlp import unified_apply_mlp
from .token_dispatcher import TokenDispatcherWithAllGather310
@@ -35,52 +35,12 @@ class AllGatherCommImpl310(AllGatherCommImpl):
to handle the token-to-expert mapping and communication efficiently.
"""
def fused_experts( # type: ignore[override]
self,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor | None = None,
use_int8_w8a8: bool = False,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
) -> FusedExpertsResult:
# This method is overridden to use the 310p-specific unified_apply_mlp
# which provides optimized MLP computation for the 310p platform
moe_comm_method = _EXTRA_CTX.moe_comm_method
assert moe_comm_method is not None, "Missing communication context"
def __init__(self, moe_config):
super().__init__(moe_config)
self.use_fusion_ops = False
dispatch_results = self.token_dispatcher.token_dispatch(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
)
mlp_output = unified_apply_mlp(
hidden_states=dispatch_results.hidden_states,
w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale,
group_list=dispatch_results.group_list,
group_list_type=dispatch_results.group_list_type,
with_quant=use_int8_w8a8,
)
combine_results = self.token_dispatcher.token_combine(
hidden_states=mlp_output, context_metadata=dispatch_results.context_metadata
)
return FusedExpertsResult(
routed_out=combine_results.routed_out,
group_list_type=dispatch_results.group_list_type,
expert_tokens=dispatch_results.group_list,
)
def _apply_mlp(self, mlp_compute_input: MoEMlpComputeInput) -> torch.Tensor:
return unified_apply_mlp(mlp_compute_input=mlp_compute_input)
def _get_token_dispatcher(self):
return TokenDispatcherWithAllGather310(

View File

@@ -18,6 +18,8 @@
import torch
import torch_npu
from vllm_ascend.ops.fused_moe.moe_runtime_args import MoEMlpComputeInput
def quant_apply_mlp(
hidden_states: torch.Tensor,
@@ -66,17 +68,20 @@ def unquant_apply_mlp(
return hidden_states
def unified_apply_mlp(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
group_list: torch.Tensor,
w1_scale: torch.Tensor | None = None,
w2_scale: torch.Tensor | None = None,
group_list_type: int = 1,
with_quant: bool = False,
) -> torch.Tensor:
if with_quant:
def unified_apply_mlp(*, mlp_compute_input: MoEMlpComputeInput) -> torch.Tensor:
hidden_states = mlp_compute_input.hidden_states
w1 = mlp_compute_input.weights.w1
w2 = mlp_compute_input.weights.w2
w1_scale = mlp_compute_input.weights.w1_scale
w2_scale = mlp_compute_input.weights.w2_scale
group_list = mlp_compute_input.group_list
group_list_type = mlp_compute_input.group_list_type
assert isinstance(w1, torch.Tensor)
assert isinstance(w2, torch.Tensor)
if mlp_compute_input.quant.is_quant:
assert isinstance(w1_scale, torch.Tensor)
assert isinstance(w2_scale, torch.Tensor)
assert w1_scale is not None and w2_scale is not None
return quant_apply_mlp(
hidden_states=hidden_states,
@@ -87,7 +92,11 @@ def unified_apply_mlp(
group_list=group_list,
group_list_type=group_list_type,
)
else:
return unquant_apply_mlp(
hidden_states=hidden_states, w1=w1, w2=w2, group_list=group_list, group_list_type=group_list_type
)
return unquant_apply_mlp(
hidden_states=hidden_states,
w1=w1,
w2=w2,
group_list=group_list,
group_list_type=group_list_type,
)

View File

@@ -25,26 +25,27 @@
import torch
from vllm.distributed.parallel_state import get_ep_group
from vllm_ascend.ops.fused_moe.token_dispatcher import TokenDispatcherWithAllGather, TokenDispatchResult
from vllm_ascend.ops.fused_moe.moe_runtime_args import MoEAllGatherCombineMetadata, MoETokenDispatchInput
from vllm_ascend.ops.fused_moe.token_dispatcher import MoETokenDispatchOutput, TokenDispatcherWithAllGather
class TokenDispatcherWithAllGather310(TokenDispatcherWithAllGather):
def __init__(self, **kwargs):
super().__init__(**kwargs)
def token_dispatch( # type: ignore[override]
def token_dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
token_dispatch_input: MoETokenDispatchInput,
):
self.original_shape = hidden_states.shape
hidden_states = token_dispatch_input.hidden_states
topk_weights = token_dispatch_input.topk_weights
topk_ids = token_dispatch_input.topk_ids
expert_map = token_dispatch_input.routing.expert_map
apply_router_weight_on_input = token_dispatch_input.routing.apply_router_weight_on_input
restore_shape = hidden_states.shape
num_tokens = hidden_states.shape[:-1].numel()
self.apply_router_weight_on_input = apply_router_weight_on_input
if self.apply_router_weight_on_input:
if apply_router_weight_on_input:
assert topk_weights.dim() == 2, "`topk_weights` should be in shape (num_tokens, topk)"
_, topk = topk_weights.shape
assert topk == 1, "Only support topk=1 when `apply_router_weight_on_input` is True"
@@ -66,13 +67,16 @@ class TokenDispatcherWithAllGather310(TokenDispatcherWithAllGather):
)
expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 1 # `count` mode
context_metadata = {"topk_weights": topk_weights, "expanded_row_idx": expanded_row_idx}
return TokenDispatchResult(
return MoETokenDispatchOutput(
hidden_states=sorted_hidden_states,
group_list=expert_tokens,
group_list_type=group_list_type,
context_metadata=context_metadata,
combine_metadata=MoEAllGatherCombineMetadata(
topk_weights=topk_weights,
expanded_row_idx=expanded_row_idx,
restore_shape=restore_shape,
),
)
def moe_init_routing(self, x, expert_idx, active_num, active_expert_range):

View File

@@ -25,6 +25,7 @@ from vllm.distributed import get_ep_group
from vllm_ascend._310p.fused_moe.experts_selector import select_experts
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.ops.fused_moe.experts_selector import zero_experts_compute
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
from vllm_ascend.quantization.methods.base import AscendMoEScheme, QuantType
from .registry import register_scheme
@@ -95,7 +96,9 @@ class AscendW8A8DynamicFusedMoEMethod310(AscendMoEScheme):
log2phy: torch.Tensor | None = None,
global_redundant_expert_num: int = 0,
pertoken_scale: Any | None = None,
**kwargs,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
mc2_mask: torch.Tensor | None = None,
) -> torch.Tensor:
zero_expert_num = getattr(layer, "zero_expert_num", 0)
zero_expert_type = getattr(layer, "zero_expert_type", None)
@@ -128,15 +131,19 @@ class AscendW8A8DynamicFusedMoEMethod310(AscendMoEScheme):
moe_comm_method = _EXTRA_CTX.moe_comm_method
final_hidden_states = moe_comm_method.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale,
w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
use_int8_w8a8=True,
fused_experts_input=build_fused_experts_input(
hidden_states=x,
topk_weights=topk_weights,
topk_ids=topk_ids,
w1=layer.w13_weight,
w2=layer.w2_weight,
quant_type=self.quant_type,
dynamic_eplb=False,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
),
)
if zero_expert_num > 0 and zero_expert_type is not None:
final_hidden_states += zero_expert_result

View File

@@ -41,7 +41,8 @@ from vllm_ascend.eplb.core.eplb_utils import init_eplb_config
from vllm_ascend.flash_common3_context import get_flash_common3_context, set_flash_common3_context
from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_experts_compute
from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl, FusedExpertsResult, setup_moe_comm_method
from vllm_ascend.quantization.methods.base import QuantType
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
from vllm_ascend.quantization.quant_type import QuantType
from vllm_ascend.utils import (
ACL_FORMAT_FRACTAL_NZ,
enable_sp,
@@ -113,7 +114,9 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
activation: str = "silu",
enable_force_load_balance: bool = False,
log2phy: torch.Tensor = None,
**kwargs,
global_redundant_expert_num: int = 0,
pertoken_scale: torch.Tensor | None = None,
mc2_mask: torch.Tensor | None = None,
) -> torch.Tensor:
zero_expert_num = getattr(layer, "zero_expert_num", 0)
zero_expert_type = getattr(layer, "zero_expert_type", None)
@@ -167,7 +170,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
# (due to signature constraints), we are forced to use a placeholder empty tensor.
# This TODO tracks the requirement to update the C++ operator to accept Optional[Tensor]
# or None for scales in non-quantized scenarios.
if get_forward_context().moe_comm_type == MoECommType.FUSED_MC2:
if _EXTRA_CTX.moe_comm_type == MoECommType.FUSED_MC2:
w1 = [layer.w13_weight]
w1_scale = [torch.tensor([], dtype=torch.int64)]
w2 = [layer.w2_weight]
@@ -179,21 +182,26 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
w2_scale = None
final_hidden_states = moe_comm_method.fused_experts(
hidden_states=x,
w1=w1,
w2=w2,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_bias=layer.w13_bias if self.moe.has_bias else None,
w2_bias=layer.w2_bias if self.moe.has_bias else None,
activation=activation,
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
apply_router_weight_on_input=apply_router_weight_on_input,
dynamic_eplb=self.dynamic_eplb,
log2phy=log2phy,
mc2_mask=kwargs.get("mc2_mask"),
fused_experts_input=build_fused_experts_input(
hidden_states=x,
topk_weights=topk_weights,
topk_ids=topk_ids,
w1=w1,
w2=w2,
w1_bias=layer.w13_bias if self.moe.has_bias else None,
w2_bias=layer.w2_bias if self.moe.has_bias else None,
quant_type=QuantType.NONE,
dynamic_eplb=self.dynamic_eplb,
expert_map=expert_map,
global_redundant_expert_num=global_redundant_expert_num,
mc2_mask=mc2_mask,
apply_router_weight_on_input=apply_router_weight_on_input,
log2phy=log2phy,
pertoken_scale=pertoken_scale,
activation=activation,
w1_scale=w1_scale,
w2_scale=w2_scale,
)
)
if zero_expert_num > 0 and zero_expert_type is not None:
final_hidden_states += zero_expert_result
@@ -474,23 +482,23 @@ class AscendFusedMoE(FusedMoE):
set_flash_common3_context(topk_weights=topk_weights, topk_ids=topk_ids)
hidden_states, router_logits, mc2_mask, context_metadata = _EXTRA_CTX.moe_comm_method.prepare(
prepare_output = _EXTRA_CTX.moe_comm_method.prepare(
hidden_states=hidden_states,
router_logits=router_logits,
replace_allreduce=_EXTRA_CTX.flash_comm_v1_enabled,
enable_shared_expert_dp=self.enable_shared_expert_dp,
quant_type=self.quant_type,
)
hidden_states = prepare_output.hidden_states
router_logits = prepare_output.router_logits
mc2_mask = prepare_output.mc2_mask
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
pertoken_scale = prepare_output.pertoken_scale
# Make sure the default stream waits for the gate stream to finish.
if self.multistream_overlap_gate:
torch.npu.current_stream().wait_stream(AscendFusedMoE.gate_stream)
if isinstance(hidden_states, tuple):
hidden_states, pertoken_scale = hidden_states
else:
pertoken_scale = None
# Matrix multiply.
fused_experts_results: FusedExpertsResult = self.quant_method.apply(
layer=self,
@@ -538,7 +546,7 @@ class AscendFusedMoE(FusedMoE):
routed_out = _EXTRA_CTX.moe_comm_method.finalize(
hidden_states=fused_experts_results.routed_out,
reduce_results=self.reduce_results,
context_metadata=context_metadata,
padded_hidden_states_shape=padded_hidden_states_shape,
)
if return_with_event:

View File

@@ -24,6 +24,13 @@ from vllm.model_executor.layers.fused_moe import FusedMoEConfig
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
MoEFusedExpertsInput,
MoEMlpComputeInput,
MoEPrepareOutput,
build_mlp_compute_input,
build_token_dispatch_input,
)
from vllm_ascend.ops.fused_moe.prepare_finalize import (
PrepareAndFinalize,
PrepareAndFinalizeWithAll2All,
@@ -36,8 +43,7 @@ from vllm_ascend.ops.fused_moe.token_dispatcher import (
TokenDispatcherWithAllGather,
TokenDispatcherWithMC2,
)
from vllm_ascend.quantization.methods.base import QuantType
from vllm_ascend.quantization.quant_parser import parse_mxfp_quant_params
from vllm_ascend.quantization.quant_type import QuantType
_MoECommMethods: dict[MoECommType | None, MoECommMethod] = {}
@@ -90,131 +96,70 @@ class MoECommMethod(ABC):
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
quant_type: QuantType = QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
hidden_states, router_logits, mc2_mask, context_metadata = self.prepare_finalize.prepare(
hidden_states, router_logits, enable_shared_expert_dp, replace_allreduce, quant_type
) -> MoEPrepareOutput:
return self.prepare_finalize.prepare(
hidden_states,
router_logits,
enable_shared_expert_dp,
replace_allreduce,
quant_type,
)
return hidden_states, router_logits, mc2_mask, context_metadata
def finalize(
self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None
self,
hidden_states: torch.Tensor,
reduce_results: bool,
padded_hidden_states_shape: torch.Size | None = None,
) -> torch.Tensor:
hidden_states = self.prepare_finalize.finalize(hidden_states, reduce_results, context_metadata)
hidden_states = self.prepare_finalize.finalize(hidden_states, reduce_results, padded_hidden_states_shape)
return hidden_states
def fused_experts(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor | list[torch.Tensor],
w2: torch.Tensor | list[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
w1_bias: torch.Tensor = None,
w2_bias: torch.Tensor = None,
apply_router_weight_on_input: bool = False,
use_int8_w8a8: bool = False,
use_int4_w4a8: bool = False,
use_int4_w4a16: bool = False,
expert_map: torch.Tensor | None = None,
w1_scale: list[torch.Tensor] | None = None,
w2_scale: list[torch.Tensor] | None = None,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
w1_offset: torch.Tensor | None = None,
w2_offset: torch.Tensor | None = None,
# For load balance
log2phy: torch.Tensor = None,
need_trans: bool = False,
dynamic_eplb: bool = False,
mc2_mask: torch.Tensor = None,
pertoken_scale: torch.Tensor | None = None,
**kwargs,
fused_experts_input: MoEFusedExpertsInput,
):
# Check constraints
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16, torch.int8]
assert fused_experts_input.hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16, torch.int8]
moe_comm_method = _EXTRA_CTX.moe_comm_method
assert moe_comm_method is not None, "Missing communication context"
before_dispatch_evt = torch.npu.current_stream().record_event()
# Apply log2phy if needed
if log2phy is not None:
topk_ids = log2phy[topk_ids]
# TODO(linfeng): Current massive parameter passing is quite severe; parameter differences introduced
# by different quantization modes will be consolidated into a dataclass in a follow-up.
use_mxfp_quant = kwargs.get("use_mxfp_quant", False)
dispatch_with_quant = use_int8_w8a8 or use_int4_w4a8 or use_mxfp_quant
act_quant_type, weight_quant_type, scale_type, per_token_scale_type, round_mode = parse_mxfp_quant_params(
**kwargs
routed_topk_ids = fused_experts_input.topk_ids
if fused_experts_input.routing.log2phy is not None:
routed_topk_ids = fused_experts_input.routing.log2phy[routed_topk_ids]
token_dispatch_input = build_token_dispatch_input(
fused_experts_input=fused_experts_input,
topk_ids=routed_topk_ids,
)
token_dispatch_output = self.token_dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
mlp_compute_input = build_mlp_compute_input(
fused_experts_input=fused_experts_input,
token_dispatch_output=token_dispatch_output,
use_fusion_ops=self.use_fusion_ops,
)
dispatch_kwargs = {
"hidden_states": hidden_states,
"topk_weights": topk_weights,
"topk_ids": topk_ids,
"expert_map": expert_map,
"global_redundant_expert_num": self.moe_config.global_redundant_expert_num,
"mc2_mask": mc2_mask,
"apply_router_weight_on_input": apply_router_weight_on_input,
"dynamic_eplb": dynamic_eplb,
"pertoken_scale": pertoken_scale,
}
if isinstance(self.token_dispatcher, TokenDispatcherWithMC2):
dispatch_kwargs["with_quant"] = dispatch_with_quant
dispatch_kwargs["comm_quant_mode"] = kwargs.get("comm_quant_mode")
dispatch_kwargs["y_dtype"] = act_quant_type if use_mxfp_quant else None
dispatch_kwargs["use_mxfp_quant"] = use_mxfp_quant
else:
dispatch_kwargs["with_quant"] = use_int8_w8a8 or use_int4_w4a8
dispatch_results = self.token_dispatcher.token_dispatch(**dispatch_kwargs)
mlp_output = unified_apply_mlp(
hidden_states=dispatch_results.hidden_states,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
w1_bias=w1_bias,
w2_bias=w2_bias,
activation=activation,
group_list=dispatch_results.group_list,
dynamic_scale=dispatch_results.dynamic_scale,
group_list_type=dispatch_results.group_list_type,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
w1_offset=w1_offset,
w2_offset=w2_offset,
topk_scales=dispatch_results.topk_scales,
with_quant=use_int8_w8a8 or use_int4_w4a8 or use_int4_w4a16 or use_mxfp_quant,
fusion=(use_int8_w8a8 or use_mxfp_quant) and self.use_fusion_ops,
need_trans=need_trans,
dynamic_eplb=dynamic_eplb,
use_mxfp_quant=use_mxfp_quant,
act_quant_type=act_quant_type,
weight_quant_type=weight_quant_type,
scale_type=scale_type,
per_token_scale_type=per_token_scale_type,
round_mode=round_mode,
use_bf16=(hidden_states.dtype == torch.bfloat16),
rollback_quant_config=kwargs.get("rollback_quant_config"),
)
mlp_output = self._apply_mlp(mlp_compute_input)
before_combine_evt = torch.npu.current_stream().record_event()
combine_results = self.token_dispatcher.token_combine(
hidden_states=mlp_output, context_metadata=dispatch_results.context_metadata
routed_out = self.token_dispatcher.token_combine(
hidden_states=mlp_output,
combine_metadata=token_dispatch_output.combine_metadata,
)
return FusedExpertsResult(
routed_out=combine_results.routed_out,
routed_out=routed_out,
before_dispatch_evt=before_dispatch_evt,
before_combine_evt=before_combine_evt,
group_list_type=dispatch_results.group_list_type,
expert_tokens=dispatch_results.group_list,
group_list_type=token_dispatch_output.group_list_type,
expert_tokens=token_dispatch_output.group_list,
)
def _apply_mlp(self, mlp_compute_input: MoEMlpComputeInput) -> torch.Tensor:
return unified_apply_mlp(mlp_compute_input=mlp_compute_input)
@abstractmethod
def _get_token_dispatcher(self) -> MoETokenDispatcher:
raise NotImplementedError("_get_token_dispatcher function not implemented.")
@@ -317,54 +262,32 @@ class FusedMC2CommImpl(MoECommMethod):
def fused_experts(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor | list[torch.Tensor],
w2: torch.Tensor | list[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
w1_bias: torch.Tensor = None,
w2_bias: torch.Tensor = None,
apply_router_weight_on_input: bool = False,
use_int8_w8a8: bool = False,
use_int4_w4a8: bool = False,
use_int4_w4a16: bool = False,
expert_map: torch.Tensor | None = None,
w1_scale: list[torch.Tensor] | None = None,
w2_scale: list[torch.Tensor] | None = None,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
w1_offset: torch.Tensor | None = None,
w2_offset: torch.Tensor | None = None,
# For load balance
log2phy: torch.Tensor = None,
need_trans: bool = False,
dynamic_eplb: bool = False,
mc2_mask: torch.Tensor = None,
pertoken_scale: torch.Tensor | None = None,
**kwargs,
fused_experts_input: MoEFusedExpertsInput,
):
assert not (w1_scale is None or w2_scale is None), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl."
assert not (fused_experts_input.weights.w1_scale is None or fused_experts_input.weights.w2_scale is None), (
"w1_scale and w2_scale cannot be None for FusedMC2CommImpl."
)
assert isinstance(self.token_dispatcher, TokenDispatcherWithMC2), (
"token_dispatcher must be an instance of TokenDispatcherWithMC2."
)
# Apply log2phy if needed
if log2phy is not None:
topk_ids = log2phy[topk_ids]
topk_ids = fused_experts_input.topk_ids
if fused_experts_input.routing.log2phy is not None:
topk_ids = fused_experts_input.routing.log2phy[topk_ids]
expert_tokens = None
if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1:
out = torch.empty_like(hidden_states)
out = torch.empty_like(fused_experts_input.hidden_states)
torch.ops._C_ascend.dispatch_ffn_combine( # type: ignore
x=hidden_states,
weight1=w1,
weight2=w2,
x=fused_experts_input.hidden_states,
weight1=fused_experts_input.weights.w1,
weight2=fused_experts_input.weights.w2,
expert_idx=topk_ids,
scale1=w1_scale,
scale2=w2_scale,
probs=topk_weights.to(torch.float32),
scale1=fused_experts_input.weights.w1_scale,
scale2=fused_experts_input.weights.w2_scale,
probs=fused_experts_input.topk_weights.to(torch.float32),
group=self.token_dispatcher.moe_all_to_all_group_name,
max_output_size=65536,
out=out,
@@ -372,16 +295,16 @@ class FusedMC2CommImpl(MoECommMethod):
)
expert_tokens = self.expert_token_nums
elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2:
assert expert_map is not None, "expert_map cannot be None."
assert fused_experts_input.routing.expert_map is not None, "expert_map cannot be None."
out, expert_tokens = torch.ops._C_ascend.dispatch_gmm_combine_decode( # type: ignore
x=hidden_states,
x=fused_experts_input.hidden_states,
expert_ids=topk_ids,
gmm1_permuted_weight=w1,
gmm1_permuted_weight_scale=w1_scale,
gmm2_weight=w2,
gmm2_weight_scale=w2_scale,
gmm1_permuted_weight=fused_experts_input.weights.w1,
gmm1_permuted_weight_scale=fused_experts_input.weights.w1_scale,
gmm2_weight=fused_experts_input.weights.w2,
gmm2_weight_scale=fused_experts_input.weights.w2_scale,
expert_smooth_scales=None,
expert_scales=topk_weights.to(torch.float32),
expert_scales=fused_experts_input.topk_weights.to(torch.float32),
group_ep=self.token_dispatcher.moe_all_to_all_group_name,
ep_rank_size=self.token_dispatcher.ep_world_size,
ep_rank_id=self.token_dispatcher.ep_rank_id,

View File

@@ -27,6 +27,7 @@ from vllm_ascend.device.mxfp_compat import (
ensure_mxfp8_moe_available,
)
from vllm_ascend.ops.activation import AscendSwigluOAIAndMul
from vllm_ascend.ops.fused_moe.moe_runtime_args import MoEMlpComputeInput
from vllm_ascend.utils import (
dispose_tensor,
enable_custom_op,
@@ -95,27 +96,17 @@ def quant_apply_mlp(
w2_offset: torch.Tensor | None = None,
fusion: bool = False,
dynamic_eplb: bool = False,
**kwargs,
use_mxfp_quant: bool = False,
act_quant_type: torch.dtype = torch.float8_e4m3fn,
weight_quant_type: torch.dtype | None = None,
scale_type: torch.dtype | None = None,
per_token_scale_type: torch.dtype | None = None,
use_bf16: bool = True,
) -> torch.Tensor:
# TODO(linfeng): Current massive parameter passing is quite severe; parameter differences introduced by different
# quantization modes will be consolidated into a dataclass in a follow-up.
use_mxfp_quant = kwargs.get("use_mxfp_quant", False)
act_quant_type = torch.float8_e4m3fn
weight_quant_type = None
scale_type = None
per_token_scale_type = None
use_bf16 = True
input_hidden_dtype = hidden_states.dtype
use_gmm_swiglu_quant_fusion = use_mxfp_quant or (fusion and not dynamic_eplb)
if use_mxfp_quant:
act_quant_type = kwargs.get("act_quant_type", torch.float8_e4m3fn)
weight_quant_type = kwargs.get("weight_quant_type", torch.float8_e4m3fn)
scale_type = kwargs.get("scale_type")
per_token_scale_type = kwargs.get("per_token_scale_type")
use_bf16 = kwargs.get("use_bf16", True)
ensure_mxfp8_moe_available("MXFP MoE MLP path")
if w1_scale_bias is not None or w2_scale_bias is not None:
@@ -393,34 +384,32 @@ def unquant_apply_mlp(
return hidden_states
def unified_apply_mlp(
hidden_states: torch.Tensor,
w1: torch.Tensor | list[torch.Tensor],
w2: torch.Tensor | list[torch.Tensor],
group_list: torch.Tensor,
w1_scale: list[torch.Tensor] | None = None,
w2_scale: list[torch.Tensor] | None = None,
activation: str | None = None,
w1_bias: torch.Tensor = None,
w2_bias: torch.Tensor = None,
dynamic_scale: torch.Tensor = None,
group_list_type: int = 1,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
w1_offset: torch.Tensor | None = None,
w2_offset: torch.Tensor | None = None,
topk_scales: torch.Tensor | None = None,
with_quant: bool = False,
fusion: bool = False,
need_trans: bool = True,
dynamic_eplb: bool = False,
**kwargs,
) -> torch.Tensor:
def unified_apply_mlp(*, mlp_compute_input: MoEMlpComputeInput) -> torch.Tensor:
"""
Unified MoE MLP entry.
Quant path is dispatched by DeviceOperator with explicit quant-type flags.
Quant path is dispatched by DeviceOperator with explicit typed kernel flags.
"""
if not with_quant:
hidden_states = mlp_compute_input.hidden_states
group_list = mlp_compute_input.group_list
group_list_type = mlp_compute_input.group_list_type
dynamic_scale = mlp_compute_input.dynamic_scale
topk_scales = mlp_compute_input.topk_scales
w1 = mlp_compute_input.weights.w1
w2 = mlp_compute_input.weights.w2
w1_bias = mlp_compute_input.weights.w1_bias
w2_bias = mlp_compute_input.weights.w2_bias
w1_scale = mlp_compute_input.weights.w1_scale
w2_scale = mlp_compute_input.weights.w2_scale
w1_scale_bias = mlp_compute_input.weights.w1_scale_bias
w2_scale_bias = mlp_compute_input.weights.w2_scale_bias
w1_offset = mlp_compute_input.weights.w1_offset
w2_offset = mlp_compute_input.weights.w2_offset
activation = mlp_compute_input.activation
need_trans = mlp_compute_input.need_trans
dynamic_eplb = mlp_compute_input.dynamic_eplb
fusion = mlp_compute_input.fusion
if not mlp_compute_input.quant.is_quant:
return unquant_apply_mlp(
hidden_states=hidden_states,
w1=w1,
@@ -435,13 +424,22 @@ def unified_apply_mlp(
)
assert w1_scale is not None and w2_scale is not None
# TODO(linfeng): Current massive parameter passing is quite severe; parameter differences introduced by different
# quantization modes will be consolidated into a dataclass in a follow-up.
act_quant_type = kwargs.get("act_quant_type", torch.float8_e4m3fn)
weight_quant_type = kwargs.get("weight_quant_type", torch.float8_e4m3fn)
scale_type = kwargs.get("scale_type")
per_token_scale_type = kwargs.get("per_token_scale_type")
use_mxfp_quant = kwargs.get("use_mxfp_quant", False)
act_quant_type = torch.float8_e4m3fn
weight_quant_type = torch.float8_e4m3fn
scale_type = None
per_token_scale_type = None
use_bf16 = hidden_states.dtype == torch.bfloat16
use_mxfp_quant = mlp_compute_input.quant.is_mxfp
if use_mxfp_quant:
mxfp = mlp_compute_input.quant.mxfp
assert mxfp is not None, "mlp_compute_input.quant.mxfp is required when quant_type is MXFP8."
act_quant_type = mxfp.act_quant_type or act_quant_type
weight_quant_type = mxfp.weight_quant_type or weight_quant_type
scale_type = mxfp.scale_dtype
per_token_scale_type = mxfp.per_token_scale_dtype
use_bf16 = mxfp.use_bf16
return quant_apply_mlp(
hidden_states=hidden_states,
w1=w1,
@@ -457,10 +455,10 @@ def unified_apply_mlp(
w2_offset=w2_offset,
fusion=fusion,
dynamic_eplb=dynamic_eplb,
use_mxfp_quant=use_mxfp_quant,
act_quant_type=act_quant_type,
weight_quant_type=weight_quant_type,
scale_type=scale_type,
per_token_scale_type=per_token_scale_type,
use_mxfp_quant=use_mxfp_quant,
use_bf16=kwargs.get("use_bf16", True),
use_bf16=use_bf16,
)

View File

@@ -0,0 +1,244 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Typed runtime contracts and builders for fused MoE execution.
This module is the single entry point for the runtime payloads used across the
fused MoE pipeline.
Relationship overview:
stage params: reusable sub-payloads
- MoERoutingParams
- MoEQuantParams
- internal MXFP leaf: MoEMxfpParams
stage contracts: stage input/output payloads
prepare
-> MoEPrepareOutput
fused_experts input
-> MoEFusedExpertsInput
|- weights: MoEWeights
|- routing: MoERoutingParams
|- quant: MoEQuantParams
dispatch
input -> MoETokenDispatchInput
output -> MoETokenDispatchOutput[TMoECombineMetadata]
TMoECombineMetadata is one of:
- MoEAllGatherCombineMetadata
- MoEAllToAllCombineMetadata
- MoEMC2CombineMetadata
mlp
input -> MoEMlpComputeInput
combine
output -> torch.Tensor
The helper builders below adapt legacy call sites into these typed contracts.
Only the fused_moe package should need to know about the internal MXFP leaf
dataclass directly.
"""
from __future__ import annotations
import torch
import vllm_ascend.ops.fused_moe.moe_stage_params as _stage_params
from vllm_ascend.ops.fused_moe.moe_stage_contracts import (
MoEAllGatherCombineMetadata,
MoEAllToAllCombineMetadata,
MoEFusedExpertsInput,
MoEMC2CombineMetadata,
MoEMlpComputeInput,
MoEPrepareOutput,
MoETokenDispatchInput,
MoETokenDispatchOutput,
MoEWeights,
TMoECombineMetadata,
)
from vllm_ascend.ops.fused_moe.moe_stage_params import (
MoEQuantParams,
MoERoutingParams,
)
from vllm_ascend.quantization.quant_type import QuantType
def _build_mxfp_params(
*,
quant_type: QuantType,
mxfp_act_quant_type: torch.dtype | None = None,
mxfp_weight_quant_type: torch.dtype | None = None,
mxfp_scale_dtype: torch.dtype | None = None,
mxfp_per_token_scale_dtype: torch.dtype | None = None,
mxfp_use_bf16: bool | None = None,
) -> _stage_params.MoEMxfpParams | None:
if quant_type != QuantType.MXFP8:
return None
has_explicit_mxfp_args = any(
value is not None
for value in (
mxfp_act_quant_type,
mxfp_weight_quant_type,
mxfp_scale_dtype,
mxfp_per_token_scale_dtype,
mxfp_use_bf16,
)
)
if not has_explicit_mxfp_args:
raise ValueError("primitive MXFP params are required when quant_type is QuantType.MXFP8.")
return _stage_params.MoEMxfpParams(
act_quant_type=mxfp_act_quant_type,
weight_quant_type=mxfp_weight_quant_type,
scale_dtype=mxfp_scale_dtype,
per_token_scale_dtype=mxfp_per_token_scale_dtype,
use_bf16=True if mxfp_use_bf16 is None else mxfp_use_bf16,
)
def build_fused_experts_input(
*,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
w1: torch.Tensor | list[torch.Tensor],
w2: torch.Tensor | list[torch.Tensor],
quant_type: QuantType,
dynamic_eplb: bool,
expert_map: torch.Tensor | None = None,
global_redundant_expert_num: int = 0,
mc2_mask: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
log2phy: torch.Tensor | None = None,
pertoken_scale: torch.Tensor | None = None,
activation: str = "silu",
need_trans: bool = False,
w1_bias: torch.Tensor | None = None,
w2_bias: torch.Tensor | None = None,
comm_quant_mode: int | None = None,
mxfp_act_quant_type: torch.dtype | None = None,
mxfp_weight_quant_type: torch.dtype | None = None,
mxfp_scale_dtype: torch.dtype | None = None,
mxfp_per_token_scale_dtype: torch.dtype | None = None,
mxfp_use_bf16: bool | None = None,
w1_scale: list[torch.Tensor] | torch.Tensor | None = None,
w2_scale: list[torch.Tensor] | torch.Tensor | None = None,
w1_scale_bias: torch.Tensor | None = None,
w2_scale_bias: torch.Tensor | None = None,
w1_offset: torch.Tensor | None = None,
w2_offset: torch.Tensor | None = None,
) -> MoEFusedExpertsInput:
return MoEFusedExpertsInput(
hidden_states=hidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
weights=MoEWeights(
w1=w1,
w2=w2,
w1_bias=w1_bias,
w2_bias=w2_bias,
w1_scale=w1_scale,
w2_scale=w2_scale,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
w1_offset=w1_offset,
w2_offset=w2_offset,
),
routing=MoERoutingParams(
expert_map=expert_map,
global_redundant_expert_num=global_redundant_expert_num,
mc2_mask=mc2_mask,
apply_router_weight_on_input=apply_router_weight_on_input,
log2phy=log2phy,
pertoken_scale=pertoken_scale,
),
activation=activation,
need_trans=need_trans,
dynamic_eplb=dynamic_eplb,
quant=MoEQuantParams(
quant_type=quant_type,
comm_quant_mode=comm_quant_mode,
mxfp=_build_mxfp_params(
quant_type=quant_type,
mxfp_act_quant_type=mxfp_act_quant_type,
mxfp_weight_quant_type=mxfp_weight_quant_type,
mxfp_scale_dtype=mxfp_scale_dtype,
mxfp_per_token_scale_dtype=mxfp_per_token_scale_dtype,
mxfp_use_bf16=mxfp_use_bf16,
),
),
)
def build_token_dispatch_input(
*,
fused_experts_input: MoEFusedExpertsInput,
topk_ids: torch.Tensor | None = None,
) -> MoETokenDispatchInput:
return MoETokenDispatchInput(
hidden_states=fused_experts_input.hidden_states,
topk_weights=fused_experts_input.topk_weights,
topk_ids=fused_experts_input.topk_ids if topk_ids is None else topk_ids,
routing=fused_experts_input.routing,
quant=fused_experts_input.quant,
)
def build_mlp_compute_input(
*,
fused_experts_input: MoEFusedExpertsInput,
token_dispatch_output: MoETokenDispatchOutput[TMoECombineMetadata],
use_fusion_ops: bool,
) -> MoEMlpComputeInput:
if fused_experts_input.quant.is_mxfp and fused_experts_input.quant.mxfp is None:
raise ValueError("fused_experts_input.quant.mxfp is required when quant_type is QuantType.MXFP8.")
return MoEMlpComputeInput(
hidden_states=token_dispatch_output.hidden_states,
group_list=token_dispatch_output.group_list,
group_list_type=token_dispatch_output.group_list_type,
dynamic_scale=token_dispatch_output.dynamic_scale,
topk_scales=token_dispatch_output.topk_scales,
weights=fused_experts_input.weights,
quant=fused_experts_input.quant,
fusion=fused_experts_input.quant.quant_type in (QuantType.W8A8, QuantType.MXFP8) and use_fusion_ops,
activation=fused_experts_input.activation,
need_trans=fused_experts_input.need_trans,
dynamic_eplb=fused_experts_input.dynamic_eplb,
)
__all__ = [
"MoEAllGatherCombineMetadata",
"MoEAllToAllCombineMetadata",
"MoEFusedExpertsInput",
"MoEMC2CombineMetadata",
"MoEMlpComputeInput",
"MoEPrepareOutput",
"MoEQuantParams",
"MoERoutingParams",
"MoETokenDispatchInput",
"MoETokenDispatchOutput",
"MoEWeights",
"TMoECombineMetadata",
"build_fused_experts_input",
"build_token_dispatch_input",
"build_mlp_compute_input",
]

View File

@@ -0,0 +1,154 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import annotations
from dataclasses import dataclass
from typing import Generic, TypeVar
import numpy as np
import torch
from vllm_ascend.ops.fused_moe.moe_stage_params import MoEQuantParams, MoERoutingParams
TMoECombineMetadata = TypeVar("TMoECombineMetadata")
# prepare -> fused_experts
@dataclass(frozen=True, slots=True)
class MoEPrepareOutput:
"""Typed output from prepare stage."""
hidden_states: torch.Tensor
router_logits: torch.Tensor
mc2_mask: torch.Tensor | None
padded_hidden_states_shape: torch.Size | None
pertoken_scale: torch.Tensor | None = None
@dataclass(frozen=True, slots=True)
class MoEWeights:
"""Dense and quantized weight payloads consumed by MoE execution."""
w1: torch.Tensor | list[torch.Tensor]
w2: torch.Tensor | list[torch.Tensor]
w1_bias: torch.Tensor | None = None
w2_bias: torch.Tensor | None = None
w1_scale: torch.Tensor | list[torch.Tensor] | None = None
w2_scale: torch.Tensor | list[torch.Tensor] | None = None
w1_scale_bias: torch.Tensor | None = None
w2_scale_bias: torch.Tensor | None = None
w1_offset: torch.Tensor | None = None
w2_offset: torch.Tensor | None = None
@dataclass(frozen=True, slots=True)
class MoEFusedExpertsInput:
"""Top-level input for the routed experts pipeline."""
hidden_states: torch.Tensor
topk_weights: torch.Tensor
topk_ids: torch.Tensor
weights: MoEWeights
routing: MoERoutingParams
quant: MoEQuantParams
activation: str = "silu"
need_trans: bool = False
dynamic_eplb: bool = False
@dataclass(frozen=True, slots=True)
class MoETokenDispatchInput:
"""Input to token dispatch."""
hidden_states: torch.Tensor
topk_weights: torch.Tensor
topk_ids: torch.Tensor
routing: MoERoutingParams
quant: MoEQuantParams
# dispatch carry-over state consumed by combine
@dataclass(frozen=True, slots=True)
class MoEMC2CombineMetadata:
topk_ids: torch.Tensor
topk_weights: torch.Tensor
expert_map: torch.Tensor | None
ep_recv_counts: torch.Tensor
tp_recv_counts: torch.Tensor
assist_info_for_combine: torch.Tensor
expand_scales: torch.Tensor | None
dispatch_with_quant: bool
@dataclass(frozen=True, slots=True)
class MoEAllGatherCombineMetadata:
topk_weights: torch.Tensor
expanded_row_idx: torch.Tensor
restore_shape: torch.Size
@dataclass(frozen=True, slots=True)
class MoEAllToAllCombineMetadata:
input_splits: np.ndarray
output_splits: np.ndarray
topk_weights: torch.Tensor
reversed_local_input_permutation_mapping: torch.Tensor
reversed_global_input_permutation_mapping: torch.Tensor | None
hidden_shape: torch.Size
hidden_shape_before_permute: torch.Size
@dataclass(frozen=True, slots=True)
class MoETokenDispatchOutput(Generic[TMoECombineMetadata]):
hidden_states: torch.Tensor
group_list: torch.Tensor
group_list_type: int
combine_metadata: TMoECombineMetadata
dynamic_scale: torch.Tensor | None = None
topk_scales: torch.Tensor | None = None
# dispatch -> mlp -> combine
@dataclass(frozen=True, slots=True)
class MoEMlpComputeInput:
"""Input to MLP compute."""
hidden_states: torch.Tensor
group_list: torch.Tensor
group_list_type: int
dynamic_scale: torch.Tensor | None
topk_scales: torch.Tensor | None
weights: MoEWeights
quant: MoEQuantParams
fusion: bool
activation: str = "silu"
need_trans: bool = False
dynamic_eplb: bool = False
__all__ = [
"MoEPrepareOutput",
"MoEWeights",
"MoEFusedExpertsInput",
"MoETokenDispatchInput",
"MoEMC2CombineMetadata",
"MoEAllGatherCombineMetadata",
"MoEAllToAllCombineMetadata",
"MoETokenDispatchOutput",
"MoEMlpComputeInput",
"TMoECombineMetadata",
]

View File

@@ -0,0 +1,86 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import annotations
from dataclasses import dataclass
import torch
from vllm_ascend.quantization.quant_type import QuantType
@dataclass(frozen=True, slots=True)
class MoERoutingParams:
"""Routing and dispatch side inputs for one MoE invocation.
`pertoken_scale` is intentionally kept here even though it is not a pure
routing concept. It is used by pre-quantized activation flows, currently
the AllGather + EP W8A8 prepare path, where prepare emits per-token
activation scales and dispatch needs to carry them forward so the MLP
quant path can reuse those scales instead of requantizing activations.
"""
expert_map: torch.Tensor | None
global_redundant_expert_num: int
mc2_mask: torch.Tensor | None
apply_router_weight_on_input: bool
log2phy: torch.Tensor | None = None
# Precomputed activation scales from prepare stage for quantized dispatch.
pertoken_scale: torch.Tensor | None = None
@dataclass(frozen=True, slots=True)
class MoEMxfpParams:
"""Internal MXFP-only precision settings used by fused_moe runtime."""
act_quant_type: torch.dtype | None = None
weight_quant_type: torch.dtype | None = None
scale_dtype: torch.dtype | None = None
per_token_scale_dtype: torch.dtype | None = None
use_bf16: bool = True
@dataclass(frozen=True, slots=True)
class MoEQuantParams:
"""Quant mode, backend override, and optional internal MXFP leaf config."""
quant_type: QuantType = QuantType.NONE
comm_quant_mode: int | None = None
mxfp: MoEMxfpParams | None = None
@property
def is_quant(self) -> bool:
return self.quant_type != QuantType.NONE
@property
def is_mxfp(self) -> bool:
return self.quant_type == QuantType.MXFP8
@property
def is_int_quant(self) -> bool:
return self.quant_type in (QuantType.W8A8, QuantType.W4A8)
@property
def dispatch_with_quant(self) -> bool:
return self.quant_type in (QuantType.W8A8, QuantType.W4A8, QuantType.MXFP8)
__all__ = [
"MoERoutingParams",
"MoEMxfpParams",
"MoEQuantParams",
]

View File

@@ -31,7 +31,8 @@ from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl
from vllm_ascend.quantization.methods.base import QuantType
from vllm_ascend.ops.fused_moe.moe_runtime_args import MoEPrepareOutput
from vllm_ascend.quantization.quant_type import QuantType
from vllm_ascend.utils import enable_sp, npu_stream_switch, prefill_context_parallel_enable
@@ -64,7 +65,7 @@ class PrepareAndFinalize(ABC):
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
quant_type: QuantType = QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
) -> MoEPrepareOutput:
"""
Prepare tensors before MoE computation. May involve:
- Padding to align communication boundaries
@@ -79,16 +80,20 @@ class PrepareAndFinalize(ABC):
quant_type: none, w8a8, w4a8 or mxfp8
Returns:
Tuple of:
MoEPrepareOutput:
- processed hidden_states (may be padded/sliced/broadcasted)
- processed router_logits (may be recomputed or broadcasted)
- optional communication mask (e.g., mc2_mask for sparse ops)
- optional context metadata (e.g., saved split_hidden_states for finalization)
- optional padded hidden state shape for finalization
- optional per-token scale for quantized path
"""
raise NotImplementedError("Prepare not implemented.")
def finalize(
self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None
self,
hidden_states: torch.Tensor,
reduce_results: bool,
padded_hidden_states_shape: torch.Size | None = None,
) -> torch.Tensor:
"""
Finalize MoE output. May involve:
@@ -130,7 +135,7 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
quant_type=QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
) -> MoEPrepareOutput:
"""
Preparation steps:
1. Pad hidden_states and router_logits to next multiple of TP size.
@@ -140,7 +145,7 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
Skips if `enable_shared_expert_dp` or `replace_allreduce` is True.
Returns:
Tuple of (hidden_states, router_logits, None, context_metadata) — no mask used in All2All.
MoEPrepareOutput where `mc2_mask` is None for All2All path.
"""
self.replace_allreduce = replace_allreduce
self.enable_shared_expert_dp = enable_shared_expert_dp
@@ -162,12 +167,19 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank]
context_metadata = {"padded_hidden_states_shape": padded_hidden_states_shape}
return hidden_states, router_logits, None, context_metadata
return MoEPrepareOutput(
hidden_states=hidden_states,
router_logits=router_logits,
mc2_mask=None,
padded_hidden_states_shape=padded_hidden_states_shape,
pertoken_scale=None,
)
def finalize(
self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None
self,
hidden_states: torch.Tensor,
reduce_results: bool,
padded_hidden_states_shape: torch.Size | None = None,
) -> torch.Tensor:
"""
Finalization steps:
@@ -180,12 +192,11 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
if not (self.enable_shared_expert_dp or self.replace_allreduce):
if self.tp_size > 1:
assert context_metadata is not None
assert padded_hidden_states_shape is not None
# Cannot reuse `split_hidden_states` from prepare phase as it
# may share memory with original hidden_states. Since shared
# experts may use the original tensor, reusing it would cause
# in-place modification during all_gather, corrupting the data.
padded_hidden_states_shape = context_metadata["padded_hidden_states_shape"]
gathered_hidden_states = torch.empty(
padded_hidden_states_shape, device=hidden_states.device, dtype=hidden_states.dtype
)
@@ -227,7 +238,7 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
quant_type=QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
) -> MoEPrepareOutput:
"""
Preparation steps:
1. Fetch `mc2_mask` and target padding length from forward context.
@@ -238,7 +249,7 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
Skips padding/slicing if `enable_shared_expert_dp` or `replace_allreduce` is True.
Returns:
Tuple of (hidden_states, router_logits, mc2_mask, context_metadata), possibly sliced/padded.
MoEPrepareOutput, possibly sliced/padded.
"""
self.replace_allreduce = replace_allreduce
self.enable_shared_expert_dp = enable_shared_expert_dp
@@ -267,11 +278,13 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank]
context_metadata = {
"padded_hidden_states_shape": padded_hidden_states_shape,
}
return hidden_states, router_logits, mc2_mask, context_metadata
return MoEPrepareOutput(
hidden_states=hidden_states,
router_logits=router_logits,
mc2_mask=mc2_mask,
padded_hidden_states_shape=padded_hidden_states_shape,
pertoken_scale=None,
)
class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
@@ -303,13 +316,13 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
quant_type=QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
) -> MoEPrepareOutput:
"""
Preparation steps:
AllGather hidden_states and router_logits to form global tensors.
Returns:
Tuple of (global_hidden_states, global_router_logits, None)
MoEPrepareOutput with global tensors.
"""
if enable_sp():
return self._prepare_with_ep_group(hidden_states, router_logits, quant_type)
@@ -318,7 +331,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
def _prepare_with_ep_group(
self, hidden_states: torch.Tensor, router_logits: torch.Tensor, quant_type=QuantType.NONE
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
) -> MoEPrepareOutput:
pertoken_scale = None
if quant_type == QuantType.W8A8:
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
@@ -342,10 +355,13 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
if self.multistream_overlap_gate:
torch.npu.current_stream().wait_stream(PrepareAndFinalize.quant_stream)
if pertoken_scale is not None:
return (hidden_states, pertoken_scale), router_logits, None, None
return hidden_states, router_logits, None, None
return MoEPrepareOutput(
hidden_states=hidden_states,
router_logits=router_logits,
mc2_mask=None,
padded_hidden_states_shape=None,
pertoken_scale=pertoken_scale,
)
def _prepare_with_dp_group(
self,
@@ -354,7 +370,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
quant_type=QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
) -> MoEPrepareOutput:
"""
Preparation steps:
1. Fetch max token count across DP group from forward context.
@@ -362,7 +378,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
3. All-gather across DP group to form global input tensor.
Returns:
Tuple of (global_hidden_states, global_router_logits, None, None)
MoEPrepareOutput with global tensors.
"""
self.enable_shared_expert_dp = enable_shared_expert_dp
if self.moe_config.dp_size > 1:
@@ -396,10 +412,19 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
dim=0,
)
return hidden_states, router_logits, None, None
return MoEPrepareOutput(
hidden_states=hidden_states,
router_logits=router_logits,
mc2_mask=None,
padded_hidden_states_shape=None,
pertoken_scale=None,
)
def finalize(
self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None
self,
hidden_states: torch.Tensor,
reduce_results: bool,
padded_hidden_states_shape: torch.Size | None = None,
) -> torch.Tensor:
"""
Finalization steps:

View File

@@ -21,7 +21,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Generic
import torch
import torch_npu
@@ -31,25 +31,18 @@ from vllm.distributed.parallel_state import get_ep_group
from vllm_ascend.device.device_op import DeviceOperator
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.fused_moe.comm_utils import async_all_to_all, gather_from_sequence_parallel_region
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
MoEAllGatherCombineMetadata,
MoEAllToAllCombineMetadata,
MoEMC2CombineMetadata,
MoETokenDispatchInput,
MoETokenDispatchOutput,
TMoECombineMetadata,
)
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, is_hierarchical_communication_enabled
@dataclass
class TokenDispatchResult:
hidden_states: torch.Tensor
group_list: torch.Tensor
group_list_type: int
dynamic_scale: torch.Tensor | None = field(default=None)
topk_scales: torch.Tensor | None = field(default=None)
context_metadata: dict = field(default_factory=dict)
@dataclass
class TokenCombineResult:
routed_out: torch.Tensor
class MoETokenDispatcher(ABC):
class MoETokenDispatcher(ABC, Generic[TMoECombineMetadata]):
def __init__(self, **kwargs) -> None:
"""
Initialize the MoE Token Dispatcher.
@@ -73,27 +66,21 @@ class MoETokenDispatcher(ABC):
@abstractmethod
def token_dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor | None = None,
global_redundant_expert_num: int = 0,
mc2_mask: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False,
dynamic_eplb: bool = False,
pertoken_scale: torch.Tensor | None = None,
) -> TokenDispatchResult:
token_dispatch_input: MoETokenDispatchInput,
) -> MoETokenDispatchOutput[TMoECombineMetadata]:
raise NotImplementedError("Dispatch function not implemented.")
@abstractmethod
def token_combine(
self, hidden_states: torch.Tensor, context_metadata: dict, bias: torch.Tensor | None = None
) -> TokenCombineResult:
self,
hidden_states: torch.Tensor,
combine_metadata: TMoECombineMetadata,
bias: torch.Tensor | None = None,
) -> torch.Tensor:
raise NotImplementedError("Combine function not implemented.")
class TokenDispatcherWithMC2(MoETokenDispatcher):
class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
device_group = get_mc2_group().device_group
@@ -110,7 +97,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
# improve communication performance.
self.need_expert_scale = is_hierarchical_communication_enabled()
self.with_quant = False
# Here we need to calculate the global_bs = max_bs_per_rank * ep_world_size to execute
# dispatch & combine operators with different input num_tokens per rank.
@@ -131,25 +117,23 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
def get_dispatch_mc2_kwargs(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor,
mc2_mask: torch.Tensor,
global_redundant_expert_num: int = 0,
**kwargs,
token_dispatch_input: MoETokenDispatchInput,
):
use_mxfp_quant = kwargs.get("use_mxfp_quant", False)
comm_quant_mode = kwargs.get("comm_quant_mode")
hidden_states = token_dispatch_input.hidden_states
topk_weights = token_dispatch_input.topk_weights
topk_ids = token_dispatch_input.topk_ids
expert_map = token_dispatch_input.routing.expert_map
global_redundant_expert_num = token_dispatch_input.routing.global_redundant_expert_num
comm_quant_mode = token_dispatch_input.quant.comm_quant_mode
assert expert_map is not None, "expert_map is required for MC2 token dispatch."
# NOTE: quant_mode differs by quant feature:
# - Legacy int communication quantization uses quant_mode=2.
# - A5 MXFP8 communication uses quant_mode=4.
# TODO(linfeng): The quantization-related parameters need to be consolidated into a single
# dataclass, and the FP8 MoE code path should be integrated into it going forward.
if comm_quant_mode is not None:
quant_mode = comm_quant_mode
elif self.with_quant:
quant_mode = 4 if self.a5_need_extra_args and use_mxfp_quant else 2
elif token_dispatch_input.quant.dispatch_with_quant:
quant_mode = 4 if self.a5_need_extra_args and token_dispatch_input.quant.is_mxfp else 2
else:
quant_mode = 0
self.moe_expert_num = len(expert_map) + global_redundant_expert_num
@@ -178,10 +162,13 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
"tp_rank_id": 0,
}
)
if self.a5_need_extra_args and use_mxfp_quant:
y_dtype = kwargs.get("y_dtype")
if self.with_quant:
y_dtype = torch.float8_e4m3fn if y_dtype is None else y_dtype
if self.a5_need_extra_args and token_dispatch_input.quant.is_mxfp:
y_dtype = torch.float8_e4m3fn
if (
token_dispatch_input.quant.mxfp is not None
and token_dispatch_input.quant.mxfp.act_quant_type is not None
):
y_dtype = token_dispatch_input.quant.mxfp.act_quant_type
stage1_kwargs.update({"tp_world_size": 1, "tp_rank_id": 0, "y_dtype": y_dtype})
if self.need_expert_scale or self.a5_need_extra_args:
stage1_kwargs.update(
@@ -195,22 +182,9 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
def token_dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor | None = None,
global_redundant_expert_num: int = 0,
mc2_mask: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False,
dynamic_eplb: bool = False,
pertoken_scale: torch.Tensor | None = None,
**kwargs,
token_dispatch_input: MoETokenDispatchInput,
):
self.with_quant = with_quant
kwargs_mc2 = self.get_dispatch_mc2_kwargs(
hidden_states, topk_weights, topk_ids, expert_map, mc2_mask, global_redundant_expert_num, **kwargs
)
kwargs_mc2 = self.get_dispatch_mc2_kwargs(token_dispatch_input)
output = (
torch_npu.npu_moe_distribute_dispatch_v2(**kwargs_mc2)
if self.enable_dispatch_v2
@@ -227,33 +201,32 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
expand_scales,
) = output[0:7]
context_metadata = {
"topk_ids": topk_ids,
"topk_weights": topk_weights,
"expert_map": expert_map,
"ep_recv_counts": ep_recv_counts,
"tp_recv_counts": tp_recv_counts,
"assist_info_for_combine": assist_info_for_combine,
"expand_scales": expand_scales,
}
group_list_type = 0
return TokenDispatchResult(
return MoETokenDispatchOutput(
hidden_states=expand_x,
dynamic_scale=dynamic_scale,
group_list=expert_token_nums,
group_list_type=group_list_type,
context_metadata=context_metadata,
combine_metadata=MoEMC2CombineMetadata(
topk_ids=token_dispatch_input.topk_ids,
topk_weights=token_dispatch_input.topk_weights,
expert_map=token_dispatch_input.routing.expert_map,
ep_recv_counts=ep_recv_counts,
tp_recv_counts=tp_recv_counts,
assist_info_for_combine=assist_info_for_combine,
expand_scales=expand_scales,
dispatch_with_quant=token_dispatch_input.quant.dispatch_with_quant,
),
)
def get_combine_mc_kwargs(self, hidden_states: torch.Tensor, context_metadata: dict):
expert_map = context_metadata["expert_map"]
topk_ids = context_metadata["topk_ids"]
topk_weights = context_metadata["topk_weights"]
ep_recv_counts = context_metadata["ep_recv_counts"]
tp_recv_counts = context_metadata["tp_recv_counts"]
assist_info_for_combine = context_metadata["assist_info_for_combine"]
expand_scales = context_metadata["expand_scales"]
def get_combine_mc_kwargs(self, hidden_states: torch.Tensor, combine_metadata: MoEMC2CombineMetadata):
expert_map = combine_metadata.expert_map
topk_ids = combine_metadata.topk_ids
topk_weights = combine_metadata.topk_weights
ep_recv_counts = combine_metadata.ep_recv_counts
tp_recv_counts = combine_metadata.tp_recv_counts
assist_info_for_combine = combine_metadata.assist_info_for_combine
expand_scales = combine_metadata.expand_scales
assert expert_map is not None
@@ -267,7 +240,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
"global_bs": self.global_bs,
}
if self.with_quant:
if combine_metadata.dispatch_with_quant:
tp_recv_counts = torch.empty(1, dtype=torch.int32, device=hidden_states.device)
stage3_kwargs = {
@@ -296,52 +269,44 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
kwargs_mc2.update(stage3_kwargs)
return kwargs_mc2
def token_combine(self, hidden_states, context_metadata, bias=None):
def token_combine(self, hidden_states, combine_metadata, bias=None):
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states, context_metadata)
kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states, combine_metadata)
combined_output = (
torch_npu.npu_moe_distribute_combine_v2(**kwargs_mc2)
if self.enable_dispatch_v2
else torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
)
return TokenCombineResult(
routed_out=combined_output,
)
return combined_output
class TokenDispatcherWithAllGather(MoETokenDispatcher):
class TokenDispatcherWithAllGather(MoETokenDispatcher[MoEAllGatherCombineMetadata]):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.apply_router_weight_on_input = False
self.max_num_tokens = kwargs.get("max_num_tokens")
num_experts_local = kwargs.get("num_local_experts", 0)
self.num_experts_local = (
num_experts_local.item() if torch.is_tensor(num_experts_local) else int(num_experts_local)
)
self.original_shape = None
self.with_quant = False
def token_dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor | None = None,
global_redundant_expert_num: int = 0,
mc2_mask: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False,
dynamic_eplb: bool = False,
pertoken_scale: torch.Tensor | None = None,
token_dispatch_input: MoETokenDispatchInput,
):
self.with_quant = with_quant
self.original_shape = hidden_states.shape
with_quant = token_dispatch_input.quant.is_int_quant
hidden_states = token_dispatch_input.hidden_states
topk_weights = token_dispatch_input.topk_weights
topk_ids = token_dispatch_input.topk_ids
expert_map = token_dispatch_input.routing.expert_map
pertoken_scale = token_dispatch_input.routing.pertoken_scale
global_redundant_expert_num = token_dispatch_input.routing.global_redundant_expert_num
restore_shape = hidden_states.shape
num_tokens = hidden_states.shape[:-1].numel()
self.apply_router_weight_on_input = apply_router_weight_on_input
if self.apply_router_weight_on_input:
apply_router_weight_on_input = token_dispatch_input.routing.apply_router_weight_on_input
if apply_router_weight_on_input:
assert topk_weights.dim() == 2, "`topk_weights` should be in shape (num_tokens, topk)"
_, topk = topk_weights.shape
assert topk == 1, "Only support topk=1 when `apply_router_weight_on_input` is True"
@@ -365,35 +330,37 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
expert_tokens_num_type=1,
expert_tokens_num_flag=True,
active_expert_range=[first_expert_idx, last_expert_idx],
quant_mode=1 if self.with_quant and pertoken_scale is None else -1,
quant_mode=1 if with_quant and pertoken_scale is None else -1,
)
expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 1 # `count` mode
context_metadata = {"topk_weights": topk_weights, "expanded_row_idx": expanded_row_idx}
return TokenDispatchResult(
return MoETokenDispatchOutput(
hidden_states=sorted_hidden_states,
dynamic_scale=pertoken_scale if self.with_quant else None,
dynamic_scale=pertoken_scale if with_quant else None,
group_list=expert_tokens,
group_list_type=group_list_type,
context_metadata=context_metadata,
combine_metadata=MoEAllGatherCombineMetadata(
topk_weights=topk_weights,
expanded_row_idx=expanded_row_idx,
restore_shape=restore_shape,
),
)
def token_combine(self, hidden_states, context_metadata, bias=None):
assert self.original_shape is not None
def token_combine(self, hidden_states, combine_metadata, bias=None):
final_hidden_states = torch_npu.npu_moe_token_unpermute(
permuted_tokens=hidden_states,
sorted_indices=torch.abs(context_metadata["expanded_row_idx"]),
probs=context_metadata["topk_weights"],
sorted_indices=torch.abs(combine_metadata.expanded_row_idx),
probs=combine_metadata.topk_weights,
)
if len(self.original_shape) == 3:
final_hidden_states = final_hidden_states.view(self.original_shape)
if len(combine_metadata.restore_shape) == 3:
final_hidden_states = final_hidden_states.view(combine_metadata.restore_shape)
# these values are no longer used, so they need to be set to None for memory release.
return TokenCombineResult(routed_out=final_hidden_states)
return final_hidden_states
class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
class TokenDispatcherWithAll2AllV(MoETokenDispatcher[MoEAllToAllCombineMetadata]):
"""
The implementation of the AlltoAll-based token dispatcher, which handles token
dispatching on the sequence level instead of token level. The core of this implementation
@@ -402,12 +369,8 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.with_quant = False
self.num_local_experts = kwargs.get("num_local_experts", 0)
self.hidden_shape = None
self.hidden_shape_before_permute = None
assert self.num_local_experts > 0, "Expected at least one expert"
if self.num_local_experts > 1:
self.expert_ids_per_ep_rank = torch.tensor(
@@ -432,19 +395,12 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
def token_dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor | None = None,
global_redundant_expert_num: int = 0,
mc2_mask: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False,
dynamic_eplb: bool = False,
pertoken_scale: torch.Tensor | None = None,
token_dispatch_input: MoETokenDispatchInput,
):
self.with_quant = with_quant
self.hidden_shape = hidden_states.shape
with_quant = token_dispatch_input.quant.is_int_quant
hidden_states = token_dispatch_input.hidden_states
topk_weights = token_dispatch_input.topk_weights
topk_ids = token_dispatch_input.topk_ids
(
permutated_local_input_tokens,
@@ -452,12 +408,13 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
tokens_per_expert,
input_splits,
output_splits,
num_global_tokens_per_local_expert,
global_input_tokens_local_experts_indices,
hidden_shape,
hidden_shape_before_permute,
) = self._dispatch_preprocess(hidden_states, topk_ids)
dynamic_scale_after_all2all = None
if self.with_quant:
if with_quant:
permutated_local_input_tokens, dynamic_scale = torch_npu.npu_dynamic_quant(permutated_local_input_tokens)
_, dynamic_scale_after_all2all, permute2_ep_all_to_all_handle = async_all_to_all(
dynamic_scale, output_splits, input_splits, self.ep_group
@@ -474,64 +431,66 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
# Postprocess
global_input_tokens, dynamic_scale_final, reversed_global_input_permutation_mapping = (
self._dispatch_postprocess(
global_input_tokens, dynamic_scale_after_all2all, global_input_tokens_local_experts_indices
global_input_tokens,
dynamic_scale_after_all2all,
global_input_tokens_local_experts_indices,
with_quant,
)
)
context_metadata = {
"input_splits": input_splits,
"output_splits": output_splits,
"topk_weights": topk_weights,
"reversed_local_input_permutation_mapping": reversed_local_input_permutation_mapping,
"reversed_global_input_permutation_mapping": reversed_global_input_permutation_mapping,
}
return TokenDispatchResult(
return MoETokenDispatchOutput(
hidden_states=global_input_tokens,
dynamic_scale=dynamic_scale_final,
group_list=tokens_per_expert,
group_list_type=1,
context_metadata=context_metadata,
combine_metadata=MoEAllToAllCombineMetadata(
input_splits=input_splits,
output_splits=output_splits,
topk_weights=topk_weights,
reversed_local_input_permutation_mapping=reversed_local_input_permutation_mapping,
reversed_global_input_permutation_mapping=reversed_global_input_permutation_mapping,
hidden_shape=hidden_shape,
hidden_shape_before_permute=hidden_shape_before_permute,
),
)
def token_combine(self, hidden_states, context_metadata, bias=None):
def token_combine(self, hidden_states, combine_metadata, bias=None):
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
# 1. Preprocess using metadata
hidden_states = self._combine_preprocess(hidden_states, context_metadata)
hidden_states = self._combine_preprocess(hidden_states, combine_metadata)
# 2. AllToAll
_, permutated_local_input_tokens, handle = async_all_to_all(
hidden_states,
context_metadata["input_splits"],
context_metadata["output_splits"],
combine_metadata.input_splits,
combine_metadata.output_splits,
self.ep_group,
)
handle.wait()
hidden_states.untyped_storage().resize_(0)
# 3. Postprocess using metadata
output = self._combine_postprocess(permutated_local_input_tokens, context_metadata)
output = self._combine_postprocess(permutated_local_input_tokens, combine_metadata)
return TokenCombineResult(routed_out=output)
return output
def _dispatch_preprocess(self, hidden_states, topk_ids):
assert self.hidden_shape is not None
hidden_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_states.size(-1))
(
tokens_per_expert,
input_splits,
output_splits,
num_global_tokens_per_local_expert,
global_input_tokens_local_experts_indices,
num_out_tokens,
) = self._preprocess(topk_ids)
self.hidden_shape_before_permute = hidden_states.shape
hidden_shape_before_permute = hidden_states.shape
permutated_local_input_tokens, reversed_local_input_permutation_mapping = torch_npu.npu_moe_token_permute(
tokens=hidden_states,
indices=topk_ids,
num_out_tokens=self.num_out_tokens,
num_out_tokens=num_out_tokens,
)
return (
@@ -540,15 +499,16 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
tokens_per_expert,
input_splits,
output_splits,
num_global_tokens_per_local_expert,
global_input_tokens_local_experts_indices,
hidden_shape,
hidden_shape_before_permute,
)
def _preprocess(self, topk_ids: torch.Tensor):
num_local_tokens_per_expert = torch.histc(topk_ids, bins=self.num_experts, min=0, max=self.num_experts)
ep_size = self.ep_size
self.num_out_tokens = topk_ids.numel()
num_out_tokens = topk_ids.numel()
input_splits = (
num_local_tokens_per_expert.reshape(ep_size, self.num_local_experts)
@@ -585,19 +545,19 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
num_tokens_per_local_expert,
input_splits,
output_splits,
num_global_tokens_per_local_expert,
global_input_tokens_local_experts_indices,
num_out_tokens,
)
def _dispatch_postprocess(
self, global_input_tokens, dynamic_scale_after_all2all, global_input_tokens_local_experts_indices
self, global_input_tokens, dynamic_scale_after_all2all, global_input_tokens_local_experts_indices, with_quant
):
# Early return if no local experts or no tokens
if self.num_local_experts <= 1:
return global_input_tokens, dynamic_scale_after_all2all, None
# Handle quantized case
if self.with_quant:
if with_quant:
assert global_input_tokens_local_experts_indices is not None, (
"global_input_tokens_local_experts_indices must be provided"
)
@@ -612,20 +572,26 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
)
return global_input_tokens, dynamic_scale_after_all2all, reversed_global_input_permutation_mapping
def _combine_preprocess(self, hidden_states: torch.Tensor, context_metadata: dict) -> torch.Tensor:
def _combine_preprocess(
self, hidden_states: torch.Tensor, combine_metadata: MoEAllToAllCombineMetadata
) -> torch.Tensor:
# Unpermutation 2: expert output to AlltoAll input
if hidden_states.shape[0] > 0 and self.num_local_experts > 1:
rev_global = context_metadata["reversed_global_input_permutation_mapping"]
rev_global = combine_metadata.reversed_global_input_permutation_mapping
if hidden_states.shape[0] > 0 and self.num_local_experts > 1 and rev_global is not None:
hidden_states = torch_npu.npu_moe_token_unpermute(hidden_states, rev_global)
return hidden_states
def _combine_postprocess(self, permutated_local_input_tokens: torch.Tensor, context_metadata: dict) -> torch.Tensor:
def _combine_postprocess(
self,
permutated_local_input_tokens: torch.Tensor,
combine_metadata: MoEAllToAllCombineMetadata,
) -> torch.Tensor:
# Unpermutation 1: AlltoAll output to output
output = torch_npu.npu_moe_token_unpermute(
permuted_tokens=permutated_local_input_tokens,
sorted_indices=context_metadata["reversed_local_input_permutation_mapping"].to(torch.int32),
probs=context_metadata["topk_weights"],
restore_shape=self.hidden_shape_before_permute,
sorted_indices=combine_metadata.reversed_local_input_permutation_mapping.to(torch.int32),
probs=combine_metadata.topk_weights,
restore_shape=combine_metadata.hidden_shape_before_permute,
)
output = output.view(self.hidden_shape)
output = output.view(combine_metadata.hidden_shape)
return output

View File

@@ -16,24 +16,30 @@
#
"""Ascend quantization module.
This module provides quantization support for Ascend NPU.
Supported quantization tools:
- ModelSlim: Use AscendModelSlimConfig
- LLM-Compressor (compressed_tensors): Use AscendCompressedTensorsConfig
Public API:
- Config classes: AscendModelSlimConfig, AscendCompressedTensorsConfig
- For scheme implementations, import from vllm_ascend.quantization.methods
This module intentionally avoids eager imports so that importing lightweight
submodules (for example ``quant_type``) does not trigger heavy registration
paths and circular imports during startup.
"""
# LLM-Compressor (compressed_tensors) quantization config
from .compressed_tensors_config import AscendCompressedTensorsConfig
from typing import TYPE_CHECKING, Any
# ModelSlim quantization config
from .modelslim_config import AscendModelSlimConfig
if TYPE_CHECKING:
from .compressed_tensors_config import AscendCompressedTensorsConfig
from .modelslim_config import AscendModelSlimConfig
__all__ = [
"AscendModelSlimConfig",
"AscendCompressedTensorsConfig",
]
def __getattr__(name: str) -> Any:
if name == "AscendModelSlimConfig":
from .modelslim_config import AscendModelSlimConfig
return AscendModelSlimConfig
if name == "AscendCompressedTensorsConfig":
from .compressed_tensors_config import AscendCompressedTensorsConfig
return AscendCompressedTensorsConfig
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@@ -255,28 +255,34 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
enable_force_load_balance: bool = False,
log2phy: torch.Tensor | None = None,
global_redundant_expert_num=0,
**kwargs,
pertoken_scale: torch.Tensor | None = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
mc2_mask: torch.Tensor | None = None,
) -> torch.Tensor:
return self.quant_method.apply(
layer,
x,
router_logits,
top_k,
renormalize,
use_grouped_topk,
global_num_experts,
expert_map,
topk_group,
num_expert_group,
custom_routing_function,
scoring_func,
routed_scaling_factor,
e_score_correction_bias,
is_prefill,
enable_force_load_balance,
log2phy,
global_redundant_expert_num,
**kwargs,
layer=layer,
x=x,
router_logits=router_logits,
top_k=top_k,
renormalize=renormalize,
use_grouped_topk=use_grouped_topk,
global_num_experts=global_num_experts,
expert_map=expert_map,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
is_prefill=is_prefill,
enable_force_load_balance=enable_force_load_balance,
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
pertoken_scale=pertoken_scale,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
mc2_mask=mc2_mask,
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:

View File

@@ -18,19 +18,11 @@
from abc import ABC, abstractmethod
from collections.abc import Callable
from enum import Enum
from typing import Any
import torch
class QuantType(Enum):
"""Quantization type enum for MoE schemes."""
NONE = 0
W8A8 = 1
W4A8 = 2
MXFP8 = 3
from vllm_ascend.quantization.quant_type import QuantType
class AscendLinearScheme(ABC):
@@ -245,7 +237,10 @@ class AscendMoEScheme(ABC):
enable_force_load_balance: bool = False,
log2phy: torch.Tensor | None = None,
global_redundant_expert_num: int = 0,
**kwargs,
pertoken_scale: Any | None = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
mc2_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""Forward computation for MoE layer.
@@ -268,7 +263,10 @@ class AscendMoEScheme(ABC):
enable_force_load_balance: Whether to force load balancing.
log2phy: Logical to physical expert mapping.
global_redundant_expert_num: Number of redundant experts.
**kwargs: Additional keyword arguments.
pertoken_scale: Optional per-token activation scale from prepare stage.
activation: Expert MLP activation type.
apply_router_weight_on_input: Whether to pre-scale hidden states by router weights.
mc2_mask: Optional mask used by MC2 dispatch.
Returns:
Output tensor after MoE computation.

View File

@@ -25,8 +25,9 @@ from vllm.config import get_current_vllm_config
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
from .base import AscendMoEScheme
from .base import AscendMoEScheme, QuantType
from .registry import register_scheme
@@ -103,6 +104,8 @@ def pack_to_int32(weight: torch.Tensor) -> torch.Tensor:
class AscendW4A16FusedMoEMethod(AscendMoEScheme):
"""FusedMoE method for Ascend W4A16."""
quant_type: QuantType = QuantType.W4A16
def __init__(self) -> None:
self.transpose_weight = True
self.num_bits = 4 # dtype = torch.int4
@@ -192,7 +195,10 @@ class AscendW4A16FusedMoEMethod(AscendMoEScheme):
enable_force_load_balance: bool = True,
log2phy: torch.Tensor | None = None,
global_redundant_expert_num: int = 0,
**kwargs,
pertoken_scale: Any | None = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
mc2_mask: torch.Tensor | None = None,
) -> torch.Tensor:
assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, (
"Number of global experts mismatch (excluding redundancy)"
@@ -217,20 +223,26 @@ class AscendW4A16FusedMoEMethod(AscendMoEScheme):
moe_comm_method = _EXTRA_CTX.moe_comm_method
return moe_comm_method.fused_experts(
hidden_states=x,
w1=layer.w13_weight_packed,
w2=layer.w2_weight_packed,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w1_offset=layer.w13_weight_offset,
w2_offset=layer.w2_weight_offset,
topk_weights=topk_weights,
topk_ids=topk_ids,
use_int4_w4a16=True,
expert_map=expert_map,
log2phy=log2phy,
dynamic_eplb=self.dynamic_eplb,
mc2_mask=kwargs.get("mc2_mask"),
fused_experts_input=build_fused_experts_input(
hidden_states=x,
topk_weights=topk_weights,
topk_ids=topk_ids,
w1=layer.w13_weight_packed,
w2=layer.w2_weight_packed,
quant_type=self.quant_type,
dynamic_eplb=self.dynamic_eplb,
expert_map=expert_map,
global_redundant_expert_num=global_redundant_expert_num,
mc2_mask=mc2_mask,
apply_router_weight_on_input=apply_router_weight_on_input,
log2phy=log2phy,
pertoken_scale=pertoken_scale,
activation=activation,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w1_offset=layer.w13_weight_offset,
w2_offset=layer.w2_weight_offset,
)
)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:

View File

@@ -28,6 +28,7 @@ from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD, maybe_trans_nz
from .base import AscendLinearScheme, AscendMoEScheme, QuantType
@@ -343,7 +344,10 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
enable_force_load_balance: bool = False,
log2phy: torch.Tensor | None = None,
global_redundant_expert_num: int = 0,
**kwargs,
pertoken_scale: torch.Tensor | None = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
mc2_mask: torch.Tensor | None = None,
) -> torch.Tensor:
assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, (
"Number of global experts mismatch (excluding redundancy)"
@@ -377,20 +381,26 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
moe_comm_method = _EXTRA_CTX.moe_comm_method
return moe_comm_method.fused_experts(
hidden_states=x,
w1=[layer.w13_weight],
w2=[layer.w2_weight],
w1_scale=[layer.w13_weight_scale],
w2_scale=[layer.w2_weight_scale],
w1_scale_bias=layer.w13_scale_bias if hasattr(layer, "w13_scale_bias") else None,
w2_scale_bias=layer.w2_scale_bias if hasattr(layer, "w2_scale_bias") else None,
topk_weights=topk_weights,
topk_ids=topk_ids,
use_int4_w4a8=True,
expert_map=expert_map,
log2phy=log2phy,
dynamic_eplb=self.dynamic_eplb,
mc2_mask=kwargs.get("mc2_mask"),
fused_experts_input=build_fused_experts_input(
hidden_states=x,
topk_weights=topk_weights,
topk_ids=topk_ids,
w1=[layer.w13_weight],
w2=[layer.w2_weight],
quant_type=self.quant_type,
dynamic_eplb=self.dynamic_eplb,
expert_map=expert_map,
global_redundant_expert_num=global_redundant_expert_num,
mc2_mask=mc2_mask,
apply_router_weight_on_input=apply_router_weight_on_input,
log2phy=log2phy,
pertoken_scale=pertoken_scale,
activation=activation,
w1_scale=[layer.w13_weight_scale],
w2_scale=[layer.w2_weight_scale],
w1_scale_bias=layer.w13_scale_bias if hasattr(layer, "w13_scale_bias") else None,
w2_scale_bias=layer.w2_scale_bias if hasattr(layer, "w2_scale_bias") else None,
)
)
def process_scale(self, weight: torch.Tensor, scale, per_group_scale):

View File

@@ -29,6 +29,7 @@ from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.flash_common3_context import get_flash_common3_context
from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_experts_compute
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, maybe_trans_nz
from .base import AscendLinearScheme, AscendMoEScheme, QuantType
@@ -182,7 +183,9 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
log2phy: torch.Tensor | None = None,
global_redundant_expert_num: int = 0,
pertoken_scale: Any | None = None,
**kwargs,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
mc2_mask: torch.Tensor | None = None,
) -> torch.Tensor:
zero_expert_num = getattr(layer, "zero_expert_num", 0)
zero_expert_type = getattr(layer, "zero_expert_type", None)
@@ -249,19 +252,24 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
w2_scale = [layer.fused_w2_scale] if fused_scale_flag else [layer.w2_weight_scale]
final_hidden_states = moe_comm_method.fused_experts(
hidden_states=x,
pertoken_scale=pertoken_scale,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
use_int8_w8a8=True,
expert_map=expert_map,
log2phy=log2phy,
dynamic_eplb=self.dynamic_eplb,
mc2_mask=kwargs.get("mc2_mask"),
fused_experts_input=build_fused_experts_input(
hidden_states=x,
topk_weights=topk_weights,
topk_ids=topk_ids,
w1=w1,
w2=w2,
quant_type=self.quant_type,
dynamic_eplb=self.dynamic_eplb,
expert_map=expert_map,
global_redundant_expert_num=global_redundant_expert_num,
mc2_mask=mc2_mask,
apply_router_weight_on_input=apply_router_weight_on_input,
log2phy=log2phy,
pertoken_scale=pertoken_scale,
activation=activation,
w1_scale=[layer.fused_w1_scale] if fused_scale_flag else w1_scale,
w2_scale=[layer.fused_w2_scale] if fused_scale_flag else w2_scale,
)
)
if zero_expert_num > 0 and zero_expert_type is not None:
final_hidden_states += zero_expert_result

View File

@@ -31,6 +31,7 @@ from vllm_ascend.device.mxfp_compat import (
ensure_mxfp8_moe_available,
)
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
from .base import AscendLinearScheme, AscendMoEScheme, QuantType
from .registry import register_scheme
@@ -170,7 +171,10 @@ class AscendW8A8MXFP8DynamicFusedMoEMethod(AscendMoEScheme):
enable_force_load_balance: bool = True,
log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0,
**kwargs,
pertoken_scale: Any | None = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
mc2_mask: torch.Tensor | None = None,
) -> torch.Tensor:
expected = global_num_experts - global_redundant_expert_num
assert router_logits.shape[1] == expected, "Number of global experts mismatch (excluding redundancy)"
@@ -198,23 +202,29 @@ class AscendW8A8MXFP8DynamicFusedMoEMethod(AscendMoEScheme):
moe_comm_method = _EXTRA_CTX.moe_comm_method
return moe_comm_method.fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale,
w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
use_int8_w8a8=False,
expert_map=expert_map,
log2phy=log2phy,
dynamic_eplb=self.dynamic_eplb,
mc2_mask=kwargs.get("mc2_mask"),
use_mxfp_quant=True,
act_quant_type=torch.float8_e4m3fn,
weight_quant_type=torch.float8_e4m3fn,
scale_type=FLOAT8_E8M0FNU_DTYPE,
per_token_scale_type=FLOAT8_E8M0FNU_DTYPE,
fused_experts_input=build_fused_experts_input(
hidden_states=x,
topk_weights=topk_weights,
topk_ids=topk_ids,
w1=layer.w13_weight,
w2=layer.w2_weight,
quant_type=self.quant_type,
dynamic_eplb=self.dynamic_eplb,
expert_map=expert_map,
global_redundant_expert_num=global_redundant_expert_num,
mc2_mask=mc2_mask,
apply_router_weight_on_input=apply_router_weight_on_input,
log2phy=log2phy,
pertoken_scale=pertoken_scale,
activation=activation,
mxfp_act_quant_type=torch.float8_e4m3fn,
mxfp_weight_quant_type=torch.float8_e4m3fn,
mxfp_scale_dtype=FLOAT8_E8M0FNU_DTYPE,
mxfp_per_token_scale_dtype=FLOAT8_E8M0FNU_DTYPE,
mxfp_use_bf16=(x.dtype == torch.bfloat16),
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
)
)
def process_weights_after_loading(self, layer):

View File

@@ -0,0 +1,33 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Shared quantization enum definitions.
Keep this module lightweight and side-effect free so core runtime modules can
import QuantType without triggering heavy quantization package initialization.
"""
from enum import Enum
class QuantType(Enum):
"""Quantization type enum for MoE schemes."""
NONE = 0
W8A8 = 1
W4A8 = 2
MXFP8 = 3
W4A16 = 4