### What this PR does / why we need it? Refactor `vllm_ascend/ops/fused_moe` to replace scattered MoE business `**kwargs` with typed request objects and explicit stage boundaries. - Prepare, dispatch, MLP, and quant stages now have clearer ownership. - Main MoE path no longer depends on business `kwargs.get(...)` lookups. - Comm and dispatcher interfaces are request-only on the main path. - UTs can assert stage-level fields directly instead of inferring behavior indirectly. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed. --------- Signed-off-by: linfeng-yuan <1102311262@qq.com>
241 lines
11 KiB
Python
241 lines
11 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# Copyright 2023 The vLLM team.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
import unittest
|
|
|
|
import torch
|
|
|
|
import vllm_ascend.ops.fused_moe.moe_runtime_args as runtime_args
|
|
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
|
|
MoEAllGatherCombineMetadata,
|
|
MoETokenDispatchOutput,
|
|
MoEWeights,
|
|
build_fused_experts_input,
|
|
build_mlp_compute_input,
|
|
build_token_dispatch_input,
|
|
)
|
|
from vllm_ascend.quantization.quant_type import QuantType
|
|
|
|
|
|
class TestMoERuntimeArgs(unittest.TestCase):
|
|
def test_runtime_args_facade_exports_public_contracts_and_builders(self):
|
|
expected_symbols = [
|
|
"MoEAllGatherCombineMetadata",
|
|
"MoEAllToAllCombineMetadata",
|
|
"MoEFusedExpertsInput",
|
|
"MoEMC2CombineMetadata",
|
|
"MoEMlpComputeInput",
|
|
"MoEPrepareOutput",
|
|
"MoEQuantParams",
|
|
"MoERoutingParams",
|
|
"MoETokenDispatchInput",
|
|
"MoETokenDispatchOutput",
|
|
"MoEWeights",
|
|
"TMoECombineMetadata",
|
|
"build_fused_experts_input",
|
|
"build_mlp_compute_input",
|
|
"build_token_dispatch_input",
|
|
]
|
|
|
|
for symbol in expected_symbols:
|
|
with self.subTest(symbol=symbol):
|
|
self.assertTrue(hasattr(runtime_args, symbol))
|
|
self.assertFalse(hasattr(runtime_args, "MoEMxfpParams"))
|
|
|
|
def test_build_fused_experts_input_preserves_runtime_semantics(self):
|
|
for quant_type in (
|
|
QuantType.NONE,
|
|
QuantType.W4A16,
|
|
QuantType.W4A8,
|
|
QuantType.W8A8,
|
|
QuantType.MXFP8,
|
|
):
|
|
with self.subTest(quant_type=quant_type):
|
|
hidden_states = torch.randn(4, 8)
|
|
topk_weights = torch.randn(4, 2)
|
|
topk_ids = torch.randint(0, 4, (4, 2), dtype=torch.int32)
|
|
fused_experts_input = build_fused_experts_input(
|
|
hidden_states=hidden_states,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
w1=torch.randn(2, 8, 16),
|
|
w2=torch.randn(2, 16, 8),
|
|
quant_type=quant_type,
|
|
dynamic_eplb=True,
|
|
expert_map=torch.tensor([0, 1, 2, 3], dtype=torch.int32),
|
|
global_redundant_expert_num=2,
|
|
mc2_mask=torch.tensor([True, False, True, False]),
|
|
apply_router_weight_on_input=True,
|
|
log2phy=torch.tensor([3, 2, 1, 0], dtype=torch.int32),
|
|
pertoken_scale=torch.randn(4),
|
|
activation="gelu",
|
|
mxfp_act_quant_type=torch.float8_e4m3fn if quant_type == QuantType.MXFP8 else None,
|
|
)
|
|
|
|
self.assertIs(fused_experts_input.hidden_states, hidden_states)
|
|
self.assertIs(fused_experts_input.topk_weights, topk_weights)
|
|
self.assertIs(fused_experts_input.topk_ids, topk_ids)
|
|
self.assertTrue(fused_experts_input.dynamic_eplb)
|
|
self.assertTrue(fused_experts_input.routing.apply_router_weight_on_input)
|
|
self.assertEqual(fused_experts_input.routing.global_redundant_expert_num, 2)
|
|
self.assertEqual(fused_experts_input.activation, "gelu")
|
|
self.assertEqual(fused_experts_input.quant.quant_type, quant_type)
|
|
|
|
def test_build_fused_experts_input_merges_dense_and_quant_weights(self):
|
|
w1 = torch.randn(2, 8, 16)
|
|
w2 = torch.randn(2, 16, 8)
|
|
w1_scale = [torch.randn(1)]
|
|
w2_scale = [torch.randn(1)]
|
|
w1_scale_bias = torch.randn(1)
|
|
w2_scale_bias = torch.randn(1)
|
|
w1_offset = torch.randn(1)
|
|
w2_offset = torch.randn(1)
|
|
|
|
fused_experts_input = build_fused_experts_input(
|
|
hidden_states=torch.randn(4, 8),
|
|
topk_weights=torch.randn(4, 2),
|
|
topk_ids=torch.randint(0, 4, (4, 2), dtype=torch.int32),
|
|
w1=w1,
|
|
w2=w2,
|
|
quant_type=QuantType.W8A8,
|
|
dynamic_eplb=False,
|
|
w1_scale=w1_scale,
|
|
w2_scale=w2_scale,
|
|
w1_scale_bias=w1_scale_bias,
|
|
w2_scale_bias=w2_scale_bias,
|
|
w1_offset=w1_offset,
|
|
w2_offset=w2_offset,
|
|
)
|
|
|
|
self.assertIsInstance(fused_experts_input.weights, MoEWeights)
|
|
self.assertIs(fused_experts_input.weights.w1, w1)
|
|
self.assertIs(fused_experts_input.weights.w2, w2)
|
|
self.assertIs(fused_experts_input.weights.w1_scale, w1_scale)
|
|
self.assertIs(fused_experts_input.weights.w2_scale, w2_scale)
|
|
self.assertIs(fused_experts_input.weights.w1_scale_bias, w1_scale_bias)
|
|
self.assertIs(fused_experts_input.weights.w2_scale_bias, w2_scale_bias)
|
|
self.assertIs(fused_experts_input.weights.w1_offset, w1_offset)
|
|
self.assertIs(fused_experts_input.weights.w2_offset, w2_offset)
|
|
|
|
def test_build_token_dispatch_input_supports_remapped_topk_ids(self):
|
|
fused_experts_input = build_fused_experts_input(
|
|
hidden_states=torch.randn(2, 4),
|
|
topk_weights=torch.randn(2, 1),
|
|
topk_ids=torch.tensor([[0], [1]], dtype=torch.int32),
|
|
w1=torch.randn(1, 4, 8),
|
|
w2=torch.randn(1, 8, 4),
|
|
quant_type=QuantType.NONE,
|
|
dynamic_eplb=False,
|
|
)
|
|
routed_topk_ids = torch.tensor([[3], [2]], dtype=torch.int32)
|
|
|
|
token_dispatch_input = build_token_dispatch_input(
|
|
fused_experts_input=fused_experts_input,
|
|
topk_ids=routed_topk_ids,
|
|
)
|
|
|
|
self.assertIs(token_dispatch_input.hidden_states, fused_experts_input.hidden_states)
|
|
self.assertIs(token_dispatch_input.topk_weights, fused_experts_input.topk_weights)
|
|
self.assertIs(token_dispatch_input.routing, fused_experts_input.routing)
|
|
self.assertIs(token_dispatch_input.quant, fused_experts_input.quant)
|
|
self.assertIs(token_dispatch_input.topk_ids, routed_topk_ids)
|
|
|
|
def test_build_fused_experts_input_requires_primitive_mxfp_params_for_mxfp_quant(self):
|
|
with self.assertRaisesRegex(ValueError, "primitive MXFP params are required"):
|
|
build_fused_experts_input(
|
|
hidden_states=torch.randn(2, 8),
|
|
topk_weights=torch.randn(2, 2),
|
|
topk_ids=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),
|
|
w1=torch.randn(2, 8, 16),
|
|
w2=torch.randn(2, 16, 8),
|
|
quant_type=QuantType.MXFP8,
|
|
dynamic_eplb=False,
|
|
)
|
|
|
|
def test_build_mlp_compute_input_derives_fusion_and_preserves_mxfp_params(self):
|
|
fused_experts_input = build_fused_experts_input(
|
|
hidden_states=torch.randn(2, 8, dtype=torch.bfloat16),
|
|
topk_weights=torch.randn(2, 2),
|
|
topk_ids=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),
|
|
w1=torch.randn(2, 8, 16),
|
|
w2=torch.randn(2, 16, 8),
|
|
quant_type=QuantType.MXFP8,
|
|
dynamic_eplb=False,
|
|
mxfp_act_quant_type=torch.float8_e4m3fn,
|
|
mxfp_weight_quant_type=torch.float8_e4m3fn,
|
|
mxfp_scale_dtype=torch.float32,
|
|
mxfp_per_token_scale_dtype=torch.float16,
|
|
mxfp_use_bf16=False,
|
|
w1_scale=[torch.randn(1)],
|
|
w2_scale=[torch.randn(1)],
|
|
)
|
|
token_dispatch_output = MoETokenDispatchOutput(
|
|
hidden_states=torch.randn(4, 8, dtype=torch.bfloat16),
|
|
group_list=torch.tensor([2, 2], dtype=torch.int64),
|
|
group_list_type=1,
|
|
dynamic_scale=torch.randn(4, 1),
|
|
combine_metadata=MoEAllGatherCombineMetadata(
|
|
topk_weights=fused_experts_input.topk_weights,
|
|
expanded_row_idx=torch.arange(4, dtype=torch.int32),
|
|
restore_shape=torch.Size([2, 8]),
|
|
),
|
|
)
|
|
|
|
mlp_compute_input = build_mlp_compute_input(
|
|
fused_experts_input=fused_experts_input,
|
|
token_dispatch_output=token_dispatch_output,
|
|
use_fusion_ops=True,
|
|
)
|
|
|
|
self.assertIs(mlp_compute_input.hidden_states, token_dispatch_output.hidden_states)
|
|
self.assertIs(mlp_compute_input.weights, fused_experts_input.weights)
|
|
self.assertIs(mlp_compute_input.weights.w1_scale, fused_experts_input.weights.w1_scale)
|
|
self.assertIs(mlp_compute_input.weights.w2_scale, fused_experts_input.weights.w2_scale)
|
|
self.assertTrue(mlp_compute_input.fusion)
|
|
self.assertTrue(mlp_compute_input.quant.is_mxfp)
|
|
assert mlp_compute_input.quant.mxfp is not None
|
|
self.assertEqual(mlp_compute_input.quant.mxfp.scale_dtype, torch.float32)
|
|
self.assertEqual(mlp_compute_input.quant.mxfp.per_token_scale_dtype, torch.float16)
|
|
self.assertFalse(mlp_compute_input.quant.mxfp.use_bf16)
|
|
|
|
def test_build_fused_experts_input_constructs_internal_mxfp_leaf_from_primitives(self):
|
|
fused_experts_input = build_fused_experts_input(
|
|
hidden_states=torch.randn(2, 8, dtype=torch.bfloat16),
|
|
topk_weights=torch.randn(2, 2),
|
|
topk_ids=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),
|
|
w1=torch.randn(2, 8, 16),
|
|
w2=torch.randn(2, 16, 8),
|
|
quant_type=QuantType.MXFP8,
|
|
dynamic_eplb=False,
|
|
mxfp_act_quant_type=torch.float8_e4m3fn,
|
|
mxfp_weight_quant_type=torch.float8_e4m3fn,
|
|
mxfp_scale_dtype=torch.float32,
|
|
mxfp_per_token_scale_dtype=torch.float16,
|
|
mxfp_use_bf16=False,
|
|
)
|
|
|
|
self.assertTrue(fused_experts_input.quant.is_mxfp)
|
|
assert fused_experts_input.quant.mxfp is not None
|
|
self.assertEqual(fused_experts_input.quant.mxfp.act_quant_type, torch.float8_e4m3fn)
|
|
self.assertEqual(fused_experts_input.quant.mxfp.weight_quant_type, torch.float8_e4m3fn)
|
|
self.assertEqual(fused_experts_input.quant.mxfp.scale_dtype, torch.float32)
|
|
self.assertEqual(fused_experts_input.quant.mxfp.per_token_scale_dtype, torch.float16)
|
|
self.assertFalse(fused_experts_input.quant.mxfp.use_bf16)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main(verbosity=2)
|