[0.18.0][cherry-pick][BugFix]Fix compilation errors for operators dispatch_gmm_combine_decode/moe_combine_normal/moe_dispatch_normal (#7844)
**What this PR does / why we need it?** pick from https://github.com/vllm-project/vllm-ascend/pull/7114 Fix compilation errors encountered when building versions later than b020 for the following operators: dispatch_gmm_combine_decode, moe_combine_normal, moe_dispatch_normal **Root Cause** After the b020 version update, the original moe_distribute_base.h file was updated and its definitions changed, which caused compilation failures for the above three operators that depend on this file. **Solution** We have added a dedicated copy of moe_distribute_base.h into the implementation of these three operators, ensuring stable compilation independent of framework version updates. **Does this PR introduce any user-facing change?** No. There are no user-facing changes; this fix only resolves compilation issues without affecting functionality or user behavior. **How was this patch tested?** vLLM version: releases/v0.18.0 Signed-off-by: Wangyibo1005 <2633333316@qq.com>
This commit is contained in:
@@ -39,12 +39,6 @@ elif [[ "$SOC_VERSION" =~ ^ascend910_93 ]]; then
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
# dependency: cann-toolkit file moe_distribute_base.h
|
||||
HCCL_STRUCT_FILE_PATH=$(find -L "${ASCEND_TOOLKIT_HOME}" -name "moe_distribute_base.h" 2>/dev/null | head -n1)
|
||||
if [ -z "$HCCL_STRUCT_FILE_PATH" ]; then
|
||||
echo "cannot find moe_distribute_base.h file in CANN env"
|
||||
exit 1
|
||||
fi
|
||||
# for dispatch_gmm_combine_decode
|
||||
yes | cp "${HCCL_STRUCT_FILE_PATH}" "${ROOT_DIR}/csrc/utils/inc/kernel"
|
||||
|
||||
|
||||
@@ -47,6 +47,7 @@ target_sources(optiling PRIVATE
|
||||
|
||||
target_include_directories(optiling PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
${CMAKE_SOURCE_DIR}/utils/inc/kernel
|
||||
)
|
||||
|
||||
target_sources(opsproto PRIVATE
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
#ifndef DISPATCH_GMM_COMBINE_DECODE_BASE_H
|
||||
#define DISPATCH_GMM_COMBINE_DECODE_BASE_H
|
||||
|
||||
#include "../common/moe_distribute_base.h"
|
||||
#include "moe_distribute_base.h"
|
||||
|
||||
#define TemplateMC2TypeClass typename ExpandXType, typename W1ScaleType, typename W2ScaleType, typename WType, typename ExpandIdxType, bool IsNeedReduceScatter, uint32_t EXEC_FLAG
|
||||
#define TemplateMC2TypeFunc ExpandXType, W1ScaleType, W2ScaleType, WType, ExpandIdxType, IsNeedReduceScatter, EXEC_FLAG
|
||||
|
||||
@@ -38,6 +38,7 @@ target_sources(optiling PRIVATE
|
||||
|
||||
target_include_directories(optiling PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
${CMAKE_SOURCE_DIR}/utils/inc/kernel
|
||||
)
|
||||
|
||||
target_sources(opsproto PRIVATE)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "kernel_tiling/kernel_tiling.h"
|
||||
#include "../common/moe_distribute_base.h"
|
||||
#include "moe_distribute_base.h"
|
||||
#include "moe_combine_normal_tiling.h"
|
||||
|
||||
namespace MoeCombineNormalImpl {
|
||||
|
||||
@@ -38,6 +38,7 @@ target_sources(optiling PRIVATE
|
||||
|
||||
target_include_directories(optiling PRIVATE
|
||||
${CMAKE_CURRENT_SOURCE_DIR}
|
||||
${CMAKE_SOURCE_DIR}/utils/inc/kernel
|
||||
)
|
||||
|
||||
target_sources(opsproto PRIVATE)
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#include "kernel_operator.h"
|
||||
#include "kernel_tiling/kernel_tiling.h"
|
||||
#include "../common/moe_distribute_base.h"
|
||||
#include "moe_distribute_base.h"
|
||||
#include "moe_dispatch_normal_tiling.h"
|
||||
|
||||
namespace MoeDispatchNormalImpl {
|
||||
|
||||
288
csrc/utils/inc/kernel/moe_distribute_base.h
Normal file
288
csrc/utils/inc/kernel/moe_distribute_base.h
Normal file
@@ -0,0 +1,288 @@
|
||||
/**
|
||||
* Copyright (c) 2025 Huawei Technologies Co., Ltd.
|
||||
* This program is free software, you can redistribute it and/or modify it under the terms and conditions of
|
||||
* CANN Open Software License Agreement Version 2.0 (the "License").
|
||||
* Please refer to the License for details. You may not use this file except in compliance with the License.
|
||||
* THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED,
|
||||
* INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE.
|
||||
* See LICENSE in the root of the software repository for the full text of the License.
|
||||
*/
|
||||
|
||||
/*!
|
||||
* \file moe_distribute_base.h
|
||||
* \brief
|
||||
*/
|
||||
|
||||
#ifndef MOE_DISTRIBUTE_BASE_H
|
||||
#define MOE_DISTRIBUTE_BASE_H
|
||||
|
||||
constexpr uint32_t LOCAL_NOTIFY_MAX_NUM = 64;
|
||||
constexpr uint32_t LOCAL_STREAM_MAX_NUM = 19;
|
||||
constexpr uint32_t AICPU_OP_NOTIFY_MAX_NUM = 2;
|
||||
constexpr uint32_t AICPU_MAX_RANK_NUM = 128 * 1024;
|
||||
|
||||
struct HcclSignalInfo {
|
||||
uint64_t resId;
|
||||
uint64_t addr;
|
||||
uint32_t devId;
|
||||
uint32_t tsId;
|
||||
uint32_t rankId;
|
||||
uint32_t flag;
|
||||
};
|
||||
|
||||
struct ListCommon {
|
||||
uint64_t nextHost;
|
||||
uint64_t preHost;
|
||||
uint64_t nextDevice;
|
||||
uint64_t preDevice;
|
||||
};
|
||||
|
||||
struct HcclStreamInfo {
|
||||
int32_t streamIds;
|
||||
uint32_t sqIds;
|
||||
uint32_t cqIds;
|
||||
uint32_t logicCqids;
|
||||
};
|
||||
|
||||
struct LocalResInfoV2 {
|
||||
uint32_t streamNum;
|
||||
uint32_t signalNum;
|
||||
HcclSignalInfo localSignals[LOCAL_NOTIFY_MAX_NUM];
|
||||
HcclStreamInfo streamInfo[LOCAL_STREAM_MAX_NUM];
|
||||
HcclStreamInfo mainStreamInfo;
|
||||
HcclSignalInfo aicpuOpNotify[AICPU_OP_NOTIFY_MAX_NUM];
|
||||
ListCommon nextTagRes; // HccltagLocalResV2
|
||||
};
|
||||
|
||||
enum class rtFloatOverflowMode_t {
|
||||
RT_OVERFLOW_MODE_SATURATION = 0,
|
||||
RT_OVERFLOW_MODE_INFNAN,
|
||||
RT_OVERFLOW_MODE_UNDEF,
|
||||
};
|
||||
|
||||
struct AlgoTopoInfo {
|
||||
uint32_t userRank;
|
||||
uint32_t userRankSize;
|
||||
int32_t deviceLogicId;
|
||||
bool isSingleMeshAggregation;
|
||||
uint32_t deviceNumPerAggregation;
|
||||
uint32_t superPodNum;
|
||||
uint32_t devicePhyId;
|
||||
uint32_t topoType; // TopoType
|
||||
uint32_t deviceType;
|
||||
uint32_t serverNum;
|
||||
uint32_t meshAggregationRankSize;
|
||||
uint32_t multiModuleDiffDeviceNumMode;
|
||||
uint32_t multiSuperPodDiffServerNumMode;
|
||||
uint32_t realUserRank;
|
||||
bool isDiffDeviceModule;
|
||||
bool isDiffDeviceType;
|
||||
uint32_t gcdDeviceNumPerAggregation;
|
||||
uint32_t moduleNum;
|
||||
uint32_t isUsedRdmaRankPairNum;
|
||||
uint64_t isUsedRdmaRankPair;
|
||||
uint32_t pairLinkCounterNum;
|
||||
uint64_t pairLinkCounter;
|
||||
uint32_t nicNum;
|
||||
uint64_t nicList;
|
||||
uint64_t complanRankLength;
|
||||
uint64_t complanRank;
|
||||
uint64_t bridgeRankNum;
|
||||
uint64_t bridgeRank;
|
||||
uint64_t serverAndsuperPodRankLength;
|
||||
uint64_t serverAndsuperPodRank;
|
||||
};
|
||||
|
||||
struct HcclOpConfig {
|
||||
uint8_t deterministic;
|
||||
uint8_t retryEnable;
|
||||
uint8_t highPerfEnable;
|
||||
uint8_t padding[5];
|
||||
uint8_t linkTimeOut[8];
|
||||
uint64_t notifyWaitTime;
|
||||
uint32_t retryHoldTime;
|
||||
uint32_t retryIntervalTime;
|
||||
bool interHccsDisable = false;
|
||||
rtFloatOverflowMode_t floatOverflowMode = rtFloatOverflowMode_t::RT_OVERFLOW_MODE_UNDEF;
|
||||
uint32_t multiQpThreshold = 512;
|
||||
};
|
||||
|
||||
struct HcclMC2WorkSpace {
|
||||
uint64_t workSpace;
|
||||
uint64_t workSpaceSize;
|
||||
};
|
||||
|
||||
struct RemoteResPtr {
|
||||
uint64_t nextHostPtr;
|
||||
uint64_t nextDevicePtr;
|
||||
};
|
||||
|
||||
struct HDCommunicateParams {
|
||||
uint64_t hostAddr { 0 };
|
||||
uint64_t deviceAddr { 0 };
|
||||
uint64_t readCacheAddr { 0 };
|
||||
uint32_t devMemSize{ 0 };
|
||||
uint32_t buffLen{ 0 };
|
||||
uint32_t flag{ 0 };
|
||||
};
|
||||
|
||||
struct HcclRankRelationResV2 {
|
||||
uint32_t remoteUsrRankId;
|
||||
uint32_t remoteWorldRank;
|
||||
uint64_t windowsIn;
|
||||
uint64_t windowsOut;
|
||||
uint64_t windowsExp;
|
||||
ListCommon nextTagRes;
|
||||
};
|
||||
|
||||
struct HcclOpResParam {
|
||||
HcclMC2WorkSpace mc2WorkSpace;
|
||||
uint32_t localUsrRankId; // usrrankid
|
||||
uint32_t rankSize;
|
||||
uint64_t winSize;
|
||||
uint64_t localWindowsIn;
|
||||
uint64_t localWindowsOut;
|
||||
char hcomId[128];
|
||||
uint64_t winExpSize;
|
||||
uint64_t localWindowsExp;
|
||||
uint32_t rWinStart;
|
||||
uint32_t rWinOffset;
|
||||
uint64_t version;
|
||||
LocalResInfoV2 localRes;
|
||||
AlgoTopoInfo topoInfo;
|
||||
|
||||
HcclOpConfig config;
|
||||
uint64_t hostStateInfo;
|
||||
uint64_t aicpuStateInfo;
|
||||
uint64_t lockAddr;
|
||||
uint32_t rsv[16];
|
||||
uint32_t notifysize;
|
||||
uint32_t remoteResNum;
|
||||
RemoteResPtr remoteRes[AICPU_MAX_RANK_NUM];
|
||||
|
||||
// communicate retry
|
||||
HDCommunicateParams kfcControlTransferH2DParams;
|
||||
HDCommunicateParams kfcStatusTransferD2HParams;
|
||||
uint64_t tinyMem; // for all2all
|
||||
uint64_t tinyMemSize;
|
||||
uint64_t zeroCopyHeadPtr;
|
||||
uint64_t zeroCopyTailPtr;
|
||||
uint64_t zeroCopyRingBuffer;
|
||||
uint64_t zeroCopyIpcPtrs[16];
|
||||
uint32_t zeroCopyDevicePhyId[16];
|
||||
|
||||
bool utraceStatusFlag;
|
||||
};
|
||||
|
||||
// Transport
|
||||
enum class HcclAiRMAMemType : uint32_t {
|
||||
LOCAL_INPUT = 0,
|
||||
REMOTE_INPUT,
|
||||
LOCAL_OUTPUT,
|
||||
REMOTE_OUTPUT,
|
||||
MAX_NUM
|
||||
};
|
||||
|
||||
struct HcclAiRMAMemInfo {
|
||||
uint32_t memMaxNum{0};
|
||||
uint32_t sizeOfMemDetails{0};
|
||||
uint64_t memDetailPtr{0};
|
||||
};
|
||||
|
||||
// Transport QP/Mem
|
||||
struct HcclAiRMAInfo {
|
||||
uint32_t curRankId{0};
|
||||
uint32_t rankNum{0};
|
||||
uint32_t qpNum{0};
|
||||
uint32_t sizeOfAiRMAWQ{0}; // sizeof(HcclAiRMAWQ)
|
||||
uint32_t sizeOfAiRMACQ{0}; // sizeof(HcclAiRMACQ)
|
||||
uint32_t sizeOfAiRMAMem{0}; // sizeof(HcclAiRMAMemInfo)
|
||||
uint64_t sqPtr{0};
|
||||
uint64_t scqPtr{0};
|
||||
uint64_t rqPtr{0};
|
||||
uint64_t rcqPtr{0};
|
||||
uint64_t memPtr{0};
|
||||
};
|
||||
|
||||
struct HcclA2CombineOpParam {
|
||||
uint64_t workSpace; // Address for communication between client and server,
|
||||
// hccl requests and clears
|
||||
uint64_t workSpaceSize; // Space for communication between client and server
|
||||
uint32_t rankId; // id of this rank
|
||||
uint32_t rankNum; // num of ranks in this comm group
|
||||
uint64_t winSize; // size of each windows memory
|
||||
uint64_t windowsIn[AscendC::HCCL_MAX_RANK_NUM]; // windows address for input, windowsIn[rankId] corresponds
|
||||
// to the local card address,
|
||||
// and others are cross-card mapping addresses.
|
||||
uint64_t windowsOut[AscendC::HCCL_MAX_RANK_NUM]; // windows address for output, windowsOut[rankId] corresponds
|
||||
// to the local card address,
|
||||
// and others are cross-card mapping addresses.
|
||||
uint8_t res[8328];
|
||||
uint8_t multiFlag;
|
||||
__gm__ AscendC::IbVerbsData *data;
|
||||
uint64_t dataSize;
|
||||
|
||||
uint64_t sizeOfAiRMAInfo;
|
||||
uint64_t aiRMAInfo;
|
||||
};
|
||||
|
||||
enum class DBMode : int32_t {
|
||||
INVALID_DB = -1,
|
||||
HW_DB = 0,
|
||||
SW_DB
|
||||
};
|
||||
|
||||
struct HcclAiRMAWQ {
|
||||
uint32_t wqn{0};
|
||||
uint64_t bufAddr{0};
|
||||
uint32_t wqeSize{0};
|
||||
uint32_t depth{0};
|
||||
uint64_t headAddr{0};
|
||||
uint64_t tailAddr{0};
|
||||
DBMode dbMode{DBMode::INVALID_DB}; // 0-hw/1-sw
|
||||
uint64_t dbAddr{0};
|
||||
uint32_t sl{0};
|
||||
};
|
||||
|
||||
struct HcclAiRMACQ {
|
||||
uint32_t cqn{0};
|
||||
uint64_t bufAddr{0};
|
||||
uint32_t cqeSize{0};
|
||||
uint32_t depth{0};
|
||||
uint64_t headAddr{0};
|
||||
uint64_t tailAddr{0};
|
||||
DBMode dbMode{DBMode::INVALID_DB}; // 0-hw/1-sw
|
||||
uint64_t dbAddr{0};
|
||||
};
|
||||
|
||||
struct hns_roce_rc_sq_wqe {
|
||||
uint32_t byte_4;
|
||||
uint32_t msg_len;
|
||||
uint32_t immtdata;
|
||||
uint32_t byte_16;
|
||||
uint32_t byte_20;
|
||||
uint32_t rkey;
|
||||
uint64_t remoteVA;
|
||||
};
|
||||
|
||||
|
||||
struct hns_roce_lite_wqe_data_seg {
|
||||
uint32_t len;
|
||||
uint32_t lkey;
|
||||
uint64_t localVA;
|
||||
};
|
||||
|
||||
__aicore__ inline void cacheWriteThrough(__gm__ uint8_t* sourceAddr, uint64_t length) {
|
||||
__gm__ uint8_t* start =
|
||||
(__gm__ uint8_t*)((uint64_t)sourceAddr / AscendC::CACHE_LINE_SIZE * AscendC::CACHE_LINE_SIZE);
|
||||
__gm__ uint8_t* end =
|
||||
(__gm__ uint8_t*)(((uint64_t)sourceAddr + length) / AscendC::CACHE_LINE_SIZE * AscendC::CACHE_LINE_SIZE);
|
||||
AscendC::GlobalTensor<uint8_t> global;
|
||||
global.SetGlobalBuffer(start);
|
||||
for (uint32_t i = 0; i <= end - start; i += AscendC::CACHE_LINE_SIZE) {
|
||||
AscendC::DataCacheCleanAndInvalid<uint8_t, AscendC::CacheLine::SINGLE_CACHE_LINE,
|
||||
AscendC::DcciDst::CACHELINE_OUT>(global[i]);
|
||||
}
|
||||
}
|
||||
|
||||
#endif // MOE_DISTRIBUTE_BASE_H
|
||||
Reference in New Issue
Block a user