[refactor] replace scattered business kwargs with typed request objects and explicit stage boundaries (#7024)
### What this PR does / why we need it? Refactor `vllm_ascend/ops/fused_moe` to replace scattered MoE business `**kwargs` with typed request objects and explicit stage boundaries. - Prepare, dispatch, MLP, and quant stages now have clearer ownership. - Main MoE path no longer depends on business `kwargs.get(...)` lookups. - Comm and dispatcher interfaces are request-only on the main path. - UTs can assert stage-level fields directly instead of inferring behavior indirectly. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed. --------- Signed-off-by: linfeng-yuan <1102311262@qq.com>
This commit is contained in:
@@ -19,6 +19,38 @@ import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend._310p.fused_moe.moe_mlp import unified_apply_mlp
|
||||
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
|
||||
MoEMlpComputeInput,
|
||||
MoEQuantParams,
|
||||
MoEWeights,
|
||||
)
|
||||
from vllm_ascend.quantization.quant_type import QuantType
|
||||
|
||||
|
||||
def build_mlp_compute_input_fixture(
|
||||
*,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
with_quant: bool,
|
||||
w1_scale: torch.Tensor | None = None,
|
||||
w2_scale: torch.Tensor | None = None,
|
||||
group_list_type: int = 1,
|
||||
) -> MoEMlpComputeInput:
|
||||
return MoEMlpComputeInput(
|
||||
hidden_states=hidden_states,
|
||||
group_list=group_list,
|
||||
group_list_type=group_list_type,
|
||||
dynamic_scale=None,
|
||||
topk_scales=None,
|
||||
weights=MoEWeights(w1=w1, w2=w2, w1_scale=w1_scale, w2_scale=w2_scale),
|
||||
quant=MoEQuantParams(quant_type=QuantType.W8A8 if with_quant else QuantType.NONE),
|
||||
fusion=False,
|
||||
activation="silu",
|
||||
need_trans=False,
|
||||
dynamic_eplb=False,
|
||||
)
|
||||
|
||||
|
||||
class TestUnifiedApplyMLP310(TestBase):
|
||||
@@ -38,14 +70,13 @@ class TestUnifiedApplyMLP310(TestBase):
|
||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||
|
||||
result = unified_apply_mlp(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=None,
|
||||
w2=w2,
|
||||
w2_scale=None,
|
||||
group_list=group_list,
|
||||
group_list_type=1,
|
||||
with_quant=False,
|
||||
mlp_compute_input=build_mlp_compute_input_fixture(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
group_list=group_list,
|
||||
with_quant=False,
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
||||
@@ -94,14 +125,15 @@ class TestUnifiedApplyMLP310(TestBase):
|
||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||
|
||||
result = unified_apply_mlp(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
group_list_type=1,
|
||||
with_quant=True,
|
||||
mlp_compute_input=build_mlp_compute_input_fixture(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
group_list=group_list,
|
||||
with_quant=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
)
|
||||
)
|
||||
|
||||
mock_cumsum.assert_called_once()
|
||||
|
||||
@@ -95,4 +95,4 @@ def test_SiluAndMul_forward_310p(
|
||||
assert torch.allclose(actual_arg, expected_arg), "swiglu called with unexpected input"
|
||||
|
||||
expected_out = (dummy_tensor[..., :h] + 1) * dummy_tensor[..., h:]
|
||||
assert torch.allclose(out, expected_out)
|
||||
assert torch.allclose(out, expected_out)
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from typing import List, TypedDict
|
||||
from typing import TypedDict
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -20,12 +20,19 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod
|
||||
from vllm_ascend.ops.fused_moe.moe_mlp import (cumsum_group_list,
|
||||
unified_apply_mlp)
|
||||
from vllm_ascend.ops.fused_moe.moe_mlp import cumsum_group_list, unified_apply_mlp
|
||||
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
|
||||
MoEMlpComputeInput,
|
||||
MoEPrepareOutput,
|
||||
MoEQuantParams,
|
||||
MoEWeights,
|
||||
)
|
||||
from vllm_ascend.quantization.quant_type import QuantType
|
||||
from vllm_ascend.utils import AscendDeviceType, adapt_patch
|
||||
|
||||
adapt_patch(True)
|
||||
@@ -54,6 +61,51 @@ def mock_npu_format_cast(weight_data, format):
|
||||
return weight_data
|
||||
|
||||
|
||||
def build_mlp_compute_input_fixture(
|
||||
*,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor | list[torch.Tensor],
|
||||
w2: torch.Tensor | list[torch.Tensor],
|
||||
group_list: torch.Tensor,
|
||||
with_quant: bool,
|
||||
group_list_type: int = 1,
|
||||
dynamic_scale: torch.Tensor | None = None,
|
||||
topk_scales: torch.Tensor | None = None,
|
||||
w1_scale: torch.Tensor | list[torch.Tensor] | None = None,
|
||||
w2_scale: torch.Tensor | list[torch.Tensor] | None = None,
|
||||
w1_scale_bias: torch.Tensor | None = None,
|
||||
w2_scale_bias: torch.Tensor | None = None,
|
||||
w1_offset: torch.Tensor | None = None,
|
||||
w2_offset: torch.Tensor | None = None,
|
||||
fusion: bool = False,
|
||||
activation: str = "silu",
|
||||
need_trans: bool = True,
|
||||
dynamic_eplb: bool = False,
|
||||
) -> MoEMlpComputeInput:
|
||||
return MoEMlpComputeInput(
|
||||
hidden_states=hidden_states,
|
||||
group_list=group_list,
|
||||
group_list_type=group_list_type,
|
||||
dynamic_scale=dynamic_scale,
|
||||
topk_scales=topk_scales,
|
||||
weights=MoEWeights(
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
w1_offset=w1_offset,
|
||||
w2_offset=w2_offset,
|
||||
),
|
||||
quant=MoEQuantParams(quant_type=QuantType.W8A8 if with_quant else QuantType.NONE),
|
||||
fusion=fusion,
|
||||
activation=activation,
|
||||
need_trans=need_trans,
|
||||
dynamic_eplb=dynamic_eplb,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_vllm_config_mock(mocker: MockerFixture):
|
||||
mock_hf_config = MagicMock()
|
||||
@@ -77,7 +129,13 @@ def mock_dist_env(mocker: MockerFixture):
|
||||
mock_moe_comm_method = MagicMock()
|
||||
|
||||
def mock_prepare(hidden_states, router_logits, **kwargs):
|
||||
return hidden_states, router_logits
|
||||
return MoEPrepareOutput(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
mc2_mask=kwargs.get("mc2_mask"),
|
||||
padded_hidden_states_shape=None,
|
||||
pertoken_scale=None,
|
||||
)
|
||||
|
||||
mock_moe_comm_method.prepare.side_effect = mock_prepare
|
||||
|
||||
@@ -204,18 +262,18 @@ def moe_method(mock_dist_env):
|
||||
|
||||
class Device(TypedDict):
|
||||
device_id: int
|
||||
device_expert: List[int]
|
||||
device_expert: list[int]
|
||||
|
||||
|
||||
class Layer(TypedDict):
|
||||
layer_id: int
|
||||
device_count: int
|
||||
device_list: List[Device]
|
||||
device_list: list[Device]
|
||||
|
||||
|
||||
class MockData(TypedDict):
|
||||
moe_layer_count: int
|
||||
layer_list: List[Layer]
|
||||
layer_list: list[Layer]
|
||||
|
||||
|
||||
class MockQuantMethod(nn.Module):
|
||||
@@ -338,18 +396,15 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
w2_scale = torch.randn(5, 20, dtype=torch.bfloat16)
|
||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||
|
||||
result = unified_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
dynamic_scale=None,
|
||||
group_list_type=1,
|
||||
w1_scale_bias=None,
|
||||
w2_scale_bias=None,
|
||||
topk_scales=None,
|
||||
with_quant=True)
|
||||
result = unified_apply_mlp(mlp_compute_input=build_mlp_compute_input_fixture(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
group_list=group_list,
|
||||
with_quant=True,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
))
|
||||
|
||||
mock_get_forward_context.assert_called()
|
||||
|
||||
@@ -383,18 +438,14 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||
topk_scales = torch.randn(10, 1, dtype=torch.float16)
|
||||
|
||||
result = unified_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=None,
|
||||
w2=w2,
|
||||
w2_scale=None,
|
||||
group_list=group_list,
|
||||
dynamic_scale=None,
|
||||
group_list_type=1,
|
||||
w1_scale_bias=None,
|
||||
w2_scale_bias=None,
|
||||
topk_scales=topk_scales,
|
||||
with_quant=False)
|
||||
result = unified_apply_mlp(mlp_compute_input=build_mlp_compute_input_fixture(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
group_list=group_list,
|
||||
with_quant=False,
|
||||
topk_scales=topk_scales,
|
||||
))
|
||||
|
||||
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
||||
mock_npu_swiglu.assert_called_once()
|
||||
@@ -445,18 +496,18 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||
provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32)
|
||||
|
||||
result = unified_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
dynamic_scale=provided_dynamic_scale,
|
||||
group_list_type=1,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
topk_scales=None,
|
||||
with_quant=True)
|
||||
result = unified_apply_mlp(mlp_compute_input=build_mlp_compute_input_fixture(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
group_list=group_list,
|
||||
with_quant=True,
|
||||
dynamic_scale=provided_dynamic_scale,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
))
|
||||
|
||||
mock_get_forward_context.assert_called()
|
||||
|
||||
@@ -490,18 +541,14 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||
topk_scales = torch.randn(10, 1, dtype=torch.float16)
|
||||
|
||||
result = unified_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=None,
|
||||
w2=w2,
|
||||
w2_scale=None,
|
||||
group_list=group_list,
|
||||
dynamic_scale=None,
|
||||
group_list_type=1,
|
||||
w1_scale_bias=None,
|
||||
w2_scale_bias=None,
|
||||
topk_scales=topk_scales,
|
||||
with_quant=False)
|
||||
result = unified_apply_mlp(mlp_compute_input=build_mlp_compute_input_fixture(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
group_list=group_list,
|
||||
with_quant=False,
|
||||
topk_scales=topk_scales,
|
||||
))
|
||||
|
||||
self.assertEqual(mock_npu_grouped_matmul.call_count, 2)
|
||||
mock_npu_swiglu.assert_called_once()
|
||||
@@ -556,19 +603,19 @@ class TestUnifiedApplyMLP(TestBase):
|
||||
group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64)
|
||||
provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32)
|
||||
|
||||
result = unified_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
dynamic_scale=provided_dynamic_scale,
|
||||
group_list_type=1,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
topk_scales=None,
|
||||
with_quant=True,
|
||||
fusion=True)
|
||||
result = unified_apply_mlp(mlp_compute_input=build_mlp_compute_input_fixture(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
group_list=group_list,
|
||||
with_quant=True,
|
||||
dynamic_scale=provided_dynamic_scale,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
fusion=True,
|
||||
))
|
||||
|
||||
mock_get_forward_context.assert_called()
|
||||
mock_npu_grouped_matmul.assert_called_once()
|
||||
|
||||
@@ -4,12 +4,21 @@ import torch
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
|
||||
AlltoAllCommImpl,
|
||||
MC2CommImpl)
|
||||
from vllm_ascend.ops.fused_moe.moe_comm_method import (
|
||||
AllGatherCommImpl,
|
||||
AlltoAllCommImpl,
|
||||
MC2CommImpl,
|
||||
)
|
||||
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
|
||||
MoEAllGatherCombineMetadata,
|
||||
MoEFusedExpertsInput,
|
||||
MoEPrepareOutput,
|
||||
MoEQuantParams,
|
||||
MoERoutingParams,
|
||||
MoEWeights,
|
||||
)
|
||||
from vllm_ascend.ops.fused_moe.token_dispatcher import MoETokenDispatchOutput
|
||||
from vllm_ascend.quantization.methods.base import QuantType
|
||||
from vllm_ascend.ops.fused_moe.token_dispatcher import (TokenCombineResult,
|
||||
TokenDispatchResult)
|
||||
|
||||
|
||||
class TestMoECommMethod(TestBase):
|
||||
@@ -45,8 +54,11 @@ class TestMoECommMethod(TestBase):
|
||||
|
||||
# Mock prepare finalize
|
||||
mock_pf_instance = MagicMock()
|
||||
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
|
||||
torch.randn(4, 2), None, None)
|
||||
mock_pf_instance.prepare.return_value = MoEPrepareOutput(
|
||||
hidden_states=torch.randn(4, 8),
|
||||
router_logits=torch.randn(4, 2),
|
||||
mc2_mask=None,
|
||||
padded_hidden_states_shape=None)
|
||||
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
|
||||
mock_prepare_finalize.return_value = mock_pf_instance
|
||||
|
||||
@@ -60,8 +72,9 @@ class TestMoECommMethod(TestBase):
|
||||
# Test prepare method
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare(
|
||||
hidden_states, router_logits)
|
||||
prepare_output = comm_impl.prepare(hidden_states, router_logits)
|
||||
h_out = prepare_output.hidden_states
|
||||
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
|
||||
|
||||
# Verify prepare was called with correct arguments
|
||||
mock_pf_instance.prepare.assert_called_once_with(
|
||||
@@ -70,7 +83,7 @@ class TestMoECommMethod(TestBase):
|
||||
# Test finalize method
|
||||
comm_impl.finalize(h_out,
|
||||
reduce_results=True,
|
||||
context_metadata=context_metadata)
|
||||
padded_hidden_states_shape=padded_hidden_states_shape)
|
||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
|
||||
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@@ -86,10 +99,11 @@ class TestMoECommMethod(TestBase):
|
||||
|
||||
# Mock prepare finalize
|
||||
mock_pf_instance = MagicMock()
|
||||
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
|
||||
torch.randn(4, 2),
|
||||
torch.tensor([1, 0, 1,
|
||||
0]), None)
|
||||
mock_pf_instance.prepare.return_value = MoEPrepareOutput(
|
||||
hidden_states=torch.randn(4, 8),
|
||||
router_logits=torch.randn(4, 2),
|
||||
mc2_mask=torch.tensor([1, 0, 1, 0]),
|
||||
padded_hidden_states_shape=None)
|
||||
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
|
||||
mock_prepare_finalize.return_value = mock_pf_instance
|
||||
|
||||
@@ -103,8 +117,9 @@ class TestMoECommMethod(TestBase):
|
||||
# Test prepare method
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare(
|
||||
hidden_states, router_logits)
|
||||
prepare_output = comm_impl.prepare(hidden_states, router_logits)
|
||||
h_out = prepare_output.hidden_states
|
||||
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
|
||||
|
||||
# Verify prepare was called with correct arguments
|
||||
mock_pf_instance.prepare.assert_called_once_with(
|
||||
@@ -113,7 +128,7 @@ class TestMoECommMethod(TestBase):
|
||||
# Test finalize method
|
||||
comm_impl.finalize(h_out,
|
||||
reduce_results=True,
|
||||
context_metadata=context_metadata)
|
||||
padded_hidden_states_shape=padded_hidden_states_shape)
|
||||
mock_pf_instance.finalize.assert_called_once_with(h_out, True, None)
|
||||
|
||||
@patch('vllm_ascend.ascend_forward_context.get_forward_context')
|
||||
@@ -133,8 +148,11 @@ class TestMoECommMethod(TestBase):
|
||||
|
||||
# Mock prepare finalize
|
||||
mock_pf_instance = MagicMock()
|
||||
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
|
||||
torch.randn(4, 2), None, None)
|
||||
mock_pf_instance.prepare.return_value = MoEPrepareOutput(
|
||||
hidden_states=torch.randn(4, 8),
|
||||
router_logits=torch.randn(4, 2),
|
||||
mc2_mask=None,
|
||||
padded_hidden_states_shape=None)
|
||||
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
|
||||
mock_prepare_finalize.return_value = mock_pf_instance
|
||||
|
||||
@@ -148,8 +166,7 @@ class TestMoECommMethod(TestBase):
|
||||
# Test prepare method
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare(
|
||||
hidden_states, router_logits)
|
||||
_ = comm_impl.prepare(hidden_states, router_logits)
|
||||
|
||||
# Verify prepare was called with correct arguments
|
||||
mock_pf_instance.prepare.assert_called_once_with(
|
||||
@@ -174,19 +191,27 @@ class TestMoECommMethod(TestBase):
|
||||
|
||||
# Mock prepare finalize
|
||||
mock_pf_instance = MagicMock()
|
||||
mock_pf_instance.prepare.return_value = (torch.randn(4, 8),
|
||||
torch.randn(4, 2), None)
|
||||
mock_pf_instance.prepare.return_value = MoEPrepareOutput(
|
||||
hidden_states=torch.randn(4, 8),
|
||||
router_logits=torch.randn(4, 2),
|
||||
mc2_mask=None,
|
||||
padded_hidden_states_shape=None)
|
||||
mock_pf_instance.finalize.return_value = torch.randn(4, 8)
|
||||
mock_prepare_finalize.return_value = mock_pf_instance
|
||||
|
||||
# Mock token dispatcher
|
||||
mock_td_instance = MagicMock()
|
||||
mock_td_instance.token_dispatch.return_value = TokenDispatchResult(
|
||||
hidden_states=torch.randn(6, 8),
|
||||
group_list=torch.tensor([2, 2, 2]),
|
||||
group_list_type=1)
|
||||
mock_td_instance.token_combine.return_value = TokenCombineResult(
|
||||
routed_out=torch.randn(4, 8))
|
||||
dispatch_topk_weights = torch.tensor([[0.5, 0.5], [0.3, 0.7], [0.8, 0.2], [0.6, 0.4]])
|
||||
mock_td_instance.token_dispatch.return_value = MoETokenDispatchOutput(
|
||||
hidden_states=torch.randn(6, 8),
|
||||
group_list=torch.tensor([2, 2, 2]),
|
||||
group_list_type=1,
|
||||
combine_metadata=MoEAllGatherCombineMetadata(
|
||||
topk_weights=dispatch_topk_weights,
|
||||
expanded_row_idx=torch.arange(8, dtype=torch.int32),
|
||||
restore_shape=torch.Size([4, 8]),
|
||||
))
|
||||
mock_td_instance.token_combine.return_value = torch.randn(4, 8)
|
||||
mock_token_dispatcher.return_value = mock_td_instance
|
||||
|
||||
# Mock unified_apply_mlp
|
||||
@@ -199,8 +224,7 @@ class TestMoECommMethod(TestBase):
|
||||
hidden_states = torch.randn(4, 8).contiguous()
|
||||
w1 = torch.randn(16, 8).contiguous()
|
||||
w2 = torch.randn(16, 8).contiguous()
|
||||
topk_weights = torch.tensor([[0.5, 0.5], [0.3, 0.7], [0.8, 0.2],
|
||||
[0.6, 0.4]])
|
||||
topk_weights = dispatch_topk_weights
|
||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 0], [1, 1]])
|
||||
|
||||
# Make sure tensors are contiguous and have correct strides
|
||||
@@ -208,12 +232,25 @@ class TestMoECommMethod(TestBase):
|
||||
w1 = w1.contiguous()
|
||||
w2 = w2.contiguous()
|
||||
|
||||
result = comm_impl.fused_experts(hidden_states=hidden_states,
|
||||
w1=[w1],
|
||||
w2=[w2],
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation="silu")
|
||||
result = comm_impl.fused_experts(fused_experts_input=MoEFusedExpertsInput(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
weights=MoEWeights(
|
||||
w1=[w1],
|
||||
w2=[w2],
|
||||
),
|
||||
routing=MoERoutingParams(
|
||||
expert_map=None,
|
||||
global_redundant_expert_num=0,
|
||||
mc2_mask=None,
|
||||
apply_router_weight_on_input=False,
|
||||
),
|
||||
activation="silu",
|
||||
need_trans=False,
|
||||
dynamic_eplb=False,
|
||||
quant=MoEQuantParams(),
|
||||
))
|
||||
|
||||
# Verify result shape
|
||||
self.assertEqual(result.routed_out.shape, (4, 8))
|
||||
@@ -223,6 +260,12 @@ class TestMoECommMethod(TestBase):
|
||||
|
||||
# Verify unified_apply_mlp was called
|
||||
mock_unified_apply_mlp.assert_called_once()
|
||||
mlp_compute_input = mock_unified_apply_mlp.call_args.kwargs["mlp_compute_input"]
|
||||
self.assertFalse(mlp_compute_input.fusion)
|
||||
self.assertFalse(mlp_compute_input.quant.is_mxfp)
|
||||
|
||||
# Verify token_combine was called
|
||||
mock_td_instance.token_combine.assert_called_once()
|
||||
mock_td_instance.token_combine.assert_called_once_with(
|
||||
hidden_states=mock_unified_apply_mlp.return_value,
|
||||
combine_metadata=mock_td_instance.token_dispatch.return_value.combine_metadata,
|
||||
)
|
||||
|
||||
@@ -1,9 +1,17 @@
|
||||
import unittest
|
||||
from typing import ClassVar
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
from vllm_ascend.ops.fused_moe.moe_mlp import cumsum_group_list
|
||||
from vllm_ascend.ops.fused_moe.moe_mlp import cumsum_group_list, unified_apply_mlp
|
||||
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
|
||||
MoEMlpComputeInput,
|
||||
MoEQuantParams,
|
||||
MoEWeights,
|
||||
)
|
||||
from vllm_ascend.ops.fused_moe.moe_stage_params import MoEMxfpParams
|
||||
from vllm_ascend.quantization.quant_type import QuantType
|
||||
|
||||
|
||||
class TestCumsumGroupList(unittest.TestCase):
|
||||
@@ -14,7 +22,7 @@ class TestCumsumGroupList(unittest.TestCase):
|
||||
cls.glist_dict = {
|
||||
0: torch.tensor([0, 2, 3, 3]),
|
||||
1: torch.tensor([0, 2, 1, 0]),
|
||||
2: torch.tensor([[1, 2], [2, 1], [0, 0], [0, 0]])
|
||||
2: torch.tensor([[1, 2], [2, 1], [0, 0], [0, 0]]),
|
||||
}
|
||||
|
||||
support_combine = [(0, 0), (1, 0), (0, 1)]
|
||||
@@ -23,29 +31,101 @@ class TestCumsumGroupList(unittest.TestCase):
|
||||
def test_cumsum_group_list_supported_conversion(self):
|
||||
for src_list_type, dst_list_type in self.support_combine:
|
||||
with self.subTest(src=src_list_type, dst=dst_list_type):
|
||||
result = cumsum_group_list(self.glist_dict[src_list_type],
|
||||
src_list_type,
|
||||
dst_list_type,
|
||||
expert_num=4)
|
||||
self.assertTrue(
|
||||
torch.equal(result, self.glist_dict[dst_list_type]))
|
||||
result = cumsum_group_list(self.glist_dict[src_list_type], src_list_type, dst_list_type, expert_num=4)
|
||||
self.assertTrue(torch.equal(result, self.glist_dict[dst_list_type]))
|
||||
|
||||
def test_cumsum_group_list_invalid_type_valueerror(self):
|
||||
with self.assertRaises(ValueError) as excinfo:
|
||||
cumsum_group_list(self.glist_dict[0], 4, 0)
|
||||
self.assertIn("group_list_type should be in [0, 1, 2], but received",
|
||||
str(excinfo.exception))
|
||||
self.assertIn("group_list_type should be in [0, 1, 2], but received", str(excinfo.exception))
|
||||
|
||||
def test_cumsum_group_list_unsupported_conversion_notimplementederror(
|
||||
self):
|
||||
def test_cumsum_group_list_unsupported_conversion_notimplementederror(self):
|
||||
for src_list_type, dst_list_type in self.unsupported_combine:
|
||||
with self.subTest(src=src_list_type, dst=dst_list_type):
|
||||
with self.assertRaises(NotImplementedError) as excinfo:
|
||||
cumsum_group_list(self.glist_dict[0], src_list_type,
|
||||
dst_list_type)
|
||||
self.assertIn("This feature is under development.",
|
||||
str(excinfo.exception))
|
||||
cumsum_group_list(self.glist_dict[0], src_list_type, dst_list_type)
|
||||
self.assertIn("This feature is under development.", str(excinfo.exception))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
class TestUnifiedApplyMlpRequest(unittest.TestCase):
|
||||
def test_request_unquant_path(self):
|
||||
hidden_states = torch.randn(2, 8)
|
||||
expected = torch.randn(2, 8)
|
||||
mlp_compute_input = MoEMlpComputeInput(
|
||||
hidden_states=hidden_states,
|
||||
group_list=torch.tensor([2, 2], dtype=torch.int64),
|
||||
group_list_type=1,
|
||||
dynamic_scale=None,
|
||||
topk_scales=None,
|
||||
weights=MoEWeights(
|
||||
w1=torch.randn(1, 16, 8),
|
||||
w2=torch.randn(1, 8, 8),
|
||||
w1_bias=torch.randn(1, 16),
|
||||
w2_bias=torch.randn(1, 8),
|
||||
),
|
||||
quant=MoEQuantParams(quant_type=QuantType.NONE),
|
||||
fusion=False,
|
||||
activation="silu",
|
||||
need_trans=False,
|
||||
dynamic_eplb=False,
|
||||
)
|
||||
|
||||
with (
|
||||
patch("vllm_ascend.ops.fused_moe.moe_mlp.unquant_apply_mlp", return_value=expected) as mock_unquant,
|
||||
patch("vllm_ascend.ops.fused_moe.moe_mlp.quant_apply_mlp") as mock_quant,
|
||||
):
|
||||
output = unified_apply_mlp(mlp_compute_input=mlp_compute_input)
|
||||
|
||||
self.assertTrue(output is expected)
|
||||
mock_unquant.assert_called_once()
|
||||
self.assertEqual(mock_unquant.call_args.kwargs["activation"], "silu")
|
||||
self.assertFalse(mock_unquant.call_args.kwargs["need_trans"])
|
||||
mock_quant.assert_not_called()
|
||||
|
||||
def test_request_quant_path(self):
|
||||
hidden_states = torch.randn(2, 8)
|
||||
expected = torch.randn(2, 8)
|
||||
mlp_compute_input = MoEMlpComputeInput(
|
||||
hidden_states=hidden_states,
|
||||
group_list=torch.tensor([2, 2], dtype=torch.int64),
|
||||
group_list_type=1,
|
||||
dynamic_scale=torch.randn(2, 1),
|
||||
topk_scales=None,
|
||||
weights=MoEWeights(
|
||||
w1=torch.randn(1, 16, 8),
|
||||
w2=torch.randn(1, 8, 8),
|
||||
w1_scale=[torch.randn(1)],
|
||||
w2_scale=[torch.randn(1)],
|
||||
),
|
||||
quant=MoEQuantParams(
|
||||
quant_type=QuantType.MXFP8,
|
||||
mxfp=MoEMxfpParams(
|
||||
act_quant_type=torch.float8_e4m3fn,
|
||||
weight_quant_type=torch.float8_e4m3fn,
|
||||
use_bf16=False,
|
||||
),
|
||||
),
|
||||
fusion=True,
|
||||
activation="silu",
|
||||
need_trans=False,
|
||||
dynamic_eplb=True,
|
||||
)
|
||||
|
||||
with (
|
||||
patch("vllm_ascend.ops.fused_moe.moe_mlp.quant_apply_mlp", return_value=expected) as mock_quant,
|
||||
patch("vllm_ascend.ops.fused_moe.moe_mlp.unquant_apply_mlp") as mock_unquant,
|
||||
):
|
||||
output = unified_apply_mlp(mlp_compute_input=mlp_compute_input)
|
||||
|
||||
self.assertTrue(output is expected)
|
||||
mock_quant.assert_called_once()
|
||||
quant_kwargs = mock_quant.call_args.kwargs
|
||||
self.assertTrue(quant_kwargs["use_mxfp_quant"])
|
||||
self.assertTrue(quant_kwargs["fusion"])
|
||||
self.assertTrue(quant_kwargs["dynamic_eplb"])
|
||||
self.assertFalse(quant_kwargs["use_bf16"])
|
||||
mock_unquant.assert_not_called()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
240
tests/ut/ops/test_moe_runtime_args.py
Normal file
240
tests/ut/ops/test_moe_runtime_args.py
Normal file
@@ -0,0 +1,240 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
import vllm_ascend.ops.fused_moe.moe_runtime_args as runtime_args
|
||||
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
|
||||
MoEAllGatherCombineMetadata,
|
||||
MoETokenDispatchOutput,
|
||||
MoEWeights,
|
||||
build_fused_experts_input,
|
||||
build_mlp_compute_input,
|
||||
build_token_dispatch_input,
|
||||
)
|
||||
from vllm_ascend.quantization.quant_type import QuantType
|
||||
|
||||
|
||||
class TestMoERuntimeArgs(unittest.TestCase):
|
||||
def test_runtime_args_facade_exports_public_contracts_and_builders(self):
|
||||
expected_symbols = [
|
||||
"MoEAllGatherCombineMetadata",
|
||||
"MoEAllToAllCombineMetadata",
|
||||
"MoEFusedExpertsInput",
|
||||
"MoEMC2CombineMetadata",
|
||||
"MoEMlpComputeInput",
|
||||
"MoEPrepareOutput",
|
||||
"MoEQuantParams",
|
||||
"MoERoutingParams",
|
||||
"MoETokenDispatchInput",
|
||||
"MoETokenDispatchOutput",
|
||||
"MoEWeights",
|
||||
"TMoECombineMetadata",
|
||||
"build_fused_experts_input",
|
||||
"build_mlp_compute_input",
|
||||
"build_token_dispatch_input",
|
||||
]
|
||||
|
||||
for symbol in expected_symbols:
|
||||
with self.subTest(symbol=symbol):
|
||||
self.assertTrue(hasattr(runtime_args, symbol))
|
||||
self.assertFalse(hasattr(runtime_args, "MoEMxfpParams"))
|
||||
|
||||
def test_build_fused_experts_input_preserves_runtime_semantics(self):
|
||||
for quant_type in (
|
||||
QuantType.NONE,
|
||||
QuantType.W4A16,
|
||||
QuantType.W4A8,
|
||||
QuantType.W8A8,
|
||||
QuantType.MXFP8,
|
||||
):
|
||||
with self.subTest(quant_type=quant_type):
|
||||
hidden_states = torch.randn(4, 8)
|
||||
topk_weights = torch.randn(4, 2)
|
||||
topk_ids = torch.randint(0, 4, (4, 2), dtype=torch.int32)
|
||||
fused_experts_input = build_fused_experts_input(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
w1=torch.randn(2, 8, 16),
|
||||
w2=torch.randn(2, 16, 8),
|
||||
quant_type=quant_type,
|
||||
dynamic_eplb=True,
|
||||
expert_map=torch.tensor([0, 1, 2, 3], dtype=torch.int32),
|
||||
global_redundant_expert_num=2,
|
||||
mc2_mask=torch.tensor([True, False, True, False]),
|
||||
apply_router_weight_on_input=True,
|
||||
log2phy=torch.tensor([3, 2, 1, 0], dtype=torch.int32),
|
||||
pertoken_scale=torch.randn(4),
|
||||
activation="gelu",
|
||||
mxfp_act_quant_type=torch.float8_e4m3fn if quant_type == QuantType.MXFP8 else None,
|
||||
)
|
||||
|
||||
self.assertIs(fused_experts_input.hidden_states, hidden_states)
|
||||
self.assertIs(fused_experts_input.topk_weights, topk_weights)
|
||||
self.assertIs(fused_experts_input.topk_ids, topk_ids)
|
||||
self.assertTrue(fused_experts_input.dynamic_eplb)
|
||||
self.assertTrue(fused_experts_input.routing.apply_router_weight_on_input)
|
||||
self.assertEqual(fused_experts_input.routing.global_redundant_expert_num, 2)
|
||||
self.assertEqual(fused_experts_input.activation, "gelu")
|
||||
self.assertEqual(fused_experts_input.quant.quant_type, quant_type)
|
||||
|
||||
def test_build_fused_experts_input_merges_dense_and_quant_weights(self):
|
||||
w1 = torch.randn(2, 8, 16)
|
||||
w2 = torch.randn(2, 16, 8)
|
||||
w1_scale = [torch.randn(1)]
|
||||
w2_scale = [torch.randn(1)]
|
||||
w1_scale_bias = torch.randn(1)
|
||||
w2_scale_bias = torch.randn(1)
|
||||
w1_offset = torch.randn(1)
|
||||
w2_offset = torch.randn(1)
|
||||
|
||||
fused_experts_input = build_fused_experts_input(
|
||||
hidden_states=torch.randn(4, 8),
|
||||
topk_weights=torch.randn(4, 2),
|
||||
topk_ids=torch.randint(0, 4, (4, 2), dtype=torch.int32),
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
quant_type=QuantType.W8A8,
|
||||
dynamic_eplb=False,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
w1_offset=w1_offset,
|
||||
w2_offset=w2_offset,
|
||||
)
|
||||
|
||||
self.assertIsInstance(fused_experts_input.weights, MoEWeights)
|
||||
self.assertIs(fused_experts_input.weights.w1, w1)
|
||||
self.assertIs(fused_experts_input.weights.w2, w2)
|
||||
self.assertIs(fused_experts_input.weights.w1_scale, w1_scale)
|
||||
self.assertIs(fused_experts_input.weights.w2_scale, w2_scale)
|
||||
self.assertIs(fused_experts_input.weights.w1_scale_bias, w1_scale_bias)
|
||||
self.assertIs(fused_experts_input.weights.w2_scale_bias, w2_scale_bias)
|
||||
self.assertIs(fused_experts_input.weights.w1_offset, w1_offset)
|
||||
self.assertIs(fused_experts_input.weights.w2_offset, w2_offset)
|
||||
|
||||
def test_build_token_dispatch_input_supports_remapped_topk_ids(self):
|
||||
fused_experts_input = build_fused_experts_input(
|
||||
hidden_states=torch.randn(2, 4),
|
||||
topk_weights=torch.randn(2, 1),
|
||||
topk_ids=torch.tensor([[0], [1]], dtype=torch.int32),
|
||||
w1=torch.randn(1, 4, 8),
|
||||
w2=torch.randn(1, 8, 4),
|
||||
quant_type=QuantType.NONE,
|
||||
dynamic_eplb=False,
|
||||
)
|
||||
routed_topk_ids = torch.tensor([[3], [2]], dtype=torch.int32)
|
||||
|
||||
token_dispatch_input = build_token_dispatch_input(
|
||||
fused_experts_input=fused_experts_input,
|
||||
topk_ids=routed_topk_ids,
|
||||
)
|
||||
|
||||
self.assertIs(token_dispatch_input.hidden_states, fused_experts_input.hidden_states)
|
||||
self.assertIs(token_dispatch_input.topk_weights, fused_experts_input.topk_weights)
|
||||
self.assertIs(token_dispatch_input.routing, fused_experts_input.routing)
|
||||
self.assertIs(token_dispatch_input.quant, fused_experts_input.quant)
|
||||
self.assertIs(token_dispatch_input.topk_ids, routed_topk_ids)
|
||||
|
||||
def test_build_fused_experts_input_requires_primitive_mxfp_params_for_mxfp_quant(self):
|
||||
with self.assertRaisesRegex(ValueError, "primitive MXFP params are required"):
|
||||
build_fused_experts_input(
|
||||
hidden_states=torch.randn(2, 8),
|
||||
topk_weights=torch.randn(2, 2),
|
||||
topk_ids=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),
|
||||
w1=torch.randn(2, 8, 16),
|
||||
w2=torch.randn(2, 16, 8),
|
||||
quant_type=QuantType.MXFP8,
|
||||
dynamic_eplb=False,
|
||||
)
|
||||
|
||||
def test_build_mlp_compute_input_derives_fusion_and_preserves_mxfp_params(self):
|
||||
fused_experts_input = build_fused_experts_input(
|
||||
hidden_states=torch.randn(2, 8, dtype=torch.bfloat16),
|
||||
topk_weights=torch.randn(2, 2),
|
||||
topk_ids=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),
|
||||
w1=torch.randn(2, 8, 16),
|
||||
w2=torch.randn(2, 16, 8),
|
||||
quant_type=QuantType.MXFP8,
|
||||
dynamic_eplb=False,
|
||||
mxfp_act_quant_type=torch.float8_e4m3fn,
|
||||
mxfp_weight_quant_type=torch.float8_e4m3fn,
|
||||
mxfp_scale_dtype=torch.float32,
|
||||
mxfp_per_token_scale_dtype=torch.float16,
|
||||
mxfp_use_bf16=False,
|
||||
w1_scale=[torch.randn(1)],
|
||||
w2_scale=[torch.randn(1)],
|
||||
)
|
||||
token_dispatch_output = MoETokenDispatchOutput(
|
||||
hidden_states=torch.randn(4, 8, dtype=torch.bfloat16),
|
||||
group_list=torch.tensor([2, 2], dtype=torch.int64),
|
||||
group_list_type=1,
|
||||
dynamic_scale=torch.randn(4, 1),
|
||||
combine_metadata=MoEAllGatherCombineMetadata(
|
||||
topk_weights=fused_experts_input.topk_weights,
|
||||
expanded_row_idx=torch.arange(4, dtype=torch.int32),
|
||||
restore_shape=torch.Size([2, 8]),
|
||||
),
|
||||
)
|
||||
|
||||
mlp_compute_input = build_mlp_compute_input(
|
||||
fused_experts_input=fused_experts_input,
|
||||
token_dispatch_output=token_dispatch_output,
|
||||
use_fusion_ops=True,
|
||||
)
|
||||
|
||||
self.assertIs(mlp_compute_input.hidden_states, token_dispatch_output.hidden_states)
|
||||
self.assertIs(mlp_compute_input.weights, fused_experts_input.weights)
|
||||
self.assertIs(mlp_compute_input.weights.w1_scale, fused_experts_input.weights.w1_scale)
|
||||
self.assertIs(mlp_compute_input.weights.w2_scale, fused_experts_input.weights.w2_scale)
|
||||
self.assertTrue(mlp_compute_input.fusion)
|
||||
self.assertTrue(mlp_compute_input.quant.is_mxfp)
|
||||
assert mlp_compute_input.quant.mxfp is not None
|
||||
self.assertEqual(mlp_compute_input.quant.mxfp.scale_dtype, torch.float32)
|
||||
self.assertEqual(mlp_compute_input.quant.mxfp.per_token_scale_dtype, torch.float16)
|
||||
self.assertFalse(mlp_compute_input.quant.mxfp.use_bf16)
|
||||
|
||||
def test_build_fused_experts_input_constructs_internal_mxfp_leaf_from_primitives(self):
|
||||
fused_experts_input = build_fused_experts_input(
|
||||
hidden_states=torch.randn(2, 8, dtype=torch.bfloat16),
|
||||
topk_weights=torch.randn(2, 2),
|
||||
topk_ids=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32),
|
||||
w1=torch.randn(2, 8, 16),
|
||||
w2=torch.randn(2, 16, 8),
|
||||
quant_type=QuantType.MXFP8,
|
||||
dynamic_eplb=False,
|
||||
mxfp_act_quant_type=torch.float8_e4m3fn,
|
||||
mxfp_weight_quant_type=torch.float8_e4m3fn,
|
||||
mxfp_scale_dtype=torch.float32,
|
||||
mxfp_per_token_scale_dtype=torch.float16,
|
||||
mxfp_use_bf16=False,
|
||||
)
|
||||
|
||||
self.assertTrue(fused_experts_input.quant.is_mxfp)
|
||||
assert fused_experts_input.quant.mxfp is not None
|
||||
self.assertEqual(fused_experts_input.quant.mxfp.act_quant_type, torch.float8_e4m3fn)
|
||||
self.assertEqual(fused_experts_input.quant.mxfp.weight_quant_type, torch.float8_e4m3fn)
|
||||
self.assertEqual(fused_experts_input.quant.mxfp.scale_dtype, torch.float32)
|
||||
self.assertEqual(fused_experts_input.quant.mxfp.per_token_scale_dtype, torch.float16)
|
||||
self.assertFalse(fused_experts_input.quant.mxfp.use_bf16)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
||||
@@ -45,18 +45,22 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
|
||||
h_out, r_out, mask, context_metadata = layer.prepare(
|
||||
hidden_states, router_logits)
|
||||
prepare_output = layer.prepare(hidden_states, router_logits)
|
||||
h_out = prepare_output.hidden_states
|
||||
r_out = prepare_output.router_logits
|
||||
mask = prepare_output.mc2_mask
|
||||
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
|
||||
|
||||
# Check padding and split
|
||||
self.assertEqual(h_out.shape[0], 4)
|
||||
self.assertEqual(r_out.shape[0], 4)
|
||||
self.assertEqual(mask.tolist(), [1, 0, 1])
|
||||
self.assertEqual(padded_hidden_states_shape, torch.Size([4, 8]))
|
||||
|
||||
# Finalize
|
||||
result = layer.finalize(h_out,
|
||||
reduce_results=False,
|
||||
context_metadata=context_metadata)
|
||||
padded_hidden_states_shape=padded_hidden_states_shape)
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
|
||||
@patch(
|
||||
@@ -79,14 +83,19 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
||||
hidden_states = torch.randn(4, 8)
|
||||
router_logits = torch.randn(4, 2)
|
||||
|
||||
h_out, r_out, mask, context_metadata = layer.prepare(
|
||||
prepare_output = layer.prepare(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
enable_shared_expert_dp=False,
|
||||
replace_allreduce=False)
|
||||
h_out = prepare_output.hidden_states
|
||||
r_out = prepare_output.router_logits
|
||||
mask = prepare_output.mc2_mask
|
||||
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
|
||||
|
||||
# With TP=2, should split into 2 parts
|
||||
self.assertEqual(h_out.shape[0], 2)
|
||||
self.assertEqual(padded_hidden_states_shape, torch.Size([4, 8]))
|
||||
|
||||
# Mock all_gather behavior
|
||||
def mock_all_gather_func(tensor_list, tensor, group=None):
|
||||
@@ -101,7 +110,7 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
||||
]
|
||||
final_result = layer.finalize(h_out,
|
||||
reduce_results=False,
|
||||
context_metadata=context_metadata)
|
||||
padded_hidden_states_shape=padded_hidden_states_shape)
|
||||
|
||||
# Should concat back to original size
|
||||
self.assertEqual(final_result.shape[0], 4)
|
||||
@@ -117,15 +126,18 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
|
||||
h_out, r_out, _, context_metadata = layer.prepare(
|
||||
hidden_states, router_logits)
|
||||
prepare_output = layer.prepare(hidden_states, router_logits)
|
||||
h_out = prepare_output.hidden_states
|
||||
r_out = prepare_output.router_logits
|
||||
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
|
||||
|
||||
# Pad to tp_size=1, so no change
|
||||
self.assertEqual(h_out.shape[0], 3)
|
||||
self.assertEqual(padded_hidden_states_shape, torch.Size([3, 8]))
|
||||
|
||||
result = layer.finalize(h_out,
|
||||
reduce_results=False,
|
||||
context_metadata=context_metadata)
|
||||
padded_hidden_states_shape=padded_hidden_states_shape)
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
|
||||
@patch(
|
||||
@@ -141,14 +153,18 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
||||
hidden_states = torch.randn(2, 8)
|
||||
router_logits = torch.randn(2, 2)
|
||||
|
||||
h_out, r_out, _, context_metadata = layer.prepare(
|
||||
prepare_output = layer.prepare(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
enable_shared_expert_dp=False,
|
||||
replace_allreduce=False)
|
||||
h_out = prepare_output.hidden_states
|
||||
r_out = prepare_output.router_logits
|
||||
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
|
||||
|
||||
# Split due to TP=2
|
||||
self.assertEqual(h_out.shape[0], 1)
|
||||
self.assertEqual(padded_hidden_states_shape, torch.Size([2, 8]))
|
||||
|
||||
# Mock all_gather
|
||||
def mock_all_gather_func(tensor_list, tensor, group=None):
|
||||
@@ -163,7 +179,7 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
||||
]
|
||||
final_result = layer.finalize(h_out,
|
||||
reduce_results=False,
|
||||
context_metadata=context_metadata)
|
||||
padded_hidden_states_shape=padded_hidden_states_shape)
|
||||
|
||||
# Should concat back
|
||||
self.assertEqual(final_result.shape[0], 2)
|
||||
@@ -200,12 +216,15 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
||||
hidden_states = torch.randn(3, 8)
|
||||
router_logits = torch.randn(3, 2)
|
||||
|
||||
h_out, r_out, _, context_metadata = layer.prepare(
|
||||
hidden_states, router_logits)
|
||||
prepare_output = layer.prepare(hidden_states, router_logits)
|
||||
h_out = prepare_output.hidden_states
|
||||
r_out = prepare_output.router_logits
|
||||
padded_hidden_states_shape = prepare_output.padded_hidden_states_shape
|
||||
|
||||
# After all-gather with DP=2, should double the batch size
|
||||
self.assertEqual(h_out.shape[0], 12)
|
||||
self.assertEqual(r_out.shape[0], 12)
|
||||
self.assertIsNone(padded_hidden_states_shape)
|
||||
|
||||
# Finalize with reduce_scatter
|
||||
def mock_reduce_scatter_func(tensor, dim):
|
||||
@@ -215,7 +234,7 @@ class TestPrepareAndFinalize(unittest.TestCase):
|
||||
mock_dp_group.reduce_scatter = mock_reduce_scatter_func
|
||||
result = layer.finalize(h_out,
|
||||
reduce_results=False,
|
||||
context_metadata=context_metadata)
|
||||
padded_hidden_states_shape=padded_hidden_states_shape)
|
||||
|
||||
self.assertEqual(result.shape[0], 3)
|
||||
|
||||
|
||||
@@ -17,14 +17,62 @@
|
||||
|
||||
from unittest.mock import MagicMock, PropertyMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
|
||||
from vllm_ascend.ops.fused_moe.moe_runtime_args import (
|
||||
MoEAllGatherCombineMetadata,
|
||||
MoEAllToAllCombineMetadata,
|
||||
MoEMC2CombineMetadata,
|
||||
MoEQuantParams,
|
||||
MoERoutingParams,
|
||||
MoETokenDispatchInput,
|
||||
)
|
||||
from vllm_ascend.ops.fused_moe.token_dispatcher import ( # isort: skip
|
||||
AscendDeviceType, TokenDispatcherWithAll2AllV,
|
||||
TokenDispatcherWithAllGather, TokenDispatcherWithMC2)
|
||||
AscendDeviceType,
|
||||
TokenDispatcherWithAll2AllV,
|
||||
TokenDispatcherWithAllGather,
|
||||
TokenDispatcherWithMC2,
|
||||
)
|
||||
from vllm_ascend.ops.fused_moe.moe_stage_params import MoEMxfpParams
|
||||
from vllm_ascend.quantization.quant_type import QuantType
|
||||
|
||||
|
||||
def build_token_dispatch_input_fixture(
|
||||
*,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
pertoken_scale: torch.Tensor | None = None,
|
||||
quant_type: QuantType = QuantType.NONE,
|
||||
comm_quant_mode: int | None = None,
|
||||
act_quant_type: torch.dtype | None = None,
|
||||
) -> MoETokenDispatchInput:
|
||||
mxfp_spec = None
|
||||
if quant_type == QuantType.MXFP8:
|
||||
mxfp_spec = MoEMxfpParams(act_quant_type=act_quant_type)
|
||||
return MoETokenDispatchInput(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
routing=MoERoutingParams(
|
||||
expert_map=expert_map,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
mc2_mask=None,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
pertoken_scale=pertoken_scale,
|
||||
),
|
||||
quant=MoEQuantParams(
|
||||
quant_type=quant_type,
|
||||
comm_quant_mode=comm_quant_mode,
|
||||
mxfp=mxfp_spec,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TestTokenDispatcherWithMC2(TestBase):
|
||||
@@ -85,7 +133,6 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
def test_init(self):
|
||||
self.assertEqual(self.dispatcher.ep_rank_id, 0)
|
||||
self.assertEqual(self.dispatcher.ep_world_size, 8)
|
||||
self.assertFalse(self.dispatcher.with_quant)
|
||||
self.assertTrue(self.dispatcher.enable_dispatch_v2)
|
||||
self.assertTrue(self.dispatcher.need_extra_args)
|
||||
|
||||
@@ -94,10 +141,16 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
topk_ids = torch.randint(0, 8, (10, 1))
|
||||
topk_weights = torch.randn(10, 1)
|
||||
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
mc2_mask = None
|
||||
|
||||
kwargs = self.dispatcher.get_dispatch_mc2_kwargs(
|
||||
hidden_states, topk_weights, topk_ids, expert_map, mc2_mask)
|
||||
token_dispatch_input = build_token_dispatch_input_fixture(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
global_redundant_expert_num=0,
|
||||
apply_router_weight_on_input=False,
|
||||
pertoken_scale=None,
|
||||
)
|
||||
kwargs = self.dispatcher.get_dispatch_mc2_kwargs(token_dispatch_input)
|
||||
self.assertIn("x", kwargs)
|
||||
self.assertIn("expert_ids", kwargs)
|
||||
self.assertEqual(kwargs["moe_expert_num"], 8)
|
||||
@@ -111,39 +164,42 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
with patch("torch_npu.npu_moe_distribute_dispatch_v2",
|
||||
return_value=(torch.randn(10, 128), ) * 5 +
|
||||
(None, None)) as mock_dispatch:
|
||||
output = self.dispatcher.token_dispatch(hidden_states,
|
||||
topk_weights, topk_ids,
|
||||
expert_map)
|
||||
token_dispatch_input = build_token_dispatch_input_fixture(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
output = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
|
||||
mock_dispatch.assert_called_once()
|
||||
self.assertEqual(output.group_list_type, 0) # group_list_type == 0
|
||||
self.assertIsInstance(output.combine_metadata, MoEMC2CombineMetadata)
|
||||
|
||||
def test_get_combine_mc_kwargs_with_quant(self):
|
||||
self.dispatcher.with_quant = True
|
||||
hidden_states = torch.randn(10, 128)
|
||||
topk_ids = torch.randint(0, 8, (10, 1))
|
||||
topk_weights = torch.randn(10, 1)
|
||||
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
tp_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
mc2_mask = None
|
||||
assist_info_for_combine = torch.arange(10)
|
||||
|
||||
context_metadata = {
|
||||
"topk_ids": topk_ids,
|
||||
"topk_weights": topk_weights,
|
||||
"expert_map": expert_map,
|
||||
"ep_recv_counts": ep_recv_counts,
|
||||
"mc2_mask": mc2_mask,
|
||||
"assist_info_for_combine": assist_info_for_combine,
|
||||
"expand_scales": None,
|
||||
"tp_recv_counts": tp_recv_counts
|
||||
}
|
||||
combine_metadata = MoEMC2CombineMetadata(
|
||||
topk_ids=topk_ids,
|
||||
topk_weights=topk_weights,
|
||||
expert_map=expert_map,
|
||||
ep_recv_counts=ep_recv_counts,
|
||||
tp_recv_counts=tp_recv_counts,
|
||||
assist_info_for_combine=assist_info_for_combine,
|
||||
expand_scales=None,
|
||||
dispatch_with_quant=True,
|
||||
)
|
||||
|
||||
self.dispatcher.need_extra_args = True
|
||||
self.dispatcher.enable_dispatch_v2 = True
|
||||
self.dispatcher.moe_expert_num = len(expert_map)
|
||||
kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states,
|
||||
context_metadata)
|
||||
combine_metadata)
|
||||
self.assertIn("tp_send_counts", kwargs)
|
||||
|
||||
|
||||
@@ -188,14 +244,19 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||
|
||||
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
|
||||
topk_ids, None)
|
||||
token_dispatch_input = build_token_dispatch_input_fixture(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
)
|
||||
results = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
|
||||
|
||||
# Verify npu_moe_init_routing is called
|
||||
self.mock_npu_moe_init_routing_custom.assert_called_once()
|
||||
args, kwargs = self.mock_npu_moe_init_routing_custom.call_args
|
||||
|
||||
self.assertEqual(results.group_list_type, 1)
|
||||
self.assertIsInstance(results.combine_metadata, MoEAllGatherCombineMetadata)
|
||||
|
||||
@pytest.mark.skip(
|
||||
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
|
||||
@@ -205,14 +266,19 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||
|
||||
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
|
||||
topk_ids, None)
|
||||
token_dispatch_input = build_token_dispatch_input_fixture(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
)
|
||||
results = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
|
||||
|
||||
# Verify npu_moe_init_routing is called
|
||||
self.mock_npu_moe_init_routing_custom.assert_called_once()
|
||||
args, kwargs = self.mock_npu_moe_init_routing_custom.call_args
|
||||
|
||||
self.assertEqual(results.group_list_type, 1)
|
||||
self.assertIsInstance(results.combine_metadata, MoEAllGatherCombineMetadata)
|
||||
|
||||
@pytest.mark.skip(
|
||||
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
|
||||
@@ -230,9 +296,12 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||
|
||||
results = self.dispatcher_quant.token_dispatch(hidden_states,
|
||||
topk_weights, topk_ids,
|
||||
None)
|
||||
token_dispatch_input = build_token_dispatch_input_fixture(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
)
|
||||
results = self.dispatcher_quant.token_dispatch(token_dispatch_input=token_dispatch_input)
|
||||
|
||||
self.assertEqual(results.group_list_type, 1)
|
||||
|
||||
@@ -252,11 +321,13 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]])
|
||||
topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]])
|
||||
|
||||
results = self.dispatcher_quant.token_dispatch(hidden_states,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
None,
|
||||
with_quant=True)
|
||||
token_dispatch_input = build_token_dispatch_input_fixture(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
quant_type=QuantType.W8A8,
|
||||
)
|
||||
results = self.dispatcher_quant.token_dispatch(token_dispatch_input=token_dispatch_input)
|
||||
|
||||
self.assertIsNotNone(results.hidden_states)
|
||||
self.assertIsNotNone(results.group_list)
|
||||
@@ -267,40 +338,43 @@ class TestTokenDispatcherWithAllGather(TestBase):
|
||||
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
|
||||
def test_token_combine_with_expert_map(self):
|
||||
hidden_states = torch.randn(6, 128)
|
||||
context_metadata = {
|
||||
"expanded_row_idx": torch.tensor([0, 1, 1, 1, 1, 1]),
|
||||
"topk_weights": torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
|
||||
}
|
||||
self.dispatcher.original_shape = (6, 128)
|
||||
final_hidden_states = self.dispatcher.token_combine(
|
||||
hidden_states, context_metadata).routed_out
|
||||
combine_metadata = MoEAllGatherCombineMetadata(
|
||||
expanded_row_idx=torch.tensor([0, 1, 1, 1, 1, 1]),
|
||||
topk_weights=torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
|
||||
restore_shape=torch.Size([6, 128]),
|
||||
)
|
||||
final_hidden_states = self.dispatcher.token_combine(hidden_states, combine_metadata)
|
||||
self.assertEqual(final_hidden_states.shape, (6, 128))
|
||||
|
||||
@pytest.mark.skip(
|
||||
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
|
||||
def test_token_combine_without_expert_map(self):
|
||||
hidden_states = torch.randn(6, 128)
|
||||
context_metadata = {
|
||||
"expanded_row_idx": torch.tensor([0, 1, 1, 1, 1, 1]),
|
||||
"topk_weights": torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
|
||||
}
|
||||
self.dispatcher.original_shape = (6, 128)
|
||||
final_hidden_states = self.dispatcher.token_combine(
|
||||
hidden_states, context_metadata).routed_out
|
||||
combine_metadata = MoEAllGatherCombineMetadata(
|
||||
expanded_row_idx=torch.tensor([0, 1, 1, 1, 1, 1]),
|
||||
topk_weights=torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
|
||||
restore_shape=torch.Size([6, 128]),
|
||||
)
|
||||
final_hidden_states = self.dispatcher.token_combine(hidden_states, combine_metadata)
|
||||
self.mock_npu_moe_token_unpermute.assert_called_once()
|
||||
self.assertEqual(final_hidden_states.shape, (6, 128))
|
||||
|
||||
@pytest.mark.skip(
|
||||
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
|
||||
def test_token_dispatch_with_router_weight(self):
|
||||
self.dispatcher.apply_router_weight_on_input = True
|
||||
hidden_states = torch.randn(3, 128)
|
||||
topk_weights = torch.tensor([[0.7], [0.6], [0.5]]) # topk=1
|
||||
topk_ids = torch.tensor([[0], [1], [2]])
|
||||
|
||||
results = self.dispatcher.token_dispatch(hidden_states, topk_weights,
|
||||
topk_ids, None)
|
||||
token_dispatch_input = build_token_dispatch_input_fixture(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
apply_router_weight_on_input=True,
|
||||
)
|
||||
results = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
|
||||
self.assertEqual(results.hidden_states.shape, (6, 128))
|
||||
self.assertIsInstance(results.combine_metadata, MoEAllGatherCombineMetadata)
|
||||
|
||||
|
||||
class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
@@ -408,35 +482,39 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
[0, 1], dtype=torch.int32)
|
||||
self.dispatcher.local_expert_indices = [0, 1]
|
||||
|
||||
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map)
|
||||
token_dispatch_input = build_token_dispatch_input_fixture(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
result = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
|
||||
|
||||
self.assertIsNotNone(result.hidden_states)
|
||||
self.assertIsNotNone(result.group_list)
|
||||
self.assertEqual(result.group_list_type, 1)
|
||||
self.assertIsInstance(result.combine_metadata, MoEAllToAllCombineMetadata)
|
||||
|
||||
@pytest.mark.skip(
|
||||
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
|
||||
def test_token_combine(self):
|
||||
hidden_states = torch.randn(16, 16)
|
||||
context_metadata = {
|
||||
"input_splits": [4, 4],
|
||||
"output_splits": [4, 4],
|
||||
"topk_weights": torch.rand(8, 4),
|
||||
"reversed_local_input_permutation_mapping": torch.arange(8),
|
||||
"reversed_global_input_permutation_mapping": torch.arange(16),
|
||||
}
|
||||
self.dispatcher.hidden_shape = (8, 16)
|
||||
self.dispatcher.hidden_shape_before_permute = (8, 16)
|
||||
combine_metadata = MoEAllToAllCombineMetadata(
|
||||
input_splits=np.array([4, 4]),
|
||||
output_splits=np.array([4, 4]),
|
||||
topk_weights=torch.rand(8, 4),
|
||||
reversed_local_input_permutation_mapping=torch.arange(8),
|
||||
reversed_global_input_permutation_mapping=torch.arange(16),
|
||||
hidden_shape=torch.Size([8, 16]),
|
||||
hidden_shape_before_permute=torch.Size([8, 16]),
|
||||
)
|
||||
self.dispatcher.expert_ids_per_ep_rank = torch.tensor(
|
||||
[0, 1], dtype=torch.int32)
|
||||
self.dispatcher.local_expert_indices = [0, 1]
|
||||
|
||||
output = self.dispatcher.token_combine(hidden_states, context_metadata)
|
||||
output = self.dispatcher.token_combine(hidden_states, combine_metadata)
|
||||
self.assertIsNotNone(output)
|
||||
self.assertEqual(output.routed_out.shape, (8, 16))
|
||||
self.assertEqual(output.shape, (8, 16))
|
||||
|
||||
@pytest.mark.skip(
|
||||
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
|
||||
@@ -454,16 +532,20 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
[0, 1], dtype=torch.int32)
|
||||
self.dispatcher.local_expert_indices = [0, 1]
|
||||
|
||||
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
with_quant=True)
|
||||
token_dispatch_input = build_token_dispatch_input_fixture(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
quant_type=QuantType.W8A8,
|
||||
)
|
||||
result = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
|
||||
|
||||
self.assertIsNotNone(result.hidden_states)
|
||||
self.assertIsNotNone(result.group_list)
|
||||
self.assertIsNotNone(result.dynamic_scale)
|
||||
self.assertEqual(result.group_list_type, 1)
|
||||
self.assertIsInstance(result.combine_metadata, MoEAllToAllCombineMetadata)
|
||||
|
||||
@pytest.mark.skip(
|
||||
"Skip as register_kernels has NPU SocName checking in CANN 8.5.0.")
|
||||
@@ -484,14 +566,16 @@ class TestTokenDispatcherWithAll2AllV(TestBase):
|
||||
[0, 1], dtype=torch.int32)
|
||||
self.dispatcher.local_expert_indices = [0, 1]
|
||||
|
||||
result = self.dispatcher.token_dispatch(hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
with_quant=True)
|
||||
token_dispatch_input = build_token_dispatch_input_fixture(
|
||||
hidden_states=hidden_states,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
quant_type=QuantType.W8A8,
|
||||
)
|
||||
result = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input)
|
||||
|
||||
self.assertIsNotNone(result.hidden_states)
|
||||
self.assertIsNotNone(result.group_list)
|
||||
self.assertIsNotNone(result.dynamic_scale)
|
||||
self.assertEqual(result.group_list_type, 1)
|
||||
|
||||
|
||||
@@ -3,9 +3,8 @@ from unittest.mock import Mock, patch
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.methods.w4a16 import (AscendW4A16FusedMoEMethod,
|
||||
pack_to_int32,
|
||||
unpack_from_int32)
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.quantization.methods.w4a16 import AscendW4A16FusedMoEMethod, pack_to_int32, unpack_from_int32
|
||||
|
||||
|
||||
class TestUnpackFromInt32(TestBase):
|
||||
@@ -268,3 +267,41 @@ class TestAscendW4A16FusedMoEMethod(TestBase):
|
||||
torch.equal(layer.w13_weight_packed.data, original_w13_data))
|
||||
self.assertTrue(
|
||||
torch.equal(layer.w2_weight_packed.data, original_w2_data))
|
||||
|
||||
@patch("vllm_ascend.quantization.methods.w4a16._EXTRA_CTX")
|
||||
@patch("vllm_ascend.quantization.methods.w4a16.select_experts")
|
||||
def test_apply_uses_explicit_dispatch_and_mlp_args(self, mock_select_experts, mock_extra_ctx):
|
||||
tokens = 3
|
||||
hidden_size = self.output_size
|
||||
layer = self.build_layer()
|
||||
x = torch.randn(tokens, hidden_size, dtype=torch.float32)
|
||||
router_logits = torch.randn(tokens, self.experts, dtype=torch.float32)
|
||||
topk_weights = torch.randn(tokens, 2, dtype=torch.float32)
|
||||
topk_ids = torch.randint(0, self.experts, (tokens, 2), dtype=torch.int64)
|
||||
mc2_mask = torch.tensor([1, 0, 1], dtype=torch.bool)
|
||||
pertoken_scale = torch.randn(tokens, dtype=torch.float32)
|
||||
|
||||
mock_select_experts.return_value = (topk_weights, topk_ids)
|
||||
mock_comm = Mock()
|
||||
mock_comm.fused_experts.return_value = torch.randn(tokens, hidden_size, dtype=torch.float32)
|
||||
mock_extra_ctx.moe_comm_method = mock_comm
|
||||
mock_extra_ctx.moe_comm_type = MoECommType.ALLGATHER
|
||||
|
||||
self.quant_method.apply(
|
||||
layer=layer,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
top_k=2,
|
||||
renormalize=True,
|
||||
global_num_experts=self.experts,
|
||||
activation="gelu",
|
||||
apply_router_weight_on_input=True,
|
||||
mc2_mask=mc2_mask,
|
||||
pertoken_scale=pertoken_scale,
|
||||
)
|
||||
|
||||
fused_experts_input = mock_comm.fused_experts.call_args.kwargs["fused_experts_input"]
|
||||
self.assertEqual(fused_experts_input.activation, "gelu")
|
||||
self.assertTrue(fused_experts_input.routing.apply_router_weight_on_input)
|
||||
self.assertIs(fused_experts_input.routing.mc2_mask, mc2_mask)
|
||||
self.assertIs(fused_experts_input.routing.pertoken_scale, pertoken_scale)
|
||||
|
||||
@@ -3,8 +3,8 @@ from unittest.mock import Mock, patch
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.quantization.methods.w8a8_dynamic import \
|
||||
AscendW8A8DynamicFusedMoEMethod
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.quantization.methods.w8a8_dynamic import AscendW8A8DynamicFusedMoEMethod
|
||||
|
||||
|
||||
class TestAscendW8A8FusedMoEMethod(TestBase):
|
||||
@@ -32,8 +32,9 @@ class TestAscendW8A8FusedMoEMethod(TestBase):
|
||||
mock_ep_group = Mock()
|
||||
mock_get_ep_group.return_value = mock_ep_group
|
||||
mock_ascend_config = Mock()
|
||||
|
||||
mock_ascend_config.enable_chunked_prefill = False
|
||||
mock_ascend_config.multistream_overlap_gate = False
|
||||
mock_ascend_config.eplb_config = Mock(dynamic_eplb=False)
|
||||
mock_get_ascend_config.return_value = mock_ascend_config
|
||||
mock_mc2_group = Mock(device_group=0)
|
||||
mock_get_mc2_group.return_value = mock_mc2_group
|
||||
@@ -104,3 +105,125 @@ class TestAscendW8A8FusedMoEMethod(TestBase):
|
||||
new_layer = self.build_layer()
|
||||
self.quant_method.process_weights_after_loading(new_layer)
|
||||
mock_npu_format_cast.assert_called()
|
||||
|
||||
@patch("vllm_ascend.quantization.methods.w8a8_dynamic._EXTRA_CTX")
|
||||
@patch("vllm_ascend.quantization.methods.w8a8_dynamic.select_experts")
|
||||
def test_apply_uses_explicit_dispatch_and_mlp_args(self, mock_select_experts, mock_extra_ctx):
|
||||
tokens = 4
|
||||
hidden_size = self.hidden_size
|
||||
layer = torch.nn.Module()
|
||||
layer.w13_weight = torch.randint(
|
||||
-8,
|
||||
8,
|
||||
(self.num_experts, 2 * self.intermediate_size, hidden_size),
|
||||
dtype=torch.int8,
|
||||
)
|
||||
layer.w2_weight = torch.randint(
|
||||
-8,
|
||||
8,
|
||||
(self.num_experts, hidden_size, self.intermediate_size),
|
||||
dtype=torch.int8,
|
||||
)
|
||||
layer.w13_weight_scale_fp32 = torch.ones(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32)
|
||||
layer.w2_weight_scale = torch.ones(self.num_experts, hidden_size, dtype=torch.float32)
|
||||
|
||||
x = torch.randn(tokens, hidden_size, dtype=torch.float32)
|
||||
router_logits = torch.randn(tokens, self.num_experts, dtype=torch.float32)
|
||||
topk_weights = torch.randn(tokens, 2, dtype=torch.float32)
|
||||
topk_ids = torch.randint(0, self.num_experts, (tokens, 2), dtype=torch.int64)
|
||||
mc2_mask = torch.tensor([1, 0, 1, 0], dtype=torch.bool)
|
||||
pertoken_scale = torch.randn(tokens, dtype=torch.float32)
|
||||
|
||||
mock_select_experts.return_value = (topk_weights, topk_ids)
|
||||
mock_comm = Mock()
|
||||
mock_comm.fused_experts.return_value = torch.randn(tokens, hidden_size, dtype=torch.float32)
|
||||
mock_extra_ctx.moe_comm_method = mock_comm
|
||||
mock_extra_ctx.moe_comm_type = MoECommType.ALLGATHER
|
||||
self.quant_method.multistream_overlap_gate = False
|
||||
self.quant_method.in_dtype = torch.float32
|
||||
|
||||
self.quant_method.apply(
|
||||
layer=layer,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
top_k=2,
|
||||
renormalize=True,
|
||||
global_num_experts=self.num_experts,
|
||||
activation="gelu",
|
||||
apply_router_weight_on_input=True,
|
||||
mc2_mask=mc2_mask,
|
||||
pertoken_scale=pertoken_scale,
|
||||
)
|
||||
|
||||
fused_experts_input = mock_comm.fused_experts.call_args.kwargs["fused_experts_input"]
|
||||
self.assertEqual(fused_experts_input.activation, "gelu")
|
||||
self.assertTrue(fused_experts_input.routing.apply_router_weight_on_input)
|
||||
self.assertIs(fused_experts_input.routing.mc2_mask, mc2_mask)
|
||||
self.assertIs(fused_experts_input.routing.pertoken_scale, pertoken_scale)
|
||||
self.assertIs(fused_experts_input.topk_weights, topk_weights)
|
||||
self.assertIs(fused_experts_input.topk_ids, topk_ids)
|
||||
|
||||
@patch("vllm_ascend.quantization.methods.w8a8_dynamic.get_flash_common3_context")
|
||||
@patch("vllm_ascend.quantization.methods.w8a8_dynamic._EXTRA_CTX")
|
||||
@patch("vllm_ascend.quantization.methods.w8a8_dynamic.select_experts")
|
||||
def test_apply_overlap_gate_uses_fc3_context(
|
||||
self,
|
||||
mock_select_experts,
|
||||
mock_extra_ctx,
|
||||
mock_get_flash_common3_context,
|
||||
):
|
||||
tokens = 4
|
||||
hidden_size = self.hidden_size
|
||||
layer = torch.nn.Module()
|
||||
layer.w13_weight = torch.randint(
|
||||
-8,
|
||||
8,
|
||||
(self.num_experts, 2 * self.intermediate_size, hidden_size),
|
||||
dtype=torch.int8,
|
||||
)
|
||||
layer.w2_weight = torch.randint(
|
||||
-8,
|
||||
8,
|
||||
(self.num_experts, hidden_size, self.intermediate_size),
|
||||
dtype=torch.int8,
|
||||
)
|
||||
layer.w13_weight_scale_fp32 = torch.ones(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32)
|
||||
layer.w2_weight_scale = torch.ones(self.num_experts, hidden_size, dtype=torch.float32)
|
||||
|
||||
x = torch.randn(tokens, hidden_size, dtype=torch.float32)
|
||||
router_logits = torch.randn(tokens, self.num_experts, dtype=torch.float32)
|
||||
topk_weights = torch.randn(tokens, 2, dtype=torch.float32)
|
||||
topk_ids = torch.randint(0, self.num_experts, (tokens, 2), dtype=torch.int64)
|
||||
mc2_mask = torch.tensor([1, 0, 1, 0], dtype=torch.bool)
|
||||
pertoken_scale = torch.randn(tokens, dtype=torch.float32)
|
||||
|
||||
self.quant_method.multistream_overlap_gate = True
|
||||
self.quant_method.in_dtype = torch.float32
|
||||
mock_get_flash_common3_context.return_value = Mock(topk_weights=topk_weights, topk_ids=topk_ids)
|
||||
|
||||
mock_comm = Mock()
|
||||
mock_comm.fused_experts.return_value = torch.randn(tokens, hidden_size, dtype=torch.float32)
|
||||
mock_extra_ctx.moe_comm_method = mock_comm
|
||||
mock_extra_ctx.moe_comm_type = MoECommType.ALLGATHER
|
||||
|
||||
self.quant_method.apply(
|
||||
layer=layer,
|
||||
x=x,
|
||||
router_logits=router_logits,
|
||||
top_k=2,
|
||||
renormalize=True,
|
||||
global_num_experts=self.num_experts,
|
||||
activation="gelu",
|
||||
apply_router_weight_on_input=True,
|
||||
mc2_mask=mc2_mask,
|
||||
pertoken_scale=pertoken_scale,
|
||||
)
|
||||
|
||||
mock_select_experts.assert_not_called()
|
||||
fused_experts_input = mock_comm.fused_experts.call_args.kwargs["fused_experts_input"]
|
||||
self.assertEqual(fused_experts_input.activation, "gelu")
|
||||
self.assertTrue(fused_experts_input.routing.apply_router_weight_on_input)
|
||||
self.assertIs(fused_experts_input.routing.mc2_mask, mc2_mask)
|
||||
self.assertIs(fused_experts_input.routing.pertoken_scale, pertoken_scale)
|
||||
self.assertIs(fused_experts_input.topk_weights, topk_weights)
|
||||
self.assertIs(fused_experts_input.topk_ids, topk_ids)
|
||||
|
||||
Reference in New Issue
Block a user