[Feat] enable hierarchical communication for mc2 ops on A2 (#3015)
Currently, when in A2, setting the environment variables `HCCL_INTRA_PCIE_ENABLE=1` and `HCCL_INTRA_ROCE_ENABLE=0` can reduce cross-machine communication traffic and significantly improve communication performance. For more details, please refer to [document](https://www.hiascend.com/document/detail/zh/Pytorch/710/apiref/torchnpuCustomsapi/context/torch_npu-npu_moe_distribute_dispatch_v2.md) - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: realliujiaxu <realliujiaxu@163.com>
This commit is contained in:
@@ -20,7 +20,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Optional
|
||||
|
||||
@@ -31,7 +30,8 @@ from vllm.distributed.parallel_state import get_ep_group
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.moe.comm_utils import (
|
||||
async_all_to_all, gather_from_sequence_parallel_region)
|
||||
from vllm_ascend.utils import AscendSocVersion, get_ascend_soc_version
|
||||
from vllm_ascend.utils import (AscendSocVersion, get_ascend_soc_version,
|
||||
is_hierarchical_communication_enabled)
|
||||
|
||||
|
||||
class MoETokenDispatcher(ABC):
|
||||
@@ -99,6 +99,10 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
|
||||
self.a3_need_extra_args = \
|
||||
get_ascend_soc_version() == AscendSocVersion.A3
|
||||
# NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and
|
||||
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
|
||||
# improve communication performance.
|
||||
self.need_expert_scale = is_hierarchical_communication_enabled()
|
||||
self.output = None
|
||||
self.assist_info_for_combine = None
|
||||
self.ep_recv_counts = None
|
||||
@@ -108,6 +112,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
self.shared_experts = None
|
||||
self.mc2_mask = None
|
||||
self.with_quant = False
|
||||
self.expand_scales = None
|
||||
|
||||
def get_dispatch_mc2_kwargs(
|
||||
self,
|
||||
@@ -153,6 +158,11 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
stage1_kwargs.update({
|
||||
"x_active_mask": self.mc2_mask,
|
||||
})
|
||||
if self.need_expert_scale:
|
||||
stage1_kwargs.update({
|
||||
"expert_scales":
|
||||
topk_weights.to(torch.float32),
|
||||
})
|
||||
|
||||
kwargs_mc2.update(stage1_kwargs)
|
||||
return kwargs_mc2
|
||||
@@ -186,8 +196,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch(
|
||||
**kwargs_mc2)
|
||||
# comm_stream.wait_stream(torch.npu.current_stream())
|
||||
expand_x, dynamic_scale, self.assist_info_for_combine, \
|
||||
expert_token_nums, self.ep_recv_counts = self.output[0:5]
|
||||
expand_x, dynamic_scale, self.assist_info_for_combine, expert_token_nums, \
|
||||
self.ep_recv_counts, _, self.expand_scales = self.output[0:7]
|
||||
|
||||
if self.with_quant:
|
||||
if shared_experts is not None:
|
||||
@@ -240,6 +250,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
"group_ep": self.moe_all_to_all_group_name,
|
||||
"ep_world_size": self.ep_world_size,
|
||||
"ep_rank_id": self.ep_rank_id,
|
||||
"expand_scales": self.expand_scales,
|
||||
}
|
||||
if self.enable_dispatch_v2:
|
||||
stage3_kwargs.update({
|
||||
@@ -281,6 +292,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
self.topk_weights = None
|
||||
self.mc2_mask = None
|
||||
self.expert_map = None
|
||||
self.expand_scales = None
|
||||
|
||||
if self.shared_experts is None:
|
||||
return hidden_states
|
||||
|
||||
Reference in New Issue
Block a user