diff --git a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp index fed0232c..7051349b 100644 --- a/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp +++ b/csrc/dispatch_ffn_combine/op_kernel/dispatch_ffn_combine_kernel.hpp @@ -779,10 +779,10 @@ private: AscendC::GlobalTensor ExpertTokenNums; ExpertTokenNums.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(params.ptrExpertTokenNums)); - AscendC::GlobalTensor LcalCumsumMM; - LcalCumsumMM.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspaceInfo.ptrcumsumMM + (params.EP - 1) * params.expertPerRank * sizeof(int32_t))); - CopyGMToGM(ExpertTokenNums, LcalCumsumMM, params.expertPerRank, params.ubMoveNum); - AscendC::SyncAll(); + if(coreIdx == 0) + { + CopyGMToGM(ExpertTokenNums, cumsumMM[(params.EP - 1) * params.expertPerRank], params.expertPerRank, params.ubMoveNum); + } uint16_t syncgmm1Idx = 0; AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(syncgmm1Idx / CROSS_CORE_FLAG_MAX_SET_COUNT); syncgmm1Idx++; @@ -921,11 +921,11 @@ private: AscendC::LocalTensor statusTensor = resource.ubBuf.template GetBufferByByte(uboffset); uboffset += sendRankNum_ * UB_ALIGN; AscendC::LocalTensor gatherMaskOutTensor = resource.ubBuf.template GetBufferByByte(uboffset); - uboffset += params.EP * sizeof(float); + uboffset += AlignUp(params.EP * sizeof(float), 32); AscendC::LocalTensor gatherTmpTensor = resource.ubBuf.template GetBufferByByte(uboffset); - uboffset += sizeof(uint32_t); + uboffset += AlignUp(sizeof(uint32_t), 32); AscendC::LocalTensor statusSumOutTensor = resource.ubBuf.template GetBufferByByte(uboffset); - uboffset += sizeof(float); + uboffset += AlignUp(sizeof(float), 32); shmem.CrossRankSyncV2Wait(statusTensor, gatherMaskOutTensor, gatherTmpTensor, statusSumOutTensor); MoeTokenUnpermuteTilingData tilingData; MoeTokenUnpermuteTiling(params.problemShape.m() * params.topK, n2, params.topK, tilingData, coreNum / 2); diff --git a/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16_kernel.hpp b/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16_kernel.hpp index 231e9f72..51e939be 100644 --- a/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16_kernel.hpp +++ b/csrc/dispatch_ffn_combine_bf16/op_kernel/dispatch_ffn_combine_bf16_kernel.hpp @@ -756,8 +756,9 @@ CATLASS_DEVICE ExpertTokenNums.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(params.ptrExpertTokenNums)); AscendC::GlobalTensor LcalCumsumMM; LcalCumsumMM.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t*>(workspaceInfo.ptrcumsumMM + (params.EP - 1) * params.expertPerRank * sizeof(int32_t))); - CopyGMToGM(ExpertTokenNums, LcalCumsumMM, params.expertPerRank, params.ubMoveNum); - AscendC::SyncAll(); + if (coreIdx == 0) { + CopyGMToGM(ExpertTokenNums, LcalCumsumMM, params.expertPerRank, params.ubMoveNum); + } uint32_t curGroupOffset = 0; int32_t prevSumBeforeRank = 0; diff --git a/vllm_ascend/ops/fused_moe/moe_comm_method.py b/vllm_ascend/ops/fused_moe/moe_comm_method.py index 7ca34e74..9b27ed14 100644 --- a/vllm_ascend/ops/fused_moe/moe_comm_method.py +++ b/vllm_ascend/ops/fused_moe/moe_comm_method.py @@ -277,6 +277,13 @@ class FusedMC2CommImpl(MoECommMethod): Communication and Computation parallelism on Ascend devices. """ + def __init__(self, moe_config): + super().__init__(moe_config) + if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: + self.expert_token_nums = torch.zeros([self.moe_config.num_local_experts], dtype=torch.int32, device="npu") + else: + self.expert_token_nums = None + def _get_token_dispatcher(self): return TokenDispatcherWithMC2() @@ -325,7 +332,6 @@ class FusedMC2CommImpl(MoECommMethod): expert_tokens = None if envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 1: out = torch.empty_like(hidden_states) - expert_token_nums = torch.zeros([self.moe_config.num_local_experts], dtype=torch.int32) torch.ops._C_ascend.dispatch_ffn_combine( # type: ignore x=hidden_states, weight1=w1, @@ -337,9 +343,9 @@ class FusedMC2CommImpl(MoECommMethod): group=self.token_dispatcher.moe_all_to_all_group_name, max_output_size=65536, out=out, - expert_token_nums=expert_token_nums, + expert_token_nums=self.expert_token_nums, ) - expert_tokens = expert_token_nums + expert_tokens = self.expert_token_nums elif envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2: assert expert_map is not None, "expert_map cannot be None." group_list_type = 1