From ebb940691fcc1b6e48bc14e8d180934da4d1542f Mon Sep 17 00:00:00 2001 From: wangqiankun13 Date: Mon, 19 Jan 2026 16:10:43 +0800 Subject: [PATCH] [Feature] Adapt DispathGmmCombineDecode opertor to align with weight scale dtype of small operators. [RFC: issue 5476] (#5755) ### What this PR does / why we need it? [Feature] Adapt DispathGmmCombineDecode opertor to align with weight scale dtype of small operators. - **Before**: weight scale must be float32 - **After**: weight scale can be float32/float16 when x is float16, float32/bfloat16 when x is float32/bfloat16. And w1 scale can use different dtype with w2 scale. More info about this operator, please refer to RFC: issue https://github.com/vllm-project/vllm-ascend/issues/5476 ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? #### Perf > When scale is of type fp16 or bf16, it will be cast to fp32 internally within the operator, while the subsequent computations remain unchanged. Therefore, this PR will introduce an additional cast operation but halve the memory copy operations for scale . Furthermore, since the scale data is only a few KB in size and participates in relatively few computations, its impact is almost negligible compared to major operations like matrix multiplication. Thus, the theoretical performance change should be minimal. test single operator cases from qwen3-235b, - single A3 node(ep16), 64 moe experts, 4 experts / die (like qwen3-235b ep32) - batch=18/32, token_hidden_size 4096, moe_intermediate_size 1536 The test was conducted for 100 rounds, and the average of the last 95 rounds was taken. | | bs18(us)| bs32(us)| | -----| -----| -----| |Without this PR|96.28|108.83| |With this PR|96.06|107.90| Note: Single-operator benchmarks represent an ideal scenario. They are usually only useful for referencing relative changes and may not fully align with performance data observed within the full model. #### Acc test qwen3-235b eplb on a single A3 node(ep16), with dispatch_gmm_combine_decode | dataset | version | metric | mode | vllm-api-stream-chat | |----- | ----- | ----- | ----- | -----| | aime2024 | 604a78 | accuracy | gen | 83.33 | - vLLM version: v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/2f4e6548efec402b913ffddc8726230d9311948d Signed-off-by: wangqiankun --- .../dispatch_gmm_combine_decode_def.cpp | 101 +++-- .../op_kernel/dispatch_gmm_combine_decode.cpp | 3 +- .../op_kernel/dispatch_gmm_combine_decode.h | 13 +- .../block_epilogue_per_token_dequant.hpp | 407 ++---------------- .../block_epilogue_per_token_dequant_swiglu.h | 55 ++- ...m_per_token_dequant_multistage_workspace.h | 2 +- ...equant_swiglu_quant_multistage_workspace.h | 14 +- .../dispatch_gmm_combine_decode_base.h | 4 +- .../test_dispatch_gmm_combine_decode.py | 2 - vllm_ascend/eplb/adaptor/vllm_adaptor.py | 10 +- vllm_ascend/quantization/w8a8_dynamic.py | 19 +- 11 files changed, 166 insertions(+), 464 deletions(-) diff --git a/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_def.cpp b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_def.cpp index 511f653c..1f991815 100644 --- a/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_def.cpp +++ b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_def.cpp @@ -17,59 +17,94 @@ public: { this->Input("x") .ParamType(REQUIRED) - .DataType({ge::DT_BF16, ge::DT_FLOAT16}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + .DataType({ge::DT_BF16, ge::DT_BF16, ge::DT_BF16, ge::DT_BF16, + ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("expert_ids") .ParamType(REQUIRED) - .DataType({ge::DT_INT32, ge::DT_INT32}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, + ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("gmm1_permuted_weight") .ParamType(DYNAMIC) - .DataType({ge::DT_INT8, ge::DT_INT8}) - .Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}) - .UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}); + .DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, + ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8}) + .Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, + ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}) + .UnknownShapeFormat( + {ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, + ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}); this->Input("gmm1_permuted_weight_scale") .ParamType(DYNAMIC) - .DataType({ge::DT_FLOAT, ge::DT_FLOAT}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + .DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_BF16, ge::DT_BF16, + ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("gmm2_weight") .ParamType(DYNAMIC) - .DataType({ge::DT_INT8, ge::DT_INT8}) - .Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}) - .UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}); + .DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, + ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8}) + .Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, + ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}) + .UnknownShapeFormat( + {ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, + ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}); this->Input("gmm2_weight_scale") .ParamType(DYNAMIC) - .DataType({ge::DT_FLOAT, ge::DT_FLOAT}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + .DataType({ge::DT_FLOAT, ge::DT_BF16, ge::DT_FLOAT, ge::DT_BF16, + ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("expert_scales") .ParamType(REQUIRED) - .DataType({ge::DT_FLOAT, ge::DT_FLOAT}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + .DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, + ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("expert_smooth_scales") .ParamType(OPTIONAL) - .DataType({ge::DT_FLOAT, ge::DT_FLOAT}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + .DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, + ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("x_active_mask") .ParamType(OPTIONAL) - .DataType({ge::DT_BOOL, ge::DT_BOOL}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + .DataType({ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, + ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Output("output") .ParamType(REQUIRED) - .DataType({ge::DT_BF16, ge::DT_FLOAT16}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + .DataType({ge::DT_BF16, ge::DT_BF16, ge::DT_BF16, ge::DT_BF16, + ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Output("expert_token_nums") .ParamType(REQUIRED) - .DataType({ge::DT_INT64, ge::DT_INT64}) - .Format({ge::FORMAT_ND, ge::FORMAT_ND}) - .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND}); + .DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, + ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64}) + .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Attr("group_ep").String(); this->Attr("ep_rank_size").Int(); this->Attr("ep_rank_id").Int(); diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.cpp b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.cpp index 02f8ada4..aae344af 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.cpp +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.cpp @@ -27,7 +27,8 @@ extern "C" __global__ __aicore__ void dispatch_gmm_combine_decode( GET_TILING_DATA(tiling_data, tiling); if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1) || TILING_KEY_IS(2) || TILING_KEY_IS(3) || TILING_KEY_IS(4) || TILING_KEY_IS(5) || TILING_KEY_IS(6) || TILING_KEY_IS(7)) { - DispatchGmmCombineDecode op; + DispatchGmmCombineDecode< + DTYPE_X, DTYPE_GMM1_PERMUTED_WEIGHT_SCALE, DTYPE_GMM2_WEIGHT_SCALE, int32_t, false, TILING_KEY_VAR> op; op.Init(x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, gmm2_weight_scale, expert_scales, expert_smooth_scales, x_active_mask, output, expertTokenNums, workspace, nullptr, &tiling_data); op.Process(); diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h index bedbc52b..97aa44ea 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h @@ -54,7 +54,7 @@ using Gmm2DispatchPolicy = GMM2_L0A_STAGES, GMM2_L0B_STAGES, CUSTOM_L0C_STAGES, CUSTOM_ENABLE_UNIT_FLAG, CUSTOM_ENABLE_SHUFFLE_K>; -template CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA, layout::RowMajor layoutA, GM_ADDR gmB, layout::zN layoutB, GM_ADDR gmScale, @@ -72,7 +72,6 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun using L1TileShape = L1TileShape_; using L0TileShape = L0TileShape_; - using XType = XType_; using AType = Gemm::GemmType; using BType = Gemm::GemmType; using CType = Gemm::GemmType; @@ -81,7 +80,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun constexpr uint32_t ubStages = 1; using EpilogueDispatchPolicy = Epilogue::EpilogueAtlasA2PerTokenDequantSwiglu; - using ScaleType = Gemm::GemmType; + using ScaleType = Gemm::GemmType; using PerTokenScaleType = Gemm::GemmType; using DType = Gemm::GemmType; @@ -110,9 +109,9 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun using GemmKernel = typename std::conditional< (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE), Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace< - EXEC_FLAG, XType, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>, + TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>, Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch< - BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>>::type; + TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>>::type; if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { typename GemmKernel::Params params{problemShape, @@ -197,7 +196,7 @@ CATLASS_DEVICE void GmmDeq(GemmCoord problemShape, uint32_t groupCount, GM_ADDR constexpr uint32_t ubStages = 1; using EpilogueDispatchPolicy = Epilogue::EpilogueAtlasA2PerTokenDequantCombine; - using ScaleType = Gemm::GemmType; + using ScaleType = Gemm::GemmType; using PerTokenScaleType = Gemm::GemmType; using DType = Gemm::GemmType; @@ -411,7 +410,7 @@ __aicore__ inline void DispatchGmmCombineDecode::Process() Arch::CrossCoreWaitFlag(gmm1AivFinished); } } - GmmDeqSwigluQuant( gmm1ProblemShape, groupCount_, gmGroupList, gmX1, layoutX1, gmPermuteWeight1_, layoutWeight1, gmPermuteScale1_, layoutW1Scale, gmX1Scale, layoutX1Scale, gmX2, layoutX2, gmX2Scale, diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_per_token_dequant.hpp b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_per_token_dequant.hpp index cf7d16e3..dc64875a 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_per_token_dequant.hpp +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_per_token_dequant.hpp @@ -19,304 +19,16 @@ #include "catlass/layout/layout.hpp" #include "catlass/matrix_coord.hpp" -#define ENABLE_EP_SEND_COUNT_HASH 0 - namespace Catlass::Epilogue::Block { -template -class BlockEpilogue, CType_, ScaleType_, PerTokenScaleType_, - DType_, TileRowBroadcastMul_, TileBroadcastOneBlk_, TileOneBlkColumnBroadcastMul_, TileCopy_, - EpilogueTileSwizzle_> -{ -public: - using DispatchPolicy = EpilogueAtlasA2PerTokenDequantCombine; - using ArchTag = typename DispatchPolicy::ArchTag; - static constexpr uint32_t UB_STAGES = UB_STAGES_; - - // Data infos - using ElementC = typename CType_::Element; - using LayoutC = typename CType_::Layout; - using ElementScale = typename ScaleType_::Element; - using LayoutScale = typename ScaleType_::Layout; - using ElementPerTokenScale = typename PerTokenScaleType_::Element; - using LayoutPerTokenScale = typename PerTokenScaleType_::Layout; - using ElementD = typename DType_::Element; - using LayoutD = typename DType_::Layout; - - // Check data infos - static_assert(std::is_same_v && - (std::is_same_v || std::is_same_v) && - std::is_same_v && std::is_same_v, - "The element type template parameters of BlockEpilogue are wrong"); - static_assert(std::is_same_v && std::is_same_v && - std::is_same_v && - std::is_same_v, - "The layout template parameters of BlockEpilogue are wrong"); - - // Tile compute ops - using TileRowBroadcastMul = TileRowBroadcastMul_; - using TileBroadcastOneBlk = TileBroadcastOneBlk_; - using TileOneBlkColumnBroadcastMul = TileOneBlkColumnBroadcastMul_; - - // Tile copy - using CopyGmToUbC = typename TileCopy_::CopyGmToUbC; - using CopyGmToUbScale = typename TileCopy_::CopyGmToUbX; - using CopyGmToUbPerTokenScale = typename TileCopy_::CopyGmToUbY; - using CopyUbToGmD = typename TileCopy_::CopyUbToGmD; - - using EpilogueTileSwizzle = EpilogueTileSwizzle_; - - using TileShape = typename TileRowBroadcastMul::TileShape; - - static_assert(TileShape::ROW == TileBroadcastOneBlk::COMPUTE_LENGTH && - std::is_same_v, - "TileShape must be consistent for all tile compute ops"); - - static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COLUMN * sizeof(ElementScale) + - TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) + - (TileShape::COUNT + TileShape::COLUMN + TileShape::COUNT + TileShape::ROW) * sizeof(float) + - TileShape::ROW * BYTE_PER_BLK) <= ArchTag::UB_SIZE, - "TileShape is too large to fit in UB"); - - struct Params { - __gm__ ElementScale *ptrScale{nullptr}; - LayoutScale layoutScale{}; - __gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr}; - LayoutPerTokenScale layoutPerTokenScale{}; - __gm__ ElementD *ptrD{nullptr}; - LayoutD layoutD{}; - - CATLASS_DEVICE - Params() {}; - - CATLASS_DEVICE - Params(__gm__ ElementScale *ptrScale_, LayoutScale const &layoutScale_, - __gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_, - __gm__ ElementD *ptrD_, LayoutD const &layoutD_) - : ptrScale(ptrScale_), - layoutScale(layoutScale_), - ptrPerTokenScale(ptrPerTokenScale_), - layoutPerTokenScale(layoutPerTokenScale_), - ptrD(ptrD_), - layoutD(layoutD_) - {} - }; - - CATLASS_DEVICE - BlockEpilogue(Arch::Resource const &resource, Params const ¶ms = Params{}) : params(params) - { - size_t ubOffset = 0; - int32_t eventVMTE2 = 0; - int32_t eventMTE2V = 0; - int32_t eventMTE3V = 0; - int32_t eventVMTE3 = 0; - for (uint32_t i = 0; i < UB_STAGES; ++i) { - ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COUNT * sizeof(ElementC); - ubScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COLUMN * sizeof(ElementScale); - ubPerTokenScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale); - ubDList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COUNT * sizeof(ElementD); - - eventUbCVMTE2List[i] = eventVMTE2++; - eventUbCMTE2VList[i] = eventMTE2V++; - eventUbScaleVMTE2List[i] = eventVMTE2++; - eventUbScaleMTE2VList[i] = eventMTE2V++; - eventUbPerTokenScaleVMTE2List[i] = eventVMTE2++; - eventUbPerTokenScaleMTE2VList[i] = eventMTE2V++; - eventUbDMTE3VList[i] = eventMTE3V++; - eventUbDVMTE3List[i] = eventVMTE3++; - - AscendC::SetFlag(eventUbCVMTE2List[i]); - AscendC::SetFlag(eventUbScaleVMTE2List[i]); - AscendC::SetFlag(eventUbPerTokenScaleVMTE2List[i]); - AscendC::SetFlag(eventUbDMTE3VList[i]); - } - ubCFp32 = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COUNT * sizeof(float); - ubScaleFp32 = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COLUMN * sizeof(float); - ubMul = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COUNT * sizeof(float); - ubPerTokenScaleFp32 = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::ROW * sizeof(float); - ubPerTokenScaleFp32Brcb = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::ROW * BYTE_PER_BLK; - ubPerTokenMul = ubMul; - } - - CATLASS_DEVICE - ~BlockEpilogue() - { - for (uint32_t i = 0; i < UB_STAGES; ++i) { - AscendC::WaitFlag(eventUbCVMTE2List[i]); - AscendC::WaitFlag(eventUbScaleVMTE2List[i]); - AscendC::WaitFlag(eventUbPerTokenScaleVMTE2List[i]); - AscendC::WaitFlag(eventUbDMTE3VList[i]); - } - } - - CATLASS_DEVICE - void UpdateParams(Params const ¶ms_) - { - params = params_; - } - - CATLASS_DEVICE - void operator()(GemmCoord const &blockShapeMNK, GemmCoord const &blockCoordMNK, - GemmCoord const &actualBlockShapeMNK, AscendC::GlobalTensor const &gmBlockC, - LayoutC const &layoutBlockC, Callback &&callback = Callback{}) - { - if (actualBlockShapeMNK.k() == 0) { - return; - } - callback(); - - // Calculate the offset of the current block - MatrixCoord blockShape = blockShapeMNK.GetCoordMN(); - MatrixCoord blockCoord = blockCoordMNK.GetCoordMN(); - MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN(); - MatrixCoord blockOffset = blockCoord * blockShape; - - AscendC::GlobalTensor gmScale; - gmScale.SetGlobalBuffer(params.ptrScale); - AscendC::GlobalTensor gmPerTokenScale; - gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale); - AscendC::GlobalTensor gmD; - gmD.SetGlobalBuffer(params.ptrD); - - auto ubTileStride = MakeCoord(static_cast(TileShape::COLUMN), 1L); - auto tileShape = TileShape::ToCoord(); - EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape); - uint32_t tileLoops = epilogueTileSwizzle.GetLoops(); - uint32_t subblockIdx = AscendC::GetSubBlockIdx(); - uint32_t subblockNum = AscendC::GetSubBlockNum(); - for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; loopIdx += subblockNum) { - auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx); - auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord); - auto tileOffsetInBlock = tileCoord * tileShape; - auto tileOffset = blockOffset + tileOffsetInBlock; - - auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)]; - auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape); - - auto &ubC = ubCList[ubListId]; - LayoutC layoutUbC{actualTileShape, ubTileStride}; - - AscendC::WaitFlag(eventUbCVMTE2List[ubListId]); - copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC); - AscendC::SetFlag(eventUbCMTE2VList[ubListId]); - - auto scaleTileOffset = tileOffset.template GetCoordByAxis<1>(); - auto scaleTileShape = actualTileShape.template GetCoordByAxis<1>(); - - auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)]; - auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape); - - auto &ubScale = ubScaleList[ubListId]; - auto layoutUbScale = LayoutScale::template MakeLayoutInUb(scaleTileShape); - - AscendC::WaitFlag(eventUbScaleVMTE2List[ubListId]); - copyGmToUbScale(ubScale, gmTileScale, layoutUbScale, layoutGmTileScale); - AscendC::SetFlag(eventUbScaleMTE2VList[ubListId]); - - auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>(); - auto perTokenScaleTileShape = actualTileShape.template GetCoordByAxis<0>(); - - auto gmTilePerTokenScale = gmPerTokenScale[params.layoutPerTokenScale.GetOffset(perTokenScaleTileOffset)]; - auto layoutGmTilePerTokenScale = params.layoutPerTokenScale.GetTileLayout(perTokenScaleTileShape); - - auto &ubPerTokenScale = ubPerTokenScaleList[ubListId]; - auto layoutUbPerTokenScale = - LayoutScale::template MakeLayoutInUb(perTokenScaleTileShape); - - AscendC::WaitFlag(eventUbPerTokenScaleVMTE2List[ubListId]); - copyGmToUbPerTokenScale(ubPerTokenScale, gmTilePerTokenScale, layoutUbPerTokenScale, - layoutGmTilePerTokenScale); - AscendC::SetFlag(eventUbPerTokenScaleMTE2VList[ubListId]); - - AscendC::WaitFlag(eventUbCMTE2VList[ubListId]); - AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); - AscendC::SetFlag(eventUbCVMTE2List[ubListId]); - - AscendC::WaitFlag(eventUbScaleMTE2VList[ubListId]); - AscendC::Cast(ubScaleFp32, ubScale, AscendC::RoundMode::CAST_NONE, TileShape::COLUMN); - AscendC::SetFlag(eventUbScaleVMTE2List[ubListId]); - - AscendC::WaitFlag(eventUbPerTokenScaleMTE2VList[ubListId]); - AscendC::Cast(ubPerTokenScaleFp32, ubPerTokenScale, AscendC::RoundMode::CAST_NONE, TileShape::ROW); - AscendC::SetFlag(eventUbPerTokenScaleVMTE2List[ubListId]); - - tileRowBroadcastMul(ubMul, ubCFp32, ubScaleFp32); - tileBroadcastOneBlk(ubPerTokenScaleFp32Brcb, ubPerTokenScaleFp32); - AscendC::PipeBarrier(); - tileOneBlkColumnBroadcastMul(ubPerTokenMul, ubMul, ubPerTokenScaleFp32Brcb); - AscendC::PipeBarrier(); - - auto &ubD = ubDList[ubListId]; - LayoutD layoutUbD{actualTileShape, ubTileStride}; - - AscendC::WaitFlag(eventUbDMTE3VList[ubListId]); - AscendC::Cast(ubD, ubPerTokenMul, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); - AscendC::SetFlag(eventUbDVMTE3List[ubListId]); - - auto gmTileD = gmD[params.layoutD.GetOffset(tileOffset)]; - auto layoutGmTileD = params.layoutD.GetTileLayout(actualTileShape); - - AscendC::WaitFlag(eventUbDVMTE3List[ubListId]); - copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD); - AscendC::SetFlag(eventUbDMTE3VList[ubListId]); - - ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0; - } - } - -private: - Params params; - - AscendC::LocalTensor ubCList[UB_STAGES]; - AscendC::LocalTensor ubScaleList[UB_STAGES]; - AscendC::LocalTensor ubPerTokenScaleList[UB_STAGES]; - AscendC::LocalTensor ubDList[UB_STAGES]; - - int32_t eventUbCVMTE2List[UB_STAGES]; - int32_t eventUbCMTE2VList[UB_STAGES]; - int32_t eventUbScaleVMTE2List[UB_STAGES]; - int32_t eventUbScaleMTE2VList[UB_STAGES]; - int32_t eventUbPerTokenScaleVMTE2List[UB_STAGES]; - int32_t eventUbPerTokenScaleMTE2VList[UB_STAGES]; - int32_t eventUbDMTE3VList[UB_STAGES]; - int32_t eventUbDVMTE3List[UB_STAGES]; - - uint32_t ubListId{0}; - - AscendC::LocalTensor ubCFp32; - AscendC::LocalTensor ubScaleFp32; - AscendC::LocalTensor ubMul; - AscendC::LocalTensor ubPerTokenScaleFp32; - AscendC::LocalTensor ubPerTokenScaleFp32Brcb; - AscendC::LocalTensor ubPerTokenMul; - - TileRowBroadcastMul tileRowBroadcastMul; - TileBroadcastOneBlk tileBroadcastOneBlk; - TileOneBlkColumnBroadcastMul tileOneBlkColumnBroadcastMul; - - CopyGmToUbC copyGmToUbC; - CopyGmToUbScale copyGmToUbScale; - CopyGmToUbPerTokenScale copyGmToUbPerTokenScale; - CopyUbToGmD copyUbToGmD; -}; - -template -class BlockEpilogue, CType_, Gemm::GemmType, - Gemm::GemmType, DType_, TileRowBroadcastMul_, TileBroadcastOneBlk_, - TileOneBlkColumnBroadcastMul_, TileCopy_, EpilogueTileSwizzle_> +class BlockEpilogue, + CType_, Gemm::GemmType, Gemm::GemmType, DType_, + TileRowBroadcastMul_, TileBroadcastOneBlk_, TileOneBlkColumnBroadcastMul_, + TileCopy_, EpilogueTileSwizzle_> { public: using DispatchPolicy = EpilogueAtlasA2PerTokenDequantCombine; @@ -327,7 +39,8 @@ public: // Data infos using ElementC = typename CType_::Element; using LayoutC = typename CType_::Layout; - using ElementScale = float; + using ElementRawScale = ScaleType_; + using ElementFp32Scale = float; using LayoutScale = LayoutScale_; using ElementPerTokenScale = float; using LayoutPerTokenScale = LayoutPerTokenScale_; @@ -362,14 +75,16 @@ public: std::is_same_v, "TileShape must be consistent for all tile compute ops"); - static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COLUMN * sizeof(ElementScale) + + static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + + (std::is_same_v ? + 0 : TileShape::COLUMN * sizeof(ElementRawScale)) + + TileShape::COLUMN * sizeof(ElementFp32Scale) + TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) + (TileShape::COUNT + TileShape::COUNT) * sizeof(float) + TileShape::ROW * BYTE_PER_BLK) <= ArchTag::UB_SIZE, "TileShape is too large to fit in UB"); - struct Params { - __gm__ ElementScale *ptrScale{nullptr}; + __gm__ ElementRawScale *ptrScale{nullptr}; LayoutScale layoutScale{}; __gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr}; LayoutPerTokenScale layoutPerTokenScale{}; @@ -380,7 +95,7 @@ public: Params() {}; CATLASS_DEVICE - Params(__gm__ ElementScale *ptrScale_, LayoutScale const &layoutScale_, + Params(__gm__ ElementRawScale *ptrScale_, LayoutScale const &layoutScale_, __gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_, __gm__ ElementD *ptrD_, LayoutD const &layoutD_) : ptrScale(ptrScale_), @@ -408,8 +123,12 @@ public: for (uint32_t i = 0; i < UB_STAGES; ++i) { ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); ubOffset += TileShape::COUNT * sizeof(ElementC); - ubScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COLUMN * sizeof(ElementScale); + if constexpr (!std::is_same_v) { + ubRawScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COLUMN * sizeof(ElementRawScale); + } + ubFp32ScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COLUMN * sizeof(ElementFp32Scale); ubPerTokenScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale); ubDList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); @@ -451,22 +170,6 @@ public: AscendC::DataCopyPad(epSendCountLocal_, epSendCountGM, epSendCntParams, copyPadParams); AscendC::SetFlag(eventMTE2S); AscendC::WaitFlag(eventMTE2S); -#if ENABLE_EP_SEND_COUNT_HASH - tokenToEpRankHashLocal_ = resource.ubBuf.template GetBufferByByte(ubOffset); - uint32_t maxGroupSendCount = 0; - uint32_t groupSendCount = 0; - for (uint32_t expertIdx = 0; expertIdx < calcInfo.moeExpertPerRankNum_; ++expertIdx) { - uint32_t prevGroupSendCount = groupSendCount; - groupSendCount = epSendCountLocal_.GetValue((expertIdx + 1) * calcInfo.epWorldSize_ - 1); - if (maxGroupSendCount < groupSendCount - prevGroupSendCount) { - maxGroupSendCount = groupSendCount - prevGroupSendCount; - } - } - ubOffset += maxGroupSendCount * sizeof(int32_t); - AlignUbOffset(); - // assert: ubOffset <= AscendC::TOTAL_UB_SIZE or - // AscendC::TOTAL_VEC_LOCAL_SIZE -#endif } } @@ -495,28 +198,6 @@ public: ->windowsIn) + calcInfo.winDataSizeOffset_ + expertLocalId * calcInfo.expertPerSizeOnWin_ + rankId * OPT_RANK_OFFSET; } -#if ENABLE_EP_SEND_COUNT_HASH - CATLASS_DEVICE void InitTokenToEpRankHashLocalForEpRank(uint32_t &hashOffset, uint32_t epRank, uint32_t copyLen) - { - constexpr uint32_t DUPLICATE_MASK_COUNT = 8; - uint32_t hashOffsetMask = (((uint32_t)hashOffset) & (DUPLICATE_MASK_COUNT - 1)); - if (hashOffsetMask != 0) { - uint32_t remainMaskCount = DUPLICATE_MASK_COUNT - hashOffsetMask; - if (copyLen < remainMaskCount) { - remainMaskCount = copyLen; - } - uint64_t copyMask = ((1UL << remainMaskCount) - 1) << hashOffsetMask; - AscendC::Duplicate(tokenToEpRankHashLocal_[hashOffset - hashOffsetMask], epRank, ©Mask, 1, 1, - DUPLICATE_MASK_COUNT); - hashOffset += remainMaskCount; - copyLen -= remainMaskCount; - } - if (copyLen > 0) { - AscendC::Duplicate(tokenToEpRankHashLocal_[hashOffset], epRank, copyLen); - hashOffset += copyLen; - } - } -#endif CATLASS_DEVICE void SetCombineSendEpRank(uint32_t epRank, uint32_t &remoteEpRank, uint32_t &localEpRank) { @@ -542,11 +223,7 @@ public: uint32_t tokenOffset = offsetD - startToken * calcInfo.axisH_; uint32_t itToken = startToken; uint32_t endToken = startToken + layoutGmTileD.shape(0); -#if ENABLE_EP_SEND_COUNT_HASH - uint32_t epRankStart = tokenToEpRankHashLocal_(itToken - startToken); -#else constexpr uint32_t epRankStart = 0; -#endif uint32_t sendCount = expertIdx == 0 && epRankStart == 0 ? 0 : epSendCountLocal_.GetValue(expertOffset + epRankStart - 1); for (uint32_t epRank = epRankStart; epRank < calcInfo.epWorldSize_ && itToken < endToken; ++epRank) { @@ -582,20 +259,6 @@ public: if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { expertOffset = expertIdx * calcInfo.epWorldSize_; -#if ENABLE_EP_SEND_COUNT_HASH - if (currentExpertIdx_ != expertIdx) { - uint32_t hashOffset = 0; - uint32_t sendCount = expertIdx == 0 ? 0 : epSendCountLocal_.GetValue(expertOffset - 1); - for (uint32_t epRank = 0; epRank < calcInfo.epWorldSize_; ++epRank) { - uint32_t prevSendCount = sendCount; - sendCount = epSendCountLocal_.GetValue(expertOffset + epRank); - InitTokenToEpRankHashLocalForEpRank(hashOffset, epRank, sendCount - prevSendCount); - } - AscendC::SetFlag(eventVS); - AscendC::WaitFlag(eventVS); - currentExpertIdx_ = expertIdx; - } -#endif } callback(); @@ -605,7 +268,7 @@ public: MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN(); MatrixCoord blockOffset = blockCoord * blockShape; - AscendC::GlobalTensor gmScale; + AscendC::GlobalTensor gmScale; gmScale.SetGlobalBuffer(params.ptrScale); AscendC::GlobalTensor gmPerTokenScale; gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale); @@ -640,11 +303,16 @@ public: auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)]; auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape); - auto &ubScale = ubScaleList[ubListId]; - auto layoutUbScale = LayoutScale::template MakeLayoutInUb(scaleTileShape); - + auto &ubFp32Scale = ubFp32ScaleList[ubListId]; + auto layoutFp32UbScale = LayoutScale::template MakeLayoutInUb(scaleTileShape); + auto &ubRawScale = ubRawScaleList[ubListId]; + auto layoutRawUbScale = LayoutScale::template MakeLayoutInUb(scaleTileShape); AscendC::WaitFlag(eventUbScaleVMTE2List[ubListId]); - copyGmToUbScale(ubScale, gmTileScale, layoutUbScale, layoutGmTileScale); + if constexpr (!std::is_same_v) { + copyGmToUbScale(ubRawScale, gmTileScale, layoutRawUbScale, layoutGmTileScale); + } else { + copyGmToUbScale(ubFp32Scale, gmTileScale, layoutFp32UbScale, layoutGmTileScale); + } AscendC::SetFlag(eventUbScaleMTE2VList[ubListId]); auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>(); @@ -667,7 +335,11 @@ public: AscendC::SetFlag(eventUbCVMTE2List[ubListId]); AscendC::WaitFlag(eventUbScaleMTE2VList[ubListId]); - tileRowBroadcastMul(ubMul, ubCFp32, ubScale); + if constexpr (!std::is_same_v) { + AscendC::Cast(ubFp32Scale, ubRawScale, AscendC::RoundMode::CAST_NONE, TileShape::COLUMN); + AscendC::PipeBarrier(); + } + tileRowBroadcastMul(ubMul, ubCFp32, ubFp32Scale); AscendC::SetFlag(eventUbScaleVMTE2List[ubListId]); AscendC::WaitFlag(eventUbPerTokenScaleMTE2VList[ubListId]); @@ -709,7 +381,8 @@ private: MoeDistributeCombineImpl::CombineCalcInfo calcInfo; AscendC::LocalTensor ubCList[UB_STAGES]; - AscendC::LocalTensor ubScaleList[UB_STAGES]; + AscendC::LocalTensor ubRawScaleList[UB_STAGES]; + AscendC::LocalTensor ubFp32ScaleList[UB_STAGES]; AscendC::LocalTensor ubPerTokenScaleList[UB_STAGES]; AscendC::LocalTensor ubDList[UB_STAGES]; @@ -723,10 +396,6 @@ private: int32_t eventUbDVMTE3List[UB_STAGES]; AscendC::LocalTensor epSendCountLocal_; -#if ENABLE_EP_SEND_COUNT_HASH - AscendC::LocalTensor tokenToEpRankHashLocal_; - uint32_t currentExpertIdx_{static_cast(-1)}; -#endif size_t ubOffset{0}; int32_t eventVMTE2{0}; diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_per_token_dequant_swiglu.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_per_token_dequant_swiglu.h index f203dbfe..a14b0cc2 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_per_token_dequant_swiglu.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_per_token_dequant_swiglu.h @@ -21,13 +21,14 @@ namespace Catlass::Epilogue::Block { -template -class BlockEpilogue, CType_, - Gemm::GemmType, Gemm::GemmType, DType_, - TileRowBroadcastMul_, TileBroadcastOneBlk_, TileOneBlkColumnBroadcastMul_, TileCopy_, - EpilogueTileSwizzle_> +class BlockEpilogue, + CType_, Gemm::GemmType, Gemm::GemmType, + DType_, TileRowBroadcastMul_, TileBroadcastOneBlk_, TileOneBlkColumnBroadcastMul_, + TileCopy_, EpilogueTileSwizzle_> { public: using DispatchPolicy = EpilogueAtlasA2PerTokenDequantSwiglu; @@ -37,7 +38,8 @@ public: // Data infos using ElementC = typename CType_::Element; using LayoutC = typename CType_::Layout; - using ElementScale = float; + using ElementRawScale = ScaleType_; + using ElementFp32Scale = float; using LayoutScale = LayoutScale_; using ElementPerTokenScale = float; using LayoutPerTokenScale = LayoutPerTokenScale_; @@ -76,14 +78,17 @@ public: static_assert(UB_STAGES <= 2, "UB stages too large, event id is not enough."); - static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COLUMN * sizeof(ElementScale) + + static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + + (std::is_same_v ? + 0 : TileShape::COLUMN * sizeof(ElementRawScale)) + + TileShape::COLUMN * sizeof(ElementFp32Scale) + TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) + (TileShape::COUNT + TileShape::COUNT) * sizeof(float) + TileShape::ROW * BYTE_PER_BLK) <= ArchTag::UB_SIZE, "TileShape is too large to fit in UB"); struct Params { - __gm__ ElementScale *ptrScale{nullptr}; + __gm__ ElementRawScale *ptrScale{nullptr}; LayoutScale layoutScale{}; __gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr}; LayoutPerTokenScale layoutPerTokenScale{}; @@ -94,7 +99,7 @@ public: Params() {}; CATLASS_DEVICE - Params(__gm__ ElementScale *ptrScale_, LayoutScale const &layoutScale_, + Params(__gm__ ElementRawScale *ptrScale_, LayoutScale const &layoutScale_, __gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_, __gm__ ElementD *ptrD_, LayoutD const &layoutD_) : ptrScale(ptrScale_), @@ -117,8 +122,12 @@ public: for (uint32_t i = 0; i < UB_STAGES; ++i) { ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); ubOffset += TileShape::COUNT * sizeof(ElementC); - ubScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COLUMN * sizeof(ElementScale); + if constexpr (!std::is_same_v) { + ubRawScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COLUMN * sizeof(ElementRawScale); + } + ubFp32ScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COLUMN * sizeof(ElementFp32Scale); ubPerTokenScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale); ubDList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); @@ -177,7 +186,7 @@ public: MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN(); MatrixCoord blockOffset = blockCoord * blockShape; bool isLeft = blockOffset.column() < (params.layoutD.shape(1) >> 1); - AscendC::GlobalTensor gmScale; + AscendC::GlobalTensor gmScale; gmScale.SetGlobalBuffer(params.ptrScale); AscendC::GlobalTensor gmPerTokenScale; gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale); @@ -212,11 +221,16 @@ public: auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)]; auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape); - auto &ubScale = ubScaleList[ubListId]; - auto layoutUbScale = LayoutScale::template MakeLayoutInUb(scaleTileShape); - + auto &ubFp32Scale = ubFp32ScaleList[ubListId]; + auto layoutFp32UbScale = LayoutScale::template MakeLayoutInUb(scaleTileShape); + auto &ubRawScale = ubRawScaleList[ubListId]; + auto layoutRawUbScale = LayoutScale::template MakeLayoutInUb(scaleTileShape); AscendC::WaitFlag(eventUbScaleVMTE2List[ubListId]); - copyGmToUbScale(ubScale, gmTileScale, layoutUbScale, layoutGmTileScale); + if constexpr (!std::is_same_v) { + copyGmToUbScale(ubRawScale, gmTileScale, layoutRawUbScale, layoutGmTileScale); + } else { + copyGmToUbScale(ubFp32Scale, gmTileScale, layoutFp32UbScale, layoutGmTileScale); + } AscendC::SetFlag(eventUbScaleMTE2VList[ubListId]); auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>(); @@ -238,7 +252,11 @@ public: AscendC::Cast(ubTmpMxN, ubC, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); AscendC::SetFlag(eventUbCVMTE2List[ubListId]); AscendC::WaitFlag(eventUbScaleMTE2VList[ubListId]); - tileRowBroadcastMul(ubTmpMxN, ubTmpMxN, ubScale); + if constexpr (!std::is_same_v) { + AscendC::Cast(ubFp32Scale, ubRawScale, AscendC::RoundMode::CAST_NONE, TileShape::COLUMN); + AscendC::PipeBarrier(); + } + tileRowBroadcastMul(ubTmpMxN, ubTmpMxN, ubFp32Scale); AscendC::SetFlag(eventUbScaleVMTE2List[ubListId]); AscendC::WaitFlag(eventUbPerTokenScaleMTE2VList[ubListId]); tileBroadcastOneBlk(ubTmpMx32B, ubPerTokenScale); @@ -279,7 +297,8 @@ private: Params params; AscendC::LocalTensor ubCList[UB_STAGES]; - AscendC::LocalTensor ubScaleList[UB_STAGES]; + AscendC::LocalTensor ubRawScaleList[UB_STAGES]; + AscendC::LocalTensor ubFp32ScaleList[UB_STAGES]; AscendC::LocalTensor ubPerTokenScaleList[UB_STAGES]; AscendC::LocalTensor ubDList[UB_STAGES]; diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.h index 35b3512c..28eb10d0 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.h @@ -39,7 +39,7 @@ public: using ElementAccumulator = typename BlockMmad::ElementAccumulator; using BlockEpilogue = BlockEpilogue_; - using ElementScale = typename BlockEpilogue::ElementScale; + using ElementScale = typename BlockEpilogue::ElementRawScale; using LayoutScale = typename BlockEpilogue::LayoutScale; using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale; using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale; diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h index cf4956e0..967e5869 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h @@ -354,7 +354,7 @@ __aicore__ inline static void CalQuantRow(const uint32_t column, uint32_t &row) row = row < MAX_QUANT_ROW_ONCE ? row : MAX_QUANT_ROW_ONCE; } -template class GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace { @@ -371,7 +371,7 @@ public: using ElementAccumulator = typename BlockMmad::ElementAccumulator; using BlockEpilogue = BlockEpilogue_; - using ElementScale = typename BlockEpilogue::ElementScale; + using ElementScale = typename BlockEpilogue::ElementRawScale; using LayoutScale = typename BlockEpilogue::LayoutScale; using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale; using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale; @@ -388,7 +388,7 @@ public: static constexpr uint32_t WORKSPACE_STAGES = WORKSPACE_STAGES_; using ElementGroupList = ElementGroupList_; - using XType = XType_; + using XType = ExpandXType; // Parameters structure struct Params { @@ -1715,7 +1715,7 @@ private: namespace Catlass::Gemm::Kernel { -template class GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch { @@ -1732,7 +1732,7 @@ public: using ElementAccumulator = typename BlockMmad::ElementAccumulator; using BlockEpilogue = BlockEpilogue_; - using ElementScale = typename BlockEpilogue::ElementScale; + using ElementScale = typename BlockEpilogue::ElementRawScale; using LayoutScale = typename BlockEpilogue::LayoutScale; using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale; using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale; @@ -2017,7 +2017,7 @@ private: struct AicWaitFunc { using MatmulKernel = GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch< - BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>; + TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>; CATLASS_DEVICE AicWaitFunc() = default; @@ -2034,7 +2034,7 @@ private: struct AicSetFunc { using MatmulKernel = GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch< - BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>; + TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>; CATLASS_DEVICE AicSetFunc() = default; diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_base.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_base.h index 0c57896c..b9ac8932 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_base.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_base.h @@ -12,8 +12,8 @@ #include "../common/moe_distribute_base.h" -#define TemplateMC2TypeClass typename ExpandXType, typename ExpandIdxType, bool IsNeedReduceScatter, uint32_t EXEC_FLAG -#define TemplateMC2TypeFunc ExpandXType, ExpandIdxType, IsNeedReduceScatter, EXEC_FLAG +#define TemplateMC2TypeClass typename ExpandXType, typename W1ScaleType, typename W2ScaleType, typename ExpandIdxType, bool IsNeedReduceScatter, uint32_t EXEC_FLAG +#define TemplateMC2TypeFunc ExpandXType, W1ScaleType, W2ScaleType, ExpandIdxType, IsNeedReduceScatter, EXEC_FLAG #define TemplateDispatchTypeClass \ typename XType, typename ExpandXOutType, bool StaticQuant, bool DynamicQuant, bool IsSmoothScaleExist, \ bool IsNeedAllgater, uint32_t EXEC_FLAG diff --git a/tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_gmm_combine_decode.py b/tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_gmm_combine_decode.py index 547e5c1d..b928ad7e 100644 --- a/tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_gmm_combine_decode.py +++ b/tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_gmm_combine_decode.py @@ -275,8 +275,6 @@ class FusionOp(DecodeMoeOps): torch_npu.Format.FRACTAL_NZ) gmm2_weight = torch_npu.npu_format_cast(gmm2_weight, torch_npu.Format.FRACTAL_NZ) - gmm1_weight_scale = gmm1_weight_scale.float() - gmm2_weight_scale = gmm2_weight_scale.float() if self.dynamic_eplb: self.gmm1_weight = [ diff --git a/vllm_ascend/eplb/adaptor/vllm_adaptor.py b/vllm_ascend/eplb/adaptor/vllm_adaptor.py index 6b5cce8d..500c9d4b 100644 --- a/vllm_ascend/eplb/adaptor/vllm_adaptor.py +++ b/vllm_ascend/eplb/adaptor/vllm_adaptor.py @@ -46,15 +46,12 @@ class VllmEplbAdaptor(EplbAdaptor): self.model.model.layers[i].mlp.experts.w13_weight_scale_fp32_list self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_scale_list"] = \ self.model.model.layers[i].mlp.experts.w2_weight_scale_list - self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_scale_fp32_list"] = \ - self.model.model.layers[i].mlp.experts.w2_weight_scale_fp32_list # TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here if self.model.quant_config is not None: self.expert_weight_names = [ "w13_weight_list", "w2_weight_list", "w13_weight_scale_fp32_list", "w13_weight_offset", - "w2_weight_scale_list", "w2_weight_offset", - "w2_weight_scale_fp32_list" + "w2_weight_scale_list", "w2_weight_offset" ] else: self.expert_weight_names = ["w13_weight", "w2_weight"] @@ -83,8 +80,7 @@ class VllmEplbAdaptor(EplbAdaptor): self.num_dense_layers) + ".mlp.experts." + name if name in [ "w13_weight_list", "w2_weight_list", - "w13_weight_scale_fp32_list", "w2_weight_scale_list", - "w2_weight_scale_fp32_list" + "w13_weight_scale_fp32_list", "w2_weight_scale_list" ]: expert_tensor = self.param_dict[complete_name][0] expert_tensor = expert_tensor.clone() @@ -105,7 +101,7 @@ class VllmEplbAdaptor(EplbAdaptor): if name in [ "w13_weight_list", "w2_weight_list", "w13_weight_scale_fp32_list", - "w2_weight_scale_list", "w2_weight_scale_fp32_list" + "w2_weight_scale_list" ]: per_expert_param.append( self.param_dict["model.layers." + str(layer_idx) + diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 0c6def33..f3037073 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -243,24 +243,16 @@ class AscendW8A8DynamicFusedMoEMethod: topk_weights = topk_weights.to(self.in_dtype) moe_comm_method = get_forward_context().moe_comm_method - # When VLLM_ASCEND_ENABLE_FUSED_MC2 == 2, use dispatch_gmm_combine_decode, need fp32 scale - w2_weight_scale_fp32_flag = ( - get_forward_context().moe_comm_type == MoECommType.FUSED_MC2 - and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2) if self.dynamic_eplb: w1 = layer.w13_weight_list w1_scale = layer.w13_weight_scale_fp32_list w2 = layer.w2_weight_list - w2_scale = layer.w2_weight_scale_fp32_list \ - if w2_weight_scale_fp32_flag else layer.w2_weight_scale_list + w2_scale = layer.w2_weight_scale_list else: w1 = [layer.w13_weight] w1_scale = [layer.w13_weight_scale_fp32] w2 = [layer.w2_weight] - w2_scale = [ - layer.w2_weight_scale_fp32 - if w2_weight_scale_fp32_flag else layer.w2_weight_scale - ] + w2_scale = [layer.w2_weight_scale] fused_scale_flag = (get_forward_context().moe_comm_type == MoECommType.FUSED_MC2 @@ -302,8 +294,6 @@ class AscendW8A8DynamicFusedMoEMethod: layer.w13_weight_offset.data.shape[0], -1) layer.w2_weight_scale.data = layer.w2_weight_scale.data.view( layer.w2_weight_scale.data.shape[0], -1) - layer.w2_weight_scale_fp32 = layer.w2_weight_scale.data.to( - torch.float32) layer.w2_weight_offset.data = layer.w2_weight_offset.data.view( layer.w2_weight_offset.data.shape[0], -1) @@ -328,16 +318,11 @@ class AscendW8A8DynamicFusedMoEMethod: weight.clone() for weight in layer.w2_weight_scale.data.unbind(dim=0) ] - layer.w2_weight_scale_fp32_list = [ - weight.clone() - for weight in layer.w2_weight_scale_fp32.data.unbind(dim=0) - ] del layer.w13_weight del layer.w2_weight del layer.w13_weight_scale del layer.w13_weight_scale_fp32 del layer.w2_weight_scale - del layer.w2_weight_scale_fp32 torch.npu.empty_cache()