[feat] enable hierarchical mc2 ops on A2 by default (#5300)
### What this PR does / why we need it?
Previously, it was necessary to set the environment variables
HCCL_INTRA_PCIE_ENABLE=1 and HCCL_INTRA_ROCE_ENABLE=0. This PR enables
hierarchical MC2 operations on A2 by default.
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: release/v0.13.0
- vLLM main:
ad32e3e19c
Signed-off-by: hwhaokun <haokun0405@163.com>
Co-authored-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
@@ -31,8 +31,7 @@ from vllm.distributed.parallel_state import get_ep_group
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.fused_moe.comm_utils import (
|
||||
async_all_to_all, gather_from_sequence_parallel_region)
|
||||
from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
|
||||
is_hierarchical_communication_enabled)
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
||||
|
||||
|
||||
class MoETokenDispatcher(ABC):
|
||||
@@ -101,10 +100,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
self.need_extra_args = (
|
||||
get_ascend_device_type() == AscendDeviceType.A3)
|
||||
|
||||
# NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and
|
||||
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
|
||||
# improve communication performance.
|
||||
self.need_expert_scale = is_hierarchical_communication_enabled()
|
||||
self.with_quant = False
|
||||
|
||||
# Here we need to calculate the global_bs = max_bs_per_rank * ep_world_size to execute
|
||||
@@ -142,6 +137,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
else:
|
||||
quant_mode = 0
|
||||
moe_expert_num = len(expert_map)
|
||||
|
||||
kwargs_mc2 = {
|
||||
"x": hidden_states,
|
||||
"expert_ids": topk_ids,
|
||||
@@ -150,8 +146,12 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
"moe_expert_num": moe_expert_num,
|
||||
"global_bs": self.global_bs,
|
||||
"expert_token_nums_type": 0,
|
||||
"expert_scales": topk_weights.to(torch.float32),
|
||||
}
|
||||
|
||||
if get_ascend_device_type() == AscendDeviceType.A2:
|
||||
kwargs_mc2["comm_alg"] = "hierarchy"
|
||||
|
||||
stage1_kwargs = {
|
||||
"scales": None,
|
||||
"quant_mode": quant_mode,
|
||||
@@ -165,11 +165,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
"tp_world_size": 1,
|
||||
"tp_rank_id": 0,
|
||||
})
|
||||
if self.need_expert_scale:
|
||||
stage1_kwargs.update({
|
||||
"expert_scales":
|
||||
topk_weights.to(torch.float32),
|
||||
})
|
||||
|
||||
kwargs_mc2.update(stage1_kwargs)
|
||||
return kwargs_mc2
|
||||
@@ -269,8 +264,12 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
"shared_expert_rank_num": 0,
|
||||
"moe_expert_num": moe_expert_num,
|
||||
"global_bs": self.global_bs,
|
||||
"expand_scales": expand_scales,
|
||||
}
|
||||
|
||||
if get_ascend_device_type() == AscendDeviceType.A2:
|
||||
kwargs_mc2["comm_alg"] = "hierarchy"
|
||||
|
||||
if self.with_quant:
|
||||
tp_recv_counts = torch.empty(1,
|
||||
dtype=torch.int32,
|
||||
@@ -281,7 +280,6 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
"group_ep": self.moe_all_to_all_group_name,
|
||||
"ep_world_size": self.ep_world_size,
|
||||
"ep_rank_id": self.ep_rank_id,
|
||||
"expand_scales": expand_scales,
|
||||
}
|
||||
|
||||
if self.enable_dispatch_v2:
|
||||
|
||||
Reference in New Issue
Block a user