[Feature] Adapt DispathGmmCombineDecode opertor to align with weight scale dtype of small operators. [RFC: issue 5476] (#5755)

### What this PR does / why we need it?

[Feature] Adapt DispathGmmCombineDecode opertor to align with weight
scale dtype of small operators.
- **Before**: weight scale must be float32
- **After**: weight scale can be float32/float16 when x is float16,
float32/bfloat16 when x is float32/bfloat16. And w1 scale can use
different dtype with w2 scale.

More info about this operator, please refer to RFC: issue
https://github.com/vllm-project/vllm-ascend/issues/5476

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
#### Perf

> When scale is of type fp16 or bf16, it will be cast to fp32 internally
within the operator, while the subsequent computations remain unchanged.
Therefore, this PR will introduce an additional cast operation but halve
the memory copy operations for scale . Furthermore, since the scale data
is only a few KB in size and participates in relatively few
computations, its impact is almost negligible compared to major
operations like matrix multiplication. Thus, the theoretical performance
change should be minimal.

test single operator cases from qwen3-235b,
- single A3 node(ep16), 64 moe experts, 4 experts / die (like qwen3-235b
ep32)
- batch=18/32, token_hidden_size 4096, moe_intermediate_size 1536

The test was conducted for 100 rounds, and the average of the last 95
rounds was taken.
| | bs18(us)| bs32(us)|
| -----| -----| -----|
|Without this PR|96.28|108.83|
|With this PR|96.06|107.90|

Note: Single-operator benchmarks represent an ideal scenario. They are
usually only useful for referencing relative changes and may not fully
align with performance data observed within the full model.

#### Acc
test qwen3-235b eplb on a single A3 node(ep16),
with dispatch_gmm_combine_decode
| dataset | version | metric | mode | vllm-api-stream-chat |
|----- | ----- | ----- | ----- | -----|
| aime2024 | 604a78 | accuracy | gen | 83.33 |

- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef

Signed-off-by: wangqiankun <wangqiankun13@huawei.com>
This commit is contained in:
wangqiankun13
2026-01-19 16:10:43 +08:00
committed by GitHub
parent 687df88151
commit ebb940691f
11 changed files with 166 additions and 464 deletions

View File

@@ -17,59 +17,94 @@ public:
{
this->Input("x")
.ParamType(REQUIRED)
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
.DataType({ge::DT_BF16, ge::DT_BF16, ge::DT_BF16, ge::DT_BF16,
ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("expert_ids")
.ParamType(REQUIRED)
.DataType({ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32,
ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("gmm1_permuted_weight")
.ParamType(DYNAMIC)
.DataType({ge::DT_INT8, ge::DT_INT8})
.Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
.UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
.DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8,
ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8})
.Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ,
ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
.UnknownShapeFormat(
{ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ,
ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
this->Input("gmm1_permuted_weight_scale")
.ParamType(DYNAMIC)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
.DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_BF16, ge::DT_BF16,
ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("gmm2_weight")
.ParamType(DYNAMIC)
.DataType({ge::DT_INT8, ge::DT_INT8})
.Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
.UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
.DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8,
ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8})
.Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ,
ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
.UnknownShapeFormat(
{ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ,
ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
this->Input("gmm2_weight_scale")
.ParamType(DYNAMIC)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
.DataType({ge::DT_FLOAT, ge::DT_BF16, ge::DT_FLOAT, ge::DT_BF16,
ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("expert_scales")
.ParamType(REQUIRED)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
.DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT,
ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("expert_smooth_scales")
.ParamType(OPTIONAL)
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
.DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT,
ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Input("x_active_mask")
.ParamType(OPTIONAL)
.DataType({ge::DT_BOOL, ge::DT_BOOL})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
.DataType({ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL,
ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("output")
.ParamType(REQUIRED)
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
.DataType({ge::DT_BF16, ge::DT_BF16, ge::DT_BF16, ge::DT_BF16,
ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Output("expert_token_nums")
.ParamType(REQUIRED)
.DataType({ge::DT_INT64, ge::DT_INT64})
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64,
ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
this->Attr("group_ep").String();
this->Attr("ep_rank_size").Int();
this->Attr("ep_rank_id").Int();

View File

@@ -27,7 +27,8 @@ extern "C" __global__ __aicore__ void dispatch_gmm_combine_decode(
GET_TILING_DATA(tiling_data, tiling);
if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1) || TILING_KEY_IS(2) || TILING_KEY_IS(3) ||
TILING_KEY_IS(4) || TILING_KEY_IS(5) || TILING_KEY_IS(6) || TILING_KEY_IS(7)) {
DispatchGmmCombineDecode<DTYPE_X, int32_t, false, TILING_KEY_VAR> op;
DispatchGmmCombineDecode<
DTYPE_X, DTYPE_GMM1_PERMUTED_WEIGHT_SCALE, DTYPE_GMM2_WEIGHT_SCALE, int32_t, false, TILING_KEY_VAR> op;
op.Init(x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, gmm2_weight_scale,
expert_scales, expert_smooth_scales, x_active_mask, output, expertTokenNums, workspace, nullptr, &tiling_data);
op.Process();

View File

@@ -54,7 +54,7 @@ using Gmm2DispatchPolicy =
GMM2_L0A_STAGES, GMM2_L0B_STAGES, CUSTOM_L0C_STAGES,
CUSTOM_ENABLE_UNIT_FLAG, CUSTOM_ENABLE_SHUFFLE_K>;
template <uint32_t EXEC_FLAG, typename XType_, class L1TileShape_, class L0TileShape_, class EpilogueTileShape_,
template <TemplateMC2TypeClass, class L1TileShape_, class L0TileShape_, class EpilogueTileShape_,
class BlockScheduler_, class DispatchPolicy_ = MmadAtlasA2Custom>
CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA,
layout::RowMajor layoutA, GM_ADDR gmB, layout::zN layoutB, GM_ADDR gmScale,
@@ -72,7 +72,6 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun
using L1TileShape = L1TileShape_;
using L0TileShape = L0TileShape_;
using XType = XType_;
using AType = Gemm::GemmType<int8_t, layout::RowMajor>;
using BType = Gemm::GemmType<int8_t, layout::zN>;
using CType = Gemm::GemmType<int32_t, layout::RowMajor>;
@@ -81,7 +80,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun
constexpr uint32_t ubStages = 1;
using EpilogueDispatchPolicy = Epilogue::EpilogueAtlasA2PerTokenDequantSwiglu<ubStages, 0>;
using ScaleType = Gemm::GemmType<float, layout::VectorLayout>;
using ScaleType = Gemm::GemmType<W1ScaleType, layout::VectorLayout>;
using PerTokenScaleType = Gemm::GemmType<float, layout::VectorLayout>;
using DType = Gemm::GemmType<float, layout::RowMajor>;
@@ -110,9 +109,9 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun
using GemmKernel = typename std::conditional<
(EXEC_FLAG & EXEC_FLAG_DEEP_FUSE),
Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace<
EXEC_FLAG, XType, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>,
TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>,
Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch<
BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>>::type;
TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>>::type;
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
typename GemmKernel::Params params{problemShape,
@@ -197,7 +196,7 @@ CATLASS_DEVICE void GmmDeq(GemmCoord problemShape, uint32_t groupCount, GM_ADDR
constexpr uint32_t ubStages = 1;
using EpilogueDispatchPolicy = Epilogue::EpilogueAtlasA2PerTokenDequantCombine<ubStages, EXEC_FLAG>;
using ScaleType = Gemm::GemmType<float, layout::VectorLayout>;
using ScaleType = Gemm::GemmType<W2ScaleType, layout::VectorLayout>;
using PerTokenScaleType = Gemm::GemmType<float, layout::VectorLayout>;
using DType = Gemm::GemmType<ExpandXType, layout::RowMajor>;
@@ -411,7 +410,7 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Process()
Arch::CrossCoreWaitFlag(gmm1AivFinished);
}
}
GmmDeqSwigluQuant<EXEC_FLAG, ExpandXType, Gmm1L1TileShape, Gmm1L0TileShape, Gmm1EpilogueTileShape,
GmmDeqSwigluQuant<TemplateMC2TypeFunc, Gmm1L1TileShape, Gmm1L0TileShape, Gmm1EpilogueTileShape,
Gmm1BlockScheduler>(
gmm1ProblemShape, groupCount_, gmGroupList, gmX1, layoutX1, gmPermuteWeight1_, layoutWeight1,
gmPermuteScale1_, layoutW1Scale, gmX1Scale, layoutX1Scale, gmX2, layoutX2, gmX2Scale,

View File

@@ -19,304 +19,16 @@
#include "catlass/layout/layout.hpp"
#include "catlass/matrix_coord.hpp"
#define ENABLE_EP_SEND_COUNT_HASH 0
namespace Catlass::Epilogue::Block {
template <uint32_t UB_STAGES_, uint32_t EXEC_FLAG_, class CType_, class ScaleType_, class PerTokenScaleType_,
class DType_, class TileRowBroadcastMul_, class TileBroadcastOneBlk_, class TileOneBlkColumnBroadcastMul_,
template <uint32_t UB_STAGES_, uint32_t EXEC_FLAG_,
class CType_, class ScaleType_, class LayoutScale_, class LayoutPerTokenScale_, class DType_,
class TileRowBroadcastMul_, class TileBroadcastOneBlk_, class TileOneBlkColumnBroadcastMul_,
class TileCopy_, class EpilogueTileSwizzle_>
class BlockEpilogue<EpilogueAtlasA2PerTokenDequantCombine<UB_STAGES_, EXEC_FLAG_>, CType_, ScaleType_, PerTokenScaleType_,
DType_, TileRowBroadcastMul_, TileBroadcastOneBlk_, TileOneBlkColumnBroadcastMul_, TileCopy_,
EpilogueTileSwizzle_>
{
public:
using DispatchPolicy = EpilogueAtlasA2PerTokenDequantCombine<UB_STAGES_, EXEC_FLAG_>;
using ArchTag = typename DispatchPolicy::ArchTag;
static constexpr uint32_t UB_STAGES = UB_STAGES_;
// Data infos
using ElementC = typename CType_::Element;
using LayoutC = typename CType_::Layout;
using ElementScale = typename ScaleType_::Element;
using LayoutScale = typename ScaleType_::Layout;
using ElementPerTokenScale = typename PerTokenScaleType_::Element;
using LayoutPerTokenScale = typename PerTokenScaleType_::Layout;
using ElementD = typename DType_::Element;
using LayoutD = typename DType_::Layout;
// Check data infos
static_assert(std::is_same_v<ElementC, int32_t> &&
(std::is_same_v<ElementD, half> || std::is_same_v<ElementD, bfloat16_t>) &&
std::is_same_v<ElementScale, ElementD> && std::is_same_v<ElementPerTokenScale, ElementD>,
"The element type template parameters of BlockEpilogue are wrong");
static_assert(std::is_same_v<LayoutC, layout::RowMajor> && std::is_same_v<LayoutScale, layout::VectorLayout> &&
std::is_same_v<LayoutPerTokenScale, layout::VectorLayout> &&
std::is_same_v<LayoutD, layout::RowMajor>,
"The layout template parameters of BlockEpilogue are wrong");
// Tile compute ops
using TileRowBroadcastMul = TileRowBroadcastMul_;
using TileBroadcastOneBlk = TileBroadcastOneBlk_;
using TileOneBlkColumnBroadcastMul = TileOneBlkColumnBroadcastMul_;
// Tile copy
using CopyGmToUbC = typename TileCopy_::CopyGmToUbC;
using CopyGmToUbScale = typename TileCopy_::CopyGmToUbX;
using CopyGmToUbPerTokenScale = typename TileCopy_::CopyGmToUbY;
using CopyUbToGmD = typename TileCopy_::CopyUbToGmD;
using EpilogueTileSwizzle = EpilogueTileSwizzle_;
using TileShape = typename TileRowBroadcastMul::TileShape;
static_assert(TileShape::ROW == TileBroadcastOneBlk::COMPUTE_LENGTH &&
std::is_same_v<TileShape, typename TileOneBlkColumnBroadcastMul::TileShape>,
"TileShape must be consistent for all tile compute ops");
static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COLUMN * sizeof(ElementScale) +
TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) +
(TileShape::COUNT + TileShape::COLUMN + TileShape::COUNT + TileShape::ROW) * sizeof(float) +
TileShape::ROW * BYTE_PER_BLK) <= ArchTag::UB_SIZE,
"TileShape is too large to fit in UB");
struct Params {
__gm__ ElementScale *ptrScale{nullptr};
LayoutScale layoutScale{};
__gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr};
LayoutPerTokenScale layoutPerTokenScale{};
__gm__ ElementD *ptrD{nullptr};
LayoutD layoutD{};
CATLASS_DEVICE
Params() {};
CATLASS_DEVICE
Params(__gm__ ElementScale *ptrScale_, LayoutScale const &layoutScale_,
__gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_,
__gm__ ElementD *ptrD_, LayoutD const &layoutD_)
: ptrScale(ptrScale_),
layoutScale(layoutScale_),
ptrPerTokenScale(ptrPerTokenScale_),
layoutPerTokenScale(layoutPerTokenScale_),
ptrD(ptrD_),
layoutD(layoutD_)
{}
};
CATLASS_DEVICE
BlockEpilogue(Arch::Resource<ArchTag> const &resource, Params const &params = Params{}) : params(params)
{
size_t ubOffset = 0;
int32_t eventVMTE2 = 0;
int32_t eventMTE2V = 0;
int32_t eventMTE3V = 0;
int32_t eventVMTE3 = 0;
for (uint32_t i = 0; i < UB_STAGES; ++i) {
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
ubOffset += TileShape::COUNT * sizeof(ElementC);
ubScaleList[i] = resource.ubBuf.template GetBufferByByte<ElementScale>(ubOffset);
ubOffset += TileShape::COLUMN * sizeof(ElementScale);
ubPerTokenScaleList[i] = resource.ubBuf.template GetBufferByByte<ElementPerTokenScale>(ubOffset);
ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale);
ubDList[i] = resource.ubBuf.template GetBufferByByte<ElementD>(ubOffset);
ubOffset += TileShape::COUNT * sizeof(ElementD);
eventUbCVMTE2List[i] = eventVMTE2++;
eventUbCMTE2VList[i] = eventMTE2V++;
eventUbScaleVMTE2List[i] = eventVMTE2++;
eventUbScaleMTE2VList[i] = eventMTE2V++;
eventUbPerTokenScaleVMTE2List[i] = eventVMTE2++;
eventUbPerTokenScaleMTE2VList[i] = eventMTE2V++;
eventUbDMTE3VList[i] = eventMTE3V++;
eventUbDVMTE3List[i] = eventVMTE3++;
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[i]);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbPerTokenScaleVMTE2List[i]);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
}
ubCFp32 = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
ubOffset += TileShape::COUNT * sizeof(float);
ubScaleFp32 = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
ubOffset += TileShape::COLUMN * sizeof(float);
ubMul = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
ubOffset += TileShape::COUNT * sizeof(float);
ubPerTokenScaleFp32 = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
ubOffset += TileShape::ROW * sizeof(float);
ubPerTokenScaleFp32Brcb = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
ubOffset += TileShape::ROW * BYTE_PER_BLK;
ubPerTokenMul = ubMul;
}
CATLASS_DEVICE
~BlockEpilogue()
{
for (uint32_t i = 0; i < UB_STAGES; ++i) {
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[i]);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbPerTokenScaleVMTE2List[i]);
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
}
}
CATLASS_DEVICE
void UpdateParams(Params const &params_)
{
params = params_;
}
CATLASS_DEVICE
void operator()(GemmCoord const &blockShapeMNK, GemmCoord const &blockCoordMNK,
GemmCoord const &actualBlockShapeMNK, AscendC::GlobalTensor<ElementC> const &gmBlockC,
LayoutC const &layoutBlockC, Callback &&callback = Callback{})
{
if (actualBlockShapeMNK.k() == 0) {
return;
}
callback();
// Calculate the offset of the current block
MatrixCoord blockShape = blockShapeMNK.GetCoordMN();
MatrixCoord blockCoord = blockCoordMNK.GetCoordMN();
MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN();
MatrixCoord blockOffset = blockCoord * blockShape;
AscendC::GlobalTensor<ElementScale> gmScale;
gmScale.SetGlobalBuffer(params.ptrScale);
AscendC::GlobalTensor<ElementPerTokenScale> gmPerTokenScale;
gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale);
AscendC::GlobalTensor<ElementD> gmD;
gmD.SetGlobalBuffer(params.ptrD);
auto ubTileStride = MakeCoord(static_cast<int64_t>(TileShape::COLUMN), 1L);
auto tileShape = TileShape::ToCoord();
EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape);
uint32_t tileLoops = epilogueTileSwizzle.GetLoops();
uint32_t subblockIdx = AscendC::GetSubBlockIdx();
uint32_t subblockNum = AscendC::GetSubBlockNum();
for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; loopIdx += subblockNum) {
auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx);
auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord);
auto tileOffsetInBlock = tileCoord * tileShape;
auto tileOffset = blockOffset + tileOffsetInBlock;
auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)];
auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape);
auto &ubC = ubCList[ubListId];
LayoutC layoutUbC{actualTileShape, ubTileStride};
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
auto scaleTileOffset = tileOffset.template GetCoordByAxis<1>();
auto scaleTileShape = actualTileShape.template GetCoordByAxis<1>();
auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)];
auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape);
auto &ubScale = ubScaleList[ubListId];
auto layoutUbScale = LayoutScale::template MakeLayoutInUb<ElementScale>(scaleTileShape);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[ubListId]);
copyGmToUbScale(ubScale, gmTileScale, layoutUbScale, layoutGmTileScale);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbScaleMTE2VList[ubListId]);
auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>();
auto perTokenScaleTileShape = actualTileShape.template GetCoordByAxis<0>();
auto gmTilePerTokenScale = gmPerTokenScale[params.layoutPerTokenScale.GetOffset(perTokenScaleTileOffset)];
auto layoutGmTilePerTokenScale = params.layoutPerTokenScale.GetTileLayout(perTokenScaleTileShape);
auto &ubPerTokenScale = ubPerTokenScaleList[ubListId];
auto layoutUbPerTokenScale =
LayoutScale::template MakeLayoutInUb<ElementPerTokenScale>(perTokenScaleTileShape);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbPerTokenScaleVMTE2List[ubListId]);
copyGmToUbPerTokenScale(ubPerTokenScale, gmTilePerTokenScale, layoutUbPerTokenScale,
layoutGmTilePerTokenScale);
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbPerTokenScaleMTE2VList[ubListId]);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_RINT, TileShape::COUNT);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbScaleMTE2VList[ubListId]);
AscendC::Cast(ubScaleFp32, ubScale, AscendC::RoundMode::CAST_NONE, TileShape::COLUMN);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[ubListId]);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbPerTokenScaleMTE2VList[ubListId]);
AscendC::Cast(ubPerTokenScaleFp32, ubPerTokenScale, AscendC::RoundMode::CAST_NONE, TileShape::ROW);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbPerTokenScaleVMTE2List[ubListId]);
tileRowBroadcastMul(ubMul, ubCFp32, ubScaleFp32);
tileBroadcastOneBlk(ubPerTokenScaleFp32Brcb, ubPerTokenScaleFp32);
AscendC::PipeBarrier<PIPE_V>();
tileOneBlkColumnBroadcastMul(ubPerTokenMul, ubMul, ubPerTokenScaleFp32Brcb);
AscendC::PipeBarrier<PIPE_V>();
auto &ubD = ubDList[ubListId];
LayoutD layoutUbD{actualTileShape, ubTileStride};
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
AscendC::Cast(ubD, ubPerTokenMul, AscendC::RoundMode::CAST_RINT, TileShape::COUNT);
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
auto gmTileD = gmD[params.layoutD.GetOffset(tileOffset)];
auto layoutGmTileD = params.layoutD.GetTileLayout(actualTileShape);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD);
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0;
}
}
private:
Params params;
AscendC::LocalTensor<ElementC> ubCList[UB_STAGES];
AscendC::LocalTensor<ElementScale> ubScaleList[UB_STAGES];
AscendC::LocalTensor<ElementPerTokenScale> ubPerTokenScaleList[UB_STAGES];
AscendC::LocalTensor<ElementD> ubDList[UB_STAGES];
int32_t eventUbCVMTE2List[UB_STAGES];
int32_t eventUbCMTE2VList[UB_STAGES];
int32_t eventUbScaleVMTE2List[UB_STAGES];
int32_t eventUbScaleMTE2VList[UB_STAGES];
int32_t eventUbPerTokenScaleVMTE2List[UB_STAGES];
int32_t eventUbPerTokenScaleMTE2VList[UB_STAGES];
int32_t eventUbDMTE3VList[UB_STAGES];
int32_t eventUbDVMTE3List[UB_STAGES];
uint32_t ubListId{0};
AscendC::LocalTensor<float> ubCFp32;
AscendC::LocalTensor<float> ubScaleFp32;
AscendC::LocalTensor<float> ubMul;
AscendC::LocalTensor<float> ubPerTokenScaleFp32;
AscendC::LocalTensor<float> ubPerTokenScaleFp32Brcb;
AscendC::LocalTensor<float> ubPerTokenMul;
TileRowBroadcastMul tileRowBroadcastMul;
TileBroadcastOneBlk tileBroadcastOneBlk;
TileOneBlkColumnBroadcastMul tileOneBlkColumnBroadcastMul;
CopyGmToUbC copyGmToUbC;
CopyGmToUbScale copyGmToUbScale;
CopyGmToUbPerTokenScale copyGmToUbPerTokenScale;
CopyUbToGmD copyUbToGmD;
};
template <uint32_t UB_STAGES_, uint32_t EXEC_FLAG_, class CType_, class LayoutScale_, class LayoutPerTokenScale_,
class DType_, class TileRowBroadcastMul_, class TileBroadcastOneBlk_, class TileOneBlkColumnBroadcastMul_,
class TileCopy_, class EpilogueTileSwizzle_>
class BlockEpilogue<EpilogueAtlasA2PerTokenDequantCombine<UB_STAGES_, EXEC_FLAG_>, CType_, Gemm::GemmType<float, LayoutScale_>,
Gemm::GemmType<float, LayoutPerTokenScale_>, DType_, TileRowBroadcastMul_, TileBroadcastOneBlk_,
TileOneBlkColumnBroadcastMul_, TileCopy_, EpilogueTileSwizzle_>
class BlockEpilogue<EpilogueAtlasA2PerTokenDequantCombine<UB_STAGES_, EXEC_FLAG_>,
CType_, Gemm::GemmType<ScaleType_, LayoutScale_>, Gemm::GemmType<float, LayoutPerTokenScale_>, DType_,
TileRowBroadcastMul_, TileBroadcastOneBlk_, TileOneBlkColumnBroadcastMul_,
TileCopy_, EpilogueTileSwizzle_>
{
public:
using DispatchPolicy = EpilogueAtlasA2PerTokenDequantCombine<UB_STAGES_, EXEC_FLAG_>;
@@ -327,7 +39,8 @@ public:
// Data infos
using ElementC = typename CType_::Element;
using LayoutC = typename CType_::Layout;
using ElementScale = float;
using ElementRawScale = ScaleType_;
using ElementFp32Scale = float;
using LayoutScale = LayoutScale_;
using ElementPerTokenScale = float;
using LayoutPerTokenScale = LayoutPerTokenScale_;
@@ -362,14 +75,16 @@ public:
std::is_same_v<TileShape, typename TileOneBlkColumnBroadcastMul::TileShape>,
"TileShape must be consistent for all tile compute ops");
static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COLUMN * sizeof(ElementScale) +
static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) +
(std::is_same_v<ElementRawScale, ElementFp32Scale> ?
0 : TileShape::COLUMN * sizeof(ElementRawScale)) +
TileShape::COLUMN * sizeof(ElementFp32Scale) +
TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) +
(TileShape::COUNT + TileShape::COUNT) * sizeof(float) + TileShape::ROW * BYTE_PER_BLK) <=
ArchTag::UB_SIZE,
"TileShape is too large to fit in UB");
struct Params {
__gm__ ElementScale *ptrScale{nullptr};
__gm__ ElementRawScale *ptrScale{nullptr};
LayoutScale layoutScale{};
__gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr};
LayoutPerTokenScale layoutPerTokenScale{};
@@ -380,7 +95,7 @@ public:
Params() {};
CATLASS_DEVICE
Params(__gm__ ElementScale *ptrScale_, LayoutScale const &layoutScale_,
Params(__gm__ ElementRawScale *ptrScale_, LayoutScale const &layoutScale_,
__gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_,
__gm__ ElementD *ptrD_, LayoutD const &layoutD_)
: ptrScale(ptrScale_),
@@ -408,8 +123,12 @@ public:
for (uint32_t i = 0; i < UB_STAGES; ++i) {
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
ubOffset += TileShape::COUNT * sizeof(ElementC);
ubScaleList[i] = resource.ubBuf.template GetBufferByByte<ElementScale>(ubOffset);
ubOffset += TileShape::COLUMN * sizeof(ElementScale);
if constexpr (!std::is_same_v<ElementRawScale, ElementFp32Scale>) {
ubRawScaleList[i] = resource.ubBuf.template GetBufferByByte<ElementRawScale>(ubOffset);
ubOffset += TileShape::COLUMN * sizeof(ElementRawScale);
}
ubFp32ScaleList[i] = resource.ubBuf.template GetBufferByByte<ElementFp32Scale>(ubOffset);
ubOffset += TileShape::COLUMN * sizeof(ElementFp32Scale);
ubPerTokenScaleList[i] = resource.ubBuf.template GetBufferByByte<ElementPerTokenScale>(ubOffset);
ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale);
ubDList[i] = resource.ubBuf.template GetBufferByByte<ElementD>(ubOffset);
@@ -451,22 +170,6 @@ public:
AscendC::DataCopyPad(epSendCountLocal_, epSendCountGM, epSendCntParams, copyPadParams);
AscendC::SetFlag<AscendC::HardEvent::MTE2_S>(eventMTE2S);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_S>(eventMTE2S);
#if ENABLE_EP_SEND_COUNT_HASH
tokenToEpRankHashLocal_ = resource.ubBuf.template GetBufferByByte<int32_t>(ubOffset);
uint32_t maxGroupSendCount = 0;
uint32_t groupSendCount = 0;
for (uint32_t expertIdx = 0; expertIdx < calcInfo.moeExpertPerRankNum_; ++expertIdx) {
uint32_t prevGroupSendCount = groupSendCount;
groupSendCount = epSendCountLocal_.GetValue((expertIdx + 1) * calcInfo.epWorldSize_ - 1);
if (maxGroupSendCount < groupSendCount - prevGroupSendCount) {
maxGroupSendCount = groupSendCount - prevGroupSendCount;
}
}
ubOffset += maxGroupSendCount * sizeof(int32_t);
AlignUbOffset();
// assert: ubOffset <= AscendC::TOTAL_UB_SIZE or
// AscendC::TOTAL_VEC_LOCAL_SIZE
#endif
}
}
@@ -495,28 +198,6 @@ public:
->windowsIn) +
calcInfo.winDataSizeOffset_ + expertLocalId * calcInfo.expertPerSizeOnWin_ + rankId * OPT_RANK_OFFSET;
}
#if ENABLE_EP_SEND_COUNT_HASH
CATLASS_DEVICE void InitTokenToEpRankHashLocalForEpRank(uint32_t &hashOffset, uint32_t epRank, uint32_t copyLen)
{
constexpr uint32_t DUPLICATE_MASK_COUNT = 8;
uint32_t hashOffsetMask = (((uint32_t)hashOffset) & (DUPLICATE_MASK_COUNT - 1));
if (hashOffsetMask != 0) {
uint32_t remainMaskCount = DUPLICATE_MASK_COUNT - hashOffsetMask;
if (copyLen < remainMaskCount) {
remainMaskCount = copyLen;
}
uint64_t copyMask = ((1UL << remainMaskCount) - 1) << hashOffsetMask;
AscendC::Duplicate<int32_t>(tokenToEpRankHashLocal_[hashOffset - hashOffsetMask], epRank, &copyMask, 1, 1,
DUPLICATE_MASK_COUNT);
hashOffset += remainMaskCount;
copyLen -= remainMaskCount;
}
if (copyLen > 0) {
AscendC::Duplicate<int32_t>(tokenToEpRankHashLocal_[hashOffset], epRank, copyLen);
hashOffset += copyLen;
}
}
#endif
CATLASS_DEVICE void SetCombineSendEpRank(uint32_t epRank, uint32_t &remoteEpRank, uint32_t &localEpRank)
{
@@ -542,11 +223,7 @@ public:
uint32_t tokenOffset = offsetD - startToken * calcInfo.axisH_;
uint32_t itToken = startToken;
uint32_t endToken = startToken + layoutGmTileD.shape(0);
#if ENABLE_EP_SEND_COUNT_HASH
uint32_t epRankStart = tokenToEpRankHashLocal_(itToken - startToken);
#else
constexpr uint32_t epRankStart = 0;
#endif
uint32_t sendCount =
expertIdx == 0 && epRankStart == 0 ? 0 : epSendCountLocal_.GetValue(expertOffset + epRankStart - 1);
for (uint32_t epRank = epRankStart; epRank < calcInfo.epWorldSize_ && itToken < endToken; ++epRank) {
@@ -582,20 +259,6 @@ public:
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
expertOffset = expertIdx * calcInfo.epWorldSize_;
#if ENABLE_EP_SEND_COUNT_HASH
if (currentExpertIdx_ != expertIdx) {
uint32_t hashOffset = 0;
uint32_t sendCount = expertIdx == 0 ? 0 : epSendCountLocal_.GetValue(expertOffset - 1);
for (uint32_t epRank = 0; epRank < calcInfo.epWorldSize_; ++epRank) {
uint32_t prevSendCount = sendCount;
sendCount = epSendCountLocal_.GetValue(expertOffset + epRank);
InitTokenToEpRankHashLocalForEpRank(hashOffset, epRank, sendCount - prevSendCount);
}
AscendC::SetFlag<AscendC::HardEvent::V_S>(eventVS);
AscendC::WaitFlag<AscendC::HardEvent::V_S>(eventVS);
currentExpertIdx_ = expertIdx;
}
#endif
}
callback();
@@ -605,7 +268,7 @@ public:
MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN();
MatrixCoord blockOffset = blockCoord * blockShape;
AscendC::GlobalTensor<ElementScale> gmScale;
AscendC::GlobalTensor<ElementRawScale> gmScale;
gmScale.SetGlobalBuffer(params.ptrScale);
AscendC::GlobalTensor<ElementPerTokenScale> gmPerTokenScale;
gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale);
@@ -640,11 +303,16 @@ public:
auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)];
auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape);
auto &ubScale = ubScaleList[ubListId];
auto layoutUbScale = LayoutScale::template MakeLayoutInUb<ElementScale>(scaleTileShape);
auto &ubFp32Scale = ubFp32ScaleList[ubListId];
auto layoutFp32UbScale = LayoutScale::template MakeLayoutInUb<ElementFp32Scale>(scaleTileShape);
auto &ubRawScale = ubRawScaleList[ubListId];
auto layoutRawUbScale = LayoutScale::template MakeLayoutInUb<ElementRawScale>(scaleTileShape);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[ubListId]);
copyGmToUbScale(ubScale, gmTileScale, layoutUbScale, layoutGmTileScale);
if constexpr (!std::is_same_v<ElementRawScale, ElementFp32Scale>) {
copyGmToUbScale(ubRawScale, gmTileScale, layoutRawUbScale, layoutGmTileScale);
} else {
copyGmToUbScale(ubFp32Scale, gmTileScale, layoutFp32UbScale, layoutGmTileScale);
}
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbScaleMTE2VList[ubListId]);
auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>();
@@ -667,7 +335,11 @@ public:
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbScaleMTE2VList[ubListId]);
tileRowBroadcastMul(ubMul, ubCFp32, ubScale);
if constexpr (!std::is_same_v<ElementRawScale, ElementFp32Scale>) {
AscendC::Cast(ubFp32Scale, ubRawScale, AscendC::RoundMode::CAST_NONE, TileShape::COLUMN);
AscendC::PipeBarrier<PIPE_V>();
}
tileRowBroadcastMul(ubMul, ubCFp32, ubFp32Scale);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[ubListId]);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbPerTokenScaleMTE2VList[ubListId]);
@@ -709,7 +381,8 @@ private:
MoeDistributeCombineImpl::CombineCalcInfo calcInfo;
AscendC::LocalTensor<ElementC> ubCList[UB_STAGES];
AscendC::LocalTensor<ElementScale> ubScaleList[UB_STAGES];
AscendC::LocalTensor<ElementRawScale> ubRawScaleList[UB_STAGES];
AscendC::LocalTensor<ElementFp32Scale> ubFp32ScaleList[UB_STAGES];
AscendC::LocalTensor<ElementPerTokenScale> ubPerTokenScaleList[UB_STAGES];
AscendC::LocalTensor<ElementD> ubDList[UB_STAGES];
@@ -723,10 +396,6 @@ private:
int32_t eventUbDVMTE3List[UB_STAGES];
AscendC::LocalTensor<int32_t> epSendCountLocal_;
#if ENABLE_EP_SEND_COUNT_HASH
AscendC::LocalTensor<int32_t> tokenToEpRankHashLocal_;
uint32_t currentExpertIdx_{static_cast<uint32_t>(-1)};
#endif
size_t ubOffset{0};
int32_t eventVMTE2{0};

