Update comment doc (#4731)
### What this PR does / why we need it?
Translate remaining Chinese comments in the `dispatch_ffn_combine` code
to English and update the installation guide to remind users to
initialize submodules when building from source.
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: mojave2 <chenchen145@huawei.com>
Signed-off-by: Chen Chen <0109chenchen@gmail.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -22,19 +22,22 @@ extern "C" {
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 算子功能:实现分布式MoE从InitRouting到Unpermute全部算子的融合
|
* Operator function: fuse all distributed MoE ops from InitRouting through Unpermute.
|
||||||
* @brief aclnnDispatchFFNCombine的第一段接口,根据具体的计算流程,计算workspace大小。
|
* @brief First-stage interface of aclnnDispatchFFNCombine that calculates workspace size based on the specific compute flow.
|
||||||
* @domain aclnn_ops_infer
|
* @domain aclnn_ops_infer
|
||||||
* @param [in] a: matmul左矩阵,数据类型支持:float16, bf16。
|
* @param [in] x: The input tensor.
|
||||||
* @param [in] b: matmul右矩阵,数据类型支持:float16, bf16。
|
* @param [in] weight1: The first weight tensor.
|
||||||
* @param [in] bias: 偏置,数据类型支持:float16, bf16。
|
* @param [in] weight2: The second weight tensor.
|
||||||
* @param [in] group: 标识通信域名称的字符串。
|
* @param [in] expertId: The expert ID tensor.
|
||||||
* @param [in] worldsize: 通信域size,支持2/4/8卡。
|
* @param [in] scale1: The first scale tensor.
|
||||||
* @param [in] epRankId: ep本卡Id。取值范围[0, worldSize),各卡的rankId不能重复
|
* @param [in] scale2: The second scale tensor.
|
||||||
* @param [out] c: 计算+通信的结果,数据类型:同输入。
|
* @param [in] probs: The probabilities tensor.
|
||||||
* @param [out] workspaceSize: 返回需要在npu device侧申请的workspace大小。
|
* @param [in] group: string identifying the communication domain name.
|
||||||
* @param [out] executor: 返回op执行器,包含了算子计算流程。
|
* @param [in] maxOutputSize: The maximum output size.
|
||||||
* @return aclnnStatus: 返回状态码
|
* @param [out] out: result of computation + communication; same dtype as input.
|
||||||
|
* @param [out] workspaceSize: workspace size to allocate on the NPU device side.
|
||||||
|
* @param [out] executor: op executor containing the operator compute flow.
|
||||||
|
* @return aclnnStatus: status code.
|
||||||
*/
|
*/
|
||||||
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensor* weight1, const aclTensor* weight2,
|
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombineGetWorkspaceSize(const aclTensor* x, const aclTensor* weight1, const aclTensor* weight2,
|
||||||
const aclTensor* expertId, const aclTensor* scale1, const aclTensor* scale2,
|
const aclTensor* expertId, const aclTensor* scale1, const aclTensor* scale2,
|
||||||
@@ -44,12 +47,12 @@ __attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombineGetWor
|
|||||||
uint64_t* workspaceSize, aclOpExecutor** executor);
|
uint64_t* workspaceSize, aclOpExecutor** executor);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief aclnnDispatchGmmCombine的第二段接口,用于执行计算。
|
* @brief Second-stage interface of aclnnDispatchFFNCombine to execute computation.
|
||||||
* @param [in] workspace: 在npu device侧申请的workspace内存起址。
|
* @param [in] workspace: workspace memory address allocated on the NPU device side.
|
||||||
* @param [in] workspace_size: 在npu device侧申请的workspace大小,由第一段接口aclnnDispatchFFNCombineGetWorkspaceSize获取。
|
* @param [in] workspace_size: workspace size allocated on the NPU device side, obtained from aclnnDispatchFFNCombineGetWorkspaceSize.
|
||||||
* @param [in] exector: op执行器,包含了算子计算流程。
|
* @param [in] executor: op executor containing the operator compute flow.
|
||||||
* @param [in] stream: acl stream流。
|
* @param [in] stream: acl stream.
|
||||||
* @return aclnnStatus: 返回状态码
|
* @return aclnnStatus: status code.
|
||||||
*/
|
*/
|
||||||
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombine(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor,
|
__attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombine(void* workspace, uint64_t workspaceSize, aclOpExecutor* executor,
|
||||||
aclrtStream stream);
|
aclrtStream stream);
|
||||||
@@ -58,4 +61,4 @@ __attribute__((visibility("default"))) aclnnStatus aclnnDispatchFFNCombine(void*
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#endif // OP_API_INC_GMM_ALLTOALLV_
|
#endif // OP_API_INC_DISPATCH_FFN_COMBINE_
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ class DispatchFFNCombine : public OpDef {
|
|||||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||||
|
|
||||||
// 输出
|
// Output
|
||||||
this->Output("out")
|
this->Output("out")
|
||||||
.ParamType(REQUIRED)
|
.ParamType(REQUIRED)
|
||||||
.DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_BF16})
|
.DataType({ge::DT_FLOAT16, ge::DT_BF16, ge::DT_BF16})
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ using namespace AscendC;
|
|||||||
using namespace ge;
|
using namespace ge;
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// 1. 常量定义
|
// 1. Constant definitions
|
||||||
const char *K_INNER_DEBUG = "DispatchFFNCombine Tiling Debug";
|
const char *K_INNER_DEBUG = "DispatchFFNCombine Tiling Debug";
|
||||||
constexpr uint32_t ATTR_GROUP_INDEX = 0;
|
constexpr uint32_t ATTR_GROUP_INDEX = 0;
|
||||||
constexpr uint32_t ATTR_MAX_OUTPUT_SIZE_INDEX = 1;
|
constexpr uint32_t ATTR_MAX_OUTPUT_SIZE_INDEX = 1;
|
||||||
@@ -54,13 +54,13 @@ static int32_t CeilDev(int32_t num, int32_t div)
|
|||||||
return (num + div - 1) / div;
|
return (num + div - 1) / div;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 解析并校验 rankId, group, worldSize, isTransB 属性值
|
// Parse and validate rankId, group, worldSize, and isTransB attributes
|
||||||
static ge::graphStatus DispatchFFNCombineCheckAttrAndSetTiling(gert::TilingContext *context, DispatchFFNCombineInfo& info)
|
static ge::graphStatus DispatchFFNCombineCheckAttrAndSetTiling(gert::TilingContext *context, DispatchFFNCombineInfo& info)
|
||||||
{
|
{
|
||||||
auto attrs = context->GetAttrs();
|
auto attrs = context->GetAttrs();
|
||||||
OP_TILING_CHECK(attrs == nullptr, OP_LOGE(K_INNER_DEBUG, "attrs is null."), return ge::GRAPH_FAILED);
|
OP_TILING_CHECK(attrs == nullptr, OP_LOGE(K_INNER_DEBUG, "attrs is null."), return ge::GRAPH_FAILED);
|
||||||
|
|
||||||
// todo:Attr相关tilingdata的设置、校验、打印
|
// TODO: set, validate, and print tiling data related to attributes
|
||||||
auto groupPtr = attrs->GetAttrPointer<char>(static_cast<int>(ATTR_GROUP_INDEX));
|
auto groupPtr = attrs->GetAttrPointer<char>(static_cast<int>(ATTR_GROUP_INDEX));
|
||||||
auto maxOutputSizePtr = attrs->GetAttrPointer<int>(ATTR_MAX_OUTPUT_SIZE_INDEX);
|
auto maxOutputSizePtr = attrs->GetAttrPointer<int>(ATTR_MAX_OUTPUT_SIZE_INDEX);
|
||||||
auto is_trans_b = attrs->GetAttrPointer<bool>(ATTR_IS_TRANS_B);
|
auto is_trans_b = attrs->GetAttrPointer<bool>(ATTR_IS_TRANS_B);
|
||||||
@@ -87,7 +87,7 @@ static ge::graphStatus DispatchFFNCombineCheckAttrAndSetTiling(gert::TilingConte
|
|||||||
return ge::GRAPH_SUCCESS;
|
return ge::GRAPH_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 提取输入张量 A 和 B 的形状,计算出 M、K、N 值
|
// Extract shapes of input tensors A and B to compute M, K, N
|
||||||
static ge::graphStatus DispatchFFNCombineCheckShapeAndSetTiling(gert::TilingContext *context, DispatchFFNCombineInfo &info)
|
static ge::graphStatus DispatchFFNCombineCheckShapeAndSetTiling(gert::TilingContext *context, DispatchFFNCombineInfo &info)
|
||||||
{
|
{
|
||||||
const char *nodeName = context->GetNodeName();
|
const char *nodeName = context->GetNodeName();
|
||||||
@@ -116,7 +116,7 @@ static ge::graphStatus DispatchFFNCombineCheckShapeAndSetTiling(gert::TilingCont
|
|||||||
return ge::GRAPH_SUCCESS;
|
return ge::GRAPH_SUCCESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 获取当前芯片平台的 AI Core 数目、UB 容量等硬件信息。
|
// Get hardware info such as AI Core count and UB capacity for the current chip platform.
|
||||||
static ge::graphStatus DispatchFFNCombineGetPlatformInfoAndSetTiling(gert::TilingContext *context, DispatchFFNCombineInfo& info)
|
static ge::graphStatus DispatchFFNCombineGetPlatformInfoAndSetTiling(gert::TilingContext *context, DispatchFFNCombineInfo& info)
|
||||||
{
|
{
|
||||||
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
|
auto ascendcPlatform = platform_ascendc::PlatformAscendC(context->GetPlatformInfo());
|
||||||
@@ -146,9 +146,9 @@ void SetTilingData(CoCTiling &cocTilingData, DispatchFFNCombineInfo &info)
|
|||||||
cocTilingData.lenPerLoop = cocTilingData.m0 * cocTilingData.n0 / 2;
|
cocTilingData.lenPerLoop = cocTilingData.m0 * cocTilingData.n0 / 2;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 主调度函数:
|
// Main scheduling function:
|
||||||
// 获取 tilingData ➝ 检查 Attr ➝ 检查 Shape ➝ 获取平台信息
|
// Get tilingData ➝ check Attr ➝ check Shape ➝ get platform info
|
||||||
// ➝ 调用 SetTilingData(根据rank数目) ➝ 设置 blockDim ➝ 设置 tilingKey ➝ 设置 workspace ➝ 配置通信参数
|
// ➝ call SetTilingData (based on rank count) ➝ set blockDim ➝ set tilingKey ➝ set workspace ➝ configure communication parameters
|
||||||
|
|
||||||
static ge::graphStatus DispatchFFNCombineTilingFuncImpl(gert::TilingContext *context)
|
static ge::graphStatus DispatchFFNCombineTilingFuncImpl(gert::TilingContext *context)
|
||||||
{
|
{
|
||||||
@@ -262,4 +262,4 @@ ge::graphStatus TilingParseForDispatchFFNCombine(gert::TilingParseContext *conte
|
|||||||
IMPL_OP_OPTILING(DispatchFFNCombine)
|
IMPL_OP_OPTILING(DispatchFFNCombine)
|
||||||
.Tiling(DispatchFFNCombineTilingFunc)
|
.Tiling(DispatchFFNCombineTilingFunc)
|
||||||
.TilingParse<DispatchFFNCombineCompileInfo>(TilingParseForDispatchFFNCombine);
|
.TilingParse<DispatchFFNCombineCompileInfo>(TilingParseForDispatchFFNCombine);
|
||||||
} // namespace optiling
|
} // namespace optiling
|
||||||
|
|||||||
@@ -64,8 +64,8 @@ class HcomTopoInfo {
|
|||||||
~HcomTopoInfo() = default;
|
~HcomTopoInfo() = default;
|
||||||
std::unordered_map<std::string, TopoInfo> rank_info_;
|
std::unordered_map<std::string, TopoInfo> rank_info_;
|
||||||
std::mutex mutex_;
|
std::mutex mutex_;
|
||||||
std::unordered_map<std::string, void*> group_to_ordered_stream_; // 通信域保序流
|
std::unordered_map<std::string, void*> group_to_ordered_stream_; // Ordered stream for the communication domain
|
||||||
std::unordered_map<int32_t, std::unordered_map<std::string, void*>> device_id_to_group_to_ordered_stream_; // 通信域保序流
|
std::unordered_map<int32_t, std::unordered_map<std::string, void*>> device_id_to_group_to_ordered_stream_; // Ordered stream for the communication domain
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -8,8 +8,8 @@
|
|||||||
* See LICENSE in the root of the software repository for the full text of the License.
|
* See LICENSE in the root of the software repository for the full text of the License.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#ifndef DISPATH_FFN_COMBINE_KERNEL_HPP
|
#ifndef DISPATCH_FFN_COMBINE_KERNEL_HPP
|
||||||
#define DISPATH_FFN_COMBINE_KERNEL_HPP
|
#define DISPATCH_FFN_COMBINE_KERNEL_HPP
|
||||||
|
|
||||||
#include "kernel_operator.h"
|
#include "kernel_operator.h"
|
||||||
|
|
||||||
@@ -324,7 +324,7 @@ private:
|
|||||||
int64_t gmGroupOffsetC = 0;
|
int64_t gmGroupOffsetC = 0;
|
||||||
uint32_t startCoreIdx = 0;
|
uint32_t startCoreIdx = 0;
|
||||||
uint32_t syncGroupIdx = 0;
|
uint32_t syncGroupIdx = 0;
|
||||||
AscendC::CrossCoreWaitFlag<0x2>(0); // 等待aiv计算cumsumformm
|
AscendC::CrossCoreWaitFlag<0x2>(0); // Wait for AIV to finish cumsum for matmul
|
||||||
int64_t preCurrentmSum = 0;
|
int64_t preCurrentmSum = 0;
|
||||||
int32_t syncLoopIdx = -1;
|
int32_t syncLoopIdx = -1;
|
||||||
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
for (uint32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
||||||
@@ -364,7 +364,7 @@ private:
|
|||||||
int64_t gmOffsetA = layoutA.GetOffset(offsetA);
|
int64_t gmOffsetA = layoutA.GetOffset(offsetA);
|
||||||
int64_t gmOffsetB = layoutB1.GetOffset(offsetB);
|
int64_t gmOffsetB = layoutB1.GetOffset(offsetB);
|
||||||
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
|
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
|
||||||
int64_t gmOffsetS = groupIdx * params.problemShape.n() + blockCoord.n() * L1TileShape::N; // 每个expert一组scale
|
int64_t gmOffsetS = groupIdx * params.problemShape.n() + blockCoord.n() * L1TileShape::N; // One scale group per expert
|
||||||
if (currentM > 0) {
|
if (currentM > 0) {
|
||||||
blockMmad(
|
blockMmad(
|
||||||
gmA[gmGroupOffsetA + gmOffsetA], layoutA,
|
gmA[gmGroupOffsetA + gmOffsetA], layoutA,
|
||||||
@@ -465,7 +465,7 @@ private:
|
|||||||
int64_t gmOffsetA = layoutA.GetOffset(offsetA);
|
int64_t gmOffsetA = layoutA.GetOffset(offsetA);
|
||||||
int64_t gmOffsetB = layoutB2.GetOffset(offsetB);
|
int64_t gmOffsetB = layoutB2.GetOffset(offsetB);
|
||||||
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
|
int64_t gmOffsetC = layoutC.GetOffset(offsetC);
|
||||||
int64_t gmOffsetS = groupIdx * n2 + blockCoord.n() * L1TileShape::N; // 每个expert一组scale
|
int64_t gmOffsetS = groupIdx * n2 + blockCoord.n() * L1TileShape::N; // One scale group per expert
|
||||||
if (currentM > 0) {
|
if (currentM > 0) {
|
||||||
blockMmad(
|
blockMmad(
|
||||||
gmPermutedToken[gmGroupOffsetA + gmOffsetA], layoutA,
|
gmPermutedToken[gmGroupOffsetA + gmOffsetA], layoutA,
|
||||||
@@ -537,7 +537,7 @@ private:
|
|||||||
void Dispatch(Params const ¶ms) {
|
void Dispatch(Params const ¶ms) {
|
||||||
icache_preload(8);
|
icache_preload(8);
|
||||||
int64_t localTokenPerExpertOffset = peermemInfo.offsetPeerTokenPerExpert + tokenPerExpertLayout(params.rank, 0, 0) * sizeof(int32_t);
|
int64_t localTokenPerExpertOffset = peermemInfo.offsetPeerTokenPerExpert + tokenPerExpertLayout(params.rank, 0, 0) * sizeof(int32_t);
|
||||||
GM_ADDR localTokenPerExpert = shmem() + localTokenPerExpertOffset; // 把通信矩阵全部放到peermem
|
GM_ADDR localTokenPerExpert = shmem() + localTokenPerExpertOffset; // Place the entire communication matrix in peermem
|
||||||
uint32_t expandedRowIdxOffset = AlignUp(params.problemShape.m(), 256) * params.topK * sizeof(int32_t);
|
uint32_t expandedRowIdxOffset = AlignUp(params.problemShape.m(), 256) * params.topK * sizeof(int32_t);
|
||||||
|
|
||||||
//---initRouting------
|
//---initRouting------
|
||||||
@@ -571,7 +571,7 @@ private:
|
|||||||
int32_t syncLoopIdx = -1;
|
int32_t syncLoopIdx = -1;
|
||||||
BlockEpilogue1 blockEpilogue(resource);
|
BlockEpilogue1 blockEpilogue(resource);
|
||||||
for (int32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
for (int32_t groupIdx = 0; groupIdx < params.expertPerRank; ++groupIdx) {
|
||||||
// 第i个core从第i个rank的peermem读数据
|
// The ith core reads data from the ith rank's peermem
|
||||||
groupIdxDeq = groupIdx - 2;
|
groupIdxDeq = groupIdx - 2;
|
||||||
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
|
for(int32_t dstEpIdx = coreIdx; dstEpIdx < params.EP; dstEpIdx += coreNum) {
|
||||||
uint32_t rowStart = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum1;
|
uint32_t rowStart = (dstEpIdx == 0 ? 0 : cumsumMM((dstEpIdx - 1) * params.expertPerRank + groupIdx)) + prevGroupSum1;
|
||||||
@@ -592,9 +592,9 @@ private:
|
|||||||
MatrixCoord offsetPeer{rowSrc, 0};
|
MatrixCoord offsetPeer{rowSrc, 0};
|
||||||
int64_t gmOffsetA = params.layoutA.GetOffset(offsetA);
|
int64_t gmOffsetA = params.layoutA.GetOffset(offsetA);
|
||||||
int64_t gmOffsetPeer = params.layoutA.GetOffset(offsetPeer);
|
int64_t gmOffsetPeer = params.layoutA.GetOffset(offsetPeer);
|
||||||
// 通信Data
|
// Communication data
|
||||||
CopyGMToGM(gmA[gmOffsetA], gmRemoteA[gmOffsetPeer], rows * params.problemShape.k(), params.ubMoveNum);
|
CopyGMToGM(gmA[gmOffsetA], gmRemoteA[gmOffsetPeer], rows * params.problemShape.k(), params.ubMoveNum);
|
||||||
// 通信scale
|
// Communication scale
|
||||||
CopyGMToGM(gmPerTokenScale1[rowStart], gmRemotePerTokenScale[rowSrc], rows, rows);
|
CopyGMToGM(gmPerTokenScale1[rowStart], gmRemotePerTokenScale[rowSrc], rows, rows);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -604,7 +604,7 @@ private:
|
|||||||
AscendC::CrossCoreWaitFlag<0x2>(syncLoopIdx / 8 + 1);
|
AscendC::CrossCoreWaitFlag<0x2>(syncLoopIdx / 8 + 1);
|
||||||
}
|
}
|
||||||
AscendC::SyncAll<true>();
|
AscendC::SyncAll<true>();
|
||||||
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(0); // V通知C当前轮的通信已完成
|
AscendC::CrossCoreSetFlag<0x2, PIPE_MTE3>(0); // V notifies C that the current communication round is complete
|
||||||
|
|
||||||
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1 && prevGroupSum1 > 0) {
|
if ((params.epilogueGranularity < params.expertPerRank && params.epilogueGranularity > 0) && groupIdx == params.expertPerRank - 1 && prevGroupSum1 > 0) {
|
||||||
uint32_t rowStartThisCore = 0;
|
uint32_t rowStartThisCore = 0;
|
||||||
@@ -664,7 +664,7 @@ private:
|
|||||||
uint32_t n2 = params.problemShape.k();
|
uint32_t n2 = params.problemShape.k();
|
||||||
uint32_t k2 = params.problemShape.n() / 2;
|
uint32_t k2 = params.problemShape.n() / 2;
|
||||||
|
|
||||||
// TODO 计算tokenperexpert的cumsum
|
// TODO compute the cumsum of tokenPerExpert
|
||||||
typename BlockEpilogue2::Params epilogueParams{
|
typename BlockEpilogue2::Params epilogueParams{
|
||||||
static_cast<int32_t>(params.EP),
|
static_cast<int32_t>(params.EP),
|
||||||
static_cast<int32_t>(params.expertPerRank),
|
static_cast<int32_t>(params.expertPerRank),
|
||||||
@@ -774,10 +774,10 @@ private:
|
|||||||
|
|
||||||
CATLASS_DEVICE
|
CATLASS_DEVICE
|
||||||
PeermemInfo(const Params & params, const HcclShmem & shmem) {
|
PeermemInfo(const Params & params, const HcclShmem & shmem) {
|
||||||
offsetA = 0; // 占用1/3的BUFFSIZE
|
offsetA = 0; // Occupies one third of BUFFSIZE
|
||||||
offsetPeerPerTokenScale = offsetA + AlignUp(shmem.SegmentSize() / 3, 512); // 占用1MB
|
offsetPeerPerTokenScale = offsetA + AlignUp(shmem.SegmentSize() / 3, 512); // Occupies 1 MB
|
||||||
offsetD = offsetPeerPerTokenScale + MB_SIZE; // 占用剩下的
|
offsetD = offsetPeerPerTokenScale + MB_SIZE; // Occupies the remaining space
|
||||||
offsetPeerTokenPerExpert = shmem.SegmentSize() - 2 * MB_SIZE; // 占用最后2MB
|
offsetPeerTokenPerExpert = shmem.SegmentSize() - 2 * MB_SIZE; // Occupies the final 2 MB
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -811,4 +811,4 @@ private:
|
|||||||
|
|
||||||
} // namespace Catlass::Gemm::Kernel
|
} // namespace Catlass::Gemm::Kernel
|
||||||
|
|
||||||
#endif // DISPATH_FFN_COMBINE_KERNEL_HPP
|
#endif // DISPATH_FFN_COMBINE_KERNEL_HPP
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ __aicore__ inline void moe_init_routing_quant_v2(
|
|||||||
sortPipe.Destroy();
|
sortPipe.Destroy();
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tilingKey == 10000 || tilingKey == 10010 || tilingKey ==11000 || tilingKey ==11010) { //没有drop的情况
|
if (tilingKey == 10000 || tilingKey == 10010 || tilingKey ==11000 || tilingKey ==11010) { // No drop scenario
|
||||||
if (tilingData->expertTokensCountOrCumsumFlag != EXERPT_TOKENS_NONE) {
|
if (tilingData->expertTokensCountOrCumsumFlag != EXERPT_TOKENS_NONE) {
|
||||||
TPipe expertTokenOutPipe;
|
TPipe expertTokenOutPipe;
|
||||||
MoeV2ExpertTokenOut expertTokenOutOp;
|
MoeV2ExpertTokenOut expertTokenOutOp;
|
||||||
@@ -94,7 +94,7 @@ __aicore__ inline void moe_init_routing_quant_v2(
|
|||||||
srcToDstOp.Init<MoeInitRoutingQuantV2TilingData>(expandedRowIdx, workspace, tilingData, &srcToDstPipe);
|
srcToDstOp.Init<MoeInitRoutingQuantV2TilingData>(expandedRowIdx, workspace, tilingData, &srcToDstPipe);
|
||||||
srcToDstOp.Process();
|
srcToDstOp.Process();
|
||||||
srcToDstPipe.Destroy();
|
srcToDstPipe.Destroy();
|
||||||
} else if (tilingKey ==10100 || tilingKey ==10110 || tilingKey ==11100 || tilingKey ==11110) { //有drop的情况
|
} else if (tilingKey ==10100 || tilingKey ==10110 || tilingKey ==11100 || tilingKey ==11110) { // Drop scenario
|
||||||
TPipe expertTokenOutPipe;
|
TPipe expertTokenOutPipe;
|
||||||
MoeV2ExpertTokenOut expertTokenOutOp;
|
MoeV2ExpertTokenOut expertTokenOutOp;
|
||||||
expertTokenOutOp.Init<MoeInitRoutingQuantV2TilingData>(expertTokensCountOrCumsum, expertTokensBeforeCapacity,
|
expertTokenOutOp.Init<MoeInitRoutingQuantV2TilingData>(expertTokensCountOrCumsum, expertTokensBeforeCapacity,
|
||||||
|
|||||||
@@ -178,7 +178,7 @@ uint64_t InnerMoeInitRoutingV2TilingBase::GetTilingKey() const {
|
|||||||
return TILING_KEY_HIGH_PERFORMANCE;
|
return TILING_KEY_HIGH_PERFORMANCE;
|
||||||
}
|
}
|
||||||
if (dropPadMode == 0) {
|
if (dropPadMode == 0) {
|
||||||
if (totalLength <= sortLoopMaxElement) { // 排序只用到一个核排序
|
if (totalLength <= sortLoopMaxElement) { // Sorting uses only one core
|
||||||
return TILING_KEY_DROPLESS_SORT_ONE_CORE;
|
return TILING_KEY_DROPLESS_SORT_ONE_CORE;
|
||||||
} else {
|
} else {
|
||||||
return TILING_KEY_DROPLESS_SORT_MULTI_CORE;
|
return TILING_KEY_DROPLESS_SORT_MULTI_CORE;
|
||||||
@@ -206,10 +206,10 @@ bool InnerMoeInitRoutingV2TilingBase::GetShapeAttrsInfo(int64_t m, int64_t cols,
|
|||||||
this->expertTokensCountOrCumsumFlag = expertTokensCountOrCumsumFlag;
|
this->expertTokensCountOrCumsumFlag = expertTokensCountOrCumsumFlag;
|
||||||
this->expertTokensBeforeCapacityFlag = expertTokensBeforeCapacityFlag;
|
this->expertTokensBeforeCapacityFlag = expertTokensBeforeCapacityFlag;
|
||||||
if (dropPadMode == 1) {
|
if (dropPadMode == 1) {
|
||||||
// droppad场景下不输出expertTokensCountOrCumsum
|
// Do not output expertTokensCountOrCumsum in drop-pad mode
|
||||||
expertTokensCountOrCumsumFlag = 0;
|
expertTokensCountOrCumsumFlag = 0;
|
||||||
} else {
|
} else {
|
||||||
// dropless场景下不输出expertTokensBeforeCapacity
|
// Do not output expertTokensBeforeCapacity in dropless mode
|
||||||
expertTokensBeforeCapacityFlag = false;
|
expertTokensBeforeCapacityFlag = false;
|
||||||
}
|
}
|
||||||
moeInitRoutingTilingData.cols = cols;
|
moeInitRoutingTilingData.cols = cols;
|
||||||
@@ -235,8 +235,8 @@ bool InnerMoeInitRoutingV2TilingBase::GetPlatformInfo(int64_t aivCoreNum, int64_
|
|||||||
|
|
||||||
|
|
||||||
bool InnerMoeInitRoutingV2TilingBase::GetWorkspaceSize() {
|
bool InnerMoeInitRoutingV2TilingBase::GetWorkspaceSize() {
|
||||||
// 计算workspace大小
|
// Calculate workspace size
|
||||||
size_t sortWorkspaceSize = totalLength * sizeof(float) * NUM_TWO * NUM_THREE; // 排序需要的空间
|
size_t sortWorkspaceSize = totalLength * sizeof(float) * NUM_TWO * NUM_THREE; // Space needed for sorting
|
||||||
size_t scatterWorkspaceSize = totalLength * sizeof(int32_t) * NUM_TWO;
|
size_t scatterWorkspaceSize = totalLength * sizeof(int32_t) * NUM_TWO;
|
||||||
size_t expertTokenFlagSize = aivNum * 2 * sizeof(int32_t);
|
size_t expertTokenFlagSize = aivNum * 2 * sizeof(int32_t);
|
||||||
workspaceSize_ = sortWorkspaceSize + scatterWorkspaceSize + expertTokenFlagSize + SIZE_16 * LENGTH_1024 * LENGTH_1024;
|
workspaceSize_ = sortWorkspaceSize + scatterWorkspaceSize + expertTokenFlagSize + SIZE_16 * LENGTH_1024 * LENGTH_1024;
|
||||||
@@ -257,11 +257,11 @@ void InnerMoeInitRoutingV2TilingBase::Tiling4VBSOneCoreCompute(InnerMoeV2VBSComp
|
|||||||
|
|
||||||
void InnerMoeInitRoutingV2TilingBase::Tiling4VBSMultiCoreCompute(InnerMoeV2VBSComputeTilingData* tilingData) {
|
void InnerMoeInitRoutingV2TilingBase::Tiling4VBSMultiCoreCompute(InnerMoeV2VBSComputeTilingData* tilingData) {
|
||||||
//Tiling4VBSMultiCoreCompute
|
//Tiling4VBSMultiCoreCompute
|
||||||
int64_t needCoreNum = CeilDiv(totalLength, sortLoopMaxElement); // 向上取整
|
int64_t needCoreNum = CeilDiv(totalLength, sortLoopMaxElement); // Round up
|
||||||
needCoreNum = static_cast<int64_t>(std::pow(4, CeilLog4(needCoreNum)));
|
needCoreNum = static_cast<int64_t>(std::pow(4, CeilLog4(needCoreNum)));
|
||||||
needCoreNum = std::min(needCoreNum, aivNum); // 不能超过物理核数
|
needCoreNum = std::min(needCoreNum, aivNum); // Cannot exceed physical core count
|
||||||
if (needCoreNum > 0) {
|
if (needCoreNum > 0) {
|
||||||
int64_t perCoreElements = totalLength / needCoreNum; // 每个核处理的元素数
|
int64_t perCoreElements = totalLength / needCoreNum; // Elements handled per core
|
||||||
int64_t alineFloorPerCoreElements = perCoreElements - perCoreElements % SORT32_ALIGN_ELEMENT;
|
int64_t alineFloorPerCoreElements = perCoreElements - perCoreElements % SORT32_ALIGN_ELEMENT;
|
||||||
int64_t lastCoreElement = totalLength - (needCoreNum - 1) * alineFloorPerCoreElements;
|
int64_t lastCoreElement = totalLength - (needCoreNum - 1) * alineFloorPerCoreElements;
|
||||||
int64_t alineCeilPerCoreElements = perCoreElements + SORT32_ALIGN_ELEMENT - perCoreElements % SORT32_ALIGN_ELEMENT;
|
int64_t alineCeilPerCoreElements = perCoreElements + SORT32_ALIGN_ELEMENT - perCoreElements % SORT32_ALIGN_ELEMENT;
|
||||||
@@ -274,7 +274,7 @@ void InnerMoeInitRoutingV2TilingBase::Tiling4VBSMultiCoreCompute(InnerMoeV2VBSCo
|
|||||||
tilingData->needCoreNum = needCoreNum;
|
tilingData->needCoreNum = needCoreNum;
|
||||||
do {
|
do {
|
||||||
tilingData->perCoreElements = perCoreElements;
|
tilingData->perCoreElements = perCoreElements;
|
||||||
tilingData->perCoreLoops = CeilDiv(tilingData->perCoreElements, sortLoopMaxElement); // 每个核处理的loop数
|
tilingData->perCoreLoops = CeilDiv(tilingData->perCoreElements, sortLoopMaxElement); // Loops handled per core
|
||||||
tilingData->perCorePerLoopElements = std::min(tilingData->perCoreElements, sortLoopMaxElement);
|
tilingData->perCorePerLoopElements = std::min(tilingData->perCoreElements, sortLoopMaxElement);
|
||||||
tilingData->perCoreLastLoopElements = tilingData->perCoreElements - (tilingData->perCoreLoops - 1) * tilingData->perCorePerLoopElements;
|
tilingData->perCoreLastLoopElements = tilingData->perCoreElements - (tilingData->perCoreLoops - 1) * tilingData->perCorePerLoopElements;
|
||||||
tilingData->lastCoreElements = totalLength - (tilingData->needCoreNum - 1) * tilingData->perCoreElements;
|
tilingData->lastCoreElements = totalLength - (tilingData->needCoreNum - 1) * tilingData->perCoreElements;
|
||||||
@@ -294,7 +294,7 @@ void InnerMoeInitRoutingV2TilingBase::Tiling4VBSMultiCoreCompute(InnerMoeV2VBSCo
|
|||||||
void InnerMoeInitRoutingV2TilingBase::Tiling4VBSCompute() {
|
void InnerMoeInitRoutingV2TilingBase::Tiling4VBSCompute() {
|
||||||
auto tilingData = &moeInitRoutingTilingData.vbsComputeParamsOp;
|
auto tilingData = &moeInitRoutingTilingData.vbsComputeParamsOp;
|
||||||
tilingData->oneLoopMaxElements = sortLoopMaxElement;
|
tilingData->oneLoopMaxElements = sortLoopMaxElement;
|
||||||
if (totalLength <= sortLoopMaxElement) { // 只用到一个核
|
if (totalLength <= sortLoopMaxElement) { // Only one core is used
|
||||||
Tiling4VBSOneCoreCompute(tilingData);
|
Tiling4VBSOneCoreCompute(tilingData);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -304,11 +304,11 @@ void InnerMoeInitRoutingV2TilingBase::Tiling4VBSCompute() {
|
|||||||
void InnerMoeInitRoutingV2TilingBase::Tiling4VMSMiddleCompute() {
|
void InnerMoeInitRoutingV2TilingBase::Tiling4VMSMiddleCompute() {
|
||||||
auto vbsComputeTilingData = &moeInitRoutingTilingData.vbsComputeParamsOp;
|
auto vbsComputeTilingData = &moeInitRoutingTilingData.vbsComputeParamsOp;
|
||||||
auto tilingData = &moeInitRoutingTilingData.vmsMiddleComputeParamsOp;
|
auto tilingData = &moeInitRoutingTilingData.vmsMiddleComputeParamsOp;
|
||||||
if (vbsComputeTilingData->needCoreNum <= MRG_LIST_NUM) { // 队列数小于一次vms则没有中间归并
|
if (vbsComputeTilingData->needCoreNum <= MRG_LIST_NUM) { // No intermediate merge if queue count fits one VMS
|
||||||
tilingData->needCoreNum = 0; // 需要的核数
|
tilingData->needCoreNum = 0; // Required core count
|
||||||
} else {
|
} else {
|
||||||
int64_t needCoreNum = CeilDiv(vbsComputeTilingData->needCoreNum, MRG_LIST_NUM);
|
int64_t needCoreNum = CeilDiv(vbsComputeTilingData->needCoreNum, MRG_LIST_NUM);
|
||||||
tilingData->needCoreNum = needCoreNum; // 需要的核数
|
tilingData->needCoreNum = needCoreNum; // Required core count
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -333,7 +333,7 @@ void InnerMoeInitRoutingV2TilingBase::Tiling4SrcToDstCompute() {
|
|||||||
tilingData->needCoreNum = needCoreNum;
|
tilingData->needCoreNum = needCoreNum;
|
||||||
int64_t lastCoreNum = totalLength - perCoreRows * (tilingData->needCoreNum - 1);
|
int64_t lastCoreNum = totalLength - perCoreRows * (tilingData->needCoreNum - 1);
|
||||||
tilingData->perCoreRows = perCoreRows;
|
tilingData->perCoreRows = perCoreRows;
|
||||||
if (perLoopMaxRows >= tilingData->perCoreRows) { // 一个loop结束
|
if (perLoopMaxRows >= tilingData->perCoreRows) { // One loop completes
|
||||||
tilingData->perCorePerLoopRows = tilingData->perCoreRows;
|
tilingData->perCorePerLoopRows = tilingData->perCoreRows;
|
||||||
tilingData->perCoreLastLoopRows = tilingData->perCoreRows;
|
tilingData->perCoreLastLoopRows = tilingData->perCoreRows;
|
||||||
} else {
|
} else {
|
||||||
@@ -407,4 +407,4 @@ void InnerMoeInitRoutingV2TilingBase::Tiling4SrcToDstCapacityCompute() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -151,7 +151,7 @@ __aicore__ inline void MoeV2ExpertTokenOut::CopyOutExpertTokensCumsum(bool isTai
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
#ifdef __CCE_KT_TEST__
|
#ifdef __CCE_KT_TEST__
|
||||||
// CPU孪生调试无法使用多核同步,可能导致index为未初始化的脏数据,因此需要特殊处理
|
// CPU twin debugging cannot use multi-core sync, so index may contain uninitialized dirty data; handle specially
|
||||||
if (this->firstExpertId > expertTokensCountOrCumsumGm.GetSize()) {
|
if (this->firstExpertId > expertTokensCountOrCumsumGm.GetSize()) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
@@ -202,7 +202,7 @@ __aicore__ inline void MoeV2ExpertTokenOut::CopyOutExpertTokensCount(bool isTail
|
|||||||
int64_t copyLength = isTail ? this->lastExpertId - this->firstExpertId + 1 : this->expertNumUbAlign;
|
int64_t copyLength = isTail ? this->lastExpertId - this->firstExpertId + 1 : this->expertNumUbAlign;
|
||||||
DataCopyExtParams copyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(copyLength * sizeof(int32_t)), 0, 0, 0};
|
DataCopyExtParams copyParams{static_cast<uint16_t>(1), static_cast<uint32_t>(copyLength * sizeof(int32_t)), 0, 0, 0};
|
||||||
#ifdef __CCE_KT_TEST__
|
#ifdef __CCE_KT_TEST__
|
||||||
// CPU孪生调试不进行输出拷贝
|
// CPU twin debugging skips output copies
|
||||||
return;
|
return;
|
||||||
#endif
|
#endif
|
||||||
SetAtomicAdd<int32_t>();
|
SetAtomicAdd<int32_t>();
|
||||||
@@ -307,4 +307,4 @@ __aicore__ inline void MoeV2ExpertTokenOut::Process() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
} // namespace MoeInitRoutingQuantV2
|
} // namespace MoeInitRoutingQuantV2
|
||||||
#endif // INNER_MOE_V2_EXPERT_TOKEN_OUT_H
|
#endif // INNER_MOE_V2_EXPERT_TOKEN_OUT_H
|
||||||
|
|||||||
@@ -184,7 +184,7 @@ __aicore__ inline void MoeV2GatherDynamicQuant<T>::CopyOutXQuant1H(int64_t progr
|
|||||||
|
|
||||||
inputXInQueue.EnQue<T>(inLocal);
|
inputXInQueue.EnQue<T>(inLocal);
|
||||||
|
|
||||||
// 计算quant
|
// Compute quantization
|
||||||
Compute(smoothLocal);
|
Compute(smoothLocal);
|
||||||
|
|
||||||
LocalTensor<float> quantScaleLocal = scaleOutQueue.DeQue<float>();
|
LocalTensor<float> quantScaleLocal = scaleOutQueue.DeQue<float>();
|
||||||
@@ -525,7 +525,7 @@ template <typename T>
|
|||||||
__aicore__ inline void MoeV2GatherDynamicQuant<T>::Process() {
|
__aicore__ inline void MoeV2GatherDynamicQuant<T>::Process() {
|
||||||
if (this->blockIdx < this->needCoreNum) {
|
if (this->blockIdx < this->needCoreNum) {
|
||||||
currentLoopRows = perLoopRows;
|
currentLoopRows = perLoopRows;
|
||||||
if (colLoops > 1) { // 一行无法全载,需要workspace
|
if (colLoops > 1) { // A single row cannot be fully loaded; workspace is required
|
||||||
if (smoothType == 2) {
|
if (smoothType == 2) {
|
||||||
for (int64_t loop = 0; loop < this->rowLoops - 1; loop++) {
|
for (int64_t loop = 0; loop < this->rowLoops - 1; loop++) {
|
||||||
CopyInExpandedExpertIdx(loop);
|
CopyInExpandedExpertIdx(loop);
|
||||||
@@ -543,7 +543,7 @@ __aicore__ inline void MoeV2GatherDynamicQuant<T>::Process() {
|
|||||||
CopyInExpandedRowIdx(this->rowLoops - 1);
|
CopyInExpandedRowIdx(this->rowLoops - 1);
|
||||||
CopyOutPartialXQuant1H(this->rowLoops - 1);
|
CopyOutPartialXQuant1H(this->rowLoops - 1);
|
||||||
}
|
}
|
||||||
} else { // 一行可以全载
|
} else { // A single row can be fully loaded
|
||||||
if (smoothType == 2) {
|
if (smoothType == 2) {
|
||||||
for (int64_t loop = 0; loop < this->rowLoops - 1; loop++) {
|
for (int64_t loop = 0; loop < this->rowLoops - 1; loop++) {
|
||||||
CopyInExpandedExpertIdx(loop);
|
CopyInExpandedExpertIdx(loop);
|
||||||
|
|||||||
@@ -111,7 +111,7 @@ __aicore__ inline void MoeV2GatherOut<T>::CopyOut(int64_t progress) {
|
|||||||
}
|
}
|
||||||
outOffset = outIndex * cols + colsLoop * this->perLoopCols;
|
outOffset = outIndex * cols + colsLoop * this->perLoopCols;
|
||||||
#ifdef __CCE_KT_TEST__
|
#ifdef __CCE_KT_TEST__
|
||||||
// CPU孪生调试无法使用多核同步,可能导致index为未初始化的脏数据,因此需要特殊处理
|
// CPU twin debugging cannot use multi-core sync, so index may contain uninitialized dirty data; handle specially
|
||||||
if (outOffset > expandedXGm.GetSize()) {
|
if (outOffset > expandedXGm.GetSize()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -132,7 +132,7 @@ __aicore__ inline void MoeV2SrcToDstWithCapacity<T, TilingData>::CopyOut(int64_t
|
|||||||
col = this->lastLoopCols;
|
col = this->lastLoopCols;
|
||||||
}
|
}
|
||||||
#ifdef __CCE_KT_TEST__
|
#ifdef __CCE_KT_TEST__
|
||||||
// CPU孪生调试无法使用多核同步,可能导致index为未初始化的脏数据,因此需要特殊处理
|
// CPU twin debugging cannot use multi-core sync, so index may contain uninitialized dirty data; handle specially
|
||||||
if (index * this->cols + i * this->perLoopCols + col * sizeof(T) > expandedXGm.GetSize()) {
|
if (index * this->cols + i * this->perLoopCols + col * sizeof(T) > expandedXGm.GetSize()) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -266,4 +266,4 @@ __aicore__ inline void MoeV2SrcToDstWithCapacity<T, TilingData>::Process() {
|
|||||||
this->SyncAll();
|
this->SyncAll();
|
||||||
}
|
}
|
||||||
} // namespace MoeInitRoutingQuantV2
|
} // namespace MoeInitRoutingQuantV2
|
||||||
#endif // INNER_MOE_V2_SRC_TO_DST_WITH_CAPACITY_H
|
#endif // INNER_MOE_V2_SRC_TO_DST_WITH_CAPACITY_H
|
||||||
|
|||||||
@@ -107,7 +107,7 @@ KernelMoeTokenUnpermute<T1, T2, T3, PROBS>::Init(GM_ADDR permuted_tokens, GM_ADD
|
|||||||
this->tokens_splited_num = tiling_data->tokens_splited_num;
|
this->tokens_splited_num = tiling_data->tokens_splited_num;
|
||||||
this->tokens_splited_remain = tiling_data->tokens_splited_remain;
|
this->tokens_splited_remain = tiling_data->tokens_splited_remain;
|
||||||
|
|
||||||
// 处理token_by_core尾块
|
// Handle the tail block for token_by_core
|
||||||
if (this->tokens_core_remain > 0 && blockIdx < this->tokens_core_remain) {
|
if (this->tokens_core_remain > 0 && blockIdx < this->tokens_core_remain) {
|
||||||
this->tokens_core_length += 1;
|
this->tokens_core_length += 1;
|
||||||
this->tokens_splited_remain += 1;
|
this->tokens_splited_remain += 1;
|
||||||
@@ -181,7 +181,7 @@ __aicore__ inline void KernelMoeTokenUnpermute<T1, T2, T3, PROBS>::Process()
|
|||||||
for (int64_t i = 0; i < this->tokens_splited_num; ++i) {
|
for (int64_t i = 0; i < this->tokens_splited_num; ++i) {
|
||||||
CalMultiOutToken(i * this->tokens_splited_length, this->tokens_splited_length);
|
CalMultiOutToken(i * this->tokens_splited_length, this->tokens_splited_length);
|
||||||
}
|
}
|
||||||
// 处理tokens_num不能均匀分核数的尾块
|
// Handle the tail block when tokens_num is not evenly divisible by core count
|
||||||
if (this->tokens_splited_remain > 0) {
|
if (this->tokens_splited_remain > 0) {
|
||||||
CalMultiOutToken(this->tokens_splited_num * this->tokens_splited_length, this->tokens_splited_remain);
|
CalMultiOutToken(this->tokens_splited_num * this->tokens_splited_length, this->tokens_splited_remain);
|
||||||
}
|
}
|
||||||
@@ -231,7 +231,7 @@ __aicore__ inline void KernelMoeTokenUnpermute<T1, T2, T3, PROBS>::CalSingleOutT
|
|||||||
for (int64_t h_index = 0; h_index < this->hidden_splited_num; ++h_index) {
|
for (int64_t h_index = 0; h_index < this->hidden_splited_num; ++h_index) {
|
||||||
CalPartOutToken(start_token, h_index, this->hidden_splited_length, out_token_idx);
|
CalPartOutToken(start_token, h_index, this->hidden_splited_length, out_token_idx);
|
||||||
}
|
}
|
||||||
// 一次不能完整容纳完整的hidden_size, 处理尾块
|
// Handle the tail block when a full hidden_size does not fit in one pass
|
||||||
if (this->hidden_splited_remain > 0) {
|
if (this->hidden_splited_remain > 0) {
|
||||||
CalPartOutToken(start_token, this->hidden_splited_num, this->hidden_splited_remain, out_token_idx);
|
CalPartOutToken(start_token, this->hidden_splited_num, this->hidden_splited_remain, out_token_idx);
|
||||||
}
|
}
|
||||||
@@ -248,7 +248,7 @@ KernelMoeTokenUnpermute<T1, T2, T3, PROBS>::CalPartOutToken(const int64_t start_
|
|||||||
int64_t end_token = start_token + this->top_k;
|
int64_t end_token = start_token + this->top_k;
|
||||||
T2 cal_token_idx = this->indicesLocal.GetValue(start_token);
|
T2 cal_token_idx = this->indicesLocal.GetValue(start_token);
|
||||||
|
|
||||||
// 处理第一个Token数据
|
// Handle the first token
|
||||||
if (cal_token_idx < this->num_out_tokens) {
|
if (cal_token_idx < this->num_out_tokens) {
|
||||||
float probsValue = 0;
|
float probsValue = 0;
|
||||||
if constexpr (PROBS) {
|
if constexpr (PROBS) {
|
||||||
@@ -263,7 +263,7 @@ KernelMoeTokenUnpermute<T1, T2, T3, PROBS>::CalPartOutToken(const int64_t start_
|
|||||||
Duplicate(this->token_tensor0, static_cast<float>(0), h_length);
|
Duplicate(this->token_tensor0, static_cast<float>(0), h_length);
|
||||||
}
|
}
|
||||||
|
|
||||||
// 处理剩余的Token数据
|
// Handle the remaining tokens
|
||||||
for (int64_t token_index = start_token + 1; token_index < end_token; ++token_index) {
|
for (int64_t token_index = start_token + 1; token_index < end_token; ++token_index) {
|
||||||
cal_token_idx = this->indicesLocal.GetValue(token_index);
|
cal_token_idx = this->indicesLocal.GetValue(token_index);
|
||||||
if (cal_token_idx < this->num_out_tokens) {
|
if (cal_token_idx < this->num_out_tokens) {
|
||||||
@@ -278,7 +278,7 @@ KernelMoeTokenUnpermute<T1, T2, T3, PROBS>::CalPartOutToken(const int64_t start_
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 输出计算结果
|
// Write out the computed result
|
||||||
CopyOut(out_token_index, h_index, h_length);
|
CopyOut(out_token_index, h_index, h_length);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -146,27 +146,27 @@ public:
|
|||||||
auto gmTileD = gmD[loopIdx * blockN];
|
auto gmTileD = gmD[loopIdx * blockN];
|
||||||
LayoutC layoutUbC{1, blockN};
|
LayoutC layoutUbC{1, blockN};
|
||||||
|
|
||||||
// 把C从GM workspace搬到UB
|
// Move C from GM workspace to UB
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
|
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
|
||||||
copyGmToUbC(ubC, gmTileC, layoutUbC, layoutUbC);
|
copyGmToUbC(ubC, gmTileC, layoutUbC, layoutUbC);
|
||||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
|
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
|
||||||
|
|
||||||
//在UB上做把C cast成FP32
|
// Cast C to FP32 in UB
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
|
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
|
||||||
AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_NONE, blockN);
|
AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_NONE, blockN);
|
||||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
|
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
|
||||||
|
|
||||||
// 获取pertoken scale值,gmPerTokenScale的第loopIdx行
|
// Get per-token scale from row loopIdx of gmPerTokenScale
|
||||||
ElementPerTokenScale perTokenScale = gmPerTokenScale(loopIdx);
|
ElementPerTokenScale perTokenScale = gmPerTokenScale(loopIdx);
|
||||||
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::S_V>(0);
|
AscendC::SetFlag<AscendC::HardEvent::S_V>(0);
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::S_V>(0);
|
AscendC::WaitFlag<AscendC::HardEvent::S_V>(0);
|
||||||
// pertoken scale值与FP32的C做Muls乘法
|
// Multiply FP32 C by the per-token scale
|
||||||
AscendC::PipeBarrier<PIPE_V>();
|
AscendC::PipeBarrier<PIPE_V>();
|
||||||
AscendC::Muls(ubCFp32, ubCFp32, perTokenScale, blockN);
|
AscendC::Muls(ubCFp32, ubCFp32, perTokenScale, blockN);
|
||||||
AscendC::PipeBarrier<PIPE_V>();
|
AscendC::PipeBarrier<PIPE_V>();
|
||||||
|
|
||||||
// 将muls结果转回fp16/bf16
|
// Cast the muls result back to fp16/bf16
|
||||||
LayoutD layoutUbD{1, blockN};
|
LayoutD layoutUbD{1, blockN};
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
|
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
|
||||||
|
|
||||||
|
|||||||
@@ -140,7 +140,7 @@ public:
|
|||||||
{
|
{
|
||||||
params = params_;
|
params = params_;
|
||||||
}
|
}
|
||||||
// 每个tile就是1*7168,每个block是一个expert的所有token=[group[i], 7168]
|
// Each tile is 1x7168, and each block covers all tokens for one expert = [group[i], 7168]
|
||||||
CATLASS_DEVICE
|
CATLASS_DEVICE
|
||||||
void operator() (
|
void operator() (
|
||||||
AscendC::GlobalTensor<ElementC> const &gmC,
|
AscendC::GlobalTensor<ElementC> const &gmC,
|
||||||
@@ -200,39 +200,39 @@ public:
|
|||||||
auto gmTileD = gmD[loopIdx * ChunkTileLen];
|
auto gmTileD = gmD[loopIdx * ChunkTileLen];
|
||||||
LayoutC layoutUbC{1, blockN};
|
LayoutC layoutUbC{1, blockN};
|
||||||
|
|
||||||
// 把C从GM workspace搬到UB
|
// Move C from GM workspace to UB
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
|
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
|
||||||
copyGmToUbC(ubC, gmTileC, layoutUbC, layoutUbC);
|
copyGmToUbC(ubC, gmTileC, layoutUbC, layoutUbC);
|
||||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
|
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
|
||||||
|
|
||||||
// 在UB上做把C cast成FP32
|
// Cast C to FP32 in UB
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
|
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
|
||||||
AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_NONE, blockN);
|
AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_NONE, blockN);
|
||||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
|
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
|
||||||
|
|
||||||
// 获取pertoken scale值,gmPerTokenScale的第loopIdx行
|
// Get per-token scale from row loopIdx of gmPerTokenScale
|
||||||
ElementPerTokenScale perTokenScale = gmPerTokenScale1(loopIdx);
|
ElementPerTokenScale perTokenScale = gmPerTokenScale1(loopIdx);
|
||||||
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::S_V>(0);
|
AscendC::SetFlag<AscendC::HardEvent::S_V>(0);
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::S_V>(0);
|
AscendC::WaitFlag<AscendC::HardEvent::S_V>(0);
|
||||||
// pertoken scale值与FP32的C做Muls乘法
|
// Multiply FP32 C by the per-token scale
|
||||||
AscendC::PipeBarrier<PIPE_V>();
|
AscendC::PipeBarrier<PIPE_V>();
|
||||||
AscendC::Muls(ubCFp32, ubCFp32, perTokenScale, blockN);
|
AscendC::Muls(ubCFp32, ubCFp32, perTokenScale, blockN);
|
||||||
AscendC::PipeBarrier<PIPE_V>();
|
AscendC::PipeBarrier<PIPE_V>();
|
||||||
|
|
||||||
//swiglue计算过程
|
// Swiglu computation process
|
||||||
AscendC::Muls(ubCFp32ChunkN, ubCFp32, -1.0f, ChunkTileLen);
|
AscendC::Muls(ubCFp32ChunkN, ubCFp32, -1.0f, ChunkTileLen);
|
||||||
AscendC::PipeBarrier<PIPE_V>();
|
AscendC::PipeBarrier<PIPE_V>();
|
||||||
AscendC::Exp(ubCFp32ChunkN, ubCFp32ChunkN, ChunkTileLen);
|
AscendC::Exp(ubCFp32ChunkN, ubCFp32ChunkN, ChunkTileLen);
|
||||||
AscendC::PipeBarrier<PIPE_V>();
|
AscendC::PipeBarrier<PIPE_V>();
|
||||||
AscendC::Adds(ubCFp32ChunkN, ubCFp32ChunkN, 1.0f, ChunkTileLen);
|
AscendC::Adds(ubCFp32ChunkN, ubCFp32ChunkN, 1.0f, ChunkTileLen);
|
||||||
AscendC::PipeBarrier<PIPE_V>();
|
AscendC::PipeBarrier<PIPE_V>();
|
||||||
//TODO除的时候是否会对之后的数据有影响;
|
// TODO: confirm whether the division impacts subsequent data
|
||||||
AscendC::Div(ubCFp32ChunkN, ubCFp32, ubCFp32ChunkN, ChunkTileLen);
|
AscendC::Div(ubCFp32ChunkN, ubCFp32, ubCFp32ChunkN, ChunkTileLen);
|
||||||
AscendC::PipeBarrier<PIPE_V>();
|
AscendC::PipeBarrier<PIPE_V>();
|
||||||
AscendC::Mul(ubCFp32ChunkN, ubCFp32ChunkN, ubCFp32[ChunkTileLen], ChunkTileLen);
|
AscendC::Mul(ubCFp32ChunkN, ubCFp32ChunkN, ubCFp32[ChunkTileLen], ChunkTileLen);
|
||||||
|
|
||||||
//quant过程,两种方式区别;
|
// Quantization process; difference between the two approaches
|
||||||
AscendC::PipeBarrier<PIPE_V>();
|
AscendC::PipeBarrier<PIPE_V>();
|
||||||
AscendC::Abs(ubAbs, ubCFp32ChunkN, ChunkTileLen);
|
AscendC::Abs(ubAbs, ubCFp32ChunkN, ChunkTileLen);
|
||||||
AscendC::PipeBarrier<PIPE_V>();
|
AscendC::PipeBarrier<PIPE_V>();
|
||||||
@@ -243,7 +243,7 @@ public:
|
|||||||
AscendC::SetFlag<AscendC::HardEvent::V_S>(0);
|
AscendC::SetFlag<AscendC::HardEvent::V_S>(0);
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::V_S>(0);
|
AscendC::WaitFlag<AscendC::HardEvent::V_S>(0);
|
||||||
|
|
||||||
//TODO两种计算方法的效率比较
|
// TODO: compare the efficiency of the two calculation methods
|
||||||
ElementPerTokenScale GMubDequantScale = ubReduceMax.GetValue(0);
|
ElementPerTokenScale GMubDequantScale = ubReduceMax.GetValue(0);
|
||||||
AscendC::SetFlag<AscendC::HardEvent::S_V>(0);
|
AscendC::SetFlag<AscendC::HardEvent::S_V>(0);
|
||||||
|
|
||||||
|
|||||||
@@ -56,7 +56,7 @@ FORCE_INLINE_AICORE int32_t gm_signal_wait_until_eq_for_barrier(__gm__ int32_t *
|
|||||||
constexpr int32_t MAX_RANK_SIZE = 32;
|
constexpr int32_t MAX_RANK_SIZE = 32;
|
||||||
class HcclShmem {
|
class HcclShmem {
|
||||||
public:
|
public:
|
||||||
#ifdef HCCL_COMM // hccl需要初始化hccl context
|
#ifdef HCCL_COMM // HCCL needs to initialize the HCCL context
|
||||||
__gm__ HcclOpResParamCustom *WinContext_{nullptr};
|
__gm__ HcclOpResParamCustom *WinContext_{nullptr};
|
||||||
Hccl<HCCL_SERVER_TYPE_AICPU> hccl_;
|
Hccl<HCCL_SERVER_TYPE_AICPU> hccl_;
|
||||||
GM_ADDR m_ptrArray[MAX_RANK_SIZE];
|
GM_ADDR m_ptrArray[MAX_RANK_SIZE];
|
||||||
@@ -92,7 +92,7 @@ public:
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
FORCE_INLINE_AICORE
|
FORCE_INLINE_AICORE
|
||||||
GM_ADDR operator() () const { // 无参数,返回本地peermem
|
GM_ADDR operator() () const { // No argument: return local peermem
|
||||||
#ifdef HCCL_COMM
|
#ifdef HCCL_COMM
|
||||||
return m_ptrArray[m_rank];
|
return m_ptrArray[m_rank];
|
||||||
#else
|
#else
|
||||||
@@ -101,7 +101,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
FORCE_INLINE_AICORE
|
FORCE_INLINE_AICORE
|
||||||
GM_ADDR operator() (int32_t index) const { // 带index参数,返回远端peermem首地址
|
GM_ADDR operator() (int32_t index) const { // With index: return remote peermem base address
|
||||||
#ifdef HCCL_COMM
|
#ifdef HCCL_COMM
|
||||||
return m_ptrArray[index];
|
return m_ptrArray[index];
|
||||||
#else
|
#else
|
||||||
@@ -126,22 +126,6 @@ public:
|
|||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
// FORCE_INLINE_AICORE
|
|
||||||
// GM_ADDR operator () (GM_ADDR ptr, int32_t index) const { // shmem_ptr相同用法
|
|
||||||
// #ifdef HCCL_COMM
|
|
||||||
// size_t offset = ptr - m_ptrArray[m_rank];
|
|
||||||
// if (offset < 0 || offset >= m_segmentSize) {
|
|
||||||
// return nullptr;
|
|
||||||
// }
|
|
||||||
// if (index < 0 || index >= m_rankSize) {
|
|
||||||
// return nullptr;
|
|
||||||
// }
|
|
||||||
// return m_ptrArray[index] + offset;
|
|
||||||
// #else
|
|
||||||
// return shmem_ptr(ptr, index);
|
|
||||||
// #endif
|
|
||||||
// }
|
|
||||||
|
|
||||||
|
|
||||||
FORCE_INLINE_AICORE
|
FORCE_INLINE_AICORE
|
||||||
~HcclShmem() {
|
~HcclShmem() {
|
||||||
|
|||||||
@@ -157,6 +157,7 @@ cd ..
|
|||||||
# Install vLLM Ascend.
|
# Install vLLM Ascend.
|
||||||
git clone --depth 1 --branch |vllm_ascend_version| https://github.com/vllm-project/vllm-ascend.git
|
git clone --depth 1 --branch |vllm_ascend_version| https://github.com/vllm-project/vllm-ascend.git
|
||||||
cd vllm-ascend
|
cd vllm-ascend
|
||||||
|
git submodule update --init --recursive
|
||||||
pip install -v -e .
|
pip install -v -e .
|
||||||
cd ..
|
cd ..
|
||||||
```
|
```
|
||||||
|
|||||||
Reference in New Issue
Block a user