diff --git a/tests/e2e/singlecard/ops/test_fused_moe.py b/tests/e2e/singlecard/ops/test_fused_moe.py index ab673a4..d6320a5 100644 --- a/tests/e2e/singlecard/ops/test_fused_moe.py +++ b/tests/e2e/singlecard/ops/test_fused_moe.py @@ -24,10 +24,12 @@ from unittest.mock import MagicMock, patch import pytest import torch +import torch_npu from vllm.model_executor.layers.activation import SiluAndMul -from vllm_ascend.ops.fused_moe import fused_experts from vllm_ascend.ops.layers.experts_selector import select_experts +from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \ + TokenDispatcherWithAllGather NUM_EXPERTS = [8, 64] EP_SIZE = [1, 4] @@ -35,6 +37,38 @@ TOP_KS = [2, 6] DEVICE = ["npu"] +def apply_mlp( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + group_list: torch.Tensor, + group_list_type: int = 1, +) -> torch.Tensor: + w1 = w1.transpose(1, 2) + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w1], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + )[0] + + hidden_states = torch_npu.npu_swiglu(hidden_states) + + w2 = w2.transpose(1, 2) + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w2], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + )[0] + + return hidden_states + + def torch_moe(a, w1, w2, topk_weights, topk_ids, topk, expert_map): B, D = a.shape a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) @@ -60,7 +94,7 @@ def torch_moe(a, w1, w2, topk_weights, topk_ids, topk, expert_map): @pytest.mark.parametrize("ep_size", EP_SIZE) @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("device", DEVICE) -def test_fused_experts( +def test_token_dispatcher_with_all_gather( m: int, n: int, k: int, @@ -75,19 +109,23 @@ def test_fused_experts( w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 10 score = torch.randn((m, e), device=device, dtype=dtype) + expert_map = None + local_e = e + w1_local = w1 + w2_local = w2 if ep_size > 1: local_e = e // ep_size - e_ids = torch.randint(0, - e, (local_e, ), - device=device, - dtype=torch.int32) - e_map = torch.full((e, ), -1, device=device, dtype=torch.int32) - e_map[e_ids] = torch.arange(local_e, device=device, dtype=torch.int32) - w1 = w1[e_ids] - w2 = w2[e_ids] - else: - e_map = None + e_ids = torch.arange(local_e * 0, + local_e * (0 + 1), + device=device, + dtype=torch.int32) + expert_map = torch.full((e, ), -1, device=device, dtype=torch.int32) + expert_map[e_ids] = torch.arange(local_e, + device=device, + dtype=torch.int32) + w1_local = w1[e_ids] + w2_local = w2[e_ids] score = torch.softmax(score, dim=-1, dtype=dtype) topk_weights, topk_ids = torch.topk(score, topk) @@ -99,11 +137,42 @@ def test_fused_experts( dtype=torch.int32, ).view(topk, -1).permute(1, 0).contiguous()) - output = fused_experts(a, w1, w2, topk_weights, topk_ids, row_idx, topk, - e_map) - torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk, e_map) - # TODO: The native params are: atol=2e-2, rtol=0, maybe related to the nan problem - torch.testing.assert_close(output, torch_output, atol=4e-2, rtol=1) + dispatcher_kwargs = { + "num_experts": e, + "top_k": topk, + "num_local_experts": local_e, + } + dispatcher = TokenDispatcherWithAllGather(**dispatcher_kwargs) + + apply_router_weight_on_input = False + dispatch_output = dispatcher.token_dispatch( + hidden_states=a, + topk_weights=topk_weights, + topk_ids=topk_ids, + row_idx=row_idx, + expert_map=expert_map, + apply_router_weight_on_input=apply_router_weight_on_input) + + sorted_hidden_states = dispatch_output["hidden_states"] + group_list = dispatch_output["group_list"] + group_list_type = dispatch_output.get("group_list_type", 1) + + expert_output = apply_mlp(hidden_states=sorted_hidden_states, + w1=w1_local, + w2=w2_local, + group_list=group_list, + group_list_type=group_list_type) + + combined_output = dispatcher.token_combine(hidden_states=expert_output, + bias=None) + + torch_output = torch_moe(a, w1, w2, topk_weights, topk_ids, topk, + expert_map) + + torch.testing.assert_close(combined_output, + torch_output, + atol=4e-2, + rtol=1) torch.npu.empty_cache() diff --git a/tests/ut/models/test_deepseek_v2.py b/tests/ut/models/test_deepseek_v2.py index 3902b5b..e0c50e8 100644 --- a/tests/ut/models/test_deepseek_v2.py +++ b/tests/ut/models/test_deepseek_v2.py @@ -22,7 +22,6 @@ from vllm.config import CacheConfig from vllm.distributed.parallel_state import GroupCoordinator from vllm_ascend.models.deepseek_v2 import ( - CustomDeepseekV2DecoderLayer, CustomDeepseekV2ForCausalLM, CustomDeepseekV2MergedReplicatedLinear, CustomDeepseekV2MLAAttention, CustomDeepseekV2MLP, CustomDeepseekV2MoE, CustomDeepseekV2RowParallelLinear, @@ -115,7 +114,8 @@ def mock_distributed(): patch("vllm_ascend.ops.fused_moe.get_current_vllm_config", return_value=mock_vllm_config), \ patch.dict("vllm.distributed.parallel_state.__dict__", _TP=tp_group, _EP=ep_group, _DP=dp_group, _PP=pp_group), \ - patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group): + patch.dict("vllm_ascend.distributed.parallel_state.__dict__", _MC2=ep_group), \ + patch("torch.npu.current_device", return_value=0): yield @@ -266,54 +266,3 @@ def test_custom_deepseek_v2_mla_attention(mock_rms_norm, mock_distributed, kv_lora_rank=16, prefix="layers.1.self_attn") assert hasattr(attn, "q_proj") - - -@patch("torch_npu.npu_add_rms_norm") -@patch("torch_npu.npu_rms_norm") -def test_custom_deepseek_v2_decoder_layer(mock_rms_norm, mock_add_norm, - mock_distributed, base_config, - vllm_config): - mock_rms_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128)) - mock_add_norm.return_value = (torch.randn(2, 128), torch.randn(2, 128), - torch.randn(2, 128)) - base_config.n_routed_experts = 4 - layer = CustomDeepseekV2DecoderLayer(config=base_config, - prefix="layers.0", - model_config=vllm_config.model_config, - cache_config=CacheConfig(), - quant_config=None) - assert isinstance(layer.mlp, CustomDeepseekV2MoE) - - x = torch.randn(2, 4, 128) - positions = torch.arange(4).repeat(2, 1) - - with patch.object(layer.self_attn, "forward", Mock(return_value=torch.randn(2, 4, 128))), \ - patch.object(layer.mlp, "forward", Mock(return_value=torch.randn(2, 4, 128))): - hidden_states, residual = layer(positions, x, None) - assert hidden_states.shape == (2, 4, 128) - - base_config.n_routed_experts = None - layer = CustomDeepseekV2DecoderLayer(config=base_config, - prefix="layers.0", - model_config=vllm_config.model_config, - quant_config=None) - assert isinstance(layer.mlp, CustomDeepseekV2MLP) - - -def test_custom_deepseek_v2_for_causal_lm(mock_distributed, vllm_config): - model = CustomDeepseekV2ForCausalLM(vllm_config=vllm_config) - - input_ids = torch.randint(0, 10000, (2, 4)) - positions = torch.arange(4).repeat(2, 1) - with patch.object(model.model, - "forward", - return_value=torch.randn(2, 4, 128)): - output = model(input_ids, positions) - assert output.shape == (2, 4, 128) - - weights = [("model.embed_tokens.weight", torch.randn(10000, 128))] - with patch( - "vllm.model_executor.model_loader.weight_utils.default_weight_loader" - ): - loaded = model.load_weights(weights) - assert loaded is not None diff --git a/tests/ut/ops/test_fused_ops.py b/tests/ut/ops/test_fused_ops.py index 49a9ed2..04a1659 100644 --- a/tests/ut/ops/test_fused_ops.py +++ b/tests/ut/ops/test_fused_ops.py @@ -22,11 +22,15 @@ import torch_npu from pytest_mock import MockerFixture from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase -from vllm_ascend.ascend_forward_context import _get_fused_moe_state +import vllm_ascend.ops.moe_dispatcher.token_dispatcher as token_dispatcher_module +from tests.ut.base import TestBase +from vllm_ascend.ascend_forward_context import (FusedMoEState, + _get_fused_moe_state) from vllm_ascend.ops.fused_moe import (AscendFusedMoE, - AscendUnquantizedFusedMoEMethod) + AscendUnquantizedFusedMoEMethod, + unified_apply_mlp) from vllm_ascend.ops.layers.experts_selector import select_experts -from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402 +from vllm_ascend.utils import AscendSocVersion, adapt_patch adapt_patch(True) @@ -56,7 +60,73 @@ def mock_npu_format_cast(weight_data, format): @pytest.fixture def mock_dist_env(mocker: MockerFixture): - # init dist env patch + mock_setup_token_dispatchers = MagicMock() + mock_token_dispatcher_with_allgather = MagicMock() + mock_token_dispatcher_with_all2allv = MagicMock() + mock_token_dispatcher_with_mc2 = MagicMock() + + mock_dispatch_result_allgather = { + "hidden_states": torch.randn(16, 2), + "group_list": torch.tensor([8, 16], dtype=torch.int64), + "group_list_type": 0, + } + mock_combine_result_allgather = torch.randn(16, 2) + + mock_token_dispatcher_with_allgather.token_dispatch.return_value = mock_dispatch_result_allgather + mock_token_dispatcher_with_allgather.token_combine.return_value = mock_combine_result_allgather + + mock_dispatch_result_all2allv = { + "hidden_states": torch.randn(16, 2), + "group_list": torch.tensor([4, 8, 12, 16], dtype=torch.int64), + "group_list_type": 1, + "dynamic_scale": None, + } + mock_combine_result_all2allv = torch.randn(16, 2) + mock_token_dispatcher_with_all2allv.token_dispatch.return_value = mock_dispatch_result_all2allv + mock_token_dispatcher_with_all2allv.token_combine.return_value = mock_combine_result_all2allv + + mock_dispatch_result_mc2 = { + "hidden_states": torch.randn(16, 2), + "group_list": torch.tensor([5, 10, 15, 16], dtype=torch.int64), + "group_list_type": 1, + "dynamic_scale": None, + "assist_info_for_combine": torch.randn(16, 2), + "ep_recv_counts": torch.tensor([4, 4, 4, 4], dtype=torch.int32), + } + mock_combine_result_mc2 = torch.randn(16, 2) + mock_token_dispatcher_with_mc2.token_dispatch.return_value = mock_dispatch_result_mc2 + mock_token_dispatcher_with_mc2.token_combine.return_value = mock_combine_result_mc2 + + captured_dispatchers = {} + + def capture_register(dispatcher_instance): + key = dispatcher_instance.__class__.__name__ + captured_dispatchers[key] = dispatcher_instance + if key == 'TokenDispatcherWithAllGather': + captured_dispatchers[key] = mock_token_dispatcher_with_allgather + elif key == 'TokenDispatcherWithAll2AllV': + captured_dispatchers[key] = mock_token_dispatcher_with_all2allv + elif key == 'TokenDispatcherWithMC2': + captured_dispatchers[key] = mock_token_dispatcher_with_mc2 + + mock_register_token_dispatcher_patcher = patch( + 'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher', + side_effect=capture_register) + + mock_get_token_dispatcher_patcher = patch( + 'vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_token_dispatcher', + side_effect=lambda name: captured_dispatchers.get(name)) + + default_mock_token_dispatcher = mock_token_dispatcher_with_allgather + + mock_forward_context_obj = MagicMock( + fused_moe_state=FusedMoEState.AllGather, + token_dispatcher=default_mock_token_dispatcher, + max_tokens_across_dp=10, + dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]), + mc2_mask=torch.zeros(16, dtype=torch.bool), + padded_num_tokens=16, + with_quant=False) with patch('torch.distributed.get_rank', return_value=0), \ patch('torch.distributed.get_world_size', return_value=4), \ @@ -66,12 +136,10 @@ def mock_dist_env(mocker: MockerFixture): patch('vllm.distributed.parallel_state.get_tp_group', return_value=mock_dp_and_tp_group(mocker)), \ patch('vllm_ascend.ops.fused_moe.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ patch('vllm.model_executor.layers.fused_moe.layer.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ - patch('torch.distributed.all_gather', return_value=MagicMock(return_value=torch.randn(10,32))), \ - patch('torch.distributed.all_to_all_single', return_value=torch.randn(8, 32)), \ - patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce', - return_value=torch.randn(5, 32)), \ - patch('vllm_ascend.ops.fused_moe.data_parallel_reduce_scatter', - return_value=torch.randn(5, 32)), \ + patch('torch.distributed.all_gather'), \ + patch('torch.distributed.all_to_all_single'), \ + patch('vllm_ascend.ops.fused_moe.tensor_model_parallel_all_reduce'), \ + patch('vllm_ascend.ops.fused_moe.data_parallel_reduce_scatter'), \ patch('vllm.model_executor.layers.fused_moe.config.get_dp_group', return_value=mock_dp_and_tp_group(mocker)), \ patch('vllm_ascend.ops.fused_moe.get_ascend_config', @@ -82,22 +150,31 @@ def mock_dist_env(mocker: MockerFixture): patch('vllm_ascend.ops.fused_moe.determine_expert_map', return_value=(3, torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]))), \ patch('vllm_ascend.ops.fused_moe.get_forward_context', - return_value=MagicMock( - max_tokens_across_dp=10, - dp_metadata=MagicMock(cu_tokens_across_dp_cpu=[5, 10]) - )), \ + return_value=mock_forward_context_obj), \ patch('vllm_ascend.ops.fused_moe.get_current_vllm_config', return_value=MagicMock( parallel_config=MagicMock(tensor_parallel_size=2), scheduler_config=MagicMock(max_num_seqs=4), model_config=MagicMock(max_model_len=2048) - )): - yield + )), \ + patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3), \ + patch.object(token_dispatcher_module, 'setup_token_dispatchers', mock_setup_token_dispatchers): + + yield { + 'mock_forward_context_obj': mock_forward_context_obj, + 'mock_token_dispatcher_with_allgather': + mock_token_dispatcher_with_allgather, + 'mock_token_dispatcher_with_all2allv': + mock_token_dispatcher_with_all2allv, + 'mock_token_dispatcher_with_mc2': mock_token_dispatcher_with_mc2, + } + + mock_register_token_dispatcher_patcher.stop() + mock_get_token_dispatcher_patcher.stop() @pytest.fixture def mock_moe_env(mocker: MockerFixture): - # init moe env patch with patch('torch_npu.npu_moe_gating_top_k', return_value=( torch.randn(8, 2), @@ -144,7 +221,6 @@ def mock_moe_env(mocker: MockerFixture): @pytest.fixture def default_moe_config(): - """default moe config""" return { 'num_experts': 8, 'top_k': 2, @@ -188,7 +264,6 @@ class MockQuantMethod(nn.Module): class MockFusedMoEMethod(FusedMoEMethodBase): - # TODO(bnell): also pass quant_config? moe = MagicMock() def __init__(self): @@ -223,13 +298,11 @@ class TestAscendFusedMoe: assert hasattr(layer, 'w13_weight') assert hasattr(layer, 'w2_weight') - # check group_topk with pytest.raises(AssertionError): error_config = default_moe_config.copy() error_config['use_grouped_topk'] = True layer = AscendFusedMoE(**error_config) - # check scoring_func with pytest.raises(ValueError): error_config = default_moe_config.copy() error_config['scoring_func'] = "random" @@ -254,14 +327,7 @@ class TestAscendFusedMoe: [None, None, False, 1, None], [None, None, True, 5, 1], [None, None, False, 5, 1]]) def test_forward(self, mock_dist_env, default_moe_config, others_param): - """ - 1 test has shared_experts - 2 test has top_k - 3 test is_prefill is true - 4 test single num_tokens(decode) - 5 test ep_size is 1 and is_prefill is true - 6 test ep_size is 1 and is_prefill is False - """ + top_k, shared_experts, is_prefill, num_tokens, ep_size = others_param inputs = torch.randn(num_tokens, 32) router_logits = torch.randn(num_tokens, 8) @@ -327,25 +393,42 @@ class TestAscendUnquantizedFusedMoEMethod: [[256, 4], [128, 1], [128, 1], [128, 4]]) def test_apply_without_expert_map(self, moe_method, mock_dist_env, mock_moe_env, others_param): - """ - 1 test is_deepseek_v3_r1=true and use fused_experts_with_all2all - 2 test use_select_experts and fused_experts - 3 test use select_gating_topk_softmax_experts and fused_experts - 4 test use select_experts and fused_experts_with_all2all_buffer - """ + global_num_experts, ep_size = others_param is_prefill = False is_deepseek_v3_r1 = global_num_experts == 256 + + if ep_size == 1: + selected_token_dispatcher = mock_dist_env[ + 'mock_token_dispatcher_with_allgather'] + elif ep_size < 16: + selected_token_dispatcher = mock_dist_env[ + 'mock_token_dispatcher_with_all2allv'] + else: + selected_token_dispatcher = mock_dist_env[ + 'mock_token_dispatcher_with_mc2'] + forward_context = MagicMock(fused_moe_state=_get_fused_moe_state( - ep_size, is_prefill, is_deepseek_v3_r1)) + ep_size, is_prefill, is_deepseek_v3_r1), + with_quant=False, + token_dispatcher=selected_token_dispatcher) + with patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context): moe_method.ep_size = ep_size x = torch.randn(8, 2, 2) router_logits = torch.randn(8, 8) layer = MagicMock() - layer.w13_weight = torch.randn(8, 16, 1) - layer.w2_weight = torch.randn(16, 8, 1) + local_num_experts = 2 + hidden_size = 2 + intermediate_size_per_partition = 4 + + layer.w13_weight = torch.randn(local_num_experts, + intermediate_size_per_partition * 2, + hidden_size) + layer.w2_weight = torch.randn(local_num_experts, hidden_size, + intermediate_size_per_partition) + result = moe_method.apply(layer=layer, x=x, router_logits=router_logits, @@ -354,29 +437,38 @@ class TestAscendUnquantizedFusedMoEMethod: global_num_experts=global_num_experts, is_prefill=is_prefill) - if ep_size == 1: - assert result.shape == (16, 2) - else: - assert result.shape == x.shape + expected_shape = (16, 2) + + assert result.shape == expected_shape @pytest.mark.parametrize("others_param", [[16, False], [1, True], [1, False], [4, False]]) def test_apply_with_expert_map(self, moe_method, mock_dist_env, mock_moe_env, others_param): - """ - 1 test use_select_experts and use fused_expters_with_mc2 - 2 test use_select_experts and fused_experts_with_all2all_buffer - 3 test use_select_experts and fused_experts_with_all2all - 4 test use_select_experts and fused_experts - """ + ep_size, alltoall_buffer = others_param is_prefill = False - forward_context = MagicMock( - fused_moe_state=_get_fused_moe_state(ep_size, is_prefill, True)) + + if ep_size == 1: + selected_token_dispatcher = mock_dist_env[ + 'mock_token_dispatcher_with_allgather'] + elif ep_size < 16: + selected_token_dispatcher = mock_dist_env[ + 'mock_token_dispatcher_with_all2allv'] + else: + selected_token_dispatcher = mock_dist_env[ + 'mock_token_dispatcher_with_mc2'] + + forward_context = MagicMock(fused_moe_state=_get_fused_moe_state( + ep_size, is_prefill, True), + with_quant=False, + token_dispatcher=selected_token_dispatcher) + with patch("vllm_ascend.ops.fused_moe.MOE_ALL2ALL_BUFFER", alltoall_buffer), \ patch("vllm_ascend.ops.fused_moe.get_forward_context", return_value=forward_context), \ - patch("vllm_ascend.ops.fused_moe.get_ascend_soc_version", return_value=AscendSocVersion.A3): + patch("vllm_ascend.utils.get_ascend_soc_version", return_value=AscendSocVersion.A3): + expert_map = torch.tensor([0, 1, 2, -1, -1, -1, -1, -1]) moe_method.ep_size = ep_size x = torch.randn(8, 2, 2) @@ -386,8 +478,16 @@ class TestAscendUnquantizedFusedMoEMethod: if alltoall_buffer: moe_method.max_model_len = 1 layer = MagicMock() - layer.w13_weight = torch.randn(8, 16, 1) - layer.w2_weight = torch.randn(16, 8, 1) + + local_num_experts = 2 + hidden_size = 2 + intermediate_size_per_partition = 4 + layer.w13_weight = torch.randn(local_num_experts, + intermediate_size_per_partition * 2, + hidden_size) + layer.w2_weight = torch.randn(local_num_experts, hidden_size, + intermediate_size_per_partition) + result = moe_method.apply(layer=layer, x=x, router_logits=router_logits, @@ -397,10 +497,9 @@ class TestAscendUnquantizedFusedMoEMethod: expert_map=expert_map, is_prefill=is_prefill) - if ep_size == 16 or ep_size == 1: - assert result.shape == (16, 2) - else: - assert result.shape == x.shape + expected_shape = (16, 2) + + assert result.shape == expected_shape class TestExpertsSelector: @@ -426,3 +525,239 @@ class TestExpertsSelector: assert topk_weights.shape == (8, 2) assert topk_ids.shape == (8, 2) + + +class TestUnifiedApplyMLP(TestBase): + + @patch('vllm_ascend.ops.fused_moe.get_forward_context') + @patch('vllm_ascend.ops.fused_moe.get_mc2_group') + @patch('vllm_ascend.ops.fused_moe.is_310p') + @patch('torch_npu.npu_grouped_matmul') + @patch('torch_npu.npu_dynamic_quant') + @patch('torch_npu.npu_dequant_swiglu_quant') + def test_unified_apply_mlp_with_quantization_mc2(self, mock_npu_dequant, + mock_npu_dynamic_quant, + mock_npu_grouped_matmul, + mock_is_310p, + mock_get_mc2_group, + mock_get_forward_context): + + mock_forward_context = MagicMock() + mock_forward_context.with_quant = True + mock_forward_context.fused_moe_state = FusedMoEState.MC2 + mock_get_forward_context.return_value = mock_forward_context + + mock_mc2_group = MagicMock() + mock_get_mc2_group.return_value = mock_mc2_group + + mock_is_310p.return_value = False + + mock_npu_dynamic_quant.return_value = (torch.randint(-128, + 127, (10, 20), + dtype=torch.int8), + torch.rand(10, + 1, + dtype=torch.float32)) + + mock_npu_grouped_matmul.side_effect = [[ + torch.randint(-2147483648, 2147483647, (10, 40), dtype=torch.int32) + ], [torch.randn(10, 20, dtype=torch.bfloat16)]] + + mock_npu_dequant.return_value = (torch.randn(10, + 40, + dtype=torch.bfloat16), + torch.randn(10, + 1, + dtype=torch.float32)) + + hidden_states = torch.randn(10, 20, dtype=torch.bfloat16) + w1 = torch.randint(-128, 127, (5, 20, 40), dtype=torch.int8) + w1_scale = torch.randn(5, 40, dtype=torch.float32) + w2 = torch.randint(-128, 127, (5, 40, 20), dtype=torch.int8) + w2_scale = torch.randn(5, 20, dtype=torch.bfloat16) + group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64) + + result = unified_apply_mlp(hidden_states=hidden_states, + w1=w1, + w1_scale=w1_scale, + w2=w2, + w2_scale=w2_scale, + group_list=group_list, + dynamic_scale=None, + group_list_type=1, + w1_scale_bias=None, + w2_scale_bias=None, + topk_scales=None) + + mock_get_forward_context.assert_called() + self.assertTrue(mock_forward_context.with_quant) + self.assertEqual(mock_forward_context.fused_moe_state, + FusedMoEState.MC2) + + mock_npu_dynamic_quant.assert_called() + + self.assertEqual(mock_npu_grouped_matmul.call_count, 2) + + mock_npu_dequant.assert_called_once() + + self.assertEqual(result.dtype, torch.bfloat16) + + @patch('vllm_ascend.ops.fused_moe.get_forward_context') + @patch('vllm_ascend.ops.fused_moe.is_310p') + @patch('torch_npu.npu_grouped_matmul') + @patch('torch_npu.npu_swiglu') + @patch('torch_npu.npu_dynamic_quant') + def test_unified_apply_mlp_without_quantization( + self, mock_npu_dynamic_quant, mock_npu_swiglu, + mock_npu_grouped_matmul, mock_is_310p, mock_get_forward_context): + + mock_forward_context = MagicMock() + mock_forward_context.with_quant = False + mock_get_forward_context.return_value = mock_forward_context + + mock_is_310p.return_value = False + + mock_npu_grouped_matmul.side_effect = [[ + torch.randn(10, 40, dtype=torch.float16) + ], [torch.randn(10, 20, dtype=torch.float16)]] + mock_npu_swiglu.return_value = torch.randn(10, 40, dtype=torch.float16) + mock_npu_dynamic_quant.return_value = (MagicMock(), MagicMock()) + + hidden_states = torch.randn(10, 20, dtype=torch.float16) + w1 = torch.randn(5, 20, 40, dtype=torch.float16) + w2 = torch.randn(5, 40, 20, dtype=torch.float16) + group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64) + topk_scales = torch.randn(10, 1, dtype=torch.float16) + + result = unified_apply_mlp(hidden_states=hidden_states, + w1=w1, + w1_scale=None, + w2=w2, + w2_scale=None, + group_list=group_list, + dynamic_scale=None, + group_list_type=1, + w1_scale_bias=None, + w2_scale_bias=None, + topk_scales=topk_scales) + + mock_get_forward_context.assert_called() + self.assertFalse(mock_forward_context.with_quant) + + self.assertEqual(mock_npu_grouped_matmul.call_count, 2) + mock_npu_swiglu.assert_called_once() + + self.assertEqual(result.shape, hidden_states.shape) + self.assertEqual(result.dtype, torch.float16) + + @patch('vllm_ascend.ops.fused_moe.get_forward_context') + @patch('torch_npu.npu_grouped_matmul') + @patch('torch_npu.npu_swiglu') + @patch('torch_npu.npu_dynamic_quant') + def test_unified_apply_mlp_with_quantization_and_dynamic_scale( + self, mock_npu_dynamic_quant, mock_npu_swiglu, + mock_npu_grouped_matmul, mock_get_forward_context): + + mock_forward_context = MagicMock() + mock_forward_context.with_quant = True + mock_forward_context.fused_moe_state = "NOT_MC2" + mock_get_forward_context.return_value = mock_forward_context + + mock_npu_grouped_matmul.side_effect = [[ + torch.randn(10, 40, dtype=torch.bfloat16) + ], [torch.randn(10, 20, dtype=torch.bfloat16)]] + + mock_npu_swiglu.return_value = torch.randn(10, + 40, + dtype=torch.bfloat16) + + mock_npu_dynamic_quant.return_value = (torch.randint(-128, + 127, (10, 40), + dtype=torch.int8), + torch.rand(10, + 1, + dtype=torch.float32)) + + hidden_states = torch.randn(10, 20, dtype=torch.bfloat16) + w1 = torch.randn(5, 20, 40, dtype=torch.bfloat16) + w1_scale = torch.randn(5, 40, dtype=torch.bfloat16) + w2 = torch.randn(5, 40, 20, dtype=torch.bfloat16) + w2_scale = torch.randn(5, 20, dtype=torch.bfloat16) + w1_scale_bias = torch.randn(5, 40, dtype=torch.bfloat16) + w2_scale_bias = torch.randn(5, 20, dtype=torch.bfloat16) + group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64) + provided_dynamic_scale = torch.rand(10, 1, dtype=torch.float32) + + result = unified_apply_mlp(hidden_states=hidden_states, + w1=w1, + w1_scale=w1_scale, + w2=w2, + w2_scale=w2_scale, + group_list=group_list, + dynamic_scale=provided_dynamic_scale, + group_list_type=1, + w1_scale_bias=w1_scale_bias, + w2_scale_bias=w2_scale_bias, + topk_scales=None) + + mock_get_forward_context.assert_called() + self.assertTrue(mock_forward_context.with_quant) + + self.assertEqual(mock_npu_grouped_matmul.call_count, 2) + mock_npu_swiglu.assert_called_once() + mock_npu_dynamic_quant.assert_called_once() + + self.assertEqual(result.shape, hidden_states.shape) + self.assertEqual(result.dtype, torch.bfloat16) + + @patch('vllm_ascend.ops.fused_moe.get_forward_context') + @patch('vllm_ascend.ops.fused_moe.is_310p') + @patch('torch_npu.npu_grouped_matmul') + @patch('torch_npu.npu_swiglu') + @patch('torch_npu.npu_dynamic_quant') + def test_unified_apply_mlp_without_quantization_310p( + self, mock_npu_dynamic_quant, mock_npu_swiglu, + mock_npu_grouped_matmul, mock_is_310p, mock_get_forward_context): + + mock_forward_context = MagicMock() + mock_forward_context.with_quant = False + mock_get_forward_context.return_value = mock_forward_context + + mock_is_310p.return_value = True + + mock_gmm1_out = torch.randn(10, 40, dtype=torch.float16) + mock_gmm2_out = torch.randn(10, 20, dtype=torch.float16) + mock_npu_grouped_matmul.side_effect = [[mock_gmm1_out], + [mock_gmm2_out]] + + mock_npu_swiglu.return_value = torch.randn(10, 40, dtype=torch.float16) + + mock_npu_dynamic_quant.return_value = (MagicMock(), MagicMock()) + + hidden_states = torch.randn(10, 20, dtype=torch.float16) + w1 = torch.randn(5, 20, 40, dtype=torch.float16) + w2 = torch.randn(5, 40, 20, dtype=torch.float16) + group_list = torch.tensor([2, 4, 6, 8, 10], dtype=torch.int64) + topk_scales = torch.randn(10, 1, dtype=torch.float16) + + result = unified_apply_mlp(hidden_states=hidden_states, + w1=w1, + w1_scale=None, + w2=w2, + w2_scale=None, + group_list=group_list, + dynamic_scale=None, + group_list_type=1, + w1_scale_bias=None, + w2_scale_bias=None, + topk_scales=topk_scales) + + mock_get_forward_context.assert_called() + self.assertFalse(mock_forward_context.with_quant) + mock_is_310p.assert_called_once() + + self.assertEqual(mock_npu_grouped_matmul.call_count, 2) + mock_npu_swiglu.assert_called_once() + + self.assertEqual(result.shape, hidden_states.shape) + self.assertEqual(result.dtype, torch.float16) diff --git a/tests/ut/ops/test_token_dispatcher.py b/tests/ut/ops/test_token_dispatcher.py index 724c1be..77f40fa 100644 --- a/tests/ut/ops/test_token_dispatcher.py +++ b/tests/ut/ops/test_token_dispatcher.py @@ -25,8 +25,8 @@ from tests.ut.base import PytestBase, TestBase from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( AscendSocVersion, MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig, TokenDispatcherWithAll2AllV, TokenDispatcherWithAllGather, - TokenDispatcherWithMC2) -from vllm_ascend.utils import adapt_patch # noqa E402 + TokenDispatcherWithMC2, _Dispatchers, _register_token_dispatcher, + get_token_dispatcher, setup_token_dispatchers) class TestMoEAlltoAllSeqOverLapDispatcher(PytestBase): @@ -90,7 +90,7 @@ class TestTokenDispatcherWithMC2(TestBase): self.forward_context = MagicMock() self.forward_context.mc2_mask = torch.tensor([1, 0, 1]) self.forward_context_patch = patch( - "vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_forward_context", + "vllm.forward_context.get_forward_context", return_value=self.forward_context) self.forward_context_patch.start() @@ -100,28 +100,18 @@ class TestTokenDispatcherWithMC2(TestBase): return_value=AscendSocVersion.A3) self.ascend_soc_version_patch.start() - # Mock get_ascend_config() - self.ascend_config = MagicMock() - self.ascend_config.torchair_graph_config.enabled = False - self.ascend_config_patch = patch( - "vllm_ascend.ops.moe_dispatcher.token_dispatcher.get_ascend_config", - return_value=self.ascend_config) - self.ascend_config_patch.start() - kwargs = {"with_quant": False, "top_k": 8, "num_experts": 128} self.dispatcher = TokenDispatcherWithMC2(**kwargs) + self.row_idx = torch.arange(10, dtype=torch.int32) def tearDown(self): self.mc2_group_patch.stop() self.forward_context_patch.stop() self.ascend_soc_version_patch.stop() - self.ascend_config_patch.stop() def test_init(self): - # self.assertEqual(self.dispatcher.moe_all_to_all_group_name, "hccl_123") self.assertEqual(self.dispatcher.ep_rank_id, 0) self.assertEqual(self.dispatcher.ep_world_size, 8) - self.assertFalse(self.dispatcher.torchair_graph_enabled) self.assertFalse(self.dispatcher.with_quant) self.assertTrue(self.dispatcher.enable_dispatch_v2) self.assertTrue(self.dispatcher.need_extra_args) @@ -149,9 +139,10 @@ class TestTokenDispatcherWithMC2(TestBase): return_value=(torch.randn(10, 128), ) * 5) as mock_dispatch: output = self.dispatcher.token_dispatch(hidden_states, topk_weights, topk_ids, - expert_map) + self.row_idx, expert_map) mock_dispatch.assert_called_once() - self.assertEqual(output[0], 1) # group_list_type == 1 + self.assertEqual(output["group_list_type"], + 1) # group_list_type == 1 def test_token_dispatch_with_shared_experts_and_quant(self): self.shared_experts = MagicMock() @@ -166,20 +157,13 @@ class TestTokenDispatcherWithMC2(TestBase): with patch("torch_npu.npu_moe_distribute_dispatch_v2", return_value=(torch.randn(10, 128), ) * 5): - with patch( - "vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_stream_switch", - autospec=True): - with patch( - "vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_wait_tensor", - autospec=True) as mock_wait: - self.dispatcher.token_dispatch( - self.hidden_states, - self.topk_weights, - torch.randint(0, 8, (10, 1)), - torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]), - shared_experts=self.shared_experts) - mock_wait.assert_any_call(self.hidden_states, - self.topk_weights) + self.dispatcher.token_dispatch(self.hidden_states, + self.topk_weights, + torch.randint(0, 8, (10, 1)), + self.row_idx, + torch.tensor( + [0, 1, 2, 3, 4, 5, 6, 7]), + shared_experts=self.shared_experts) def test_get_combine_mc_kwargs_with_quant(self): self.dispatcher.with_quant = True @@ -213,13 +197,7 @@ class TestTokenDispatcherWithMC2(TestBase): with patch("torch_npu.npu_moe_distribute_combine_v2", return_value=torch.randn(10, 128)): - with patch( - "vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_stream_switch", - autospec=True): - with patch( - "vllm_ascend.ops.moe_dispatcher.token_dispatcher.npu_wait_tensor", - autospec=True): - self.dispatcher.token_combine(self.hidden_states) + self.dispatcher.token_combine(self.hidden_states) class TestTokenDispatcherWithAllGather(TestBase): @@ -257,6 +235,7 @@ class TestTokenDispatcherWithAllGather(TestBase): self.mock_moe_finalize_routing = self.patcher_moe_finalize_routing.start( ) self.mock_moe_finalize_routing.return_value = torch.randn(3, 128) + self.row_idx = torch.arange(10, dtype=torch.int32) def tearDown(self): self.patcher_moe_init_routing.stop() @@ -268,14 +247,14 @@ class TestTokenDispatcherWithAllGather(TestBase): topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]) topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]]) - group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher.token_dispatch( - hidden_states, topk_weights, topk_ids, None) + results = self.dispatcher.token_dispatch(hidden_states, topk_weights, + topk_ids, self.row_idx, None) # Verify npu_moe_init_routing is called self.mock_moe_init_routing.assert_called_once() args, kwargs = self.mock_moe_init_routing.call_args - self.assertEqual(group_list_type, 0) + self.assertEqual(results["group_list_type"], 0) def test_token_dispatch_with_quant(self): kwargs = { @@ -292,11 +271,11 @@ class TestTokenDispatcherWithAllGather(TestBase): topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]) topk_ids = torch.tensor([[0, 1], [1, 2], [2, 3]]) - group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher_quant.token_dispatch( - hidden_states, topk_weights, topk_ids, None) + results = self.dispatcher_quant.token_dispatch(hidden_states, + topk_weights, topk_ids, + self.row_idx, None) - # Verify quant mode returns group_list_type=1 - self.assertEqual(group_list_type, 0) + self.assertEqual(results["group_list_type"], 0) def test_token_combine_with_expert_map(self): self.dispatcher.expert_map = torch.tensor([0, 1, 2, 3]) @@ -337,19 +316,9 @@ class TestTokenDispatcherWithAllGather(TestBase): topk_weights = torch.tensor([[0.7], [0.6], [0.5]]) # topk=1 topk_ids = torch.tensor([[0], [1], [2]]) - group_list_type, sorted_hidden_states, expert_tokens = self.dispatcher.token_dispatch( - hidden_states, topk_weights, topk_ids, None) - self.assertEqual(sorted_hidden_states.shape, (6, 128)) - - def test_token_dispatch_invalid_topk_when_router_weight(self): - self.dispatcher.apply_router_weight_on_input = True - hidden_states = torch.randn(3, 128) - topk_weights = torch.tensor([[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]]) - - with self.assertRaises(AssertionError): - self.dispatcher.token_dispatch( - hidden_states, topk_weights, - torch.tensor([[0, 1], [1, 2], [2, 3]]), None) + results = self.dispatcher.token_dispatch(hidden_states, topk_weights, + topk_ids, None) + self.assertEqual(results["hidden_states"].shape, (6, 128)) class TestTokenDispatcherWithAll2AllV(TestBase): @@ -443,6 +412,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase): num_experts=4, num_local_experts=2, with_quant=False) + self.row_idx = torch.arange(10, dtype=torch.int32) def test_token_dispatch(self): hidden_states = torch.randn(8, 16) @@ -457,6 +427,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase): result = self.dispatcher.token_dispatch(hidden_states=hidden_states, topk_weights=topk_weights, topk_ids=topk_ids, + row_idx=self.row_idx, expert_map=expert_map) self.assertIsNotNone(result["hidden_states"]) @@ -504,6 +475,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase): result = self.dispatcher.token_dispatch(hidden_states=hidden_states, topk_weights=topk_weights, topk_ids=topk_ids, + row_idx=self.row_idx, expert_map=expert_map) self.assertIsNotNone(result["hidden_states"]) @@ -532,6 +504,7 @@ class TestTokenDispatcherWithAll2AllV(TestBase): result = self.dispatcher.token_dispatch(hidden_states=hidden_states, topk_weights=topk_weights, topk_ids=topk_ids, + row_idx=self.row_idx, expert_map=expert_map) self.assertIsNotNone(result["hidden_states"]) @@ -553,9 +526,126 @@ class TestTokenDispatcherWithAll2AllV(TestBase): result = self.dispatcher.token_dispatch(hidden_states=hidden_states, topk_weights=topk_weights, topk_ids=topk_ids, + row_idx=self.row_idx, expert_map=expert_map, log2phy=log2phy) self.assertIsNotNone(result["hidden_states"]) self.assertIsNotNone(result["group_list"]) self.assertEqual(result["group_list_type"], 1) + + +class TestDispatcherRegistry(TestBase): + + def setUp(self): + _Dispatchers.clear() + + def tearDown(self): + _Dispatchers.clear() + + def test_register_and_get_token_dispatcher(self): + mock_dispatcher = MagicMock() + mock_dispatcher.__class__.__name__ = "MockDispatcher" + + _register_token_dispatcher(mock_dispatcher) + + self.assertIn("MockDispatcher", _Dispatchers) + self.assertIs(_Dispatchers["MockDispatcher"], mock_dispatcher) + + retrieved_dispatcher = get_token_dispatcher("MockDispatcher") + self.assertIs(retrieved_dispatcher, mock_dispatcher) + + self.assertIsNone(get_token_dispatcher("NonExistentDispatcher")) + + @patch( + 'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAllGather' + ) + @patch( + 'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher' + ) + def test_setup_token_dispatchers_ep_size_1_creates_allgather( + self, mock_register, mock_allgather_class): + kwargs = {"top_k": 2, "num_experts": 8} + mock_instance = MagicMock() + mock_allgather_class.return_value = mock_instance + + self.assertNotIn("TokenDispatcherWithAllGather", _Dispatchers) + + setup_token_dispatchers(ep_size=1, **kwargs) + + mock_allgather_class.assert_called_once_with(**kwargs) + mock_register.assert_called_once_with(mock_instance) + + @patch( + 'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV' + ) + @patch( + 'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher' + ) + def test_setup_token_dispatchers_ep_size_2_creates_all2allv( + self, mock_register, mock_all2allv_class): + kwargs = {"top_k": 2, "num_experts": 16, "num_local_experts": 2} + mock_instance = MagicMock() + mock_all2allv_class.return_value = mock_instance + + self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers) + + setup_token_dispatchers(ep_size=2, **kwargs) + + mock_all2allv_class.assert_called_once_with(**kwargs) + mock_register.assert_called_once_with(mock_instance) + + @patch( + 'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV' + ) + @patch( + 'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithMC2' + ) + @patch( + 'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher' + ) + def test_setup_token_dispatchers_ep_size_16_creates_all2allv_and_mc2( + self, mock_register, mock_mc2_class, mock_all2allv_class): + kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2} + mock_all2allv_instance = MagicMock() + mock_mc2_instance = MagicMock() + mock_all2allv_class.return_value = mock_all2allv_instance + mock_mc2_class.return_value = mock_mc2_instance + + self.assertNotIn("TokenDispatcherWithAll2AllV", _Dispatchers) + self.assertNotIn("TokenDispatcherWithMC2", _Dispatchers) + + setup_token_dispatchers(ep_size=16, **kwargs) + + mock_all2allv_class.assert_called_once_with(**kwargs) + mock_mc2_class.assert_called_once_with(**kwargs) + self.assertEqual(mock_register.call_count, 2) + mock_register.assert_any_call(mock_all2allv_instance) + mock_register.assert_any_call(mock_mc2_instance) + + @patch( + 'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithAll2AllV' + ) + @patch( + 'vllm_ascend.ops.moe_dispatcher.token_dispatcher.TokenDispatcherWithMC2' + ) + @patch( + 'vllm_ascend.ops.moe_dispatcher.token_dispatcher._register_token_dispatcher' + ) + def test_setup_token_dispatchers_ep_size_16_skips_if_exist( + self, mock_register, mock_mc2_class, mock_all2allv_class): + kwargs = {"top_k": 2, "num_experts": 32, "num_local_experts": 2} + mock_existing_all2allv = MagicMock() + mock_existing_mc2 = MagicMock() + _Dispatchers["TokenDispatcherWithAll2AllV"] = mock_existing_all2allv + _Dispatchers["TokenDispatcherWithMC2"] = mock_existing_mc2 + + setup_token_dispatchers(ep_size=16, **kwargs) + + mock_all2allv_class.assert_not_called() + mock_mc2_class.assert_not_called() + mock_register.assert_not_called() + self.assertIs(_Dispatchers["TokenDispatcherWithAll2AllV"], + mock_existing_all2allv) + self.assertIs(_Dispatchers["TokenDispatcherWithMC2"], + mock_existing_mc2) diff --git a/tests/ut/quantization/test_w8a8_dynamic.py b/tests/ut/quantization/test_w8a8_dynamic.py deleted file mode 100644 index 0e07eb1..0000000 --- a/tests/ut/quantization/test_w8a8_dynamic.py +++ /dev/null @@ -1,82 +0,0 @@ -from unittest.mock import MagicMock, patch - -import torch - -from tests.ut.base import TestBase -from vllm_ascend.quantization.w8a8_dynamic import fused_experts_with_all2all - - -class TestAscendW8A8FusedMoEMethod(TestBase): - - def setUp(self): - self.hidden_size = 128 - self.num_tokens = 128 - self.placeholder = torch.randn(self.num_tokens, - self.hidden_size, - dtype=torch.bfloat16) - - @patch("torch.distributed.all_to_all_single") - @patch("torch_npu.npu_moe_re_routing") - @patch("torch_npu.npu_grouped_matmul") - @patch("torch_npu.npu_swiglu") - @patch("torch_npu.npu_dynamic_quant") - @patch("torch_npu.npu_moe_finalize_routing") - @patch("torch_npu.npu_moe_init_routing") - def test_fused_experts_with_all2all(self, mock_moe_init_routing, - mock_moe_finalize_routing, - mock_dynamic_quant, mock_swiglu, - mock_grouped_matmul, - mock_moe_re_routing, - mock_all_to_all_single): - expert_map = MagicMock() - ep_group = MagicMock() - placeholder_int8 = torch.randint(0, - 100, - (self.num_tokens, self.hidden_size), - dtype=torch.int8) - placeholder_ones = torch.ones(self.num_tokens, dtype=torch.int32) - mock_all_to_all_single.side_effect = lambda output, input, *args, **kwargs: output.copy_( - input) - mock_moe_init_routing.return_value = ( - placeholder_int8, - placeholder_ones, - placeholder_ones, - ) - mock_moe_re_routing.return_value = (placeholder_int8, self.placeholder, - torch.randint(0, - 100, - (self.num_tokens, ), - dtype=torch.int32), - self.placeholder) - mock_grouped_matmul.return_value = self.placeholder - mock_swiglu.return_value = self.placeholder - mock_dynamic_quant.return_value = ( - placeholder_int8, - torch.randn(self.num_tokens), - ) - mock_moe_finalize_routing.return_value = self.placeholder - row_idx_len = self.num_tokens * 8 - row_idx = (torch.arange( - 0, - row_idx_len, - dtype=torch.int32, - ).view(8, -1).permute(1, 0).contiguous()) - - result = fused_experts_with_all2all( - hidden_states=self.placeholder, - w1=self.placeholder, - w1_scale=self.placeholder, - w2=self.placeholder, - w2_scale=self.placeholder, - topk_weights=self.placeholder, - topk_ids=self.placeholder, - row_idx=row_idx, - top_k=8, - expert_map=expert_map, - ep_group=ep_group, - log2phy=None, - global_redundant_expert_num=256, - ) - self.assertIsNotNone(result) - self.assertEqual(result.dtype, torch.bfloat16) - self.assertEqual(result.shape, (128, 128)) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index def0d35..3e48cf7 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -46,6 +46,18 @@ def _get_fused_moe_state(ep_size: int, with_prefill: bool, return FusedMoEState.MC2 +def get_dispatcher_name(ep_size: int, with_prefill: bool) -> str: + if ep_size == 1: + return "TokenDispatcherWithAllGather" + + if ep_size < 16: + return "TokenDispatcherWithAll2AllV" + + if with_prefill: + return "TokenDispatcherWithAll2AllV" + return "TokenDispatcherWithMC2" + + @contextmanager def set_ascend_forward_context( attn_metadata: Any, @@ -87,6 +99,14 @@ def set_ascend_forward_context( forward_context.fused_moe_state = fused_moe_state forward_context.in_profile_run = in_profile_run + with_quant = vllm_config.quant_config is not None + forward_context.with_quant = with_quant + from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \ + get_token_dispatcher + dispatcher_name = get_dispatcher_name(ep_size, with_prefill) + dispatcher = get_token_dispatcher(dispatcher_name) + forward_context.token_dispatcher = dispatcher + # NOTE: This cannot be set using set_forward_context # due to multiple warmups before actual capturing forward_context.capturing = False diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index b8509f0..24a1667 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -16,14 +16,14 @@ # Adapted from vllm/tests/kernels/test_moe.py import os -from typing import Any, Callable, Optional, Tuple, Union +from typing import Any, Callable, Optional import torch import torch.distributed as dist import torch_npu from torch import nn from vllm.config import get_current_vllm_config -from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank, +from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.distributed.parallel_state import (get_dp_group, get_ep_group, @@ -49,9 +49,8 @@ from vllm_ascend.ops.layers.experts_selector import select_experts from vllm_ascend.ops.moe_dispatcher.token_dispatcher import ( MoEAlltoAllSeqOverLapDispatcher, MoEDispatcherConfig) from vllm_ascend.ops.sequence_parallel import MetadataForPadding -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion, - dispose_tensor, get_all_reduce_merge_state, - get_ascend_soc_version, +from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, dispose_tensor, + get_all_reduce_merge_state, get_rm_router_logits_state, is_310p) MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER @@ -122,149 +121,6 @@ def process_topk_ids(topk_ids: torch.Tensor, expert_num: int, ep_size: int, return topk_ids_pad, unpad_indices -def fused_experts_with_mc2( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - moe_parallel_config: FusedMoEParallelConfig, - expert_map: torch.Tensor = None, - moe_all_to_all_group_name: Optional[str] = None, - shared_experts: Optional[Any] = None, - mc2_mask: Optional[torch.Tensor] = None, -) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - quant_mode = 0 - ep_rank_id = moe_parallel_config.ep_rank - ep_world_size = moe_parallel_config.ep_size - - # NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine - need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3) - - # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine - a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3 - - enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2") - - moe_expert_num = len(expert_map) - kwargs_mc2 = { - "x": hidden_states, - "expert_ids": topk_ids, - "expert_shard_type": 0, - "shared_expert_rank_num": 0, - "moe_expert_num": moe_expert_num, - "global_bs": 0, - } - - stage1_kwargs = { - "scales": None, - "quant_mode": quant_mode, - "group_ep": moe_all_to_all_group_name, - "ep_world_size": ep_world_size, - "ep_rank_id": ep_rank_id, - } - if need_extra_args: - stage1_kwargs.update({ - "group_tp": moe_all_to_all_group_name, - "tp_world_size": 1, - "tp_rank_id": 0, - }) - if a3_need_extra_args and enable_dispatch_v2: - stage1_kwargs.update({ - "x_active_mask": mc2_mask, - }) - - kwargs_mc2.update(stage1_kwargs) - - output = torch_npu.npu_moe_distribute_dispatch_v2( - **kwargs_mc2 - ) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch( - **kwargs_mc2) - # comm_stream.wait_stream(torch.npu.current_stream()) - expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, ep_recv_counts = output[ - 0:5] - - if shared_experts is not None: - shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states) - shared_act = shared_experts.act_fn(shared_gate_up) - - w1 = w1.transpose(1, 2) - - group_list = expert_token_nums.to(torch.int64) - gate_up_out_list = torch_npu.npu_grouped_matmul( - x=[expand_x], - weight=[w1], - split_item=2, - # 1 means count mode, to avoid cumulative operation of the group list - group_list_type=1, - group_type=0, - group_list=group_list, - )[0] - - gate_up_out = torch_npu.npu_swiglu(gate_up_out_list) - - w2 = w2.transpose(1, 2) - down_out_list = torch_npu.npu_grouped_matmul( - x=[gate_up_out], - weight=[w2], - split_item=2, - group_list_type=1, - group_type=0, - group_list=group_list, - )[0] - - # moeCombine - kwargs_mc2 = { - "expand_x": down_out_list, - "expert_ids": topk_ids, - "expert_scales": topk_weights.to(torch.float32), - "expert_shard_type": 0, - "shared_expert_rank_num": 0, - "moe_expert_num": moe_expert_num, - "global_bs": 0, - } - tp_recv_counts = output[5] - stage3_kwargs = { - "ep_send_counts": ep_recv_counts, - "group_ep": moe_all_to_all_group_name, - "ep_world_size": ep_world_size, - "ep_rank_id": ep_rank_id, - } - if enable_dispatch_v2: - stage3_kwargs.update({ - "assist_info_for_combine": - assist_info_for_combine, - }) - else: - stage3_kwargs.update({ - "expand_idx": assist_info_for_combine, - }) - if need_extra_args: - stage3_kwargs.update({ - "tp_send_counts": tp_recv_counts, - "group_tp": moe_all_to_all_group_name, - "tp_world_size": 1, - "tp_rank_id": 0, - }) - if a3_need_extra_args and enable_dispatch_v2: - stage3_kwargs.update({ - "x_active_mask": mc2_mask, - }) - kwargs_mc2.update(stage3_kwargs) - - hidden_states = torch_npu.npu_moe_distribute_combine_v2( - **kwargs_mc2 - ) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine( - **kwargs_mc2) - - if shared_experts is None: - return hidden_states - else: - shared_hidden_states, _ = shared_experts.down_proj(shared_act) - return hidden_states, shared_hidden_states - - def apply_mlp( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -318,248 +174,6 @@ def apply_mlp( return hidden_states -# currently expert parallelism implemented with all2all -# is under-optimized. -def fused_experts_with_all2all( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - row_idx: torch.Tensor, - top_k: int, - expert_map: torch.Tensor = None, - ep_group: GroupCoordinator = None, -): - original_shape = hidden_states.shape - if len(original_shape) == 3: - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - - num_tokens, _ = hidden_states.shape - num_experts = w1.shape[0] - - if expert_map is not None: - global_num_experts = len(expert_map) - local_num_experts = global_num_experts // ep_group.world_size - hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( - hidden_states, - row_idx=row_idx, - expert_idx=topk_ids, - active_num=num_tokens) - - global_expert_tokens = torch.bincount(expanded_expert_idx, - minlength=global_num_experts) - scatter_sizes = global_expert_tokens.view(ep_group.world_size, - -1).sum(-1) - - gather_sizes = torch.empty_like(scatter_sizes) - dist.all_to_all_single(gather_sizes, - scatter_sizes, - group=ep_group.device_group) - scatter_size_list = scatter_sizes.cpu().tolist() - gather_size_list = gather_sizes.cpu().tolist() - - expanded_expert_idx = expanded_expert_idx % local_num_experts - hidden_states = ep_group.all_to_all(hidden_states, 0, 0, - scatter_size_list, - gather_size_list) - local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0, - scatter_size_list, - gather_size_list) - - sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx) - - expert_tokens = torch_npu.npu_moe_compute_expert_tokens( - sorted_local_expert_idx, local_num_experts).to(torch.int64) - - hidden_states = hidden_states[sorted_idx] - else: - hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( - hidden_states, - row_idx=row_idx, - expert_idx=topk_ids, - active_num=num_tokens) - - expert_tokens = torch_npu.npu_moe_compute_expert_tokens( - expanded_expert_idx, num_experts) - expert_tokens = expert_tokens.to(torch.int64) - - w1 = w1.transpose(1, 2) - gate_up_out_list = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[w1], - split_item=2, - group_list_type=0, - group_type=0, - group_list=expert_tokens, - )[0] - - hidden_states = torch_npu.npu_swiglu(gate_up_out_list) - - w2 = w2.transpose(1, 2) - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[w2], - split_item=2, - group_list_type=0, - group_type=0, - group_list=expert_tokens, - )[0] - - if expert_map is not None: - resorted_idx = torch.argsort(sorted_idx) - hidden_states = hidden_states[resorted_idx] - hidden_states = ep_group.all_to_all(hidden_states, 0, 0, - gather_size_list, - scatter_size_list) - - final_hidden_states = torch_npu.npu_moe_finalize_routing( - hidden_states, - skip1=None, - skip2=None, - bias=None, - scales=topk_weights, - expanded_src_to_dst_row=expanded_row_idx, - export_for_source_row=topk_ids, - ) - else: - # TODO: Reorder device memory 2 times here, replace the current - # implementation here when suitable operators become available. - final_hidden_states = torch_npu.npu_moe_finalize_routing( - hidden_states, - skip1=None, - skip2=None, - bias=None, - scales=topk_weights, - expanded_src_to_dst_row=expanded_row_idx, - export_for_source_row=topk_ids, - ) - if len(original_shape) == 3: - final_hidden_states = final_hidden_states.view(original_shape) - return final_hidden_states - - -# currently expert parallelism implemented with all2all -# is under-optimized. -def fused_experts_with_all2all_buffer( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - row_idx: torch.Tensor, - top_k: int, - max_model_len: int, - global_batch_size: int, - expert_map: torch.Tensor = None, - ep_group: GroupCoordinator = None, -): - original_shape = hidden_states.shape - if len(original_shape) == 3: - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - - num_tokens, _ = hidden_states.shape - - global_num_experts = len(expert_map) - local_num_experts = global_num_experts // ep_group.world_size - row_idx_len = num_tokens * top_k - hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( - hidden_states, - row_idx=row_idx, - expert_idx=topk_ids, - active_num=num_tokens) - - max_row_per_ep_rank = (-(-global_batch_size // ep_group.world_size) * - max_model_len // ep_group.world_size + - 1) * top_k * 2 - expert_idx_buffer_scatter, unpad_indices = process_topk_ids( - expanded_expert_idx, global_num_experts, ep_group.world_size, - max_row_per_ep_rank, num_tokens, top_k) - hidden_states_pad_idx = torch.zeros( - expert_idx_buffer_scatter.shape, - dtype=expert_idx_buffer_scatter.dtype, - device=expert_idx_buffer_scatter.device) - non_pad_len = torch.sum((expert_idx_buffer_scatter - != global_num_experts).to(torch.int32)) - hidden_states_pad_idx[expert_idx_buffer_scatter != - global_num_experts] = torch.arange( - non_pad_len, - dtype=expert_idx_buffer_scatter.dtype, - device=hidden_states.device) - - hidden_states_buffer_scatter = hidden_states[hidden_states_pad_idx] - expert_idx_buffer_gather = torch.empty_like( - expert_idx_buffer_scatter, - dtype=expert_idx_buffer_scatter.dtype, - device=expert_idx_buffer_scatter.device) - hidden_states_buffer_gather = torch.empty_like( - hidden_states_buffer_scatter, - dtype=hidden_states_buffer_scatter.dtype, - device=hidden_states_buffer_scatter.device) - dist.all_to_all_single(expert_idx_buffer_gather, - expert_idx_buffer_scatter, - group=ep_group.device_group) - dist.all_to_all_single(hidden_states_buffer_gather, - hidden_states_buffer_scatter, - group=ep_group.device_group) - mask = expert_idx_buffer_gather != global_num_experts - local_expert_idx = expert_idx_buffer_gather[mask] - ep_group.rank * ( - global_num_experts // ep_group.world_size) - hidden_states = hidden_states_buffer_gather[mask] - idx_type = local_expert_idx.dtype - sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx.float()) - sorted_local_expert_idx = sorted_local_expert_idx.to(idx_type) - - expert_tokens = torch_npu.npu_moe_compute_expert_tokens( - sorted_local_expert_idx, local_num_experts).to(torch.int64) - hidden_states = hidden_states[sorted_idx] - group_list_type = 0 - - hidden_states = apply_mlp(hidden_states, - w1, - w2, - expert_tokens, - group_list_type=group_list_type) - - resorted_idx = torch.argsort(sorted_idx.float()).to(sorted_idx.dtype) - hidden_states = hidden_states[resorted_idx] - hidden_states_scatter = torch.zeros( - (mask.shape[0], hidden_states.shape[1]), - dtype=hidden_states.dtype, - device=hidden_states.device) - hidden_states_scatter[mask] = hidden_states - hidden_states_gatter = torch.empty_like( - hidden_states_scatter, - dtype=hidden_states_scatter.dtype, - device=hidden_states_scatter.device) - dist.all_to_all_single(hidden_states_gatter, - hidden_states_scatter, - group=ep_group.device_group) - hidden_states_gatter = hidden_states_gatter[expert_idx_buffer_scatter != - global_num_experts] - if hidden_states_gatter.shape[0] != row_idx_len: - hidden_states = torch.zeros((row_idx_len, hidden_states.shape[1]), - dtype=hidden_states.dtype, - device=hidden_states.device) - hidden_states[unpad_indices != -1] = hidden_states_gatter - else: - # TODO: Reorder device memory 2 times here, replace the current - hidden_states = hidden_states_gatter - final_hidden_states = torch_npu.npu_moe_finalize_routing( - hidden_states, - skip1=None, - skip2=None, - bias=None, - scales=topk_weights, - expanded_src_to_dst_row=expanded_row_idx, - export_for_source_row=topk_ids, - ) - - if len(original_shape) == 3: - final_hidden_states = final_hidden_states.view(original_shape) - return final_hidden_states - - def fused_experts_moge( hidden_states: torch.Tensor, w1: torch.Tensor, @@ -651,188 +265,228 @@ def fused_experts_moge( return final_hidden_states -def fused_experts_with_all2allv( - token_dispatcher, - probs, - routing_map, - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, -): - # Enable moe alltoallv, it's a balanced policy for precision and efficiency. - (share_experts_output, dispatched_input, - tokens_per_expert) = (token_dispatcher.token_permutation( - hidden_states, probs, routing_map)) - - expert_output = apply_mlp(dispatched_input, w1, w2, tokens_per_expert) - output, mlp_bias = token_dispatcher.token_unpermutation(expert_output) - return output - - -def fused_experts( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - row_idx: torch.Tensor, - top_k: int, - expert_map: torch.Tensor = None, - apply_router_weight_on_input: bool = False, - max_num_tokens: Optional[int] = None, -) -> torch.Tensor: - """ - Fused experts with top-k routing. - - Args: - hidden_states: Hidden states of shape (num_tokens, hidden_size). - w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size). - w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size). - topk_weights: Routing weights of shape (num_tokens, top_k). - topk_ids: Selected expert IDs of shape (num_tokens, top_k). - top_k: Number of experts to select. - expert_map: Expert mapping of shape (num_experts,). - - Returns: - hidden_states: Hidden states after routing. - """ - """ - # Check constraints. - assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" - assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" - assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" - assert w1.is_contiguous(), "Expert weights1 must be contiguous" - assert w2.is_contiguous(), "Expert weights2 must be contiguous" - """ - # if torch.distributed.get_rank() == 0: - # print(w1.shape) - # print(hidden_states.shape) - - original_shape = hidden_states.shape - # assert len(original_shape) == 2 - - num_tokens = hidden_states.shape[:-1].numel() - num_experts = w1.shape[0] - dtype = hidden_states.dtype - device = hidden_states.device - # assert dtype in [torch.float32, torch.float16, torch.bfloat16 - # ], "Only float32, float16, and bfloat16 are supported" - - if apply_router_weight_on_input: - assert (topk_weights.dim() == 2 - ), "`topk_weights` should be in shape (num_tokens, topk)" - _, topk = topk_weights.shape - assert ( - topk == 1 - ), "Only support topk=1 when `apply_router_weight_on_input` is True" - hidden_states = hidden_states * topk_weights.to(hidden_states.dtype) - - if expert_map is not None: - # Generate token indices and flatten - token_indices = (torch.arange(num_tokens, - device=device, - dtype=torch.int64).unsqueeze(1).expand( - -1, top_k).reshape(-1)) - - # Flatten token-to-expert mappings and map to local experts - weights_flat = topk_weights.view(-1) - experts_flat = topk_ids.view(-1) - local_experts_flat = expert_map[experts_flat] - - # Filter valid token-expert pairs - mask = local_experts_flat != -1 - filtered_weights = torch.where( - mask, weights_flat, torch.zeros_like(weights_flat)).to(dtype) - filtered_experts = torch.where( - mask, local_experts_flat, - torch.full_like(local_experts_flat, - num_experts)).to(topk_ids.dtype) - - # Sort by local expert IDs - sort_indices = torch.argsort(filtered_experts.view(torch.float32)) - sorted_token_indices = token_indices[sort_indices] - sorted_weights = filtered_weights[sort_indices] - - # Compute token counts with minlength of num_experts - # This is equivalent to but faster than: - # >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1] - token_counts = torch.zeros(num_experts + 1, - device=device, - dtype=torch.int64) - ones = torch.ones_like(filtered_experts, dtype=torch.int64) - token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones) - token_counts = token_counts[:num_experts] - expert_tokens = torch.cumsum(token_counts, dim=0, dtype=torch.int64) - - # Rearrange hidden_states - sorted_hidden_states = hidden_states[sorted_token_indices] +def quant_apply_mlp(hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + group_list: torch.Tensor, + dynamic_scale: torch.Tensor = None, + group_list_type: int = 1, + w1_scale_bias: torch.Tensor = None, + w2_scale_bias: torch.Tensor = None) -> torch.Tensor: + if dynamic_scale is None: + unquantized_hidden_states = hidden_states + hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant( + hidden_states) + # Dispose the original unquantized hidden states + # to save npu memory because they're no longer used. + dispose_tensor(unquantized_hidden_states) else: - active_num = max_num_tokens if max_num_tokens is not None else num_tokens - sorted_hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( - hidden_states, - row_idx=row_idx, - expert_idx=topk_ids, - active_num=active_num) + pertoken_scale = dynamic_scale - expert_tokens = torch_npu.npu_moe_compute_expert_tokens( - expanded_expert_idx, num_experts) - expert_tokens = expert_tokens.to(torch.int64) + bias1, bias2 = None, None + _output_dtype = w2_scale.dtype + is_mc2 = get_forward_context().fused_moe_state == FusedMoEState.MC2 + if w1_scale_bias is None and is_mc2: + w1_scale = w1_scale.to(torch.float32) + + # gmm1: gate_up_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w1], + split_item=3, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=torch.int32)[0] + + # act_fn: swiglu + hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant( + x=hidden_states, + weight_scale=w1_scale, + activation_scale=pertoken_scale, + bias=None, + quant_scale=None, + quant_offset=None, + group_index=group_list, + activate_left=True, + quant_mode=1, + ) + + # gmm2: down_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w2], + scale=[w2_scale], + per_token_scale=[swiglu_out_scale], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=w2_scale.dtype)[0] + else: + if w1_scale_bias is not None: + if group_list_type == 0: + group_list = torch.cat( + [group_list[:1], + torch.diff(group_list, dim=0)]) + group_list_type = 1 + bias1 = [w1_scale_bias] + bias2 = [w2_scale_bias] + # TODO w4a8 scene: dynamic acquisition of dtype in the future + _output_dtype = torch.bfloat16 + + # gmm1: gate_up_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w1], + scale=[w1_scale], + bias=bias1, + per_token_scale=[pertoken_scale], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=_output_dtype)[0] + + # act_fn: swiglu + hidden_states = torch_npu.npu_swiglu(hidden_states) + hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant( + hidden_states) + + # gmm2: down_proj + hidden_states = torch_npu.npu_grouped_matmul( + x=[hidden_states], + weight=[w2], + scale=[w2_scale], + bias=bias2, + per_token_scale=[swiglu_out_scale], + split_item=2, + group_list_type=group_list_type, + group_type=0, + group_list=group_list, + output_dtype=_output_dtype)[0] + return hidden_states + + +def unquant_apply_mlp( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + group_list: torch.Tensor, + group_list_type: int = 1, + topk_scales: Optional[torch.Tensor] = None) -> torch.Tensor: w1 = w1.transpose(1, 2) - gate_up_out_list = torch_npu.npu_grouped_matmul( - x=[sorted_hidden_states], + gate_up_out = torch_npu.npu_grouped_matmul( + x=[hidden_states], weight=[w1], split_item=2, - group_list_type=0, + group_list_type=group_list_type, group_type=0, - group_list=expert_tokens, + group_list=group_list, )[0] + if is_310p(): + gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to( + torch.float16) + else: + gate_up_out = torch_npu.npu_swiglu(gate_up_out) - gate_up_out = torch_npu.npu_swiglu(gate_up_out_list) + if topk_scales is not None: + gate_up_out *= topk_scales w2 = w2.transpose(1, 2) - down_out_list = torch_npu.npu_grouped_matmul( + hidden_states = torch_npu.npu_grouped_matmul( x=[gate_up_out], weight=[w2], split_item=2, - group_list_type=0, + group_list_type=group_list_type, group_type=0, - group_list=expert_tokens, + group_list=group_list, )[0] + return hidden_states - if expert_map is not None: - weighted_down_out = down_out_list * sorted_weights.unsqueeze(1) - final_hidden_states = torch.zeros(*original_shape, - device=hidden_states.device, - dtype=dtype) - - # TODO: npu_grouped_matmul output random values at [num_valid_tokens:, ...] - # This created multiple NaN and index_add_ will mix them up which harms accuracy - # remove this mask and filter after it being fixed - num_valid_tokens = mask.sum() - valid_token_mask = torch.arange( - 0, sorted_token_indices.shape[0], - device=device).unsqueeze(1) < num_valid_tokens - valid_output = torch.where( - valid_token_mask, weighted_down_out, - torch.zeros_like(weighted_down_out)).to(dtype) - final_hidden_states.index_add_(0, sorted_token_indices, valid_output) +def unified_apply_mlp( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w1_scale: torch.Tensor, + w2: torch.Tensor, + w2_scale: torch.Tensor, + group_list: torch.Tensor, + dynamic_scale: torch.Tensor = None, + group_list_type: int = 1, + w1_scale_bias: torch.Tensor = None, + w2_scale_bias: torch.Tensor = None, + topk_scales: Optional[torch.Tensor] = None) -> torch.Tensor: + if get_forward_context().with_quant: + return quant_apply_mlp(hidden_states=hidden_states, + w1=w1, + w1_scale=w1_scale, + w2=w2, + w2_scale=w2_scale, + group_list=group_list, + dynamic_scale=dynamic_scale, + group_list_type=group_list_type, + w1_scale_bias=w1_scale_bias, + w2_scale_bias=w2_scale_bias) else: - scales = torch.ones_like( - topk_weights) if apply_router_weight_on_input else topk_weights - # TODO: Reorder device memory 2 times here, replace the current - # implementation here when suitable operators become available. - final_hidden_states = torch_npu.npu_moe_finalize_routing( - down_out_list, - skip1=None, - skip2=None, - bias=None, - scales=scales, - expanded_src_to_dst_row=expanded_row_idx, - export_for_source_row=topk_ids, - ) + return unquant_apply_mlp(hidden_states=hidden_states, + w1=w1, + w2=w2, + group_list=group_list, + group_list_type=group_list_type, + topk_scales=topk_scales) + +def unified_fused_experts_eager(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + row_idx: torch.Tensor, + expert_map: Optional[torch.Tensor] = None, + log2phy: Optional[torch.Tensor] = None, + global_redundant_expert_num: int = 0, + w1_scale: Optional[torch.Tensor] = None, + w1_scale_bias: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + w2_scale_bias: Optional[torch.Tensor] = None, + shared_experts: Optional[torch.Tensor] = None, + shared_gate_up: Optional[Any] = None, + shared_dequant_scale: Optional[Any] = None, + mc2_mask: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False): + token_dispatcher = get_forward_context().token_dispatcher + + results = token_dispatcher.token_dispatch( + hidden_states=hidden_states, + topk_weights=topk_weights, + topk_ids=topk_ids, + row_idx=row_idx, + expert_map=expert_map, + log2phy=log2phy, + global_redundant_expert_num=global_redundant_expert_num, + shared_experts=shared_experts, + shared_gate_up=shared_gate_up, + shared_dequant_scale=shared_dequant_scale, + mc2_mask=mc2_mask, + apply_router_weight_on_input=apply_router_weight_on_input) + + expert_output = unified_apply_mlp( + hidden_states=results["hidden_states"], + w1=w1, + w1_scale=w1_scale, + w2=w2, + w2_scale=w2_scale, + group_list=results["group_list"], + dynamic_scale=results.get("dynamic_scale"), + group_list_type=results.get("group_list_type"), + w1_scale_bias=w1_scale_bias, + w2_scale_bias=w2_scale_bias, + topk_scales=results.get("topk_scales")) + final_hidden_states = token_dispatcher.token_combine(expert_output) return final_hidden_states @@ -914,65 +568,16 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): if enable_force_load_balance and not self.use_aclgraph: topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) - fused_moe_state = get_forward_context().fused_moe_state - - if fused_moe_state == FusedMoEState.MC2: - return fused_experts_with_mc2( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - moe_parallel_config=self.moe.moe_parallel_config, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - expert_map=expert_map, - moe_all_to_all_group_name=self.moe_all_to_all_group_name, - shared_experts=shared_experts, - mc2_mask=kwargs.get("mc2_mask", None)) - elif fused_moe_state in [ - FusedMoEState.AllGather, FusedMoEState.NaiveMulticast - ]: - return fused_experts(hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - row_idx=row_idx, - top_k=top_k, - expert_map=expert_map) - elif MOE_ALL2ALL_BUFFER: - return fused_experts_with_all2all_buffer( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - row_idx=row_idx, - top_k=top_k, - max_model_len=self.max_model_len, - global_batch_size=self.global_batch_size, - expert_map=expert_map, - ep_group=get_ep_group()) - elif fused_moe_state == FusedMoEState.All2AllSeq: - token_dispatcher = kwargs.get("token_dispatcher") - return fused_experts_with_all2allv( - token_dispatcher=token_dispatcher, - probs=topk_weights, - routing_map=topk_ids, - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - ) - else: - return fused_experts_with_all2all(hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - topk_weights=topk_weights, - topk_ids=topk_ids, - row_idx=row_idx, - top_k=top_k, - expert_map=expert_map, - ep_group=get_ep_group()) + return unified_fused_experts_eager(hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + topk_weights=topk_weights, + topk_ids=topk_ids, + row_idx=row_idx, + expert_map=expert_map, + shared_experts=shared_experts, + mc2_mask=kwargs.get( + "mc2_mask", None)) class AscendFusedMoE(FusedMoE): @@ -1154,6 +759,19 @@ class AscendFusedMoE(FusedMoE): self.token_dispatcher, token_dispatcher1 ] + ep_size = (get_ep_group().world_size if + vllm_config.parallel_config.enable_expert_parallel else 1) + with_quant = quant_config is not None + from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \ + setup_token_dispatchers + setup_token_dispatchers( + ep_size, + top_k=self.top_k, + num_experts=self.global_num_experts, + num_global_redundant_experts=self.global_redundant_expert_num, + num_local_experts=self.local_num_experts, + with_quant=with_quant) + def naive_multicast(self, x: torch.Tensor, cu_tokens_across_dp_cpu: torch.Tensor): assert (len(x.shape) == 2) diff --git a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py index e02652d..c0d85bb 100644 --- a/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py +++ b/vllm_ascend/ops/moe_dispatcher/token_dispatcher.py @@ -22,21 +22,18 @@ # limitations under the License. from abc import ABC, abstractmethod -from typing import Optional +from typing import Any, Dict, Optional import torch import torch_npu from vllm.distributed.parallel_state import get_ep_group -from vllm.forward_context import get_forward_context -from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.distributed.tensor_parallel import ( all_gather_last_dim_from_tensor_parallel_region, all_to_all_hp2sp, all_to_all_sp2hp, gather_from_sequence_parallel_region, reduce_scatter_last_dim_to_tensor_parallel_region) from vllm_ascend.ops.comm_utils import async_all_to_all -from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version @@ -460,6 +457,31 @@ class MoEAlltoAllSeqOverLapDispatcher(MoEDispatcher): return output, None +_Dispatchers: Dict[str, Any] = {} + + +def _register_token_dispatcher(dispatcher: Any): + _Dispatchers[dispatcher.__class__.__name__] = dispatcher + + +def get_token_dispatcher(name: str): + return _Dispatchers.get(name) + + +def setup_token_dispatchers(ep_size: int, **kwargs): + existing_dispatchers = set(_Dispatchers.keys()) + + if ep_size == 1 and "TokenDispatcherWithAllGather" not in existing_dispatchers: + _register_token_dispatcher(TokenDispatcherWithAllGather(**kwargs)) + elif ep_size < 16 and "TokenDispatcherWithAll2AllV" not in existing_dispatchers: + _register_token_dispatcher(TokenDispatcherWithAll2AllV(**kwargs)) + elif ep_size >= 16: + if "TokenDispatcherWithAll2AllV" not in existing_dispatchers: + _register_token_dispatcher(TokenDispatcherWithAll2AllV(**kwargs)) + if "TokenDispatcherWithMC2" not in existing_dispatchers: + _register_token_dispatcher(TokenDispatcherWithMC2(**kwargs)) + + class MoETokenDispatcher(ABC): def __init__(self, **kwargs) -> None: @@ -484,18 +506,19 @@ class MoETokenDispatcher(ABC): return get_ep_group().world_size @abstractmethod - def token_dispatch( - self, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - expert_map: torch.Tensor, - log2phy: Optional[torch.Tensor] = None, - global_redundant_expert_num: int = 0, - shared_gate_up: Optional[torch.Tensor] = None, - shared_dequant_scale: Optional[torch.Tensor] = None, - shared_experts: Optional[torch.Tensor] = None, - ): + def token_dispatch(self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + row_idx: torch.Tensor, + expert_map: Optional[torch.Tensor] = None, + log2phy: Optional[torch.Tensor] = None, + global_redundant_expert_num: int = 0, + shared_experts: Optional[torch.Tensor] = None, + shared_gate_up: Optional[torch.Tensor] = None, + shared_dequant_scale: Optional[torch.Tensor] = None, + mc2_mask: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False): raise NotImplementedError("Dispatch function not implemented.") @abstractmethod @@ -516,40 +539,39 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank) self.ep_rank_id = get_mc2_group().rank_in_group self.ep_world_size = get_mc2_group().world_size - ascend_config = get_ascend_config() - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2") - self.need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3 - or self.torchair_graph_enabled) + self.need_extra_args = ( + get_ascend_soc_version() == AscendSocVersion.A3) # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine self.a3_need_extra_args = \ get_ascend_soc_version() == AscendSocVersion.A3 self.output = None - self.dynamic_scale = None self.assist_info_for_combine = None self.ep_recv_counts = None self.shared_act = None self.topk_ids = None self.topk_weights = None self.shared_experts = None + self.mc2_mask = None - def get_dispatch_mc2_kwargs(self, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - expert_map: torch.Tensor, - global_redundant_expert_num: int = 0): - quant_mode = 0 - forward_context = get_forward_context() - mc2_mask = forward_context.mc2_mask + def get_dispatch_mc2_kwargs( + self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + expert_map: torch.Tensor, + global_redundant_expert_num: int = 0, + ): if self.with_quant: + quant_mode = 2 if (expert_map is not None): moe_expert_num = len(expert_map) + global_redundant_expert_num else: moe_expert_num = global_redundant_expert_num else: + quant_mode = 0 moe_expert_num = len(expert_map) kwargs_mc2 = { "x": hidden_states, @@ -575,28 +597,30 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): }) if self.a3_need_extra_args and self.enable_dispatch_v2: stage1_kwargs.update({ - "x_active_mask": mc2_mask, + "x_active_mask": self.mc2_mask, }) kwargs_mc2.update(stage1_kwargs) return kwargs_mc2 - def token_dispatch( - self, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - expert_map: torch.Tensor, - log2phy: Optional[torch.Tensor] = None, - global_redundant_expert_num: int = 0, - shared_gate_up: Optional[torch.Tensor] = None, - shared_dequant_scale: Optional[torch.Tensor] = None, - shared_experts: Optional[torch.Tensor] = None, - ): + def token_dispatch(self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + row_idx: torch.Tensor, + expert_map: Optional[torch.Tensor] = None, + log2phy: Optional[torch.Tensor] = None, + global_redundant_expert_num: int = 0, + shared_experts: Optional[torch.Tensor] = None, + shared_gate_up: Optional[torch.Tensor] = None, + shared_dequant_scale: Optional[torch.Tensor] = None, + mc2_mask: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False): self.expert_map = expert_map self.topk_ids = topk_ids self.topk_weights = topk_weights self.shared_experts = shared_experts + self.mc2_mask = mc2_mask kwargs_mc2 = self.get_dispatch_mc2_kwargs(hidden_states, topk_weights, topk_ids, expert_map, @@ -606,28 +630,27 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): ) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch( **kwargs_mc2) # comm_stream.wait_stream(torch.npu.current_stream()) - expand_x, self.dynamic_scale, self.assist_info_for_combine, \ + expand_x, dynamic_scale, self.assist_info_for_combine, \ expert_token_nums, self.ep_recv_counts = self.output[0:5] if self.with_quant: if shared_experts is not None: - with npu_stream_switch("moe_secondary", 0): - npu_wait_tensor(shared_gate_up, expand_x) - shared_act_out = shared_experts.act_fn( - (shared_gate_up, shared_dequant_scale)) - self.shared_act, self.swiglu_out_scale = \ - shared_act_out[0], shared_act_out[1] + shared_act_out = shared_experts.act_fn( + (shared_gate_up, shared_dequant_scale)) + self.shared_act, self.swiglu_out_scale = \ + shared_act_out[0], shared_act_out[1] else: if shared_experts is not None: - with npu_stream_switch("moe_secondary", 0): - npu_wait_tensor(hidden_states, topk_weights) - shared_gate_up, _ = shared_experts.gate_up_proj( - hidden_states) - npu_wait_tensor(shared_gate_up, expand_x) - self.shared_act = shared_experts.act_fn(shared_gate_up) + shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states) + self.shared_act = shared_experts.act_fn(shared_gate_up) group_list_type = 1 - return group_list_type, expand_x, expert_token_nums + return { + "group_list_type": group_list_type, + "hidden_states": expand_x, + "group_list": expert_token_nums, + "dynamic_scale": dynamic_scale, + } def get_combine_mc_kwargs(self, hidden_states: torch.Tensor): assert self.expert_map is not None @@ -635,8 +658,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): assert self.topk_ids is not None assert self.output is not None moe_expert_num = len(self.expert_map) - forward_context = get_forward_context() - mc2_mask = forward_context.mc2_mask # moeCombine kwargs_mc2 = { "expand_x": hidden_states, @@ -677,7 +698,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): }) if self.a3_need_extra_args and self.enable_dispatch_v2: stage3_kwargs.update({ - "x_active_mask": mc2_mask, + "x_active_mask": self.mc2_mask, }) kwargs_mc2.update(stage3_kwargs) return kwargs_mc2 @@ -685,7 +706,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): def token_combine(self, hidden_states: torch.Tensor, bias: torch.Tensor = None): - kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states) hidden_states = torch_npu.npu_moe_distribute_combine_v2( **kwargs_mc2 @@ -695,15 +715,11 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): return hidden_states else: if self.with_quant: - with npu_stream_switch("moe_secondary", 0): - npu_wait_tensor(self.shared_act, hidden_states) - shared_hidden_states, _ = self.shared_experts.down_proj( - (self.shared_act, self.swiglu_out_scale)) + shared_hidden_states, _ = self.shared_experts.down_proj( + (self.shared_act, self.swiglu_out_scale)) else: - with npu_stream_switch("moe_secondary", 0): - npu_wait_tensor(self.shared_act, hidden_states) - shared_hidden_states, _ = self.shared_experts.down_proj( - self.shared_act) + shared_hidden_states, _ = self.shared_experts.down_proj( + self.shared_act) return hidden_states, shared_hidden_states @@ -711,13 +727,9 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): def __init__(self, **kwargs): super().__init__(**kwargs) - self.apply_router_weight_on_input = kwargs.get( - "apply_router_weight_on_input") - self.top_k = kwargs.get("top_k") + self.apply_router_weight_on_input = False self.max_num_tokens = kwargs.get("max_num_tokens") - ep_size = kwargs.get("ep_size") - if ep_size is not None: - self.num_experts_local = self.num_experts // ep_size + self.num_experts_local = kwargs.get("num_local_experts", 0) self.sorted_weights = None self.expanded_row_idx = None self.sorted_token_indices = None @@ -727,20 +739,20 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): self.topk_weights = None self.topk_ids = None - def token_dispatch( - self, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - expert_map: torch.Tensor, - log2phy: Optional[torch.Tensor] = None, - global_redundant_expert_num: int = 0, - shared_gate_up: Optional[torch.Tensor] = None, - shared_dequant_scale: Optional[torch.Tensor] = None, - shared_experts: Optional[torch.Tensor] = None, - ): + def token_dispatch(self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + row_idx: torch.Tensor, + expert_map: Optional[torch.Tensor] = None, + log2phy: Optional[torch.Tensor] = None, + global_redundant_expert_num: int = 0, + shared_experts: Optional[torch.Tensor] = None, + shared_gate_up: Optional[torch.Tensor] = None, + shared_dequant_scale: Optional[torch.Tensor] = None, + mc2_mask: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False): self.original_shape = hidden_states.shape - # assert len(original_shape) == 2 num_tokens = hidden_states.shape[:-1].numel() dtype = hidden_states.dtype @@ -748,9 +760,7 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): self.expert_map = expert_map self.topk_weights = topk_weights self.topk_ids = topk_ids - # assert dtype in [torch.float32, torch.float16, torch.bfloat16 - # ], "Only float32, float16, and bfsloat16 are supported" - + self.apply_router_weight_on_input = apply_router_weight_on_input if self.apply_router_weight_on_input: assert (topk_weights.dim() == 2 ), "`topk_weights` should be in shape (num_tokens, topk)" @@ -803,19 +813,13 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): sorted_hidden_states = hidden_states[self.sorted_token_indices] if self.with_quant: group_list_type = 1 + expert_tokens = token_counts else: expert_tokens = torch.cumsum(token_counts, dim=0, dtype=torch.int64) group_list_type = 0 else: - row_idx_len = num_tokens * self.top_k - row_idx = (torch.arange(0, - row_idx_len, - dtype=torch.int32, - device=device).view(self.top_k, - -1).permute( - 1, 0).contiguous()) active_num = self.max_num_tokens if self.max_num_tokens is not None else num_tokens sorted_hidden_states, self.expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( hidden_states, @@ -827,18 +831,23 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): expanded_expert_idx, self.num_experts_local) expert_tokens = expert_tokens.to(torch.int64) group_list_type = 0 - return group_list_type, sorted_hidden_states, expert_tokens + return { + "group_list_type": group_list_type, + "hidden_states": sorted_hidden_states, + "group_list": expert_tokens, + } def token_combine(self, hidden_states: torch.Tensor, bias: torch.Tensor = None): - assert self.mask is not None - assert self.sorted_token_indices is not None - assert self.sorted_weights is not None assert self.original_shape is not None dtype = hidden_states.dtype device = hidden_states.device if self.expert_map is not None: + assert self.mask is not None + assert self.sorted_token_indices is not None + assert self.sorted_weights is not None + weighted_down_out = hidden_states * \ self.sorted_weights.unsqueeze(1) @@ -887,7 +896,6 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): expanded_src_to_dst_row=self.expanded_row_idx, export_for_source_row=self.topk_ids, ) - return final_hidden_states @@ -895,29 +903,27 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher): class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher): def __init__(self, **kwargs): - super(MoETokenDispatcher, self).__init__(**kwargs) - self.apply_router_weight_on_input = kwargs.get( - "apply_router_weight_on_input") - ep_size = kwargs.get("ep_size") - self.local_ep = ep_size - assert self.local_ep is not None + super().__init__(**kwargs) + self.apply_router_weight_on_input = False + self.local_ep = 1 self.local_num_experts = self.num_experts // self.local_ep self.local_num_group = self.top_k // self.local_ep self.bsz = None - def token_dispatch( - self, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - expert_map: torch.Tensor, - log2phy: Optional[torch.Tensor] = None, - global_redundant_expert_num: int = 0, - shared_gate_up: Optional[torch.Tensor] = None, - shared_dequant_scale: Optional[torch.Tensor] = None, - shared_experts: Optional[torch.Tensor] = None, - ): - + def token_dispatch(self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + row_idx: torch.Tensor, + expert_map: Optional[torch.Tensor] = None, + log2phy: Optional[torch.Tensor] = None, + global_redundant_expert_num: int = 0, + shared_experts: Optional[torch.Tensor] = None, + shared_gate_up: Optional[torch.Tensor] = None, + shared_dequant_scale: Optional[torch.Tensor] = None, + mc2_mask: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False): + self.apply_router_weight_on_input = apply_router_weight_on_input if self.apply_router_weight_on_input: assert (topk_weights.dim() == 2 ), "`topk_weights` should be in shape (num_tokens, topk)" @@ -932,7 +938,7 @@ class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher): flatten_topk_ids = topk_ids.view(-1) self.sorted_topk_ids = torch.argsort(flatten_topk_ids.float()) self.sorted_topk_ids = self.sorted_topk_ids.to(torch.int32) - self.sorted_hidden_states = hidden_states.index_select( + sorted_hidden_states = hidden_states.index_select( 0, self.sorted_topk_ids // self.local_num_group) experts_id = torch.arange(0, @@ -942,15 +948,20 @@ class UnquantizedTokenDispatcherWithFusedExpertsMoge(MoETokenDispatcher): num_tokens_per_expert = ( flatten_topk_ids.unsqueeze(-1) == experts_id).to( torch.float32).sum(0) - self.topk_scales = topk_weights.view(-1).index_select( + topk_scales = topk_weights.view(-1).index_select( 0, self.sorted_topk_ids).unsqueeze(-1) group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64) - return hidden_states, group_list + group_list_type = 0 + return { + "group_list_type": group_list_type, + "hidden_states": sorted_hidden_states, + "group_list": group_list, + "topk_scales": topk_scales, + } def token_combine(self, hidden_states: torch.Tensor, bias: torch.Tensor = None): - assert self.local_ep is not None unsorted_topk_ids = torch.argsort(self.sorted_topk_ids.float()).to( torch.int32) unsorted_hidden_states = hidden_states.index_select( @@ -1009,18 +1020,19 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher): self.local_expert_indices[i + 1] - 1), "local_expert_indices must be continuous" - def token_dispatch( - self, - hidden_states: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - expert_map: torch.Tensor, - log2phy: Optional[torch.Tensor] = None, - global_redundant_expert_num: int = 0, - shared_gate_up: Optional[torch.Tensor] = None, - shared_dequant_scale: Optional[torch.Tensor] = None, - shared_experts: Optional[torch.Tensor] = None, - ): + def token_dispatch(self, + hidden_states: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + row_idx: torch.Tensor, + expert_map: Optional[torch.Tensor] = None, + log2phy: Optional[torch.Tensor] = None, + global_redundant_expert_num: int = 0, + shared_experts: Optional[torch.Tensor] = None, + shared_gate_up: Optional[torch.Tensor] = None, + shared_dequant_scale: Optional[torch.Tensor] = None, + mc2_mask: Optional[torch.Tensor] = None, + apply_router_weight_on_input: bool = False): self.hidden_shape = hidden_states.shape self.topk_weights = topk_weights assert topk_weights.dim() == 2, "Expected 2D tensor for topk_weights" diff --git a/vllm_ascend/quantization/w4a8_dynamic.py b/vllm_ascend/quantization/w4a8_dynamic.py index 329b3eb..207c5e1 100644 --- a/vllm_ascend/quantization/w4a8_dynamic.py +++ b/vllm_ascend/quantization/w4a8_dynamic.py @@ -26,9 +26,8 @@ from vllm.forward_context import get_forward_context from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.distributed.parallel_state import get_mc2_group +from vllm_ascend.ops.fused_moe import unified_fused_experts_eager from vllm_ascend.ops.layers.experts_selector import select_experts -from vllm_ascend.quantization.w8a8_dynamic import (fused_experts_with_all2all, - fused_experts_with_mc2) class AscendW4A8DynamicLinearMethod: @@ -291,48 +290,25 @@ class AscendW4A8DynamicFusedMoEMethod: topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) topk_weights = topk_weights.to(x.dtype) - if fused_moe_state == FusedMoEState.MC2: - return fused_experts_with_mc2( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - w1_scale=layer.w13_weight_scale_second, - w2_scale=layer.w2_weight_scale_second, - w1_scale_bias=layer.w13_scale_bias, - w2_scale_bias=layer.w2_scale_bias, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - expert_map=expert_map, - moe_all_to_all_group_name=self.moe_all_to_all_group_name, - log2phy=log2phy, - global_redundant_expert_num=global_redundant_expert_num, - shared_experts=shared_experts, - quantized_x_for_share=shared_gate_up, - dynamic_scale_for_share=shared_dequant_scale, - mc2_mask=kwargs.get("mc2_mask", None)) - else: - # The current implementation of deepseek moe splits hidden_states - # according to tp_size before they are feed into layers module. - # Therefore, all2all is needed no matter how dp/tp is set so as to - # dispatch/combine tokens. - return fused_experts_with_all2all( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - w1_scale=layer.w13_weight_scale_second, - w2_scale=layer.w2_weight_scale_second, - w1_scale_bias=layer.w13_scale_bias, - w2_scale_bias=layer.w2_scale_bias, - topk_weights=topk_weights, - topk_ids=topk_ids, - row_idx=row_idx, - top_k=top_k, - expert_map=expert_map, - ep_group=self.ep_group, - log2phy=log2phy, - global_redundant_expert_num=global_redundant_expert_num, - ) + + return unified_fused_experts_eager( + hidden_states=x, + w1=layer.w13_weight, + w2=layer.w2_weight, + w1_scale=layer.w13_weight_scale_second, + w2_scale=layer.w2_weight_scale_second, + w1_scale_bias=layer.w13_scale_bias, + w2_scale_bias=layer.w2_scale_bias, + topk_weights=topk_weights, + topk_ids=topk_ids, + row_idx=row_idx, + expert_map=expert_map, + log2phy=log2phy, + global_redundant_expert_num=global_redundant_expert_num, + shared_experts=shared_experts, + shared_gate_up=shared_gate_up, + shared_dequant_scale=shared_dequant_scale, + mc2_mask=kwargs.get("mc2_mask", None)) def process_scale(self, weight: torch.Tensor, scale, per_group_scale): group_num, k, n = weight.shape diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 1d6a61b..1177af8 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -18,17 +18,16 @@ from typing import Any, Callable, Dict, Optional, Tuple, Union import torch -import torch.distributed as dist import torch_npu -from vllm.distributed import GroupCoordinator, get_ep_group +from vllm.distributed import get_ep_group from vllm.forward_context import get_forward_context import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_forward_context import FusedMoEState from vllm_ascend.distributed.parallel_state import get_mc2_group +from vllm_ascend.ops.fused_moe import unified_fused_experts_eager from vllm_ascend.ops.layers.experts_selector import select_experts -from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion, - dispose_tensor, get_ascend_soc_version) +from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, dispose_tensor def apply_mlp_decode(hidden_states: torch.Tensor, @@ -197,520 +196,6 @@ def apply_mlp(hidden_states: torch.Tensor, return hidden_states -def fused_experts_with_mc2( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w2: torch.Tensor, - w1_scale: torch.Tensor, - w2_scale: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - expert_map: torch.Tensor = None, - moe_all_to_all_group_name: str = "", - log2phy: torch.Tensor = None, - global_redundant_expert_num: int = 0, - shared_experts: Optional[Any] = None, - quantized_x_for_share: Optional[Any] = None, - dynamic_scale_for_share: Optional[Any] = None, - mc2_mask: Optional[torch.Tensor] = None, - shared_gate_up: Optional[Any] = None, - shared_dequant_scale: Optional[Any] = None, - w1_scale_bias: torch.Tensor = None, - w2_scale_bias: torch.Tensor = None, -) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: - assert mc2_mask is not None - if log2phy is not None: - topk_ids = log2phy[topk_ids] - - quant_mode = 2 - ep_group = get_mc2_group() - ep_rank_id = ep_group.rank_in_group - ep_world_size = ep_group.world_size - - # NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine - need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3) - - # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine - a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3 - - enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2") - - if (expert_map is not None): - moe_expert_num = len(expert_map) + global_redundant_expert_num - else: - moe_expert_num = global_redundant_expert_num - # hidden_states = hidden_states.bfloat16() - kwargs_mc2 = { - "x": hidden_states, - "expert_ids": topk_ids, - "expert_shard_type": 0, - "shared_expert_rank_num": 0, - "moe_expert_num": moe_expert_num, - "global_bs": 0, - } - - stage1_kwargs = { - "scales": None, - "quant_mode": quant_mode, - "group_ep": moe_all_to_all_group_name, - "ep_world_size": ep_world_size, - "ep_rank_id": ep_rank_id, - } - if need_extra_args: - stage1_kwargs.update({ - "group_tp": moe_all_to_all_group_name, - "tp_world_size": 1, - "tp_rank_id": 0, - }) - if a3_need_extra_args and enable_dispatch_v2: - stage1_kwargs.update({ - "x_active_mask": mc2_mask, - }) - kwargs_mc2.update(stage1_kwargs) - - output = torch_npu.npu_moe_distribute_dispatch_v2( - **kwargs_mc2 - ) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch( - **kwargs_mc2) - # comm_stream.wait_stream(torch.npu.current_stream()) - expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, ep_recv_counts = output[ - 0:5] - - if shared_experts is not None: - shared_act_out = shared_experts.act_fn( - (shared_gate_up, shared_dequant_scale)) - shared_act, swiglu_out_scale = shared_act_out[0], shared_act_out[1] - - # `expand_x` will be disposed in the `apply_mlp` function - if w1_scale_bias is None: - down_out_list = apply_mlp_decode(expand_x, - w1, - w1_scale, - w2, - w2_scale, - expert_token_nums, - dynamic_scale=dynamic_scale) - else: - # w4a8 scene, cannot use apply_mlp_decode because the operator is not supported - down_out_list = apply_mlp(expand_x, - w1, - w1_scale, - w2, - w2_scale, - expert_token_nums, - dynamic_scale=dynamic_scale, - w1_scale_bias=w1_scale_bias, - w2_scale_bias=w2_scale_bias) - - # moeCombine - kwargs_mc2 = { - "expand_x": down_out_list, - "expert_ids": topk_ids, - "expert_scales": topk_weights.to(torch.float32), - "expert_shard_type": 0, - "shared_expert_rank_num": 0, - "moe_expert_num": moe_expert_num, - "global_bs": 0, - } - tp_recv_counts = torch.empty(1, - dtype=torch.int32, - device=hidden_states.device) - stage3_kwargs = { - "ep_send_counts": ep_recv_counts, - "group_ep": moe_all_to_all_group_name, - "ep_world_size": ep_world_size, - "ep_rank_id": ep_rank_id, - } - if enable_dispatch_v2: - stage3_kwargs.update({ - "assist_info_for_combine": - assist_info_for_combine, - }) - else: - stage3_kwargs.update({ - "expand_idx": assist_info_for_combine, - }) - if need_extra_args: - stage3_kwargs.update({ - "tp_send_counts": tp_recv_counts, - "group_tp": moe_all_to_all_group_name, - "tp_world_size": 1, - "tp_rank_id": 0, - }) - if a3_need_extra_args and enable_dispatch_v2: - stage3_kwargs.update({ - "x_active_mask": mc2_mask, - }) - kwargs_mc2.update(stage3_kwargs) - - hidden_states = torch_npu.npu_moe_distribute_combine_v2( - **kwargs_mc2 - ) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine( - **kwargs_mc2) - - if shared_experts is None: - return hidden_states - else: - shared_output, _ = shared_experts.down_proj( - (shared_act, swiglu_out_scale)) - return hidden_states, shared_output - - -def init_routing_quant(hidden_states, top_k, topk_ids, row_idx, - global_num_experts): - num_tokens, _ = hidden_states.shape - hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( - hidden_states, - row_idx=row_idx, - expert_idx=topk_ids, - active_num=num_tokens) - - expanded_row_idx = (expanded_row_idx.view(top_k, -1).permute( - 1, 0).contiguous().view(-1)) - global_expert_tokens = torch.bincount(expanded_expert_idx, - minlength=global_num_experts) - global_expert_tokens = global_expert_tokens.to(torch.int32) - quantized_tokens, token_scales = torch_npu.npu_dynamic_quant(hidden_states) - return quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales - - -# currently expert parallelism implemented with all2all -# is under-optimized. -def fused_experts_with_all2all( - hidden_states: torch.Tensor, - w1: torch.Tensor, - w1_scale: torch.Tensor, - w2: torch.Tensor, - w2_scale: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - row_idx: torch.Tensor, - top_k: int, - expert_map: torch.Tensor = None, - ep_group: GroupCoordinator = None, - log2phy: torch.Tensor = None, - global_redundant_expert_num: int = 0, - w1_scale_bias: torch.Tensor = None, - w2_scale_bias: torch.Tensor = None, -): - if log2phy is not None: - topk_ids = log2phy[topk_ids] - original_shape = hidden_states.shape - if len(original_shape) == 3: - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - - num_tokens, _ = hidden_states.shape - num_experts = w1.shape[0] - - if expert_map is not None: - global_num_experts = len(expert_map) + global_redundant_expert_num - if hasattr(torch_npu, "npu_moe_init_routing_quant"): - quantized_tokens, expanded_row_idx, global_expert_tokens, _, token_scales = torch_npu.npu_moe_init_routing_quant( - hidden_states, - expert_idx=topk_ids.to(torch.int32), - active_num=0, - expert_capacity=0, - expert_num=global_num_experts, - drop_pad_mode=0, - expert_tokens_num_mode=2, - expert_tokens_before_capacity_flag=False, - quant_mode=1, - ) - else: - quantized_tokens, expanded_row_idx, global_expert_tokens, token_scales = init_routing_quant( - hidden_states, top_k, topk_ids, row_idx, global_num_experts) - - gather_sizes = global_expert_tokens.new_empty( - global_expert_tokens.shape[0]) - dist.all_to_all_single(gather_sizes, global_expert_tokens) - - token_counts_combined = torch.stack( - [gather_sizes, global_expert_tokens], dim=0) - token_counts_combined = token_counts_combined.view( - 2, ep_group.world_size, -1).sum(dim=2) - token_counts_combined_cpu = token_counts_combined.to( - torch.device("cpu"), non_blocking=True).numpy() - all_tokens = gather_sizes.sum() - - gathered_tokens = quantized_tokens.new_empty(all_tokens.item(), - quantized_tokens.shape[1]) - dynamic_scale = token_scales.new_empty(gathered_tokens.shape[0]) - gather_size_list = token_counts_combined_cpu[1] - scatter_size_list = token_counts_combined_cpu[0] - - dist.all_to_all_single(gathered_tokens, quantized_tokens, - scatter_size_list, gather_size_list) - dist.all_to_all_single(dynamic_scale, token_scales, scatter_size_list, - gather_size_list) - - hidden_states, dynamic_scale, inverse_indices, expert_tokens = torch_npu.npu_moe_re_routing( - gathered_tokens, - gather_sizes.view(ep_group.world_size, -1), - per_token_scales=dynamic_scale) - expert_tokens = expert_tokens.to(torch.int64) - group_list_type = 1 - else: - hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( - hidden_states, - row_idx=row_idx, - expert_idx=topk_ids, - active_num=num_tokens) - - expert_tokens = torch_npu.npu_moe_compute_expert_tokens( - expanded_expert_idx, num_experts) - expert_tokens = expert_tokens.to(torch.int64) - group_list_type = 0 - dynamic_scale = None - - # `hidden_states` will be disposed in the `apply_mlp` function - hidden_states = apply_mlp( - hidden_states, - w1, - w1_scale, #17 - w2, - w2_scale, - expert_tokens, #16 - dynamic_scale=dynamic_scale, - group_list_type=group_list_type, - w1_scale_bias=w1_scale_bias, - w2_scale_bias=w2_scale_bias) - - if expert_map is not None: - reordered_outputs = torch.index_select( - hidden_states, - dim=0, - # Workaround: Convert to float so that argsort runs on AI Core instead of slower AICPU - index=inverse_indices.to(torch.float32).argsort().to(torch.int32)) - - hidden_states = reordered_outputs.new_empty(*quantized_tokens.shape) - dist.all_to_all_single(hidden_states, reordered_outputs, - gather_size_list, scatter_size_list) - - final_hidden_states = torch_npu.npu_moe_finalize_routing( - hidden_states, - skip1=None, - skip2=None, - bias=None, - scales=topk_weights, - expanded_src_to_dst_row=expanded_row_idx, - export_for_source_row=None, - drop_pad_mode=2) - else: - # TODO: Reorder device memory 2 times here, replace the current - # implementation here when suitable operators become available. - final_hidden_states = torch_npu.npu_moe_finalize_routing( - hidden_states, - skip1=None, - skip2=None, - bias=None, - scales=topk_weights, - expanded_src_to_dst_row=expanded_row_idx, - export_for_source_row=topk_ids, - ) - if len(original_shape) == 3: - final_hidden_states = final_hidden_states.view(original_shape) - return final_hidden_states - - -def fused_experts_with_allgather(hidden_states: torch.Tensor, - w1: torch.Tensor, - w1_scale: torch.Tensor, - w2: torch.Tensor, - w2_scale: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - top_k: int, - expert_map: torch.Tensor = None): - original_shape = hidden_states.shape - if len(original_shape) == 3: - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - num_tokens = hidden_states.shape[0] - batch_size, hidden_size = hidden_states.shape - topk_weights = topk_weights.to(hidden_states.dtype) - - ep_group = get_ep_group().device_group - ep_rank = torch.distributed.get_rank(group=ep_group) - ep_size = torch.distributed.get_world_size(ep_group) - - global_num_experts = len(expert_map) - local_num_experts = global_num_experts // ep_size - - hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states) - - hidden_states, expanded_x_idx, expert_tokens, pertoken_scale = torch_npu.npu_moe_init_routing_v2( - hidden_states, - topk_ids, - scale=pertoken_scale, - offset=None, - active_num=num_tokens * top_k, - expert_num=global_num_experts, - expert_tokens_num_type=1, - expert_tokens_num_flag=True, - active_expert_range=[ - ep_rank * local_num_experts, (ep_rank + 1) * local_num_experts - ], - quant_mode=-1, - row_idx_type=1) - group_list_type = 1 - - sorted_topk_weight = torch.index_select(topk_weights.view(-1), 0, - expanded_x_idx) - row_index = expanded_x_idx // topk_ids.shape[-1] - row_index = row_index.to(torch.int64) - share_input = torch.zeros((batch_size, hidden_size), - dtype=torch.bfloat16, - device="npu") - - hidden_states = torch_npu.npu_grouped_matmul( - x=[hidden_states], - weight=[w1], - split_item=3, - group_list_type=group_list_type, - group_type=0, - group_list=expert_tokens, - output_dtype=torch.int32)[0] - - # act_fn: swiglu - hidden_states, pertoken_scale = torch_npu.npu_dequant_swiglu_quant( - x=hidden_states, - weight_scale=w1_scale.to(torch.float32), - activation_scale=pertoken_scale, - bias=None, - quant_scale=None, - quant_offset=None, - group_index=expert_tokens, - activate_left=True, - quant_mode=1, - ) - - final_hidden_states = torch_npu.npu_grouped_matmul_finalize_routing( - hidden_states, - w2, - scale=w2_scale.to(torch.float32), - bias=None, - pertoken_scale=pertoken_scale.view(-1), - group_list=expert_tokens, - shared_input=share_input, - logit=sorted_topk_weight.to(torch.float32), - row_index=row_index, - output_bs=batch_size).to(torch.bfloat16) - - if len(original_shape) == 3: - final_hidden_states = final_hidden_states.view(original_shape) - - return final_hidden_states - - -def fused_experts(hidden_states: torch.Tensor, - w1: torch.Tensor, - w1_scale: torch.Tensor, - w2: torch.Tensor, - w2_scale: torch.Tensor, - topk_weights: torch.Tensor, - topk_ids: torch.Tensor, - row_idx: torch.Tensor, - top_k: int, - expert_map: torch.Tensor = None): - original_shape = hidden_states.shape - if len(original_shape) == 3: - hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - - num_tokens, _ = hidden_states.shape - num_experts = w1.shape[0] - dtype = hidden_states.dtype - device = hidden_states.device - - if expert_map is not None: - # Generate token indices and flatten - token_indices = (torch.arange(num_tokens, - device=device, - dtype=torch.int64).unsqueeze(1).expand( - -1, top_k).reshape(-1)) - - # Flatten token-to-expert mappings and map to local experts - weights_flat = topk_weights.view(-1) - experts_flat = topk_ids.view(-1) - local_experts_flat = expert_map[experts_flat] - - # Filter valid token-expert pairs - mask = local_experts_flat != -1 - filtered_weights = torch.where( - mask, weights_flat, torch.zeros_like(weights_flat)).to(dtype) - filtered_experts = torch.where( - mask, local_experts_flat, - torch.full_like(local_experts_flat, - num_experts)).to(topk_ids.dtype) - - # Sort by local expert IDs - sort_indices = torch.argsort(filtered_experts) - sorted_token_indices = token_indices[sort_indices] - sorted_weights = filtered_weights[sort_indices] - - # Compute token counts with minlength of num_experts - # This is equivalent to but faster than: - # >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1] - token_counts = torch.zeros(num_experts + 1, - device=device, - dtype=torch.int64) - ones = torch.ones_like(filtered_experts, dtype=torch.int64) - token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones) - expert_tokens = token_counts[:num_experts] - # Rearrange hidden_states - hidden_states = hidden_states[sorted_token_indices] - group_list_type = 1 - else: - hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( - hidden_states, - row_idx=row_idx, - expert_idx=topk_ids, - active_num=num_tokens) - - expert_tokens = torch_npu.npu_moe_compute_expert_tokens( - expanded_expert_idx, num_experts) - expert_tokens = expert_tokens.to(torch.int64) - group_list_type = 0 - - # `hidden_states` will be disposed in the `apply_mlp` function - hidden_states = apply_mlp(hidden_states, - w1, - w1_scale, - w2, - w2_scale, - expert_tokens, - group_list_type=group_list_type) - - if expert_map is not None: - hidden_states.mul_(sorted_weights.unsqueeze(1)) - final_hidden_states = torch.zeros(*original_shape, - device=device, - dtype=dtype) - - num_valid_tokens = mask.sum() - valid_token_mask = torch.arange( - 0, sorted_token_indices.shape[0], - device=device).unsqueeze(1) < num_valid_tokens - hidden_states = hidden_states.masked_fill_(~valid_token_mask, - 0).to(dtype) - final_hidden_states.index_add_(0, sorted_token_indices, hidden_states) - else: - # TODO: Reorder device memory 2 times here, replace the current - # implementation here when suitable operators become available. - final_hidden_states = torch_npu.npu_moe_finalize_routing( - hidden_states, - skip1=None, - skip2=None, - bias=None, - scales=topk_weights, - expanded_src_to_dst_row=expanded_row_idx, - export_for_source_row=topk_ids, - ) - - if len(original_shape) == 3: - final_hidden_states = final_hidden_states.view(original_shape) - return final_hidden_states - - class AscendW8A8DynamicLinearMethod: """Linear method for Ascend W8A8_DYNAMIC. """ @@ -905,68 +390,23 @@ class AscendW8A8DynamicFusedMoEMethod: topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) topk_weights = topk_weights.to(x.dtype) - if fused_moe_state == FusedMoEState.AllGatherEP: - return fused_experts_with_allgather( - hidden_states=x, - w1=layer.w13_weight, - w1_scale=layer.w13_weight_scale, - w2=layer.w2_weight, - w2_scale=layer.w2_weight_scale, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - expert_map=expert_map) - elif fused_moe_state == FusedMoEState.MC2: - return fused_experts_with_mc2( - hidden_states=x, - w1=layer.w13_weight, - w2=layer.w2_weight, - w1_scale=layer.w13_weight_scale_fp32, - w2_scale=layer.w2_weight_scale, - topk_weights=topk_weights, - topk_ids=topk_ids, - top_k=top_k, - expert_map=expert_map, - moe_all_to_all_group_name=self.moe_all_to_all_group_name, - log2phy=log2phy, - global_redundant_expert_num=global_redundant_expert_num, - shared_experts=shared_experts, - mc2_mask=kwargs.get("mc2_mask", None), - shared_gate_up=shared_gate_up, - shared_dequant_scale=shared_dequant_scale) - elif fused_moe_state in [ - FusedMoEState.AllGather, FusedMoEState.NaiveMulticast - ]: - return fused_experts(hidden_states=x, - w1=layer.w13_weight, - w1_scale=layer.w13_weight_scale, - w2=layer.w2_weight, - w2_scale=layer.w2_weight_scale, - topk_weights=topk_weights, - topk_ids=topk_ids, - row_idx=row_idx, - top_k=top_k, - expert_map=expert_map) - else: - # The current implementation of deepseek moe splits hidden_states - # according to tp_size before they are feed into layers module. - # Therefore, all2all is needed no matter how dp/tp is set so as to - # dispatch/combine tokens. - return fused_experts_with_all2all( - hidden_states=x, - w1=layer.w13_weight, - w1_scale=layer.w13_weight_scale, - w2=layer.w2_weight, - w2_scale=layer.w2_weight_scale, - topk_weights=topk_weights, - topk_ids=topk_ids, - row_idx=row_idx, - top_k=top_k, - expert_map=expert_map, - ep_group=self.ep_group, - log2phy=log2phy, - global_redundant_expert_num=global_redundant_expert_num, - ) + + return unified_fused_experts_eager( + hidden_states=x, + w1=layer.w13_weight, + w1_scale=layer.w13_weight_scale, + w2=layer.w2_weight, + w2_scale=layer.w2_weight_scale, + topk_weights=topk_weights, + topk_ids=topk_ids, + row_idx=row_idx, + expert_map=expert_map, + log2phy=log2phy, + global_redundant_expert_num=global_redundant_expert_num, + shared_experts=shared_experts, + shared_gate_up=shared_gate_up, + shared_dequant_scale=shared_dequant_scale, + mc2_mask=kwargs.get("mc2_mask", None)) def process_weights_after_loading(self, layer): if self.transpose_weight: