[Feature] Adapt DispathGmmCombineDecode opertor to align with weight scale dtype of small operators. [RFC: issue 5476] (#5755)
### What this PR does / why we need it?
[Feature] Adapt DispathGmmCombineDecode opertor to align with weight
scale dtype of small operators.
- **Before**: weight scale must be float32
- **After**: weight scale can be float32/float16 when x is float16,
float32/bfloat16 when x is float32/bfloat16. And w1 scale can use
different dtype with w2 scale.
More info about this operator, please refer to RFC: issue
https://github.com/vllm-project/vllm-ascend/issues/5476
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
#### Perf
> When scale is of type fp16 or bf16, it will be cast to fp32 internally
within the operator, while the subsequent computations remain unchanged.
Therefore, this PR will introduce an additional cast operation but halve
the memory copy operations for scale . Furthermore, since the scale data
is only a few KB in size and participates in relatively few
computations, its impact is almost negligible compared to major
operations like matrix multiplication. Thus, the theoretical performance
change should be minimal.
test single operator cases from qwen3-235b,
- single A3 node(ep16), 64 moe experts, 4 experts / die (like qwen3-235b
ep32)
- batch=18/32, token_hidden_size 4096, moe_intermediate_size 1536
The test was conducted for 100 rounds, and the average of the last 95
rounds was taken.
| | bs18(us)| bs32(us)|
| -----| -----| -----|
|Without this PR|96.28|108.83|
|With this PR|96.06|107.90|
Note: Single-operator benchmarks represent an ideal scenario. They are
usually only useful for referencing relative changes and may not fully
align with performance data observed within the full model.
#### Acc
test qwen3-235b eplb on a single A3 node(ep16),
with dispatch_gmm_combine_decode
| dataset | version | metric | mode | vllm-api-stream-chat |
|----- | ----- | ----- | ----- | -----|
| aime2024 | 604a78 | accuracy | gen | 83.33 |
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
Signed-off-by: wangqiankun <wangqiankun13@huawei.com>
This commit is contained in:
@@ -17,59 +17,94 @@ public:
|
|||||||
{
|
{
|
||||||
this->Input("x")
|
this->Input("x")
|
||||||
.ParamType(REQUIRED)
|
.ParamType(REQUIRED)
|
||||||
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
|
.DataType({ge::DT_BF16, ge::DT_BF16, ge::DT_BF16, ge::DT_BF16,
|
||||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16})
|
||||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||||
|
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||||
|
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||||
|
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||||
this->Input("expert_ids")
|
this->Input("expert_ids")
|
||||||
.ParamType(REQUIRED)
|
.ParamType(REQUIRED)
|
||||||
.DataType({ge::DT_INT32, ge::DT_INT32})
|
.DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32,
|
||||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32})
|
||||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||||
|
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||||
|
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||||
|
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||||
this->Input("gmm1_permuted_weight")
|
this->Input("gmm1_permuted_weight")
|
||||||
.ParamType(DYNAMIC)
|
.ParamType(DYNAMIC)
|
||||||
.DataType({ge::DT_INT8, ge::DT_INT8})
|
.DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8,
|
||||||
.Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
|
ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8})
|
||||||
.UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
|
.Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ,
|
||||||
|
ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
|
||||||
|
.UnknownShapeFormat(
|
||||||
|
{ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ,
|
||||||
|
ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
|
||||||
this->Input("gmm1_permuted_weight_scale")
|
this->Input("gmm1_permuted_weight_scale")
|
||||||
.ParamType(DYNAMIC)
|
.ParamType(DYNAMIC)
|
||||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
|
.DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_BF16, ge::DT_BF16,
|
||||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_FLOAT16})
|
||||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||||
|
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||||
|
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||||
|
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||||
this->Input("gmm2_weight")
|
this->Input("gmm2_weight")
|
||||||
.ParamType(DYNAMIC)
|
.ParamType(DYNAMIC)
|
||||||
.DataType({ge::DT_INT8, ge::DT_INT8})
|
.DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8,
|
||||||
.Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
|
ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8})
|
||||||
.UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
|
.Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ,
|
||||||
|
ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ})
|
||||||
|
.UnknownShapeFormat(
|
||||||
|
{ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ,
|
||||||
|
ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ});
|
||||||
this->Input("gmm2_weight_scale")
|
this->Input("gmm2_weight_scale")
|
||||||
.ParamType(DYNAMIC)
|
.ParamType(DYNAMIC)
|
||||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
|
.DataType({ge::DT_FLOAT, ge::DT_BF16, ge::DT_FLOAT, ge::DT_BF16,
|
||||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_FLOAT16})
|
||||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||||
|
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||||
|
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||||
|
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||||
this->Input("expert_scales")
|
this->Input("expert_scales")
|
||||||
.ParamType(REQUIRED)
|
.ParamType(REQUIRED)
|
||||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
|
.DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT,
|
||||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT})
|
||||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||||
|
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||||
|
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||||
|
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||||
this->Input("expert_smooth_scales")
|
this->Input("expert_smooth_scales")
|
||||||
.ParamType(OPTIONAL)
|
.ParamType(OPTIONAL)
|
||||||
.DataType({ge::DT_FLOAT, ge::DT_FLOAT})
|
.DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT,
|
||||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT})
|
||||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||||
|
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||||
|
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||||
|
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||||
this->Input("x_active_mask")
|
this->Input("x_active_mask")
|
||||||
.ParamType(OPTIONAL)
|
.ParamType(OPTIONAL)
|
||||||
.DataType({ge::DT_BOOL, ge::DT_BOOL})
|
.DataType({ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL,
|
||||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL})
|
||||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||||
|
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||||
|
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||||
|
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||||
this->Output("output")
|
this->Output("output")
|
||||||
.ParamType(REQUIRED)
|
.ParamType(REQUIRED)
|
||||||
.DataType({ge::DT_BF16, ge::DT_FLOAT16})
|
.DataType({ge::DT_BF16, ge::DT_BF16, ge::DT_BF16, ge::DT_BF16,
|
||||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16})
|
||||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||||
|
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||||
|
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||||
|
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||||
this->Output("expert_token_nums")
|
this->Output("expert_token_nums")
|
||||||
.ParamType(REQUIRED)
|
.ParamType(REQUIRED)
|
||||||
.DataType({ge::DT_INT64, ge::DT_INT64})
|
.DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64,
|
||||||
.Format({ge::FORMAT_ND, ge::FORMAT_ND})
|
ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64})
|
||||||
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND});
|
.Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||||
|
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND})
|
||||||
|
.UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND,
|
||||||
|
ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND});
|
||||||
this->Attr("group_ep").String();
|
this->Attr("group_ep").String();
|
||||||
this->Attr("ep_rank_size").Int();
|
this->Attr("ep_rank_size").Int();
|
||||||
this->Attr("ep_rank_id").Int();
|
this->Attr("ep_rank_id").Int();
|
||||||
|
|||||||
@@ -27,7 +27,8 @@ extern "C" __global__ __aicore__ void dispatch_gmm_combine_decode(
|
|||||||
GET_TILING_DATA(tiling_data, tiling);
|
GET_TILING_DATA(tiling_data, tiling);
|
||||||
if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1) || TILING_KEY_IS(2) || TILING_KEY_IS(3) ||
|
if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1) || TILING_KEY_IS(2) || TILING_KEY_IS(3) ||
|
||||||
TILING_KEY_IS(4) || TILING_KEY_IS(5) || TILING_KEY_IS(6) || TILING_KEY_IS(7)) {
|
TILING_KEY_IS(4) || TILING_KEY_IS(5) || TILING_KEY_IS(6) || TILING_KEY_IS(7)) {
|
||||||
DispatchGmmCombineDecode<DTYPE_X, int32_t, false, TILING_KEY_VAR> op;
|
DispatchGmmCombineDecode<
|
||||||
|
DTYPE_X, DTYPE_GMM1_PERMUTED_WEIGHT_SCALE, DTYPE_GMM2_WEIGHT_SCALE, int32_t, false, TILING_KEY_VAR> op;
|
||||||
op.Init(x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, gmm2_weight_scale,
|
op.Init(x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, gmm2_weight_scale,
|
||||||
expert_scales, expert_smooth_scales, x_active_mask, output, expertTokenNums, workspace, nullptr, &tiling_data);
|
expert_scales, expert_smooth_scales, x_active_mask, output, expertTokenNums, workspace, nullptr, &tiling_data);
|
||||||
op.Process();
|
op.Process();
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ using Gmm2DispatchPolicy =
|
|||||||
GMM2_L0A_STAGES, GMM2_L0B_STAGES, CUSTOM_L0C_STAGES,
|
GMM2_L0A_STAGES, GMM2_L0B_STAGES, CUSTOM_L0C_STAGES,
|
||||||
CUSTOM_ENABLE_UNIT_FLAG, CUSTOM_ENABLE_SHUFFLE_K>;
|
CUSTOM_ENABLE_UNIT_FLAG, CUSTOM_ENABLE_SHUFFLE_K>;
|
||||||
|
|
||||||
template <uint32_t EXEC_FLAG, typename XType_, class L1TileShape_, class L0TileShape_, class EpilogueTileShape_,
|
template <TemplateMC2TypeClass, class L1TileShape_, class L0TileShape_, class EpilogueTileShape_,
|
||||||
class BlockScheduler_, class DispatchPolicy_ = MmadAtlasA2Custom>
|
class BlockScheduler_, class DispatchPolicy_ = MmadAtlasA2Custom>
|
||||||
CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA,
|
CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA,
|
||||||
layout::RowMajor layoutA, GM_ADDR gmB, layout::zN layoutB, GM_ADDR gmScale,
|
layout::RowMajor layoutA, GM_ADDR gmB, layout::zN layoutB, GM_ADDR gmScale,
|
||||||
@@ -72,7 +72,6 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun
|
|||||||
using L1TileShape = L1TileShape_;
|
using L1TileShape = L1TileShape_;
|
||||||
using L0TileShape = L0TileShape_;
|
using L0TileShape = L0TileShape_;
|
||||||
|
|
||||||
using XType = XType_;
|
|
||||||
using AType = Gemm::GemmType<int8_t, layout::RowMajor>;
|
using AType = Gemm::GemmType<int8_t, layout::RowMajor>;
|
||||||
using BType = Gemm::GemmType<int8_t, layout::zN>;
|
using BType = Gemm::GemmType<int8_t, layout::zN>;
|
||||||
using CType = Gemm::GemmType<int32_t, layout::RowMajor>;
|
using CType = Gemm::GemmType<int32_t, layout::RowMajor>;
|
||||||
@@ -81,7 +80,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun
|
|||||||
|
|
||||||
constexpr uint32_t ubStages = 1;
|
constexpr uint32_t ubStages = 1;
|
||||||
using EpilogueDispatchPolicy = Epilogue::EpilogueAtlasA2PerTokenDequantSwiglu<ubStages, 0>;
|
using EpilogueDispatchPolicy = Epilogue::EpilogueAtlasA2PerTokenDequantSwiglu<ubStages, 0>;
|
||||||
using ScaleType = Gemm::GemmType<float, layout::VectorLayout>;
|
using ScaleType = Gemm::GemmType<W1ScaleType, layout::VectorLayout>;
|
||||||
using PerTokenScaleType = Gemm::GemmType<float, layout::VectorLayout>;
|
using PerTokenScaleType = Gemm::GemmType<float, layout::VectorLayout>;
|
||||||
using DType = Gemm::GemmType<float, layout::RowMajor>;
|
using DType = Gemm::GemmType<float, layout::RowMajor>;
|
||||||
|
|
||||||
@@ -110,9 +109,9 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun
|
|||||||
using GemmKernel = typename std::conditional<
|
using GemmKernel = typename std::conditional<
|
||||||
(EXEC_FLAG & EXEC_FLAG_DEEP_FUSE),
|
(EXEC_FLAG & EXEC_FLAG_DEEP_FUSE),
|
||||||
Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace<
|
Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace<
|
||||||
EXEC_FLAG, XType, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>,
|
TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>,
|
||||||
Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch<
|
Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch<
|
||||||
BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>>::type;
|
TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>>::type;
|
||||||
|
|
||||||
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
|
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
|
||||||
typename GemmKernel::Params params{problemShape,
|
typename GemmKernel::Params params{problemShape,
|
||||||
@@ -197,7 +196,7 @@ CATLASS_DEVICE void GmmDeq(GemmCoord problemShape, uint32_t groupCount, GM_ADDR
|
|||||||
|
|
||||||
constexpr uint32_t ubStages = 1;
|
constexpr uint32_t ubStages = 1;
|
||||||
using EpilogueDispatchPolicy = Epilogue::EpilogueAtlasA2PerTokenDequantCombine<ubStages, EXEC_FLAG>;
|
using EpilogueDispatchPolicy = Epilogue::EpilogueAtlasA2PerTokenDequantCombine<ubStages, EXEC_FLAG>;
|
||||||
using ScaleType = Gemm::GemmType<float, layout::VectorLayout>;
|
using ScaleType = Gemm::GemmType<W2ScaleType, layout::VectorLayout>;
|
||||||
using PerTokenScaleType = Gemm::GemmType<float, layout::VectorLayout>;
|
using PerTokenScaleType = Gemm::GemmType<float, layout::VectorLayout>;
|
||||||
using DType = Gemm::GemmType<ExpandXType, layout::RowMajor>;
|
using DType = Gemm::GemmType<ExpandXType, layout::RowMajor>;
|
||||||
|
|
||||||
@@ -411,7 +410,7 @@ __aicore__ inline void DispatchGmmCombineDecode<TemplateMC2TypeFunc>::Process()
|
|||||||
Arch::CrossCoreWaitFlag(gmm1AivFinished);
|
Arch::CrossCoreWaitFlag(gmm1AivFinished);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
GmmDeqSwigluQuant<EXEC_FLAG, ExpandXType, Gmm1L1TileShape, Gmm1L0TileShape, Gmm1EpilogueTileShape,
|
GmmDeqSwigluQuant<TemplateMC2TypeFunc, Gmm1L1TileShape, Gmm1L0TileShape, Gmm1EpilogueTileShape,
|
||||||
Gmm1BlockScheduler>(
|
Gmm1BlockScheduler>(
|
||||||
gmm1ProblemShape, groupCount_, gmGroupList, gmX1, layoutX1, gmPermuteWeight1_, layoutWeight1,
|
gmm1ProblemShape, groupCount_, gmGroupList, gmX1, layoutX1, gmPermuteWeight1_, layoutWeight1,
|
||||||
gmPermuteScale1_, layoutW1Scale, gmX1Scale, layoutX1Scale, gmX2, layoutX2, gmX2Scale,
|
gmPermuteScale1_, layoutW1Scale, gmX1Scale, layoutX1Scale, gmX2, layoutX2, gmX2Scale,
|
||||||
|
|||||||
@@ -19,304 +19,16 @@
|
|||||||
#include "catlass/layout/layout.hpp"
|
#include "catlass/layout/layout.hpp"
|
||||||
#include "catlass/matrix_coord.hpp"
|
#include "catlass/matrix_coord.hpp"
|
||||||
|
|
||||||
#define ENABLE_EP_SEND_COUNT_HASH 0
|
|
||||||
|
|
||||||
namespace Catlass::Epilogue::Block {
|
namespace Catlass::Epilogue::Block {
|
||||||
|
|
||||||
template <uint32_t UB_STAGES_, uint32_t EXEC_FLAG_, class CType_, class ScaleType_, class PerTokenScaleType_,
|
template <uint32_t UB_STAGES_, uint32_t EXEC_FLAG_,
|
||||||
class DType_, class TileRowBroadcastMul_, class TileBroadcastOneBlk_, class TileOneBlkColumnBroadcastMul_,
|
class CType_, class ScaleType_, class LayoutScale_, class LayoutPerTokenScale_, class DType_,
|
||||||
|
class TileRowBroadcastMul_, class TileBroadcastOneBlk_, class TileOneBlkColumnBroadcastMul_,
|
||||||
class TileCopy_, class EpilogueTileSwizzle_>
|
class TileCopy_, class EpilogueTileSwizzle_>
|
||||||
class BlockEpilogue<EpilogueAtlasA2PerTokenDequantCombine<UB_STAGES_, EXEC_FLAG_>, CType_, ScaleType_, PerTokenScaleType_,
|
class BlockEpilogue<EpilogueAtlasA2PerTokenDequantCombine<UB_STAGES_, EXEC_FLAG_>,
|
||||||
DType_, TileRowBroadcastMul_, TileBroadcastOneBlk_, TileOneBlkColumnBroadcastMul_, TileCopy_,
|
CType_, Gemm::GemmType<ScaleType_, LayoutScale_>, Gemm::GemmType<float, LayoutPerTokenScale_>, DType_,
|
||||||
EpilogueTileSwizzle_>
|
TileRowBroadcastMul_, TileBroadcastOneBlk_, TileOneBlkColumnBroadcastMul_,
|
||||||
{
|
TileCopy_, EpilogueTileSwizzle_>
|
||||||
public:
|
|
||||||
using DispatchPolicy = EpilogueAtlasA2PerTokenDequantCombine<UB_STAGES_, EXEC_FLAG_>;
|
|
||||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
|
||||||
static constexpr uint32_t UB_STAGES = UB_STAGES_;
|
|
||||||
|
|
||||||
// Data infos
|
|
||||||
using ElementC = typename CType_::Element;
|
|
||||||
using LayoutC = typename CType_::Layout;
|
|
||||||
using ElementScale = typename ScaleType_::Element;
|
|
||||||
using LayoutScale = typename ScaleType_::Layout;
|
|
||||||
using ElementPerTokenScale = typename PerTokenScaleType_::Element;
|
|
||||||
using LayoutPerTokenScale = typename PerTokenScaleType_::Layout;
|
|
||||||
using ElementD = typename DType_::Element;
|
|
||||||
using LayoutD = typename DType_::Layout;
|
|
||||||
|
|
||||||
// Check data infos
|
|
||||||
static_assert(std::is_same_v<ElementC, int32_t> &&
|
|
||||||
(std::is_same_v<ElementD, half> || std::is_same_v<ElementD, bfloat16_t>) &&
|
|
||||||
std::is_same_v<ElementScale, ElementD> && std::is_same_v<ElementPerTokenScale, ElementD>,
|
|
||||||
"The element type template parameters of BlockEpilogue are wrong");
|
|
||||||
static_assert(std::is_same_v<LayoutC, layout::RowMajor> && std::is_same_v<LayoutScale, layout::VectorLayout> &&
|
|
||||||
std::is_same_v<LayoutPerTokenScale, layout::VectorLayout> &&
|
|
||||||
std::is_same_v<LayoutD, layout::RowMajor>,
|
|
||||||
"The layout template parameters of BlockEpilogue are wrong");
|
|
||||||
|
|
||||||
// Tile compute ops
|
|
||||||
using TileRowBroadcastMul = TileRowBroadcastMul_;
|
|
||||||
using TileBroadcastOneBlk = TileBroadcastOneBlk_;
|
|
||||||
using TileOneBlkColumnBroadcastMul = TileOneBlkColumnBroadcastMul_;
|
|
||||||
|
|
||||||
// Tile copy
|
|
||||||
using CopyGmToUbC = typename TileCopy_::CopyGmToUbC;
|
|
||||||
using CopyGmToUbScale = typename TileCopy_::CopyGmToUbX;
|
|
||||||
using CopyGmToUbPerTokenScale = typename TileCopy_::CopyGmToUbY;
|
|
||||||
using CopyUbToGmD = typename TileCopy_::CopyUbToGmD;
|
|
||||||
|
|
||||||
using EpilogueTileSwizzle = EpilogueTileSwizzle_;
|
|
||||||
|
|
||||||
using TileShape = typename TileRowBroadcastMul::TileShape;
|
|
||||||
|
|
||||||
static_assert(TileShape::ROW == TileBroadcastOneBlk::COMPUTE_LENGTH &&
|
|
||||||
std::is_same_v<TileShape, typename TileOneBlkColumnBroadcastMul::TileShape>,
|
|
||||||
"TileShape must be consistent for all tile compute ops");
|
|
||||||
|
|
||||||
static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COLUMN * sizeof(ElementScale) +
|
|
||||||
TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) +
|
|
||||||
(TileShape::COUNT + TileShape::COLUMN + TileShape::COUNT + TileShape::ROW) * sizeof(float) +
|
|
||||||
TileShape::ROW * BYTE_PER_BLK) <= ArchTag::UB_SIZE,
|
|
||||||
"TileShape is too large to fit in UB");
|
|
||||||
|
|
||||||
struct Params {
|
|
||||||
__gm__ ElementScale *ptrScale{nullptr};
|
|
||||||
LayoutScale layoutScale{};
|
|
||||||
__gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr};
|
|
||||||
LayoutPerTokenScale layoutPerTokenScale{};
|
|
||||||
__gm__ ElementD *ptrD{nullptr};
|
|
||||||
LayoutD layoutD{};
|
|
||||||
|
|
||||||
CATLASS_DEVICE
|
|
||||||
Params() {};
|
|
||||||
|
|
||||||
CATLASS_DEVICE
|
|
||||||
Params(__gm__ ElementScale *ptrScale_, LayoutScale const &layoutScale_,
|
|
||||||
__gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_,
|
|
||||||
__gm__ ElementD *ptrD_, LayoutD const &layoutD_)
|
|
||||||
: ptrScale(ptrScale_),
|
|
||||||
layoutScale(layoutScale_),
|
|
||||||
ptrPerTokenScale(ptrPerTokenScale_),
|
|
||||||
layoutPerTokenScale(layoutPerTokenScale_),
|
|
||||||
ptrD(ptrD_),
|
|
||||||
layoutD(layoutD_)
|
|
||||||
{}
|
|
||||||
};
|
|
||||||
|
|
||||||
CATLASS_DEVICE
|
|
||||||
BlockEpilogue(Arch::Resource<ArchTag> const &resource, Params const ¶ms = Params{}) : params(params)
|
|
||||||
{
|
|
||||||
size_t ubOffset = 0;
|
|
||||||
int32_t eventVMTE2 = 0;
|
|
||||||
int32_t eventMTE2V = 0;
|
|
||||||
int32_t eventMTE3V = 0;
|
|
||||||
int32_t eventVMTE3 = 0;
|
|
||||||
for (uint32_t i = 0; i < UB_STAGES; ++i) {
|
|
||||||
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
|
|
||||||
ubOffset += TileShape::COUNT * sizeof(ElementC);
|
|
||||||
ubScaleList[i] = resource.ubBuf.template GetBufferByByte<ElementScale>(ubOffset);
|
|
||||||
ubOffset += TileShape::COLUMN * sizeof(ElementScale);
|
|
||||||
ubPerTokenScaleList[i] = resource.ubBuf.template GetBufferByByte<ElementPerTokenScale>(ubOffset);
|
|
||||||
ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale);
|
|
||||||
ubDList[i] = resource.ubBuf.template GetBufferByByte<ElementD>(ubOffset);
|
|
||||||
ubOffset += TileShape::COUNT * sizeof(ElementD);
|
|
||||||
|
|
||||||
eventUbCVMTE2List[i] = eventVMTE2++;
|
|
||||||
eventUbCMTE2VList[i] = eventMTE2V++;
|
|
||||||
eventUbScaleVMTE2List[i] = eventVMTE2++;
|
|
||||||
eventUbScaleMTE2VList[i] = eventMTE2V++;
|
|
||||||
eventUbPerTokenScaleVMTE2List[i] = eventVMTE2++;
|
|
||||||
eventUbPerTokenScaleMTE2VList[i] = eventMTE2V++;
|
|
||||||
eventUbDMTE3VList[i] = eventMTE3V++;
|
|
||||||
eventUbDVMTE3List[i] = eventVMTE3++;
|
|
||||||
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[i]);
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbPerTokenScaleVMTE2List[i]);
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
|
|
||||||
}
|
|
||||||
ubCFp32 = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
|
|
||||||
ubOffset += TileShape::COUNT * sizeof(float);
|
|
||||||
ubScaleFp32 = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
|
|
||||||
ubOffset += TileShape::COLUMN * sizeof(float);
|
|
||||||
ubMul = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
|
|
||||||
ubOffset += TileShape::COUNT * sizeof(float);
|
|
||||||
ubPerTokenScaleFp32 = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
|
|
||||||
ubOffset += TileShape::ROW * sizeof(float);
|
|
||||||
ubPerTokenScaleFp32Brcb = resource.ubBuf.template GetBufferByByte<float>(ubOffset);
|
|
||||||
ubOffset += TileShape::ROW * BYTE_PER_BLK;
|
|
||||||
ubPerTokenMul = ubMul;
|
|
||||||
}
|
|
||||||
|
|
||||||
CATLASS_DEVICE
|
|
||||||
~BlockEpilogue()
|
|
||||||
{
|
|
||||||
for (uint32_t i = 0; i < UB_STAGES; ++i) {
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[i]);
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[i]);
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbPerTokenScaleVMTE2List[i]);
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
CATLASS_DEVICE
|
|
||||||
void UpdateParams(Params const ¶ms_)
|
|
||||||
{
|
|
||||||
params = params_;
|
|
||||||
}
|
|
||||||
|
|
||||||
CATLASS_DEVICE
|
|
||||||
void operator()(GemmCoord const &blockShapeMNK, GemmCoord const &blockCoordMNK,
|
|
||||||
GemmCoord const &actualBlockShapeMNK, AscendC::GlobalTensor<ElementC> const &gmBlockC,
|
|
||||||
LayoutC const &layoutBlockC, Callback &&callback = Callback{})
|
|
||||||
{
|
|
||||||
if (actualBlockShapeMNK.k() == 0) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
callback();
|
|
||||||
|
|
||||||
// Calculate the offset of the current block
|
|
||||||
MatrixCoord blockShape = blockShapeMNK.GetCoordMN();
|
|
||||||
MatrixCoord blockCoord = blockCoordMNK.GetCoordMN();
|
|
||||||
MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN();
|
|
||||||
MatrixCoord blockOffset = blockCoord * blockShape;
|
|
||||||
|
|
||||||
AscendC::GlobalTensor<ElementScale> gmScale;
|
|
||||||
gmScale.SetGlobalBuffer(params.ptrScale);
|
|
||||||
AscendC::GlobalTensor<ElementPerTokenScale> gmPerTokenScale;
|
|
||||||
gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale);
|
|
||||||
AscendC::GlobalTensor<ElementD> gmD;
|
|
||||||
gmD.SetGlobalBuffer(params.ptrD);
|
|
||||||
|
|
||||||
auto ubTileStride = MakeCoord(static_cast<int64_t>(TileShape::COLUMN), 1L);
|
|
||||||
auto tileShape = TileShape::ToCoord();
|
|
||||||
EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape);
|
|
||||||
uint32_t tileLoops = epilogueTileSwizzle.GetLoops();
|
|
||||||
uint32_t subblockIdx = AscendC::GetSubBlockIdx();
|
|
||||||
uint32_t subblockNum = AscendC::GetSubBlockNum();
|
|
||||||
for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; loopIdx += subblockNum) {
|
|
||||||
auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx);
|
|
||||||
auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord);
|
|
||||||
auto tileOffsetInBlock = tileCoord * tileShape;
|
|
||||||
auto tileOffset = blockOffset + tileOffsetInBlock;
|
|
||||||
|
|
||||||
auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)];
|
|
||||||
auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape);
|
|
||||||
|
|
||||||
auto &ubC = ubCList[ubListId];
|
|
||||||
LayoutC layoutUbC{actualTileShape, ubTileStride};
|
|
||||||
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
|
|
||||||
copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC);
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
|
|
||||||
|
|
||||||
auto scaleTileOffset = tileOffset.template GetCoordByAxis<1>();
|
|
||||||
auto scaleTileShape = actualTileShape.template GetCoordByAxis<1>();
|
|
||||||
|
|
||||||
auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)];
|
|
||||||
auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape);
|
|
||||||
|
|
||||||
auto &ubScale = ubScaleList[ubListId];
|
|
||||||
auto layoutUbScale = LayoutScale::template MakeLayoutInUb<ElementScale>(scaleTileShape);
|
|
||||||
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[ubListId]);
|
|
||||||
copyGmToUbScale(ubScale, gmTileScale, layoutUbScale, layoutGmTileScale);
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbScaleMTE2VList[ubListId]);
|
|
||||||
|
|
||||||
auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>();
|
|
||||||
auto perTokenScaleTileShape = actualTileShape.template GetCoordByAxis<0>();
|
|
||||||
|
|
||||||
auto gmTilePerTokenScale = gmPerTokenScale[params.layoutPerTokenScale.GetOffset(perTokenScaleTileOffset)];
|
|
||||||
auto layoutGmTilePerTokenScale = params.layoutPerTokenScale.GetTileLayout(perTokenScaleTileShape);
|
|
||||||
|
|
||||||
auto &ubPerTokenScale = ubPerTokenScaleList[ubListId];
|
|
||||||
auto layoutUbPerTokenScale =
|
|
||||||
LayoutScale::template MakeLayoutInUb<ElementPerTokenScale>(perTokenScaleTileShape);
|
|
||||||
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbPerTokenScaleVMTE2List[ubListId]);
|
|
||||||
copyGmToUbPerTokenScale(ubPerTokenScale, gmTilePerTokenScale, layoutUbPerTokenScale,
|
|
||||||
layoutGmTilePerTokenScale);
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbPerTokenScaleMTE2VList[ubListId]);
|
|
||||||
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbCMTE2VList[ubListId]);
|
|
||||||
AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_RINT, TileShape::COUNT);
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
|
|
||||||
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbScaleMTE2VList[ubListId]);
|
|
||||||
AscendC::Cast(ubScaleFp32, ubScale, AscendC::RoundMode::CAST_NONE, TileShape::COLUMN);
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[ubListId]);
|
|
||||||
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbPerTokenScaleMTE2VList[ubListId]);
|
|
||||||
AscendC::Cast(ubPerTokenScaleFp32, ubPerTokenScale, AscendC::RoundMode::CAST_NONE, TileShape::ROW);
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbPerTokenScaleVMTE2List[ubListId]);
|
|
||||||
|
|
||||||
tileRowBroadcastMul(ubMul, ubCFp32, ubScaleFp32);
|
|
||||||
tileBroadcastOneBlk(ubPerTokenScaleFp32Brcb, ubPerTokenScaleFp32);
|
|
||||||
AscendC::PipeBarrier<PIPE_V>();
|
|
||||||
tileOneBlkColumnBroadcastMul(ubPerTokenMul, ubMul, ubPerTokenScaleFp32Brcb);
|
|
||||||
AscendC::PipeBarrier<PIPE_V>();
|
|
||||||
|
|
||||||
auto &ubD = ubDList[ubListId];
|
|
||||||
LayoutD layoutUbD{actualTileShape, ubTileStride};
|
|
||||||
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
|
|
||||||
AscendC::Cast(ubD, ubPerTokenMul, AscendC::RoundMode::CAST_RINT, TileShape::COUNT);
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
|
|
||||||
|
|
||||||
auto gmTileD = gmD[params.layoutD.GetOffset(tileOffset)];
|
|
||||||
auto layoutGmTileD = params.layoutD.GetTileLayout(actualTileShape);
|
|
||||||
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE3>(eventUbDVMTE3List[ubListId]);
|
|
||||||
copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD);
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::MTE3_V>(eventUbDMTE3VList[ubListId]);
|
|
||||||
|
|
||||||
ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
Params params;
|
|
||||||
|
|
||||||
AscendC::LocalTensor<ElementC> ubCList[UB_STAGES];
|
|
||||||
AscendC::LocalTensor<ElementScale> ubScaleList[UB_STAGES];
|
|
||||||
AscendC::LocalTensor<ElementPerTokenScale> ubPerTokenScaleList[UB_STAGES];
|
|
||||||
AscendC::LocalTensor<ElementD> ubDList[UB_STAGES];
|
|
||||||
|
|
||||||
int32_t eventUbCVMTE2List[UB_STAGES];
|
|
||||||
int32_t eventUbCMTE2VList[UB_STAGES];
|
|
||||||
int32_t eventUbScaleVMTE2List[UB_STAGES];
|
|
||||||
int32_t eventUbScaleMTE2VList[UB_STAGES];
|
|
||||||
int32_t eventUbPerTokenScaleVMTE2List[UB_STAGES];
|
|
||||||
int32_t eventUbPerTokenScaleMTE2VList[UB_STAGES];
|
|
||||||
int32_t eventUbDMTE3VList[UB_STAGES];
|
|
||||||
int32_t eventUbDVMTE3List[UB_STAGES];
|
|
||||||
|
|
||||||
uint32_t ubListId{0};
|
|
||||||
|
|
||||||
AscendC::LocalTensor<float> ubCFp32;
|
|
||||||
AscendC::LocalTensor<float> ubScaleFp32;
|
|
||||||
AscendC::LocalTensor<float> ubMul;
|
|
||||||
AscendC::LocalTensor<float> ubPerTokenScaleFp32;
|
|
||||||
AscendC::LocalTensor<float> ubPerTokenScaleFp32Brcb;
|
|
||||||
AscendC::LocalTensor<float> ubPerTokenMul;
|
|
||||||
|
|
||||||
TileRowBroadcastMul tileRowBroadcastMul;
|
|
||||||
TileBroadcastOneBlk tileBroadcastOneBlk;
|
|
||||||
TileOneBlkColumnBroadcastMul tileOneBlkColumnBroadcastMul;
|
|
||||||
|
|
||||||
CopyGmToUbC copyGmToUbC;
|
|
||||||
CopyGmToUbScale copyGmToUbScale;
|
|
||||||
CopyGmToUbPerTokenScale copyGmToUbPerTokenScale;
|
|
||||||
CopyUbToGmD copyUbToGmD;
|
|
||||||
};
|
|
||||||
|
|
||||||
template <uint32_t UB_STAGES_, uint32_t EXEC_FLAG_, class CType_, class LayoutScale_, class LayoutPerTokenScale_,
|
|
||||||
class DType_, class TileRowBroadcastMul_, class TileBroadcastOneBlk_, class TileOneBlkColumnBroadcastMul_,
|
|
||||||
class TileCopy_, class EpilogueTileSwizzle_>
|
|
||||||
class BlockEpilogue<EpilogueAtlasA2PerTokenDequantCombine<UB_STAGES_, EXEC_FLAG_>, CType_, Gemm::GemmType<float, LayoutScale_>,
|
|
||||||
Gemm::GemmType<float, LayoutPerTokenScale_>, DType_, TileRowBroadcastMul_, TileBroadcastOneBlk_,
|
|
||||||
TileOneBlkColumnBroadcastMul_, TileCopy_, EpilogueTileSwizzle_>
|
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
using DispatchPolicy = EpilogueAtlasA2PerTokenDequantCombine<UB_STAGES_, EXEC_FLAG_>;
|
using DispatchPolicy = EpilogueAtlasA2PerTokenDequantCombine<UB_STAGES_, EXEC_FLAG_>;
|
||||||
@@ -327,7 +39,8 @@ public:
|
|||||||
// Data infos
|
// Data infos
|
||||||
using ElementC = typename CType_::Element;
|
using ElementC = typename CType_::Element;
|
||||||
using LayoutC = typename CType_::Layout;
|
using LayoutC = typename CType_::Layout;
|
||||||
using ElementScale = float;
|
using ElementRawScale = ScaleType_;
|
||||||
|
using ElementFp32Scale = float;
|
||||||
using LayoutScale = LayoutScale_;
|
using LayoutScale = LayoutScale_;
|
||||||
using ElementPerTokenScale = float;
|
using ElementPerTokenScale = float;
|
||||||
using LayoutPerTokenScale = LayoutPerTokenScale_;
|
using LayoutPerTokenScale = LayoutPerTokenScale_;
|
||||||
@@ -362,14 +75,16 @@ public:
|
|||||||
std::is_same_v<TileShape, typename TileOneBlkColumnBroadcastMul::TileShape>,
|
std::is_same_v<TileShape, typename TileOneBlkColumnBroadcastMul::TileShape>,
|
||||||
"TileShape must be consistent for all tile compute ops");
|
"TileShape must be consistent for all tile compute ops");
|
||||||
|
|
||||||
static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COLUMN * sizeof(ElementScale) +
|
static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) +
|
||||||
|
(std::is_same_v<ElementRawScale, ElementFp32Scale> ?
|
||||||
|
0 : TileShape::COLUMN * sizeof(ElementRawScale)) +
|
||||||
|
TileShape::COLUMN * sizeof(ElementFp32Scale) +
|
||||||
TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) +
|
TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) +
|
||||||
(TileShape::COUNT + TileShape::COUNT) * sizeof(float) + TileShape::ROW * BYTE_PER_BLK) <=
|
(TileShape::COUNT + TileShape::COUNT) * sizeof(float) + TileShape::ROW * BYTE_PER_BLK) <=
|
||||||
ArchTag::UB_SIZE,
|
ArchTag::UB_SIZE,
|
||||||
"TileShape is too large to fit in UB");
|
"TileShape is too large to fit in UB");
|
||||||
|
|
||||||
struct Params {
|
struct Params {
|
||||||
__gm__ ElementScale *ptrScale{nullptr};
|
__gm__ ElementRawScale *ptrScale{nullptr};
|
||||||
LayoutScale layoutScale{};
|
LayoutScale layoutScale{};
|
||||||
__gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr};
|
__gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr};
|
||||||
LayoutPerTokenScale layoutPerTokenScale{};
|
LayoutPerTokenScale layoutPerTokenScale{};
|
||||||
@@ -380,7 +95,7 @@ public:
|
|||||||
Params() {};
|
Params() {};
|
||||||
|
|
||||||
CATLASS_DEVICE
|
CATLASS_DEVICE
|
||||||
Params(__gm__ ElementScale *ptrScale_, LayoutScale const &layoutScale_,
|
Params(__gm__ ElementRawScale *ptrScale_, LayoutScale const &layoutScale_,
|
||||||
__gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_,
|
__gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_,
|
||||||
__gm__ ElementD *ptrD_, LayoutD const &layoutD_)
|
__gm__ ElementD *ptrD_, LayoutD const &layoutD_)
|
||||||
: ptrScale(ptrScale_),
|
: ptrScale(ptrScale_),
|
||||||
@@ -408,8 +123,12 @@ public:
|
|||||||
for (uint32_t i = 0; i < UB_STAGES; ++i) {
|
for (uint32_t i = 0; i < UB_STAGES; ++i) {
|
||||||
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
|
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
|
||||||
ubOffset += TileShape::COUNT * sizeof(ElementC);
|
ubOffset += TileShape::COUNT * sizeof(ElementC);
|
||||||
ubScaleList[i] = resource.ubBuf.template GetBufferByByte<ElementScale>(ubOffset);
|
if constexpr (!std::is_same_v<ElementRawScale, ElementFp32Scale>) {
|
||||||
ubOffset += TileShape::COLUMN * sizeof(ElementScale);
|
ubRawScaleList[i] = resource.ubBuf.template GetBufferByByte<ElementRawScale>(ubOffset);
|
||||||
|
ubOffset += TileShape::COLUMN * sizeof(ElementRawScale);
|
||||||
|
}
|
||||||
|
ubFp32ScaleList[i] = resource.ubBuf.template GetBufferByByte<ElementFp32Scale>(ubOffset);
|
||||||
|
ubOffset += TileShape::COLUMN * sizeof(ElementFp32Scale);
|
||||||
ubPerTokenScaleList[i] = resource.ubBuf.template GetBufferByByte<ElementPerTokenScale>(ubOffset);
|
ubPerTokenScaleList[i] = resource.ubBuf.template GetBufferByByte<ElementPerTokenScale>(ubOffset);
|
||||||
ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale);
|
ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale);
|
||||||
ubDList[i] = resource.ubBuf.template GetBufferByByte<ElementD>(ubOffset);
|
ubDList[i] = resource.ubBuf.template GetBufferByByte<ElementD>(ubOffset);
|
||||||
@@ -451,22 +170,6 @@ public:
|
|||||||
AscendC::DataCopyPad(epSendCountLocal_, epSendCountGM, epSendCntParams, copyPadParams);
|
AscendC::DataCopyPad(epSendCountLocal_, epSendCountGM, epSendCntParams, copyPadParams);
|
||||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_S>(eventMTE2S);
|
AscendC::SetFlag<AscendC::HardEvent::MTE2_S>(eventMTE2S);
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_S>(eventMTE2S);
|
AscendC::WaitFlag<AscendC::HardEvent::MTE2_S>(eventMTE2S);
|
||||||
#if ENABLE_EP_SEND_COUNT_HASH
|
|
||||||
tokenToEpRankHashLocal_ = resource.ubBuf.template GetBufferByByte<int32_t>(ubOffset);
|
|
||||||
uint32_t maxGroupSendCount = 0;
|
|
||||||
uint32_t groupSendCount = 0;
|
|
||||||
for (uint32_t expertIdx = 0; expertIdx < calcInfo.moeExpertPerRankNum_; ++expertIdx) {
|
|
||||||
uint32_t prevGroupSendCount = groupSendCount;
|
|
||||||
groupSendCount = epSendCountLocal_.GetValue((expertIdx + 1) * calcInfo.epWorldSize_ - 1);
|
|
||||||
if (maxGroupSendCount < groupSendCount - prevGroupSendCount) {
|
|
||||||
maxGroupSendCount = groupSendCount - prevGroupSendCount;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
ubOffset += maxGroupSendCount * sizeof(int32_t);
|
|
||||||
AlignUbOffset();
|
|
||||||
// assert: ubOffset <= AscendC::TOTAL_UB_SIZE or
|
|
||||||
// AscendC::TOTAL_VEC_LOCAL_SIZE
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -495,28 +198,6 @@ public:
|
|||||||
->windowsIn) +
|
->windowsIn) +
|
||||||
calcInfo.winDataSizeOffset_ + expertLocalId * calcInfo.expertPerSizeOnWin_ + rankId * OPT_RANK_OFFSET;
|
calcInfo.winDataSizeOffset_ + expertLocalId * calcInfo.expertPerSizeOnWin_ + rankId * OPT_RANK_OFFSET;
|
||||||
}
|
}
|
||||||
#if ENABLE_EP_SEND_COUNT_HASH
|
|
||||||
CATLASS_DEVICE void InitTokenToEpRankHashLocalForEpRank(uint32_t &hashOffset, uint32_t epRank, uint32_t copyLen)
|
|
||||||
{
|
|
||||||
constexpr uint32_t DUPLICATE_MASK_COUNT = 8;
|
|
||||||
uint32_t hashOffsetMask = (((uint32_t)hashOffset) & (DUPLICATE_MASK_COUNT - 1));
|
|
||||||
if (hashOffsetMask != 0) {
|
|
||||||
uint32_t remainMaskCount = DUPLICATE_MASK_COUNT - hashOffsetMask;
|
|
||||||
if (copyLen < remainMaskCount) {
|
|
||||||
remainMaskCount = copyLen;
|
|
||||||
}
|
|
||||||
uint64_t copyMask = ((1UL << remainMaskCount) - 1) << hashOffsetMask;
|
|
||||||
AscendC::Duplicate<int32_t>(tokenToEpRankHashLocal_[hashOffset - hashOffsetMask], epRank, ©Mask, 1, 1,
|
|
||||||
DUPLICATE_MASK_COUNT);
|
|
||||||
hashOffset += remainMaskCount;
|
|
||||||
copyLen -= remainMaskCount;
|
|
||||||
}
|
|
||||||
if (copyLen > 0) {
|
|
||||||
AscendC::Duplicate<int32_t>(tokenToEpRankHashLocal_[hashOffset], epRank, copyLen);
|
|
||||||
hashOffset += copyLen;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
CATLASS_DEVICE void SetCombineSendEpRank(uint32_t epRank, uint32_t &remoteEpRank, uint32_t &localEpRank)
|
CATLASS_DEVICE void SetCombineSendEpRank(uint32_t epRank, uint32_t &remoteEpRank, uint32_t &localEpRank)
|
||||||
{
|
{
|
||||||
@@ -542,11 +223,7 @@ public:
|
|||||||
uint32_t tokenOffset = offsetD - startToken * calcInfo.axisH_;
|
uint32_t tokenOffset = offsetD - startToken * calcInfo.axisH_;
|
||||||
uint32_t itToken = startToken;
|
uint32_t itToken = startToken;
|
||||||
uint32_t endToken = startToken + layoutGmTileD.shape(0);
|
uint32_t endToken = startToken + layoutGmTileD.shape(0);
|
||||||
#if ENABLE_EP_SEND_COUNT_HASH
|
|
||||||
uint32_t epRankStart = tokenToEpRankHashLocal_(itToken - startToken);
|
|
||||||
#else
|
|
||||||
constexpr uint32_t epRankStart = 0;
|
constexpr uint32_t epRankStart = 0;
|
||||||
#endif
|
|
||||||
uint32_t sendCount =
|
uint32_t sendCount =
|
||||||
expertIdx == 0 && epRankStart == 0 ? 0 : epSendCountLocal_.GetValue(expertOffset + epRankStart - 1);
|
expertIdx == 0 && epRankStart == 0 ? 0 : epSendCountLocal_.GetValue(expertOffset + epRankStart - 1);
|
||||||
for (uint32_t epRank = epRankStart; epRank < calcInfo.epWorldSize_ && itToken < endToken; ++epRank) {
|
for (uint32_t epRank = epRankStart; epRank < calcInfo.epWorldSize_ && itToken < endToken; ++epRank) {
|
||||||
@@ -582,20 +259,6 @@ public:
|
|||||||
|
|
||||||
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
|
if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) {
|
||||||
expertOffset = expertIdx * calcInfo.epWorldSize_;
|
expertOffset = expertIdx * calcInfo.epWorldSize_;
|
||||||
#if ENABLE_EP_SEND_COUNT_HASH
|
|
||||||
if (currentExpertIdx_ != expertIdx) {
|
|
||||||
uint32_t hashOffset = 0;
|
|
||||||
uint32_t sendCount = expertIdx == 0 ? 0 : epSendCountLocal_.GetValue(expertOffset - 1);
|
|
||||||
for (uint32_t epRank = 0; epRank < calcInfo.epWorldSize_; ++epRank) {
|
|
||||||
uint32_t prevSendCount = sendCount;
|
|
||||||
sendCount = epSendCountLocal_.GetValue(expertOffset + epRank);
|
|
||||||
InitTokenToEpRankHashLocalForEpRank(hashOffset, epRank, sendCount - prevSendCount);
|
|
||||||
}
|
|
||||||
AscendC::SetFlag<AscendC::HardEvent::V_S>(eventVS);
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::V_S>(eventVS);
|
|
||||||
currentExpertIdx_ = expertIdx;
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
callback();
|
callback();
|
||||||
@@ -605,7 +268,7 @@ public:
|
|||||||
MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN();
|
MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN();
|
||||||
MatrixCoord blockOffset = blockCoord * blockShape;
|
MatrixCoord blockOffset = blockCoord * blockShape;
|
||||||
|
|
||||||
AscendC::GlobalTensor<ElementScale> gmScale;
|
AscendC::GlobalTensor<ElementRawScale> gmScale;
|
||||||
gmScale.SetGlobalBuffer(params.ptrScale);
|
gmScale.SetGlobalBuffer(params.ptrScale);
|
||||||
AscendC::GlobalTensor<ElementPerTokenScale> gmPerTokenScale;
|
AscendC::GlobalTensor<ElementPerTokenScale> gmPerTokenScale;
|
||||||
gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale);
|
gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale);
|
||||||
@@ -640,11 +303,16 @@ public:
|
|||||||
auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)];
|
auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)];
|
||||||
auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape);
|
auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape);
|
||||||
|
|
||||||
auto &ubScale = ubScaleList[ubListId];
|
auto &ubFp32Scale = ubFp32ScaleList[ubListId];
|
||||||
auto layoutUbScale = LayoutScale::template MakeLayoutInUb<ElementScale>(scaleTileShape);
|
auto layoutFp32UbScale = LayoutScale::template MakeLayoutInUb<ElementFp32Scale>(scaleTileShape);
|
||||||
|
auto &ubRawScale = ubRawScaleList[ubListId];
|
||||||
|
auto layoutRawUbScale = LayoutScale::template MakeLayoutInUb<ElementRawScale>(scaleTileShape);
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[ubListId]);
|
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[ubListId]);
|
||||||
copyGmToUbScale(ubScale, gmTileScale, layoutUbScale, layoutGmTileScale);
|
if constexpr (!std::is_same_v<ElementRawScale, ElementFp32Scale>) {
|
||||||
|
copyGmToUbScale(ubRawScale, gmTileScale, layoutRawUbScale, layoutGmTileScale);
|
||||||
|
} else {
|
||||||
|
copyGmToUbScale(ubFp32Scale, gmTileScale, layoutFp32UbScale, layoutGmTileScale);
|
||||||
|
}
|
||||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbScaleMTE2VList[ubListId]);
|
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbScaleMTE2VList[ubListId]);
|
||||||
|
|
||||||
auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>();
|
auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>();
|
||||||
@@ -667,7 +335,11 @@ public:
|
|||||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
|
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
|
||||||
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbScaleMTE2VList[ubListId]);
|
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbScaleMTE2VList[ubListId]);
|
||||||
tileRowBroadcastMul(ubMul, ubCFp32, ubScale);
|
if constexpr (!std::is_same_v<ElementRawScale, ElementFp32Scale>) {
|
||||||
|
AscendC::Cast(ubFp32Scale, ubRawScale, AscendC::RoundMode::CAST_NONE, TileShape::COLUMN);
|
||||||
|
AscendC::PipeBarrier<PIPE_V>();
|
||||||
|
}
|
||||||
|
tileRowBroadcastMul(ubMul, ubCFp32, ubFp32Scale);
|
||||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[ubListId]);
|
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[ubListId]);
|
||||||
|
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbPerTokenScaleMTE2VList[ubListId]);
|
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbPerTokenScaleMTE2VList[ubListId]);
|
||||||
@@ -709,7 +381,8 @@ private:
|
|||||||
MoeDistributeCombineImpl::CombineCalcInfo calcInfo;
|
MoeDistributeCombineImpl::CombineCalcInfo calcInfo;
|
||||||
|
|
||||||
AscendC::LocalTensor<ElementC> ubCList[UB_STAGES];
|
AscendC::LocalTensor<ElementC> ubCList[UB_STAGES];
|
||||||
AscendC::LocalTensor<ElementScale> ubScaleList[UB_STAGES];
|
AscendC::LocalTensor<ElementRawScale> ubRawScaleList[UB_STAGES];
|
||||||
|
AscendC::LocalTensor<ElementFp32Scale> ubFp32ScaleList[UB_STAGES];
|
||||||
AscendC::LocalTensor<ElementPerTokenScale> ubPerTokenScaleList[UB_STAGES];
|
AscendC::LocalTensor<ElementPerTokenScale> ubPerTokenScaleList[UB_STAGES];
|
||||||
AscendC::LocalTensor<ElementD> ubDList[UB_STAGES];
|
AscendC::LocalTensor<ElementD> ubDList[UB_STAGES];
|
||||||
|
|
||||||
@@ -723,10 +396,6 @@ private:
|
|||||||
int32_t eventUbDVMTE3List[UB_STAGES];
|
int32_t eventUbDVMTE3List[UB_STAGES];
|
||||||
|
|
||||||
AscendC::LocalTensor<int32_t> epSendCountLocal_;
|
AscendC::LocalTensor<int32_t> epSendCountLocal_;
|
||||||
#if ENABLE_EP_SEND_COUNT_HASH
|
|
||||||
AscendC::LocalTensor<int32_t> tokenToEpRankHashLocal_;
|
|
||||||
uint32_t currentExpertIdx_{static_cast<uint32_t>(-1)};
|
|
||||||
#endif
|
|
||||||
|
|
||||||
size_t ubOffset{0};
|
size_t ubOffset{0};
|
||||||
int32_t eventVMTE2{0};
|
int32_t eventVMTE2{0};
|
||||||
|
|||||||
@@ -21,13 +21,14 @@
|
|||||||
|
|
||||||
namespace Catlass::Epilogue::Block {
|
namespace Catlass::Epilogue::Block {
|
||||||
|
|
||||||
template <uint32_t UB_STAGES_, uint32_t EXEC_FLAG_, class CType_, class LayoutScale_, class LayoutPerTokenScale_,
|
template <uint32_t UB_STAGES_, uint32_t EXEC_FLAG_,
|
||||||
|
class CType_, class ScaleType_, class LayoutScale_, class LayoutPerTokenScale_,
|
||||||
class DType_, class TileRowBroadcastMul_, class TileBroadcastOneBlk_, class TileOneBlkColumnBroadcastMul_,
|
class DType_, class TileRowBroadcastMul_, class TileBroadcastOneBlk_, class TileOneBlkColumnBroadcastMul_,
|
||||||
class TileCopy_, class EpilogueTileSwizzle_>
|
class TileCopy_, class EpilogueTileSwizzle_>
|
||||||
class BlockEpilogue<EpilogueAtlasA2PerTokenDequantSwiglu<UB_STAGES_, EXEC_FLAG_>, CType_,
|
class BlockEpilogue<EpilogueAtlasA2PerTokenDequantSwiglu<UB_STAGES_, EXEC_FLAG_>,
|
||||||
Gemm::GemmType<float, LayoutScale_>, Gemm::GemmType<float, LayoutPerTokenScale_>, DType_,
|
CType_, Gemm::GemmType<ScaleType_, LayoutScale_>, Gemm::GemmType<float, LayoutPerTokenScale_>,
|
||||||
TileRowBroadcastMul_, TileBroadcastOneBlk_, TileOneBlkColumnBroadcastMul_, TileCopy_,
|
DType_, TileRowBroadcastMul_, TileBroadcastOneBlk_, TileOneBlkColumnBroadcastMul_,
|
||||||
EpilogueTileSwizzle_>
|
TileCopy_, EpilogueTileSwizzle_>
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
using DispatchPolicy = EpilogueAtlasA2PerTokenDequantSwiglu<UB_STAGES_, EXEC_FLAG_>;
|
using DispatchPolicy = EpilogueAtlasA2PerTokenDequantSwiglu<UB_STAGES_, EXEC_FLAG_>;
|
||||||
@@ -37,7 +38,8 @@ public:
|
|||||||
// Data infos
|
// Data infos
|
||||||
using ElementC = typename CType_::Element;
|
using ElementC = typename CType_::Element;
|
||||||
using LayoutC = typename CType_::Layout;
|
using LayoutC = typename CType_::Layout;
|
||||||
using ElementScale = float;
|
using ElementRawScale = ScaleType_;
|
||||||
|
using ElementFp32Scale = float;
|
||||||
using LayoutScale = LayoutScale_;
|
using LayoutScale = LayoutScale_;
|
||||||
using ElementPerTokenScale = float;
|
using ElementPerTokenScale = float;
|
||||||
using LayoutPerTokenScale = LayoutPerTokenScale_;
|
using LayoutPerTokenScale = LayoutPerTokenScale_;
|
||||||
@@ -76,14 +78,17 @@ public:
|
|||||||
|
|
||||||
static_assert(UB_STAGES <= 2, "UB stages too large, event id is not enough.");
|
static_assert(UB_STAGES <= 2, "UB stages too large, event id is not enough.");
|
||||||
|
|
||||||
static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COLUMN * sizeof(ElementScale) +
|
static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) +
|
||||||
|
(std::is_same_v<ElementRawScale, ElementFp32Scale> ?
|
||||||
|
0 : TileShape::COLUMN * sizeof(ElementRawScale)) +
|
||||||
|
TileShape::COLUMN * sizeof(ElementFp32Scale) +
|
||||||
TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) +
|
TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) +
|
||||||
(TileShape::COUNT + TileShape::COUNT) * sizeof(float) + TileShape::ROW * BYTE_PER_BLK) <=
|
(TileShape::COUNT + TileShape::COUNT) * sizeof(float) + TileShape::ROW * BYTE_PER_BLK) <=
|
||||||
ArchTag::UB_SIZE,
|
ArchTag::UB_SIZE,
|
||||||
"TileShape is too large to fit in UB");
|
"TileShape is too large to fit in UB");
|
||||||
|
|
||||||
struct Params {
|
struct Params {
|
||||||
__gm__ ElementScale *ptrScale{nullptr};
|
__gm__ ElementRawScale *ptrScale{nullptr};
|
||||||
LayoutScale layoutScale{};
|
LayoutScale layoutScale{};
|
||||||
__gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr};
|
__gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr};
|
||||||
LayoutPerTokenScale layoutPerTokenScale{};
|
LayoutPerTokenScale layoutPerTokenScale{};
|
||||||
@@ -94,7 +99,7 @@ public:
|
|||||||
Params() {};
|
Params() {};
|
||||||
|
|
||||||
CATLASS_DEVICE
|
CATLASS_DEVICE
|
||||||
Params(__gm__ ElementScale *ptrScale_, LayoutScale const &layoutScale_,
|
Params(__gm__ ElementRawScale *ptrScale_, LayoutScale const &layoutScale_,
|
||||||
__gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_,
|
__gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_,
|
||||||
__gm__ ElementD *ptrD_, LayoutD const &layoutD_)
|
__gm__ ElementD *ptrD_, LayoutD const &layoutD_)
|
||||||
: ptrScale(ptrScale_),
|
: ptrScale(ptrScale_),
|
||||||
@@ -117,8 +122,12 @@ public:
|
|||||||
for (uint32_t i = 0; i < UB_STAGES; ++i) {
|
for (uint32_t i = 0; i < UB_STAGES; ++i) {
|
||||||
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
|
ubCList[i] = resource.ubBuf.template GetBufferByByte<ElementC>(ubOffset);
|
||||||
ubOffset += TileShape::COUNT * sizeof(ElementC);
|
ubOffset += TileShape::COUNT * sizeof(ElementC);
|
||||||
ubScaleList[i] = resource.ubBuf.template GetBufferByByte<ElementScale>(ubOffset);
|
if constexpr (!std::is_same_v<ElementRawScale, ElementFp32Scale>) {
|
||||||
ubOffset += TileShape::COLUMN * sizeof(ElementScale);
|
ubRawScaleList[i] = resource.ubBuf.template GetBufferByByte<ElementRawScale>(ubOffset);
|
||||||
|
ubOffset += TileShape::COLUMN * sizeof(ElementRawScale);
|
||||||
|
}
|
||||||
|
ubFp32ScaleList[i] = resource.ubBuf.template GetBufferByByte<ElementFp32Scale>(ubOffset);
|
||||||
|
ubOffset += TileShape::COLUMN * sizeof(ElementFp32Scale);
|
||||||
ubPerTokenScaleList[i] = resource.ubBuf.template GetBufferByByte<ElementPerTokenScale>(ubOffset);
|
ubPerTokenScaleList[i] = resource.ubBuf.template GetBufferByByte<ElementPerTokenScale>(ubOffset);
|
||||||
ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale);
|
ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale);
|
||||||
ubDList[i] = resource.ubBuf.template GetBufferByByte<ElementD>(ubOffset);
|
ubDList[i] = resource.ubBuf.template GetBufferByByte<ElementD>(ubOffset);
|
||||||
@@ -177,7 +186,7 @@ public:
|
|||||||
MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN();
|
MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN();
|
||||||
MatrixCoord blockOffset = blockCoord * blockShape;
|
MatrixCoord blockOffset = blockCoord * blockShape;
|
||||||
bool isLeft = blockOffset.column() < (params.layoutD.shape(1) >> 1);
|
bool isLeft = blockOffset.column() < (params.layoutD.shape(1) >> 1);
|
||||||
AscendC::GlobalTensor<ElementScale> gmScale;
|
AscendC::GlobalTensor<ElementRawScale> gmScale;
|
||||||
gmScale.SetGlobalBuffer(params.ptrScale);
|
gmScale.SetGlobalBuffer(params.ptrScale);
|
||||||
AscendC::GlobalTensor<ElementPerTokenScale> gmPerTokenScale;
|
AscendC::GlobalTensor<ElementPerTokenScale> gmPerTokenScale;
|
||||||
gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale);
|
gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale);
|
||||||
@@ -212,11 +221,16 @@ public:
|
|||||||
auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)];
|
auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)];
|
||||||
auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape);
|
auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape);
|
||||||
|
|
||||||
auto &ubScale = ubScaleList[ubListId];
|
auto &ubFp32Scale = ubFp32ScaleList[ubListId];
|
||||||
auto layoutUbScale = LayoutScale::template MakeLayoutInUb<ElementScale>(scaleTileShape);
|
auto layoutFp32UbScale = LayoutScale::template MakeLayoutInUb<ElementFp32Scale>(scaleTileShape);
|
||||||
|
auto &ubRawScale = ubRawScaleList[ubListId];
|
||||||
|
auto layoutRawUbScale = LayoutScale::template MakeLayoutInUb<ElementRawScale>(scaleTileShape);
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[ubListId]);
|
AscendC::WaitFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[ubListId]);
|
||||||
copyGmToUbScale(ubScale, gmTileScale, layoutUbScale, layoutGmTileScale);
|
if constexpr (!std::is_same_v<ElementRawScale, ElementFp32Scale>) {
|
||||||
|
copyGmToUbScale(ubRawScale, gmTileScale, layoutRawUbScale, layoutGmTileScale);
|
||||||
|
} else {
|
||||||
|
copyGmToUbScale(ubFp32Scale, gmTileScale, layoutFp32UbScale, layoutGmTileScale);
|
||||||
|
}
|
||||||
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbScaleMTE2VList[ubListId]);
|
AscendC::SetFlag<AscendC::HardEvent::MTE2_V>(eventUbScaleMTE2VList[ubListId]);
|
||||||
|
|
||||||
auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>();
|
auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>();
|
||||||
@@ -238,7 +252,11 @@ public:
|
|||||||
AscendC::Cast(ubTmpMxN, ubC, AscendC::RoundMode::CAST_RINT, TileShape::COUNT);
|
AscendC::Cast(ubTmpMxN, ubC, AscendC::RoundMode::CAST_RINT, TileShape::COUNT);
|
||||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
|
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbCVMTE2List[ubListId]);
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbScaleMTE2VList[ubListId]);
|
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbScaleMTE2VList[ubListId]);
|
||||||
tileRowBroadcastMul(ubTmpMxN, ubTmpMxN, ubScale);
|
if constexpr (!std::is_same_v<ElementRawScale, ElementFp32Scale>) {
|
||||||
|
AscendC::Cast(ubFp32Scale, ubRawScale, AscendC::RoundMode::CAST_NONE, TileShape::COLUMN);
|
||||||
|
AscendC::PipeBarrier<PIPE_V>();
|
||||||
|
}
|
||||||
|
tileRowBroadcastMul(ubTmpMxN, ubTmpMxN, ubFp32Scale);
|
||||||
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[ubListId]);
|
AscendC::SetFlag<AscendC::HardEvent::V_MTE2>(eventUbScaleVMTE2List[ubListId]);
|
||||||
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbPerTokenScaleMTE2VList[ubListId]);
|
AscendC::WaitFlag<AscendC::HardEvent::MTE2_V>(eventUbPerTokenScaleMTE2VList[ubListId]);
|
||||||
tileBroadcastOneBlk(ubTmpMx32B, ubPerTokenScale);
|
tileBroadcastOneBlk(ubTmpMx32B, ubPerTokenScale);
|
||||||
@@ -279,7 +297,8 @@ private:
|
|||||||
Params params;
|
Params params;
|
||||||
|
|
||||||
AscendC::LocalTensor<ElementC> ubCList[UB_STAGES];
|
AscendC::LocalTensor<ElementC> ubCList[UB_STAGES];
|
||||||
AscendC::LocalTensor<ElementScale> ubScaleList[UB_STAGES];
|
AscendC::LocalTensor<ElementRawScale> ubRawScaleList[UB_STAGES];
|
||||||
|
AscendC::LocalTensor<ElementFp32Scale> ubFp32ScaleList[UB_STAGES];
|
||||||
AscendC::LocalTensor<ElementPerTokenScale> ubPerTokenScaleList[UB_STAGES];
|
AscendC::LocalTensor<ElementPerTokenScale> ubPerTokenScaleList[UB_STAGES];
|
||||||
AscendC::LocalTensor<ElementD> ubDList[UB_STAGES];
|
AscendC::LocalTensor<ElementD> ubDList[UB_STAGES];
|
||||||
|
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ public:
|
|||||||
using ElementAccumulator = typename BlockMmad::ElementAccumulator;
|
using ElementAccumulator = typename BlockMmad::ElementAccumulator;
|
||||||
|
|
||||||
using BlockEpilogue = BlockEpilogue_;
|
using BlockEpilogue = BlockEpilogue_;
|
||||||
using ElementScale = typename BlockEpilogue::ElementScale;
|
using ElementScale = typename BlockEpilogue::ElementRawScale;
|
||||||
using LayoutScale = typename BlockEpilogue::LayoutScale;
|
using LayoutScale = typename BlockEpilogue::LayoutScale;
|
||||||
using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale;
|
using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale;
|
||||||
using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale;
|
using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale;
|
||||||
|
|||||||
@@ -354,7 +354,7 @@ __aicore__ inline static void CalQuantRow(const uint32_t column, uint32_t &row)
|
|||||||
row = row < MAX_QUANT_ROW_ONCE ? row : MAX_QUANT_ROW_ONCE;
|
row = row < MAX_QUANT_ROW_ONCE ? row : MAX_QUANT_ROW_ONCE;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <uint32_t EXEC_FLAG, typename XType_, class BlockMmad_, class BlockEpilogue_, class BlockScheduler_, uint32_t WORKSPACE_STAGES_,
|
template <TemplateMC2TypeClass, class BlockMmad_, class BlockEpilogue_, class BlockScheduler_, uint32_t WORKSPACE_STAGES_,
|
||||||
class ElementGroupList_>
|
class ElementGroupList_>
|
||||||
class GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace
|
class GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace
|
||||||
{
|
{
|
||||||
@@ -371,7 +371,7 @@ public:
|
|||||||
using ElementAccumulator = typename BlockMmad::ElementAccumulator;
|
using ElementAccumulator = typename BlockMmad::ElementAccumulator;
|
||||||
|
|
||||||
using BlockEpilogue = BlockEpilogue_;
|
using BlockEpilogue = BlockEpilogue_;
|
||||||
using ElementScale = typename BlockEpilogue::ElementScale;
|
using ElementScale = typename BlockEpilogue::ElementRawScale;
|
||||||
using LayoutScale = typename BlockEpilogue::LayoutScale;
|
using LayoutScale = typename BlockEpilogue::LayoutScale;
|
||||||
using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale;
|
using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale;
|
||||||
using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale;
|
using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale;
|
||||||
@@ -388,7 +388,7 @@ public:
|
|||||||
static constexpr uint32_t WORKSPACE_STAGES = WORKSPACE_STAGES_;
|
static constexpr uint32_t WORKSPACE_STAGES = WORKSPACE_STAGES_;
|
||||||
using ElementGroupList = ElementGroupList_;
|
using ElementGroupList = ElementGroupList_;
|
||||||
|
|
||||||
using XType = XType_;
|
using XType = ExpandXType;
|
||||||
|
|
||||||
// Parameters structure
|
// Parameters structure
|
||||||
struct Params {
|
struct Params {
|
||||||
@@ -1715,7 +1715,7 @@ private:
|
|||||||
|
|
||||||
namespace Catlass::Gemm::Kernel {
|
namespace Catlass::Gemm::Kernel {
|
||||||
|
|
||||||
template <class BlockMmad_, class BlockEpilogue_, class BlockScheduler_, uint32_t WORKSPACE_STAGES_,
|
template <TemplateMC2TypeClass, class BlockMmad_, class BlockEpilogue_, class BlockScheduler_, uint32_t WORKSPACE_STAGES_,
|
||||||
class ElementGroupList_>
|
class ElementGroupList_>
|
||||||
class GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch
|
class GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch
|
||||||
{
|
{
|
||||||
@@ -1732,7 +1732,7 @@ public:
|
|||||||
using ElementAccumulator = typename BlockMmad::ElementAccumulator;
|
using ElementAccumulator = typename BlockMmad::ElementAccumulator;
|
||||||
|
|
||||||
using BlockEpilogue = BlockEpilogue_;
|
using BlockEpilogue = BlockEpilogue_;
|
||||||
using ElementScale = typename BlockEpilogue::ElementScale;
|
using ElementScale = typename BlockEpilogue::ElementRawScale;
|
||||||
using LayoutScale = typename BlockEpilogue::LayoutScale;
|
using LayoutScale = typename BlockEpilogue::LayoutScale;
|
||||||
using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale;
|
using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale;
|
||||||
using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale;
|
using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale;
|
||||||
@@ -2017,7 +2017,7 @@ private:
|
|||||||
|
|
||||||
struct AicWaitFunc {
|
struct AicWaitFunc {
|
||||||
using MatmulKernel = GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch<
|
using MatmulKernel = GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch<
|
||||||
BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>;
|
TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>;
|
||||||
|
|
||||||
CATLASS_DEVICE
|
CATLASS_DEVICE
|
||||||
AicWaitFunc() = default;
|
AicWaitFunc() = default;
|
||||||
@@ -2034,7 +2034,7 @@ private:
|
|||||||
|
|
||||||
struct AicSetFunc {
|
struct AicSetFunc {
|
||||||
using MatmulKernel = GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch<
|
using MatmulKernel = GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch<
|
||||||
BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>;
|
TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>;
|
||||||
|
|
||||||
CATLASS_DEVICE
|
CATLASS_DEVICE
|
||||||
AicSetFunc() = default;
|
AicSetFunc() = default;
|
||||||
|
|||||||
@@ -12,8 +12,8 @@
|
|||||||
|
|
||||||
#include "../common/moe_distribute_base.h"
|
#include "../common/moe_distribute_base.h"
|
||||||
|
|
||||||
#define TemplateMC2TypeClass typename ExpandXType, typename ExpandIdxType, bool IsNeedReduceScatter, uint32_t EXEC_FLAG
|
#define TemplateMC2TypeClass typename ExpandXType, typename W1ScaleType, typename W2ScaleType, typename ExpandIdxType, bool IsNeedReduceScatter, uint32_t EXEC_FLAG
|
||||||
#define TemplateMC2TypeFunc ExpandXType, ExpandIdxType, IsNeedReduceScatter, EXEC_FLAG
|
#define TemplateMC2TypeFunc ExpandXType, W1ScaleType, W2ScaleType, ExpandIdxType, IsNeedReduceScatter, EXEC_FLAG
|
||||||
#define TemplateDispatchTypeClass \
|
#define TemplateDispatchTypeClass \
|
||||||
typename XType, typename ExpandXOutType, bool StaticQuant, bool DynamicQuant, bool IsSmoothScaleExist, \
|
typename XType, typename ExpandXOutType, bool StaticQuant, bool DynamicQuant, bool IsSmoothScaleExist, \
|
||||||
bool IsNeedAllgater, uint32_t EXEC_FLAG
|
bool IsNeedAllgater, uint32_t EXEC_FLAG
|
||||||
|
|||||||
@@ -275,8 +275,6 @@ class FusionOp(DecodeMoeOps):
|
|||||||
torch_npu.Format.FRACTAL_NZ)
|
torch_npu.Format.FRACTAL_NZ)
|
||||||
gmm2_weight = torch_npu.npu_format_cast(gmm2_weight,
|
gmm2_weight = torch_npu.npu_format_cast(gmm2_weight,
|
||||||
torch_npu.Format.FRACTAL_NZ)
|
torch_npu.Format.FRACTAL_NZ)
|
||||||
gmm1_weight_scale = gmm1_weight_scale.float()
|
|
||||||
gmm2_weight_scale = gmm2_weight_scale.float()
|
|
||||||
|
|
||||||
if self.dynamic_eplb:
|
if self.dynamic_eplb:
|
||||||
self.gmm1_weight = [
|
self.gmm1_weight = [
|
||||||
|
|||||||
@@ -46,15 +46,12 @@ class VllmEplbAdaptor(EplbAdaptor):
|
|||||||
self.model.model.layers[i].mlp.experts.w13_weight_scale_fp32_list
|
self.model.model.layers[i].mlp.experts.w13_weight_scale_fp32_list
|
||||||
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_scale_list"] = \
|
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_scale_list"] = \
|
||||||
self.model.model.layers[i].mlp.experts.w2_weight_scale_list
|
self.model.model.layers[i].mlp.experts.w2_weight_scale_list
|
||||||
self.param_dict["model.layers." + str(i) + ".mlp.experts." + "w2_weight_scale_fp32_list"] = \
|
|
||||||
self.model.model.layers[i].mlp.experts.w2_weight_scale_fp32_list
|
|
||||||
# TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here
|
# TODO: init self.expert_weight_names depending on different model types, only deepseek v3 w8a8 and qwen3-moe is supported here
|
||||||
if self.model.quant_config is not None:
|
if self.model.quant_config is not None:
|
||||||
self.expert_weight_names = [
|
self.expert_weight_names = [
|
||||||
"w13_weight_list", "w2_weight_list",
|
"w13_weight_list", "w2_weight_list",
|
||||||
"w13_weight_scale_fp32_list", "w13_weight_offset",
|
"w13_weight_scale_fp32_list", "w13_weight_offset",
|
||||||
"w2_weight_scale_list", "w2_weight_offset",
|
"w2_weight_scale_list", "w2_weight_offset"
|
||||||
"w2_weight_scale_fp32_list"
|
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
self.expert_weight_names = ["w13_weight", "w2_weight"]
|
self.expert_weight_names = ["w13_weight", "w2_weight"]
|
||||||
@@ -83,8 +80,7 @@ class VllmEplbAdaptor(EplbAdaptor):
|
|||||||
self.num_dense_layers) + ".mlp.experts." + name
|
self.num_dense_layers) + ".mlp.experts." + name
|
||||||
if name in [
|
if name in [
|
||||||
"w13_weight_list", "w2_weight_list",
|
"w13_weight_list", "w2_weight_list",
|
||||||
"w13_weight_scale_fp32_list", "w2_weight_scale_list",
|
"w13_weight_scale_fp32_list", "w2_weight_scale_list"
|
||||||
"w2_weight_scale_fp32_list"
|
|
||||||
]:
|
]:
|
||||||
expert_tensor = self.param_dict[complete_name][0]
|
expert_tensor = self.param_dict[complete_name][0]
|
||||||
expert_tensor = expert_tensor.clone()
|
expert_tensor = expert_tensor.clone()
|
||||||
@@ -105,7 +101,7 @@ class VllmEplbAdaptor(EplbAdaptor):
|
|||||||
if name in [
|
if name in [
|
||||||
"w13_weight_list", "w2_weight_list",
|
"w13_weight_list", "w2_weight_list",
|
||||||
"w13_weight_scale_fp32_list",
|
"w13_weight_scale_fp32_list",
|
||||||
"w2_weight_scale_list", "w2_weight_scale_fp32_list"
|
"w2_weight_scale_list"
|
||||||
]:
|
]:
|
||||||
per_expert_param.append(
|
per_expert_param.append(
|
||||||
self.param_dict["model.layers." + str(layer_idx) +
|
self.param_dict["model.layers." + str(layer_idx) +
|
||||||
|
|||||||
@@ -243,24 +243,16 @@ class AscendW8A8DynamicFusedMoEMethod:
|
|||||||
topk_weights = topk_weights.to(self.in_dtype)
|
topk_weights = topk_weights.to(self.in_dtype)
|
||||||
|
|
||||||
moe_comm_method = get_forward_context().moe_comm_method
|
moe_comm_method = get_forward_context().moe_comm_method
|
||||||
# When VLLM_ASCEND_ENABLE_FUSED_MC2 == 2, use dispatch_gmm_combine_decode, need fp32 scale
|
|
||||||
w2_weight_scale_fp32_flag = (
|
|
||||||
get_forward_context().moe_comm_type == MoECommType.FUSED_MC2
|
|
||||||
and envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2 == 2)
|
|
||||||
if self.dynamic_eplb:
|
if self.dynamic_eplb:
|
||||||
w1 = layer.w13_weight_list
|
w1 = layer.w13_weight_list
|
||||||
w1_scale = layer.w13_weight_scale_fp32_list
|
w1_scale = layer.w13_weight_scale_fp32_list
|
||||||
w2 = layer.w2_weight_list
|
w2 = layer.w2_weight_list
|
||||||
w2_scale = layer.w2_weight_scale_fp32_list \
|
w2_scale = layer.w2_weight_scale_list
|
||||||
if w2_weight_scale_fp32_flag else layer.w2_weight_scale_list
|
|
||||||
else:
|
else:
|
||||||
w1 = [layer.w13_weight]
|
w1 = [layer.w13_weight]
|
||||||
w1_scale = [layer.w13_weight_scale_fp32]
|
w1_scale = [layer.w13_weight_scale_fp32]
|
||||||
w2 = [layer.w2_weight]
|
w2 = [layer.w2_weight]
|
||||||
w2_scale = [
|
w2_scale = [layer.w2_weight_scale]
|
||||||
layer.w2_weight_scale_fp32
|
|
||||||
if w2_weight_scale_fp32_flag else layer.w2_weight_scale
|
|
||||||
]
|
|
||||||
|
|
||||||
fused_scale_flag = (get_forward_context().moe_comm_type
|
fused_scale_flag = (get_forward_context().moe_comm_type
|
||||||
== MoECommType.FUSED_MC2
|
== MoECommType.FUSED_MC2
|
||||||
@@ -302,8 +294,6 @@ class AscendW8A8DynamicFusedMoEMethod:
|
|||||||
layer.w13_weight_offset.data.shape[0], -1)
|
layer.w13_weight_offset.data.shape[0], -1)
|
||||||
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
|
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
|
||||||
layer.w2_weight_scale.data.shape[0], -1)
|
layer.w2_weight_scale.data.shape[0], -1)
|
||||||
layer.w2_weight_scale_fp32 = layer.w2_weight_scale.data.to(
|
|
||||||
torch.float32)
|
|
||||||
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
|
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
|
||||||
layer.w2_weight_offset.data.shape[0], -1)
|
layer.w2_weight_offset.data.shape[0], -1)
|
||||||
|
|
||||||
@@ -328,16 +318,11 @@ class AscendW8A8DynamicFusedMoEMethod:
|
|||||||
weight.clone()
|
weight.clone()
|
||||||
for weight in layer.w2_weight_scale.data.unbind(dim=0)
|
for weight in layer.w2_weight_scale.data.unbind(dim=0)
|
||||||
]
|
]
|
||||||
layer.w2_weight_scale_fp32_list = [
|
|
||||||
weight.clone()
|
|
||||||
for weight in layer.w2_weight_scale_fp32.data.unbind(dim=0)
|
|
||||||
]
|
|
||||||
del layer.w13_weight
|
del layer.w13_weight
|
||||||
del layer.w2_weight
|
del layer.w2_weight
|
||||||
del layer.w13_weight_scale
|
del layer.w13_weight_scale
|
||||||
del layer.w13_weight_scale_fp32
|
del layer.w13_weight_scale_fp32
|
||||||
del layer.w2_weight_scale
|
del layer.w2_weight_scale
|
||||||
del layer.w2_weight_scale_fp32
|
|
||||||
torch.npu.empty_cache()
|
torch.npu.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user