diff --git a/tests/ut/ops/test_token_dispatcher.py b/tests/ut/ops/test_token_dispatcher.py index ed32b93..aed2b7d 100644 --- a/tests/ut/ops/test_token_dispatcher.py +++ b/tests/ut/ops/test_token_dispatcher.py @@ -92,7 +92,8 @@ class TestTokenDispatcherWithMC2(TestBase): expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) with patch("torch_npu.npu_moe_distribute_dispatch_v2", - return_value=(torch.randn(10, 128), ) * 5) as mock_dispatch: + return_value=(torch.randn(10, 128), ) * 5 + + (None, None)) as mock_dispatch: output = self.dispatcher.token_dispatch(hidden_states, topk_weights, topk_ids, self.row_idx, expert_map) @@ -112,7 +113,7 @@ class TestTokenDispatcherWithMC2(TestBase): self.topk_weights = torch.randn(10, 1) with patch("torch_npu.npu_moe_distribute_dispatch_v2", - return_value=(torch.randn(10, 128), ) * 5): + return_value=(torch.randn(10, 128), ) * 5 + (None, None)): self.dispatcher.token_dispatch(self.hidden_states, self.topk_weights, torch.randint(0, 8, (10, 1)), diff --git a/tests/ut/torchair/quantization/test_torchair_w8a8_dynamic.py b/tests/ut/torchair/quantization/test_torchair_w8a8_dynamic.py index 520155d..09b5aa3 100644 --- a/tests/ut/torchair/quantization/test_torchair_w8a8_dynamic.py +++ b/tests/ut/torchair/quantization/test_torchair_w8a8_dynamic.py @@ -3,8 +3,9 @@ from unittest.mock import MagicMock, patch import torch from tests.ut.base import TestBase -from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \ - torchair_fused_experts_with_all2all +from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import ( + torchair_fused_experts_with_all2all, torchair_fused_experts_with_mc2) +from vllm_ascend.utils import AscendSocVersion class TestAscendW8A8FusedMoEMethod(TestBase): @@ -73,3 +74,57 @@ class TestAscendW8A8FusedMoEMethod(TestBase): self.assertIsNotNone(result) self.assertEqual(result.dtype, torch.bfloat16) self.assertEqual(result.shape, (128, 128)) + + @patch.dict('os.environ', { + 'HCCL_INTRA_ROCE_ENABLE': '0', + 'HCCL_INTRA_PCIE_ENABLE': '1' + }) + @patch( + "vllm_ascend.torchair.quantization.torchair_w8a8_dynamic.get_ascend_soc_version" + ) + @patch( + 'vllm_ascend.torchair.quantization.torchair_w8a8_dynamic.get_mc2_group' + ) + @patch('torch_npu.npu_moe_distribute_combine_v2') + @patch('torch_npu.npu_moe_distribute_dispatch_v2') + @patch( + 'vllm_ascend.torchair.quantization.torchair_w8a8_dynamic.torchair_apply_mlp_decode' + ) + def test_torchair_fused_experts_with_mc2_a2_optimization( + self, mock_mlp_decode, mock_dispatch, mock_combine, mock_get_group, + mock_ascend_soc_version): + """Test expert_scales is passed in A2 SOC version with mc2 optimization""" + # Setup mocks + mock_ascend_soc_version.return_value = AscendSocVersion.A2 + + mock_group = MagicMock() + mock_group.rank_in_group = 0 + mock_group.world_size = 4 + mock_get_group.return_value = mock_group + + mock_combine.return_value = self.placeholder + + mock_dispatch.return_value = (torch.randn(32, 1024), torch.randn(1), + torch.randint(0, 32, (32, )), + torch.randint(1, 5, (8, )), + torch.randint(1, 5, (4, )), None, + torch.randn(32)) + mock_mlp_decode.return_value = self.placeholder + + result = torchair_fused_experts_with_mc2( + hidden_states=self.placeholder, + w1=self.placeholder, + w2=self.placeholder, + w1_scale=self.placeholder, + w2_scale=self.placeholder, + topk_weights=self.placeholder, + topk_ids=self.placeholder, + top_k=2, + mc2_mask=self.placeholder) + + # Check that expert_scales was passed to dispatch + call_args = mock_dispatch.call_args[1] + self.assertIn('expert_scales', call_args) + + self.assertIsInstance(result, torch.Tensor) + self.assertEqual(result.shape, self.placeholder.shape) diff --git a/vllm_ascend/ops/moe/token_dispatcher.py b/vllm_ascend/ops/moe/token_dispatcher.py index e7bb9a3..6cb97c3 100644 --- a/vllm_ascend/ops/moe/token_dispatcher.py +++ b/vllm_ascend/ops/moe/token_dispatcher.py @@ -20,7 +20,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - from abc import ABC, abstractmethod from typing import Any, Optional @@ -31,7 +30,8 @@ from vllm.distributed.parallel_state import get_ep_group from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.ops.moe.comm_utils import ( async_all_to_all, gather_from_sequence_parallel_region) -from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version +from vllm_ascend.utils import (AscendSocVersion, get_ascend_soc_version, + is_hierarchical_communication_enabled) class MoETokenDispatcher(ABC): @@ -99,6 +99,10 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine self.a3_need_extra_args = \ get_ascend_soc_version() == AscendSocVersion.A3 + # NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and + # HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly + # improve communication performance. + self.need_expert_scale = is_hierarchical_communication_enabled() self.output = None self.assist_info_for_combine = None self.ep_recv_counts = None @@ -108,6 +112,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): self.shared_experts = None self.mc2_mask = None self.with_quant = False + self.expand_scales = None def get_dispatch_mc2_kwargs( self, @@ -153,6 +158,11 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): stage1_kwargs.update({ "x_active_mask": self.mc2_mask, }) + if self.need_expert_scale: + stage1_kwargs.update({ + "expert_scales": + topk_weights.to(torch.float32), + }) kwargs_mc2.update(stage1_kwargs) return kwargs_mc2 @@ -186,8 +196,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): ) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch( **kwargs_mc2) # comm_stream.wait_stream(torch.npu.current_stream()) - expand_x, dynamic_scale, self.assist_info_for_combine, \ - expert_token_nums, self.ep_recv_counts = self.output[0:5] + expand_x, dynamic_scale, self.assist_info_for_combine, expert_token_nums, \ + self.ep_recv_counts, _, self.expand_scales = self.output[0:7] if self.with_quant: if shared_experts is not None: @@ -240,6 +250,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): "group_ep": self.moe_all_to_all_group_name, "ep_world_size": self.ep_world_size, "ep_rank_id": self.ep_rank_id, + "expand_scales": self.expand_scales, } if self.enable_dispatch_v2: stage3_kwargs.update({ @@ -281,6 +292,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher): self.topk_weights = None self.mc2_mask = None self.expert_map = None + self.expand_scales = None if self.shared_experts is None: return hidden_states diff --git a/vllm_ascend/torchair/ops/torchair_fused_moe.py b/vllm_ascend/torchair/ops/torchair_fused_moe.py index 56843fb..e7d0093 100644 --- a/vllm_ascend/torchair/ops/torchair_fused_moe.py +++ b/vllm_ascend/torchair/ops/torchair_fused_moe.py @@ -51,7 +51,8 @@ from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor from vllm_ascend.utils import (AscendSocVersion, dispose_tensor, get_all_reduce_merge_state, get_ascend_soc_version, - get_rm_router_logits_state, is_310p) + get_rm_router_logits_state, is_310p, + is_hierarchical_communication_enabled) def torchair_fused_experts_with_mc2( @@ -78,6 +79,10 @@ def torchair_fused_experts_with_mc2( # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3 + # NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and + # HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly + # improve communication performance. + need_expert_scale = is_hierarchical_communication_enabled() enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2") @@ -108,6 +113,10 @@ def torchair_fused_experts_with_mc2( stage1_kwargs.update({ "x_active_mask": mc2_mask, }) + if need_expert_scale: + stage1_kwargs.update({ + "expert_scales": topk_weights.to(torch.float32), + }) kwargs_mc2.update(stage1_kwargs) @@ -116,8 +125,8 @@ def torchair_fused_experts_with_mc2( ) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch( **kwargs_mc2) # comm_stream.wait_stream(torch.npu.current_stream()) - expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, ep_recv_counts = output[ - 0:5] + expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, \ + ep_recv_counts, _, expand_scales = output[0:7] if shared_experts is not None: with npu_stream_switch("moe_secondary", 0): @@ -167,6 +176,7 @@ def torchair_fused_experts_with_mc2( "group_ep": moe_all_to_all_group_name, "ep_world_size": ep_world_size, "ep_rank_id": ep_rank_id, + "expand_scales": expand_scales, } if enable_dispatch_v2: stage3_kwargs.update({ diff --git a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py index 1825b2b..b027b2f 100644 --- a/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py +++ b/vllm_ascend/torchair/quantization/torchair_w8a8_dynamic.py @@ -14,7 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # - from typing import Any, Callable, Dict, Optional, Tuple, Union import torch @@ -29,7 +28,8 @@ from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion, - dispose_tensor, get_ascend_soc_version) + dispose_tensor, get_ascend_soc_version, + is_hierarchical_communication_enabled) def torchair_apply_mlp_decode(hidden_states: torch.Tensor, @@ -237,6 +237,10 @@ def torchair_fused_experts_with_mc2( # NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3 + # NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and + # HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly + # improve communication performance. + need_expert_scale = is_hierarchical_communication_enabled() enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2") @@ -271,6 +275,10 @@ def torchair_fused_experts_with_mc2( stage1_kwargs.update({ "x_active_mask": mc2_mask, }) + if need_expert_scale: + stage1_kwargs.update({ + "expert_scales": topk_weights.to(torch.float32), + }) kwargs_mc2.update(stage1_kwargs) output = torch_npu.npu_moe_distribute_dispatch_v2( @@ -278,8 +286,8 @@ def torchair_fused_experts_with_mc2( ) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch( **kwargs_mc2) # comm_stream.wait_stream(torch.npu.current_stream()) - expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, ep_recv_counts = output[ - 0:5] + expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, \ + ep_recv_counts, _, expand_scales = output[0:7] if shared_experts is not None: with npu_stream_switch("moe_secondary", 0): @@ -327,6 +335,7 @@ def torchair_fused_experts_with_mc2( "group_ep": moe_all_to_all_group_name, "ep_world_size": ep_world_size, "ep_rank_id": ep_rank_id, + "expand_scales": expand_scales, } if enable_dispatch_v2: stage3_kwargs.update({ diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index 8382cf7..e6076b9 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -666,7 +666,7 @@ def get_hccl_config_for_pg_options(group_name: str) -> Optional[dict]: Args: group_name: Name of the communication group - + Returns: HCCL pg_options or None for mc2 group """ @@ -689,7 +689,7 @@ def get_default_buffer_config() -> dict: def calculate_dp_buffer_size() -> int: """ - formula of dp buffer size: + formula of dp buffer size: dp_size + 2 (flags: with_prefill and enable_dbo) """ from vllm.config import get_current_vllm_config @@ -698,3 +698,11 @@ def calculate_dp_buffer_size() -> int: int32_size = torch.iinfo(torch.int32).bits // 8 dp_buffer_size = math.ceil((dp_size + 2) * int32_size / (1024 * 1024)) return max(dp_buffer_size, _MIN_DP_BUFFER_SIZE) + + +# Currently, when in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 +# and HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and +# significantly improve communication performance of MC2 ops dispatch/combine. +def is_hierarchical_communication_enabled(): + return (os.getenv("HCCL_INTRA_ROCE_ENABLE", "") == "0" + and os.getenv("HCCL_INTRA_PCIE_ENABLE", "") == "1")