[Feat] enable hierarchical communication for mc2 ops on A2 (#3015)
Currently, when in A2, setting the environment variables `HCCL_INTRA_PCIE_ENABLE=1` and `HCCL_INTRA_ROCE_ENABLE=0` can reduce cross-machine communication traffic and significantly improve communication performance. For more details, please refer to [document](https://www.hiascend.com/document/detail/zh/Pytorch/710/apiref/torchnpuCustomsapi/context/torch_npu-npu_moe_distribute_dispatch_v2.md) - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
@@ -92,7 +92,8 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
expert_map = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7])
|
||||
|
||||
with patch("torch_npu.npu_moe_distribute_dispatch_v2",
|
||||
return_value=(torch.randn(10, 128), ) * 5) as mock_dispatch:
|
||||
return_value=(torch.randn(10, 128), ) * 5 +
|
||||
(None, None)) as mock_dispatch:
|
||||
output = self.dispatcher.token_dispatch(hidden_states,
|
||||
topk_weights, topk_ids,
|
||||
self.row_idx, expert_map)
|
||||
@@ -112,7 +113,7 @@ class TestTokenDispatcherWithMC2(TestBase):
|
||||
self.topk_weights = torch.randn(10, 1)
|
||||
|
||||
with patch("torch_npu.npu_moe_distribute_dispatch_v2",
|
||||
return_value=(torch.randn(10, 128), ) * 5):
|
||||
return_value=(torch.randn(10, 128), ) * 5 + (None, None)):
|
||||
self.dispatcher.token_dispatch(self.hidden_states,
|
||||
self.topk_weights,
|
||||
torch.randint(0, 8, (10, 1)),
|
||||
|
||||
@@ -3,8 +3,9 @@ from unittest.mock import MagicMock, patch
|
||||
import torch
|
||||
|
||||
from tests.ut.base import TestBase
|
||||
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import \
|
||||
torchair_fused_experts_with_all2all
|
||||
from vllm_ascend.torchair.quantization.torchair_w8a8_dynamic import (
|
||||
torchair_fused_experts_with_all2all, torchair_fused_experts_with_mc2)
|
||||
from vllm_ascend.utils import AscendSocVersion
|
||||
|
||||
|
||||
class TestAscendW8A8FusedMoEMethod(TestBase):
|
||||
@@ -73,3 +74,57 @@ class TestAscendW8A8FusedMoEMethod(TestBase):
|
||||
self.assertIsNotNone(result)
|
||||
self.assertEqual(result.dtype, torch.bfloat16)
|
||||
self.assertEqual(result.shape, (128, 128))
|
||||
|
||||
@patch.dict('os.environ', {
|
||||
'HCCL_INTRA_ROCE_ENABLE': '0',
|
||||
'HCCL_INTRA_PCIE_ENABLE': '1'
|
||||
})
|
||||
@patch(
|
||||
"vllm_ascend.torchair.quantization.torchair_w8a8_dynamic.get_ascend_soc_version"
|
||||
)
|
||||
@patch(
|
||||
'vllm_ascend.torchair.quantization.torchair_w8a8_dynamic.get_mc2_group'
|
||||
)
|
||||
@patch('torch_npu.npu_moe_distribute_combine_v2')
|
||||
@patch('torch_npu.npu_moe_distribute_dispatch_v2')
|
||||
@patch(
|
||||
'vllm_ascend.torchair.quantization.torchair_w8a8_dynamic.torchair_apply_mlp_decode'
|
||||
)
|
||||
def test_torchair_fused_experts_with_mc2_a2_optimization(
|
||||
self, mock_mlp_decode, mock_dispatch, mock_combine, mock_get_group,
|
||||
mock_ascend_soc_version):
|
||||
"""Test expert_scales is passed in A2 SOC version with mc2 optimization"""
|
||||
# Setup mocks
|
||||
mock_ascend_soc_version.return_value = AscendSocVersion.A2
|
||||
|
||||
mock_group = MagicMock()
|
||||
mock_group.rank_in_group = 0
|
||||
mock_group.world_size = 4
|
||||
mock_get_group.return_value = mock_group
|
||||
|
||||
mock_combine.return_value = self.placeholder
|
||||
|
||||
mock_dispatch.return_value = (torch.randn(32, 1024), torch.randn(1),
|
||||
torch.randint(0, 32, (32, )),
|
||||
torch.randint(1, 5, (8, )),
|
||||
torch.randint(1, 5, (4, )), None,
|
||||
torch.randn(32))
|
||||
mock_mlp_decode.return_value = self.placeholder
|
||||
|
||||
result = torchair_fused_experts_with_mc2(
|
||||
hidden_states=self.placeholder,
|
||||
w1=self.placeholder,
|
||||
w2=self.placeholder,
|
||||
w1_scale=self.placeholder,
|
||||
w2_scale=self.placeholder,
|
||||
topk_weights=self.placeholder,
|
||||
topk_ids=self.placeholder,
|
||||
top_k=2,
|
||||
mc2_mask=self.placeholder)
|
||||
|
||||
# Check that expert_scales was passed to dispatch
|
||||
call_args = mock_dispatch.call_args[1]
|
||||
self.assertIn('expert_scales', call_args)
|
||||
|
||||
self.assertIsInstance(result, torch.Tensor)
|
||||
self.assertEqual(result.shape, self.placeholder.shape)
|
||||
|
||||
@@ -20,7 +20,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -31,7 +30,8 @@ from vllm.distributed.parallel_state import get_ep_group
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.moe.comm_utils import (
|
||||
async_all_to_all, gather_from_sequence_parallel_region)
|
||||
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
|
||||
from vllm_ascend.utils import (AscendSocVersion, get_ascend_soc_version,
|
||||
is_hierarchical_communication_enabled)
|
||||
|
||||
|
||||
class MoETokenDispatcher(ABC):
|
||||
@@ -99,6 +99,10 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
|
||||
self.a3_need_extra_args = \
|
||||
get_ascend_soc_version() == AscendSocVersion.A3
|
||||
# NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and
|
||||
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
|
||||
# improve communication performance.
|
||||
self.need_expert_scale = is_hierarchical_communication_enabled()
|
||||
self.output = None
|
||||
self.assist_info_for_combine = None
|
||||
self.ep_recv_counts = None
|
||||
@@ -108,6 +112,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
self.shared_experts = None
|
||||
self.mc2_mask = None
|
||||
self.with_quant = False
|
||||
self.expand_scales = None
|
||||
|
||||
def get_dispatch_mc2_kwargs(
|
||||
self,
|
||||
@@ -153,6 +158,11 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
stage1_kwargs.update({
|
||||
"x_active_mask": self.mc2_mask,
|
||||
})
|
||||
if self.need_expert_scale:
|
||||
stage1_kwargs.update({
|
||||
"expert_scales":
|
||||
topk_weights.to(torch.float32),
|
||||
})
|
||||
|
||||
kwargs_mc2.update(stage1_kwargs)
|
||||
return kwargs_mc2
|
||||
@@ -186,8 +196,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch(
|
||||
**kwargs_mc2)
|
||||
# comm_stream.wait_stream(torch.npu.current_stream())
|
||||
expand_x, dynamic_scale, self.assist_info_for_combine, \
|
||||
expert_token_nums, self.ep_recv_counts = self.output[0:5]
|
||||
expand_x, dynamic_scale, self.assist_info_for_combine, expert_token_nums, \
|
||||
self.ep_recv_counts, _, self.expand_scales = self.output[0:7]
|
||||
|
||||
if self.with_quant:
|
||||
if shared_experts is not None:
|
||||
@@ -240,6 +250,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
"group_ep": self.moe_all_to_all_group_name,
|
||||
"ep_world_size": self.ep_world_size,
|
||||
"ep_rank_id": self.ep_rank_id,
|
||||
"expand_scales": self.expand_scales,
|
||||
}
|
||||
if self.enable_dispatch_v2:
|
||||
stage3_kwargs.update({
|
||||
@@ -281,6 +292,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
self.topk_weights = None
|
||||
self.mc2_mask = None
|
||||
self.expert_map = None
|
||||
self.expand_scales = None
|
||||
|
||||
if self.shared_experts is None:
|
||||
return hidden_states
|
||||
|
||||
@@ -51,7 +51,8 @@ from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
|
||||
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
|
||||
get_all_reduce_merge_state,
|
||||
get_ascend_soc_version,
|
||||
get_rm_router_logits_state, is_310p)
|
||||
get_rm_router_logits_state, is_310p,
|
||||
is_hierarchical_communication_enabled)
|
||||
|
||||
|
||||
def torchair_fused_experts_with_mc2(
|
||||
@@ -78,6 +79,10 @@ def torchair_fused_experts_with_mc2(
|
||||
|
||||
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
|
||||
a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3
|
||||
# NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and
|
||||
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
|
||||
# improve communication performance.
|
||||
need_expert_scale = is_hierarchical_communication_enabled()
|
||||
|
||||
enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2")
|
||||
|
||||
@@ -108,6 +113,10 @@ def torchair_fused_experts_with_mc2(
|
||||
stage1_kwargs.update({
|
||||
"x_active_mask": mc2_mask,
|
||||
})
|
||||
if need_expert_scale:
|
||||
stage1_kwargs.update({
|
||||
"expert_scales": topk_weights.to(torch.float32),
|
||||
})
|
||||
|
||||
kwargs_mc2.update(stage1_kwargs)
|
||||
|
||||
@@ -116,8 +125,8 @@ def torchair_fused_experts_with_mc2(
|
||||
) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch(
|
||||
**kwargs_mc2)
|
||||
# comm_stream.wait_stream(torch.npu.current_stream())
|
||||
expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, ep_recv_counts = output[
|
||||
0:5]
|
||||
expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, \
|
||||
ep_recv_counts, _, expand_scales = output[0:7]
|
||||
|
||||
if shared_experts is not None:
|
||||
with npu_stream_switch("moe_secondary", 0):
|
||||
@@ -167,6 +176,7 @@ def torchair_fused_experts_with_mc2(
|
||||
"group_ep": moe_all_to_all_group_name,
|
||||
"ep_world_size": ep_world_size,
|
||||
"ep_rank_id": ep_rank_id,
|
||||
"expand_scales": expand_scales,
|
||||
}
|
||||
if enable_dispatch_v2:
|
||||
stage3_kwargs.update({
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -29,7 +28,8 @@ from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.torchair.ops.torchair_fused_moe import torchair_select_experts
|
||||
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
|
||||
dispose_tensor, get_ascend_soc_version)
|
||||
dispose_tensor, get_ascend_soc_version,
|
||||
is_hierarchical_communication_enabled)
|
||||
|
||||
|
||||
def torchair_apply_mlp_decode(hidden_states: torch.Tensor,
|
||||
@@ -237,6 +237,10 @@ def torchair_fused_experts_with_mc2(
|
||||
|
||||
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
|
||||
a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3
|
||||
# NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and
|
||||
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
|
||||
# improve communication performance.
|
||||
need_expert_scale = is_hierarchical_communication_enabled()
|
||||
|
||||
enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2")
|
||||
|
||||
@@ -271,6 +275,10 @@ def torchair_fused_experts_with_mc2(
|
||||
stage1_kwargs.update({
|
||||
"x_active_mask": mc2_mask,
|
||||
})
|
||||
if need_expert_scale:
|
||||
stage1_kwargs.update({
|
||||
"expert_scales": topk_weights.to(torch.float32),
|
||||
})
|
||||
kwargs_mc2.update(stage1_kwargs)
|
||||
|
||||
output = torch_npu.npu_moe_distribute_dispatch_v2(
|
||||
@@ -278,8 +286,8 @@ def torchair_fused_experts_with_mc2(
|
||||
) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch(
|
||||
**kwargs_mc2)
|
||||
# comm_stream.wait_stream(torch.npu.current_stream())
|
||||
expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, ep_recv_counts = output[
|
||||
0:5]
|
||||
expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, \
|
||||
ep_recv_counts, _, expand_scales = output[0:7]
|
||||
|
||||
if shared_experts is not None:
|
||||
with npu_stream_switch("moe_secondary", 0):
|
||||
@@ -327,6 +335,7 @@ def torchair_fused_experts_with_mc2(
|
||||
"group_ep": moe_all_to_all_group_name,
|
||||
"ep_world_size": ep_world_size,
|
||||
"ep_rank_id": ep_rank_id,
|
||||
"expand_scales": expand_scales,
|
||||
}
|
||||
if enable_dispatch_v2:
|
||||
stage3_kwargs.update({
|
||||
|
||||
@@ -666,7 +666,7 @@ def get_hccl_config_for_pg_options(group_name: str) -> Optional[dict]:
|
||||
|
||||
Args:
|
||||
group_name: Name of the communication group
|
||||
|
||||
|
||||
Returns:
|
||||
HCCL pg_options or None for mc2 group
|
||||
"""
|
||||
@@ -689,7 +689,7 @@ def get_default_buffer_config() -> dict:
|
||||
|
||||
def calculate_dp_buffer_size() -> int:
|
||||
"""
|
||||
formula of dp buffer size:
|
||||
formula of dp buffer size:
|
||||
dp_size + 2 (flags: with_prefill and enable_dbo)
|
||||
"""
|
||||
from vllm.config import get_current_vllm_config
|
||||
@@ -698,3 +698,11 @@ def calculate_dp_buffer_size() -> int:
|
||||
int32_size = torch.iinfo(torch.int32).bits // 8
|
||||
dp_buffer_size = math.ceil((dp_size + 2) * int32_size / (1024 * 1024))
|
||||
return max(dp_buffer_size, _MIN_DP_BUFFER_SIZE)
|
||||
|
||||
|
||||
# Currently, when in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1
|
||||
# and HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and
|
||||
# significantly improve communication performance of MC2 ops dispatch/combine.
|
||||
def is_hierarchical_communication_enabled():
|
||||
return (os.getenv("HCCL_INTRA_ROCE_ENABLE", "") == "0"
|
||||
and os.getenv("HCCL_INTRA_PCIE_ENABLE", "") == "1")
|
||||
|
||||
Reference in New Issue
Block a user