diff --git a/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py b/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py index bcbbecc0..11d9707f 100644 --- a/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py +++ b/tests/e2e/multicard/2-cards/test_offline_inference_distributed.py @@ -245,6 +245,7 @@ def test_qwen3_dense_prefetch_mlp_weight_tp2(model): @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"}) @patch.dict(os.environ, {"ASCEND_AGGREGATE_ENABLE": "1"}) @patch.dict(os.environ, {"HCCL_BUFFSIZE": "1024"}) +@wait_until_npu_memory_free() def test_deepseek3_2_w8a8_pruning_mtp_tp2_ep(): short_example_prompts = [ "Hello ", @@ -272,6 +273,7 @@ def test_deepseek3_2_w8a8_pruning_mtp_tp2_ep(): @patch.dict(os.environ, {"VLLM_ASCEND_ENABLE_FLASHCOMM1": "1"}) @patch.dict(os.environ, {"ASCEND_AGGREGATE_ENABLE": "1"}) @patch.dict(os.environ, {"HCCL_BUFFSIZE": "1024"}) +@wait_until_npu_memory_free() def test_deepseek3_2_w8a8c8_pruning_mtp_tp2_ep(): short_example_prompts = [ "Hello ", diff --git a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_fused_moe.py b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_fused_moe.py index 5da3f021..845950a7 100644 --- a/tests/e2e/nightly/single_node/ops/singlecard_ops/test_fused_moe.py +++ b/tests/e2e/nightly/single_node/ops/singlecard_ops/test_fused_moe.py @@ -28,11 +28,17 @@ import torch import torch_npu from vllm.model_executor.layers.activation import SiluAndMul -from vllm_ascend.ops.fused_moe.experts_selector import ( - check_npu_moe_gating_top_k, select_experts) +from vllm_ascend.ops.fused_moe.experts_selector import check_npu_moe_gating_top_k, select_experts from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp -from vllm_ascend.ops.fused_moe.token_dispatcher import \ - TokenDispatcherWithAllGather +from vllm_ascend.ops.fused_moe.moe_runtime_args import ( + build_fused_experts_input, + build_mlp_compute_input, + MoEQuantParams, + MoERoutingParams, + MoETokenDispatchInput, +) +from vllm_ascend.ops.fused_moe.token_dispatcher import TokenDispatcherWithAllGather +from vllm_ascend.quantization.quant_type import QuantType NUM_EXPERTS = [8, 64] EP_SIZE = [1] @@ -83,10 +89,8 @@ def torch_moe(a, w1, w2, topk_weights, topk_ids, topk, expert_map): for i in range(w1.shape[0]): mask = topk_ids == i if mask.sum(): - out[mask] = SiluAndMul()( - a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) - return (out.view(B, -1, w2.shape[1]) * - topk_weights.view(B, -1, 1).to(out.dtype)).sum(dim=1) + out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) + return (out.view(B, -1, w2.shape[1]) * topk_weights.view(B, -1, 1).to(out.dtype)).sum(dim=1) @pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128]) @@ -129,36 +133,41 @@ def test_token_dispatcher_with_all_gather( dispatcher = TokenDispatcherWithAllGather(**dispatcher_kwargs) apply_router_weight_on_input = False - dispatch_output = dispatcher.token_dispatch( - hidden_states=a, - topk_weights=topk_weights, - topk_ids=topk_ids, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input) + token_dispatch_output = dispatcher.token_dispatch( + token_dispatch_input=MoETokenDispatchInput( + hidden_states=a, + topk_weights=topk_weights, + topk_ids=topk_ids, + routing=MoERoutingParams( + expert_map=expert_map, + global_redundant_expert_num=0, + mc2_mask=None, + apply_router_weight_on_input=apply_router_weight_on_input, + ), + quant=MoEQuantParams(quant_type=QuantType.NONE), + ) + ) - sorted_hidden_states = dispatch_output.hidden_states - group_list = dispatch_output.group_list - group_list_type = dispatch_output.group_list_type - context_metadata = dispatch_output.context_metadata + sorted_hidden_states = token_dispatch_output.hidden_states + group_list = token_dispatch_output.group_list + group_list_type = token_dispatch_output.group_list_type + combine_metadata = token_dispatch_output.combine_metadata - expert_output = apply_mlp(hidden_states=sorted_hidden_states, - w1=w1_local, - w2=w2_local, - group_list=group_list, - group_list_type=group_list_type) + expert_output = apply_mlp( + hidden_states=sorted_hidden_states, + w1=w1_local, + w2=w2_local, + group_list=group_list, + group_list_type=group_list_type, + ) combined_output = dispatcher.token_combine( - hidden_states=expert_output, - context_metadata=context_metadata, - bias=None) + hidden_states=expert_output, combine_metadata=combine_metadata, bias=None + ) - torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk, - expert_map) + torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk, expert_map) - torch.testing.assert_close(combined_output.routed_out, - torch_output, - atol=4e-2, - rtol=1) + torch.testing.assert_close(combined_output, torch_output, atol=4e-2, rtol=1) gc.collect() torch.npu.empty_cache() torch.npu.reset_peak_memory_stats() @@ -184,8 +193,7 @@ def test_token_dispatcher_with_all_gather_quant( ): context_mock = MagicMock() context_mock.fused_moe_state = 0 - with patch("vllm_ascend.ascend_forward_context.get_forward_context", - return_value=context_mock): + with patch("vllm_ascend.ops.fused_moe.moe_mlp.get_forward_context", return_value=context_mock): a = torch.randn((m, k), device=device, dtype=dtype) / 10 w1 = torch.randn((e, k, 2 * n), device=device, dtype=torch.int8) w1_scale = torch.empty((e, 2 * n), device=device, dtype=dtype) @@ -208,34 +216,44 @@ def test_token_dispatcher_with_all_gather_quant( dispatcher = TokenDispatcherWithAllGather(**dispatcher_kwargs) apply_router_weight_on_input = False - dispatch_output = dispatcher.token_dispatch( - hidden_states=a, - topk_weights=topk_weights, - topk_ids=topk_ids, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - with_quant=True) + token_dispatch_output = dispatcher.token_dispatch( + token_dispatch_input=MoETokenDispatchInput( + hidden_states=a, + topk_weights=topk_weights, + topk_ids=topk_ids, + routing=MoERoutingParams( + expert_map=expert_map, + global_redundant_expert_num=0, + mc2_mask=None, + apply_router_weight_on_input=apply_router_weight_on_input, + ), + quant=MoEQuantParams(quant_type=QuantType.W8A8), + ) + ) - sorted_hidden_states = dispatch_output.hidden_states - group_list = dispatch_output.group_list - group_list_type = dispatch_output.group_list_type - dynamic_scale = dispatch_output.dynamic_scale - context_metadata = dispatch_output.context_metadata + combine_metadata = token_dispatch_output.combine_metadata - expert_output = unified_apply_mlp(hidden_states=sorted_hidden_states, - w1=w1, - w1_scale=w1_scale, - w2=w2, - w2_scale=w2_scale, - group_list=group_list, - group_list_type=group_list_type, - dynamic_scale=dynamic_scale, - with_quant=True) + mlp_compute_input = build_mlp_compute_input( + fused_experts_input=build_fused_experts_input( + hidden_states=a, + topk_weights=topk_weights, + topk_ids=topk_ids, + w1=w1, + w2=w2, + quant_type=QuantType.W8A8, + dynamic_eplb=False, + expert_map=expert_map, + w1_scale=w1_scale, + w2_scale=w2_scale, + ), + token_dispatch_output=token_dispatch_output, + use_fusion_ops=False, + ) + expert_output = unified_apply_mlp(mlp_compute_input=mlp_compute_input) combined_output = dispatcher.token_combine( - hidden_states=expert_output, - context_metadata=context_metadata, - bias=None) - assert combined_output.routed_out.shape == (m, k) + hidden_states=expert_output, combine_metadata=combine_metadata, bias=None + ) + assert combined_output.shape == (m, k) gc.collect() torch.npu.empty_cache() torch.npu.reset_peak_memory_stats() @@ -271,25 +289,20 @@ def test_select_experts( hidden_states = torch.randn(m, n, device=device, dtype=dtype) router_logits = torch.randn(m, e, device=device, dtype=dtype) - e_score_correction_bias = (torch.randn(e, device=device, dtype=dtype) - if with_e_correction else None) + e_score_correction_bias = torch.randn(e, device=device, dtype=dtype) if with_e_correction else None custom_routing_function = None if custom_routing: custom_routing_function = MagicMock() mock_weights = torch.randn(m, topk, device=device, dtype=dtype) - mock_ids = torch.randint(0, - e, (m, topk), - device=device, - dtype=torch.int32) + mock_ids = torch.randint(0, e, (m, topk), device=device, dtype=torch.int32) custom_routing_function.return_value = (mock_weights, mock_ids) - with patch("vllm_ascend.ops.fused_moe.experts_selector._native_grouped_topk" - ) as mock_native_grouped_topk, \ - patch('vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method', - return_value=MagicMock()): - mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like( - x) + with ( + patch("vllm_ascend.ops.fused_moe.experts_selector._native_grouped_topk") as mock_native_grouped_topk, + patch("vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method", return_value=MagicMock()), + ): + mock_native_grouped_topk.side_effect = lambda x, num_groups, k: torch.randn_like(x) topk_weights, topk_ids = select_experts( hidden_states=hidden_states, @@ -305,8 +318,8 @@ def test_select_experts( ) call_moe_gatingtopk = check_npu_moe_gating_top_k( - hidden_states, topk, renormalize, topk_group, num_expert_group, - scoring_func, custom_routing_function) + hidden_states, topk, renormalize, topk_group, num_expert_group, scoring_func, custom_routing_function + ) if not call_moe_gatingtopk and use_grouped_topk: mock_native_grouped_topk.assert_called_once() else: @@ -323,16 +336,18 @@ def test_select_experts( @pytest.mark.parametrize("device", DEVICE) def test_select_experts_invalid_scoring_func(device: str): - with patch('vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method', - return_value=MagicMock()), \ - pytest.raises(ValueError, - match="Unsupported scoring function: invalid"): - select_experts(hidden_states=torch.randn(1, 128, device=device), - router_logits=torch.randn(1, 8, device=device), - top_k=2, - use_grouped_topk=False, - renormalize=False, - scoring_func="invalid") + with ( + patch("vllm_ascend.ops.fused_moe.experts_selector.get_weight_prefetch_method", return_value=MagicMock()), + pytest.raises(ValueError, match="Unsupported scoring function: invalid"), + ): + select_experts( + hidden_states=torch.randn(1, 128, device=device), + router_logits=torch.randn(1, 8, device=device), + top_k=2, + use_grouped_topk=False, + renormalize=False, + scoring_func="invalid", + ) gc.collect() torch.npu.empty_cache() torch.npu.reset_peak_memory_stats() diff --git a/tests/ut/_310p/fused_moe/test_moe_mlp_310.py b/tests/ut/_310p/fused_moe/test_moe_mlp_310.py index 3e4f21a2..1eebe596 100644 --- a/tests/ut/_310p/fused_moe/test_moe_mlp_310.py +++ b/tests/ut/_310p/fused_moe/test_moe_mlp_310.py @@ -19,6 +19,38 @@ import torch from tests.ut.base import TestBase from vllm_ascend._310p.fused_moe.moe_mlp import unified_apply_mlp +from vllm_ascend.ops.fused_moe.moe_runtime_args import ( + MoEMlpComputeInput, + MoEQuantParams, + MoEWeights, +) +from vllm_ascend.quantization.quant_type import QuantType + + +def build_mlp_compute_input_fixture( + *, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + group_list: torch.Tensor, + with_quant: bool, + w1_scale: torch.Tensor | None = None, + w2_scale: torch.Tensor | None = None, + group_list_type: int = 1, +) -> MoEMlpComputeInput: + return MoEMlpComputeInput( + hidden_states=hidden_states, + group_list=group_list, + group_list_type=group_list_type, + dynamic_scale=None, + topk_scales=None, + weights=MoEWeights(w1=w1, w2=w2, w1_scale=w1_scale, w2_scale=w2_scale), + quant=MoEQuantParams(quant_type=QuantType.W8A8 if with_quant else QuantType.NONE), + fusion=False, + activation="silu", + need_trans=False, + dynamic_eplb=False, + ) class TestUnifiedApplyMLP310(TestBase): @@ -38,14 +70,13 @@ class TestUnifiedApplyMLP310(TestBase): group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64) result = unified_apply_mlp( - hidden_states=hidden_states, - w1=w1, - w1_scale=None, - w2=w2, - w2_scale=None, - group_list=group_list, - group_list_type=1, - with_quant=False, + mlp_compute_input=build_mlp_compute_input_fixture( + hidden_states=hidden_states, + w1=w1, + w2=w2, + group_list=group_list, + with_quant=False, + ) ) self.assertEqual(mock_npu_grouped_matmul.call_count, 2) @@ -94,14 +125,15 @@ class TestUnifiedApplyMLP310(TestBase): group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64) result = unified_apply_mlp( - hidden_states=hidden_states, - w1=w1, - w1_scale=w1_scale, - w2=w2, - w2_scale=w2_scale, - group_list=group_list, - group_list_type=1, - with_quant=True, + mlp_compute_input=build_mlp_compute_input_fixture( + hidden_states=hidden_states, + w1=w1, + w2=w2, + group_list=group_list, + with_quant=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + ) ) mock_cumsum.assert_called_once() diff --git a/tests/ut/ops/test_activation.py b/tests/ut/ops/test_activation.py index a8467440..a5bc47ac 100644 --- a/tests/ut/ops/test_activation.py +++ b/tests/ut/ops/test_activation.py @@ -95,4 +95,4 @@ def test_SiluAndMul_forward_310p( assert torch.allclose(actual_arg, expected_arg), "swiglu called with unexpected input" expected_out = (dummy_tensor[..., :h] + 1) * dummy_tensor[..., h:] - assert torch.allclose(out, expected_out) \ No newline at end of file + assert torch.allclose(out, expected_out) diff --git a/tests/ut/ops/test_fused_moe.py b/tests/ut/ops/test_fused_moe.py index f0bd0f3e..53d81577 100644 --- a/tests/ut/ops/test_fused_moe.py +++ b/tests/ut/ops/test_fused_moe.py @@ -12,7 +12,7 @@ # limitations under the License. # This file is a part of the vllm-ascend project. # -from typing import List, TypedDict +from typing import TypedDict from unittest.mock import MagicMock, patch import pytest @@ -20,12 +20,19 @@ import torch import torch.nn as nn import torch_npu from pytest_mock import MockerFixture + from tests.ut.base import TestBase from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.ops.fused_moe.experts_selector import select_experts from vllm_ascend.ops.fused_moe.fused_moe import AscendUnquantizedFusedMoEMethod -from vllm_ascend.ops.fused_moe.moe_mlp import (cumsum_group_list, - unified_apply_mlp) +from vllm_ascend.ops.fused_moe.moe_mlp import cumsum_group_list, unified_apply_mlp +from vllm_ascend.ops.fused_moe.moe_runtime_args import ( + MoEMlpComputeInput, + MoEPrepareOutput, + MoEQuantParams, + MoEWeights, +) +from vllm_ascend.quantization.quant_type import QuantType from vllm_ascend.utils import AscendDeviceType, adapt_patch adapt_patch(True) @@ -54,6 +61,51 @@ def mock_npu_format_cast(weight_data, format): return weight_data +def build_mlp_compute_input_fixture( + *, + hidden_states: torch.Tensor, + w1: torch.Tensor | list[torch.Tensor], + w2: torch.Tensor | list[torch.Tensor], + group_list: torch.Tensor, + with_quant: bool, + group_list_type: int = 1, + dynamic_scale: torch.Tensor | None = None, + topk_scales: torch.Tensor | None = None, + w1_scale: torch.Tensor | list[torch.Tensor] | None = None, + w2_scale: torch.Tensor | list[torch.Tensor] | None = None, + w1_scale_bias: torch.Tensor | None = None, + w2_scale_bias: torch.Tensor | None = None, + w1_offset: torch.Tensor | None = None, + w2_offset: torch.Tensor | None = None, + fusion: bool = False, + activation: str = "silu", + need_trans: bool = True, + dynamic_eplb: bool = False, +) -> MoEMlpComputeInput: + return MoEMlpComputeInput( + hidden_states=hidden_states, + group_list=group_list, + group_list_type=group_list_type, + dynamic_scale=dynamic_scale, + topk_scales=topk_scales, + weights=MoEWeights( + w1=w1, + w2=w2, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_scale_bias=w1_scale_bias, + w2_scale_bias=w2_scale_bias, + w1_offset=w1_offset, + w2_offset=w2_offset, + ), + quant=MoEQuantParams(quant_type=QuantType.W8A8 if with_quant else QuantType.NONE), + fusion=fusion, + activation=activation, + need_trans=need_trans, + dynamic_eplb=dynamic_eplb, + ) + + @pytest.fixture(autouse=True) def setup_vllm_config_mock(mocker: MockerFixture): mock_hf_config = MagicMock() @@ -77,7 +129,13 @@ def mock_dist_env(mocker: MockerFixture): mock_moe_comm_method = MagicMock() def mock_prepare(hidden_states, router_logits, **kwargs): - return hidden_states, router_logits + return MoEPrepareOutput( + hidden_states=hidden_states, + router_logits=router_logits, + mc2_mask=kwargs.get("mc2_mask"), + padded_hidden_states_shape=None, + pertoken_scale=None, + ) mock_moe_comm_method.prepare.side_effect = mock_prepare @@ -204,18 +262,18 @@ def moe_method(mock_dist_env): class Device(TypedDict): device_id: int - device_expert: List[int] + device_expert: list[int] class Layer(TypedDict): layer_id: int device_count: int - device_list: List[Device] + device_list: list[Device] class MockData(TypedDict): moe_layer_count: int - layer_list: List[Layer] + layer_list: list[Layer] class MockQuantMethod(nn.Module): @@ -338,18 +396,15 @@ class TestUnifiedApplyMLP(TestBase): w2_scale = torch.randn(5, 20, dtype=torch.bfloat16) group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64) - result = unified_apply_mlp(hidden_states=hidden_states, - w1=w1, - w1_scale=w1_scale, - w2=w2, - w2_scale=w2_scale, - group_list=group_list, - dynamic_scale=None, - group_list_type=1, - w1_scale_bias=None, - w2_scale_bias=None, - topk_scales=None, - with_quant=True) + result = unified_apply_mlp(mlp_compute_input=build_mlp_compute_input_fixture( + hidden_states=hidden_states, + w1=w1, + w2=w2, + group_list=group_list, + with_quant=True, + w1_scale=w1_scale, + w2_scale=w2_scale, + )) mock_get_forward_context.assert_called() @@ -383,18 +438,14 @@ class TestUnifiedApplyMLP(TestBase): group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64) topk_scales = torch.randn(10, 1, dtype=torch.float16) - result = unified_apply_mlp(hidden_states=hidden_states, - w1=w1, - w1_scale=None, - w2=w2, - w2_scale=None, - group_list=group_list, - dynamic_scale=None, - group_list_type=1, - w1_scale_bias=None, - w2_scale_bias=None, - topk_scales=topk_scales, - with_quant=False) + result = unified_apply_mlp(mlp_compute_input=build_mlp_compute_input_fixture( + hidden_states=hidden_states, + w1=w1, + w2=w2, + group_list=group_list, + with_quant=False, + topk_scales=topk_scales, + )) self.assertEqual(mock_npu_grouped_matmul.call_count, 2) mock_npu_swiglu.assert_called_once() @@ -445,18 +496,18 @@ class TestUnifiedApplyMLP(TestBase): group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64) provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32) - result = unified_apply_mlp(hidden_states=hidden_states, - w1=w1, - w1_scale=w1_scale, - w2=w2, - w2_scale=w2_scale, - group_list=group_list, - dynamic_scale=provided_dynamic_scale, - group_list_type=1, - w1_scale_bias=w1_scale_bias, - w2_scale_bias=w2_scale_bias, - topk_scales=None, - with_quant=True) + result = unified_apply_mlp(mlp_compute_input=build_mlp_compute_input_fixture( + hidden_states=hidden_states, + w1=w1, + w2=w2, + group_list=group_list, + with_quant=True, + dynamic_scale=provided_dynamic_scale, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_scale_bias=w1_scale_bias, + w2_scale_bias=w2_scale_bias, + )) mock_get_forward_context.assert_called() @@ -490,18 +541,14 @@ class TestUnifiedApplyMLP(TestBase): group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64) topk_scales = torch.randn(10, 1, dtype=torch.float16) - result = unified_apply_mlp(hidden_states=hidden_states, - w1=w1, - w1_scale=None, - w2=w2, - w2_scale=None, - group_list=group_list, - dynamic_scale=None, - group_list_type=1, - w1_scale_bias=None, - w2_scale_bias=None, - topk_scales=topk_scales, - with_quant=False) + result = unified_apply_mlp(mlp_compute_input=build_mlp_compute_input_fixture( + hidden_states=hidden_states, + w1=w1, + w2=w2, + group_list=group_list, + with_quant=False, + topk_scales=topk_scales, + )) self.assertEqual(mock_npu_grouped_matmul.call_count, 2) mock_npu_swiglu.assert_called_once() @@ -556,19 +603,19 @@ class TestUnifiedApplyMLP(TestBase): group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64) provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32) - result = unified_apply_mlp(hidden_states=hidden_states, - w1=w1, - w1_scale=w1_scale, - w2=w2, - w2_scale=w2_scale, - group_list=group_list, - dynamic_scale=provided_dynamic_scale, - group_list_type=1, - w1_scale_bias=w1_scale_bias, - w2_scale_bias=w2_scale_bias, - topk_scales=None, - with_quant=True, - fusion=True) + result = unified_apply_mlp(mlp_compute_input=build_mlp_compute_input_fixture( + hidden_states=hidden_states, + w1=w1, + w2=w2, + group_list=group_list, + with_quant=True, + dynamic_scale=provided_dynamic_scale, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_scale_bias=w1_scale_bias, + w2_scale_bias=w2_scale_bias, + fusion=True, + )) mock_get_forward_context.assert_called() mock_npu_grouped_matmul.assert_called_once() diff --git a/tests/ut/ops/test_moe_comm_method.py b/tests/ut/ops/test_moe_comm_method.py index ed805dd7..c25d3716 100644 --- a/tests/ut/ops/test_moe_comm_method.py +++ b/tests/ut/ops/test_moe_comm_method.py @@ -4,12 +4,21 @@ import torch from vllm.model_executor.layers.fused_moe import FusedMoEConfig from tests.ut.base import TestBase -from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl, - AlltoAllCommImpl, - MC2CommImpl) +from vllm_ascend.ops.fused_moe.moe_comm_method import ( + AllGatherCommImpl, + AlltoAllCommImpl, + MC2CommImpl, +) +from vllm_ascend.ops.fused_moe.moe_runtime_args import ( + MoEAllGatherCombineMetadata, + MoEFusedExpertsInput, + MoEPrepareOutput, + MoEQuantParams, + MoERoutingParams, + MoEWeights, +) +from vllm_ascend.ops.fused_moe.token_dispatcher import MoETokenDispatchOutput from vllm_ascend.quantization.methods.base import QuantType -from vllm_ascend.ops.fused_moe.token_dispatcher import (TokenCombineResult, - TokenDispatchResult) class TestMoECommMethod(TestBase): @@ -45,8 +54,11 @@ class TestMoECommMethod(TestBase): # Mock prepare finalize mock_pf_instance = MagicMock() - mock_pf_instance.prepare.return_value = (torch.randn(4, 8), - torch.randn(4, 2), None, None) + mock_pf_instance.prepare.return_value = MoEPrepareOutput( + hidden_states=torch.randn(4, 8), + router_logits=torch.randn(4, 2), + mc2_mask=None, + padded_hidden_states_shape=None) mock_pf_instance.finalize.return_value = torch.randn(4, 8) mock_prepare_finalize.return_value = mock_pf_instance @@ -60,8 +72,9 @@ class TestMoECommMethod(TestBase): # Test prepare method hidden_states = torch.randn(3, 8) router_logits = torch.randn(3, 2) - h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare( - hidden_states, router_logits) + prepare_output = comm_impl.prepare(hidden_states, router_logits) + h_out = prepare_output.hidden_states + padded_hidden_states_shape = prepare_output.padded_hidden_states_shape # Verify prepare was called with correct arguments mock_pf_instance.prepare.assert_called_once_with( @@ -70,7 +83,7 @@ class TestMoECommMethod(TestBase): # Test finalize method comm_impl.finalize(h_out, reduce_results=True, - context_metadata=context_metadata) + padded_hidden_states_shape=padded_hidden_states_shape) mock_pf_instance.finalize.assert_called_once_with(h_out, True, None) @patch('vllm_ascend.ascend_forward_context.get_forward_context') @@ -86,10 +99,11 @@ class TestMoECommMethod(TestBase): # Mock prepare finalize mock_pf_instance = MagicMock() - mock_pf_instance.prepare.return_value = (torch.randn(4, 8), - torch.randn(4, 2), - torch.tensor([1, 0, 1, - 0]), None) + mock_pf_instance.prepare.return_value = MoEPrepareOutput( + hidden_states=torch.randn(4, 8), + router_logits=torch.randn(4, 2), + mc2_mask=torch.tensor([1, 0, 1, 0]), + padded_hidden_states_shape=None) mock_pf_instance.finalize.return_value = torch.randn(4, 8) mock_prepare_finalize.return_value = mock_pf_instance @@ -103,8 +117,9 @@ class TestMoECommMethod(TestBase): # Test prepare method hidden_states = torch.randn(3, 8) router_logits = torch.randn(3, 2) - h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare( - hidden_states, router_logits) + prepare_output = comm_impl.prepare(hidden_states, router_logits) + h_out = prepare_output.hidden_states + padded_hidden_states_shape = prepare_output.padded_hidden_states_shape # Verify prepare was called with correct arguments mock_pf_instance.prepare.assert_called_once_with( @@ -113,7 +128,7 @@ class TestMoECommMethod(TestBase): # Test finalize method comm_impl.finalize(h_out, reduce_results=True, - context_metadata=context_metadata) + padded_hidden_states_shape=padded_hidden_states_shape) mock_pf_instance.finalize.assert_called_once_with(h_out, True, None) @patch('vllm_ascend.ascend_forward_context.get_forward_context') @@ -133,8 +148,11 @@ class TestMoECommMethod(TestBase): # Mock prepare finalize mock_pf_instance = MagicMock() - mock_pf_instance.prepare.return_value = (torch.randn(4, 8), - torch.randn(4, 2), None, None) + mock_pf_instance.prepare.return_value = MoEPrepareOutput( + hidden_states=torch.randn(4, 8), + router_logits=torch.randn(4, 2), + mc2_mask=None, + padded_hidden_states_shape=None) mock_pf_instance.finalize.return_value = torch.randn(4, 8) mock_prepare_finalize.return_value = mock_pf_instance @@ -148,8 +166,7 @@ class TestMoECommMethod(TestBase): # Test prepare method hidden_states = torch.randn(3, 8) router_logits = torch.randn(3, 2) - h_out, r_out, mc2_mask, context_metadata = comm_impl.prepare( - hidden_states, router_logits) + _ = comm_impl.prepare(hidden_states, router_logits) # Verify prepare was called with correct arguments mock_pf_instance.prepare.assert_called_once_with( @@ -174,19 +191,27 @@ class TestMoECommMethod(TestBase): # Mock prepare finalize mock_pf_instance = MagicMock() - mock_pf_instance.prepare.return_value = (torch.randn(4, 8), - torch.randn(4, 2), None) + mock_pf_instance.prepare.return_value = MoEPrepareOutput( + hidden_states=torch.randn(4, 8), + router_logits=torch.randn(4, 2), + mc2_mask=None, + padded_hidden_states_shape=None) mock_pf_instance.finalize.return_value = torch.randn(4, 8) mock_prepare_finalize.return_value = mock_pf_instance # Mock token dispatcher mock_td_instance = MagicMock() - mock_td_instance.token_dispatch.return_value = TokenDispatchResult( - hidden_states=torch.randn(6, 8), - group_list=torch.tensor([2, 2, 2]), - group_list_type=1) - mock_td_instance.token_combine.return_value = TokenCombineResult( - routed_out=torch.randn(4, 8)) + dispatch_topk_weights = torch.tensor([[0.5, 0.5], [0.3, 0.7], [0.8, 0.2], [0.6, 0.4]]) + mock_td_instance.token_dispatch.return_value = MoETokenDispatchOutput( + hidden_states=torch.randn(6, 8), + group_list=torch.tensor([2, 2, 2]), + group_list_type=1, + combine_metadata=MoEAllGatherCombineMetadata( + topk_weights=dispatch_topk_weights, + expanded_row_idx=torch.arange(8, dtype=torch.int32), + restore_shape=torch.Size([4, 8]), + )) + mock_td_instance.token_combine.return_value = torch.randn(4, 8) mock_token_dispatcher.return_value = mock_td_instance # Mock unified_apply_mlp @@ -199,8 +224,7 @@ class TestMoECommMethod(TestBase): hidden_states = torch.randn(4, 8).contiguous() w1 = torch.randn(16, 8).contiguous() w2 = torch.randn(16, 8).contiguous() - topk_weights = torch.tensor([[0.5, 0.5], [0.3, 0.7], [0.8, 0.2], - [0.6, 0.4]]) + topk_weights = dispatch_topk_weights topk_ids = torch.tensor([[0, 1], [1, 2], [2, 0], [1, 1]]) # Make sure tensors are contiguous and have correct strides @@ -208,12 +232,25 @@ class TestMoECommMethod(TestBase): w1 = w1.contiguous() w2 = w2.contiguous() - result = comm_impl.fused_experts(hidden_states=hidden_states, - w1=[w1], - w2=[w2], - topk_weights=topk_weights, - topk_ids=topk_ids, - activation="silu") + result = comm_impl.fused_experts(fused_experts_input=MoEFusedExpertsInput( + hidden_states=hidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + weights=MoEWeights( + w1=[w1], + w2=[w2], + ), + routing=MoERoutingParams( + expert_map=None, + global_redundant_expert_num=0, + mc2_mask=None, + apply_router_weight_on_input=False, + ), + activation="silu", + need_trans=False, + dynamic_eplb=False, + quant=MoEQuantParams(), + )) # Verify result shape self.assertEqual(result.routed_out.shape, (4, 8)) @@ -223,6 +260,12 @@ class TestMoECommMethod(TestBase): # Verify unified_apply_mlp was called mock_unified_apply_mlp.assert_called_once() + mlp_compute_input = mock_unified_apply_mlp.call_args.kwargs["mlp_compute_input"] + self.assertFalse(mlp_compute_input.fusion) + self.assertFalse(mlp_compute_input.quant.is_mxfp) # Verify token_combine was called - mock_td_instance.token_combine.assert_called_once() + mock_td_instance.token_combine.assert_called_once_with( + hidden_states=mock_unified_apply_mlp.return_value, + combine_metadata=mock_td_instance.token_dispatch.return_value.combine_metadata, + ) diff --git a/tests/ut/ops/test_moe_mlp.py b/tests/ut/ops/test_moe_mlp.py index ff0bf5f5..ccf6058d 100644 --- a/tests/ut/ops/test_moe_mlp.py +++ b/tests/ut/ops/test_moe_mlp.py @@ -1,9 +1,17 @@ import unittest from typing import ClassVar +from unittest.mock import patch import torch -from vllm_ascend.ops.fused_moe.moe_mlp import cumsum_group_list +from vllm_ascend.ops.fused_moe.moe_mlp import cumsum_group_list, unified_apply_mlp +from vllm_ascend.ops.fused_moe.moe_runtime_args import ( + MoEMlpComputeInput, + MoEQuantParams, + MoEWeights, +) +from vllm_ascend.ops.fused_moe.moe_stage_params import MoEMxfpParams +from vllm_ascend.quantization.quant_type import QuantType class TestCumsumGroupList(unittest.TestCase): @@ -14,7 +22,7 @@ class TestCumsumGroupList(unittest.TestCase): cls.glist_dict = { 0: torch.tensor([0, 2, 3, 3]), 1: torch.tensor([0, 2, 1, 0]), - 2: torch.tensor([[1, 2], [2, 1], [0, 0], [0, 0]]) + 2: torch.tensor([[1, 2], [2, 1], [0, 0], [0, 0]]), } support_combine = [(0, 0), (1, 0), (0, 1)] @@ -23,29 +31,101 @@ class TestCumsumGroupList(unittest.TestCase): def test_cumsum_group_list_supported_conversion(self): for src_list_type, dst_list_type in self.support_combine: with self.subTest(src=src_list_type, dst=dst_list_type): - result = cumsum_group_list(self.glist_dict[src_list_type], - src_list_type, - dst_list_type, - expert_num=4) - self.assertTrue( - torch.equal(result, self.glist_dict[dst_list_type])) + result = cumsum_group_list(self.glist_dict[src_list_type], src_list_type, dst_list_type, expert_num=4) + self.assertTrue(torch.equal(result, self.glist_dict[dst_list_type])) def test_cumsum_group_list_invalid_type_valueerror(self): with self.assertRaises(ValueError) as excinfo: cumsum_group_list(self.glist_dict[0], 4, 0) - self.assertIn("group_list_type should be in [0, 1, 2], but received", - str(excinfo.exception)) + self.assertIn("group_list_type should be in [0, 1, 2], but received", str(excinfo.exception)) - def test_cumsum_group_list_unsupported_conversion_notimplementederror( - self): + def test_cumsum_group_list_unsupported_conversion_notimplementederror(self): for src_list_type, dst_list_type in self.unsupported_combine: with self.subTest(src=src_list_type, dst=dst_list_type): with self.assertRaises(NotImplementedError) as excinfo: - cumsum_group_list(self.glist_dict[0], src_list_type, - dst_list_type) - self.assertIn("This feature is under development.", - str(excinfo.exception)) + cumsum_group_list(self.glist_dict[0], src_list_type, dst_list_type) + self.assertIn("This feature is under development.", str(excinfo.exception)) -if __name__ == '__main__': +class TestUnifiedApplyMlpRequest(unittest.TestCase): + def test_request_unquant_path(self): + hidden_states = torch.randn(2, 8) + expected = torch.randn(2, 8) + mlp_compute_input = MoEMlpComputeInput( + hidden_states=hidden_states, + group_list=torch.tensor([2, 2], dtype=torch.int64), + group_list_type=1, + dynamic_scale=None, + topk_scales=None, + weights=MoEWeights( + w1=torch.randn(1, 16, 8), + w2=torch.randn(1, 8, 8), + w1_bias=torch.randn(1, 16), + w2_bias=torch.randn(1, 8), + ), + quant=MoEQuantParams(quant_type=QuantType.NONE), + fusion=False, + activation="silu", + need_trans=False, + dynamic_eplb=False, + ) + + with ( + patch("vllm_ascend.ops.fused_moe.moe_mlp.unquant_apply_mlp", return_value=expected) as mock_unquant, + patch("vllm_ascend.ops.fused_moe.moe_mlp.quant_apply_mlp") as mock_quant, + ): + output = unified_apply_mlp(mlp_compute_input=mlp_compute_input) + + self.assertTrue(output is expected) + mock_unquant.assert_called_once() + self.assertEqual(mock_unquant.call_args.kwargs["activation"], "silu") + self.assertFalse(mock_unquant.call_args.kwargs["need_trans"]) + mock_quant.assert_not_called() + + def test_request_quant_path(self): + hidden_states = torch.randn(2, 8) + expected = torch.randn(2, 8) + mlp_compute_input = MoEMlpComputeInput( + hidden_states=hidden_states, + group_list=torch.tensor([2, 2], dtype=torch.int64), + group_list_type=1, + dynamic_scale=torch.randn(2, 1), + topk_scales=None, + weights=MoEWeights( + w1=torch.randn(1, 16, 8), + w2=torch.randn(1, 8, 8), + w1_scale=[torch.randn(1)], + w2_scale=[torch.randn(1)], + ), + quant=MoEQuantParams( + quant_type=QuantType.MXFP8, + mxfp=MoEMxfpParams( + act_quant_type=torch.float8_e4m3fn, + weight_quant_type=torch.float8_e4m3fn, + use_bf16=False, + ), + ), + fusion=True, + activation="silu", + need_trans=False, + dynamic_eplb=True, + ) + + with ( + patch("vllm_ascend.ops.fused_moe.moe_mlp.quant_apply_mlp", return_value=expected) as mock_quant, + patch("vllm_ascend.ops.fused_moe.moe_mlp.unquant_apply_mlp") as mock_unquant, + ): + output = unified_apply_mlp(mlp_compute_input=mlp_compute_input) + + self.assertTrue(output is expected) + mock_quant.assert_called_once() + quant_kwargs = mock_quant.call_args.kwargs + self.assertTrue(quant_kwargs["use_mxfp_quant"]) + self.assertTrue(quant_kwargs["fusion"]) + self.assertTrue(quant_kwargs["dynamic_eplb"]) + self.assertFalse(quant_kwargs["use_bf16"]) + mock_unquant.assert_not_called() + + +if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/tests/ut/ops/test_moe_runtime_args.py b/tests/ut/ops/test_moe_runtime_args.py new file mode 100644 index 00000000..ae8e9ea2 --- /dev/null +++ b/tests/ut/ops/test_moe_runtime_args.py @@ -0,0 +1,240 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import unittest + +import torch + +import vllm_ascend.ops.fused_moe.moe_runtime_args as runtime_args +from vllm_ascend.ops.fused_moe.moe_runtime_args import ( + MoEAllGatherCombineMetadata, + MoETokenDispatchOutput, + MoEWeights, + build_fused_experts_input, + build_mlp_compute_input, + build_token_dispatch_input, +) +from vllm_ascend.quantization.quant_type import QuantType + + +class TestMoERuntimeArgs(unittest.TestCase): + def test_runtime_args_facade_exports_public_contracts_and_builders(self): + expected_symbols = [ + "MoEAllGatherCombineMetadata", + "MoEAllToAllCombineMetadata", + "MoEFusedExpertsInput", + "MoEMC2CombineMetadata", + "MoEMlpComputeInput", + "MoEPrepareOutput", + "MoEQuantParams", + "MoERoutingParams", + "MoETokenDispatchInput", + "MoETokenDispatchOutput", + "MoEWeights", + "TMoECombineMetadata", + "build_fused_experts_input", + "build_mlp_compute_input", + "build_token_dispatch_input", + ] + + for symbol in expected_symbols: + with self.subTest(symbol=symbol): + self.assertTrue(hasattr(runtime_args, symbol)) + self.assertFalse(hasattr(runtime_args, "MoEMxfpParams")) + + def test_build_fused_experts_input_preserves_runtime_semantics(self): + for quant_type in ( + QuantType.NONE, + QuantType.W4A16, + QuantType.W4A8, + QuantType.W8A8, + QuantType.MXFP8, + ): + with self.subTest(quant_type=quant_type): + hidden_states = torch.randn(4, 8) + topk_weights = torch.randn(4, 2) + topk_ids = torch.randint(0, 4, (4, 2), dtype=torch.int32) + fused_experts_input = build_fused_experts_input( + hidden_states=hidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + w1=torch.randn(2, 8, 16), + w2=torch.randn(2, 16, 8), + quant_type=quant_type, + dynamic_eplb=True, + expert_map=torch.tensor([0, 1, 2, 3], dtype=torch.int32), + global_redundant_expert_num=2, + mc2_mask=torch.tensor([True, False, True, False]), + apply_router_weight_on_input=True, + log2phy=torch.tensor([3, 2, 1, 0], dtype=torch.int32), + pertoken_scale=torch.randn(4), + activation="gelu", + mxfp_act_quant_type=torch.float8_e4m3fn if quant_type == QuantType.MXFP8 else None, + ) + + self.assertIs(fused_experts_input.hidden_states, hidden_states) + self.assertIs(fused_experts_input.topk_weights, topk_weights) + self.assertIs(fused_experts_input.topk_ids, topk_ids) + self.assertTrue(fused_experts_input.dynamic_eplb) + self.assertTrue(fused_experts_input.routing.apply_router_weight_on_input) + self.assertEqual(fused_experts_input.routing.global_redundant_expert_num, 2) + self.assertEqual(fused_experts_input.activation, "gelu") + self.assertEqual(fused_experts_input.quant.quant_type, quant_type) + + def test_build_fused_experts_input_merges_dense_and_quant_weights(self): + w1 = torch.randn(2, 8, 16) + w2 = torch.randn(2, 16, 8) + w1_scale = [torch.randn(1)] + w2_scale = [torch.randn(1)] + w1_scale_bias = torch.randn(1) + w2_scale_bias = torch.randn(1) + w1_offset = torch.randn(1) + w2_offset = torch.randn(1) + + fused_experts_input = build_fused_experts_input( + hidden_states=torch.randn(4, 8), + topk_weights=torch.randn(4, 2), + topk_ids=torch.randint(0, 4, (4, 2), dtype=torch.int32), + w1=w1, + w2=w2, + quant_type=QuantType.W8A8, + dynamic_eplb=False, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_scale_bias=w1_scale_bias, + w2_scale_bias=w2_scale_bias, + w1_offset=w1_offset, + w2_offset=w2_offset, + ) + + self.assertIsInstance(fused_experts_input.weights, MoEWeights) + self.assertIs(fused_experts_input.weights.w1, w1) + self.assertIs(fused_experts_input.weights.w2, w2) + self.assertIs(fused_experts_input.weights.w1_scale, w1_scale) + self.assertIs(fused_experts_input.weights.w2_scale, w2_scale) + self.assertIs(fused_experts_input.weights.w1_scale_bias, w1_scale_bias) + self.assertIs(fused_experts_input.weights.w2_scale_bias, w2_scale_bias) + self.assertIs(fused_experts_input.weights.w1_offset, w1_offset) + self.assertIs(fused_experts_input.weights.w2_offset, w2_offset) + + def test_build_token_dispatch_input_supports_remapped_topk_ids(self): + fused_experts_input = build_fused_experts_input( + hidden_states=torch.randn(2, 4), + topk_weights=torch.randn(2, 1), + topk_ids=torch.tensor([[0], [1]], dtype=torch.int32), + w1=torch.randn(1, 4, 8), + w2=torch.randn(1, 8, 4), + quant_type=QuantType.NONE, + dynamic_eplb=False, + ) + routed_topk_ids = torch.tensor([[3], [2]], dtype=torch.int32) + + token_dispatch_input = build_token_dispatch_input( + fused_experts_input=fused_experts_input, + topk_ids=routed_topk_ids, + ) + + self.assertIs(token_dispatch_input.hidden_states, fused_experts_input.hidden_states) + self.assertIs(token_dispatch_input.topk_weights, fused_experts_input.topk_weights) + self.assertIs(token_dispatch_input.routing, fused_experts_input.routing) + self.assertIs(token_dispatch_input.quant, fused_experts_input.quant) + self.assertIs(token_dispatch_input.topk_ids, routed_topk_ids) + + def test_build_fused_experts_input_requires_primitive_mxfp_params_for_mxfp_quant(self): + with self.assertRaisesRegex(ValueError, "primitive MXFP params are required"): + build_fused_experts_input( + hidden_states=torch.randn(2, 8), + topk_weights=torch.randn(2, 2), + topk_ids=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), + w1=torch.randn(2, 8, 16), + w2=torch.randn(2, 16, 8), + quant_type=QuantType.MXFP8, + dynamic_eplb=False, + ) + + def test_build_mlp_compute_input_derives_fusion_and_preserves_mxfp_params(self): + fused_experts_input = build_fused_experts_input( + hidden_states=torch.randn(2, 8, dtype=torch.bfloat16), + topk_weights=torch.randn(2, 2), + topk_ids=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), + w1=torch.randn(2, 8, 16), + w2=torch.randn(2, 16, 8), + quant_type=QuantType.MXFP8, + dynamic_eplb=False, + mxfp_act_quant_type=torch.float8_e4m3fn, + mxfp_weight_quant_type=torch.float8_e4m3fn, + mxfp_scale_dtype=torch.float32, + mxfp_per_token_scale_dtype=torch.float16, + mxfp_use_bf16=False, + w1_scale=[torch.randn(1)], + w2_scale=[torch.randn(1)], + ) + token_dispatch_output = MoETokenDispatchOutput( + hidden_states=torch.randn(4, 8, dtype=torch.bfloat16), + group_list=torch.tensor([2, 2], dtype=torch.int64), + group_list_type=1, + dynamic_scale=torch.randn(4, 1), + combine_metadata=MoEAllGatherCombineMetadata( + topk_weights=fused_experts_input.topk_weights, + expanded_row_idx=torch.arange(4, dtype=torch.int32), + restore_shape=torch.Size([2, 8]), + ), + ) + + mlp_compute_input = build_mlp_compute_input( + fused_experts_input=fused_experts_input, + token_dispatch_output=token_dispatch_output, + use_fusion_ops=True, + ) + + self.assertIs(mlp_compute_input.hidden_states, token_dispatch_output.hidden_states) + self.assertIs(mlp_compute_input.weights, fused_experts_input.weights) + self.assertIs(mlp_compute_input.weights.w1_scale, fused_experts_input.weights.w1_scale) + self.assertIs(mlp_compute_input.weights.w2_scale, fused_experts_input.weights.w2_scale) + self.assertTrue(mlp_compute_input.fusion) + self.assertTrue(mlp_compute_input.quant.is_mxfp) + assert mlp_compute_input.quant.mxfp is not None + self.assertEqual(mlp_compute_input.quant.mxfp.scale_dtype, torch.float32) + self.assertEqual(mlp_compute_input.quant.mxfp.per_token_scale_dtype, torch.float16) + self.assertFalse(mlp_compute_input.quant.mxfp.use_bf16) + + def test_build_fused_experts_input_constructs_internal_mxfp_leaf_from_primitives(self): + fused_experts_input = build_fused_experts_input( + hidden_states=torch.randn(2, 8, dtype=torch.bfloat16), + topk_weights=torch.randn(2, 2), + topk_ids=torch.tensor([[0, 1], [1, 0]], dtype=torch.int32), + w1=torch.randn(2, 8, 16), + w2=torch.randn(2, 16, 8), + quant_type=QuantType.MXFP8, + dynamic_eplb=False, + mxfp_act_quant_type=torch.float8_e4m3fn, + mxfp_weight_quant_type=torch.float8_e4m3fn, + mxfp_scale_dtype=torch.float32, + mxfp_per_token_scale_dtype=torch.float16, + mxfp_use_bf16=False, + ) + + self.assertTrue(fused_experts_input.quant.is_mxfp) + assert fused_experts_input.quant.mxfp is not None + self.assertEqual(fused_experts_input.quant.mxfp.act_quant_type, torch.float8_e4m3fn) + self.assertEqual(fused_experts_input.quant.mxfp.weight_quant_type, torch.float8_e4m3fn) + self.assertEqual(fused_experts_input.quant.mxfp.scale_dtype, torch.float32) + self.assertEqual(fused_experts_input.quant.mxfp.per_token_scale_dtype, torch.float16) + self.assertFalse(fused_experts_input.quant.mxfp.use_bf16) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/tests/ut/ops/test_prepare_finalize.py b/tests/ut/ops/test_prepare_finalize.py index f25b5ab2..51f5a514 100644 --- a/tests/ut/ops/test_prepare_finalize.py +++ b/tests/ut/ops/test_prepare_finalize.py @@ -45,18 +45,22 @@ class TestPrepareAndFinalize(unittest.TestCase): hidden_states = torch.randn(3, 8) router_logits = torch.randn(3, 2) - h_out, r_out, mask, context_metadata = layer.prepare( - hidden_states, router_logits) + prepare_output = layer.prepare(hidden_states, router_logits) + h_out = prepare_output.hidden_states + r_out = prepare_output.router_logits + mask = prepare_output.mc2_mask + padded_hidden_states_shape = prepare_output.padded_hidden_states_shape # Check padding and split self.assertEqual(h_out.shape[0], 4) self.assertEqual(r_out.shape[0], 4) self.assertEqual(mask.tolist(), [1, 0, 1]) + self.assertEqual(padded_hidden_states_shape, torch.Size([4, 8])) # Finalize result = layer.finalize(h_out, reduce_results=False, - context_metadata=context_metadata) + padded_hidden_states_shape=padded_hidden_states_shape) self.assertEqual(result.shape[0], 3) @patch( @@ -79,14 +83,19 @@ class TestPrepareAndFinalize(unittest.TestCase): hidden_states = torch.randn(4, 8) router_logits = torch.randn(4, 2) - h_out, r_out, mask, context_metadata = layer.prepare( + prepare_output = layer.prepare( hidden_states, router_logits, enable_shared_expert_dp=False, replace_allreduce=False) + h_out = prepare_output.hidden_states + r_out = prepare_output.router_logits + mask = prepare_output.mc2_mask + padded_hidden_states_shape = prepare_output.padded_hidden_states_shape # With TP=2, should split into 2 parts self.assertEqual(h_out.shape[0], 2) + self.assertEqual(padded_hidden_states_shape, torch.Size([4, 8])) # Mock all_gather behavior def mock_all_gather_func(tensor_list, tensor, group=None): @@ -101,7 +110,7 @@ class TestPrepareAndFinalize(unittest.TestCase): ] final_result = layer.finalize(h_out, reduce_results=False, - context_metadata=context_metadata) + padded_hidden_states_shape=padded_hidden_states_shape) # Should concat back to original size self.assertEqual(final_result.shape[0], 4) @@ -117,15 +126,18 @@ class TestPrepareAndFinalize(unittest.TestCase): hidden_states = torch.randn(3, 8) router_logits = torch.randn(3, 2) - h_out, r_out, _, context_metadata = layer.prepare( - hidden_states, router_logits) + prepare_output = layer.prepare(hidden_states, router_logits) + h_out = prepare_output.hidden_states + r_out = prepare_output.router_logits + padded_hidden_states_shape = prepare_output.padded_hidden_states_shape # Pad to tp_size=1, so no change self.assertEqual(h_out.shape[0], 3) + self.assertEqual(padded_hidden_states_shape, torch.Size([3, 8])) result = layer.finalize(h_out, reduce_results=False, - context_metadata=context_metadata) + padded_hidden_states_shape=padded_hidden_states_shape) self.assertEqual(result.shape[0], 3) @patch( @@ -141,14 +153,18 @@ class TestPrepareAndFinalize(unittest.TestCase): hidden_states = torch.randn(2, 8) router_logits = torch.randn(2, 2) - h_out, r_out, _, context_metadata = layer.prepare( + prepare_output = layer.prepare( hidden_states, router_logits, enable_shared_expert_dp=False, replace_allreduce=False) + h_out = prepare_output.hidden_states + r_out = prepare_output.router_logits + padded_hidden_states_shape = prepare_output.padded_hidden_states_shape # Split due to TP=2 self.assertEqual(h_out.shape[0], 1) + self.assertEqual(padded_hidden_states_shape, torch.Size([2, 8])) # Mock all_gather def mock_all_gather_func(tensor_list, tensor, group=None): @@ -163,7 +179,7 @@ class TestPrepareAndFinalize(unittest.TestCase): ] final_result = layer.finalize(h_out, reduce_results=False, - context_metadata=context_metadata) + padded_hidden_states_shape=padded_hidden_states_shape) # Should concat back self.assertEqual(final_result.shape[0], 2) @@ -200,12 +216,15 @@ class TestPrepareAndFinalize(unittest.TestCase): hidden_states = torch.randn(3, 8) router_logits = torch.randn(3, 2) - h_out, r_out, _, context_metadata = layer.prepare( - hidden_states, router_logits) + prepare_output = layer.prepare(hidden_states, router_logits) + h_out = prepare_output.hidden_states + r_out = prepare_output.router_logits + padded_hidden_states_shape = prepare_output.padded_hidden_states_shape # After all-gather with DP=2, should double the batch size self.assertEqual(h_out.shape[0], 12) self.assertEqual(r_out.shape[0], 12) + self.assertIsNone(padded_hidden_states_shape) # Finalize with reduce_scatter def mock_reduce_scatter_func(tensor, dim): @@ -215,7 +234,7 @@ class TestPrepareAndFinalize(unittest.TestCase): mock_dp_group.reduce_scatter = mock_reduce_scatter_func result = layer.finalize(h_out, reduce_results=False, - context_metadata=context_metadata) + padded_hidden_states_shape=padded_hidden_states_shape) self.assertEqual(result.shape[0], 3) diff --git a/tests/ut/ops/test_token_dispatcher.py b/tests/ut/ops/test_token_dispatcher.py index 4844013b..f10c6f5a 100644 --- a/tests/ut/ops/test_token_dispatcher.py +++ b/tests/ut/ops/test_token_dispatcher.py @@ -17,14 +17,62 @@ from unittest.mock import MagicMock, PropertyMock, patch +import numpy as np import pytest import torch from tests.ut.base import TestBase - +from vllm_ascend.ops.fused_moe.moe_runtime_args import ( + MoEAllGatherCombineMetadata, + MoEAllToAllCombineMetadata, + MoEMC2CombineMetadata, + MoEQuantParams, + MoERoutingParams, + MoETokenDispatchInput, +) from vllm_ascend.ops.fused_moe.token_dispatcher import ( # isort: skip - AscendDeviceType, TokenDispatcherWithAll2AllV, - TokenDispatcherWithAllGather, TokenDispatcherWithMC2) + AscendDeviceType, + TokenDispatcherWithAll2AllV, + TokenDispatcherWithAllGather, + TokenDispatcherWithMC2, +) +from vllm_ascend.ops.fused_moe.moe_stage_params import MoEMxfpParams +from vllm_ascend.quantization.quant_type import QuantType + + +def build_token_dispatch_input_fixture( + *, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + expert_map: torch.Tensor | None = None, + global_redundant_expert_num: int = 0, + apply_router_weight_on_input: bool = False, + pertoken_scale: torch.Tensor | None = None, + quant_type: QuantType = QuantType.NONE, + comm_quant_mode: int | None = None, + act_quant_type: torch.dtype | None = None, +) -> MoETokenDispatchInput: + mxfp_spec = None + if quant_type == QuantType.MXFP8: + mxfp_spec = MoEMxfpParams(act_quant_type=act_quant_type) + return MoETokenDispatchInput( + hidden_states=hidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + routing=MoERoutingParams( + expert_map=expert_map, + global_redundant_expert_num=global_redundant_expert_num, + mc2_mask=None, + apply_router_weight_on_input=apply_router_weight_on_input, + pertoken_scale=pertoken_scale, + ), + quant=MoEQuantParams( + quant_type=quant_type, + comm_quant_mode=comm_quant_mode, + mxfp=mxfp_spec, + ), + ) class TestTokenDispatcherWithMC2(TestBase): @@ -85,7 +133,6 @@ class TestTokenDispatcherWithMC2(TestBase): def test_init(self): self.assertEqual(self.dispatcher.ep_rank_id, 0) self.assertEqual(self.dispatcher.ep_world_size, 8) - self.assertFalse(self.dispatcher.with_quant) self.assertTrue(self.dispatcher.enable_dispatch_v2) self.assertTrue(self.dispatcher.need_extra_args) @@ -94,10 +141,16 @@ class TestTokenDispatcherWithMC2(TestBase): topk_ids = torch.randint(0, 8, (10, 1)) topk_weights = torch.randn(10, 1) expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) - mc2_mask = None - - kwargs = self.dispatcher.get_dispatch_mc2_kwargs( - hidden_states, topk_weights, topk_ids, expert_map, mc2_mask) + token_dispatch_input = build_token_dispatch_input_fixture( + hidden_states=hidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + expert_map=expert_map, + global_redundant_expert_num=0, + apply_router_weight_on_input=False, + pertoken_scale=None, + ) + kwargs = self.dispatcher.get_dispatch_mc2_kwargs(token_dispatch_input) self.assertIn("x", kwargs) self.assertIn("expert_ids", kwargs) self.assertEqual(kwargs["moe_expert_num"], 8) @@ -111,39 +164,42 @@ class TestTokenDispatcherWithMC2(TestBase): with patch("torch_npu.npu_moe_distribute_dispatch_v2", return_value=(torch.randn(10, 128), ) * 5 + (None, None)) as mock_dispatch: - output = self.dispatcher.token_dispatch(hidden_states, - topk_weights, topk_ids, - expert_map) + token_dispatch_input = build_token_dispatch_input_fixture( + hidden_states=hidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + expert_map=expert_map, + ) + output = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input) mock_dispatch.assert_called_once() self.assertEqual(output.group_list_type, 0) # group_list_type == 0 + self.assertIsInstance(output.combine_metadata, MoEMC2CombineMetadata) def test_get_combine_mc_kwargs_with_quant(self): - self.dispatcher.with_quant = True hidden_states = torch.randn(10, 128) topk_ids = torch.randint(0, 8, (10, 1)) topk_weights = torch.randn(10, 1) expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) ep_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) tp_recv_counts = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) - mc2_mask = None assist_info_for_combine = torch.arange(10) - context_metadata = { - "topk_ids": topk_ids, - "topk_weights": topk_weights, - "expert_map": expert_map, - "ep_recv_counts": ep_recv_counts, - "mc2_mask": mc2_mask, - "assist_info_for_combine": assist_info_for_combine, - "expand_scales": None, - "tp_recv_counts": tp_recv_counts - } + combine_metadata = MoEMC2CombineMetadata( + topk_ids=topk_ids, + topk_weights=topk_weights, + expert_map=expert_map, + ep_recv_counts=ep_recv_counts, + tp_recv_counts=tp_recv_counts, + assist_info_for_combine=assist_info_for_combine, + expand_scales=None, + dispatch_with_quant=True, + ) self.dispatcher.need_extra_args = True self.dispatcher.enable_dispatch_v2 = True self.dispatcher.moe_expert_num = len(expert_map) kwargs = self.dispatcher.get_combine_mc_kwargs(hidden_states, - context_metadata) + combine_metadata) self.assertIn("tp_send_counts", kwargs) @@ -188,14 +244,19 @@ class TestTokenDispatcherWithAllGather(TestBase): topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]) topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]]) - results = self.dispatcher.token_dispatch(hidden_states, topk_weights, - topk_ids, None) + token_dispatch_input = build_token_dispatch_input_fixture( + hidden_states=hidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + ) + results = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input) # Verify npu_moe_init_routing is called self.mock_npu_moe_init_routing_custom.assert_called_once() args, kwargs = self.mock_npu_moe_init_routing_custom.call_args self.assertEqual(results.group_list_type, 1) + self.assertIsInstance(results.combine_metadata, MoEAllGatherCombineMetadata) @pytest.mark.skip( "Skip as register_kernels has NPU SocName checking in CANN 8.5.0.") @@ -205,14 +266,19 @@ class TestTokenDispatcherWithAllGather(TestBase): topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]) topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]]) - results = self.dispatcher.token_dispatch(hidden_states, topk_weights, - topk_ids, None) + token_dispatch_input = build_token_dispatch_input_fixture( + hidden_states=hidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + ) + results = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input) # Verify npu_moe_init_routing is called self.mock_npu_moe_init_routing_custom.assert_called_once() args, kwargs = self.mock_npu_moe_init_routing_custom.call_args self.assertEqual(results.group_list_type, 1) + self.assertIsInstance(results.combine_metadata, MoEAllGatherCombineMetadata) @pytest.mark.skip( "Skip as register_kernels has NPU SocName checking in CANN 8.5.0.") @@ -230,9 +296,12 @@ class TestTokenDispatcherWithAllGather(TestBase): topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]) topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]]) - results = self.dispatcher_quant.token_dispatch(hidden_states, - topk_weights, topk_ids, - None) + token_dispatch_input = build_token_dispatch_input_fixture( + hidden_states=hidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + ) + results = self.dispatcher_quant.token_dispatch(token_dispatch_input=token_dispatch_input) self.assertEqual(results.group_list_type, 1) @@ -252,11 +321,13 @@ class TestTokenDispatcherWithAllGather(TestBase): topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]) topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]]) - results = self.dispatcher_quant.token_dispatch(hidden_states, - topk_weights, - topk_ids, - None, - with_quant=True) + token_dispatch_input = build_token_dispatch_input_fixture( + hidden_states=hidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + quant_type=QuantType.W8A8, + ) + results = self.dispatcher_quant.token_dispatch(token_dispatch_input=token_dispatch_input) self.assertIsNotNone(results.hidden_states) self.assertIsNotNone(results.group_list) @@ -267,40 +338,43 @@ class TestTokenDispatcherWithAllGather(TestBase): "Skip as register_kernels has NPU SocName checking in CANN 8.5.0.") def test_token_combine_with_expert_map(self): hidden_states = torch.randn(6, 128) - context_metadata = { - "expanded_row_idx": torch.tensor([0, 1, 1, 1, 1, 1]), - "topk_weights": torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]), - } - self.dispatcher.original_shape = (6, 128) - final_hidden_states = self.dispatcher.token_combine( - hidden_states, context_metadata).routed_out + combine_metadata = MoEAllGatherCombineMetadata( + expanded_row_idx=torch.tensor([0, 1, 1, 1, 1, 1]), + topk_weights=torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]), + restore_shape=torch.Size([6, 128]), + ) + final_hidden_states = self.dispatcher.token_combine(hidden_states, combine_metadata) self.assertEqual(final_hidden_states.shape, (6, 128)) @pytest.mark.skip( "Skip as register_kernels has NPU SocName checking in CANN 8.5.0.") def test_token_combine_without_expert_map(self): hidden_states = torch.randn(6, 128) - context_metadata = { - "expanded_row_idx": torch.tensor([0, 1, 1, 1, 1, 1]), - "topk_weights": torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]), - } - self.dispatcher.original_shape = (6, 128) - final_hidden_states = self.dispatcher.token_combine( - hidden_states, context_metadata).routed_out + combine_metadata = MoEAllGatherCombineMetadata( + expanded_row_idx=torch.tensor([0, 1, 1, 1, 1, 1]), + topk_weights=torch.tensor([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]), + restore_shape=torch.Size([6, 128]), + ) + final_hidden_states = self.dispatcher.token_combine(hidden_states, combine_metadata) self.mock_npu_moe_token_unpermute.assert_called_once() self.assertEqual(final_hidden_states.shape, (6, 128)) @pytest.mark.skip( "Skip as register_kernels has NPU SocName checking in CANN 8.5.0.") def test_token_dispatch_with_router_weight(self): - self.dispatcher.apply_router_weight_on_input = True hidden_states = torch.randn(3, 128) topk_weights = torch.tensor([[0.7], [0.6], [0.5]]) # topk=1 topk_ids = torch.tensor([[0], [1], [2]]) - results = self.dispatcher.token_dispatch(hidden_states, topk_weights, - topk_ids, None) + token_dispatch_input = build_token_dispatch_input_fixture( + hidden_states=hidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + apply_router_weight_on_input=True, + ) + results = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input) self.assertEqual(results.hidden_states.shape, (6, 128)) + self.assertIsInstance(results.combine_metadata, MoEAllGatherCombineMetadata) class TestTokenDispatcherWithAll2AllV(TestBase): @@ -408,35 +482,39 @@ class TestTokenDispatcherWithAll2AllV(TestBase): [0, 1], dtype=torch.int32) self.dispatcher.local_expert_indices = [0, 1] - result = self.dispatcher.token_dispatch(hidden_states=hidden_states, - topk_weights=topk_weights, - topk_ids=topk_ids, - expert_map=expert_map) + token_dispatch_input = build_token_dispatch_input_fixture( + hidden_states=hidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + expert_map=expert_map, + ) + result = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input) self.assertIsNotNone(result.hidden_states) self.assertIsNotNone(result.group_list) self.assertEqual(result.group_list_type, 1) + self.assertIsInstance(result.combine_metadata, MoEAllToAllCombineMetadata) @pytest.mark.skip( "Skip as register_kernels has NPU SocName checking in CANN 8.5.0.") def test_token_combine(self): hidden_states = torch.randn(16, 16) - context_metadata = { - "input_splits": [4, 4], - "output_splits": [4, 4], - "topk_weights": torch.rand(8, 4), - "reversed_local_input_permutation_mapping": torch.arange(8), - "reversed_global_input_permutation_mapping": torch.arange(16), - } - self.dispatcher.hidden_shape = (8, 16) - self.dispatcher.hidden_shape_before_permute = (8, 16) + combine_metadata = MoEAllToAllCombineMetadata( + input_splits=np.array([4, 4]), + output_splits=np.array([4, 4]), + topk_weights=torch.rand(8, 4), + reversed_local_input_permutation_mapping=torch.arange(8), + reversed_global_input_permutation_mapping=torch.arange(16), + hidden_shape=torch.Size([8, 16]), + hidden_shape_before_permute=torch.Size([8, 16]), + ) self.dispatcher.expert_ids_per_ep_rank = torch.tensor( [0, 1], dtype=torch.int32) self.dispatcher.local_expert_indices = [0, 1] - output = self.dispatcher.token_combine(hidden_states, context_metadata) + output = self.dispatcher.token_combine(hidden_states, combine_metadata) self.assertIsNotNone(output) - self.assertEqual(output.routed_out.shape, (8, 16)) + self.assertEqual(output.shape, (8, 16)) @pytest.mark.skip( "Skip as register_kernels has NPU SocName checking in CANN 8.5.0.") @@ -454,16 +532,20 @@ class TestTokenDispatcherWithAll2AllV(TestBase): [0, 1], dtype=torch.int32) self.dispatcher.local_expert_indices = [0, 1] - result = self.dispatcher.token_dispatch(hidden_states=hidden_states, - topk_weights=topk_weights, - topk_ids=topk_ids, - expert_map=expert_map, - with_quant=True) + token_dispatch_input = build_token_dispatch_input_fixture( + hidden_states=hidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + expert_map=expert_map, + quant_type=QuantType.W8A8, + ) + result = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input) self.assertIsNotNone(result.hidden_states) self.assertIsNotNone(result.group_list) self.assertIsNotNone(result.dynamic_scale) self.assertEqual(result.group_list_type, 1) + self.assertIsInstance(result.combine_metadata, MoEAllToAllCombineMetadata) @pytest.mark.skip( "Skip as register_kernels has NPU SocName checking in CANN 8.5.0.") @@ -484,14 +566,16 @@ class TestTokenDispatcherWithAll2AllV(TestBase): [0, 1], dtype=torch.int32) self.dispatcher.local_expert_indices = [0, 1] - result = self.dispatcher.token_dispatch(hidden_states=hidden_states, - topk_weights=topk_weights, - topk_ids=topk_ids, - expert_map=expert_map, - with_quant=True) + token_dispatch_input = build_token_dispatch_input_fixture( + hidden_states=hidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + expert_map=expert_map, + quant_type=QuantType.W8A8, + ) + result = self.dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input) self.assertIsNotNone(result.hidden_states) self.assertIsNotNone(result.group_list) self.assertIsNotNone(result.dynamic_scale) self.assertEqual(result.group_list_type, 1) - diff --git a/tests/ut/quantization/test_w4a16.py b/tests/ut/quantization/test_w4a16.py index adf4f706..87c2b79f 100644 --- a/tests/ut/quantization/test_w4a16.py +++ b/tests/ut/quantization/test_w4a16.py @@ -3,9 +3,8 @@ from unittest.mock import Mock, patch import torch from tests.ut.base import TestBase -from vllm_ascend.quantization.methods.w4a16 import (AscendW4A16FusedMoEMethod, - pack_to_int32, - unpack_from_int32) +from vllm_ascend.ascend_forward_context import MoECommType +from vllm_ascend.quantization.methods.w4a16 import AscendW4A16FusedMoEMethod, pack_to_int32, unpack_from_int32 class TestUnpackFromInt32(TestBase): @@ -268,3 +267,41 @@ class TestAscendW4A16FusedMoEMethod(TestBase): torch.equal(layer.w13_weight_packed.data, original_w13_data)) self.assertTrue( torch.equal(layer.w2_weight_packed.data, original_w2_data)) + + @patch("vllm_ascend.quantization.methods.w4a16._EXTRA_CTX") + @patch("vllm_ascend.quantization.methods.w4a16.select_experts") + def test_apply_uses_explicit_dispatch_and_mlp_args(self, mock_select_experts, mock_extra_ctx): + tokens = 3 + hidden_size = self.output_size + layer = self.build_layer() + x = torch.randn(tokens, hidden_size, dtype=torch.float32) + router_logits = torch.randn(tokens, self.experts, dtype=torch.float32) + topk_weights = torch.randn(tokens, 2, dtype=torch.float32) + topk_ids = torch.randint(0, self.experts, (tokens, 2), dtype=torch.int64) + mc2_mask = torch.tensor([1, 0, 1], dtype=torch.bool) + pertoken_scale = torch.randn(tokens, dtype=torch.float32) + + mock_select_experts.return_value = (topk_weights, topk_ids) + mock_comm = Mock() + mock_comm.fused_experts.return_value = torch.randn(tokens, hidden_size, dtype=torch.float32) + mock_extra_ctx.moe_comm_method = mock_comm + mock_extra_ctx.moe_comm_type = MoECommType.ALLGATHER + + self.quant_method.apply( + layer=layer, + x=x, + router_logits=router_logits, + top_k=2, + renormalize=True, + global_num_experts=self.experts, + activation="gelu", + apply_router_weight_on_input=True, + mc2_mask=mc2_mask, + pertoken_scale=pertoken_scale, + ) + + fused_experts_input = mock_comm.fused_experts.call_args.kwargs["fused_experts_input"] + self.assertEqual(fused_experts_input.activation, "gelu") + self.assertTrue(fused_experts_input.routing.apply_router_weight_on_input) + self.assertIs(fused_experts_input.routing.mc2_mask, mc2_mask) + self.assertIs(fused_experts_input.routing.pertoken_scale, pertoken_scale) diff --git a/tests/ut/quantization/test_w8a8_dynamic.py b/tests/ut/quantization/test_w8a8_dynamic.py index 00cc8f13..f01898f0 100644 --- a/tests/ut/quantization/test_w8a8_dynamic.py +++ b/tests/ut/quantization/test_w8a8_dynamic.py @@ -3,8 +3,8 @@ from unittest.mock import Mock, patch import torch from tests.ut.base import TestBase -from vllm_ascend.quantization.methods.w8a8_dynamic import \ - AscendW8A8DynamicFusedMoEMethod +from vllm_ascend.ascend_forward_context import MoECommType +from vllm_ascend.quantization.methods.w8a8_dynamic import AscendW8A8DynamicFusedMoEMethod class TestAscendW8A8FusedMoEMethod(TestBase): @@ -32,8 +32,9 @@ class TestAscendW8A8FusedMoEMethod(TestBase): mock_ep_group = Mock() mock_get_ep_group.return_value = mock_ep_group mock_ascend_config = Mock() - mock_ascend_config.enable_chunked_prefill = False + mock_ascend_config.multistream_overlap_gate = False + mock_ascend_config.eplb_config = Mock(dynamic_eplb=False) mock_get_ascend_config.return_value = mock_ascend_config mock_mc2_group = Mock(device_group=0) mock_get_mc2_group.return_value = mock_mc2_group @@ -104,3 +105,125 @@ class TestAscendW8A8FusedMoEMethod(TestBase): new_layer = self.build_layer() self.quant_method.process_weights_after_loading(new_layer) mock_npu_format_cast.assert_called() + + @patch("vllm_ascend.quantization.methods.w8a8_dynamic._EXTRA_CTX") + @patch("vllm_ascend.quantization.methods.w8a8_dynamic.select_experts") + def test_apply_uses_explicit_dispatch_and_mlp_args(self, mock_select_experts, mock_extra_ctx): + tokens = 4 + hidden_size = self.hidden_size + layer = torch.nn.Module() + layer.w13_weight = torch.randint( + -8, + 8, + (self.num_experts, 2 * self.intermediate_size, hidden_size), + dtype=torch.int8, + ) + layer.w2_weight = torch.randint( + -8, + 8, + (self.num_experts, hidden_size, self.intermediate_size), + dtype=torch.int8, + ) + layer.w13_weight_scale_fp32 = torch.ones(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32) + layer.w2_weight_scale = torch.ones(self.num_experts, hidden_size, dtype=torch.float32) + + x = torch.randn(tokens, hidden_size, dtype=torch.float32) + router_logits = torch.randn(tokens, self.num_experts, dtype=torch.float32) + topk_weights = torch.randn(tokens, 2, dtype=torch.float32) + topk_ids = torch.randint(0, self.num_experts, (tokens, 2), dtype=torch.int64) + mc2_mask = torch.tensor([1, 0, 1, 0], dtype=torch.bool) + pertoken_scale = torch.randn(tokens, dtype=torch.float32) + + mock_select_experts.return_value = (topk_weights, topk_ids) + mock_comm = Mock() + mock_comm.fused_experts.return_value = torch.randn(tokens, hidden_size, dtype=torch.float32) + mock_extra_ctx.moe_comm_method = mock_comm + mock_extra_ctx.moe_comm_type = MoECommType.ALLGATHER + self.quant_method.multistream_overlap_gate = False + self.quant_method.in_dtype = torch.float32 + + self.quant_method.apply( + layer=layer, + x=x, + router_logits=router_logits, + top_k=2, + renormalize=True, + global_num_experts=self.num_experts, + activation="gelu", + apply_router_weight_on_input=True, + mc2_mask=mc2_mask, + pertoken_scale=pertoken_scale, + ) + + fused_experts_input = mock_comm.fused_experts.call_args.kwargs["fused_experts_input"] + self.assertEqual(fused_experts_input.activation, "gelu") + self.assertTrue(fused_experts_input.routing.apply_router_weight_on_input) + self.assertIs(fused_experts_input.routing.mc2_mask, mc2_mask) + self.assertIs(fused_experts_input.routing.pertoken_scale, pertoken_scale) + self.assertIs(fused_experts_input.topk_weights, topk_weights) + self.assertIs(fused_experts_input.topk_ids, topk_ids) + + @patch("vllm_ascend.quantization.methods.w8a8_dynamic.get_flash_common3_context") + @patch("vllm_ascend.quantization.methods.w8a8_dynamic._EXTRA_CTX") + @patch("vllm_ascend.quantization.methods.w8a8_dynamic.select_experts") + def test_apply_overlap_gate_uses_fc3_context( + self, + mock_select_experts, + mock_extra_ctx, + mock_get_flash_common3_context, + ): + tokens = 4 + hidden_size = self.hidden_size + layer = torch.nn.Module() + layer.w13_weight = torch.randint( + -8, + 8, + (self.num_experts, 2 * self.intermediate_size, hidden_size), + dtype=torch.int8, + ) + layer.w2_weight = torch.randint( + -8, + 8, + (self.num_experts, hidden_size, self.intermediate_size), + dtype=torch.int8, + ) + layer.w13_weight_scale_fp32 = torch.ones(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32) + layer.w2_weight_scale = torch.ones(self.num_experts, hidden_size, dtype=torch.float32) + + x = torch.randn(tokens, hidden_size, dtype=torch.float32) + router_logits = torch.randn(tokens, self.num_experts, dtype=torch.float32) + topk_weights = torch.randn(tokens, 2, dtype=torch.float32) + topk_ids = torch.randint(0, self.num_experts, (tokens, 2), dtype=torch.int64) + mc2_mask = torch.tensor([1, 0, 1, 0], dtype=torch.bool) + pertoken_scale = torch.randn(tokens, dtype=torch.float32) + + self.quant_method.multistream_overlap_gate = True + self.quant_method.in_dtype = torch.float32 + mock_get_flash_common3_context.return_value = Mock(topk_weights=topk_weights, topk_ids=topk_ids) + + mock_comm = Mock() + mock_comm.fused_experts.return_value = torch.randn(tokens, hidden_size, dtype=torch.float32) + mock_extra_ctx.moe_comm_method = mock_comm + mock_extra_ctx.moe_comm_type = MoECommType.ALLGATHER + + self.quant_method.apply( + layer=layer, + x=x, + router_logits=router_logits, + top_k=2, + renormalize=True, + global_num_experts=self.num_experts, + activation="gelu", + apply_router_weight_on_input=True, + mc2_mask=mc2_mask, + pertoken_scale=pertoken_scale, + ) + + mock_select_experts.assert_not_called() + fused_experts_input = mock_comm.fused_experts.call_args.kwargs["fused_experts_input"] + self.assertEqual(fused_experts_input.activation, "gelu") + self.assertTrue(fused_experts_input.routing.apply_router_weight_on_input) + self.assertIs(fused_experts_input.routing.mc2_mask, mc2_mask) + self.assertIs(fused_experts_input.routing.pertoken_scale, pertoken_scale) + self.assertIs(fused_experts_input.topk_weights, topk_weights) + self.assertIs(fused_experts_input.topk_ids, topk_ids) diff --git a/vllm_ascend/_310p/fused_moe/fused_moe.py b/vllm_ascend/_310p/fused_moe/fused_moe.py index 17b1765c..9de035cd 100644 --- a/vllm_ascend/_310p/fused_moe/fused_moe.py +++ b/vllm_ascend/_310p/fused_moe/fused_moe.py @@ -25,7 +25,8 @@ from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType from vllm_ascend.ops.fused_moe.experts_selector import zero_experts_compute from vllm_ascend.ops.fused_moe.moe_comm_method import FusedExpertsResult, _MoECommMethods -from vllm_ascend.quantization.methods.base import QuantType +from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input +from vllm_ascend.quantization.quant_type import QuantType from .experts_selector import select_experts from .moe_comm_method import AllGatherCommImpl310 @@ -93,13 +94,17 @@ class AscendUnquantizedFusedMoEMethod310(UnquantizedFusedMoEMethod): moe_comm_method = _EXTRA_CTX.moe_comm_method final_hidden_states = moe_comm_method.fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, + fused_experts_input=build_fused_experts_input( + hidden_states=x, + topk_weights=topk_weights, + topk_ids=topk_ids, + w1=layer.w13_weight, + w2=layer.w2_weight, + quant_type=QuantType.NONE, + dynamic_eplb=False, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + ), ) if zero_expert_num > 0 and zero_expert_type is not None: final_hidden_states += zero_expert_result @@ -218,9 +223,13 @@ class AscendFusedMoE310(FusedMoE): assert self.quant_method is not None assert self.routed_scaling_factor == 1.0, "routed_scaling_factor != 1.0 is not supported." - hidden_states, router_logits, _, context_metadata = _EXTRA_CTX.moe_comm_method.prepare( + prepare_output = _EXTRA_CTX.moe_comm_method.prepare( hidden_states=hidden_states, router_logits=router_logits, quant_type=self.quant_type ) + hidden_states = prepare_output.hidden_states + router_logits = prepare_output.router_logits + pertoken_scale = prepare_output.pertoken_scale + padded_hidden_states_shape = prepare_output.padded_hidden_states_shape # Matrix multiply. fused_experts_results: FusedExpertsResult = self.quant_method.apply( @@ -238,12 +247,13 @@ class AscendFusedMoE310(FusedMoE): global_num_experts=self.global_num_experts, expert_map=self.local_expert_map, apply_router_weight_on_input=self.apply_router_weight_on_input, + pertoken_scale=pertoken_scale, ) routed_out = _EXTRA_CTX.moe_comm_method.finalize( hidden_states=fused_experts_results.routed_out, reduce_results=self.reduce_results, - context_metadata=context_metadata, + padded_hidden_states_shape=padded_hidden_states_shape, ) return routed_out diff --git a/vllm_ascend/_310p/fused_moe/moe_comm_method.py b/vllm_ascend/_310p/fused_moe/moe_comm_method.py index 589566fc..efbed5bf 100644 --- a/vllm_ascend/_310p/fused_moe/moe_comm_method.py +++ b/vllm_ascend/_310p/fused_moe/moe_comm_method.py @@ -17,8 +17,8 @@ from __future__ import annotations import torch -from vllm_ascend.ascend_forward_context import _EXTRA_CTX -from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl, FusedExpertsResult +from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl +from vllm_ascend.ops.fused_moe.moe_runtime_args import MoEMlpComputeInput from .moe_mlp import unified_apply_mlp from .token_dispatcher import TokenDispatcherWithAllGather310 @@ -35,52 +35,12 @@ class AllGatherCommImpl310(AllGatherCommImpl): to handle the token-to-expert mapping and communication efficiently. """ - def fused_experts( # type: ignore[override] - self, - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - expert_map: torch.Tensor | None = None, - use_int8_w8a8: bool = False, - w1_scale: torch.Tensor | None = None, - w2_scale: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - ) -> FusedExpertsResult: - # This method is overridden to use the 310p-specific unified_apply_mlp - # which provides optimized MLP computation for the 310p platform - moe_comm_method = _EXTRA_CTX.moe_comm_method - assert moe_comm_method is not None, "Missing communication context" + def __init__(self, moe_config): + super().__init__(moe_config) + self.use_fusion_ops = False - dispatch_results = self.token_dispatcher.token_dispatch( - hidden_states=hidden_states, - topk_weights=topk_weights, - topk_ids=topk_ids, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - ) - - mlp_output = unified_apply_mlp( - hidden_states=dispatch_results.hidden_states, - w1=w1, - w2=w2, - w1_scale=w1_scale, - w2_scale=w2_scale, - group_list=dispatch_results.group_list, - group_list_type=dispatch_results.group_list_type, - with_quant=use_int8_w8a8, - ) - - combine_results = self.token_dispatcher.token_combine( - hidden_states=mlp_output, context_metadata=dispatch_results.context_metadata - ) - - return FusedExpertsResult( - routed_out=combine_results.routed_out, - group_list_type=dispatch_results.group_list_type, - expert_tokens=dispatch_results.group_list, - ) + def _apply_mlp(self, mlp_compute_input: MoEMlpComputeInput) -> torch.Tensor: + return unified_apply_mlp(mlp_compute_input=mlp_compute_input) def _get_token_dispatcher(self): return TokenDispatcherWithAllGather310( diff --git a/vllm_ascend/_310p/fused_moe/moe_mlp.py b/vllm_ascend/_310p/fused_moe/moe_mlp.py index ff85ac44..26f8eec3 100644 --- a/vllm_ascend/_310p/fused_moe/moe_mlp.py +++ b/vllm_ascend/_310p/fused_moe/moe_mlp.py @@ -18,6 +18,8 @@ import torch import torch_npu +from vllm_ascend.ops.fused_moe.moe_runtime_args import MoEMlpComputeInput + def quant_apply_mlp( hidden_states: torch.Tensor, @@ -66,17 +68,20 @@ def unquant_apply_mlp( return hidden_states -def unified_apply_mlp( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - group_list: torch.Tensor, - w1_scale: torch.Tensor | None = None, - w2_scale: torch.Tensor | None = None, - group_list_type: int = 1, - with_quant: bool = False, -) -> torch.Tensor: - if with_quant: +def unified_apply_mlp(*, mlp_compute_input: MoEMlpComputeInput) -> torch.Tensor: + hidden_states = mlp_compute_input.hidden_states + w1 = mlp_compute_input.weights.w1 + w2 = mlp_compute_input.weights.w2 + w1_scale = mlp_compute_input.weights.w1_scale + w2_scale = mlp_compute_input.weights.w2_scale + group_list = mlp_compute_input.group_list + group_list_type = mlp_compute_input.group_list_type + assert isinstance(w1, torch.Tensor) + assert isinstance(w2, torch.Tensor) + + if mlp_compute_input.quant.is_quant: + assert isinstance(w1_scale, torch.Tensor) + assert isinstance(w2_scale, torch.Tensor) assert w1_scale is not None and w2_scale is not None return quant_apply_mlp( hidden_states=hidden_states, @@ -87,7 +92,11 @@ def unified_apply_mlp( group_list=group_list, group_list_type=group_list_type, ) - else: - return unquant_apply_mlp( - hidden_states=hidden_states, w1=w1, w2=w2, group_list=group_list, group_list_type=group_list_type - ) + + return unquant_apply_mlp( + hidden_states=hidden_states, + w1=w1, + w2=w2, + group_list=group_list, + group_list_type=group_list_type, + ) diff --git a/vllm_ascend/_310p/fused_moe/token_dispatcher.py b/vllm_ascend/_310p/fused_moe/token_dispatcher.py index e9e6a48c..b724644a 100644 --- a/vllm_ascend/_310p/fused_moe/token_dispatcher.py +++ b/vllm_ascend/_310p/fused_moe/token_dispatcher.py @@ -25,26 +25,27 @@ import torch from vllm.distributed.parallel_state import get_ep_group -from vllm_ascend.ops.fused_moe.token_dispatcher import TokenDispatcherWithAllGather, TokenDispatchResult +from vllm_ascend.ops.fused_moe.moe_runtime_args import MoEAllGatherCombineMetadata, MoETokenDispatchInput +from vllm_ascend.ops.fused_moe.token_dispatcher import MoETokenDispatchOutput, TokenDispatcherWithAllGather class TokenDispatcherWithAllGather310(TokenDispatcherWithAllGather): def __init__(self, **kwargs): super().__init__(**kwargs) - def token_dispatch( # type: ignore[override] + def token_dispatch( self, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - expert_map: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, + token_dispatch_input: MoETokenDispatchInput, ): - self.original_shape = hidden_states.shape + hidden_states = token_dispatch_input.hidden_states + topk_weights = token_dispatch_input.topk_weights + topk_ids = token_dispatch_input.topk_ids + expert_map = token_dispatch_input.routing.expert_map + apply_router_weight_on_input = token_dispatch_input.routing.apply_router_weight_on_input + restore_shape = hidden_states.shape num_tokens = hidden_states.shape[:-1].numel() - self.apply_router_weight_on_input = apply_router_weight_on_input - if self.apply_router_weight_on_input: + if apply_router_weight_on_input: assert topk_weights.dim() == 2, "`topk_weights` should be in shape (num_tokens, topk)" _, topk = topk_weights.shape assert topk == 1, "Only support topk=1 when `apply_router_weight_on_input` is True" @@ -66,13 +67,16 @@ class TokenDispatcherWithAllGather310(TokenDispatcherWithAllGather): ) expert_tokens = expert_tokens.to(torch.int64) group_list_type = 1 # `count` mode - context_metadata = {"topk_weights": topk_weights, "expanded_row_idx": expanded_row_idx} - return TokenDispatchResult( + return MoETokenDispatchOutput( hidden_states=sorted_hidden_states, group_list=expert_tokens, group_list_type=group_list_type, - context_metadata=context_metadata, + combine_metadata=MoEAllGatherCombineMetadata( + topk_weights=topk_weights, + expanded_row_idx=expanded_row_idx, + restore_shape=restore_shape, + ), ) def moe_init_routing(self, x, expert_idx, active_num, active_expert_range): diff --git a/vllm_ascend/_310p/quantization/methods/w8a8_dynamic.py b/vllm_ascend/_310p/quantization/methods/w8a8_dynamic.py index 6a1a1303..be72dbaa 100644 --- a/vllm_ascend/_310p/quantization/methods/w8a8_dynamic.py +++ b/vllm_ascend/_310p/quantization/methods/w8a8_dynamic.py @@ -25,6 +25,7 @@ from vllm.distributed import get_ep_group from vllm_ascend._310p.fused_moe.experts_selector import select_experts from vllm_ascend.ascend_forward_context import _EXTRA_CTX from vllm_ascend.ops.fused_moe.experts_selector import zero_experts_compute +from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input from vllm_ascend.quantization.methods.base import AscendMoEScheme, QuantType from .registry import register_scheme @@ -95,7 +96,9 @@ class AscendW8A8DynamicFusedMoEMethod310(AscendMoEScheme): log2phy: torch.Tensor | None = None, global_redundant_expert_num: int = 0, pertoken_scale: Any | None = None, - **kwargs, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + mc2_mask: torch.Tensor | None = None, ) -> torch.Tensor: zero_expert_num = getattr(layer, "zero_expert_num", 0) zero_expert_type = getattr(layer, "zero_expert_type", None) @@ -128,15 +131,19 @@ class AscendW8A8DynamicFusedMoEMethod310(AscendMoEScheme): moe_comm_method = _EXTRA_CTX.moe_comm_method final_hidden_states = moe_comm_method.fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w1_scale=layer.w13_weight_scale, - w2=layer.w2_weight, - w2_scale=layer.w2_weight_scale, - topk_weights=topk_weights, - topk_ids=topk_ids, - expert_map=expert_map, - use_int8_w8a8=True, + fused_experts_input=build_fused_experts_input( + hidden_states=x, + topk_weights=topk_weights, + topk_ids=topk_ids, + w1=layer.w13_weight, + w2=layer.w2_weight, + quant_type=self.quant_type, + dynamic_eplb=False, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + ), ) if zero_expert_num > 0 and zero_expert_type is not None: final_hidden_states += zero_expert_result diff --git a/vllm_ascend/ops/fused_moe/fused_moe.py b/vllm_ascend/ops/fused_moe/fused_moe.py index 3d858f7f..af56f65f 100644 --- a/vllm_ascend/ops/fused_moe/fused_moe.py +++ b/vllm_ascend/ops/fused_moe/fused_moe.py @@ -41,7 +41,8 @@ from vllm_ascend.eplb.core.eplb_utils import init_eplb_config from vllm_ascend.flash_common3_context import get_flash_common3_context, set_flash_common3_context from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_experts_compute from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl, FusedExpertsResult, setup_moe_comm_method -from vllm_ascend.quantization.methods.base import QuantType +from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input +from vllm_ascend.quantization.quant_type import QuantType from vllm_ascend.utils import ( ACL_FORMAT_FRACTAL_NZ, enable_sp, @@ -113,7 +114,9 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): activation: str = "silu", enable_force_load_balance: bool = False, log2phy: torch.Tensor = None, - **kwargs, + global_redundant_expert_num: int = 0, + pertoken_scale: torch.Tensor | None = None, + mc2_mask: torch.Tensor | None = None, ) -> torch.Tensor: zero_expert_num = getattr(layer, "zero_expert_num", 0) zero_expert_type = getattr(layer, "zero_expert_type", None) @@ -167,7 +170,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): # (due to signature constraints), we are forced to use a placeholder empty tensor. # This TODO tracks the requirement to update the C++ operator to accept Optional[Tensor] # or None for scales in non-quantized scenarios. - if get_forward_context().moe_comm_type == MoECommType.FUSED_MC2: + if _EXTRA_CTX.moe_comm_type == MoECommType.FUSED_MC2: w1 = [layer.w13_weight] w1_scale = [torch.tensor([], dtype=torch.int64)] w2 = [layer.w2_weight] @@ -179,21 +182,26 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): w2_scale = None final_hidden_states = moe_comm_method.fused_experts( - hidden_states=x, - w1=w1, - w2=w2, - w1_scale=w1_scale, - w2_scale=w2_scale, - w1_bias=layer.w13_bias if self.moe.has_bias else None, - w2_bias=layer.w2_bias if self.moe.has_bias else None, - activation=activation, - topk_weights=topk_weights, - topk_ids=topk_ids, - expert_map=expert_map, - apply_router_weight_on_input=apply_router_weight_on_input, - dynamic_eplb=self.dynamic_eplb, - log2phy=log2phy, - mc2_mask=kwargs.get("mc2_mask"), + fused_experts_input=build_fused_experts_input( + hidden_states=x, + topk_weights=topk_weights, + topk_ids=topk_ids, + w1=w1, + w2=w2, + w1_bias=layer.w13_bias if self.moe.has_bias else None, + w2_bias=layer.w2_bias if self.moe.has_bias else None, + quant_type=QuantType.NONE, + dynamic_eplb=self.dynamic_eplb, + expert_map=expert_map, + global_redundant_expert_num=global_redundant_expert_num, + mc2_mask=mc2_mask, + apply_router_weight_on_input=apply_router_weight_on_input, + log2phy=log2phy, + pertoken_scale=pertoken_scale, + activation=activation, + w1_scale=w1_scale, + w2_scale=w2_scale, + ) ) if zero_expert_num > 0 and zero_expert_type is not None: final_hidden_states += zero_expert_result @@ -474,23 +482,23 @@ class AscendFusedMoE(FusedMoE): set_flash_common3_context(topk_weights=topk_weights, topk_ids=topk_ids) - hidden_states, router_logits, mc2_mask, context_metadata = _EXTRA_CTX.moe_comm_method.prepare( + prepare_output = _EXTRA_CTX.moe_comm_method.prepare( hidden_states=hidden_states, router_logits=router_logits, replace_allreduce=_EXTRA_CTX.flash_comm_v1_enabled, enable_shared_expert_dp=self.enable_shared_expert_dp, quant_type=self.quant_type, ) + hidden_states = prepare_output.hidden_states + router_logits = prepare_output.router_logits + mc2_mask = prepare_output.mc2_mask + padded_hidden_states_shape = prepare_output.padded_hidden_states_shape + pertoken_scale = prepare_output.pertoken_scale # Make sure the default stream waits for the gate stream to finish. if self.multistream_overlap_gate: torch.npu.current_stream().wait_stream(AscendFusedMoE.gate_stream) - if isinstance(hidden_states, tuple): - hidden_states, pertoken_scale = hidden_states - else: - pertoken_scale = None - # Matrix multiply. fused_experts_results: FusedExpertsResult = self.quant_method.apply( layer=self, @@ -538,7 +546,7 @@ class AscendFusedMoE(FusedMoE): routed_out = _EXTRA_CTX.moe_comm_method.finalize( hidden_states=fused_experts_results.routed_out, reduce_results=self.reduce_results, - context_metadata=context_metadata, + padded_hidden_states_shape=padded_hidden_states_shape, ) if return_with_event: diff --git a/vllm_ascend/ops/fused_moe/moe_comm_method.py b/vllm_ascend/ops/fused_moe/moe_comm_method.py index 01f29ba8..b6d5a86c 100644 --- a/vllm_ascend/ops/fused_moe/moe_comm_method.py +++ b/vllm_ascend/ops/fused_moe/moe_comm_method.py @@ -24,6 +24,13 @@ from vllm.model_executor.layers.fused_moe import FusedMoEConfig import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp +from vllm_ascend.ops.fused_moe.moe_runtime_args import ( + MoEFusedExpertsInput, + MoEMlpComputeInput, + MoEPrepareOutput, + build_mlp_compute_input, + build_token_dispatch_input, +) from vllm_ascend.ops.fused_moe.prepare_finalize import ( PrepareAndFinalize, PrepareAndFinalizeWithAll2All, @@ -36,8 +43,7 @@ from vllm_ascend.ops.fused_moe.token_dispatcher import ( TokenDispatcherWithAllGather, TokenDispatcherWithMC2, ) -from vllm_ascend.quantization.methods.base import QuantType -from vllm_ascend.quantization.quant_parser import parse_mxfp_quant_params +from vllm_ascend.quantization.quant_type import QuantType _MoECommMethods: dict[MoECommType | None, MoECommMethod] = {} @@ -90,131 +96,70 @@ class MoECommMethod(ABC): enable_shared_expert_dp: bool = False, replace_allreduce: bool = False, quant_type: QuantType = QuantType.NONE, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]: - hidden_states, router_logits, mc2_mask, context_metadata = self.prepare_finalize.prepare( - hidden_states, router_logits, enable_shared_expert_dp, replace_allreduce, quant_type + ) -> MoEPrepareOutput: + return self.prepare_finalize.prepare( + hidden_states, + router_logits, + enable_shared_expert_dp, + replace_allreduce, + quant_type, ) - return hidden_states, router_logits, mc2_mask, context_metadata def finalize( - self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None + self, + hidden_states: torch.Tensor, + reduce_results: bool, + padded_hidden_states_shape: torch.Size | None = None, ) -> torch.Tensor: - hidden_states = self.prepare_finalize.finalize(hidden_states, reduce_results, context_metadata) + hidden_states = self.prepare_finalize.finalize(hidden_states, reduce_results, padded_hidden_states_shape) return hidden_states def fused_experts( self, - hidden_states: torch.Tensor, - w1: torch.Tensor | list[torch.Tensor], - w2: torch.Tensor | list[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str = "silu", - w1_bias: torch.Tensor = None, - w2_bias: torch.Tensor = None, - apply_router_weight_on_input: bool = False, - use_int8_w8a8: bool = False, - use_int4_w4a8: bool = False, - use_int4_w4a16: bool = False, - expert_map: torch.Tensor | None = None, - w1_scale: list[torch.Tensor] | None = None, - w2_scale: list[torch.Tensor] | None = None, - w1_scale_bias: torch.Tensor = None, - w2_scale_bias: torch.Tensor = None, - w1_offset: torch.Tensor | None = None, - w2_offset: torch.Tensor | None = None, - # For load balance - log2phy: torch.Tensor = None, - need_trans: bool = False, - dynamic_eplb: bool = False, - mc2_mask: torch.Tensor = None, - pertoken_scale: torch.Tensor | None = None, - **kwargs, + fused_experts_input: MoEFusedExpertsInput, ): # Check constraints - assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16, torch.int8] + assert fused_experts_input.hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16, torch.int8] moe_comm_method = _EXTRA_CTX.moe_comm_method assert moe_comm_method is not None, "Missing communication context" before_dispatch_evt = torch.npu.current_stream().record_event() - # Apply log2phy if needed - if log2phy is not None: - topk_ids = log2phy[topk_ids] - # TODO(linfeng): Current massive parameter passing is quite severe; parameter differences introduced - # by different quantization modes will be consolidated into a dataclass in a follow-up. - use_mxfp_quant = kwargs.get("use_mxfp_quant", False) - dispatch_with_quant = use_int8_w8a8 or use_int4_w4a8 or use_mxfp_quant - act_quant_type, weight_quant_type, scale_type, per_token_scale_type, round_mode = parse_mxfp_quant_params( - **kwargs + routed_topk_ids = fused_experts_input.topk_ids + if fused_experts_input.routing.log2phy is not None: + routed_topk_ids = fused_experts_input.routing.log2phy[routed_topk_ids] + + token_dispatch_input = build_token_dispatch_input( + fused_experts_input=fused_experts_input, + topk_ids=routed_topk_ids, + ) + token_dispatch_output = self.token_dispatcher.token_dispatch(token_dispatch_input=token_dispatch_input) + + mlp_compute_input = build_mlp_compute_input( + fused_experts_input=fused_experts_input, + token_dispatch_output=token_dispatch_output, + use_fusion_ops=self.use_fusion_ops, ) - dispatch_kwargs = { - "hidden_states": hidden_states, - "topk_weights": topk_weights, - "topk_ids": topk_ids, - "expert_map": expert_map, - "global_redundant_expert_num": self.moe_config.global_redundant_expert_num, - "mc2_mask": mc2_mask, - "apply_router_weight_on_input": apply_router_weight_on_input, - "dynamic_eplb": dynamic_eplb, - "pertoken_scale": pertoken_scale, - } - - if isinstance(self.token_dispatcher, TokenDispatcherWithMC2): - dispatch_kwargs["with_quant"] = dispatch_with_quant - dispatch_kwargs["comm_quant_mode"] = kwargs.get("comm_quant_mode") - dispatch_kwargs["y_dtype"] = act_quant_type if use_mxfp_quant else None - dispatch_kwargs["use_mxfp_quant"] = use_mxfp_quant - else: - dispatch_kwargs["with_quant"] = use_int8_w8a8 or use_int4_w4a8 - - dispatch_results = self.token_dispatcher.token_dispatch(**dispatch_kwargs) - - mlp_output = unified_apply_mlp( - hidden_states=dispatch_results.hidden_states, - w1=w1, - w1_scale=w1_scale, - w2=w2, - w2_scale=w2_scale, - w1_bias=w1_bias, - w2_bias=w2_bias, - activation=activation, - group_list=dispatch_results.group_list, - dynamic_scale=dispatch_results.dynamic_scale, - group_list_type=dispatch_results.group_list_type, - w1_scale_bias=w1_scale_bias, - w2_scale_bias=w2_scale_bias, - w1_offset=w1_offset, - w2_offset=w2_offset, - topk_scales=dispatch_results.topk_scales, - with_quant=use_int8_w8a8 or use_int4_w4a8 or use_int4_w4a16 or use_mxfp_quant, - fusion=(use_int8_w8a8 or use_mxfp_quant) and self.use_fusion_ops, - need_trans=need_trans, - dynamic_eplb=dynamic_eplb, - use_mxfp_quant=use_mxfp_quant, - act_quant_type=act_quant_type, - weight_quant_type=weight_quant_type, - scale_type=scale_type, - per_token_scale_type=per_token_scale_type, - round_mode=round_mode, - use_bf16=(hidden_states.dtype == torch.bfloat16), - rollback_quant_config=kwargs.get("rollback_quant_config"), - ) + mlp_output = self._apply_mlp(mlp_compute_input) before_combine_evt = torch.npu.current_stream().record_event() - combine_results = self.token_dispatcher.token_combine( - hidden_states=mlp_output, context_metadata=dispatch_results.context_metadata + routed_out = self.token_dispatcher.token_combine( + hidden_states=mlp_output, + combine_metadata=token_dispatch_output.combine_metadata, ) return FusedExpertsResult( - routed_out=combine_results.routed_out, + routed_out=routed_out, before_dispatch_evt=before_dispatch_evt, before_combine_evt=before_combine_evt, - group_list_type=dispatch_results.group_list_type, - expert_tokens=dispatch_results.group_list, + group_list_type=token_dispatch_output.group_list_type, + expert_tokens=token_dispatch_output.group_list, ) + def _apply_mlp(self, mlp_compute_input: MoEMlpComputeInput) -> torch.Tensor: + return unified_apply_mlp(mlp_compute_input=mlp_compute_input) + @abstractmethod def _get_token_dispatcher(self) -> MoETokenDispatcher: raise NotImplementedError("_get_token_dispatcher function not implemented.") @@ -317,54 +262,32 @@ class FusedMC2CommImpl(MoECommMethod): def fused_experts( self, - hidden_states: torch.Tensor, - w1: torch.Tensor | list[torch.Tensor], - w2: torch.Tensor | list[torch.Tensor], - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - activation: str = "silu", - w1_bias: torch.Tensor = None, - w2_bias: torch.Tensor = None, - apply_router_weight_on_input: bool = False, - use_int8_w8a8: bool = False, - use_int4_w4a8: bool = False, - use_int4_w4a16: bool = False, - expert_map: torch.Tensor | None = None, - w1_scale: list[torch.Tensor] | None = None, - w2_scale: list[torch.Tensor] | None = None, - w1_scale_bias: torch.Tensor = None, - w2_scale_bias: torch.Tensor = None, - w1_offset: torch.Tensor | None = None, - w2_offset: torch.Tensor | None = None, - # For load balance - log2phy: torch.Tensor = None, - need_trans: bool = False, - dynamic_eplb: bool = False, - mc2_mask: torch.Tensor = None, - pertoken_scale: torch.Tensor | None = None, - **kwargs, + fused_experts_input: MoEFusedExpertsInput, ): - assert not (w1_scale is None or w2_scale is None), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl." + assert not (fused_experts_input.weights.w1_scale is None or fused_experts_input.weights.w2_scale is None), ( + "w1_scale and w2_scale cannot be None for FusedMC2CommImpl." + ) assert isinstance(self.token_dispatcher, TokenDispatcherWithMC2), ( "token_dispatcher must be an instance of TokenDispatcherWithMC2." ) # Apply log2phy if needed - if log2phy is not None: - topk_ids = log2phy[topk_ids] + topk_ids = fused_experts_input.topk_ids + if fused_experts_input.routing.log2phy is not None: + topk_ids = fused_experts_input.routing.log2phy[topk_ids] expert_tokens = None if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: - out = torch.empty_like(hidden_states) + out = torch.empty_like(fused_experts_input.hidden_states) torch.ops._C_ascend.dispatch_ffn_combine( # type: ignore - x=hidden_states, - weight1=w1, - weight2=w2, + x=fused_experts_input.hidden_states, + weight1=fused_experts_input.weights.w1, + weight2=fused_experts_input.weights.w2, expert_idx=topk_ids, - scale1=w1_scale, - scale2=w2_scale, - probs=topk_weights.to(torch.float32), + scale1=fused_experts_input.weights.w1_scale, + scale2=fused_experts_input.weights.w2_scale, + probs=fused_experts_input.topk_weights.to(torch.float32), group=self.token_dispatcher.moe_all_to_all_group_name, max_output_size=65536, out=out, @@ -372,16 +295,16 @@ class FusedMC2CommImpl(MoECommMethod): ) expert_tokens = self.expert_token_nums elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2: - assert expert_map is not None, "expert_map cannot be None." + assert fused_experts_input.routing.expert_map is not None, "expert_map cannot be None." out, expert_tokens = torch.ops._C_ascend.dispatch_gmm_combine_decode( # type: ignore - x=hidden_states, + x=fused_experts_input.hidden_states, expert_ids=topk_ids, - gmm1_permuted_weight=w1, - gmm1_permuted_weight_scale=w1_scale, - gmm2_weight=w2, - gmm2_weight_scale=w2_scale, + gmm1_permuted_weight=fused_experts_input.weights.w1, + gmm1_permuted_weight_scale=fused_experts_input.weights.w1_scale, + gmm2_weight=fused_experts_input.weights.w2, + gmm2_weight_scale=fused_experts_input.weights.w2_scale, expert_smooth_scales=None, - expert_scales=topk_weights.to(torch.float32), + expert_scales=fused_experts_input.topk_weights.to(torch.float32), group_ep=self.token_dispatcher.moe_all_to_all_group_name, ep_rank_size=self.token_dispatcher.ep_world_size, ep_rank_id=self.token_dispatcher.ep_rank_id, diff --git a/vllm_ascend/ops/fused_moe/moe_mlp.py b/vllm_ascend/ops/fused_moe/moe_mlp.py index 74b84f80..081b102c 100644 --- a/vllm_ascend/ops/fused_moe/moe_mlp.py +++ b/vllm_ascend/ops/fused_moe/moe_mlp.py @@ -27,6 +27,7 @@ from vllm_ascend.device.mxfp_compat import ( ensure_mxfp8_moe_available, ) from vllm_ascend.ops.activation import AscendSwigluOAIAndMul +from vllm_ascend.ops.fused_moe.moe_runtime_args import MoEMlpComputeInput from vllm_ascend.utils import ( dispose_tensor, enable_custom_op, @@ -95,27 +96,17 @@ def quant_apply_mlp( w2_offset: torch.Tensor | None = None, fusion: bool = False, dynamic_eplb: bool = False, - **kwargs, + use_mxfp_quant: bool = False, + act_quant_type: torch.dtype = torch.float8_e4m3fn, + weight_quant_type: torch.dtype | None = None, + scale_type: torch.dtype | None = None, + per_token_scale_type: torch.dtype | None = None, + use_bf16: bool = True, ) -> torch.Tensor: - # TODO(linfeng): Current massive parameter passing is quite severe; parameter differences introduced by different - # quantization modes will be consolidated into a dataclass in a follow-up. - use_mxfp_quant = kwargs.get("use_mxfp_quant", False) - act_quant_type = torch.float8_e4m3fn - weight_quant_type = None - scale_type = None - per_token_scale_type = None - use_bf16 = True - input_hidden_dtype = hidden_states.dtype use_gmm_swiglu_quant_fusion = use_mxfp_quant or (fusion and not dynamic_eplb) if use_mxfp_quant: - act_quant_type = kwargs.get("act_quant_type", torch.float8_e4m3fn) - weight_quant_type = kwargs.get("weight_quant_type", torch.float8_e4m3fn) - scale_type = kwargs.get("scale_type") - per_token_scale_type = kwargs.get("per_token_scale_type") - use_bf16 = kwargs.get("use_bf16", True) - ensure_mxfp8_moe_available("MXFP MoE MLP path") if w1_scale_bias is not None or w2_scale_bias is not None: @@ -393,34 +384,32 @@ def unquant_apply_mlp( return hidden_states -def unified_apply_mlp( - hidden_states: torch.Tensor, - w1: torch.Tensor | list[torch.Tensor], - w2: torch.Tensor | list[torch.Tensor], - group_list: torch.Tensor, - w1_scale: list[torch.Tensor] | None = None, - w2_scale: list[torch.Tensor] | None = None, - activation: str | None = None, - w1_bias: torch.Tensor = None, - w2_bias: torch.Tensor = None, - dynamic_scale: torch.Tensor = None, - group_list_type: int = 1, - w1_scale_bias: torch.Tensor = None, - w2_scale_bias: torch.Tensor = None, - w1_offset: torch.Tensor | None = None, - w2_offset: torch.Tensor | None = None, - topk_scales: torch.Tensor | None = None, - with_quant: bool = False, - fusion: bool = False, - need_trans: bool = True, - dynamic_eplb: bool = False, - **kwargs, -) -> torch.Tensor: +def unified_apply_mlp(*, mlp_compute_input: MoEMlpComputeInput) -> torch.Tensor: """ Unified MoE MLP entry. - Quant path is dispatched by DeviceOperator with explicit quant-type flags. + Quant path is dispatched by DeviceOperator with explicit typed kernel flags. """ - if not with_quant: + hidden_states = mlp_compute_input.hidden_states + group_list = mlp_compute_input.group_list + group_list_type = mlp_compute_input.group_list_type + dynamic_scale = mlp_compute_input.dynamic_scale + topk_scales = mlp_compute_input.topk_scales + w1 = mlp_compute_input.weights.w1 + w2 = mlp_compute_input.weights.w2 + w1_bias = mlp_compute_input.weights.w1_bias + w2_bias = mlp_compute_input.weights.w2_bias + w1_scale = mlp_compute_input.weights.w1_scale + w2_scale = mlp_compute_input.weights.w2_scale + w1_scale_bias = mlp_compute_input.weights.w1_scale_bias + w2_scale_bias = mlp_compute_input.weights.w2_scale_bias + w1_offset = mlp_compute_input.weights.w1_offset + w2_offset = mlp_compute_input.weights.w2_offset + activation = mlp_compute_input.activation + need_trans = mlp_compute_input.need_trans + dynamic_eplb = mlp_compute_input.dynamic_eplb + fusion = mlp_compute_input.fusion + + if not mlp_compute_input.quant.is_quant: return unquant_apply_mlp( hidden_states=hidden_states, w1=w1, @@ -435,13 +424,22 @@ def unified_apply_mlp( ) assert w1_scale is not None and w2_scale is not None - # TODO(linfeng): Current massive parameter passing is quite severe; parameter differences introduced by different - # quantization modes will be consolidated into a dataclass in a follow-up. - act_quant_type = kwargs.get("act_quant_type", torch.float8_e4m3fn) - weight_quant_type = kwargs.get("weight_quant_type", torch.float8_e4m3fn) - scale_type = kwargs.get("scale_type") - per_token_scale_type = kwargs.get("per_token_scale_type") - use_mxfp_quant = kwargs.get("use_mxfp_quant", False) + act_quant_type = torch.float8_e4m3fn + weight_quant_type = torch.float8_e4m3fn + scale_type = None + per_token_scale_type = None + use_bf16 = hidden_states.dtype == torch.bfloat16 + use_mxfp_quant = mlp_compute_input.quant.is_mxfp + + if use_mxfp_quant: + mxfp = mlp_compute_input.quant.mxfp + assert mxfp is not None, "mlp_compute_input.quant.mxfp is required when quant_type is MXFP8." + act_quant_type = mxfp.act_quant_type or act_quant_type + weight_quant_type = mxfp.weight_quant_type or weight_quant_type + scale_type = mxfp.scale_dtype + per_token_scale_type = mxfp.per_token_scale_dtype + use_bf16 = mxfp.use_bf16 + return quant_apply_mlp( hidden_states=hidden_states, w1=w1, @@ -457,10 +455,10 @@ def unified_apply_mlp( w2_offset=w2_offset, fusion=fusion, dynamic_eplb=dynamic_eplb, + use_mxfp_quant=use_mxfp_quant, act_quant_type=act_quant_type, weight_quant_type=weight_quant_type, scale_type=scale_type, per_token_scale_type=per_token_scale_type, - use_mxfp_quant=use_mxfp_quant, - use_bf16=kwargs.get("use_bf16", True), + use_bf16=use_bf16, ) diff --git a/vllm_ascend/ops/fused_moe/moe_runtime_args.py b/vllm_ascend/ops/fused_moe/moe_runtime_args.py new file mode 100644 index 00000000..573de7f1 --- /dev/null +++ b/vllm_ascend/ops/fused_moe/moe_runtime_args.py @@ -0,0 +1,244 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Typed runtime contracts and builders for fused MoE execution. + +This module is the single entry point for the runtime payloads used across the +fused MoE pipeline. + +Relationship overview: + + stage params: reusable sub-payloads + - MoERoutingParams + - MoEQuantParams + - internal MXFP leaf: MoEMxfpParams + + stage contracts: stage input/output payloads + prepare + -> MoEPrepareOutput + + fused_experts input + -> MoEFusedExpertsInput + |- weights: MoEWeights + |- routing: MoERoutingParams + |- quant: MoEQuantParams + + dispatch + input -> MoETokenDispatchInput + output -> MoETokenDispatchOutput[TMoECombineMetadata] + TMoECombineMetadata is one of: + - MoEAllGatherCombineMetadata + - MoEAllToAllCombineMetadata + - MoEMC2CombineMetadata + + mlp + input -> MoEMlpComputeInput + + combine + output -> torch.Tensor + +The helper builders below adapt legacy call sites into these typed contracts. +Only the fused_moe package should need to know about the internal MXFP leaf +dataclass directly. +""" + +from __future__ import annotations + +import torch + +import vllm_ascend.ops.fused_moe.moe_stage_params as _stage_params +from vllm_ascend.ops.fused_moe.moe_stage_contracts import ( + MoEAllGatherCombineMetadata, + MoEAllToAllCombineMetadata, + MoEFusedExpertsInput, + MoEMC2CombineMetadata, + MoEMlpComputeInput, + MoEPrepareOutput, + MoETokenDispatchInput, + MoETokenDispatchOutput, + MoEWeights, + TMoECombineMetadata, +) +from vllm_ascend.ops.fused_moe.moe_stage_params import ( + MoEQuantParams, + MoERoutingParams, +) +from vllm_ascend.quantization.quant_type import QuantType + + +def _build_mxfp_params( + *, + quant_type: QuantType, + mxfp_act_quant_type: torch.dtype | None = None, + mxfp_weight_quant_type: torch.dtype | None = None, + mxfp_scale_dtype: torch.dtype | None = None, + mxfp_per_token_scale_dtype: torch.dtype | None = None, + mxfp_use_bf16: bool | None = None, +) -> _stage_params.MoEMxfpParams | None: + if quant_type != QuantType.MXFP8: + return None + + has_explicit_mxfp_args = any( + value is not None + for value in ( + mxfp_act_quant_type, + mxfp_weight_quant_type, + mxfp_scale_dtype, + mxfp_per_token_scale_dtype, + mxfp_use_bf16, + ) + ) + if not has_explicit_mxfp_args: + raise ValueError("primitive MXFP params are required when quant_type is QuantType.MXFP8.") + + return _stage_params.MoEMxfpParams( + act_quant_type=mxfp_act_quant_type, + weight_quant_type=mxfp_weight_quant_type, + scale_dtype=mxfp_scale_dtype, + per_token_scale_dtype=mxfp_per_token_scale_dtype, + use_bf16=True if mxfp_use_bf16 is None else mxfp_use_bf16, + ) + + +def build_fused_experts_input( + *, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + w1: torch.Tensor | list[torch.Tensor], + w2: torch.Tensor | list[torch.Tensor], + quant_type: QuantType, + dynamic_eplb: bool, + expert_map: torch.Tensor | None = None, + global_redundant_expert_num: int = 0, + mc2_mask: torch.Tensor | None = None, + apply_router_weight_on_input: bool = False, + log2phy: torch.Tensor | None = None, + pertoken_scale: torch.Tensor | None = None, + activation: str = "silu", + need_trans: bool = False, + w1_bias: torch.Tensor | None = None, + w2_bias: torch.Tensor | None = None, + comm_quant_mode: int | None = None, + mxfp_act_quant_type: torch.dtype | None = None, + mxfp_weight_quant_type: torch.dtype | None = None, + mxfp_scale_dtype: torch.dtype | None = None, + mxfp_per_token_scale_dtype: torch.dtype | None = None, + mxfp_use_bf16: bool | None = None, + w1_scale: list[torch.Tensor] | torch.Tensor | None = None, + w2_scale: list[torch.Tensor] | torch.Tensor | None = None, + w1_scale_bias: torch.Tensor | None = None, + w2_scale_bias: torch.Tensor | None = None, + w1_offset: torch.Tensor | None = None, + w2_offset: torch.Tensor | None = None, +) -> MoEFusedExpertsInput: + return MoEFusedExpertsInput( + hidden_states=hidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + weights=MoEWeights( + w1=w1, + w2=w2, + w1_bias=w1_bias, + w2_bias=w2_bias, + w1_scale=w1_scale, + w2_scale=w2_scale, + w1_scale_bias=w1_scale_bias, + w2_scale_bias=w2_scale_bias, + w1_offset=w1_offset, + w2_offset=w2_offset, + ), + routing=MoERoutingParams( + expert_map=expert_map, + global_redundant_expert_num=global_redundant_expert_num, + mc2_mask=mc2_mask, + apply_router_weight_on_input=apply_router_weight_on_input, + log2phy=log2phy, + pertoken_scale=pertoken_scale, + ), + activation=activation, + need_trans=need_trans, + dynamic_eplb=dynamic_eplb, + quant=MoEQuantParams( + quant_type=quant_type, + comm_quant_mode=comm_quant_mode, + mxfp=_build_mxfp_params( + quant_type=quant_type, + mxfp_act_quant_type=mxfp_act_quant_type, + mxfp_weight_quant_type=mxfp_weight_quant_type, + mxfp_scale_dtype=mxfp_scale_dtype, + mxfp_per_token_scale_dtype=mxfp_per_token_scale_dtype, + mxfp_use_bf16=mxfp_use_bf16, + ), + ), + ) + + +def build_token_dispatch_input( + *, + fused_experts_input: MoEFusedExpertsInput, + topk_ids: torch.Tensor | None = None, +) -> MoETokenDispatchInput: + return MoETokenDispatchInput( + hidden_states=fused_experts_input.hidden_states, + topk_weights=fused_experts_input.topk_weights, + topk_ids=fused_experts_input.topk_ids if topk_ids is None else topk_ids, + routing=fused_experts_input.routing, + quant=fused_experts_input.quant, + ) + + +def build_mlp_compute_input( + *, + fused_experts_input: MoEFusedExpertsInput, + token_dispatch_output: MoETokenDispatchOutput[TMoECombineMetadata], + use_fusion_ops: bool, +) -> MoEMlpComputeInput: + if fused_experts_input.quant.is_mxfp and fused_experts_input.quant.mxfp is None: + raise ValueError("fused_experts_input.quant.mxfp is required when quant_type is QuantType.MXFP8.") + + return MoEMlpComputeInput( + hidden_states=token_dispatch_output.hidden_states, + group_list=token_dispatch_output.group_list, + group_list_type=token_dispatch_output.group_list_type, + dynamic_scale=token_dispatch_output.dynamic_scale, + topk_scales=token_dispatch_output.topk_scales, + weights=fused_experts_input.weights, + quant=fused_experts_input.quant, + fusion=fused_experts_input.quant.quant_type in (QuantType.W8A8, QuantType.MXFP8) and use_fusion_ops, + activation=fused_experts_input.activation, + need_trans=fused_experts_input.need_trans, + dynamic_eplb=fused_experts_input.dynamic_eplb, + ) + + +__all__ = [ + "MoEAllGatherCombineMetadata", + "MoEAllToAllCombineMetadata", + "MoEFusedExpertsInput", + "MoEMC2CombineMetadata", + "MoEMlpComputeInput", + "MoEPrepareOutput", + "MoEQuantParams", + "MoERoutingParams", + "MoETokenDispatchInput", + "MoETokenDispatchOutput", + "MoEWeights", + "TMoECombineMetadata", + "build_fused_experts_input", + "build_token_dispatch_input", + "build_mlp_compute_input", +] diff --git a/vllm_ascend/ops/fused_moe/moe_stage_contracts.py b/vllm_ascend/ops/fused_moe/moe_stage_contracts.py new file mode 100644 index 00000000..1e137498 --- /dev/null +++ b/vllm_ascend/ops/fused_moe/moe_stage_contracts.py @@ -0,0 +1,154 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from dataclasses import dataclass +from typing import Generic, TypeVar + +import numpy as np +import torch + +from vllm_ascend.ops.fused_moe.moe_stage_params import MoEQuantParams, MoERoutingParams + +TMoECombineMetadata = TypeVar("TMoECombineMetadata") + + +# prepare -> fused_experts +@dataclass(frozen=True, slots=True) +class MoEPrepareOutput: + """Typed output from prepare stage.""" + + hidden_states: torch.Tensor + router_logits: torch.Tensor + mc2_mask: torch.Tensor | None + padded_hidden_states_shape: torch.Size | None + pertoken_scale: torch.Tensor | None = None + + +@dataclass(frozen=True, slots=True) +class MoEWeights: + """Dense and quantized weight payloads consumed by MoE execution.""" + + w1: torch.Tensor | list[torch.Tensor] + w2: torch.Tensor | list[torch.Tensor] + w1_bias: torch.Tensor | None = None + w2_bias: torch.Tensor | None = None + w1_scale: torch.Tensor | list[torch.Tensor] | None = None + w2_scale: torch.Tensor | list[torch.Tensor] | None = None + w1_scale_bias: torch.Tensor | None = None + w2_scale_bias: torch.Tensor | None = None + w1_offset: torch.Tensor | None = None + w2_offset: torch.Tensor | None = None + + +@dataclass(frozen=True, slots=True) +class MoEFusedExpertsInput: + """Top-level input for the routed experts pipeline.""" + + hidden_states: torch.Tensor + topk_weights: torch.Tensor + topk_ids: torch.Tensor + weights: MoEWeights + routing: MoERoutingParams + quant: MoEQuantParams + activation: str = "silu" + need_trans: bool = False + dynamic_eplb: bool = False + + +@dataclass(frozen=True, slots=True) +class MoETokenDispatchInput: + """Input to token dispatch.""" + + hidden_states: torch.Tensor + topk_weights: torch.Tensor + topk_ids: torch.Tensor + routing: MoERoutingParams + quant: MoEQuantParams + + +# dispatch carry-over state consumed by combine +@dataclass(frozen=True, slots=True) +class MoEMC2CombineMetadata: + topk_ids: torch.Tensor + topk_weights: torch.Tensor + expert_map: torch.Tensor | None + ep_recv_counts: torch.Tensor + tp_recv_counts: torch.Tensor + assist_info_for_combine: torch.Tensor + expand_scales: torch.Tensor | None + dispatch_with_quant: bool + + +@dataclass(frozen=True, slots=True) +class MoEAllGatherCombineMetadata: + topk_weights: torch.Tensor + expanded_row_idx: torch.Tensor + restore_shape: torch.Size + + +@dataclass(frozen=True, slots=True) +class MoEAllToAllCombineMetadata: + input_splits: np.ndarray + output_splits: np.ndarray + topk_weights: torch.Tensor + reversed_local_input_permutation_mapping: torch.Tensor + reversed_global_input_permutation_mapping: torch.Tensor | None + hidden_shape: torch.Size + hidden_shape_before_permute: torch.Size + + +@dataclass(frozen=True, slots=True) +class MoETokenDispatchOutput(Generic[TMoECombineMetadata]): + hidden_states: torch.Tensor + group_list: torch.Tensor + group_list_type: int + combine_metadata: TMoECombineMetadata + dynamic_scale: torch.Tensor | None = None + topk_scales: torch.Tensor | None = None + + +# dispatch -> mlp -> combine +@dataclass(frozen=True, slots=True) +class MoEMlpComputeInput: + """Input to MLP compute.""" + + hidden_states: torch.Tensor + group_list: torch.Tensor + group_list_type: int + dynamic_scale: torch.Tensor | None + topk_scales: torch.Tensor | None + weights: MoEWeights + quant: MoEQuantParams + fusion: bool + activation: str = "silu" + need_trans: bool = False + dynamic_eplb: bool = False + + +__all__ = [ + "MoEPrepareOutput", + "MoEWeights", + "MoEFusedExpertsInput", + "MoETokenDispatchInput", + "MoEMC2CombineMetadata", + "MoEAllGatherCombineMetadata", + "MoEAllToAllCombineMetadata", + "MoETokenDispatchOutput", + "MoEMlpComputeInput", + "TMoECombineMetadata", +] diff --git a/vllm_ascend/ops/fused_moe/moe_stage_params.py b/vllm_ascend/ops/fused_moe/moe_stage_params.py new file mode 100644 index 00000000..230a0b9e --- /dev/null +++ b/vllm_ascend/ops/fused_moe/moe_stage_params.py @@ -0,0 +1,86 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import annotations + +from dataclasses import dataclass + +import torch + +from vllm_ascend.quantization.quant_type import QuantType + + +@dataclass(frozen=True, slots=True) +class MoERoutingParams: + """Routing and dispatch side inputs for one MoE invocation. + + `pertoken_scale` is intentionally kept here even though it is not a pure + routing concept. It is used by pre-quantized activation flows, currently + the AllGather + EP W8A8 prepare path, where prepare emits per-token + activation scales and dispatch needs to carry them forward so the MLP + quant path can reuse those scales instead of requantizing activations. + """ + + expert_map: torch.Tensor | None + global_redundant_expert_num: int + mc2_mask: torch.Tensor | None + apply_router_weight_on_input: bool + log2phy: torch.Tensor | None = None + # Precomputed activation scales from prepare stage for quantized dispatch. + pertoken_scale: torch.Tensor | None = None + + +@dataclass(frozen=True, slots=True) +class MoEMxfpParams: + """Internal MXFP-only precision settings used by fused_moe runtime.""" + + act_quant_type: torch.dtype | None = None + weight_quant_type: torch.dtype | None = None + scale_dtype: torch.dtype | None = None + per_token_scale_dtype: torch.dtype | None = None + use_bf16: bool = True + + +@dataclass(frozen=True, slots=True) +class MoEQuantParams: + """Quant mode, backend override, and optional internal MXFP leaf config.""" + + quant_type: QuantType = QuantType.NONE + comm_quant_mode: int | None = None + mxfp: MoEMxfpParams | None = None + + @property + def is_quant(self) -> bool: + return self.quant_type != QuantType.NONE + + @property + def is_mxfp(self) -> bool: + return self.quant_type == QuantType.MXFP8 + + @property + def is_int_quant(self) -> bool: + return self.quant_type in (QuantType.W8A8, QuantType.W4A8) + + @property + def dispatch_with_quant(self) -> bool: + return self.quant_type in (QuantType.W8A8, QuantType.W4A8, QuantType.MXFP8) + + +__all__ = [ + "MoERoutingParams", + "MoEMxfpParams", + "MoEQuantParams", +] diff --git a/vllm_ascend/ops/fused_moe/prepare_finalize.py b/vllm_ascend/ops/fused_moe/prepare_finalize.py index 6c7358aa..3f4ec057 100644 --- a/vllm_ascend/ops/fused_moe/prepare_finalize.py +++ b/vllm_ascend/ops/fused_moe/prepare_finalize.py @@ -31,7 +31,8 @@ from vllm.model_executor.layers.fused_moe import FusedMoEConfig from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import _EXTRA_CTX from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl -from vllm_ascend.quantization.methods.base import QuantType +from vllm_ascend.ops.fused_moe.moe_runtime_args import MoEPrepareOutput +from vllm_ascend.quantization.quant_type import QuantType from vllm_ascend.utils import enable_sp, npu_stream_switch, prefill_context_parallel_enable @@ -64,7 +65,7 @@ class PrepareAndFinalize(ABC): enable_shared_expert_dp: bool = False, replace_allreduce: bool = False, quant_type: QuantType = QuantType.NONE, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + ) -> MoEPrepareOutput: """ Prepare tensors before MoE computation. May involve: - Padding to align communication boundaries @@ -79,16 +80,20 @@ class PrepareAndFinalize(ABC): quant_type: none, w8a8, w4a8 or mxfp8 Returns: - Tuple of: + MoEPrepareOutput: - processed hidden_states (may be padded/sliced/broadcasted) - processed router_logits (may be recomputed or broadcasted) - optional communication mask (e.g., mc2_mask for sparse ops) - - optional context metadata (e.g., saved split_hidden_states for finalization) + - optional padded hidden state shape for finalization + - optional per-token scale for quantized path """ raise NotImplementedError("Prepare not implemented.") def finalize( - self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None + self, + hidden_states: torch.Tensor, + reduce_results: bool, + padded_hidden_states_shape: torch.Size | None = None, ) -> torch.Tensor: """ Finalize MoE output. May involve: @@ -130,7 +135,7 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize): enable_shared_expert_dp: bool = False, replace_allreduce: bool = False, quant_type=QuantType.NONE, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + ) -> MoEPrepareOutput: """ Preparation steps: 1. Pad hidden_states and router_logits to next multiple of TP size. @@ -140,7 +145,7 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize): Skips if `enable_shared_expert_dp` or `replace_allreduce` is True. Returns: - Tuple of (hidden_states, router_logits, None, context_metadata) — no mask used in All2All. + MoEPrepareOutput where `mc2_mask` is None for All2All path. """ self.replace_allreduce = replace_allreduce self.enable_shared_expert_dp = enable_shared_expert_dp @@ -162,12 +167,19 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize): hidden_states = split_hidden_states[self.tp_rank] router_logits = split_router_logits[self.tp_rank] - context_metadata = {"padded_hidden_states_shape": padded_hidden_states_shape} - - return hidden_states, router_logits, None, context_metadata + return MoEPrepareOutput( + hidden_states=hidden_states, + router_logits=router_logits, + mc2_mask=None, + padded_hidden_states_shape=padded_hidden_states_shape, + pertoken_scale=None, + ) def finalize( - self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None + self, + hidden_states: torch.Tensor, + reduce_results: bool, + padded_hidden_states_shape: torch.Size | None = None, ) -> torch.Tensor: """ Finalization steps: @@ -180,12 +192,11 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize): if not (self.enable_shared_expert_dp or self.replace_allreduce): if self.tp_size > 1: - assert context_metadata is not None + assert padded_hidden_states_shape is not None # Cannot reuse `split_hidden_states` from prepare phase as it # may share memory with original hidden_states. Since shared # experts may use the original tensor, reusing it would cause # in-place modification during all_gather, corrupting the data. - padded_hidden_states_shape = context_metadata["padded_hidden_states_shape"] gathered_hidden_states = torch.empty( padded_hidden_states_shape, device=hidden_states.device, dtype=hidden_states.dtype ) @@ -227,7 +238,7 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All): enable_shared_expert_dp: bool = False, replace_allreduce: bool = False, quant_type=QuantType.NONE, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + ) -> MoEPrepareOutput: """ Preparation steps: 1. Fetch `mc2_mask` and target padding length from forward context. @@ -238,7 +249,7 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All): Skips padding/slicing if `enable_shared_expert_dp` or `replace_allreduce` is True. Returns: - Tuple of (hidden_states, router_logits, mc2_mask, context_metadata), possibly sliced/padded. + MoEPrepareOutput, possibly sliced/padded. """ self.replace_allreduce = replace_allreduce self.enable_shared_expert_dp = enable_shared_expert_dp @@ -267,11 +278,13 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All): hidden_states = split_hidden_states[self.tp_rank] router_logits = split_router_logits[self.tp_rank] - context_metadata = { - "padded_hidden_states_shape": padded_hidden_states_shape, - } - - return hidden_states, router_logits, mc2_mask, context_metadata + return MoEPrepareOutput( + hidden_states=hidden_states, + router_logits=router_logits, + mc2_mask=mc2_mask, + padded_hidden_states_shape=padded_hidden_states_shape, + pertoken_scale=None, + ) class PrepareAndFinalizeWithAllGather(PrepareAndFinalize): @@ -303,13 +316,13 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize): enable_shared_expert_dp: bool = False, replace_allreduce: bool = False, quant_type=QuantType.NONE, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + ) -> MoEPrepareOutput: """ Preparation steps: AllGather hidden_states and router_logits to form global tensors. Returns: - Tuple of (global_hidden_states, global_router_logits, None) + MoEPrepareOutput with global tensors. """ if enable_sp(): return self._prepare_with_ep_group(hidden_states, router_logits, quant_type) @@ -318,7 +331,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize): def _prepare_with_ep_group( self, hidden_states: torch.Tensor, router_logits: torch.Tensor, quant_type=QuantType.NONE - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + ) -> MoEPrepareOutput: pertoken_scale = None if quant_type == QuantType.W8A8: hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states) @@ -342,10 +355,13 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize): if self.multistream_overlap_gate: torch.npu.current_stream().wait_stream(PrepareAndFinalize.quant_stream) - if pertoken_scale is not None: - return (hidden_states, pertoken_scale), router_logits, None, None - - return hidden_states, router_logits, None, None + return MoEPrepareOutput( + hidden_states=hidden_states, + router_logits=router_logits, + mc2_mask=None, + padded_hidden_states_shape=None, + pertoken_scale=pertoken_scale, + ) def _prepare_with_dp_group( self, @@ -354,7 +370,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize): enable_shared_expert_dp: bool = False, replace_allreduce: bool = False, quant_type=QuantType.NONE, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + ) -> MoEPrepareOutput: """ Preparation steps: 1. Fetch max token count across DP group from forward context. @@ -362,7 +378,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize): 3. All-gather across DP group to form global input tensor. Returns: - Tuple of (global_hidden_states, global_router_logits, None, None) + MoEPrepareOutput with global tensors. """ self.enable_shared_expert_dp = enable_shared_expert_dp if self.moe_config.dp_size > 1: @@ -396,10 +412,19 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize): dim=0, ) - return hidden_states, router_logits, None, None + return MoEPrepareOutput( + hidden_states=hidden_states, + router_logits=router_logits, + mc2_mask=None, + padded_hidden_states_shape=None, + pertoken_scale=None, + ) def finalize( - self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None + self, + hidden_states: torch.Tensor, + reduce_results: bool, + padded_hidden_states_shape: torch.Size | None = None, ) -> torch.Tensor: """ Finalization steps: diff --git a/vllm_ascend/ops/fused_moe/token_dispatcher.py b/vllm_ascend/ops/fused_moe/token_dispatcher.py index bf4a5972..152f6e89 100644 --- a/vllm_ascend/ops/fused_moe/token_dispatcher.py +++ b/vllm_ascend/ops/fused_moe/token_dispatcher.py @@ -21,7 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from abc import ABC, abstractmethod -from dataclasses import dataclass, field +from typing import Generic import torch import torch_npu @@ -31,25 +31,18 @@ from vllm.distributed.parallel_state import get_ep_group from vllm_ascend.device.device_op import DeviceOperator from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.fused_moe.comm_utils import async_all_to_all, gather_from_sequence_parallel_region +from vllm_ascend.ops.fused_moe.moe_runtime_args import ( + MoEAllGatherCombineMetadata, + MoEAllToAllCombineMetadata, + MoEMC2CombineMetadata, + MoETokenDispatchInput, + MoETokenDispatchOutput, + TMoECombineMetadata, +) from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, is_hierarchical_communication_enabled -@dataclass -class TokenDispatchResult: - hidden_states: torch.Tensor - group_list: torch.Tensor - group_list_type: int - dynamic_scale: torch.Tensor | None = field(default=None) - topk_scales: torch.Tensor | None = field(default=None) - context_metadata: dict = field(default_factory=dict) - - -@dataclass -class TokenCombineResult: - routed_out: torch.Tensor - - -class MoETokenDispatcher(ABC): +class MoETokenDispatcher(ABC, Generic[TMoECombineMetadata]): def __init__(self, **kwargs) -> None: """ Initialize the MoE Token Dispatcher. @@ -73,27 +66,21 @@ class MoETokenDispatcher(ABC): @abstractmethod def token_dispatch( self, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - expert_map: torch.Tensor | None = None, - global_redundant_expert_num: int = 0, - mc2_mask: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - with_quant: bool = False, - dynamic_eplb: bool = False, - pertoken_scale: torch.Tensor | None = None, - ) -> TokenDispatchResult: + token_dispatch_input: MoETokenDispatchInput, + ) -> MoETokenDispatchOutput[TMoECombineMetadata]: raise NotImplementedError("Dispatch function not implemented.") @abstractmethod def token_combine( - self, hidden_states: torch.Tensor, context_metadata: dict, bias: torch.Tensor | None = None - ) -> TokenCombineResult: + self, + hidden_states: torch.Tensor, + combine_metadata: TMoECombineMetadata, + bias: torch.Tensor | None = None, + ) -> torch.Tensor: raise NotImplementedError("Combine function not implemented.") -class TokenDispatcherWithMC2(MoETokenDispatcher): +class TokenDispatcherWithMC2(MoETokenDispatcher[MoEMC2CombineMetadata]): def __init__(self, **kwargs): super().__init__(**kwargs) device_group = get_mc2_group().device_group @@ -110,7 +97,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): # HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly # improve communication performance. self.need_expert_scale = is_hierarchical_communication_enabled() - self.with_quant = False # Here we need to calculate the global_bs = max_bs_per_rank * ep_world_size to execute # dispatch & combine operators with different input num_tokens per rank. @@ -131,25 +117,23 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): def get_dispatch_mc2_kwargs( self, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - expert_map: torch.Tensor, - mc2_mask: torch.Tensor, - global_redundant_expert_num: int = 0, - **kwargs, + token_dispatch_input: MoETokenDispatchInput, ): - use_mxfp_quant = kwargs.get("use_mxfp_quant", False) - comm_quant_mode = kwargs.get("comm_quant_mode") + hidden_states = token_dispatch_input.hidden_states + topk_weights = token_dispatch_input.topk_weights + topk_ids = token_dispatch_input.topk_ids + expert_map = token_dispatch_input.routing.expert_map + global_redundant_expert_num = token_dispatch_input.routing.global_redundant_expert_num + comm_quant_mode = token_dispatch_input.quant.comm_quant_mode + + assert expert_map is not None, "expert_map is required for MC2 token dispatch." # NOTE: quant_mode differs by quant feature: # - Legacy int communication quantization uses quant_mode=2. # - A5 MXFP8 communication uses quant_mode=4. - # TODO(linfeng): The quantization-related parameters need to be consolidated into a single - # dataclass, and the FP8 MoE code path should be integrated into it going forward. if comm_quant_mode is not None: quant_mode = comm_quant_mode - elif self.with_quant: - quant_mode = 4 if self.a5_need_extra_args and use_mxfp_quant else 2 + elif token_dispatch_input.quant.dispatch_with_quant: + quant_mode = 4 if self.a5_need_extra_args and token_dispatch_input.quant.is_mxfp else 2 else: quant_mode = 0 self.moe_expert_num = len(expert_map) + global_redundant_expert_num @@ -178,10 +162,13 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): "tp_rank_id": 0, } ) - if self.a5_need_extra_args and use_mxfp_quant: - y_dtype = kwargs.get("y_dtype") - if self.with_quant: - y_dtype = torch.float8_e4m3fn if y_dtype is None else y_dtype + if self.a5_need_extra_args and token_dispatch_input.quant.is_mxfp: + y_dtype = torch.float8_e4m3fn + if ( + token_dispatch_input.quant.mxfp is not None + and token_dispatch_input.quant.mxfp.act_quant_type is not None + ): + y_dtype = token_dispatch_input.quant.mxfp.act_quant_type stage1_kwargs.update({"tp_world_size": 1, "tp_rank_id": 0, "y_dtype": y_dtype}) if self.need_expert_scale or self.a5_need_extra_args: stage1_kwargs.update( @@ -195,22 +182,9 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): def token_dispatch( self, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - expert_map: torch.Tensor | None = None, - global_redundant_expert_num: int = 0, - mc2_mask: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - with_quant: bool = False, - dynamic_eplb: bool = False, - pertoken_scale: torch.Tensor | None = None, - **kwargs, + token_dispatch_input: MoETokenDispatchInput, ): - self.with_quant = with_quant - kwargs_mc2 = self.get_dispatch_mc2_kwargs( - hidden_states, topk_weights, topk_ids, expert_map, mc2_mask, global_redundant_expert_num, **kwargs - ) + kwargs_mc2 = self.get_dispatch_mc2_kwargs(token_dispatch_input) output = ( torch_npu.npu_moe_distribute_dispatch_v2(**kwargs_mc2) if self.enable_dispatch_v2 @@ -227,33 +201,32 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): expand_scales, ) = output[0:7] - context_metadata = { - "topk_ids": topk_ids, - "topk_weights": topk_weights, - "expert_map": expert_map, - "ep_recv_counts": ep_recv_counts, - "tp_recv_counts": tp_recv_counts, - "assist_info_for_combine": assist_info_for_combine, - "expand_scales": expand_scales, - } - group_list_type = 0 - return TokenDispatchResult( + return MoETokenDispatchOutput( hidden_states=expand_x, dynamic_scale=dynamic_scale, group_list=expert_token_nums, group_list_type=group_list_type, - context_metadata=context_metadata, + combine_metadata=MoEMC2CombineMetadata( + topk_ids=token_dispatch_input.topk_ids, + topk_weights=token_dispatch_input.topk_weights, + expert_map=token_dispatch_input.routing.expert_map, + ep_recv_counts=ep_recv_counts, + tp_recv_counts=tp_recv_counts, + assist_info_for_combine=assist_info_for_combine, + expand_scales=expand_scales, + dispatch_with_quant=token_dispatch_input.quant.dispatch_with_quant, + ), ) - def get_combine_mc_kwargs(self, hidden_states: torch.Tensor, context_metadata: dict): - expert_map = context_metadata["expert_map"] - topk_ids = context_metadata["topk_ids"] - topk_weights = context_metadata["topk_weights"] - ep_recv_counts = context_metadata["ep_recv_counts"] - tp_recv_counts = context_metadata["tp_recv_counts"] - assist_info_for_combine = context_metadata["assist_info_for_combine"] - expand_scales = context_metadata["expand_scales"] + def get_combine_mc_kwargs(self, hidden_states: torch.Tensor, combine_metadata: MoEMC2CombineMetadata): + expert_map = combine_metadata.expert_map + topk_ids = combine_metadata.topk_ids + topk_weights = combine_metadata.topk_weights + ep_recv_counts = combine_metadata.ep_recv_counts + tp_recv_counts = combine_metadata.tp_recv_counts + assist_info_for_combine = combine_metadata.assist_info_for_combine + expand_scales = combine_metadata.expand_scales assert expert_map is not None @@ -267,7 +240,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): "global_bs": self.global_bs, } - if self.with_quant: + if combine_metadata.dispatch_with_quant: tp_recv_counts = torch.empty(1, dtype=torch.int32, device=hidden_states.device) stage3_kwargs = { @@ -296,52 +269,44 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): kwargs_mc2.update(stage3_kwargs) return kwargs_mc2 - def token_combine(self, hidden_states, context_metadata, bias=None): + def token_combine(self, hidden_states, combine_metadata, bias=None): assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher." - kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states, context_metadata) + kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states, combine_metadata) combined_output = ( torch_npu.npu_moe_distribute_combine_v2(**kwargs_mc2) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(**kwargs_mc2) ) - return TokenCombineResult( - routed_out=combined_output, - ) + return combined_output -class TokenDispatcherWithAllGather(MoETokenDispatcher): +class TokenDispatcherWithAllGather(MoETokenDispatcher[MoEAllGatherCombineMetadata]): def __init__(self, **kwargs): super().__init__(**kwargs) - self.apply_router_weight_on_input = False self.max_num_tokens = kwargs.get("max_num_tokens") num_experts_local = kwargs.get("num_local_experts", 0) self.num_experts_local = ( num_experts_local.item() if torch.is_tensor(num_experts_local) else int(num_experts_local) ) - self.original_shape = None - self.with_quant = False def token_dispatch( self, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - expert_map: torch.Tensor | None = None, - global_redundant_expert_num: int = 0, - mc2_mask: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - with_quant: bool = False, - dynamic_eplb: bool = False, - pertoken_scale: torch.Tensor | None = None, + token_dispatch_input: MoETokenDispatchInput, ): - self.with_quant = with_quant - self.original_shape = hidden_states.shape + with_quant = token_dispatch_input.quant.is_int_quant + hidden_states = token_dispatch_input.hidden_states + topk_weights = token_dispatch_input.topk_weights + topk_ids = token_dispatch_input.topk_ids + expert_map = token_dispatch_input.routing.expert_map + pertoken_scale = token_dispatch_input.routing.pertoken_scale + global_redundant_expert_num = token_dispatch_input.routing.global_redundant_expert_num + restore_shape = hidden_states.shape num_tokens = hidden_states.shape[:-1].numel() - self.apply_router_weight_on_input = apply_router_weight_on_input - if self.apply_router_weight_on_input: + apply_router_weight_on_input = token_dispatch_input.routing.apply_router_weight_on_input + if apply_router_weight_on_input: assert topk_weights.dim() == 2, "`topk_weights` should be in shape (num_tokens, topk)" _, topk = topk_weights.shape assert topk == 1, "Only support topk=1 when `apply_router_weight_on_input` is True" @@ -365,35 +330,37 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): expert_tokens_num_type=1, expert_tokens_num_flag=True, active_expert_range=[first_expert_idx, last_expert_idx], - quant_mode=1 if self.with_quant and pertoken_scale is None else -1, + quant_mode=1 if with_quant and pertoken_scale is None else -1, ) expert_tokens = expert_tokens.to(torch.int64) group_list_type = 1 # `count` mode - context_metadata = {"topk_weights": topk_weights, "expanded_row_idx": expanded_row_idx} - return TokenDispatchResult( + return MoETokenDispatchOutput( hidden_states=sorted_hidden_states, - dynamic_scale=pertoken_scale if self.with_quant else None, + dynamic_scale=pertoken_scale if with_quant else None, group_list=expert_tokens, group_list_type=group_list_type, - context_metadata=context_metadata, + combine_metadata=MoEAllGatherCombineMetadata( + topk_weights=topk_weights, + expanded_row_idx=expanded_row_idx, + restore_shape=restore_shape, + ), ) - def token_combine(self, hidden_states, context_metadata, bias=None): - assert self.original_shape is not None + def token_combine(self, hidden_states, combine_metadata, bias=None): final_hidden_states = torch_npu.npu_moe_token_unpermute( permuted_tokens=hidden_states, - sorted_indices=torch.abs(context_metadata["expanded_row_idx"]), - probs=context_metadata["topk_weights"], + sorted_indices=torch.abs(combine_metadata.expanded_row_idx), + probs=combine_metadata.topk_weights, ) - if len(self.original_shape) == 3: - final_hidden_states = final_hidden_states.view(self.original_shape) + if len(combine_metadata.restore_shape) == 3: + final_hidden_states = final_hidden_states.view(combine_metadata.restore_shape) # these values are no longer used, so they need to be set to None for memory release. - return TokenCombineResult(routed_out=final_hidden_states) + return final_hidden_states -class TokenDispatcherWithAll2AllV(MoETokenDispatcher): +class TokenDispatcherWithAll2AllV(MoETokenDispatcher[MoEAllToAllCombineMetadata]): """ The implementation of the AlltoAll-based token dispatcher, which handles token dispatching on the sequence level instead of token level. The core of this implementation @@ -402,12 +369,8 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher): def __init__(self, **kwargs): super().__init__(**kwargs) - self.with_quant = False self.num_local_experts = kwargs.get("num_local_experts", 0) - self.hidden_shape = None - self.hidden_shape_before_permute = None - assert self.num_local_experts > 0, "Expected at least one expert" if self.num_local_experts > 1: self.expert_ids_per_ep_rank = torch.tensor( @@ -432,19 +395,12 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher): def token_dispatch( self, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - expert_map: torch.Tensor | None = None, - global_redundant_expert_num: int = 0, - mc2_mask: torch.Tensor | None = None, - apply_router_weight_on_input: bool = False, - with_quant: bool = False, - dynamic_eplb: bool = False, - pertoken_scale: torch.Tensor | None = None, + token_dispatch_input: MoETokenDispatchInput, ): - self.with_quant = with_quant - self.hidden_shape = hidden_states.shape + with_quant = token_dispatch_input.quant.is_int_quant + hidden_states = token_dispatch_input.hidden_states + topk_weights = token_dispatch_input.topk_weights + topk_ids = token_dispatch_input.topk_ids ( permutated_local_input_tokens, @@ -452,12 +408,13 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher): tokens_per_expert, input_splits, output_splits, - num_global_tokens_per_local_expert, global_input_tokens_local_experts_indices, + hidden_shape, + hidden_shape_before_permute, ) = self._dispatch_preprocess(hidden_states, topk_ids) dynamic_scale_after_all2all = None - if self.with_quant: + if with_quant: permutated_local_input_tokens, dynamic_scale = torch_npu.npu_dynamic_quant(permutated_local_input_tokens) _, dynamic_scale_after_all2all, permute2_ep_all_to_all_handle = async_all_to_all( dynamic_scale, output_splits, input_splits, self.ep_group @@ -474,64 +431,66 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher): # Postprocess global_input_tokens, dynamic_scale_final, reversed_global_input_permutation_mapping = ( self._dispatch_postprocess( - global_input_tokens, dynamic_scale_after_all2all, global_input_tokens_local_experts_indices + global_input_tokens, + dynamic_scale_after_all2all, + global_input_tokens_local_experts_indices, + with_quant, ) ) - context_metadata = { - "input_splits": input_splits, - "output_splits": output_splits, - "topk_weights": topk_weights, - "reversed_local_input_permutation_mapping": reversed_local_input_permutation_mapping, - "reversed_global_input_permutation_mapping": reversed_global_input_permutation_mapping, - } - - return TokenDispatchResult( + return MoETokenDispatchOutput( hidden_states=global_input_tokens, dynamic_scale=dynamic_scale_final, group_list=tokens_per_expert, group_list_type=1, - context_metadata=context_metadata, + combine_metadata=MoEAllToAllCombineMetadata( + input_splits=input_splits, + output_splits=output_splits, + topk_weights=topk_weights, + reversed_local_input_permutation_mapping=reversed_local_input_permutation_mapping, + reversed_global_input_permutation_mapping=reversed_global_input_permutation_mapping, + hidden_shape=hidden_shape, + hidden_shape_before_permute=hidden_shape_before_permute, + ), ) - def token_combine(self, hidden_states, context_metadata, bias=None): + def token_combine(self, hidden_states, combine_metadata, bias=None): assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher." # 1. Preprocess using metadata - hidden_states = self._combine_preprocess(hidden_states, context_metadata) + hidden_states = self._combine_preprocess(hidden_states, combine_metadata) # 2. AllToAll _, permutated_local_input_tokens, handle = async_all_to_all( hidden_states, - context_metadata["input_splits"], - context_metadata["output_splits"], + combine_metadata.input_splits, + combine_metadata.output_splits, self.ep_group, ) handle.wait() hidden_states.untyped_storage().resize_(0) # 3. Postprocess using metadata - output = self._combine_postprocess(permutated_local_input_tokens, context_metadata) + output = self._combine_postprocess(permutated_local_input_tokens, combine_metadata) - return TokenCombineResult(routed_out=output) + return output def _dispatch_preprocess(self, hidden_states, topk_ids): - assert self.hidden_shape is not None + hidden_shape = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_states.size(-1)) ( tokens_per_expert, input_splits, output_splits, - num_global_tokens_per_local_expert, global_input_tokens_local_experts_indices, + num_out_tokens, ) = self._preprocess(topk_ids) - - self.hidden_shape_before_permute = hidden_states.shape + hidden_shape_before_permute = hidden_states.shape permutated_local_input_tokens, reversed_local_input_permutation_mapping = torch_npu.npu_moe_token_permute( tokens=hidden_states, indices=topk_ids, - num_out_tokens=self.num_out_tokens, + num_out_tokens=num_out_tokens, ) return ( @@ -540,15 +499,16 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher): tokens_per_expert, input_splits, output_splits, - num_global_tokens_per_local_expert, global_input_tokens_local_experts_indices, + hidden_shape, + hidden_shape_before_permute, ) def _preprocess(self, topk_ids: torch.Tensor): num_local_tokens_per_expert = torch.histc(topk_ids, bins=self.num_experts, min=0, max=self.num_experts) ep_size = self.ep_size - self.num_out_tokens = topk_ids.numel() + num_out_tokens = topk_ids.numel() input_splits = ( num_local_tokens_per_expert.reshape(ep_size, self.num_local_experts) @@ -585,19 +545,19 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher): num_tokens_per_local_expert, input_splits, output_splits, - num_global_tokens_per_local_expert, global_input_tokens_local_experts_indices, + num_out_tokens, ) def _dispatch_postprocess( - self, global_input_tokens, dynamic_scale_after_all2all, global_input_tokens_local_experts_indices + self, global_input_tokens, dynamic_scale_after_all2all, global_input_tokens_local_experts_indices, with_quant ): # Early return if no local experts or no tokens if self.num_local_experts <= 1: return global_input_tokens, dynamic_scale_after_all2all, None # Handle quantized case - if self.with_quant: + if with_quant: assert global_input_tokens_local_experts_indices is not None, ( "global_input_tokens_local_experts_indices must be provided" ) @@ -612,20 +572,26 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher): ) return global_input_tokens, dynamic_scale_after_all2all, reversed_global_input_permutation_mapping - def _combine_preprocess(self, hidden_states: torch.Tensor, context_metadata: dict) -> torch.Tensor: + def _combine_preprocess( + self, hidden_states: torch.Tensor, combine_metadata: MoEAllToAllCombineMetadata + ) -> torch.Tensor: # Unpermutation 2: expert output to AlltoAll input - if hidden_states.shape[0] > 0 and self.num_local_experts > 1: - rev_global = context_metadata["reversed_global_input_permutation_mapping"] + rev_global = combine_metadata.reversed_global_input_permutation_mapping + if hidden_states.shape[0] > 0 and self.num_local_experts > 1 and rev_global is not None: hidden_states = torch_npu.npu_moe_token_unpermute(hidden_states, rev_global) return hidden_states - def _combine_postprocess(self, permutated_local_input_tokens: torch.Tensor, context_metadata: dict) -> torch.Tensor: + def _combine_postprocess( + self, + permutated_local_input_tokens: torch.Tensor, + combine_metadata: MoEAllToAllCombineMetadata, + ) -> torch.Tensor: # Unpermutation 1: AlltoAll output to output output = torch_npu.npu_moe_token_unpermute( permuted_tokens=permutated_local_input_tokens, - sorted_indices=context_metadata["reversed_local_input_permutation_mapping"].to(torch.int32), - probs=context_metadata["topk_weights"], - restore_shape=self.hidden_shape_before_permute, + sorted_indices=combine_metadata.reversed_local_input_permutation_mapping.to(torch.int32), + probs=combine_metadata.topk_weights, + restore_shape=combine_metadata.hidden_shape_before_permute, ) - output = output.view(self.hidden_shape) + output = output.view(combine_metadata.hidden_shape) return output diff --git a/vllm_ascend/quantization/__init__.py b/vllm_ascend/quantization/__init__.py index 1bf29125..575c3526 100644 --- a/vllm_ascend/quantization/__init__.py +++ b/vllm_ascend/quantization/__init__.py @@ -16,24 +16,30 @@ # """Ascend quantization module. -This module provides quantization support for Ascend NPU. - -Supported quantization tools: -- ModelSlim: Use AscendModelSlimConfig -- LLM-Compressor (compressed_tensors): Use AscendCompressedTensorsConfig - -Public API: -- Config classes: AscendModelSlimConfig, AscendCompressedTensorsConfig -- For scheme implementations, import from vllm_ascend.quantization.methods +This module intentionally avoids eager imports so that importing lightweight +submodules (for example ``quant_type``) does not trigger heavy registration +paths and circular imports during startup. """ -# LLM-Compressor (compressed_tensors) quantization config -from .compressed_tensors_config import AscendCompressedTensorsConfig +from typing import TYPE_CHECKING, Any -# ModelSlim quantization config -from .modelslim_config import AscendModelSlimConfig +if TYPE_CHECKING: + from .compressed_tensors_config import AscendCompressedTensorsConfig + from .modelslim_config import AscendModelSlimConfig __all__ = [ "AscendModelSlimConfig", "AscendCompressedTensorsConfig", ] + + +def __getattr__(name: str) -> Any: + if name == "AscendModelSlimConfig": + from .modelslim_config import AscendModelSlimConfig + + return AscendModelSlimConfig + if name == "AscendCompressedTensorsConfig": + from .compressed_tensors_config import AscendCompressedTensorsConfig + + return AscendCompressedTensorsConfig + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm_ascend/quantization/method_adapters.py b/vllm_ascend/quantization/method_adapters.py index 34764a1b..3c215860 100644 --- a/vllm_ascend/quantization/method_adapters.py +++ b/vllm_ascend/quantization/method_adapters.py @@ -255,28 +255,34 @@ class AscendFusedMoEMethod(FusedMoEMethodBase): enable_force_load_balance: bool = False, log2phy: torch.Tensor | None = None, global_redundant_expert_num=0, - **kwargs, + pertoken_scale: torch.Tensor | None = None, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + mc2_mask: torch.Tensor | None = None, ) -> torch.Tensor: return self.quant_method.apply( - layer, - x, - router_logits, - top_k, - renormalize, - use_grouped_topk, - global_num_experts, - expert_map, - topk_group, - num_expert_group, - custom_routing_function, - scoring_func, - routed_scaling_factor, - e_score_correction_bias, - is_prefill, - enable_force_load_balance, - log2phy, - global_redundant_expert_num, - **kwargs, + layer=layer, + x=x, + router_logits=router_logits, + top_k=top_k, + renormalize=renormalize, + use_grouped_topk=use_grouped_topk, + global_num_experts=global_num_experts, + expert_map=expert_map, + topk_group=topk_group, + num_expert_group=num_expert_group, + custom_routing_function=custom_routing_function, + scoring_func=scoring_func, + routed_scaling_factor=routed_scaling_factor, + e_score_correction_bias=e_score_correction_bias, + is_prefill=is_prefill, + enable_force_load_balance=enable_force_load_balance, + log2phy=log2phy, + global_redundant_expert_num=global_redundant_expert_num, + pertoken_scale=pertoken_scale, + activation=activation, + apply_router_weight_on_input=apply_router_weight_on_input, + mc2_mask=mc2_mask, ) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: diff --git a/vllm_ascend/quantization/methods/base.py b/vllm_ascend/quantization/methods/base.py index 9307eb92..4a7629b9 100644 --- a/vllm_ascend/quantization/methods/base.py +++ b/vllm_ascend/quantization/methods/base.py @@ -18,19 +18,11 @@ from abc import ABC, abstractmethod from collections.abc import Callable -from enum import Enum from typing import Any import torch - -class QuantType(Enum): - """Quantization type enum for MoE schemes.""" - - NONE = 0 - W8A8 = 1 - W4A8 = 2 - MXFP8 = 3 +from vllm_ascend.quantization.quant_type import QuantType class AscendLinearScheme(ABC): @@ -245,7 +237,10 @@ class AscendMoEScheme(ABC): enable_force_load_balance: bool = False, log2phy: torch.Tensor | None = None, global_redundant_expert_num: int = 0, - **kwargs, + pertoken_scale: Any | None = None, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + mc2_mask: torch.Tensor | None = None, ) -> torch.Tensor: """Forward computation for MoE layer. @@ -268,7 +263,10 @@ class AscendMoEScheme(ABC): enable_force_load_balance: Whether to force load balancing. log2phy: Logical to physical expert mapping. global_redundant_expert_num: Number of redundant experts. - **kwargs: Additional keyword arguments. + pertoken_scale: Optional per-token activation scale from prepare stage. + activation: Expert MLP activation type. + apply_router_weight_on_input: Whether to pre-scale hidden states by router weights. + mc2_mask: Optional mask used by MC2 dispatch. Returns: Output tensor after MoE computation. diff --git a/vllm_ascend/quantization/methods/w4a16.py b/vllm_ascend/quantization/methods/w4a16.py index bb3bc3da..3a0e3f5e 100644 --- a/vllm_ascend/quantization/methods/w4a16.py +++ b/vllm_ascend/quantization/methods/w4a16.py @@ -25,8 +25,9 @@ from vllm.config import get_current_vllm_config from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import _EXTRA_CTX from vllm_ascend.ops.fused_moe.experts_selector import select_experts +from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input -from .base import AscendMoEScheme +from .base import AscendMoEScheme, QuantType from .registry import register_scheme @@ -103,6 +104,8 @@ def pack_to_int32(weight: torch.Tensor) -> torch.Tensor: class AscendW4A16FusedMoEMethod(AscendMoEScheme): """FusedMoE method for Ascend W4A16.""" + quant_type: QuantType = QuantType.W4A16 + def __init__(self) -> None: self.transpose_weight = True self.num_bits = 4 # dtype = torch.int4 @@ -192,7 +195,10 @@ class AscendW4A16FusedMoEMethod(AscendMoEScheme): enable_force_load_balance: bool = True, log2phy: torch.Tensor | None = None, global_redundant_expert_num: int = 0, - **kwargs, + pertoken_scale: Any | None = None, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + mc2_mask: torch.Tensor | None = None, ) -> torch.Tensor: assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, ( "Number of global experts mismatch (excluding redundancy)" @@ -217,20 +223,26 @@ class AscendW4A16FusedMoEMethod(AscendMoEScheme): moe_comm_method = _EXTRA_CTX.moe_comm_method return moe_comm_method.fused_experts( - hidden_states=x, - w1=layer.w13_weight_packed, - w2=layer.w2_weight_packed, - w1_scale=layer.w13_weight_scale, - w2_scale=layer.w2_weight_scale, - w1_offset=layer.w13_weight_offset, - w2_offset=layer.w2_weight_offset, - topk_weights=topk_weights, - topk_ids=topk_ids, - use_int4_w4a16=True, - expert_map=expert_map, - log2phy=log2phy, - dynamic_eplb=self.dynamic_eplb, - mc2_mask=kwargs.get("mc2_mask"), + fused_experts_input=build_fused_experts_input( + hidden_states=x, + topk_weights=topk_weights, + topk_ids=topk_ids, + w1=layer.w13_weight_packed, + w2=layer.w2_weight_packed, + quant_type=self.quant_type, + dynamic_eplb=self.dynamic_eplb, + expert_map=expert_map, + global_redundant_expert_num=global_redundant_expert_num, + mc2_mask=mc2_mask, + apply_router_weight_on_input=apply_router_weight_on_input, + log2phy=log2phy, + pertoken_scale=pertoken_scale, + activation=activation, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + w1_offset=layer.w13_weight_offset, + w2_offset=layer.w2_weight_offset, + ) ) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: diff --git a/vllm_ascend/quantization/methods/w4a8.py b/vllm_ascend/quantization/methods/w4a8.py index 0ebeafc5..2c51944c 100644 --- a/vllm_ascend/quantization/methods/w4a8.py +++ b/vllm_ascend/quantization/methods/w4a8.py @@ -28,6 +28,7 @@ from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_forward_context import _EXTRA_CTX from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.fused_moe.experts_selector import select_experts +from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input from vllm_ascend.utils import COMPRESSED_TENSORS_METHOD, maybe_trans_nz from .base import AscendLinearScheme, AscendMoEScheme, QuantType @@ -343,7 +344,10 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme): enable_force_load_balance: bool = False, log2phy: torch.Tensor | None = None, global_redundant_expert_num: int = 0, - **kwargs, + pertoken_scale: torch.Tensor | None = None, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + mc2_mask: torch.Tensor | None = None, ) -> torch.Tensor: assert router_logits.shape[1] == global_num_experts - global_redundant_expert_num, ( "Number of global experts mismatch (excluding redundancy)" @@ -377,20 +381,26 @@ class AscendW4A8DynamicFusedMoEMethod(AscendMoEScheme): moe_comm_method = _EXTRA_CTX.moe_comm_method return moe_comm_method.fused_experts( - hidden_states=x, - w1=[layer.w13_weight], - w2=[layer.w2_weight], - w1_scale=[layer.w13_weight_scale], - w2_scale=[layer.w2_weight_scale], - w1_scale_bias=layer.w13_scale_bias if hasattr(layer, "w13_scale_bias") else None, - w2_scale_bias=layer.w2_scale_bias if hasattr(layer, "w2_scale_bias") else None, - topk_weights=topk_weights, - topk_ids=topk_ids, - use_int4_w4a8=True, - expert_map=expert_map, - log2phy=log2phy, - dynamic_eplb=self.dynamic_eplb, - mc2_mask=kwargs.get("mc2_mask"), + fused_experts_input=build_fused_experts_input( + hidden_states=x, + topk_weights=topk_weights, + topk_ids=topk_ids, + w1=[layer.w13_weight], + w2=[layer.w2_weight], + quant_type=self.quant_type, + dynamic_eplb=self.dynamic_eplb, + expert_map=expert_map, + global_redundant_expert_num=global_redundant_expert_num, + mc2_mask=mc2_mask, + apply_router_weight_on_input=apply_router_weight_on_input, + log2phy=log2phy, + pertoken_scale=pertoken_scale, + activation=activation, + w1_scale=[layer.w13_weight_scale], + w2_scale=[layer.w2_weight_scale], + w1_scale_bias=layer.w13_scale_bias if hasattr(layer, "w13_scale_bias") else None, + w2_scale_bias=layer.w2_scale_bias if hasattr(layer, "w2_scale_bias") else None, + ) ) def process_scale(self, weight: torch.Tensor, scale, per_group_scale): diff --git a/vllm_ascend/quantization/methods/w8a8_dynamic.py b/vllm_ascend/quantization/methods/w8a8_dynamic.py index 66596629..1b17ad30 100644 --- a/vllm_ascend/quantization/methods/w8a8_dynamic.py +++ b/vllm_ascend/quantization/methods/w8a8_dynamic.py @@ -29,6 +29,7 @@ from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.flash_common3_context import get_flash_common3_context from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_experts_compute +from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, maybe_trans_nz from .base import AscendLinearScheme, AscendMoEScheme, QuantType @@ -182,7 +183,9 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme): log2phy: torch.Tensor | None = None, global_redundant_expert_num: int = 0, pertoken_scale: Any | None = None, - **kwargs, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + mc2_mask: torch.Tensor | None = None, ) -> torch.Tensor: zero_expert_num = getattr(layer, "zero_expert_num", 0) zero_expert_type = getattr(layer, "zero_expert_type", None) @@ -249,19 +252,24 @@ class AscendW8A8DynamicFusedMoEMethod(AscendMoEScheme): w2_scale = [layer.fused_w2_scale] if fused_scale_flag else [layer.w2_weight_scale] final_hidden_states = moe_comm_method.fused_experts( - hidden_states=x, - pertoken_scale=pertoken_scale, - w1=w1, - w1_scale=w1_scale, - w2=w2, - w2_scale=w2_scale, - topk_weights=topk_weights, - topk_ids=topk_ids, - use_int8_w8a8=True, - expert_map=expert_map, - log2phy=log2phy, - dynamic_eplb=self.dynamic_eplb, - mc2_mask=kwargs.get("mc2_mask"), + fused_experts_input=build_fused_experts_input( + hidden_states=x, + topk_weights=topk_weights, + topk_ids=topk_ids, + w1=w1, + w2=w2, + quant_type=self.quant_type, + dynamic_eplb=self.dynamic_eplb, + expert_map=expert_map, + global_redundant_expert_num=global_redundant_expert_num, + mc2_mask=mc2_mask, + apply_router_weight_on_input=apply_router_weight_on_input, + log2phy=log2phy, + pertoken_scale=pertoken_scale, + activation=activation, + w1_scale=[layer.fused_w1_scale] if fused_scale_flag else w1_scale, + w2_scale=[layer.fused_w2_scale] if fused_scale_flag else w2_scale, + ) ) if zero_expert_num > 0 and zero_expert_type is not None: final_hidden_states += zero_expert_result diff --git a/vllm_ascend/quantization/methods/w8a8_mxfp8.py b/vllm_ascend/quantization/methods/w8a8_mxfp8.py index 574c4d75..79fa9480 100644 --- a/vllm_ascend/quantization/methods/w8a8_mxfp8.py +++ b/vllm_ascend/quantization/methods/w8a8_mxfp8.py @@ -31,6 +31,7 @@ from vllm_ascend.device.mxfp_compat import ( ensure_mxfp8_moe_available, ) from vllm_ascend.ops.fused_moe.experts_selector import select_experts +from vllm_ascend.ops.fused_moe.moe_runtime_args import build_fused_experts_input from .base import AscendLinearScheme, AscendMoEScheme, QuantType from .registry import register_scheme @@ -170,7 +171,10 @@ class AscendW8A8MXFP8DynamicFusedMoEMethod(AscendMoEScheme): enable_force_load_balance: bool = True, log2phy: torch.Tensor = None, global_redundant_expert_num: int = 0, - **kwargs, + pertoken_scale: Any | None = None, + activation: str = "silu", + apply_router_weight_on_input: bool = False, + mc2_mask: torch.Tensor | None = None, ) -> torch.Tensor: expected = global_num_experts - global_redundant_expert_num assert router_logits.shape[1] == expected, "Number of global experts mismatch (excluding redundancy)" @@ -198,23 +202,29 @@ class AscendW8A8MXFP8DynamicFusedMoEMethod(AscendMoEScheme): moe_comm_method = _EXTRA_CTX.moe_comm_method return moe_comm_method.fused_experts( - hidden_states=x, - w1=layer.w13_weight, - w1_scale=layer.w13_weight_scale, - w2=layer.w2_weight, - w2_scale=layer.w2_weight_scale, - topk_weights=topk_weights, - topk_ids=topk_ids, - use_int8_w8a8=False, - expert_map=expert_map, - log2phy=log2phy, - dynamic_eplb=self.dynamic_eplb, - mc2_mask=kwargs.get("mc2_mask"), - use_mxfp_quant=True, - act_quant_type=torch.float8_e4m3fn, - weight_quant_type=torch.float8_e4m3fn, - scale_type=FLOAT8_E8M0FNU_DTYPE, - per_token_scale_type=FLOAT8_E8M0FNU_DTYPE, + fused_experts_input=build_fused_experts_input( + hidden_states=x, + topk_weights=topk_weights, + topk_ids=topk_ids, + w1=layer.w13_weight, + w2=layer.w2_weight, + quant_type=self.quant_type, + dynamic_eplb=self.dynamic_eplb, + expert_map=expert_map, + global_redundant_expert_num=global_redundant_expert_num, + mc2_mask=mc2_mask, + apply_router_weight_on_input=apply_router_weight_on_input, + log2phy=log2phy, + pertoken_scale=pertoken_scale, + activation=activation, + mxfp_act_quant_type=torch.float8_e4m3fn, + mxfp_weight_quant_type=torch.float8_e4m3fn, + mxfp_scale_dtype=FLOAT8_E8M0FNU_DTYPE, + mxfp_per_token_scale_dtype=FLOAT8_E8M0FNU_DTYPE, + mxfp_use_bf16=(x.dtype == torch.bfloat16), + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale, + ) ) def process_weights_after_loading(self, layer): diff --git a/vllm_ascend/quantization/quant_type.py b/vllm_ascend/quantization/quant_type.py new file mode 100644 index 00000000..10327103 --- /dev/null +++ b/vllm_ascend/quantization/quant_type.py @@ -0,0 +1,33 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Shared quantization enum definitions. + +Keep this module lightweight and side-effect free so core runtime modules can +import QuantType without triggering heavy quantization package initialization. +""" + +from enum import Enum + + +class QuantType(Enum): + """Quantization type enum for MoE schemes.""" + + NONE = 0 + W8A8 = 1 + W4A8 = 2 + MXFP8 = 3 + W4A16 = 4