View File

@@ -21,13 +21,14 @@
namespace Catlass::Epilogue::Block {
template <uint32_t UB_STAGES_, uint32_t EXEC_FLAG_, class CType_, class LayoutScale_, class LayoutPerTokenScale_,
template <uint32_t UB_STAGES_, uint32_t EXEC_FLAG_,
class CType_, class ScaleType_, class LayoutScale_, class LayoutPerTokenScale_,
class DType_, class TileRowBroadcastMul_, class TileBroadcastOneBlk_, class TileOneBlkColumnBroadcastMul_,
class TileCopy_, class EpilogueTileSwizzle_>
class BlockEpilogue<EpilogueAtlasA2PerTokenDequantSwiglu<UB_STAGES_, EXEC_FLAG_>, CType_,
Gemm::GemmType<float, LayoutScale_>, Gemm::GemmType<float, LayoutPerTokenScale_>, DType_,
TileRowBroadcastMul_, TileBroadcastOneBlk_, TileOneBlkColumnBroadcastMul_, TileCopy_,
EpilogueTileSwizzle_>
class BlockEpilogue<EpilogueAtlasA2PerTokenDequantSwiglu<UB_STAGES_, EXEC_FLAG_>,
CType_, Gemm::GemmType<ScaleType_, LayoutScale_>, Gemm::GemmType<float, LayoutPerTokenScale_>,
DType_, TileRowBroadcastMul_, TileBroadcastOneBlk_, TileOneBlkColumnBroadcastMul_,
TileCopy_, EpilogueTileSwizzle_>
{
public:
using DispatchPolicy = EpilogueAtlasA2PerTokenDequantSwiglu<UB_STAGES_, EXEC_FLAG_>;
@@ -37,7 +38,8 @@ public:
// Data infos
using ElementC = typename CType_::Element;
using LayoutC = typename CType_::Layout;
using ElementScale = float;
using ElementRawScale = ScaleType_;
using ElementFp32Scale = float;
using LayoutScale = LayoutScale_;
using ElementPerTokenScale = float;
using LayoutPerTokenScale = LayoutPerTokenScale_;
@@ -76,14 +78,17 @@ public:
static_assert(UB_STAGES <= 2, "UB stages too large, event id is not enough.");
static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COLUMN * sizeof(ElementScale) +
static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) +
(std::is_same_v<ElementRawScale, ElementFp32Scale> ?
0 : TileShape::COLUMN * sizeof(ElementRawScale)) +
TileShape::COLUMN * sizeof(ElementFp32Scale) +
TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) +
(TileShape::COUNT + TileShape::COUNT) * sizeof(float) + TileShape::ROW * BYTE_PER_BLK) <=
ArchTag::UB_SIZE,
"TileShape is too large to fit in UB");
struct Params {
__gm__ ElementScale *ptrScale{nullptr};
__gm__ ElementRawScale *ptrScale{nullptr};
LayoutScale layoutScale{};
__gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr};
LayoutPerTokenScale layoutPerTokenScale{};
@@ -94,7 +99,7 @@ public:
Params() {};
CATLASS_DEVICE
Params(__gm__ ElementScale *ptrScale_, LayoutScale const &layoutScale_,
Params(__gm__ ElementRawScale *ptrScale_, LayoutScale const &layoutScale_,
__gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_,
__gm__ ElementD *ptrD_, LayoutD const &layoutD_)
: ptrScale(ptrScale_),
@@ -117,8 +122,12 @@ public:
for (uint32_t i = 0; i < UB_STAGES; ++i) {
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
ubOffset += TileShape::COUNT * sizeof(ElementC);
ubScaleList[i] = resource.ubBuf.template GetBufferByByte<ElementScale>(ubOffset);
ubOffset += TileShape::COLUMN * sizeof(ElementScale);
if constexpr (!std::is_same_v<ElementRawScale, ElementFp32Scale>) {
ubRawScaleList[i] = resource.ubBuf.template GetBufferByByte<ElementRawScale>(ubOffset);
ubOffset += TileShape::COLUMN * sizeof(ElementRawScale);
}
ubFp32ScaleList[i] = resource.ubBuf.template GetBufferByByte<ElementFp32Scale>(ubOffset);
ubOffset += TileShape::COLUMN * sizeof(ElementFp32Scale);
ubPerTokenScaleList[i] = resource.ubBuf.template GetBufferByByte<ElementPerTokenScale>(ubOffset);
ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale);
ubDList[i] = resource.ubBuf.template GetBufferByByte<ElementD>(ubOffset);
@@ -177,7 +186,7 @@ public:
MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN();
MatrixCoord blockOffset = blockCoord * blockShape;
bool isLeft = blockOffset.column() < (params.layoutD.shape(1) >> 1);
AscendC::GlobalTensor<ElementScale> gmScale;
AscendC::GlobalTensor<ElementRawScale> gmScale;
gmScale.SetGlobalBuffer(params.ptrScale);
AscendC::GlobalTensor<ElementPerTokenScale> gmPerTokenScale;
gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale);
@@ -212,11 +221,16 @@ public:
auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)];
auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape);
auto &ubScale = ubScaleList[ubListId];
auto layoutUbScale = LayoutScale::template MakeLayoutInUb<ElementScale>(scaleTileShape);
auto &ubFp32Scale = ubFp32ScaleList[ubListId];
auto layoutFp32UbScale = LayoutScale::template MakeLayoutInUb<ElementFp32Scale>(scaleTileShape);
auto &ubRawScale = ubRawScaleList[ubListId];
auto layoutRawUbScale = LayoutScale::template MakeLayoutInUb<ElementRawScale>(scaleTileShape);
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[ubListId]);
copyGmToUbScale(ubScale, gmTileScale, layoutUbScale, layoutGmTileScale);
if constexpr (!std::is_same_v<ElementRawScale, ElementFp32Scale>) {
copyGmToUbScale(ubRawScale, gmTileScale, layoutRawUbScale, layoutGmTileScale);
} else {
copyGmToUbScale(ubFp32Scale, gmTileScale, layoutFp32UbScale, layoutGmTileScale);
}
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbScaleMTE2VList[ubListId]);
auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>();
@@ -238,7 +252,11 @@ public:
AscendC::Cast(ubTmpMxN, ubC, AscendC::RoundMode::CAST_RINT, TileShape::COUNT);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbScaleMTE2VList[ubListId]);
tileRowBroadcastMul(ubTmpMxN, ubTmpMxN, ubScale);
if constexpr (!std::is_same_v<ElementRawScale, ElementFp32Scale>) {
AscendC::Cast(ubFp32Scale, ubRawScale, AscendC::RoundMode::CAST_NONE, TileShape::COLUMN);
AscendC::PipeBarrier<PIPE_V>();
}
tileRowBroadcastMul(ubTmpMxN, ubTmpMxN, ubFp32Scale);
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[ubListId]);
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbPerTokenScaleMTE2VList[ubListId]);
tileBroadcastOneBlk(ubTmpMx32B, ubPerTokenScale);
@@ -279,7 +297,8 @@ private:
Params params;
AscendC::LocalTensor<ElementC> ubCList[UB_STAGES];
AscendC::LocalTensor<ElementScale> ubScaleList[UB_STAGES];
AscendC::LocalTensor<ElementRawScale> ubRawScaleList[UB_STAGES];
AscendC::LocalTensor<ElementFp32Scale> ubFp32ScaleList[UB_STAGES];
AscendC::LocalTensor<ElementPerTokenScale> ubPerTokenScaleList[UB_STAGES];
AscendC::LocalTensor<ElementD> ubDList[UB_STAGES];

