diff --git a/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_def.cpp b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_def.cpp index 511f653c..1f991815 100644 --- a/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_def.cpp +++ b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_def.cpp @@ -17,59 +17,94 @@ public: { this->Input("x") .ParamType(REQUIRED) - .DataType({ge::DT_BF16, ge::DT_FLOAT16}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + .DataType({ge::DT_BF16, ge::DT_BF16, ge::DT_BF16, ge::DT_BF16, + ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("expert_ids") .ParamType(REQUIRED) - .DataType({ge::DT_INT32, ge::DT_INT32}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, + ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("gmm1_permuted_weight") .ParamType(DYNAMIC) - .DataType({ge::DT_INT8, ge::DT_INT8}) - .Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}) - .UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}); + .DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, + ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8}) + .Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, + ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}) + .UnknownShapeFormat( + {ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, + ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}); this->Input("gmm1_permuted_weight_scale") .ParamType(DYNAMIC) - .DataType({ge::DT_FLOAT, ge::DT_FLOAT}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + .DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_BF16, ge::DT_BF16, + ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("gmm2_weight") .ParamType(DYNAMIC) - .DataType({ge::DT_INT8, ge::DT_INT8}) - .Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}) - .UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}); + .DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, + ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8}) + .Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, + ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}) + .UnknownShapeFormat( + {ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, + ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}); this->Input("gmm2_weight_scale") .ParamType(DYNAMIC) - .DataType({ge::DT_FLOAT, ge::DT_FLOAT}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + .DataType({ge::DT_FLOAT, ge::DT_BF16, ge::DT_FLOAT, ge::DT_BF16, + ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("expert_scales") .ParamType(REQUIRED) - .DataType({ge::DT_FLOAT, ge::DT_FLOAT}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + .DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, + ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("expert_smooth_scales") .ParamType(OPTIONAL) - .DataType({ge::DT_FLOAT, ge::DT_FLOAT}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + .DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, + ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("x_active_mask") .ParamType(OPTIONAL) - .DataType({ge::DT_BOOL, ge::DT_BOOL}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + .DataType({ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, + ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Output("output") .ParamType(REQUIRED) - .DataType({ge::DT_BF16, ge::DT_FLOAT16}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + .DataType({ge::DT_BF16, ge::DT_BF16, ge::DT_BF16, ge::DT_BF16, + ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Output("expert_token_nums") .ParamType(REQUIRED) - .DataType({ge::DT_INT64, ge::DT_INT64}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + .DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, + ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Attr("group_ep").String(); this->Attr("ep_rank_size").Int(); this->Attr("ep_rank_id").Int(); diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.cpp b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.cpp index 02f8ada4..aae344af 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.cpp +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.cpp @@ -27,7 +27,8 @@ extern "C" __global__ __aicore__ void dispatch_gmm_combine_decode( GET_TILING_DATA(tiling_data, tiling); if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1) || TILING_KEY_IS(2) || TILING_KEY_IS(3) || TILING_KEY_IS(4) || TILING_KEY_IS(5) || TILING_KEY_IS(6) || TILING_KEY_IS(7)) { - DispatchGmmCombineDecode op; + DispatchGmmCombineDecode< + DTYPE_X, DTYPE_GMM1_PERMUTED_WEIGHT_SCALE, DTYPE_GMM2_WEIGHT_SCALE, int32_t, false, TILING_KEY_VAR> op; op.Init(x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, gmm2_weight_scale, expert_scales, expert_smooth_scales, x_active_mask, output, expertTokenNums, workspace, nullptr, &tiling_data); op.Process(); diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h index bedbc52b..97aa44ea 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h @@ -54,7 +54,7 @@ using Gmm2DispatchPolicy = GMM2_L0A_STAGES, GMM2_L0B_STAGES, CUSTOM_L0C_STAGES, CUSTOM_ENABLE_UNIT_FLAG, CUSTOM_ENABLE_SHUFFLE_K>; -template CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA, layout::RowMajor layoutA, GM_ADDR gmB, layout::zN layoutB, GM_ADDR gmScale, @@ -72,7 +72,6 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun using L1TileShape = L1TileShape_; using L0TileShape = L0TileShape_; - using XType = XType_; using AType = Gemm::GemmType; using BType = Gemm::GemmType; using CType = Gemm::GemmType; @@ -81,7 +80,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun constexpr uint32_t ubStages = 1; using EpilogueDispatchPolicy = Epilogue::EpilogueAtlasA2PerTokenDequantSwiglu; - using ScaleType = Gemm::GemmType; + using ScaleType = Gemm::GemmType; using PerTokenScaleType = Gemm::GemmType; using DType = Gemm::GemmType; @@ -110,9 +109,9 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun using GemmKernel = typename std::conditional< (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE), Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace< - EXEC_FLAG, XType, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>, + TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>, Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch< - BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>>::type; + TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>>::type; if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { typename GemmKernel::Params params{problemShape, @@ -197,7 +196,7 @@ CATLASS_DEVICE void GmmDeq(GemmCoord problemShape, uint32_t groupCount, GM_ADDR constexpr uint32_t ubStages = 1; using EpilogueDispatchPolicy = Epilogue::EpilogueAtlasA2PerTokenDequantCombine; - using ScaleType = Gemm::GemmType; + using ScaleType = Gemm::GemmType; using PerTokenScaleType = Gemm::GemmType; using DType = Gemm::GemmType; @@ -411,7 +410,7 @@ __aicore__ inline void DispatchGmmCombineDecode::Process() Arch::CrossCoreWaitFlag(gmm1AivFinished); } } - GmmDeqSwigluQuant( gmm1ProblemShape, groupCount_, gmGroupList, gmX1, layoutX1, gmPermuteWeight1_, layoutWeight1, gmPermuteScale1_, layoutW1Scale, gmX1Scale, layoutX1Scale, gmX2, layoutX2, gmX2Scale, diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_per_token_dequant.hpp b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_per_token_dequant.hpp index cf7d16e3..dc64875a 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_per_token_dequant.hpp +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_per_token_dequant.hpp @@ -19,304 +19,16 @@ #include "catlass/layout/layout.hpp" #include "catlass/matrix_coord.hpp" -#define ENABLE_EP_SEND_COUNT_HASH 0 - namespace Catlass::Epilogue::Block { -template -class BlockEpilogue, CType_, ScaleType_, PerTokenScaleType_, - DType_, TileRowBroadcastMul_, TileBroadcastOneBlk_, TileOneBlkColumnBroadcastMul_, TileCopy_, - EpilogueTileSwizzle_> -{ -public: - using DispatchPolicy = EpilogueAtlasA2PerTokenDequantCombine; - using ArchTag = typename DispatchPolicy::ArchTag; - static constexpr uint32_t UB_STAGES = UB_STAGES_; - - // Data infos - using ElementC = typename CType_::Element; - using LayoutC = typename CType_::Layout; - using ElementScale = typename ScaleType_::Element; - using LayoutScale = typename ScaleType_::Layout; - using ElementPerTokenScale = typename PerTokenScaleType_::Element; - using LayoutPerTokenScale = typename PerTokenScaleType_::Layout; - using ElementD = typename DType_::Element; - using LayoutD = typename DType_::Layout; - - // Check data infos - static_assert(std::is_same_v && - (std::is_same_v || std::is_same_v) && - std::is_same_v && std::is_same_v, - "The element type template parameters of BlockEpilogue are wrong"); - static_assert(std::is_same_v && std::is_same_v && - std::is_same_v && - std::is_same_v, - "The layout template parameters of BlockEpilogue are wrong"); - - // Tile compute ops - using TileRowBroadcastMul = TileRowBroadcastMul_; - using TileBroadcastOneBlk = TileBroadcastOneBlk_; - using TileOneBlkColumnBroadcastMul = TileOneBlkColumnBroadcastMul_; - - // Tile copy - using CopyGmToUbC = typename TileCopy_::CopyGmToUbC; - using CopyGmToUbScale = typename TileCopy_::CopyGmToUbX; - using CopyGmToUbPerTokenScale = typename TileCopy_::CopyGmToUbY; - using CopyUbToGmD = typename TileCopy_::CopyUbToGmD; - - using EpilogueTileSwizzle = EpilogueTileSwizzle_; - - using TileShape = typename TileRowBroadcastMul::TileShape; - - static_assert(TileShape::ROW == TileBroadcastOneBlk::COMPUTE_LENGTH && - std::is_same_v, - "TileShape must be consistent for all tile compute ops"); - - static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COLUMN * sizeof(ElementScale) + - TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) + - (TileShape::COUNT + TileShape::COLUMN + TileShape::COUNT + TileShape::ROW) * sizeof(float) + - TileShape::ROW * BYTE_PER_BLK) <= ArchTag::UB_SIZE, - "TileShape is too large to fit in UB"); - - struct Params { - __gm__ ElementScale *ptrScale{nullptr}; - LayoutScale layoutScale{}; - __gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr}; - LayoutPerTokenScale layoutPerTokenScale{}; - __gm__ ElementD *ptrD{nullptr}; - LayoutD layoutD{}; - - CATLASS_DEVICE - Params() {}; - - CATLASS_DEVICE - Params(__gm__ ElementScale *ptrScale_, LayoutScale const &layoutScale_, - __gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_, - __gm__ ElementD *ptrD_, LayoutD const &layoutD_) - : ptrScale(ptrScale_), - layoutScale(layoutScale_), - ptrPerTokenScale(ptrPerTokenScale_), - layoutPerTokenScale(layoutPerTokenScale_), - ptrD(ptrD_), - layoutD(layoutD_) - {} - }; - - CATLASS_DEVICE - BlockEpilogue(Arch::Resource const &resource, Params const ¶ms = Params{}) : params(params) - { - size_t ubOffset = 0; - int32_t eventVMTE2 = 0; - int32_t eventMTE2V = 0; - int32_t eventMTE3V = 0; - int32_t eventVMTE3 = 0; - for (uint32_t i = 0; i < UB_STAGES; ++i) { - ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COUNT * sizeof(ElementC); - ubScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COLUMN * sizeof(ElementScale); - ubPerTokenScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale); - ubDList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COUNT * sizeof(ElementD); - - eventUbCVMTE2List[i] = eventVMTE2++; - eventUbCMTE2VList[i] = eventMTE2V++; - eventUbScaleVMTE2List[i] = eventVMTE2++; - eventUbScaleMTE2VList[i] = eventMTE2V++; - eventUbPerTokenScaleVMTE2List[i] = eventVMTE2++; - eventUbPerTokenScaleMTE2VList[i] = eventMTE2V++; - eventUbDMTE3VList[i] = eventMTE3V++; - eventUbDVMTE3List[i] = eventVMTE3++; - - AscendC::SetFlag(eventUbCVMTE2List[i]); - AscendC::SetFlag(eventUbScaleVMTE2List[i]); - AscendC::SetFlag(eventUbPerTokenScaleVMTE2List[i]); - AscendC::SetFlag(eventUbDMTE3VList[i]); - } - ubCFp32 = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COUNT * sizeof(float); - ubScaleFp32 = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COLUMN * sizeof(float); - ubMul = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COUNT * sizeof(float); - ubPerTokenScaleFp32 = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::ROW * sizeof(float); - ubPerTokenScaleFp32Brcb = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::ROW * BYTE_PER_BLK; - ubPerTokenMul = ubMul; - } - - CATLASS_DEVICE - ~BlockEpilogue() - { - for (uint32_t i = 0; i < UB_STAGES; ++i) { - AscendC::WaitFlag(eventUbCVMTE2List[i]); - AscendC::WaitFlag(eventUbScaleVMTE2List[i]); - AscendC::WaitFlag(eventUbPerTokenScaleVMTE2List[i]); - AscendC::WaitFlag(eventUbDMTE3VList[i]); - } - } - - CATLASS_DEVICE - void UpdateParams(Params const ¶ms_) - { - params = params_; - } - - CATLASS_DEVICE - void operator()(GemmCoord const &blockShapeMNK, GemmCoord const &blockCoordMNK, - GemmCoord const &actualBlockShapeMNK, AscendC::GlobalTensor const &gmBlockC, - LayoutC const &layoutBlockC, Callback &&callback = Callback{}) - { - if (actualBlockShapeMNK.k() == 0) { - return; - } - callback(); - - // Calculate the offset of the current block - MatrixCoord blockShape = blockShapeMNK.GetCoordMN(); - MatrixCoord blockCoord = blockCoordMNK.GetCoordMN(); - MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN(); - MatrixCoord blockOffset = blockCoord * blockShape; - - AscendC::GlobalTensor gmScale; - gmScale.SetGlobalBuffer(params.ptrScale); - AscendC::GlobalTensor gmPerTokenScale; - gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale); - AscendC::GlobalTensor gmD; - gmD.SetGlobalBuffer(params.ptrD); - - auto ubTileStride = MakeCoord(static_cast(TileShape::COLUMN), 1L); - auto tileShape = TileShape::ToCoord(); - EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape); - uint32_t tileLoops = epilogueTileSwizzle.GetLoops(); - uint32_t subblockIdx = AscendC::GetSubBlockIdx(); - uint32_t subblockNum = AscendC::GetSubBlockNum(); - for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; loopIdx += subblockNum) { - auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx); - auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord); - auto tileOffsetInBlock = tileCoord * tileShape; - auto tileOffset = blockOffset + tileOffsetInBlock; - - auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)]; - auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape); - - auto &ubC = ubCList[ubListId]; - LayoutC layoutUbC{actualTileShape, ubTileStride}; - - AscendC::WaitFlag(eventUbCVMTE2List[ubListId]); - copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC); - AscendC::SetFlag(eventUbCMTE2VList[ubListId]); - - auto scaleTileOffset = tileOffset.template GetCoordByAxis<1>(); - auto scaleTileShape = actualTileShape.template GetCoordByAxis<1>(); - - auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)]; - auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape); - - auto &ubScale = ubScaleList[ubListId]; - auto layoutUbScale = LayoutScale::template MakeLayoutInUb(scaleTileShape); - - AscendC::WaitFlag(eventUbScaleVMTE2List[ubListId]); - copyGmToUbScale(ubScale, gmTileScale, layoutUbScale, layoutGmTileScale); - AscendC::SetFlag(eventUbScaleMTE2VList[ubListId]); - - auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>(); - auto perTokenScaleTileShape = actualTileShape.template GetCoordByAxis<0>(); - - auto gmTilePerTokenScale = gmPerTokenScale[params.layoutPerTokenScale.GetOffset(perTokenScaleTileOffset)]; - auto layoutGmTilePerTokenScale = params.layoutPerTokenScale.GetTileLayout(perTokenScaleTileShape); - - auto &ubPerTokenScale = ubPerTokenScaleList[ubListId]; - auto layoutUbPerTokenScale = - LayoutScale::template MakeLayoutInUb(perTokenScaleTileShape); - - AscendC::WaitFlag(eventUbPerTokenScaleVMTE2List[ubListId]); - copyGmToUbPerTokenScale(ubPerTokenScale, gmTilePerTokenScale, layoutUbPerTokenScale, - layoutGmTilePerTokenScale); - AscendC::SetFlag(eventUbPerTokenScaleMTE2VList[ubListId]); - - AscendC::WaitFlag(eventUbCMTE2VList[ubListId]); - AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); - AscendC::SetFlag(eventUbCVMTE2List[ubListId]); - - AscendC::WaitFlag(eventUbScaleMTE2VList[ubListId]); - AscendC::Cast(ubScaleFp32, ubScale, AscendC::RoundMode::CAST_NONE, TileShape::COLUMN); - AscendC::SetFlag(eventUbScaleVMTE2List[ubListId]); - - AscendC::WaitFlag(eventUbPerTokenScaleMTE2VList[ubListId]); - AscendC::Cast(ubPerTokenScaleFp32, ubPerTokenScale, AscendC::RoundMode::CAST_NONE, TileShape::ROW); - AscendC::SetFlag(eventUbPerTokenScaleVMTE2List[ubListId]); - - tileRowBroadcastMul(ubMul, ubCFp32, ubScaleFp32); - tileBroadcastOneBlk(ubPerTokenScaleFp32Brcb, ubPerTokenScaleFp32); - AscendC::PipeBarrier(); - tileOneBlkColumnBroadcastMul(ubPerTokenMul, ubMul, ubPerTokenScaleFp32Brcb); - AscendC::PipeBarrier(); - - auto &ubD = ubDList[ubListId]; - LayoutD layoutUbD{actualTileShape, ubTileStride}; - - AscendC::WaitFlag(eventUbDMTE3VList[ubListId]); - AscendC::Cast(ubD, ubPerTokenMul, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); - AscendC::SetFlag(eventUbDVMTE3List[ubListId]); - - auto gmTileD = gmD[params.layoutD.GetOffset(tileOffset)]; - auto layoutGmTileD = params.layoutD.GetTileLayout(actualTileShape); - - AscendC::WaitFlag(eventUbDVMTE3List[ubListId]); - copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD); - AscendC::SetFlag(eventUbDMTE3VList[ubListId]); - - ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0; - } - } - -private: - Params params; - - AscendC::LocalTensor ubCList[UB_STAGES]; - AscendC::LocalTensor ubScaleList[UB_STAGES]; - AscendC::LocalTensor ubPerTokenScaleList[UB_STAGES]; - AscendC::LocalTensor ubDList[UB_STAGES]; - - int32_t eventUbCVMTE2List[UB_STAGES]; - int32_t eventUbCMTE2VList[UB_STAGES]; - int32_t eventUbScaleVMTE2List[UB_STAGES]; - int32_t eventUbScaleMTE2VList[UB_STAGES]; - int32_t eventUbPerTokenScaleVMTE2List[UB_STAGES]; - int32_t eventUbPerTokenScaleMTE2VList[UB_STAGES]; - int32_t eventUbDMTE3VList[UB_STAGES]; - int32_t eventUbDVMTE3List[UB_STAGES]; - - uint32_t ubListId{0}; - - AscendC::LocalTensor ubCFp32; - AscendC::LocalTensor ubScaleFp32; - AscendC::LocalTensor ubMul; - AscendC::LocalTensor ubPerTokenScaleFp32; - AscendC::LocalTensor ubPerTokenScaleFp32Brcb; - AscendC::LocalTensor ubPerTokenMul; - - TileRowBroadcastMul tileRowBroadcastMul; - TileBroadcastOneBlk tileBroadcastOneBlk; - TileOneBlkColumnBroadcastMul tileOneBlkColumnBroadcastMul; - - CopyGmToUbC copyGmToUbC; - CopyGmToUbScale copyGmToUbScale; - CopyGmToUbPerTokenScale copyGmToUbPerTokenScale; - CopyUbToGmD copyUbToGmD; -}; - -template -class BlockEpilogue, CType_, Gemm::GemmType, - Gemm::GemmType, DType_, TileRowBroadcastMul_, TileBroadcastOneBlk_, - TileOneBlkColumnBroadcastMul_, TileCopy_, EpilogueTileSwizzle_> +class BlockEpilogue, + CType_, Gemm::GemmType, Gemm::GemmType, DType_, + TileRowBroadcastMul_, TileBroadcastOneBlk_, TileOneBlkColumnBroadcastMul_, + TileCopy_, EpilogueTileSwizzle_> { public: using DispatchPolicy = EpilogueAtlasA2PerTokenDequantCombine; @@ -327,7 +39,8 @@ public: // Data infos using ElementC = typename CType_::Element; using LayoutC = typename CType_::Layout; - using ElementScale = float; + using ElementRawScale = ScaleType_; + using ElementFp32Scale = float; using LayoutScale = LayoutScale_; using ElementPerTokenScale = float; using LayoutPerTokenScale = LayoutPerTokenScale_; @@ -362,14 +75,16 @@ public: std::is_same_v, "TileShape must be consistent for all tile compute ops"); - static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COLUMN * sizeof(ElementScale) + + static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + + (std::is_same_v ? + 0 : TileShape::COLUMN * sizeof(ElementRawScale)) + + TileShape::COLUMN * sizeof(ElementFp32Scale) + TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) + (TileShape::COUNT + TileShape::COUNT) * sizeof(float) + TileShape::ROW * BYTE_PER_BLK) <= ArchTag::UB_SIZE, "TileShape is too large to fit in UB"); - struct Params { - __gm__ ElementScale *ptrScale{nullptr}; + __gm__ ElementRawScale *ptrScale{nullptr}; LayoutScale layoutScale{}; __gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr}; LayoutPerTokenScale layoutPerTokenScale{}; @@ -380,7 +95,7 @@ public: Params() {}; CATLASS_DEVICE - Params(__gm__ ElementScale *ptrScale_, LayoutScale const &layoutScale_, + Params(__gm__ ElementRawScale *ptrScale_, LayoutScale const &layoutScale_, __gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_, __gm__ ElementD *ptrD_, LayoutD const &layoutD_) : ptrScale(ptrScale_), @@ -408,8 +123,12 @@ public: for (uint32_t i = 0; i < UB_STAGES; ++i) { ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); ubOffset += TileShape::COUNT * sizeof(ElementC); - ubScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COLUMN * sizeof(ElementScale); + if constexpr (!std::is_same_v) { + ubRawScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COLUMN * sizeof(ElementRawScale); + } + ubFp32ScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COLUMN * sizeof(ElementFp32Scale); ubPerTokenScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale); ubDList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); @@ -451,22 +170,6 @@ public: AscendC::DataCopyPad(epSendCountLocal_, epSendCountGM, epSendCntParams, copyPadParams); AscendC::SetFlag(eventMTE2S); AscendC::WaitFlag(eventMTE2S); -#if ENABLE_EP_SEND_COUNT_HASH - tokenToEpRankHashLocal_ = resource.ubBuf.template GetBufferByByte(ubOffset); - uint32_t maxGroupSendCount = 0; - uint32_t groupSendCount = 0; - for (uint32_t expertIdx = 0; expertIdx < calcInfo.moeExpertPerRankNum_; ++expertIdx) { - uint32_t prevGroupSendCount = groupSendCount; - groupSendCount = epSendCountLocal_.GetValue((expertIdx + 1) * calcInfo.epWorldSize_ - 1); - if (maxGroupSendCount < groupSendCount - prevGroupSendCount) { - maxGroupSendCount = groupSendCount - prevGroupSendCount; - } - } - ubOffset += maxGroupSendCount * sizeof(int32_t); - AlignUbOffset(); - // assert: ubOffset <= AscendC::TOTAL_UB_SIZE or - // AscendC::TOTAL_VEC_LOCAL_SIZE -#endif } } @@ -495,28 +198,6 @@ public: ->windowsIn) + calcInfo.winDataSizeOffset_ + expertLocalId * calcInfo.expertPerSizeOnWin_ + rankId * OPT_RANK_OFFSET; } -#if ENABLE_EP_SEND_COUNT_HASH - CATLASS_DEVICE void InitTokenToEpRankHashLocalForEpRank(uint32_t &hashOffset, uint32_t epRank, uint32_t copyLen) - { - constexpr uint32_t DUPLICATE_MASK_COUNT = 8; - uint32_t hashOffsetMask = (((uint32_t)hashOffset) & (DUPLICATE_MASK_COUNT - 1)); - if (hashOffsetMask != 0) { - uint32_t remainMaskCount = DUPLICATE_MASK_COUNT - hashOffsetMask; - if (copyLen < remainMaskCount) { - remainMaskCount = copyLen; - } - uint64_t copyMask = ((1UL << remainMaskCount) - 1) << hashOffsetMask; - AscendC::Duplicate(tokenToEpRankHashLocal_[hashOffset - hashOffsetMask], epRank, ©Mask, 1, 1, - DUPLICATE_MASK_COUNT); - hashOffset += remainMaskCount; - copyLen -= remainMaskCount; - } - if (copyLen > 0) { - AscendC::Duplicate(tokenToEpRankHashLocal_[hashOffset], epRank, copyLen); - hashOffset += copyLen; - } - } -#endif CATLASS_DEVICE void SetCombineSendEpRank(uint32_t epRank, uint32_t &remoteEpRank, uint32_t &localEpRank) { @@ -542,11 +223,7 @@ public: uint32_t tokenOffset = offsetD - startToken * calcInfo.axisH_; uint32_t itToken = startToken; uint32_t endToken = startToken + layoutGmTileD.shape(0); -#if ENABLE_EP_SEND_COUNT_HASH - uint32_t epRankStart = tokenToEpRankHashLocal_(itToken - startToken); -#else constexpr uint32_t epRankStart = 0; -#endif uint32_t sendCount = expertIdx == 0 && epRankStart == 0 ? 0 : epSendCountLocal_.GetValue(expertOffset + epRankStart - 1); for (uint32_t epRank = epRankStart; epRank < calcInfo.epWorldSize_ && itToken < endToken; ++epRank) { @@ -582,20 +259,6 @@ public: if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { expertOffset = expertIdx * calcInfo.epWorldSize_; -#if ENABLE_EP_SEND_COUNT_HASH - if (currentExpertIdx_ != expertIdx) { - uint32_t hashOffset = 0; - uint32_t sendCount = expertIdx == 0 ? 0 : epSendCountLocal_.GetValue(expertOffset - 1); - for (uint32_t epRank = 0; epRank < calcInfo.epWorldSize_; ++epRank) { - uint32_t prevSendCount = sendCount; - sendCount = epSendCountLocal_.GetValue(expertOffset + epRank); - InitTokenToEpRankHashLocalForEpRank(hashOffset, epRank, sendCount - prevSendCount); - } - AscendC::SetFlag(eventVS); - AscendC::WaitFlag(eventVS); - currentExpertIdx_ = expertIdx; - } -#endif } callback(); @@ -605,7 +268,7 @@ public: MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN(); MatrixCoord blockOffset = blockCoord * blockShape; - AscendC::GlobalTensor gmScale; + AscendC::GlobalTensor gmScale; gmScale.SetGlobalBuffer(params.ptrScale); AscendC::GlobalTensor gmPerTokenScale; gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale); @@ -640,11 +303,16 @@ public: auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)]; auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape); - auto &ubScale = ubScaleList[ubListId]; - auto layoutUbScale = LayoutScale::template MakeLayoutInUb(scaleTileShape); - + auto &ubFp32Scale = ubFp32ScaleList[ubListId]; + auto layoutFp32UbScale = LayoutScale::template MakeLayoutInUb(scaleTileShape); + auto &ubRawScale = ubRawScaleList[ubListId]; + auto layoutRawUbScale = LayoutScale::template MakeLayoutInUb(scaleTileShape); AscendC::WaitFlag(eventUbScaleVMTE2List[ubListId]); - copyGmToUbScale(ubScale, gmTileScale, layoutUbScale, layoutGmTileScale); + if constexpr (!std::is_same_v) { + copyGmToUbScale(ubRawScale, gmTileScale, layoutRawUbScale, layoutGmTileScale); + } else { + copyGmToUbScale(ubFp32Scale, gmTileScale, layoutFp32UbScale, layoutGmTileScale); + } AscendC::SetFlag(eventUbScaleMTE2VList[ubListId]); auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>(); @@ -667,7 +335,11 @@ public: AscendC::SetFlag(eventUbCVMTE2List[ubListId]); AscendC::WaitFlag(eventUbScaleMTE2VList[ubListId]); - tileRowBroadcastMul(ubMul, ubCFp32, ubScale); + if constexpr (!std::is_same_v) { + AscendC::Cast(ubFp32Scale, ubRawScale, AscendC::RoundMode::CAST_NONE, TileShape::COLUMN); + AscendC::PipeBarrier(); + } + tileRowBroadcastMul(ubMul, ubCFp32, ubFp32Scale); AscendC::SetFlag(eventUbScaleVMTE2List[ubListId]); AscendC::WaitFlag(eventUbPerTokenScaleMTE2VList[ubListId]); @@ -709,7 +381,8 @@ private: MoeDistributeCombineImpl::CombineCalcInfo calcInfo; AscendC::LocalTensor ubCList[UB_STAGES]; - AscendC::LocalTensor ubScaleList[UB_STAGES]; + AscendC::LocalTensor ubRawScaleList[UB_STAGES]; + AscendC::LocalTensor ubFp32ScaleList[UB_STAGES]; AscendC::LocalTensor ubPerTokenScaleList[UB_STAGES]; AscendC::LocalTensor ubDList[UB_STAGES]; @@ -723,10 +396,6 @@ private: int32_t eventUbDVMTE3List[UB_STAGES]; AscendC::LocalTensor epSendCountLocal_; -#if ENABLE_EP_SEND_COUNT_HASH - AscendC::LocalTensor tokenToEpRankHashLocal_; - uint32_t currentExpertIdx_{static_cast(-1)}; -#endif size_t ubOffset{0}; int32_t eventVMTE2{0}; diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_per_token_dequant_swiglu.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_per_token_dequant_swiglu.h index f203dbfe..a14b0cc2 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_per_token_dequant_swiglu.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_per_token_dequant_swiglu.h @@ -21,13 +21,14 @@ namespace Catlass::Epilogue::Block { -template -class BlockEpilogue, CType_, - Gemm::GemmType, Gemm::GemmType, DType_, - TileRowBroadcastMul_, TileBroadcastOneBlk_, TileOneBlkColumnBroadcastMul_, TileCopy_, - EpilogueTileSwizzle_> +class BlockEpilogue, + CType_, Gemm::GemmType, Gemm::GemmType, + DType_, TileRowBroadcastMul_, TileBroadcastOneBlk_, TileOneBlkColumnBroadcastMul_, + TileCopy_, EpilogueTileSwizzle_> { public: using DispatchPolicy = EpilogueAtlasA2PerTokenDequantSwiglu; @@ -37,7 +38,8 @@ public: // Data infos using ElementC = typename CType_::Element; using LayoutC = typename CType_::Layout; - using ElementScale = float; + using ElementRawScale = ScaleType_; + using ElementFp32Scale = float; using LayoutScale = LayoutScale_; using ElementPerTokenScale = float; using LayoutPerTokenScale = LayoutPerTokenScale_; @@ -76,14 +78,17 @@ public: static_assert(UB_STAGES <= 2, "UB stages too large, event id is not enough."); - static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COLUMN * sizeof(ElementScale) + + static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + + (std::is_same_v ? + 0 : TileShape::COLUMN * sizeof(ElementRawScale)) + + TileShape::COLUMN * sizeof(ElementFp32Scale) + TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) + (TileShape::COUNT + TileShape::COUNT) * sizeof(float) + TileShape::ROW * BYTE_PER_BLK) <= ArchTag::UB_SIZE, "TileShape is too large to fit in UB"); struct Params { - __gm__ ElementScale *ptrScale{nullptr}; + __gm__ ElementRawScale *ptrScale{nullptr}; LayoutScale layoutScale{}; __gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr}; LayoutPerTokenScale layoutPerTokenScale{}; @@ -94,7 +99,7 @@ public: Params() {}; CATLASS_DEVICE - Params(__gm__ ElementScale *ptrScale_, LayoutScale const &layoutScale_, + Params(__gm__ ElementRawScale *ptrScale_, LayoutScale const &layoutScale_, __gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_, __gm__ ElementD *ptrD_, LayoutD const &layoutD_) : ptrScale(ptrScale_), @@ -117,8 +122,12 @@ public: for (uint32_t i = 0; i < UB_STAGES; ++i) { ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); ubOffset += TileShape::COUNT * sizeof(ElementC); - ubScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COLUMN * sizeof(ElementScale); + if constexpr (!std::is_same_v) { + ubRawScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COLUMN * sizeof(ElementRawScale); + } + ubFp32ScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COLUMN * sizeof(ElementFp32Scale); ubPerTokenScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale); ubDList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); @@ -177,7 +186,7 @@ public: MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN(); MatrixCoord blockOffset = blockCoord * blockShape; bool isLeft = blockOffset.column() < (params.layoutD.shape(1) >> 1); - AscendC::GlobalTensor gmScale; + AscendC::GlobalTensor gmScale; gmScale.SetGlobalBuffer(params.ptrScale); AscendC::GlobalTensor gmPerTokenScale; gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale); @@ -212,11 +221,16 @@ public: auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)]; auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape); - auto &ubScale = ubScaleList[ubListId]; - auto layoutUbScale = LayoutScale::template MakeLayoutInUb(scaleTileShape); - + auto &ubFp32Scale = ubFp32ScaleList[ubListId]; + auto layoutFp32UbScale = LayoutScale::template MakeLayoutInUb(scaleTileShape); + auto &ubRawScale = ubRawScaleList[ubListId]; + auto layoutRawUbScale = LayoutScale::template MakeLayoutInUb(scaleTileShape); AscendC::WaitFlag(eventUbScaleVMTE2List[ubListId]); - copyGmToUbScale(ubScale, gmTileScale, layoutUbScale, layoutGmTileScale); + if constexpr (!std::is_same_v) { + copyGmToUbScale(ubRawScale, gmTileScale, layoutRawUbScale, layoutGmTileScale); + } else { + copyGmToUbScale(ubFp32Scale, gmTileScale, layoutFp32UbScale, layoutGmTileScale); + } AscendC::SetFlag(eventUbScaleMTE2VList[ubListId]); auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>(); @@ -238,7 +252,11 @@ public: AscendC::Cast(ubTmpMxN, ubC, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); AscendC::SetFlag(eventUbCVMTE2List[ubListId]); AscendC::WaitFlag(eventUbScaleMTE2VList[ubListId]); - tileRowBroadcastMul(ubTmpMxN, ubTmpMxN, ubScale); + if constexpr (!std::is_same_v) { + AscendC::Cast(ubFp32Scale, ubRawScale, AscendC::RoundMode::CAST_NONE, TileShape::COLUMN); + AscendC::PipeBarrier(); + } + tileRowBroadcastMul(ubTmpMxN, ubTmpMxN, ubFp32Scale); AscendC::SetFlag(eventUbScaleVMTE2List[ubListId]); AscendC::WaitFlag(eventUbPerTokenScaleMTE2VList[ubListId]); tileBroadcastOneBlk(ubTmpMx32B, ubPerTokenScale); @@ -279,7 +297,8 @@ private: Params params; AscendC::LocalTensor ubCList[UB_STAGES]; - AscendC::LocalTensor ubScaleList[UB_STAGES]; + AscendC::LocalTensor ubRawScaleList[UB_STAGES]; + AscendC::LocalTensor ubFp32ScaleList[UB_STAGES]; AscendC::LocalTensor ubPerTokenScaleList[UB_STAGES]; AscendC::LocalTensor ubDList[UB_STAGES]; diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.h index 35b3512c..28eb10d0 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.h @@ -39,7 +39,7 @@ public: using ElementAccumulator = typename BlockMmad::ElementAccumulator; using BlockEpilogue = BlockEpilogue_; - using ElementScale = typename BlockEpilogue::ElementScale; + using ElementScale = typename BlockEpilogue::ElementRawScale; using LayoutScale = typename BlockEpilogue::LayoutScale; using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale; using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale; diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h index cf4956e0..967e5869 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h @@ -354,7 +354,7 @@ __aicore__ inline static void CalQuantRow(const uint32_t column, uint32_t &row) row = row < MAX_QUANT_ROW_ONCE ? row : MAX_QUANT_ROW_ONCE; } -template class GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace { @@ -371,7 +371,7 @@ public: using ElementAccumulator = typename BlockMmad::ElementAccumulator; using BlockEpilogue = BlockEpilogue_; - using ElementScale = typename BlockEpilogue::ElementScale; + using ElementScale = typename BlockEpilogue::ElementRawScale; using LayoutScale = typename BlockEpilogue::LayoutScale; using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale; using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale; @@ -388,7 +388,7 @@ public: static constexpr uint32_t WORKSPACE_STAGES = WORKSPACE_STAGES_; using ElementGroupList = ElementGroupList_; - using XType = XType_; + using XType = ExpandXType; // Parameters structure struct Params { @@ -1715,7 +1715,7 @@ private: namespace Catlass::Gemm::Kernel { -template class GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch { @@ -1732,7 +1732,7 @@ public: using ElementAccumulator = typename BlockMmad::ElementAccumulator; using BlockEpilogue = BlockEpilogue_; - using ElementScale = typename BlockEpilogue::ElementScale; + using ElementScale = typename BlockEpilogue::ElementRawScale; using LayoutScale = typename BlockEpilogue::LayoutScale; using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale; using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale; @@ -2017,7 +2017,7 @@ private: struct AicWaitFunc { using MatmulKernel = GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch< - BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>; + TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>; CATLASS_DEVICE AicWaitFunc() = default; @@ -2034,7 +2034,7 @@ private: struct AicSetFunc { using MatmulKernel = GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch< - BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>; + TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>; CATLASS_DEVICE AicSetFunc() = default; diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_base.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_base.h index 0c57896c..b9ac8932 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_base.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_base.h @@ -12,8 +12,8 @@ #include "../common/moe_distribute_base.h" -#define TemplateMC2TypeClass typename ExpandXType, typename ExpandIdxType, bool IsNeedReduceScatter, uint32_t EXEC_FLAG -#define TemplateMC2TypeFunc ExpandXType, ExpandIdxType, IsNeedReduceScatter, EXEC_FLAG +#define TemplateMC2TypeClass typename ExpandXType, typename W1ScaleType, typename W2ScaleType, typename ExpandIdxType, bool IsNeedReduceScatter, uint32_t EXEC_FLAG +#define TemplateMC2TypeFunc ExpandXType, W1ScaleType, W2ScaleType, ExpandIdxType, IsNeedReduceScatter, EXEC_FLAG #define TemplateDispatchTypeClass \ typename XType, typename ExpandXOutType, bool StaticQuant, bool DynamicQuant, bool IsSmoothScaleExist, \ bool IsNeedAllgater, uint32_t EXEC_FLAG diff --git a/tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_gmm_combine_decode.py b/tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_gmm_combine_decode.py index 547e5c1d..b928ad7e 100644 --- a/tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_gmm_combine_decode.py +++ b/tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_gmm_combine_decode.py @@ -275,8 +275,6 @@ class FusionOp(DecodeMoeOps): torch_npu.Format.FRACTAL_NZ) gmm2_weight = torch_npu.npu_format_cast(gmm2_weight, torch_npu.Format.FRACTAL_NZ) - gmm1_weight_scale = gmm1_weight_scale.float() - gmm2_weight_scale = gmm2_weight_scale.float() if self.dynamic_eplb: self.gmm1_weight = [ diff --git a/vllm_ascend/eplb/adaptor/vllm_adaptor.py b/vllm_ascend/eplb/adaptor/vllm_adaptor.py index 6b5cce8d..500c9d4b 100644 --- a/vllm_ascend/eplb/adaptor/vllm_adaptor.py +++ b/vllm_ascend/eplb/adaptor/vllm_adaptor.py @@ -46,15 +46,12 @@ class VllmEplbAdaptor(EplbAdaptor): self.model.model.layers[i].mlp.experts.w13_weight_scale_fp32_list self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_scale_list"] = \ self.model.model.layers[i].mlp.experts.w2_weight_scale_list - self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_scale_fp32_list"] = \ - self.model.model.layers[i].mlp.experts.w2_weight_scale_fp32_list # TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here if self.model.quant_config is not None: self.expert_weight_names = [ "w13_weight_list", "w2_weight_list", "w13_weight_scale_fp32_list", "w13_weight_offset", - "w2_weight_scale_list", "w2_weight_offset", - "w2_weight_scale_fp32_list" + "w2_weight_scale_list", "w2_weight_offset" ] else: self.expert_weight_names = ["w13_weight", "w2_weight"] @@ -83,8 +80,7 @@ class VllmEplbAdaptor(EplbAdaptor): self.num_dense_layers) + ".mlp.experts." + name if name in [ "w13_weight_list", "w2_weight_list", - "w13_weight_scale_fp32_list", "w2_weight_scale_list", - "w2_weight_scale_fp32_list" + "w13_weight_scale_fp32_list", "w2_weight_scale_list" ]: expert_tensor = self.param_dict[complete_name][0] expert_tensor = expert_tensor.clone() @@ -105,7 +101,7 @@ class VllmEplbAdaptor(EplbAdaptor): if name in [ "w13_weight_list", "w2_weight_list", "w13_weight_scale_fp32_list", - "w2_weight_scale_list", "w2_weight_scale_fp32_list" + "w2_weight_scale_list" ]: per_expert_param.append( self.param_dict["model.layers." + str(layer_idx) + diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 0c6def33..f3037073 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -243,24 +243,16 @@ class AscendW8A8DynamicFusedMoEMethod: topk_weights = topk_weights.to(self.in_dtype) moe_comm_method = get_forward_context().moe_comm_method - # When VLLM_ASCEND_ENABLE_FUSED_MC2 == 2, use dispatch_gmm_combine_decode, need fp32 scale - w2_weight_scale_fp32_flag = ( - get_forward_context().moe_comm_type == MoECommType.FUSED_MC2 - and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2) if self.dynamic_eplb: w1 = layer.w13_weight_list w1_scale = layer.w13_weight_scale_fp32_list w2 = layer.w2_weight_list - w2_scale = layer.w2_weight_scale_fp32_list \ - if w2_weight_scale_fp32_flag else layer.w2_weight_scale_list + w2_scale = layer.w2_weight_scale_list else: w1 = [layer.w13_weight] w1_scale = [layer.w13_weight_scale_fp32] w2 = [layer.w2_weight] - w2_scale = [ - layer.w2_weight_scale_fp32 - if w2_weight_scale_fp32_flag else layer.w2_weight_scale - ] + w2_scale = [layer.w2_weight_scale] fused_scale_flag = (get_forward_context().moe_comm_type == MoECommType.FUSED_MC2 @@ -302,8 +294,6 @@ class AscendW8A8DynamicFusedMoEMethod: layer.w13_weight_offset.data.shape[0], -1) layer.w2_weight_scale.data = layer.w2_weight_scale.data.view( layer.w2_weight_scale.data.shape[0], -1) - layer.w2_weight_scale_fp32 = layer.w2_weight_scale.data.to( - torch.float32) layer.w2_weight_offset.data = layer.w2_weight_offset.data.view( layer.w2_weight_offset.data.shape[0], -1) @@ -328,16 +318,11 @@ class AscendW8A8DynamicFusedMoEMethod: weight.clone() for weight in layer.w2_weight_scale.data.unbind(dim=0) ] - layer.w2_weight_scale_fp32_list = [ - weight.clone() - for weight in layer.w2_weight_scale_fp32.data.unbind(dim=0) - ] del layer.w13_weight del layer.w2_weight del layer.w13_weight_scale del layer.w13_weight_scale_fp32 del layer.w2_weight_scale - del layer.w2_weight_scale_fp32 torch.npu.empty_cache()