[refactor] replace scattered business kwargs with typed request objects and explicit stage boundaries (#7024)
### What this PR does / why we need it? Refactor `vllm_ascend/ops/fused_moe` to replace scattered MoE business `**kwargs` with typed request objects and explicit stage boundaries. - Prepare, dispatch, MLP, and quant stages now have clearer ownership. - Main MoE path no longer depends on business `kwargs.get(...)` lookups. - Comm and dispatcher interfaces are request-only on the main path. - UTs can assert stage-level fields directly instead of inferring behavior indirectly. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed. --------- Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -18,19 +18,11 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class QuantType(Enum):
|
||||
"""Quantization type enum for MoE schemes."""
|
||||
|
||||
NONE = 0
|
||||
W8A8 = 1
|
||||
W4A8 = 2
|
||||
MXFP8 = 3
|
||||
from vllm_ascend.quantization.quant_type import QuantType
|
||||
|
||||
|
||||
class AscendLinearScheme(ABC):
|
||||
@@ -245,7 +237,10 @@ class AscendMoEScheme(ABC):
|
||||
enable_force_load_balance: bool = False,
|
||||
log2phy: torch.Tensor | None = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
**kwargs,
|
||||
pertoken_scale: Any | None = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
mc2_mask: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
"""Forward computation for MoE layer.
|
||||
|
||||
@@ -268,7 +263,10 @@ class AscendMoEScheme(ABC):
|
||||
enable_force_load_balance: Whether to force load balancing.
|
||||
log2phy: Logical to physical expert mapping.
|
||||
global_redundant_expert_num: Number of redundant experts.
|
||||
**kwargs: Additional keyword arguments.
|
||||
pertoken_scale: Optional per-token activation scale from prepare stage.
|
||||
activation: Expert MLP activation type.
|
||||
apply_router_weight_on_input: Whether to pre-scale hidden states by router weights.
|
||||
mc2_mask: Optional mask used by MC2 dispatch.
|
||||
|
||||
Returns:
|
||||
Output tensor after MoE computation.
|
||||
|
||||
@@ -25,8 +25,9 @@ from vllm.config import get_current_vllm_config
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
|
||||
|
||||
from .base import AscendMoEScheme
|
||||
from .base import AscendMoEScheme, QuantType
|
||||
from .registry import register_scheme
|
||||
|
||||
|
||||
@@ -103,6 +104,8 @@ def pack_to_int32(weight: torch.Tensor) -> torch.Tensor:
|
||||
class AscendW4A16FusedMoEMethod(AscendMoEScheme):
|
||||
"""FusedMoE method for Ascend W4A16."""
|
||||
|
||||
quant_type: QuantType = QuantType.W4A16
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.transpose_weight = True
|
||||
self.num_bits = 4 # dtype = torch.int4
|
||||
@@ -192,7 +195,10 @@ class AscendW4A16FusedMoEMethod(AscendMoEScheme):
|
||||
enable_force_load_balance: bool = True,
|
||||
log2phy: torch.Tensor | None = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
**kwargs,
|
||||
pertoken_scale: Any | None = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
mc2_mask: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, (
|
||||
"Number of global experts mismatch (excluding redundancy)"
|
||||
@@ -217,20 +223,26 @@ class AscendW4A16FusedMoEMethod(AscendMoEScheme):
|
||||
|
||||
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
||||
return moe_comm_method.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight_packed,
|
||||
w2=layer.w2_weight_packed,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w1_offset=layer.w13_weight_offset,
|
||||
w2_offset=layer.w2_weight_offset,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
use_int4_w4a16=True,
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy,
|
||||
dynamic_eplb=self.dynamic_eplb,
|
||||
mc2_mask=kwargs.get("mc2_mask"),
|
||||
fused_experts_input=build_fused_experts_input(
|
||||
hidden_states=x,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
w1=layer.w13_weight_packed,
|
||||
w2=layer.w2_weight_packed,
|
||||
quant_type=self.quant_type,
|
||||
dynamic_eplb=self.dynamic_eplb,
|
||||
expert_map=expert_map,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
mc2_mask=mc2_mask,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
log2phy=log2phy,
|
||||
pertoken_scale=pertoken_scale,
|
||||
activation=activation,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
w1_offset=layer.w13_weight_offset,
|
||||
w2_offset=layer.w2_weight_offset,
|
||||
)
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
|
||||
@@ -28,6 +28,7 @@ from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
|
||||
from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD, maybe_trans_nz
|
||||
|
||||
from .base import AscendLinearScheme, AscendMoEScheme, QuantType
|
||||
@@ -343,7 +344,10 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
|
||||
enable_force_load_balance: bool = False,
|
||||
log2phy: torch.Tensor | None = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
**kwargs,
|
||||
pertoken_scale: torch.Tensor | None = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
mc2_mask: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, (
|
||||
"Number of global experts mismatch (excluding redundancy)"
|
||||
@@ -377,20 +381,26 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme):
|
||||
|
||||
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
||||
return moe_comm_method.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=[layer.w13_weight],
|
||||
w2=[layer.w2_weight],
|
||||
w1_scale=[layer.w13_weight_scale],
|
||||
w2_scale=[layer.w2_weight_scale],
|
||||
w1_scale_bias=layer.w13_scale_bias if hasattr(layer, "w13_scale_bias") else None,
|
||||
w2_scale_bias=layer.w2_scale_bias if hasattr(layer, "w2_scale_bias") else None,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
use_int4_w4a8=True,
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy,
|
||||
dynamic_eplb=self.dynamic_eplb,
|
||||
mc2_mask=kwargs.get("mc2_mask"),
|
||||
fused_experts_input=build_fused_experts_input(
|
||||
hidden_states=x,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
w1=[layer.w13_weight],
|
||||
w2=[layer.w2_weight],
|
||||
quant_type=self.quant_type,
|
||||
dynamic_eplb=self.dynamic_eplb,
|
||||
expert_map=expert_map,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
mc2_mask=mc2_mask,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
log2phy=log2phy,
|
||||
pertoken_scale=pertoken_scale,
|
||||
activation=activation,
|
||||
w1_scale=[layer.w13_weight_scale],
|
||||
w2_scale=[layer.w2_weight_scale],
|
||||
w1_scale_bias=layer.w13_scale_bias if hasattr(layer, "w13_scale_bias") else None,
|
||||
w2_scale_bias=layer.w2_scale_bias if hasattr(layer, "w2_scale_bias") else None,
|
||||
)
|
||||
)
|
||||
|
||||
def process_scale(self, weight: torch.Tensor, scale, per_group_scale):
|
||||
|
||||
@@ -29,6 +29,7 @@ from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.flash_common3_context import get_flash_common3_context
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_experts_compute
|
||||
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, maybe_trans_nz
|
||||
|
||||
from .base import AscendLinearScheme, AscendMoEScheme, QuantType
|
||||
@@ -182,7 +183,9 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
|
||||
log2phy: torch.Tensor | None = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
pertoken_scale: Any | None = None,
|
||||
**kwargs,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
mc2_mask: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
zero_expert_num = getattr(layer, "zero_expert_num", 0)
|
||||
zero_expert_type = getattr(layer, "zero_expert_type", None)
|
||||
@@ -249,19 +252,24 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme):
|
||||
w2_scale = [layer.fused_w2_scale] if fused_scale_flag else [layer.w2_weight_scale]
|
||||
|
||||
final_hidden_states = moe_comm_method.fused_experts(
|
||||
hidden_states=x,
|
||||
pertoken_scale=pertoken_scale,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
use_int8_w8a8=True,
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy,
|
||||
dynamic_eplb=self.dynamic_eplb,
|
||||
mc2_mask=kwargs.get("mc2_mask"),
|
||||
fused_experts_input=build_fused_experts_input(
|
||||
hidden_states=x,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
quant_type=self.quant_type,
|
||||
dynamic_eplb=self.dynamic_eplb,
|
||||
expert_map=expert_map,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
mc2_mask=mc2_mask,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
log2phy=log2phy,
|
||||
pertoken_scale=pertoken_scale,
|
||||
activation=activation,
|
||||
w1_scale=[layer.fused_w1_scale] if fused_scale_flag else w1_scale,
|
||||
w2_scale=[layer.fused_w2_scale] if fused_scale_flag else w2_scale,
|
||||
)
|
||||
)
|
||||
if zero_expert_num > 0 and zero_expert_type is not None:
|
||||
final_hidden_states += zero_expert_result
|
||||
|
||||
@@ -31,6 +31,7 @@ from vllm_ascend.device.mxfp_compat import (
|
||||
ensure_mxfp8_moe_available,
|
||||
)
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input
|
||||
|
||||
from .base import AscendLinearScheme, AscendMoEScheme, QuantType
|
||||
from .registry import register_scheme
|
||||
@@ -170,7 +171,10 @@ class AscendW8A8MXFP8DynamicFusedMoEMethod(AscendMoEScheme):
|
||||
enable_force_load_balance: bool = True,
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
**kwargs,
|
||||
pertoken_scale: Any | None = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
mc2_mask: torch.Tensor | None = None,
|
||||
) -> torch.Tensor:
|
||||
expected = global_num_experts - global_redundant_expert_num
|
||||
assert router_logits.shape[1] == expected, "Number of global experts mismatch (excluding redundancy)"
|
||||
@@ -198,23 +202,29 @@ class AscendW8A8MXFP8DynamicFusedMoEMethod(AscendMoEScheme):
|
||||
|
||||
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
||||
return moe_comm_method.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2=layer.w2_weight,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
use_int8_w8a8=False,
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy,
|
||||
dynamic_eplb=self.dynamic_eplb,
|
||||
mc2_mask=kwargs.get("mc2_mask"),
|
||||
use_mxfp_quant=True,
|
||||
act_quant_type=torch.float8_e4m3fn,
|
||||
weight_quant_type=torch.float8_e4m3fn,
|
||||
scale_type=FLOAT8_E8M0FNU_DTYPE,
|
||||
per_token_scale_type=FLOAT8_E8M0FNU_DTYPE,
|
||||
fused_experts_input=build_fused_experts_input(
|
||||
hidden_states=x,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
quant_type=self.quant_type,
|
||||
dynamic_eplb=self.dynamic_eplb,
|
||||
expert_map=expert_map,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
mc2_mask=mc2_mask,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
log2phy=log2phy,
|
||||
pertoken_scale=pertoken_scale,
|
||||
activation=activation,
|
||||
mxfp_act_quant_type=torch.float8_e4m3fn,
|
||||
mxfp_weight_quant_type=torch.float8_e4m3fn,
|
||||
mxfp_scale_dtype=FLOAT8_E8M0FNU_DTYPE,
|
||||
mxfp_per_token_scale_dtype=FLOAT8_E8M0FNU_DTYPE,
|
||||
mxfp_use_bf16=(x.dtype == torch.bfloat16),
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
)
|
||||
)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
|
||||
Reference in New Issue
Block a user