View File

@@ -39,7 +39,7 @@ public:
using ElementAccumulator = typename BlockMmad::ElementAccumulator;
using BlockEpilogue = BlockEpilogue_;
using ElementScale = typename BlockEpilogue::ElementScale;
using ElementScale = typename BlockEpilogue::ElementRawScale;
using LayoutScale = typename BlockEpilogue::LayoutScale;
using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale;
using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale;

View File

@@ -354,7 +354,7 @@ __aicore__ inline static void CalQuantRow(const uint32_t column, uint32_t &row)
row = row < MAX_QUANT_ROW_ONCE ? row : MAX_QUANT_ROW_ONCE;
}
template <uint32_t EXEC_FLAG, typename XType_, class BlockMmad_, class BlockEpilogue_, class BlockScheduler_, uint32_t WORKSPACE_STAGES_,
template <TemplateMC2TypeClass, class BlockMmad_, class BlockEpilogue_, class BlockScheduler_, uint32_t WORKSPACE_STAGES_,
class ElementGroupList_>
class GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace
{
@@ -371,7 +371,7 @@ public:
using ElementAccumulator = typename BlockMmad::ElementAccumulator;
using BlockEpilogue = BlockEpilogue_;
using ElementScale = typename BlockEpilogue::ElementScale;
using ElementScale = typename BlockEpilogue::ElementRawScale;
using LayoutScale = typename BlockEpilogue::LayoutScale;
using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale;
using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale;
@@ -388,7 +388,7 @@ public:
static constexpr uint32_t WORKSPACE_STAGES = WORKSPACE_STAGES_;
using ElementGroupList = ElementGroupList_;
using XType = XType_;
using XType = ExpandXType;
// Parameters structure
struct Params {
@@ -1715,7 +1715,7 @@ private:
namespace Catlass::Gemm::Kernel {
template <class BlockMmad_, class BlockEpilogue_, class BlockScheduler_, uint32_t WORKSPACE_STAGES_,
template <TemplateMC2TypeClass, class BlockMmad_, class BlockEpilogue_, class BlockScheduler_, uint32_t WORKSPACE_STAGES_,
class ElementGroupList_>
class GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch
{
@@ -1732,7 +1732,7 @@ public:
using ElementAccumulator = typename BlockMmad::ElementAccumulator;
using BlockEpilogue = BlockEpilogue_;
using ElementScale = typename BlockEpilogue::ElementScale;
using ElementScale = typename BlockEpilogue::ElementRawScale;
using LayoutScale = typename BlockEpilogue::LayoutScale;
using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale;
using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale;
@@ -2017,7 +2017,7 @@ private:
struct AicWaitFunc {
using MatmulKernel = GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch<
BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>;
TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>;
CATLASS_DEVICE
AicWaitFunc() = default;
@@ -2034,7 +2034,7 @@ private:
struct AicSetFunc {
using MatmulKernel = GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch<
BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>;
TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>;
CATLASS_DEVICE
AicSetFunc() = default;

View File

@@ -12,8 +12,8 @@
#include "../common/moe_distribute_base.h"
#define TemplateMC2TypeClass typename ExpandXType, typename ExpandIdxType, bool IsNeedReduceScatter, uint32_t EXEC_FLAG
#define TemplateMC2TypeFunc ExpandXType, ExpandIdxType, IsNeedReduceScatter, EXEC_FLAG
#define TemplateMC2TypeClass typename ExpandXType, typename W1ScaleType, typename W2ScaleType, typename ExpandIdxType, bool IsNeedReduceScatter, uint32_t EXEC_FLAG
#define TemplateMC2TypeFunc ExpandXType, W1ScaleType, W2ScaleType, ExpandIdxType, IsNeedReduceScatter, EXEC_FLAG
#define TemplateDispatchTypeClass \
typename XType, typename ExpandXOutType, bool StaticQuant, bool DynamicQuant, bool IsSmoothScaleExist, \
bool IsNeedAllgater, uint32_t EXEC_FLAG

View File

@@ -275,8 +275,6 @@ class FusionOp(DecodeMoeOps):
torch_npu.Format.FRACTAL_NZ)
gmm2_weight = torch_npu.npu_format_cast(gmm2_weight,
torch_npu.Format.FRACTAL_NZ)
gmm1_weight_scale = gmm1_weight_scale.float()
gmm2_weight_scale = gmm2_weight_scale.float()
if self.dynamic_eplb:
self.gmm1_weight = [

View File

@@ -46,15 +46,12 @@ class VllmEplbAdaptor(EplbAdaptor):
self.model.model.layers[i].mlp.experts.w13_weight_scale_fp32_list
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_scale_list"] = \
self.model.model.layers[i].mlp.experts.w2_weight_scale_list
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_scale_fp32_list"] = \
self.model.model.layers[i].mlp.experts.w2_weight_scale_fp32_list
# TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here
if self.model.quant_config is not None:
self.expert_weight_names = [
"w13_weight_list", "w2_weight_list",
"w13_weight_scale_fp32_list", "w13_weight_offset",
"w2_weight_scale_list", "w2_weight_offset",
"w2_weight_scale_fp32_list"
"w2_weight_scale_list", "w2_weight_offset"
]
else:
self.expert_weight_names = ["w13_weight", "w2_weight"]
@@ -83,8 +80,7 @@ class VllmEplbAdaptor(EplbAdaptor):
self.num_dense_layers) + ".mlp.experts." + name
if name in [
"w13_weight_list", "w2_weight_list",
"w13_weight_scale_fp32_list", "w2_weight_scale_list",
"w2_weight_scale_fp32_list"
"w13_weight_scale_fp32_list", "w2_weight_scale_list"
]:
expert_tensor = self.param_dict[complete_name][0]
expert_tensor = expert_tensor.clone()
@@ -105,7 +101,7 @@ class VllmEplbAdaptor(EplbAdaptor):
if name in [
"w13_weight_list", "w2_weight_list",
"w13_weight_scale_fp32_list",
"w2_weight_scale_list", "w2_weight_scale_fp32_list"
"w2_weight_scale_list"
]:
per_expert_param.append(
self.param_dict["model.layers." + str(layer_idx) +

View File

@@ -243,24 +243,16 @@ class AscendW8A8DynamicFusedMoEMethod:
topk_weights = topk_weights.to(self.in_dtype)
moe_comm_method = get_forward_context().moe_comm_method
# When VLLM_ASCEND_ENABLE_FUSED_MC2 == 2, use dispatch_gmm_combine_decode, need fp32 scale
w2_weight_scale_fp32_flag = (
get_forward_context().moe_comm_type == MoECommType.FUSED_MC2
and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2)
if self.dynamic_eplb:
w1 = layer.w13_weight_list
w1_scale = layer.w13_weight_scale_fp32_list
w2 = layer.w2_weight_list
w2_scale = layer.w2_weight_scale_fp32_list \
if w2_weight_scale_fp32_flag else layer.w2_weight_scale_list
w2_scale = layer.w2_weight_scale_list
else:
w1 = [layer.w13_weight]
w1_scale = [layer.w13_weight_scale_fp32]
w2 = [layer.w2_weight]
w2_scale = [
layer.w2_weight_scale_fp32
if w2_weight_scale_fp32_flag else layer.w2_weight_scale
]
w2_scale = [layer.w2_weight_scale]
fused_scale_flag = (get_forward_context().moe_comm_type
== MoECommType.FUSED_MC2
@@ -302,8 +294,6 @@ class AscendW8A8DynamicFusedMoEMethod:
layer.w13_weight_offset.data.shape[0], -1)
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
layer.w2_weight_scale.data.shape[0], -1)
layer.w2_weight_scale_fp32 = layer.w2_weight_scale.data.to(
torch.float32)
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
layer.w2_weight_offset.data.shape[0], -1)
@@ -328,16 +318,11 @@ class AscendW8A8DynamicFusedMoEMethod:
weight.clone()
for weight in layer.w2_weight_scale.data.unbind(dim=0)
]
layer.w2_weight_scale_fp32_list = [
weight.clone()
for weight in layer.w2_weight_scale_fp32.data.unbind(dim=0)
]
del layer.w13_weight
del layer.w2_weight
del layer.w13_weight_scale
del layer.w13_weight_scale_fp32
del layer.w2_weight_scale
del layer.w2_weight_scale_fp32
torch.npu.empty_cache()