From f71812011d49013ca077cc2441412593b7356052 Mon Sep 17 00:00:00 2001 From: lih827 <383084552@qq.com> Date: Thu, 12 Feb 2026 10:37:41 +0800 Subject: [PATCH] [Feature] DispatchGmmCombineDecode support bf16/float16 gmm1/gmm2 weight and support gmm weight with ND format (#6393) ### What this PR does / why we need it? 1. support ND format gmm weight input. Before this pr, gmm1_weight and gmm2_weight could only be passed as input to the DispatchGmmCombineDecode operator in NZ data format. After the modification, they are allowed to be passed in ND data format. 2. support bf16/float16 gmm weight The current PR modification enables the DispatchGmmCombineDecode operator to support non-W8A8 scenarios, allowing gmm1_weight and gmm2_weight to be passed as float16/bfloat16 which is correspond with input token data type. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - vLLM version: v0.14.1 - vLLM main: https://github.com/vllm-project/vllm/commit/dc917cceb877dfd13f98c538c4c96158047d98bd Signed-off-by: lih827 <383084552@qq.com> --- .../op_host/CMakeLists.txt | 1 + .../dispatch_gmm_combine_decode_def.cpp | 171 +- .../dispatch_gmm_combine_decode_tiling.cpp | 44 +- .../op_kernel/dispatch_gmm_combine_decode.cpp | 23 +- .../op_kernel/dispatch_gmm_combine_decode.h | 33 +- .../epilogue/block/block_epilogue.h | 3 + .../block/block_epilogue_bf16_fp16.hpp | 337 +++ .../block/block_epilogue_swiglu_bf16_fp16.h | 259 +++ .../epilogue/dispatch_policy.h | 13 + ...l_slice_m_multistage_workspace_bf16_fp16.h | 383 ++++ ...equant_swiglu_quant_multistage_workspace.h | 93 - ..._m_swiglu_multistage_workspace_bf16_fp16.h | 1805 +++++++++++++++++ .../cam_moe_distribute_combine.h | 2 + .../cam_moe_distribute_dispatch.h | 2 + .../dispatch_gmm_combine_decode_base.h | 101 +- .../dispatch_gmm_combine_decode_bf16_fp16.h | 457 +++++ .../dispatch_gmm_combine_decode_tiling.h | 7 + .../test_dispatch_gmm_combine_decode.py | 269 ++- 18 files changed, 3766 insertions(+), 237 deletions(-) create mode 100644 csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_bf16_fp16.hpp create mode 100644 csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_swiglu_bf16_fp16.h create mode 100644 csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_multistage_workspace_bf16_fp16.h create mode 100644 csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_swiglu_multistage_workspace_bf16_fp16.h create mode 100644 csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_bf16_fp16.h diff --git a/csrc/dispatch_gmm_combine_decode/op_host/CMakeLists.txt b/csrc/dispatch_gmm_combine_decode/op_host/CMakeLists.txt index 1c7abb24..7039b61f 100644 --- a/csrc/dispatch_gmm_combine_decode/op_host/CMakeLists.txt +++ b/csrc/dispatch_gmm_combine_decode/op_host/CMakeLists.txt @@ -19,6 +19,7 @@ add_ops_compile_options( OPTIONS --cce-auto-sync=off -Wno-deprecated-declarations -Werror + -DASCENDC_DUMP=0 ${_DISPATCH_GMM_INC_OPTS} ) diff --git a/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_def.cpp b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_def.cpp index 1f991815..838b05b8 100644 --- a/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_def.cpp +++ b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_def.cpp @@ -18,93 +18,190 @@ public: this->Input("x") .ParamType(REQUIRED) .DataType({ge::DT_BF16, ge::DT_BF16, ge::DT_BF16, ge::DT_BF16, - ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16}) + ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, + ge::DT_BF16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_BF16, + ge::DT_BF16, ge::DT_BF16, ge::DT_FLOAT16, ge::DT_FLOAT16, + ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("expert_ids") .ParamType(REQUIRED) .DataType({ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, - ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) + ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, + ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, + ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, + ge::DT_INT32, ge::DT_INT32, ge::DT_INT32, ge::DT_INT32}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("gmm1_permuted_weight") .ParamType(DYNAMIC) .DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, - ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8}) + ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, + ge::DT_BF16, ge::DT_FLOAT16, ge::DT_INT8, ge::DT_INT8, + ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, + ge::DT_INT8, ge::DT_INT8, ge::DT_BF16, ge::DT_FLOAT16}) .Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, - ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}) - .UnknownShapeFormat( - {ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, - ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}); + ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, + ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, + ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, + ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("gmm1_permuted_weight_scale") .ParamType(DYNAMIC) .DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_BF16, ge::DT_BF16, - ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_FLOAT16}) + ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_FLOAT16, + ge::DT_BF16, ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_FLOAT, + ge::DT_BF16, ge::DT_BF16, ge::DT_FLOAT, ge::DT_FLOAT, + ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("gmm2_weight") .ParamType(DYNAMIC) .DataType({ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, - ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8}) + ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, + ge::DT_BF16, ge::DT_FLOAT16, ge::DT_INT8, ge::DT_INT8, + ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, ge::DT_INT8, + ge::DT_INT8, ge::DT_INT8, ge::DT_BF16, ge::DT_FLOAT16}) .Format({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, - ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}) - .UnknownShapeFormat( - {ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, - ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ}); + ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, + ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + .UnknownShapeFormat({ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, + ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, + ge::FORMAT_FRACTAL_NZ, ge::FORMAT_FRACTAL_NZ, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("gmm2_weight_scale") .ParamType(DYNAMIC) .DataType({ge::DT_FLOAT, ge::DT_BF16, ge::DT_FLOAT, ge::DT_BF16, - ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_FLOAT16}) + ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_FLOAT16, + ge::DT_BF16, ge::DT_FLOAT16, ge::DT_FLOAT, ge::DT_BF16, + ge::DT_FLOAT, ge::DT_BF16, ge::DT_FLOAT, ge::DT_FLOAT16, + ge::DT_FLOAT, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("expert_scales") .ParamType(REQUIRED) .DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, - ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) + ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, + ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, + ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, + ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("expert_smooth_scales") .ParamType(OPTIONAL) .DataType({ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, - ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) + ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, + ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, + ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, + ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT, ge::DT_FLOAT}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Input("x_active_mask") .ParamType(OPTIONAL) .DataType({ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, - ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL}) + ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, + ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, + ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, + ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL, ge::DT_BOOL}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Output("output") .ParamType(REQUIRED) .DataType({ge::DT_BF16, ge::DT_BF16, ge::DT_BF16, ge::DT_BF16, - ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16}) + ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_FLOAT16, + ge::DT_BF16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_BF16, + ge::DT_BF16, ge::DT_BF16, ge::DT_FLOAT16, ge::DT_FLOAT16, + ge::DT_FLOAT16, ge::DT_FLOAT16, ge::DT_BF16, ge::DT_FLOAT16}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Output("expert_token_nums") .ParamType(REQUIRED) .DataType({ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, - ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64}) + ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, + ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, + ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, + ge::DT_INT64, ge::DT_INT64, ge::DT_INT64, ge::DT_INT64}) .Format({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}) .UnknownShapeFormat({ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, - ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, + ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND, ge::FORMAT_ND}); this->Attr("group_ep").String(); this->Attr("ep_rank_size").Int(); this->Attr("ep_rank_id").Int(); diff --git a/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_tiling.cpp b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_tiling.cpp index f4d3f430..847a8381 100644 --- a/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_tiling.cpp +++ b/csrc/dispatch_gmm_combine_decode/op_host/dispatch_gmm_combine_decode_tiling.cpp @@ -91,6 +91,16 @@ static ge::graphStatus CheckGmm1Shape(gert::TilingContext *context, DispatchGmmC auto gmm1FirstTensorElement = context->GetDynamicInputTensor(INPUT_GMM1_WEIGHT_INDEX, 0); auto gmm1FirstTensorElementShape = gmm1FirstTensorElement->GetOriginShape(); uint32_t elementDims = gmm1FirstTensorElementShape.GetDimNum(); + ge::DataType gmm1DataType = gmm1FirstTensorElement->GetDataType(); + if (gmm1DataType == ge::DT_BF16 || gmm1DataType == ge::DT_FLOAT16) { + tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isBf16Fp16W = true; + } else { + tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isBf16Fp16W = false; + } + auto gmm1WeightDesc = context->GetDynamicInputDesc(INPUT_GMM1_WEIGHT_INDEX, 0); + if (GetPrimaryFormat(gmm1WeightDesc->GetStorageFormat()) == ge::FORMAT_ND) { + tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isNDFormat = true; + } OPS_ERR_IF(elementDims != 2 && elementDims != 3, OPS_LOG_E(nodeName, "gmm1Weight shape is invalid."), return ge::GRAPH_FAILED); @@ -129,6 +139,9 @@ static ge::graphStatus CheckGmm1Shape(gert::TilingContext *context, DispatchGmmC static ge::graphStatus CheckGmm1ScaleShape(gert::TilingContext *context, DispatchGmmCombineDecodeTilingData *tilingData) { + if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isBf16Fp16W) { + return ge::GRAPH_SUCCESS; + } const char *nodeName = context->GetNodeName(); uint32_t moeExpertNumPerRank = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank; uint32_t n = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen; @@ -170,6 +183,10 @@ static ge::graphStatus CheckGmm2Shape(gert::TilingContext *context, DispatchGmmC uint32_t elementDims = gmm2FirstTensorElementShape.GetDimNum(); OPS_ERR_IF(elementDims != 2 && elementDims != 3, OPS_LOG_E(nodeName, "gmm2Weight shape is invalid."), return ge::GRAPH_FAILED); + auto gmm2WeightDesc = context->GetDynamicInputDesc(INPUT_GMM2_WEIGHT_INDEX, 0); + if (GetPrimaryFormat(gmm2WeightDesc->GetStorageFormat()) == ge::FORMAT_ND) { + tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isNDFormat = true; + } if (gmm2ListLen > 1) { // List OPS_ERR_IF(gmm2ListLen != moeExpertNumPerRank, OPS_LOG_E(nodeName, "gmm2 does not match local expert number perRank."), return ge::GRAPH_FAILED); @@ -198,6 +215,9 @@ static ge::graphStatus CheckGmm2Shape(gert::TilingContext *context, DispatchGmmC static ge::graphStatus CheckGmm2ScaleShape(gert::TilingContext *context, DispatchGmmCombineDecodeTilingData *tilingData) { + if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isBf16Fp16W) { + return ge::GRAPH_SUCCESS; + } const char *nodeName = context->GetNodeName(); uint32_t moeExpertNumPerRank = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank; uint32_t h = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h; @@ -383,16 +403,26 @@ static ge::graphStatus SetWorkSpace(gert::TilingContext *context, const char *no } else { maxTokenNum = maxBatchSize * epRankSize * std::min(topK, moeExpertNumPerRank); } + uint32_t wTypeSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.isBf16Fp16W ? TOKEN_DTYPE_BYTE_SIZE : sizeof(int8_t); - size_t x1TokenSize = maxTokenNum * h * sizeof(int8_t); - size_t x2TokenSize = maxTokenNum * gmm2HLen * sizeof(int8_t); + // hbm input = x: float16 or bf16 + // buf1 dispatch (Only AIV) => x1: float16 or bf16 + // buf2 gmm1 (Only AIC) => y1: float + // sync + // buf3 swiglu (Only AIV) => x2: float16 or bf16 + // sync ? + // buf4 gmm2 (AIC & AIV) => y2: float16 or bf16 + // hbm combine (Only AIV) => output: float16 or bf16 + + size_t x1TokenSize = maxTokenNum * h * wTypeSize; // x1: float16 or bf16 + size_t x2TokenSize = maxTokenNum * gmm2HLen * wTypeSize; // x2: float16 or bf16 size_t maxTokenSize = x1TokenSize < x2TokenSize ? x2TokenSize : x1TokenSize; maxTokenSize = CeilUp(maxTokenSize, GM_ALIGN_SIZE); - size_t tokenScaleSize = CeilUp(maxTokenNum * sizeof(float), GM_ALIGN_SIZE); + size_t tokenScaleSize = tilingData.disGmmDeqSwigluQuantGmmDeqComInfo.isBf16Fp16W ? 0 : CeilUp(maxTokenNum * sizeof(float), GM_ALIGN_SIZE); size_t CVSwapBufferSize = CeilUp(aicNum * L1_TILE_BYTE_SIZE * CUBE_WORKSPACE_STAGE * sizeof(int32_t), GM_ALIGN_SIZE); - size_t swigluOutSize = maxTokenNum * gmm1HLen * sizeof(float); - size_t gmm2DepOutSize = maxTokenNum * h * TOKEN_DTYPE_BYTE_SIZE; + size_t swigluOutSize = maxTokenNum * gmm1HLen * sizeof(float); // y1: float + size_t gmm2DepOutSize = maxTokenNum * h * TOKEN_DTYPE_BYTE_SIZE; // y2: float size_t maxSwigluGmm2Size = swigluOutSize < gmm2DepOutSize ? gmm2DepOutSize : swigluOutSize; maxSwigluGmm2Size = CeilUp(maxSwigluGmm2Size, GM_ALIGN_SIZE); size_t groupListSize = CeilUp(moeExpertNumPerRank * sizeof(int64_t), GM_ALIGN_SIZE); @@ -462,6 +492,10 @@ static ge::graphStatus DispatchGmmCombineDecodeTilingFuncImpl(gert::TilingContex if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isTensorList) { tilingKey |= EXEC_FLAG_TENSOR_LIST; } + if (tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.isNDFormat) { + tilingKey |= EXEC_FLAG_ND_FORMAT; + } + context->SetTilingKey(tilingKey); context->SetBlockDim(aicNum); return ge::GRAPH_SUCCESS; diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.cpp b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.cpp index aae344af..3fa6575c 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.cpp +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.cpp @@ -8,6 +8,7 @@ * See LICENSE in the root of the software repository for the full text of the License. */ #include "dispatch_gmm_combine_decode.h" +#include "dispatch_gmm_combine_decode_bf16_fp16.h" #include #include "lib/matmul_intf.h" @@ -25,12 +26,28 @@ extern "C" __global__ __aicore__ void dispatch_gmm_combine_decode( REGISTER_TILING_DEFAULT(DispatchGmmCombineDecodeTilingData); KERNEL_TASK_TYPE_DEFAULT(KERNEL_TYPE_MIX_AIC_1_2); // 1C2V GET_TILING_DATA(tiling_data, tiling); + +#if (ORIG_DTYPE_GMM1_PERMUTED_WEIGHT == DT_INT8) if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1) || TILING_KEY_IS(2) || TILING_KEY_IS(3) || - TILING_KEY_IS(4) || TILING_KEY_IS(5) || TILING_KEY_IS(6) || TILING_KEY_IS(7)) { - DispatchGmmCombineDecode< - DTYPE_X, DTYPE_GMM1_PERMUTED_WEIGHT_SCALE, DTYPE_GMM2_WEIGHT_SCALE, int32_t, false, TILING_KEY_VAR> op; + TILING_KEY_IS(4) || TILING_KEY_IS(5) || TILING_KEY_IS(6) || TILING_KEY_IS(7) || + TILING_KEY_IS(8) || TILING_KEY_IS(9) || TILING_KEY_IS(10) || TILING_KEY_IS(11) || + TILING_KEY_IS(12) || TILING_KEY_IS(13) || TILING_KEY_IS(14) || TILING_KEY_IS(15)) { + DispatchGmmCombineDecodeImpl::DispatchGmmCombineDecode< + DTYPE_X, DTYPE_GMM1_PERMUTED_WEIGHT_SCALE, DTYPE_GMM2_WEIGHT_SCALE, int8_t, int32_t, false, TILING_KEY_VAR> op; op.Init(x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, gmm2_weight_scale, expert_scales, expert_smooth_scales, x_active_mask, output, expertTokenNums, workspace, nullptr, &tiling_data); op.Process(); } +#elif (ORIG_DTYPE_GMM1_PERMUTED_WEIGHT == DT_BF16 || ORIG_DTYPE_GMM1_PERMUTED_WEIGHT == DT_FLOAT16) + if constexpr (TILING_KEY_IS(0) || TILING_KEY_IS(1) || TILING_KEY_IS(2) || TILING_KEY_IS(3) || + TILING_KEY_IS(4) || TILING_KEY_IS(5) || TILING_KEY_IS(6) || TILING_KEY_IS(7) || + TILING_KEY_IS(8) || TILING_KEY_IS(9) || TILING_KEY_IS(10) || TILING_KEY_IS(11) || + TILING_KEY_IS(12) || TILING_KEY_IS(13) || TILING_KEY_IS(14) || TILING_KEY_IS(15)) { + DispatchGmmCombineDecodeBf16Fp16Impl::DispatchGmmCombineDecodeBf16Fp16< + DTYPE_GMM1_PERMUTED_WEIGHT, DTYPE_GMM1_PERMUTED_WEIGHT_SCALE, DTYPE_GMM2_WEIGHT_SCALE, DTYPE_GMM1_PERMUTED_WEIGHT, int32_t, false, TILING_KEY_VAR> op; + op.Init(x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, gmm2_weight_scale, + expert_scales, expert_smooth_scales, x_active_mask, output, expertTokenNums, workspace, nullptr, &tiling_data); + op.Process(); + } +#endif } diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h index 97aa44ea..264483c5 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode.h @@ -34,7 +34,7 @@ #include "dispatch_gmm_combine_decode_base.h" using namespace Catlass; - +namespace DispatchGmmCombineDecodeImpl { using MmadAtlasA2Custom = Gemm::MmadAtlasA2PreloadAsyncWithCallback CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA, - layout::RowMajor layoutA, GM_ADDR gmB, layout::zN layoutB, GM_ADDR gmScale, + layout::RowMajor layoutA, GM_ADDR gmB, + typename std::conditional<(EXEC_FLAG & EXEC_FLAG_ND_FORMAT) != 0, layout::RowMajor, layout::zN>::type layoutB, + GM_ADDR gmScale, layout::VectorLayout layoutScale, GM_ADDR gmPerTokenScale, layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD, GM_ADDR gmDequantScale, layout::VectorLayout layoutDequantScale, GM_ADDR gmWorkspace, @@ -73,7 +75,8 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun using L0TileShape = L0TileShape_; using AType = Gemm::GemmType; - using BType = Gemm::GemmType; + using LayoutB = typename std::conditional<(EXEC_FLAG & EXEC_FLAG_ND_FORMAT) != 0, layout::RowMajor, layout::zN>::type; + using BType = Gemm::GemmType; using CType = Gemm::GemmType; using BlockMmad = Gemm::Block::BlockMmad; @@ -107,7 +110,7 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun using ElementGroupList = int64_t; using GemmKernel = typename std::conditional< - (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE), + (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) != 0, Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace< TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>, Gemm::Kernel::GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspaceWithShallowDispatch< @@ -178,7 +181,9 @@ CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCoun template CATLASS_DEVICE void GmmDeq(GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA, - layout::RowMajor layoutA, GM_ADDR gmB, layout::zN layoutB, GM_ADDR gmScale, + layout::RowMajor layoutA, GM_ADDR gmB, + typename std::conditional<(EXEC_FLAG & EXEC_FLAG_ND_FORMAT) != 0, layout::RowMajor, layout::zN>::type layoutB, + GM_ADDR gmScale, layout::VectorLayout layoutScale, GM_ADDR gmPerTokenScale, layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD, GM_ADDR gmWorkspace, void *combiner) @@ -189,7 +194,8 @@ CATLASS_DEVICE void GmmDeq(GemmCoord problemShape, uint32_t groupCount, GM_ADDR using L0TileShape = L0TileShape_; using AType = Gemm::GemmType; - using BType = Gemm::GemmType; + using LayoutB = typename std::conditional<(EXEC_FLAG & EXEC_FLAG_ND_FORMAT) != 0, layout::RowMajor, layout::zN>::type; + using BType = Gemm::GemmType; using CType = Gemm::GemmType; using BlockMmad = Gemm::Block::BlockMmad; @@ -342,6 +348,16 @@ __aicore__ inline void DispatchGmmCombineDecode::Init( gmm2InputDim_ = gmm1OutputDim_ / 2; } +template +__aicore__ inline auto CreateWeightLayout(uint32_t k, uint32_t n) { + if constexpr ((EXEC_FLAG & EXEC_FLAG_ND_FORMAT) != 0) { + MatrixCoord mc{k, n}; + return layout::RowMajor::template MakeLayoutInUb(mc); + } else { + return layout::zN::template MakeLayout(k, n); + } +} + template __aicore__ inline void DispatchGmmCombineDecode::Process() { @@ -349,11 +365,11 @@ __aicore__ inline void DispatchGmmCombineDecode::Process() GemmCoord gmm2ProblemShape{maxTokenNum_, gmm2OutputDim_, gmm2InputDim_}; layout::RowMajor layoutX1{maxTokenNum_, tokenHiddenSize_}; - layout::zN layoutWeight1 = layout::zN::template MakeLayout(tokenHiddenSize_, gmm1OutputDim_); + auto layoutWeight1 = CreateWeightLayout(tokenHiddenSize_, gmm1OutputDim_); layout::VectorLayout layoutW1Scale{gmm1OutputDim_}; layout::VectorLayout layoutX1Scale{maxTokenNum_}; layout::RowMajor layoutX2{maxTokenNum_, gmm2InputDim_}; - layout::zN layoutWeight2 = layout::zN::template MakeLayout(gmm2InputDim_, gmm2OutputDim_); + auto layoutWeight2 = CreateWeightLayout(gmm2InputDim_, gmm2OutputDim_); layout::VectorLayout layoutW2Scale{gmm2OutputDim_}; layout::VectorLayout layoutX2Scale{maxTokenNum_}; layout::RowMajor layoutOutput{maxTokenNum_, gmm2OutputDim_}; @@ -436,4 +452,5 @@ __aicore__ inline void DispatchGmmCombineDecode::Process() gmScale2_, layoutW2Scale, gmX2Scale, layoutX2Scale, gmGmm2DepOut, layoutOutput, gmWorkspace, &combiner); } +} // namespace DispatchGmmCombineDecodeImpl #endif // DISPATCH_GMM_COMBINE_DECODE_H diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue.h index 4bbbc792..1d7f6aab 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue.h @@ -12,3 +12,6 @@ #include "block_epilogue_per_token_dequant_swiglu.h" #include "block_epilogue_per_token_dequant.hpp" + +#include "block_epilogue_swiglu_bf16_fp16.h" +#include "block_epilogue_bf16_fp16.hpp" \ No newline at end of file diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_bf16_fp16.hpp b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_bf16_fp16.hpp new file mode 100644 index 00000000..83c8300c --- /dev/null +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_bf16_fp16.hpp @@ -0,0 +1,337 @@ +/* + * Copyright (c) 2026 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef ACT_EPILOGUE_BLOCK_EPILOGUE_BF16_FP16_HPP +#define ACT_EPILOGUE_BLOCK_EPILOGUE_BF16_FP16_HPP + +#include "../../raw_distributed/cam_moe_distribute_combine.h" +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/detail/callback.hpp" +#include "catlass/epilogue/dispatch_policy.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/matrix_coord.hpp" + +namespace Catlass::Epilogue::Block { + +template +class BlockEpilogue, + CType_, Gemm::GemmType, Gemm::GemmType, DType_, + TileRowBroadcastMul_, TileBroadcastOneBlk_, TileOneBlkColumnBroadcastMul_, + TileCopy_, EpilogueTileSwizzle_> +{ +public: + using DispatchPolicy = EpilogueAtlasA2Combine; + using ArchTag = typename DispatchPolicy::ArchTag; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + static constexpr uint32_t EXEC_FLAG = EXEC_FLAG_; + + // Data infos + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using ElementRawScale = ScaleType_; + using ElementFp32Scale = float; + using LayoutScale = LayoutScale_; + using ElementPerTokenScale = float; + using LayoutPerTokenScale = LayoutPerTokenScale_; + using ElementD = typename DType_::Element; + using LayoutD = typename DType_::Layout; + + // Check data infos + static_assert(std::is_same_v && + (std::is_same_v || std::is_same_v), + "The element type template parameters of BlockEpilogue are wrong"); + static_assert(std::is_same_v && std::is_same_v && + std::is_same_v && + std::is_same_v, + "The layout template parameters of BlockEpilogue are wrong"); + + // Tile compute ops + using TileRowBroadcastMul = TileRowBroadcastMul_; + using TileBroadcastOneBlk = TileBroadcastOneBlk_; + using TileOneBlkColumnBroadcastMul = TileOneBlkColumnBroadcastMul_; + + // Tile copy + using CopyGmToUbC = typename TileCopy_::CopyGmToUbC; + using CopyGmToUbScale = typename TileCopy_::CopyGmToUbX; + using CopyGmToUbPerTokenScale = typename TileCopy_::CopyGmToUbY; + using CopyUbToGmD = typename TileCopy_::CopyUbToGmD; + + using EpilogueTileSwizzle = EpilogueTileSwizzle_; + + using TileShape = typename TileRowBroadcastMul::TileShape; + + static_assert(TileShape::ROW == TileBroadcastOneBlk::COMPUTE_LENGTH && + std::is_same_v, + "TileShape must be consistent for all tile compute ops"); + + static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COUNT * sizeof(ElementD)) + + TileShape::ROW * BYTE_PER_BLK) <= ArchTag::UB_SIZE, + "TileShape is too large to fit in UB"); + + struct Params { + __gm__ ElementRawScale *ptrScale{nullptr}; + LayoutScale layoutScale{}; + __gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr}; + LayoutPerTokenScale layoutPerTokenScale{}; + __gm__ ElementD *ptrD{nullptr}; + LayoutD layoutD{}; + + CATLASS_DEVICE + Params() {}; + + CATLASS_DEVICE + Params(__gm__ ElementRawScale *ptrScale_, LayoutScale const &layoutScale_, + __gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_, + __gm__ ElementD *ptrD_, LayoutD const &layoutD_) + : ptrScale(ptrScale_), + layoutScale(layoutScale_), + ptrPerTokenScale(ptrPerTokenScale_), + layoutPerTokenScale(layoutPerTokenScale_), + ptrD(ptrD_), + layoutD(layoutD_) + {} + }; + + CATLASS_DEVICE void AlignUbOffset() + { + size_t ubMask = ubOffset & (MoeDistributeCombineImpl::UB_ALIGN - 1); + if (ubMask != 0) { + ubOffset += MoeDistributeCombineImpl::UB_ALIGN - ubMask; + } + } + + CATLASS_DEVICE + BlockEpilogue(Arch::Resource &resource, MoeDistributeCombineImpl::CombineCalcInfo &calcInfo, + Params const ¶ms = Params{}) + : resource(resource), calcInfo(calcInfo), params(params) + { + for (uint32_t i = 0; i < UB_STAGES; ++i) { + ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementC); + ubDList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementD); + + eventUbCVMTE2List[i] = eventVMTE2++; + eventUbCMTE2VList[i] = eventMTE2V++; + eventUbDMTE3VList[i] = eventMTE3V++; + eventUbDVMTE3List[i] = eventVMTE3++; + + AscendC::SetFlag(eventUbCVMTE2List[i]); + AscendC::SetFlag(eventUbDMTE3VList[i]); + } + + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + AlignUbOffset(); + epSendCountLocal_ = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += calcInfo.moeSendNum_ * sizeof(int32_t); + AlignUbOffset(); + AscendC::GlobalTensor epSendCountGM; + epSendCountGM.SetGlobalBuffer((__gm__ int32_t *)calcInfo.epSendCount_); + uint32_t epSendCountSize = calcInfo.isShardExpert_ ? calcInfo.epWorldSize_ : calcInfo.moeSendNum_; + AscendC::DataCopyExtParams epSendCntParams = {1U, static_cast(epSendCountSize * sizeof(uint32_t)), + 0U, 0U, 0U}; + AscendC::DataCopyPadExtParams copyPadParams{false, 0U, 0U, 0U}; + AscendC::DataCopyPad(epSendCountLocal_, epSendCountGM, epSendCntParams, copyPadParams); + AscendC::SetFlag(eventMTE2S); + AscendC::WaitFlag(eventMTE2S); + } + } + + CATLASS_DEVICE + ~BlockEpilogue() + { + for (uint32_t i = 0; i < UB_STAGES; ++i) { + AscendC::WaitFlag(eventUbCVMTE2List[i]); + AscendC::WaitFlag(eventUbDMTE3VList[i]); + } + } + + CATLASS_DEVICE + void UpdateParams(Params const ¶ms_) + { + params = params_; + } + + CATLASS_DEVICE GM_ADDR GetWinAddrByRankId(const int32_t rankId, const uint8_t expertLocalId = 0U) + { + return (GM_ADDR)((calcInfo.epRankId_ == rankId) + ? calcInfo.epWinContext_->localWindowsIn + : ((HcclRankRelationResV2 *)(calcInfo.epWinContext_->remoteRes[rankId].nextDevicePtr)) + ->windowsIn) + + calcInfo.winDataSizeOffset_ + expertLocalId * calcInfo.expertPerSizeOnWin_ + rankId * OPT_RANK_OFFSET; + } + + CATLASS_DEVICE void SetCombineSendEpRank(uint32_t epRank, uint32_t &remoteEpRank, uint32_t &localEpRank) + { + if ((calcInfo.isShardExpert_) && (epRank < calcInfo.sharedExpertRankNum_)) { + remoteEpRank = calcInfo.epRankId_; + localEpRank = epRank; + } else { + remoteEpRank = epRank; + localEpRank = calcInfo.epRankId_; + } + } + + CATLASS_DEVICE void DoCombineSend(AscendC::LocalTensor &ubD, layout::RowMajor &layoutGmTileD, + LayoutD &layoutUbD, int64_t groupOffsetD, uint32_t expertIdx, uint32_t tileOffsetD) + { + const uint32_t copyTokenLen = layoutGmTileD.shape(1) * sizeof(ElementD); + const uint32_t copyTokenSrcStride = + (layoutUbD.stride(0) - layoutUbD.shape(1)) / (BYTE_PER_C0 / sizeof(ElementD)); + const uint32_t copyTokenDstStride = (layoutGmTileD.stride(0) - layoutGmTileD.shape(1)) * sizeof(ElementD); + + int64_t offsetD = groupOffsetD + tileOffsetD; + uint32_t startToken = offsetD / calcInfo.axisH_; + uint32_t tokenOffset = offsetD - startToken * calcInfo.axisH_; + uint32_t itToken = startToken; + uint32_t endToken = startToken + layoutGmTileD.shape(0); + constexpr uint32_t epRankStart = 0; + uint32_t sendCount = + expertIdx == 0 && epRankStart == 0 ? 0 : epSendCountLocal_.GetValue(expertOffset + epRankStart - 1); + for (uint32_t epRank = epRankStart; epRank < calcInfo.epWorldSize_ && itToken < endToken; ++epRank) { + uint32_t prevSendCount = sendCount; + sendCount = epSendCountLocal_.GetValue(expertOffset + epRank); + if (prevSendCount <= itToken && itToken < sendCount) { + uint32_t copyTokenCount = (sendCount < endToken ? sendCount : endToken) - itToken; + AscendC::DataCopyExtParams dataCopyParams(copyTokenCount, copyTokenLen, copyTokenSrcStride, + copyTokenDstStride, 0); + uint32_t remoteEpRank; + uint32_t localEpRank; + SetCombineSendEpRank(epRank, remoteEpRank, localEpRank); + GM_ADDR rankGM = GetWinAddrByRankId(remoteEpRank, expertIdx) + + localEpRank * calcInfo.moeExpertPerRankNum_ * calcInfo.expertPerSizeOnWin_; + AscendC::GlobalTensor rankWindow; + rankWindow.SetGlobalBuffer((__gm__ ElementD *)rankGM); + AscendC::DataCopyPad(rankWindow[(itToken - prevSendCount) * calcInfo.axisH_ + tokenOffset], + ubD[(itToken - startToken) * layoutUbD.stride(0)], dataCopyParams); + itToken += copyTokenCount; + } + } + } + + CATLASS_DEVICE + void operator()(int64_t groupOffsetD, uint32_t expertIdx, GemmCoord const &blockShapeMNK, + GemmCoord const &blockCoordMNK, GemmCoord const &actualBlockShapeMNK, + AscendC::GlobalTensor const &gmBlockC, LayoutC const &layoutBlockC, + Callback &&callback = Callback{}) + { + if (actualBlockShapeMNK.k() == 0) { + return; + } + + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + expertOffset = expertIdx * calcInfo.epWorldSize_; + } + + callback(); + // Calculate the offset of the current block + MatrixCoord blockShape = blockShapeMNK.GetCoordMN(); + MatrixCoord blockCoord = blockCoordMNK.GetCoordMN(); + MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN(); + MatrixCoord blockOffset = blockCoord * blockShape; + + AscendC::GlobalTensor gmScale; + gmScale.SetGlobalBuffer(params.ptrScale); + AscendC::GlobalTensor gmPerTokenScale; + gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale); + AscendC::GlobalTensor gmD; + gmD.SetGlobalBuffer(params.ptrD); + + auto ubTileStride = MakeCoord(static_cast(TileShape::COLUMN), 1L); + auto tileShape = TileShape::ToCoord(); + EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape); + uint32_t tileLoops = epilogueTileSwizzle.GetLoops(); + uint32_t subblockIdx = AscendC::GetSubBlockIdx(); + uint32_t subblockNum = AscendC::GetSubBlockNum(); + for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; loopIdx += subblockNum) { + auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx); + auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord); + auto tileOffsetInBlock = tileCoord * tileShape; + auto tileOffset = blockOffset + tileOffsetInBlock; + + auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)]; + auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape); + + auto &ubC = ubCList[ubListId]; + LayoutC layoutUbC{actualTileShape, ubTileStride}; + + AscendC::WaitFlag(eventUbCVMTE2List[ubListId]); + copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC); + AscendC::SetFlag(eventUbCMTE2VList[ubListId]); + AscendC::WaitFlag(eventUbCMTE2VList[ubListId]); + + auto &ubD = ubDList[ubListId]; + LayoutD layoutUbD{actualTileShape, ubTileStride}; + + AscendC::WaitFlag(eventUbDMTE3VList[ubListId]); + AscendC::Cast(ubD, ubC, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); + AscendC::SetFlag(eventUbDVMTE3List[ubListId]); + AscendC::SetFlag(eventUbCVMTE2List[ubListId]); + + auto tileOffsetD = params.layoutD.GetOffset(tileOffset); + auto layoutGmTileD = params.layoutD.GetTileLayout(actualTileShape); + + AscendC::WaitFlag(eventUbDVMTE3List[ubListId]); + + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + DoCombineSend(ubD, layoutGmTileD, layoutUbD, groupOffsetD, expertIdx, tileOffsetD); + } else { + auto gmTileD = gmD[tileOffsetD]; + copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD); + } + + AscendC::SetFlag(eventUbDMTE3VList[ubListId]); + + ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0; + } + } + +private: + Params params; + Arch::Resource &resource; + MoeDistributeCombineImpl::CombineCalcInfo calcInfo; + + AscendC::LocalTensor ubCList[UB_STAGES]; + AscendC::LocalTensor ubDList[UB_STAGES]; + + int32_t eventUbCVMTE2List[UB_STAGES]; + int32_t eventUbCMTE2VList[UB_STAGES]; + int32_t eventUbScaleVMTE2List[UB_STAGES]; + int32_t eventUbScaleMTE2VList[UB_STAGES]; + int32_t eventUbPerTokenScaleVMTE2List[UB_STAGES]; + int32_t eventUbPerTokenScaleMTE2VList[UB_STAGES]; + int32_t eventUbDMTE3VList[UB_STAGES]; + int32_t eventUbDVMTE3List[UB_STAGES]; + + AscendC::LocalTensor epSendCountLocal_; + + size_t ubOffset{0}; + int32_t eventVMTE2{0}; + int32_t eventMTE2V{0}; + int32_t eventMTE3V{0}; + int32_t eventVMTE3{0}; + int32_t eventVS{0}; + int32_t eventMTE2S{0}; + + uint32_t expertOffset; + + uint32_t ubListId{0}; + + CopyGmToUbC copyGmToUbC; + CopyUbToGmD copyUbToGmD; +}; + +} // namespace Catlass::Epilogue::Block + +#endif // ACT_EPILOGUE_BLOCK_EPILOGUE_BF16_FP16_HPP diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_swiglu_bf16_fp16.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_swiglu_bf16_fp16.h new file mode 100644 index 00000000..63ad0e66 --- /dev/null +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/block/block_epilogue_swiglu_bf16_fp16.h @@ -0,0 +1,259 @@ +/* + * Copyright (c) 2026 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#pragma once +#include "catlass/catlass.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/epilogue/dispatch_policy.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/detail/callback.hpp" + +#include "../tile/tile_stride_muls.h" +#include "../tile/tile_stride_binary.h" + +namespace Catlass::Epilogue::Block { + +template +class BlockEpilogue, + CType_, Gemm::GemmType, Gemm::GemmType, + DType_, TileRowBroadcastMul_, TileBroadcastOneBlk_, TileOneBlkColumnBroadcastMul_, + TileCopy_, EpilogueTileSwizzle_> +{ +public: + using DispatchPolicy = EpilogueAtlasA2Swiglu; + using ArchTag = typename DispatchPolicy::ArchTag; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + + // Data infos + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using ElementRawScale = ScaleType_; + using ElementFp32Scale = float; + using LayoutScale = LayoutScale_; + using ElementPerTokenScale = float; + using LayoutPerTokenScale = LayoutPerTokenScale_; + using ElementD = typename DType_::Element; + using LayoutD = typename DType_::Layout; + + // Check data infos + static_assert(std::is_same_v && std::is_same_v, + "The element type template parameters of BlockEpilogue are wrong"); + static_assert(std::is_same_v && std::is_same_v && + std::is_same_v && + std::is_same_v, + "The layout template parameters of BlockEpilogue are wrong"); + + // Tile compute ops + using TileRowBroadcastMul = TileRowBroadcastMul_; + using TileBroadcastOneBlk = TileBroadcastOneBlk_; + using TileOneBlkColumnBroadcastMul = TileOneBlkColumnBroadcastMul_; + + // Tile copy + using CopyGmToUbC = typename TileCopy_::CopyGmToUbC; + using CopyGmToUbScale = typename TileCopy_::CopyGmToUbX; + using CopyGmToUbPerTokenScale = typename TileCopy_::CopyGmToUbY; + using CopyUbToGmD = typename TileCopy_::CopyUbToGmD; + + using EpilogueTileSwizzle = EpilogueTileSwizzle_; + + using TileShape = typename TileRowBroadcastMul::TileShape; + static_assert(TileShape::ROW * sizeof(float) % BYTE_PER_BLK == 0, + "The per token scale granularity for word calculation must be 32 bytes aligned."); + static_assert(TileShape::COLUMN % 2 == 0, "The n-axis needs to be divided into two parts."); + + static_assert(TileShape::ROW == TileBroadcastOneBlk::COMPUTE_LENGTH && + std::is_same_v, + "TileShape must be consistent for all tile compute ops"); + + static_assert(UB_STAGES <= 2, "UB stages too large, event id is not enough."); + + static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COUNT * sizeof(ElementD)) + + TileShape::ROW * BYTE_PER_BLK) <= ArchTag::UB_SIZE, + "TileShape is too large to fit in UB"); + + struct Params { + __gm__ ElementRawScale *ptrScale{nullptr}; + LayoutScale layoutScale{}; + __gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr}; + LayoutPerTokenScale layoutPerTokenScale{}; + __gm__ ElementD *ptrD{nullptr}; + LayoutD layoutD{}; + + CATLASS_DEVICE + Params() {}; + + CATLASS_DEVICE + Params(__gm__ ElementRawScale *ptrScale_, LayoutScale const &layoutScale_, + __gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_, + __gm__ ElementD *ptrD_, LayoutD const &layoutD_) + : ptrScale(ptrScale_), + layoutScale(layoutScale_), + ptrPerTokenScale(ptrPerTokenScale_), + layoutPerTokenScale(layoutPerTokenScale_), + ptrD(ptrD_), + layoutD(layoutD_) + {} + }; + + CATLASS_DEVICE + BlockEpilogue(Arch::Resource const &resource, Params const ¶ms = Params{}) : params(params) + { + size_t ubOffset = 0; + int32_t eventVMTE2 = 0; + int32_t eventMTE2V = 0; + int32_t eventMTE3V = 0; + int32_t eventVMTE3 = 0; + int32_t eventMTE3MTE2 = 0; + int32_t eventMTE2MTE3 = 0; + for (uint32_t i = 0; i < UB_STAGES; ++i) { + ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementC); + ubDList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementD); + + eventUbCVMTE2List[i] = eventVMTE2++; + eventUbCMTE2VList[i] = eventMTE2V++; + eventUbDMTE3VList[i] = eventMTE3V++; + eventUbDVMTE3List[i] = eventVMTE3++; + eventUbMTE3MTE2List[i] = eventMTE3MTE2++; + eventUbMTE2MTE3List[i] = eventMTE2MTE3++; + + AscendC::SetFlag(eventUbCVMTE2List[i]); + AscendC::SetFlag(eventUbDMTE3VList[i]); + AscendC::SetFlag(eventUbMTE3MTE2List[i]); + } + ubDenominatorMxN = resource.ubBuf.template GetBufferByByte(ubOffset); + } + + CATLASS_DEVICE + ~BlockEpilogue() + { + for (uint32_t i = 0; i < UB_STAGES; ++i) { + AscendC::WaitFlag(eventUbCVMTE2List[i]); + AscendC::WaitFlag(eventUbDMTE3VList[i]); + AscendC::WaitFlag(eventUbMTE3MTE2List[i]); + } + } + + CATLASS_DEVICE + void UpdateParams(Params const ¶ms_) + { + params = params_; + } + + CATLASS_DEVICE + void operator()(GemmCoord const &blockShapeMNK, GemmCoord const &blockCoordMNK, + GemmCoord const &actualBlockShapeMNK, AscendC::GlobalTensor const &gmBlockC, + LayoutC const &layoutBlockC, Callback &&callback = Callback{}) + { + if (0 == actualBlockShapeMNK.k()) { + return; + } + callback(); + ubListId = 0; + // Calculate the offset of the current block + MatrixCoord blockShape = blockShapeMNK.GetCoordMN(); + MatrixCoord blockCoord = blockCoordMNK.GetCoordMN(); + MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN(); + MatrixCoord blockOffset = blockCoord * blockShape; + bool isLeft = blockOffset.column() < (params.layoutD.shape(1) >> 1); + AscendC::GlobalTensor gmScale; + gmScale.SetGlobalBuffer(params.ptrScale); + AscendC::GlobalTensor gmPerTokenScale; + gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale); + AscendC::GlobalTensor gmD; + gmD.SetGlobalBuffer(params.ptrD); + + auto ubTileStride = MakeCoord(static_cast(TileShape::COLUMN), 1L); + auto tileShape = TileShape::ToCoord(); + EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape); + uint32_t tileLoops = epilogueTileSwizzle.GetLoops(); + uint32_t subblockIdx = 0; // for 1C1V + uint32_t subblockNum = 1; // for 1C1V + for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; loopIdx += subblockNum) { + auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx); + auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord); + auto tileOffsetInBlock = tileCoord * tileShape; + auto tileOffset = blockOffset + tileOffsetInBlock; + + auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)]; + auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape); + + auto &ubC = ubCList[ubListId]; + LayoutC layoutUbC{actualTileShape, ubTileStride}; + + auto &ubD = ubDList[ubListId]; + LayoutD layoutUbD{actualTileShape, ubTileStride}; + auto gmTileD = gmD[params.layoutD.GetOffset(tileOffset)]; + auto layoutGmTileD = params.layoutD.GetTileLayout(actualTileShape); + + if (isLeft) { + AscendC::WaitFlag(eventUbCVMTE2List[ubListId]); + copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC); + AscendC::SetFlag(eventUbCMTE2VList[ubListId]); + AscendC::WaitFlag(eventUbCMTE2VList[ubListId]); + AscendC::Muls(ubDenominatorMxN, ubC, -1.0f, TileShape::COUNT); + AscendC::PipeBarrier(); + AscendC::Exp(ubDenominatorMxN, ubDenominatorMxN, TileShape::COUNT); + AscendC::PipeBarrier(); + AscendC::Adds(ubDenominatorMxN, ubDenominatorMxN, 1.0f, TileShape::COUNT); + AscendC::PipeBarrier(); + AscendC::WaitFlag(eventUbDMTE3VList[ubListId]); + AscendC::Div(ubD, ubC, ubDenominatorMxN, TileShape::COUNT); + AscendC::SetFlag(eventUbCVMTE2List[ubListId]); + AscendC::SetFlag(eventUbDVMTE3List[ubListId]); + AscendC::WaitFlag(eventUbDVMTE3List[ubListId]); + copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD); + AscendC::SetFlag(eventUbDMTE3VList[ubListId]); + } else { + AscendC::WaitFlag(eventUbMTE3MTE2List[ubListId]); + copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC); + AscendC::SetFlag(eventUbMTE2MTE3List[ubListId]); + AscendC::WaitFlag(eventUbMTE2MTE3List[ubListId]); + copyUbToGmD(gmTileD, ubC, layoutGmTileD, layoutUbD); + AscendC::SetFlag(eventUbMTE3MTE2List[ubListId]); + } + + ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0; + } + } + +private: + Params params; + + AscendC::LocalTensor ubCList[UB_STAGES]; + AscendC::LocalTensor ubDList[UB_STAGES]; + + int32_t eventUbCVMTE2List[UB_STAGES]; + int32_t eventUbCMTE2VList[UB_STAGES]; + int32_t eventUbDMTE3VList[UB_STAGES]; + int32_t eventUbDVMTE3List[UB_STAGES]; + int32_t eventUbMTE3MTE2List[UB_STAGES]; + int32_t eventUbMTE2MTE3List[UB_STAGES]; + + uint32_t ubListId{0}; + + AscendC::LocalTensor ubDenominatorMxN; + + TileRowBroadcastMul tileRowBroadcastMul; + TileBroadcastOneBlk tileBroadcastOneBlk; + TileOneBlkColumnBroadcastMul tileOneBlkColumnBroadcastMul; + + CopyGmToUbC copyGmToUbC; + CopyGmToUbScale copyGmToUbScale; + CopyGmToUbPerTokenScale copyGmToUbPerTokenScale; + CopyUbToGmD copyUbToGmD; +}; + +} // namespace Catlass::Epilogue::Block diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/dispatch_policy.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/dispatch_policy.h index df70c101..567ea210 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/dispatch_policy.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/epilogue/dispatch_policy.h @@ -26,4 +26,17 @@ struct EpilogueAtlasA2PerTokenDequantCombine { static constexpr uint32_t EXEC_FLAG = EXEC_FLAG_; }; +template +struct EpilogueAtlasA2Swiglu { + using ArchTag = Arch::AtlasA2; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + static constexpr uint32_t EXEC_FLAG = EXEC_FLAG_; +}; + +template +struct EpilogueAtlasA2Combine { + using ArchTag = Arch::AtlasA2; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + static constexpr uint32_t EXEC_FLAG = EXEC_FLAG_; +}; } // namespace Catlass::Epilogue diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_multistage_workspace_bf16_fp16.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_multistage_workspace_bf16_fp16.h new file mode 100644 index 00000000..ddb26352 --- /dev/null +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_multistage_workspace_bf16_fp16.h @@ -0,0 +1,383 @@ +/* + * Copyright (c) 2026 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef ACT_GEMM_KERNEL_GROUPED_MATMUL_M_MULTISTAGE_WORKSPACE_BF16_FP16_HPP +#define ACT_GEMM_KERNEL_GROUPED_MATMUL_M_MULTISTAGE_WORKSPACE_BF16_FP16_HPP + +#include "ascendc/basic_api/interface/kernel_operator_list_tensor_intf.h" +#include "../../raw_distributed/cam_moe_distribute_combine.h" +#include "catlass/catlass.hpp" +#include "catlass/arch/cross_core_sync.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/detail/callback.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" + +namespace Catlass::Gemm::Kernel { + +template +class GroupedMatmulSliceMMultiStageWorkspace +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + using ElementAccumulator = typename BlockMmad::ElementAccumulator; + + using BlockEpilogue = BlockEpilogue_; + using ElementScale = typename BlockEpilogue::ElementRawScale; + using LayoutScale = typename BlockEpilogue::LayoutScale; + using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale; + using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale; + using ElementD = typename BlockEpilogue::ElementD; + using LayoutD = typename BlockEpilogue::LayoutD; + using EpilogueParams = typename BlockEpilogue::Params; + + using BlockScheduler = BlockScheduler_; + static constexpr uint32_t WORKSPACE_STAGES = WORKSPACE_STAGES_; + using ElementGroupList = ElementGroupList_; + + /// Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + uint32_t problemCount; + __gm__ ElementGroupList_ *ptrGroupList; + __gm__ ElementA *ptrA; + LayoutA layoutA; + __gm__ ElementB *ptrB; + LayoutB layoutB; + __gm__ ElementScale *ptrScale; + LayoutScale layoutScale; + __gm__ ElementPerTokenScale *ptrPerTokenScale; + LayoutPerTokenScale layoutPerTokenScale; + __gm__ ElementD *ptrD; + LayoutD layoutD; + GM_ADDR ptrWorkspace; + void *combiner; + + // Methods + CATLASS_DEVICE + Params() {} + + CATLASS_DEVICE + Params(GemmCoord problemShape_, uint32_t problemCount_, GM_ADDR ptrGroupList_, GM_ADDR ptrA_, LayoutA layoutA_, + GM_ADDR ptrB_, LayoutB layoutB_, GM_ADDR ptrScale_, LayoutScale layoutScale_, GM_ADDR ptrPerTokenScale_, + LayoutPerTokenScale layoutPerTokenScale_, GM_ADDR ptrD_, LayoutD layoutD_, GM_ADDR ptrWorkspace_, + void *combiner_) + : problemShape(problemShape_), + problemCount(problemCount_), + ptrGroupList(reinterpret_cast<__gm__ ElementGroupList *>(ptrGroupList_)), + ptrA(reinterpret_cast<__gm__ ElementA *>(ptrA_)), + layoutA(layoutA_), + ptrB(reinterpret_cast<__gm__ ElementB *>(ptrB_)), + layoutB(layoutB_), + ptrScale(reinterpret_cast<__gm__ ElementScale *>(ptrScale_)), + layoutScale(layoutScale_), + ptrPerTokenScale(reinterpret_cast<__gm__ ElementPerTokenScale *>(ptrPerTokenScale_)), + layoutPerTokenScale(layoutPerTokenScale_), + ptrD(reinterpret_cast<__gm__ ElementD *>(ptrD_)), + layoutD(layoutD_), + ptrWorkspace(ptrWorkspace_), + combiner(combiner_) + {} + }; + + // Methods + CATLASS_DEVICE + GroupedMatmulSliceMMultiStageWorkspace() + { + Arch::FlagID flagId = 0; + for (uint32_t stageId = 0; stageId < WORKSPACE_STAGES; ++stageId) { + flagAicFinishStoreList[stageId] = Arch::CrossCoreFlag(flagId++); + flagAivFinishComputeList[stageId] = Arch::CrossCoreFlag(flagId++); + aicWaitFuncList[stageId] = {this, stageId}; + aicSetFuncList[stageId] = {this, stageId}; + } + } + + template + CATLASS_DEVICE void operator()(Params const ¶ms); + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + BlockScheduler blockScheduler; + BlockMmad blockMmad(resource); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer(params.ptrA); + AscendC::GlobalTensor gmB; + AscendC::ListTensorDesc gmBlistTensorDesc(reinterpret_cast<__gm__ void *>(params.ptrB)); + if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) { + gmB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(gmBlistTensorDesc.GetDataPtr(0))); + } + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + uint32_t coreIdx = AscendC::GetBlockIdx(); + uint32_t coreNum = AscendC::GetBlockNum(); + int64_t gmGroupOffsetA = 0; + int64_t gmGroupOffsetB = 0; + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + auto layoutC = layout::RowMajor{L1TileShape::M * coreNum * WORKSPACE_STAGES, L1TileShape::N}; + + uint32_t stageId = 0; + uint32_t stageUsed = 0; + uint32_t startCoreIdx = 0; + for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + if constexpr (EXEC_FLAG & EXEC_FLAG_TENSOR_LIST) { + gmB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>( + gmBlistTensorDesc.GetDataPtr(groupIdx))); + } + uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx) + : (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1)); + GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; + + LayoutA layoutA = params.layoutA.GetTileLayout(inGroupProblemShape.GetCoordMK()); + LayoutB layoutB = params.layoutB; + + blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + // Determine the starting loopIdx of the current core under the current + // groupIdx + uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; + // Loop through the matmul of each groupIdx + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { + // Compute block location + GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord); + + Callback callbackBeforeFixpipe{}; + if (stageUsed == WORKSPACE_STAGES) { + callbackBeforeFixpipe = MakeCallback(&aicWaitFuncList[stageId]); + } else { + ++stageUsed; + } + Callback callbackAfterFixpipe = MakeCallback(&aicSetFuncList[stageId]); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, 0}; + int64_t gmOffsetA = layoutA.GetOffset(offsetA); + int64_t gmOffsetB = layoutB.GetOffset(offsetB); + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + + // Compute block-scoped matrix multiply-add + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, + gmC[gmOffsetC], layoutC, actualBlockShape, callbackBeforeFixpipe, callbackAfterFixpipe); + } else { + callbackBeforeFixpipe(); + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, + gmC[gmOffsetC], layoutC, actualBlockShape); + callbackAfterFixpipe(); + } + + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k(); + if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) { + gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + } + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; + } + + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad.SynchronizeBlock(); + } + + while (stageUsed > 0) { + uint32_t aivComputeStageId = + (stageId >= stageUsed) ? (stageId - stageUsed) : (stageId + WORKSPACE_STAGES - stageUsed); + Arch::CrossCoreWaitFlag(flagAivFinishComputeList[aivComputeStageId]); + --stageUsed; + } + } + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + auto *combiner = (MoeDistributeCombineImpl::CamMoeDistributeCombine *)params.combiner; + { + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + if (get_subblockid() == 0) { + AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(MoeDistributeCombineImpl::RECV_SYNC_EVENT_ID); + } + } + BlockScheduler blockScheduler; + BlockEpilogue blockEpilogue(resource, combiner->GetCalcInfo()); + + uint32_t coreIdx = AscendC::GetBlockIdx() / AscendC::GetSubBlockNum(); + uint32_t coreNum = AscendC::GetBlockNum(); + int64_t gmGroupOffsetScale = 0; + int64_t gmGroupOffsetPerTokenScale = 0; + int64_t gmGroupOffsetD = 0; + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + auto layoutC = layout::RowMajor{L1TileShape::M * coreNum * WORKSPACE_STAGES, L1TileShape::N}; + + uint32_t stageId = 0; + uint32_t startCoreIdx = 0; + AscendC::ListTensorDesc gmScaleListTensor; + gmScaleListTensor = AscendC::ListTensorDesc(reinterpret_cast<__gm__ void *>(params.ptrScale)); + __gm__ ElementScale* gmScalePtr; + if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) { + gmScalePtr = reinterpret_cast<__gm__ ElementScale*>(gmScaleListTensor.GetDataPtr(0)); + } + for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx) + : (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1)); + GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; + + LayoutScale layoutScale = params.layoutScale; + LayoutPerTokenScale layoutPerTokenScale = + params.layoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>()); + LayoutD layoutD = params.layoutD.GetTileLayout(inGroupProblemShape.GetCoordMN()); + EpilogueParams epilogueParams; + if constexpr (EXEC_FLAG & EXEC_FLAG_TENSOR_LIST) { + gmScalePtr = reinterpret_cast<__gm__ ElementScale*>( + gmScaleListTensor.GetDataPtr(groupIdx)); + epilogueParams = EpilogueParams { + gmScalePtr, layoutScale, + params.ptrPerTokenScale + gmGroupOffsetPerTokenScale, layoutPerTokenScale, + params.ptrD + gmGroupOffsetD, layoutD}; + } else { + epilogueParams = EpilogueParams{gmScalePtr + gmGroupOffsetScale, + layoutScale, + params.ptrPerTokenScale + gmGroupOffsetPerTokenScale, + layoutPerTokenScale, + params.ptrD + gmGroupOffsetD, + layoutD}; + } + blockScheduler.Update(inGroupProblemShape, L1TileShape::ToCoordMN()); + blockEpilogue.UpdateParams(epilogueParams); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + GemmCoord blockShapeMNK = L1TileShape::ToCoord(); + uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { + GemmCoord blockCoordMNK = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShapeMNK = blockScheduler.GetActualBlockShape(blockCoordMNK); + + MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, 0}; + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + auto gmBlockC = gmC[gmOffsetC]; + auto layoutBlockC = layoutC.GetTileLayout(actualBlockShapeMNK.GetCoordMN()); + + Arch::CrossCoreWaitFlag(flagAicFinishStoreList[stageId]); + blockEpilogue(gmGroupOffsetD, groupIdx, blockShapeMNK, blockCoordMNK, actualBlockShapeMNK, gmBlockC, + layoutBlockC); + Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(flagAivFinishComputeList[stageId]); + + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) { + gmGroupOffsetScale += inGroupProblemShape.n(); + } + gmGroupOffsetPerTokenScale += inGroupProblemShape.m(); + gmGroupOffsetD += inGroupProblemShape.m() * inGroupProblemShape.n(); + + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; + } + } + + icache_preload(4); + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + if (get_subblockid() == 0) { + resource.pipe.Init(); + combiner->TPipeSet(&resource.pipe); + combiner->AllToAllSend(); + combiner->TPipeSet(nullptr); + resource.pipe.Destroy(); + } else { + resource.pipe.Init(); + combiner->TPipeSet(&resource.pipe); + combiner->ReducePermute(); + combiner->TPipeSet(nullptr); + resource.pipe.Destroy(); + } + } else { + resource.pipe.Init(); + combiner->TPipeSet(&resource.pipe); + combiner->Process(); + combiner->TPipeSet(nullptr); + resource.pipe.Destroy(); + } + } + +private: + friend struct AicWaitFunc; + friend struct AicSetFunc; + + struct AicWaitFunc { + using MatmulKernel = + GroupedMatmulSliceMMultiStageWorkspace; + + CATLASS_DEVICE + AicWaitFunc() = default; + + CATLASS_DEVICE + void operator()() const + { + Arch::CrossCoreWaitFlag(ptr->flagAivFinishComputeList[stageId]); + } + + MatmulKernel *ptr{nullptr}; + uint32_t stageId; + }; + + struct AicSetFunc { + using MatmulKernel = + GroupedMatmulSliceMMultiStageWorkspace; + + CATLASS_DEVICE + AicSetFunc() = default; + + CATLASS_DEVICE + void operator()() const + { + Arch::CrossCoreSetFlag<0x2, PIPE_FIX>(ptr->flagAicFinishStoreList[stageId]); + } + + MatmulKernel *ptr{nullptr}; + uint32_t stageId; + }; + + Arch::CrossCoreFlag flagAicFinishStoreList[WORKSPACE_STAGES]; + Arch::CrossCoreFlag flagAivFinishComputeList[WORKSPACE_STAGES]; + + AicWaitFunc aicWaitFuncList[WORKSPACE_STAGES]; + AicSetFunc aicSetFuncList[WORKSPACE_STAGES]; + Arch::Resource resource; +}; + +} // namespace Catlass::Gemm::Kernel + +#endif // ACT_GEMM_KERNEL_GROUPED_MATMUL_M_MULTISTAGE_WORKSPACE_BF16_FP16_HPP diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h index 967e5869..b7fb7623 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_swiglu_quant_multistage_workspace.h @@ -22,51 +22,6 @@ #include "../../../dispatch_gmm_combine_decode_base.h" -constexpr uint32_t STATE_OFFSET = 512; -constexpr uint64_t WIN_STATE_OFFSET = 512 * 1024; -constexpr uint64_t STATE_WIN_OFFSET = 900 * 1024; -constexpr uint64_t GROUP_TOKEN_NUM_OFFSET = 932 * 1024; -constexpr uint64_t SOFT_SYNC_OFFSET = 964 * 1024; -constexpr uint32_t SELF_STATE_OFFSET = 256 * 1024; -constexpr uint32_t SUM_TMP_TENSOR_SIZE = 1024; -constexpr uint32_t UB_ALIGN = 32; -constexpr uint32_t TOKEN_EXTRA_SPACE = 512; -constexpr uint32_t INT32_COUNT_PER_BLOCK = 8; -constexpr uint32_t SOFT_SYNC_SPACE_SIZE = 512; -constexpr int64_t LOOP_TMP_SIZE = 4096; -constexpr int32_t SUB_AIV_NUM = 2; -constexpr int32_t ODD_EVEN_BASE = 2; -constexpr int32_t BUFFER_NUM = 2; -constexpr int32_t GATHER_SECOND_NUM = 2; -constexpr uint32_t MAX_QUANT_ROW_ONCE = 8; -constexpr uint32_t QUANT_SPACE_FACTOR = 176 * 1024 / 11; // up to 176KB for quant -#define OPT_RANK_OFFSET 512 - -#define CEIL_UP(x) ((x + UB_ALIGN - 1) / UB_ALIGN * UB_ALIGN) -#define CEIL(x, y) (((x) + (y - 1)) / (y)) -#define UB_BLOCK_SIZE (32) -#define GET_WIND_STATE_ADDR_BY_RANK_ID(rankId) \ - (((epRankId == rankId) \ - ? ((GM_ADDR)(winContext_->localWindowsExp)) \ - : ((GM_ADDR)(((HcclRankRelationResV2 *)(winContext_->remoteRes[rankId].nextDevicePtr))->windowsExp))) + \ - dataState * WIN_STATE_OFFSET) -#define GET_WIND_ADDR_BY_RANK_ID(rankId) \ - (((epRankId == rankId) \ - ? ((GM_ADDR)(winContext_->localWindowsIn)) \ - : ((GM_ADDR)(((HcclRankRelationResV2 *)(winContext_->remoteRes[rankId].nextDevicePtr))->windowsIn))) + \ - winDataSizeOffset + rankId * OPT_RANK_OFFSET) -#define TOKEN_FLAG_1 (0x55555555) -#define TOKEN_FLAG_2 (0x33333333) -#define V_TO_C_FLAG_1 (0x03030303) -#define V_TO_C_FLAG_2 (0x05050505) -#define CV_FLAG_INDEX 0 -#define GROUP_ID_INDEX 1 -#define PRE_COUNT_INDEX 2 -#define SELF_COUNT_INDEX 3 -#define TOTAL_COUNT_INDEX 4 -#define GROUP_TOKEN_COUNT 3 // equal to SELF_COUNT_INDEX -#define GROUP_INFO_SIZE 32 - namespace Catlass::Gemm::Kernel { template @@ -306,54 +261,6 @@ private: Epilogue::Tile::CopyUb2Gm copyUbToGmOutput; }; -__aicore__ inline static void EncreaseSyncFlag(__gm__ uint8_t *flagAddr, uint8_t idx) -{ - // flag++, like set flag - AscendC::PipeBarrier(); - AscendC::GlobalTensor global; - global.SetGlobalBuffer(flagAddr + idx * SOFT_SYNC_SPACE_SIZE); - __asm__ __volatile__(""); - AscendC::DataCacheCleanAndInvalid( - global); - __asm__ __volatile__(""); - uint8_t value = global.GetValue(0); - global.SetValue(0, value + 1); - __asm__ __volatile__(""); - AscendC::DataCacheCleanAndInvalid( - global); - __asm__ __volatile__(""); - AscendC::PipeBarrier(); -} - -__aicore__ inline static void CheckSyncFlag(__gm__ uint8_t *flagAddr, uint8_t idx, uint32_t target) -{ - // check flag, like wait flag - AscendC::PipeBarrier(); - AscendC::GlobalTensor global; - global.SetGlobalBuffer(flagAddr + idx * SOFT_SYNC_SPACE_SIZE); - while (true) { - __asm__ __volatile__(""); - AscendC::DataCacheCleanAndInvalid(global); - __asm__ __volatile__(""); - uint8_t value = global.GetValue(0); - if (value >= target) { - __asm__ __volatile__(""); - AscendC::DataCacheCleanAndInvalid(global); - __asm__ __volatile__(""); - break; - } - } - AscendC::PipeBarrier(); -} - -__aicore__ inline static void CalQuantRow(const uint32_t column, uint32_t &row) -{ - row = QUANT_SPACE_FACTOR / column; - row = row < MAX_QUANT_ROW_ONCE ? row : MAX_QUANT_ROW_ONCE; -} - template class GroupedMatmulSliceMPerTokenDequantSwigluQuantMultiStageWorkspace diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_swiglu_multistage_workspace_bf16_fp16.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_swiglu_multistage_workspace_bf16_fp16.h new file mode 100644 index 00000000..72702d5a --- /dev/null +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_swiglu_multistage_workspace_bf16_fp16.h @@ -0,0 +1,1805 @@ +/* + * Copyright (c) 2026 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#pragma once + +#include "ascendc/basic_api/interface/kernel_operator_list_tensor_intf.h" +#include "catlass/catlass.hpp" +#include "catlass/arch/cross_core_sync.hpp" +#include "catlass/arch/resource.hpp" +#include "catlass/coord.hpp" +#include "catlass/detail/callback.hpp" +#include "catlass/gemm_coord.hpp" +#include "catlass/matrix_coord.hpp" +#include "catlass/epilogue/tile/tile_swizzle.hpp" +#include "catlass/epilogue/tile/tile_copy.hpp" + +#include "../../../dispatch_gmm_combine_decode_base.h" + +namespace Catlass::Gemm::Kernel { + +template +class SwigluPost +{ +public: + using ElementInput = float; + using LayoutInput = layout::RowMajor; + using ElementSwigluScale = float; + using LayoutSwigluScale = layout::VectorLayout; + using ElementOutput = ElementOutput_; + using LayoutOutput = layout::RowMajor; + + using InputType = GemmType; + using OutputType = GemmType; + + using EpilogueTileSwizzle = Epilogue::Tile::EpilogueHorizontalTileSwizzle; + + struct Params { + __gm__ ElementInput *ptrInput{nullptr}; + LayoutInput layoutInput; + __gm__ ElementSwigluScale *ptrSwigluScale{nullptr}; + LayoutSwigluScale layoutSwigluScale; + __gm__ ElementOutput *ptrOutput{nullptr}; + LayoutOutput layoutOutput; + uint32_t tileRow; + uint32_t tileColumn; + + CATLASS_DEVICE + Params() {}; + + CATLASS_DEVICE + Params(__gm__ ElementInput *ptrInput_, LayoutInput const &layoutInput_, + __gm__ ElementSwigluScale *ptrSwigluScale_, LayoutSwigluScale const &layoutSwigluScale_, + __gm__ ElementOutput *ptrOutput_, LayoutOutput const layoutOutput_, const uint32_t tileRow_, + const uint32_t tileColumn_) + : ptrInput(ptrInput_), + layoutInput(layoutInput_), + ptrSwigluScale(ptrSwigluScale_), + layoutSwigluScale(layoutSwigluScale_), + ptrOutput(ptrOutput_), + layoutOutput(layoutOutput_), + tileRow(tileRow_), + tileColumn(tileColumn_) + {} + }; + + CATLASS_DEVICE + SwigluPost(Arch::Resource const &resource, Params const ¶ms_) : params(params_) + { + int64_t ubOffset = 0; + tileRow = params_.tileRow; + tileColumn = params_.tileColumn; + tileCount = tileRow * tileColumn; + halfTileColumn = tileColumn / 2; + halfTileCount = tileRow * halfTileColumn; + + ubInput = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += tileCount * sizeof(ElementInput); + ubOutput = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += tileCount * sizeof(ElementOutput); + + ubInputRightHalf = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += tileCount * sizeof(float); + + AscendC::SetFlag(0); + AscendC::SetFlag(0); + AscendC::SetFlag(1); + } + + CATLASS_DEVICE + ~SwigluPost() + { + AscendC::WaitFlag(0); + AscendC::WaitFlag(0); + AscendC::WaitFlag(1); + } + + CATLASS_DEVICE + void operator()(MatrixCoord const &blockShape, MatrixCoord const &blockCoord, MatrixCoord const &actualBlockShape) + { + MatrixCoord blockOffset = blockCoord * blockShape; + + AscendC::GlobalTensor gmInput; + gmInput.SetGlobalBuffer(params.ptrInput); + AscendC::GlobalTensor gmOutput; + gmOutput.SetGlobalBuffer(params.ptrOutput); + + auto ubTileStride = MakeCoord(static_cast(tileColumn), 1L); + auto ubHalfTileStride = MakeCoord(static_cast(halfTileColumn), 1L); + auto tileShape = MakeCoord(tileRow, tileColumn); + EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape); + uint32_t tileLoops = epilogueTileSwizzle.GetLoops(); + uint32_t subblockIdx = AscendC::GetSubBlockIdx(); + uint32_t subblockNum = AscendC::GetSubBlockNum(); + for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; loopIdx += subblockNum) { + auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx); + auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord); + auto tileOffsetInBlock = tileCoord * tileShape; + auto tileOffset = blockOffset + tileOffsetInBlock; + + auto gmTileInput = gmInput[params.layoutInput.GetOffset(tileOffset)]; + auto layoutGmTileInput = params.layoutInput.GetTileLayout(actualTileShape); + + layout::RowMajor layoutUbInput{actualTileShape, ubTileStride}; + + AscendC::WaitFlag(0); + // continue swiglu computing here + copyGmToUbInput(ubInput, gmTileInput, layoutUbInput, layoutGmTileInput); + copyGmToUbInput(ubInputRightHalf, gmTileInput[params.layoutInput.shape(1) >> 1], layoutUbInput, layoutGmTileInput); + AscendC::SetFlag(0); + + AscendC::WaitFlag(0); + AscendC::Mul(ubInput, ubInput, ubInputRightHalf, tileCount); + AscendC::PipeBarrier(); + AscendC::WaitFlag(1); + AscendC::Cast(ubOutput, ubInput, AscendC::RoundMode::CAST_RINT, tileCount); + AscendC::SetFlag(1); + AscendC::SetFlag(0); + + auto gmTileOutput = gmOutput[params.layoutOutput.GetOffset(tileOffset)]; + auto layoutGmTileOutput = params.layoutOutput.GetTileLayout(actualTileShape); + + LayoutOutput layoutUbOutput{actualTileShape, ubTileStride}; + + AscendC::WaitFlag(1); + copyUbToGmOutput(gmTileOutput, ubOutput, layoutGmTileOutput, layoutUbOutput); + AscendC::SetFlag(1); + } + } + +private: + Params params; + uint32_t tileRow; + uint32_t tileColumn; + uint32_t tileCount; + uint32_t halfTileColumn; + uint32_t halfTileCount; + + AscendC::LocalTensor ubInput; + AscendC::LocalTensor ubOutput; + + AscendC::LocalTensor ubInputRightHalf; + + Epilogue::Tile::CopyGm2Ub copyGmToUbInput; + Epilogue::Tile::CopyUb2Gm copyUbToGmOutput; +}; + +template +class GroupedMatmulSliceMSwigluMultiStageWorkspace +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + using ElementAccumulator = typename BlockMmad::ElementAccumulator; + + using BlockEpilogue = BlockEpilogue_; + using ElementScale = typename BlockEpilogue::ElementRawScale; + using LayoutScale = typename BlockEpilogue::LayoutScale; + using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale; + using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale; + using ElementD = typename BlockEpilogue::ElementD; + using LayoutD = typename BlockEpilogue::LayoutD; + using EpilogueParams = typename BlockEpilogue::Params; + + using XType = ExpandXType; + using ElementSwigluScale = typename SwigluPost::ElementSwigluScale; + using LayoutSwigluScale = typename SwigluPost::LayoutSwigluScale; + using ElementOutput = typename SwigluPost::ElementOutput; + using LayoutOutput = typename SwigluPost::LayoutOutput; + + using BlockScheduler = BlockScheduler_; + static constexpr uint32_t WORKSPACE_STAGES = WORKSPACE_STAGES_; + using ElementGroupList = ElementGroupList_; + + + // Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + uint32_t problemCount; + __gm__ ElementGroupList_ *ptrGroupList; + __gm__ ElementA *ptrA; + LayoutA layoutA; + __gm__ ElementB *ptrB; + LayoutB layoutB; + __gm__ ElementScale *ptrScale; + LayoutScale layoutScale; + __gm__ ElementPerTokenScale *ptrPerTokenScale; + LayoutPerTokenScale layoutPerTokenScale; + __gm__ ElementOutput *ptrOutput; + LayoutOutput layoutOutput; + __gm__ ElementSwigluScale *ptrSwigluScale; + LayoutSwigluScale layoutSwigluScale; + GM_ADDR ptrWorkspace; + GM_ADDR gmX; + GM_ADDR debugGm; + GM_ADDR gmexpertIds; + GM_ADDR gmXActiveMask; + + GM_ADDR gmExpandIdx; + GM_ADDR gmEpSendCount; + GM_ADDR gmResvered; + GM_ADDR gmExpertTokenNums; + + uint32_t epRankSize; + uint32_t epRankId; + uint32_t moeExpertNum; + uint32_t moeExpertNumPerRank; + uint32_t sharedExpertNum; + uint32_t sharedExpertRankNum; + uint32_t quantMode; + uint32_t globalBs; + uint32_t bs; + uint32_t topK; + uint32_t tokenLen; + // Methods + CATLASS_DEVICE + Params() {} + + CATLASS_DEVICE + Params(GemmCoord problemShape_, uint32_t problemCount_, GM_ADDR ptrGroupList_, GM_ADDR ptrA_, + LayoutA const &layoutA_, GM_ADDR ptrB_, LayoutB const &layoutB_, GM_ADDR ptrScale_, + LayoutScale const &layoutScale_, GM_ADDR ptrPerTokenScale_, + LayoutPerTokenScale const &layoutPerTokenScale_, GM_ADDR ptrOutput_, LayoutOutput const &layoutOutput_, + GM_ADDR ptrSwigluScale_, LayoutSwigluScale const &layoutSwigluScale_, GM_ADDR ptrWorkspace_, + GM_ADDR gmX_, GM_ADDR debugGm_, GM_ADDR gmexpertIds_, GM_ADDR gmExpandIdx_, GM_ADDR gmEpSendCount_, GM_ADDR gmXActiveMask_, + GM_ADDR gmResvered_, GM_ADDR gmExpertTokenNums_, uint32_t epRankSize_, uint32_t epRankId_, + uint32_t moeExpertNum_, uint32_t moeExpertNumPerRank_, uint32_t sharedExpertNum_, + uint32_t sharedExpertRankNum_, uint32_t quantMode_, uint32_t globalBs_, uint32_t bs_, uint32_t topK_, + uint32_t h) + : problemShape(problemShape_), + problemCount(problemCount_), + ptrGroupList(reinterpret_cast<__gm__ ElementGroupList *>(ptrGroupList_)), + ptrA(reinterpret_cast<__gm__ ElementA *>(ptrA_)), + layoutA(layoutA_), + ptrB(reinterpret_cast<__gm__ ElementB *>(ptrB_)), + layoutB(layoutB_), + ptrScale(reinterpret_cast<__gm__ ElementScale *>(ptrScale_)), + layoutScale(layoutScale_), + ptrPerTokenScale(reinterpret_cast<__gm__ ElementPerTokenScale *>(ptrPerTokenScale_)), + layoutPerTokenScale(layoutPerTokenScale_), + ptrOutput(reinterpret_cast<__gm__ ElementOutput *>(ptrOutput_)), + layoutOutput(layoutOutput_), + ptrSwigluScale(reinterpret_cast<__gm__ ElementSwigluScale *>(ptrSwigluScale_)), + layoutSwigluScale(layoutSwigluScale_), + ptrWorkspace(ptrWorkspace_), + gmX(gmX_), + debugGm(debugGm_), + gmexpertIds(gmexpertIds_), + gmExpandIdx(gmExpandIdx_), + gmEpSendCount(gmEpSendCount_), + gmExpertTokenNums(gmExpertTokenNums_), + gmXActiveMask(gmXActiveMask_), + gmResvered(gmResvered_), + epRankSize(epRankSize_), + epRankId(epRankId_), + moeExpertNum(moeExpertNum_), + moeExpertNumPerRank(moeExpertNumPerRank_), + sharedExpertNum(sharedExpertNum_), + sharedExpertRankNum(sharedExpertRankNum_), + quantMode(quantMode_), + globalBs(globalBs_), + bs(bs_), + topK(topK_), + tokenLen(h) + {} + }; + + // Methods + CATLASS_DEVICE + GroupedMatmulSliceMSwigluMultiStageWorkspace() {} + + template + CATLASS_DEVICE void operator()(Params const ¶ms); + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + aicIdx = AscendC::GetBlockIdx(); + subBlockNum = AscendC::GetSubBlockNum(); + aiCoreGroupNum = AscendC::GetBlockNum(); + aicNum = aiCoreGroupNum; + aivNum = aiCoreGroupNum * SUB_AIV_NUM; + aicStateGlobalCoreIdx = aivNum + aicIdx; + moeExpertNumPerRank = params.moeExpertNumPerRank; + isShareExpert = (params.epRankId < params.sharedExpertRankNum); + localExpertNum = isShareExpert ? 1 : moeExpertNumPerRank; + // when localExpertNum=1, all cores send token and recv token in sequence + recvCoreNum = aivNum; + // when localExpertNum>1, half of cores send token and another half recv token in parallel + if (localExpertNum > 1) { + recvCoreNum = aiCoreGroupNum; + } + uint32_t coreNumPerGroup = recvCoreNum / localExpertNum; + winContext_ = (__gm__ HcclOpResParam *)AscendC::GetHcclContext(); + + // state of cv flag + statusDataSpaceGm = (GM_ADDR)(winContext_->localWindowsExp); + AscendC::GlobalTensor selfDataStatusTensor; + selfDataStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + STATE_WIN_OFFSET)); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + selfDataStatusTensor[aicStateGlobalCoreIdx * UB_ALIGN]); + __asm__ __volatile__(""); + cvDataState = selfDataStatusTensor(aicStateGlobalCoreIdx * UB_ALIGN); + if (cvDataState == 0) { + selfDataStatusTensor(aicStateGlobalCoreIdx * UB_ALIGN) = 1; + vToCFlag = V_TO_C_FLAG_1; + } else { + selfDataStatusTensor(aicStateGlobalCoreIdx * UB_ALIGN) = 0; + vToCFlag = V_TO_C_FLAG_2; + } + + BlockScheduler blockScheduler; + BlockMmad blockMmad(resource); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer(params.ptrA); + AscendC::GlobalTensor gmB; + AscendC::ListTensorDesc gmBlistTensorDesc(reinterpret_cast<__gm__ void *>(params.ptrB)); + if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) { + gmB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>(gmBlistTensorDesc.GetDataPtr(0))); + } + + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + int64_t gmGroupOffsetA = 0; + int64_t gmGroupOffsetB = 0; + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + auto layoutC = layout::RowMajor{L1TileShape::M * aicNum * WORKSPACE_STAGES, L1TileShape::N}; + + uint32_t stageId = 0; + uint32_t stageUsed = 0; + uint32_t startCoreIdx = 0; + AscendC::GlobalTensor groupTokenNumStateTensor; + aicSetFunc1 = {statusDataSpaceGm + SOFT_SYNC_OFFSET, + static_cast(aicNum + AscendC::GetBlockIdx())}; // AIV wait for flags in latter part + uint32_t target = 1; + for (uint32_t groupIdx = 0; groupIdx < localExpertNum; ++groupIdx) { + if constexpr (EXEC_FLAG & EXEC_FLAG_TENSOR_LIST) { + gmB.SetGlobalBuffer(reinterpret_cast<__gm__ ElementB *>( + gmBlistTensorDesc.GetDataPtr(groupIdx))); + } + groupTokenNumStateTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + GROUP_TOKEN_NUM_OFFSET) + + groupIdx * GROUP_INFO_SIZE); + // wait AIV recv needed tokens + while (true) { + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid(groupTokenNumStateTensor); + __asm__ __volatile__(""); + if (groupTokenNumStateTensor.GetValue(0) == coreNumPerGroup * vToCFlag) { + break; + } + } + + uint32_t currentM = groupTokenNumStateTensor.GetValue(GROUP_TOKEN_COUNT); + GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; + + LayoutA layoutA = params.layoutA.GetTileLayout(inGroupProblemShape.GetCoordMK()); + LayoutB layoutB = params.layoutB; + + blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + // Determine the starting loopIdx of the current core under the current groupIdx + uint32_t startLoopIdx = ((aicIdx < startCoreIdx) ? (aicIdx + aicNum) : aicIdx) - startCoreIdx; + // Loop through the matmul of each groupIdx + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += aicNum) { + // Compute block location + GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord); + + Callback callbackBeforeFixpipe{}; + if (stageUsed == WORKSPACE_STAGES) { + aicWaitFunc1 = {statusDataSpaceGm + SOFT_SYNC_OFFSET, static_cast(AscendC::GetBlockIdx()), + target}; // AIC wait for flags in former part + target += 1; + callbackBeforeFixpipe = MakeCallback(&aicWaitFunc1); + } else { + ++stageUsed; + } + Callback callbackAfterFixpipe = MakeCallback(&aicSetFunc1); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{(stageId * aicNum + aicIdx) * L1TileShape::M, 0}; + int64_t gmOffsetA = layoutA.GetOffset(offsetA); + int64_t gmOffsetB = layoutB.GetOffset(offsetB); + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + + // Compute block-scoped matrix multiply-add + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, + gmC[gmOffsetC], layoutC, actualBlockShape, callbackBeforeFixpipe, callbackAfterFixpipe); + } else { + callbackBeforeFixpipe(); + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, + gmC[gmOffsetC], layoutC, actualBlockShape); + callbackAfterFixpipe(); + } + + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k(); + if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) { + gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + } + + startCoreIdx = (startCoreIdx + coreLoops) % aicNum; + } + + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad.SynchronizeBlock(); + } + + while (stageUsed > 0) { + uint32_t aivComputeStageId = + (stageId >= stageUsed) ? (stageId - stageUsed) : (stageId + WORKSPACE_STAGES - stageUsed); + target += 1; + --stageUsed; + } + AscendC::SyncAll(); + } + + CATLASS_DEVICE + void TokenActiveMaskCal(GM_ADDR gmXActiveMask, int64_t ubOffset) + { + int64_t subUbOffset = ubOffset; + AscendC::LocalTensor maskInputTensor = (resource.ubBuf.template + GetBufferByByte(subUbOffset)); + AscendC::LocalTensor maskInputInt8Tensor = maskInputTensor.template ReinterpretCast(); + subUbOffset += CEIL_UP(axisBS * sizeof(bool)); + AscendC::LocalTensor maskTmpTensor = (resource.ubBuf.template + GetBufferByByte(subUbOffset)); + subUbOffset += CEIL_UP(axisBS * sizeof(half)); + AscendC::LocalTensor sumOutTensor = (resource.ubBuf.template + GetBufferByByte(subUbOffset)); + subUbOffset += CEIL_UP(SUM_TMP_TENSOR_SIZE); + + AscendC::GlobalTensor xActiveMaskGMTensor; + xActiveMaskGMTensor.SetGlobalBuffer((__gm__ bool *)gmXActiveMask); + uint32_t axisBsAlignSize = CEIL_UP(axisBS * sizeof(bool)); + + AscendC::DataCopyExtParams maskParams = {1U, static_cast(axisBS * sizeof(bool)), 0U, 0U, 0U}; + AscendC::DataCopyPadExtParams maskCopyPadParams{false, 0U, 0U, 0U}; + AscendC::DataCopyPad(maskInputTensor, xActiveMaskGMTensor, maskParams, maskCopyPadParams); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::Cast(maskTmpTensor, maskInputInt8Tensor, AscendC::RoundMode::CAST_NONE, axisBS); + AscendC::PipeBarrier(); + AscendC::SumParams params{1, axisBsAlignSize, axisBS}; + AscendC::Sum(sumOutTensor, maskTmpTensor, params); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + activeMaskBsCnt = static_cast(sumOutTensor.GetValue(0)); + } + + CATLASS_DEVICE + void CalExpandxIdx(int32_t dstExpertId, uint32_t tokenIndex, int32_t &curExpertCnt, int64_t ubOffset) + { + // calculate index in remote + int64_t subUbOffset = ubOffset; + AscendC::LocalTensor dstExpIdTensor_ = (resource.ubBuf.template GetBufferByByte(ubOffset)); + subUbOffset += LOOP_TMP_SIZE; + AscendC::LocalTensor subExpIdTensor_ = (resource.ubBuf.template GetBufferByByte(ubOffset)); + subUbOffset += LOOP_TMP_SIZE; + AscendC::LocalTensor workLocalTensor_ = (resource.ubBuf.template GetBufferByByte(ubOffset)); + subUbOffset += LOOP_TMP_SIZE; + AscendC::Duplicate(dstExpIdTensor_, dstExpertId, tokenIndex); + AscendC::PipeBarrier(); + AscendC::Sub(subExpIdTensor_, expertIdsTensor_, dstExpIdTensor_, tokenIndex); + AscendC::PipeBarrier(); + AscendC::LocalTensor tmpFp32 = subExpIdTensor_.ReinterpretCast(); + AscendC::LocalTensor tmpoutFp32 = dstExpIdTensor_.ReinterpretCast(); + AscendC::Abs(tmpoutFp32, tmpFp32, tokenIndex); + AscendC::PipeBarrier(); + AscendC::Mins(subExpIdTensor_, dstExpIdTensor_, 1, tokenIndex); + AscendC::PipeBarrier(); + AscendC::ReduceSum(tmpoutFp32, tmpFp32, workLocalTensor_, tokenIndex); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + int32_t curOtherExpertCnt = dstExpIdTensor_(0); + if (tokenIndex > curOtherExpertCnt) { + curExpertCnt = tokenIndex - curOtherExpertCnt; + } + } + + CATLASS_DEVICE + void CalAndSendTokenCount() + { + uint32_t totalExpertNum = sharedExpertRankNum + moeExpertNum; + uint32_t sendCountExpertNum = totalExpertNum / sendCoreNum; + uint32_t remainderRankNum = totalExpertNum % sendCoreNum; + uint32_t startExpertId = sendCountExpertNum * sendCoreIdx; + if (sendCoreIdx < remainderRankNum) { + sendCountExpertNum += 1; + startExpertId += sendCoreIdx; + } else { + startExpertId += remainderRankNum; + } + uint32_t endExpertId = startExpertId + sendCountExpertNum; + if (startExpertId >= totalExpertNum) { + return; + } + + AscendC::LocalTensor statusTensor_ = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += CEIL_UP(CEIL(expertCntUp, INT32_COUNT_PER_BLOCK) * INT32_COUNT_PER_BLOCK * UB_BLOCK_SIZE); + AscendC::Duplicate(statusTensor_, (int32_t)0, + expertCntUp * INT32_COUNT_PER_BLOCK); + if (state == 0) { + // set the first number of every 8 numbers as 0x3F800000(float 1.0) + uint64_t mask[2] = {0x101010101010101, 0}; + AscendC::PipeBarrier(); + AscendC::Duplicate(statusTensor_, 0x3F800000, mask, CEIL(expertCntUp, 8), 1, 8); + } + + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + + if (!isShareExpert) { + for (uint32_t curSatatusExpId = 0; curSatatusExpId < sharedExpertRankNum; ++curSatatusExpId) { + int32_t curExpertCnt = (curSatatusExpId + 1 + epRankId) * axisBS / sharedExpertRankNum - + (curSatatusExpId + epRankId) * axisBS / sharedExpertRankNum; + statusTensor_((curSatatusExpId)*INT32_COUNT_PER_BLOCK + 1) = curExpertCnt; + } + } + + for (uint32_t curExpertId = startExpertId; curExpertId < endExpertId; ++curExpertId) { + if (curExpertId < sharedExpertRankNum) { + continue; + } + int32_t curExpertCnt = 0; + int32_t dstExpertId = curExpertId - sharedExpertRankNum; + CalExpandxIdx(dstExpertId, expertIdsCnt, curExpertCnt, ubOffset); + int32_t cntPosIndex = curExpertId * INT32_COUNT_PER_BLOCK + 1; + statusTensor_(cntPosIndex) = curExpertCnt; + } + + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + + AscendC::GlobalTensor rankGMTensor; + uint32_t offset = stateOffset * epRankId; + for (uint32_t rankIndex = startExpertId; rankIndex < endExpertId; ++rankIndex) { + uint32_t dstRankId = rankIndex; + if (moeExpertNumPerRank > 1 && (rankIndex >= sharedExpertRankNum)) { + dstRankId = ((rankIndex - sharedExpertRankNum) / moeExpertNumPerRank + sharedExpertRankNum); + offset = + (epRankId + (rankIndex - sharedExpertRankNum) % moeExpertNumPerRank * epRankSize) * stateOffset; + } + GM_ADDR rankGM = (__gm__ uint8_t *)(GET_WIND_STATE_ADDR_BY_RANK_ID(dstRankId) + offset); + rankGMTensor.SetGlobalBuffer((__gm__ int32_t *)rankGM); + AscendC::DataCopy(rankGMTensor, statusTensor_[rankIndex * 8], 8UL); + } + } + + CATLASS_DEVICE + void SendToShareExprt(GM_ADDR gmX, GM_ADDR gmX1, GM_ADDR gmX1Scale) + { + uint32_t newAivId = sendCoreIdx - sendToMoeAivNum; + uint32_t sendTokenNum = activeMaskBsCnt / sendToShareAivNum; + uint32_t remainderTokenNum = activeMaskBsCnt % sendToShareAivNum; + uint32_t startTokenId = sendTokenNum * newAivId; + if (newAivId < remainderTokenNum) { + sendTokenNum += 1; + startTokenId += newAivId; + } else { + startTokenId += remainderTokenNum; + } + uint32_t endTokenId = startTokenId + sendTokenNum; + if (startTokenId >= activeMaskBsCnt) { + return; + } + + AscendC::LocalTensor xInTensor[BUFFER_NUM]; + AscendC::LocalTensor xInt32Tensor[BUFFER_NUM]; + + AscendC::GlobalTensor srcWinGMTensor; + srcWinGMTensor.SetGlobalBuffer((__gm__ XType *)gmX); + + xInTensor[0] = resource.ubBuf.template GetBufferByByte(ubOffset); + xInt32Tensor[0] = xInTensor[0].template ReinterpretCast(); + ubOffset += CEIL_UP(axisHCommuBf16Fp16 * sizeof(XType)); + xInTensor[1] = resource.ubBuf.template GetBufferByByte(ubOffset); + xInt32Tensor[1] = xInTensor[1].template ReinterpretCast(); + ubOffset += CEIL_UP(axisHCommuBf16Fp16 * sizeof(XType)); + AscendC::GlobalTensor dstWinGMTensor; + AscendC::GlobalTensor expandXOutGlobal; + expandXOutGlobal.SetGlobalBuffer((__gm__ XType *)(gmX1)); + + // double buffer + AscendC::SetFlag(0); + AscendC::SetFlag(1); + + for (uint32_t tokenIndex = startTokenId; tokenIndex < endTokenId; ++tokenIndex) { + uint32_t index = (tokenIndex & 1) ? 0 : 1; + int32_t eventId = (tokenIndex & 1) ? 0 : 1; + uint32_t temp = (epRankId * axisBS) / sharedExpertRankNum; + uint32_t moeOnShareRank = CEIL((tokenIndex + 1 + temp) * sharedExpertRankNum, axisBS) - 1 - epRankId; + uint32_t preCnt = (moeOnShareRank + epRankId) * axisBS / sharedExpertRankNum - + epRankId * axisBS / sharedExpertRankNum; + dstWinGMTensor.SetGlobalBuffer( + (__gm__ XType *)(GET_WIND_ADDR_BY_RANK_ID(moeOnShareRank) + expertPerSizeOnWin * epRankId)); + + AscendC::WaitFlag(eventId); + AscendC::DataCopy(xInTensor[index], srcWinGMTensor[tokenIndex * tokenLength], tokenLength); + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + xInt32Tensor[index](hOutSize / sizeof(int32_t)) = tokenFlag; + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + + if (isShareExpert) { + AscendC::DataCopy(expandXOutGlobal[tokenIndex * tokenLength], xInTensor[index], tokenLength); + } else { + AscendC::DataCopy(dstWinGMTensor[(tokenIndex - preCnt) * axisHCommuBf16Fp16], xInTensor[index], + tokenLength); + AscendC::PipeBarrier(); + AscendC::DataCopy(dstWinGMTensor[(tokenIndex - preCnt) * axisHCommuBf16Fp16 + tokenLength], + xInTensor[index][hOutSize / sizeof(XType)], 16); + } + AscendC::SetFlag(eventId); + } + AscendC::WaitFlag(0); + AscendC::WaitFlag(1); + } + + CATLASS_DEVICE + void SendToMoeExprt(GM_ADDR gmX, GM_ADDR gmExpandIdx) + { + uint32_t sendTokenNum = expertIdsCnt / sendToMoeAivNum; + uint32_t remainderTokenNum = expertIdsCnt % sendToMoeAivNum; + uint32_t startTokenId = sendTokenNum * sendCoreIdx; + if (sendCoreIdx < remainderTokenNum) { + sendTokenNum += 1; + startTokenId += sendCoreIdx; + } else { + startTokenId += remainderTokenNum; + } + uint32_t endTokenId = startTokenId + sendTokenNum; + if (startTokenId >= expertIdsCnt) { + return; + } + AscendC::LocalTensor expertCountTensor = (resource.ubBuf.template GetBufferByByte(ubOffset)); + ubOffset += CEIL_UP(expertIdsCnt * sizeof(int32_t)); + AscendC::Duplicate(expertCountTensor, (int32_t)0, expertIdsCnt); + AscendC::SetFlag(1); + AscendC::WaitFlag(1); + + AscendC::LocalTensor xInTensor[BUFFER_NUM]; + AscendC::LocalTensor xInt32Tensor[BUFFER_NUM]; + + AscendC::GlobalTensor srcWinGMTensor; + srcWinGMTensor.SetGlobalBuffer((__gm__ XType *)gmX); + + xInTensor[0] = resource.ubBuf.template GetBufferByByte(ubOffset); + xInt32Tensor[0] = xInTensor[0].template ReinterpretCast(); + ubOffset += CEIL_UP(axisHCommuBf16Fp16 * sizeof(XType)); + xInTensor[1] = resource.ubBuf.template GetBufferByByte(ubOffset); + xInt32Tensor[1] = xInTensor[1].template ReinterpretCast(); + ubOffset += CEIL_UP(axisHCommuBf16Fp16 * sizeof(XType)); + AscendC::GlobalTensor dstWinGMTensor; + AscendC::SetFlag(0); + AscendC::SetFlag(1); + uint32_t sendValidTokenIndex = 0; + for (uint32_t sendGroupIndex = 0; sendGroupIndex < moeExpertNumPerRank; ++sendGroupIndex) { + for (uint32_t tokenIndex = startTokenId; tokenIndex < endTokenId; ++tokenIndex) { + int32_t dstExpertId = expertIdsTensor_(tokenIndex); + if (dstExpertId < 0) { + continue; + } + // Send to preferentically to the specicied expert + if ((dstExpertId % moeExpertNumPerRank) != sendGroupIndex) { + continue; + } + uint32_t index = (sendValidTokenIndex & 1) ? 0 : 1; + int32_t eventId = (sendValidTokenIndex & 1) ? 0 : 1; + sendValidTokenIndex += 1; + int32_t curExpertCnt = 0; + CalExpandxIdx(dstExpertId, tokenIndex, curExpertCnt, ubOffset); + expertCountTensor(tokenIndex - startTokenId) = curExpertCnt; + uint32_t tempRankId = dstExpertId / moeExpertNumPerRank + sharedExpertRankNum; + GM_ADDR rankGM = (__gm__ uint8_t *)(GET_WIND_ADDR_BY_RANK_ID(tempRankId) + + (expertPerSizeOnWin * (epRankId * moeExpertNumPerRank + + dstExpertId % moeExpertNumPerRank)) + + hCommuSize * curExpertCnt); + dstWinGMTensor.SetGlobalBuffer((__gm__ XType *)rankGM); + + AscendC::WaitFlag(eventId); + AscendC::DataCopy(xInTensor[index], srcWinGMTensor[tokenIndex / axisK * tokenLength], tokenLength); + AscendC::SetFlag(eventId); + AscendC::WaitFlag(eventId); + xInt32Tensor[index](hOutSize / sizeof(int32_t)) = tokenFlag; + AscendC::SetFlag(eventId); + + AscendC::WaitFlag(eventId); + + AscendC::DataCopy(dstWinGMTensor, xInTensor[index], tokenLength); + AscendC::PipeBarrier(); + AscendC::DataCopy(dstWinGMTensor[tokenLength], xInTensor[index][hOutSize / sizeof(XType)], 16); + AscendC::SetFlag(eventId); + } + } + AscendC::WaitFlag(0); + AscendC::WaitFlag(1); + + AscendC::GlobalTensor expandIdxGMTensor; + expandIdxGMTensor.SetGlobalBuffer((__gm__ int32_t *)gmExpandIdx + startTokenId); + AscendC::DataCopyExtParams expertIdsCntParams = {1U, static_cast(sendTokenNum * sizeof(uint32_t)), 0U, + 0U, 0U}; + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::DataCopyPad(expandIdxGMTensor, expertCountTensor, expertIdsCntParams); + } + + CATLASS_DEVICE void + SendCoreFunc(GM_ADDR gmX, GM_ADDR gmExpertIds, GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmExpandIdx, GM_ADDR gmXActiveMask) + { + ubOffset = 0; + if constexpr (EXEC_FLAG & EXEC_FLAG_X_ACTIVE_MASK) { + TokenActiveMaskCal(gmXActiveMask, ubOffset); + } + expertIdsCnt = activeMaskBsCnt * axisK; + + AscendC::GlobalTensor expertIdsGMTensor_; + expertIdsGMTensor_.SetGlobalBuffer((__gm__ int32_t *)gmExpertIds); + expertIdsTensor_ = (resource.ubBuf.template GetBufferByByte(ubOffset)); + ubOffset += CEIL_UP(expertIdsCnt * sizeof(int32_t)); + + AscendC::DataCopyExtParams expertIdsCntParams = {1U, static_cast(expertIdsCnt * sizeof(uint32_t)), 0U, 0U, + 0U}; + AscendC::DataCopyPadExtParams copyPadParams{false, 0U, 0U, 0U}; + AscendC::DataCopyPad(expertIdsTensor_, expertIdsGMTensor_, expertIdsCntParams, copyPadParams); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + + CalAndSendTokenCount(); + AscendC::PipeBarrier(); + if (hasShareExpert) { + sendToShareAivNum = sendCoreNum / (axisK + 1); + if (sendToShareAivNum == 0) { + sendToShareAivNum = 1; + } + } + sendToMoeAivNum = sendCoreNum - sendToShareAivNum; + + AscendC::SetDeqScale((half)1.000000e+00f); + if (hasShareExpert && sendCoreIdx >= sendToMoeAivNum) { + SendToShareExprt(gmX, gmX1, gmX1Scale); + } else { + SendToMoeExprt(gmX, gmExpandIdx); + } + AscendC::PipeBarrier(); + } + + CATLASS_DEVICE + void RecvCount(int64_t ubOffset) + { + uint32_t recStatusNumPerCore = isShareExpert ? epRankSize : expertCntUp; + uint32_t startStatusIndex = 0; // every wait for all token counts + + int64_t subUbOffset = ubOffset; + AscendC::LocalTensor statusTensor_ = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(expertCntUp * UB_BLOCK_SIZE); + AscendC::LocalTensor gatherTmpTensor = (resource.ubBuf.template GetBufferByByte(subUbOffset)); + subUbOffset += CEIL_UP(UB_BLOCK_SIZE); + AscendC::LocalTensor gatherMaskOutTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(expertCntUp * sizeof(float)); + AscendC::LocalTensor statusFp32Tensor_ = statusTensor_.ReinterpretCast(); + + AscendC::LocalTensor statusSumOutTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(UB_BLOCK_SIZE); + AscendC::LocalTensor sumTmpTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(SUM_TMP_TENSOR_SIZE); + gatherTmpTensor.SetValue(0, 1); + + uint32_t mask = 1; + uint64_t rsvdCnt = 0; + AscendC::SumParams sumParams{1, recStatusNumPerCore, recStatusNumPerCore}; + float sumOfFlag = static_cast(-1.0); + float minTarget = (sumTarget * recStatusNumPerCore) - (float)0.5; + float maxTarget = (sumTarget * recStatusNumPerCore) + (float)0.5; + AscendC::DataCopyParams intriParams{static_cast(recStatusNumPerCore), 1, static_cast(15), + 0}; + AscendC::GlobalTensor windowInstatusFp32Tensor_; + windowInstatusFp32Tensor_.SetGlobalBuffer((__gm__ float *)GET_WIND_STATE_ADDR_BY_RANK_ID(epRankId)); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + + uint32_t preRecvTokenCount = 0; + while ((sumOfFlag < minTarget) || (sumOfFlag > maxTarget)) { + AscendC::DataCopy(statusFp32Tensor_, windowInstatusFp32Tensor_[startStatusIndex * stateOffset / sizeof(float)], + intriParams); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::GatherMask(gatherMaskOutTensor, statusFp32Tensor_, gatherTmpTensor, true, mask, + {1, (uint16_t)recStatusNumPerCore, 1, 0}, rsvdCnt); + AscendC::PipeBarrier(); + AscendC::Sum(statusSumOutTensor, gatherMaskOutTensor, sumTmpTensor, sumParams); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + sumOfFlag = statusSumOutTensor.GetValue(0); + } + } + + CATLASS_DEVICE + void GetCumSum(int32_t startRankId, int32_t recvExpertNum, int64_t ubOffset) + { + // calculate token index in output tensor + int64_t subUbOffset = ubOffset; + uint32_t recStatusNumPerCore = isShareExpert ? epRankSize : expertCntUp; + AscendC::LocalTensor statusTensor_ = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(expertCntUp * UB_BLOCK_SIZE); + AscendC::LocalTensor gatherTmpTensor = (resource.ubBuf.template GetBufferByByte(subUbOffset)); + subUbOffset += CEIL_UP(UB_BLOCK_SIZE); + AscendC::LocalTensor gatherMaskOutTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(expertCntUp * sizeof(float)); + AscendC::LocalTensor statusFp32Tensor_ = statusTensor_.ReinterpretCast(); + if (isShareExpert) { + for (uint32_t curSatatusExpId = 0; curSatatusExpId < sharedExpertRankNum; ++curSatatusExpId) { + int32_t curExpertCnt = (curSatatusExpId + 1 + epRankId) * axisBS / sharedExpertRankNum - + (curSatatusExpId + epRankId) * axisBS / sharedExpertRankNum; + statusTensor_((curSatatusExpId)*INT32_COUNT_PER_BLOCK + 1) = curExpertCnt; + } + } + + uint64_t rsvdCnt = 0; + gatherTmpTensor.SetValue(0, GATHER_SECOND_NUM); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::GatherMask(gatherMaskOutTensor, statusFp32Tensor_, gatherTmpTensor, true, GATHER_SECOND_NUM, + {1, (uint16_t)recStatusNumPerCore, 1, 0}, rsvdCnt); + AscendC::LocalTensor workLocalTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + AscendC::PipeBarrier(); + AscendC::ReduceSum(gatherMaskOutTensor, gatherMaskOutTensor, workLocalTensor, + (startRankId + 1) <= recvExpertNum ? (startRankId + 1) : recvExpertNum); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + } + + CATLASS_DEVICE + void RecvToken(GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmEpSendCount, uint32_t &coreTokenCount, uint32_t startRankId, + uint32_t endRankId, uint32_t recvRankNumPerCore, int64_t ubOffset) + { + int64_t subUbOffset = ubOffset; + AscendC::LocalTensor statusTensor_ = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(expertCntUp * UB_BLOCK_SIZE); + AscendC::LocalTensor gatherTmpTensor = (resource.ubBuf.template GetBufferByByte(subUbOffset)); + subUbOffset += CEIL_UP(UB_BLOCK_SIZE); + AscendC::LocalTensor gatherMaskOutTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(expertCntUp * sizeof(float)); + AscendC::LocalTensor statusFp32Tensor_ = statusTensor_.ReinterpretCast(); + + AscendC::DataCopyExtParams dataCopyParamsFloat = {1U, sizeof(float), 0U, 0U, 0U}; + AscendC::LocalTensor xTmpTensor_ = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(axisHCommu * sizeof(XType)); + AscendC::LocalTensor tmpLocalTensor = resource.ubBuf.template GetBufferByByte(subUbOffset); + subUbOffset += CEIL_UP(UB_BLOCK_SIZE); + AscendC::LocalTensor gatherMaskOutCountTensor = (gatherMaskOutTensor.template ReinterpretCast()); + AscendC::GlobalTensor tokGlobal; + AscendC::GlobalTensor tokGlobalInt32; + AscendC::GlobalTensor expandXOutGlobal; + uint32_t beginIdx = 0; + for (uint32_t index = startRankId; index < endRankId; index++) { + uint32_t i = index - startRankId; + if (i > 0) { + gatherMaskOutCountTensor.SetValue( + i, gatherMaskOutCountTensor.GetValue(i - 1) + gatherMaskOutCountTensor.GetValue(index)); + } + uint32_t count = statusTensor_.GetValue(index * INT32_COUNT_PER_BLOCK + 1); + coreTokenCount += count; + beginIdx = gatherMaskOutCountTensor.GetValue(i) - count; + if (isShareExpert && index < sharedExpertRankNum) { + beginIdx += count; + continue; + } + uint32_t winOffset = index; + if (!isShareExpert && moeExpertNumPerRank > 1) { + // srcRank: index % epRankSize + // localExpertId: index / epRankSize + // Addr: (srcRank * moeExpertNumPerRank + localExpertId) * expertPerSizeOnWin + winOffset = (index % epRankSize) * moeExpertNumPerRank + index / epRankSize; + } + GM_ADDR wAddr = (__gm__ uint8_t *)(GET_WIND_ADDR_BY_RANK_ID(epRankId)) + winOffset * expertPerSizeOnWin; + AscendC::SetFlag(0); + for (uint32_t j = 0; j < count; j++) { + tokGlobal.SetGlobalBuffer((__gm__ XType *)(wAddr + j * hCommuSize)); + tokGlobalInt32.SetGlobalBuffer((__gm__ int32_t *)(wAddr + j * hCommuSize + hOutSize)); + expandXOutGlobal.SetGlobalBuffer((__gm__ XType *)(gmX1) + (beginIdx + j) * tokenLength, tokenLength); + + while (true) { + AscendC::DataCopy(tmpLocalTensor, tokGlobalInt32, INT32_COUNT_PER_BLOCK); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + if (tmpLocalTensor.GetValue(0) == tokenFlag) { + tokGlobalInt32.SetValue(0, 0); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid(tokGlobalInt32[1]); + __asm__ __volatile__(""); + break; + } + } + AscendC::PipeBarrier(); + + AscendC::WaitFlag(0); + AscendC::DataCopy(xTmpTensor_, tokGlobal, tokenLength); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::DataCopy(expandXOutGlobal, xTmpTensor_, tokenLength); + AscendC::SetFlag(0); + } + AscendC::WaitFlag(0); + beginIdx += count; + } + AscendC::PipeBarrier(); + + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::DataCopyExtParams dataCopyOutParams = {1U, static_cast(recvRankNumPerCore * sizeof(int32_t)), 0U, + 0U, 0U}; + AscendC::GlobalTensor sendCountsGlobal; + sendCountsGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(gmEpSendCount)); + AscendC::DataCopyPad(sendCountsGlobal[startRankId], gatherMaskOutCountTensor, dataCopyOutParams); + } + + CATLASS_DEVICE + void RecvCoreFunc(GM_ADDR gmX1, GM_ADDR gmX1Scale, GM_ADDR gmEpSendCount) + { + ubOffset = 0; + RecvCount(ubOffset); + + uint32_t recvExpertNum = isShareExpert ? epRankSize : expertCntUp; + uint32_t recvCoreNumPerGroup = recvCoreNum / localExpertNum; + uint32_t recvRankNumPerCore = epRankSize / recvCoreNumPerGroup; + uint32_t remainderRankNum = epRankSize % recvCoreNumPerGroup; + + uint32_t groupId = recvCoreIdx / recvCoreNumPerGroup; + uint32_t recvCoreIdxInGroup = recvCoreIdx % recvCoreNumPerGroup; + uint32_t startRankIdInGroup = recvRankNumPerCore * recvCoreIdxInGroup; + if (recvCoreIdxInGroup < remainderRankNum) { + recvRankNumPerCore += 1; + startRankIdInGroup += recvCoreIdxInGroup; + } else { + startRankIdInGroup += remainderRankNum; + } + uint32_t endRankIdInGroup = startRankIdInGroup + recvRankNumPerCore; + uint32_t startRankId = epRankSize * groupId + startRankIdInGroup; + uint32_t endRankId = epRankSize * groupId + endRankIdInGroup; + + uint32_t coreTokenCount = 0; + + if (startRankId < recvExpertNum) { + // RecvCount, GetCumSum, RecvToken must use the same ubOffset to get right info + GetCumSum(startRankId, recvExpertNum, ubOffset); + RecvToken(gmX1, gmX1Scale, gmEpSendCount, coreTokenCount, startRankId, endRankId, recvRankNumPerCore, ubOffset); + } + + // recv finish, inform AIC + AscendC::PipeBarrier(); + AscendC::LocalTensor tmpLocalTensor = resource.ubBuf.template GetBufferByByte(0); + ubOffset += CEIL_UP(UB_BLOCK_SIZE); + tmpLocalTensor.SetValue(CV_FLAG_INDEX, vToCFlag); + tmpLocalTensor.SetValue(GROUP_ID_INDEX, groupId); + tmpLocalTensor.SetValue(SELF_COUNT_INDEX, coreTokenCount); + AscendC::SetFlag(0); + + AscendC::GlobalTensor groupTokenNumStateTensor; + groupTokenNumStateTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + GROUP_TOKEN_NUM_OFFSET)); + AscendC::WaitFlag(0); + AscendC::SetAtomicAdd(); + AscendC::DataCopy(groupTokenNumStateTensor[groupId * GROUP_INFO_SIZE], tmpLocalTensor, INT32_COUNT_PER_BLOCK); + AscendC::SetAtomicNone(); + AscendC::PipeBarrier(); + } + + CATLASS_DEVICE + void CompCoreFunc(GM_ADDR gmCVSwapBuff, __gm__ ElementScale *gmScale, __gm__ ElementPerTokenScale *gmTokenScale, + __gm__ float *gmSwigluOutput, uint32_t n, uint32_t k, LayoutScale layoutScale, + LayoutPerTokenScale wholeLayoutPerTokenScale, LayoutOutput layoutOutput) + { + uint32_t coreNumPerGroup = recvCoreNum / localExpertNum; + int64_t gmGroupOffsetScale = 0; + int64_t gmGroupOffsetPerTokenScale = 0; + int64_t gmGroupOffsetD = 0; + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(gmCVSwapBuff)); + auto layoutC = layout::RowMajor{L1TileShape::M * aiCoreGroupNum * WORKSPACE_STAGES, L1TileShape::N}; + { + BlockScheduler blockScheduler; + BlockEpilogue blockEpilogue(resource); + + uint32_t stageId = 0; + uint32_t target = 1; + uint32_t startCoreIdx = 0; + AscendC::ListTensorDesc gmScaleListTensor; + AscendC::GlobalTensor groupTokenNumStateTensor; + gmScaleListTensor = AscendC::ListTensorDesc(reinterpret_cast<__gm__ void *>(gmScale)); + __gm__ ElementScale* gmScalePtr; + if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) { + gmScalePtr = reinterpret_cast<__gm__ ElementScale*>(gmScaleListTensor.GetDataPtr(0)); + } + for (uint32_t groupIdx = 0; groupIdx < localExpertNum; ++groupIdx) { + // just like AIC + groupTokenNumStateTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + GROUP_TOKEN_NUM_OFFSET) + + groupIdx * GROUP_INFO_SIZE); + while (true) { + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid(groupTokenNumStateTensor); + __asm__ __volatile__(""); + if (groupTokenNumStateTensor.GetValue(0) == coreNumPerGroup * vToCFlag) { + break; + } + } + uint32_t currentM = groupTokenNumStateTensor.GetValue(GROUP_TOKEN_COUNT); + GemmCoord inGroupProblemShape{currentM, n, k}; + LayoutPerTokenScale layoutPerTokenScale = + wholeLayoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>()); + LayoutD layoutD = layout::RowMajor{currentM, n}; + EpilogueParams epilogueParams; + if constexpr (EXEC_FLAG & EXEC_FLAG_TENSOR_LIST) { + gmScalePtr = reinterpret_cast<__gm__ ElementScale*>( + gmScaleListTensor.GetDataPtr(groupIdx)); + epilogueParams = EpilogueParams { + gmScalePtr, layoutScale, + gmTokenScale + gmGroupOffsetPerTokenScale, layoutPerTokenScale, + gmSwigluOutput + gmGroupOffsetD, layoutD}; + } else { + epilogueParams = EpilogueParams{gmScalePtr + gmGroupOffsetScale, + layoutScale, + gmTokenScale + gmGroupOffsetPerTokenScale, + layoutPerTokenScale, + gmSwigluOutput + gmGroupOffsetD, + layoutD}; + } + blockScheduler.Update(inGroupProblemShape, L1TileShape::ToCoordMN()); + blockEpilogue.UpdateParams(epilogueParams); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + GemmCoord blockShapeMNK = L1TileShape::ToCoord(); + uint32_t startLoopIdx = + ((compCoreIdx < startCoreIdx) ? (compCoreIdx + aiCoreGroupNum) : compCoreIdx) - startCoreIdx; + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += aiCoreGroupNum) { + GemmCoord blockCoordMNK = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShapeMNK = blockScheduler.GetActualBlockShape(blockCoordMNK); + + MatrixCoord offsetC{(stageId * aiCoreGroupNum + aiCoreGroupIdx) * L1TileShape::M, 0}; + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + auto gmBlockC = gmC[gmOffsetC]; + auto layoutBlockC = layoutC.GetTileLayout(actualBlockShapeMNK.GetCoordMN()); + CheckSyncFlag(statusDataSpaceGm + SOFT_SYNC_OFFSET, + static_cast(compCoreNum + compCoreIdx), target); + target += 1; + blockEpilogue(blockShapeMNK, blockCoordMNK, actualBlockShapeMNK, gmBlockC, layoutBlockC); + EncreaseSyncFlag(statusDataSpaceGm + SOFT_SYNC_OFFSET, static_cast(compCoreIdx)); + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + if constexpr (!(EXEC_FLAG & EXEC_FLAG_TENSOR_LIST)) { + gmGroupOffsetScale += inGroupProblemShape.n(); + } + gmGroupOffsetPerTokenScale += inGroupProblemShape.m(); + gmGroupOffsetD += currentM * n; + + startCoreIdx = (startCoreIdx + coreLoops) % aiCoreGroupNum; + } + } + // clean + AscendC::PipeBarrier(); + AscendC::GlobalTensor softSyncTensor; + softSyncTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + SOFT_SYNC_OFFSET)); + AscendC::LocalTensor tmpZeroLocalTensor = resource.ubBuf.template GetBufferByByte(0); + AscendC::Duplicate(tmpZeroLocalTensor, (int32_t)0, INT32_COUNT_PER_BLOCK); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::DataCopy(softSyncTensor[compCoreIdx * SOFT_SYNC_SPACE_SIZE / sizeof(int32_t)], tmpZeroLocalTensor, + INT32_COUNT_PER_BLOCK); + AscendC::DataCopy(softSyncTensor[(compCoreIdx + compCoreNum) * SOFT_SYNC_SPACE_SIZE / sizeof(int32_t)], + tmpZeroLocalTensor, INT32_COUNT_PER_BLOCK); + } + + CATLASS_DEVICE + void AivInitParams(Params const ¶ms) + { + aiCoreGroupNum = AscendC::GetBlockNum(); + subBlockNum = AscendC::GetSubBlockNum(); // 1C2V + aicNum = aiCoreGroupNum; + aivNum = aiCoreGroupNum * subBlockNum; + aivIdx = AscendC::GetBlockIdx(); + aiCoreGroupIdx = aivIdx / subBlockNum; + aivStateGlobalCoreIdx = aivNum + aicNum + aivIdx; + + isCompCore = (aivIdx % subBlockNum) == 0; + compCoreNum = aiCoreGroupNum; + compCoreIdx = aiCoreGroupIdx; + // when localExpertNum=1, all cores send token and recv token in sequence + isRecvCore = true; + isSendCore = true; + recvCoreIdx = aivIdx; + sendCoreIdx = aivIdx; + sendCoreNum = aivNum; + recvCoreNum = aivNum; + + moeExpertNumPerRank = params.moeExpertNumPerRank; + + epRankSize = params.epRankSize; + epRankId = params.epRankId; + expertCntUp = epRankSize * moeExpertNumPerRank; + sharedExpertRankNum = params.sharedExpertRankNum; + hasShareExpert = (sharedExpertRankNum > 0); + isShareExpert = (epRankId < sharedExpertRankNum); + localExpertNum = isShareExpert ? 1 : moeExpertNumPerRank; + moeExpertNum = params.moeExpertNum; + tokenLength = params.tokenLen; + + // when localExpertNum>1, half of cores send token and another half recv token in parallel + if (localExpertNum > 1) { + isRecvCore = ((aivIdx % ODD_EVEN_BASE) == 0); + isSendCore = ((aivIdx % ODD_EVEN_BASE) == 1); + recvCoreIdx = aivIdx / subBlockNum; + sendCoreIdx = aivIdx / subBlockNum; + sendCoreNum = aiCoreGroupNum; + recvCoreNum = aiCoreGroupNum; + } + + hOutSize = tokenLength * sizeof(XType); + scaleParamPad = TOKEN_EXTRA_SPACE; + hCommuSize = hOutSize + scaleParamPad; + axisHCommu = hCommuSize / sizeof(int8_t); + axisHCommuBf16Fp16 = hCommuSize / sizeof(XType); + axisBS = params.bs; + activeMaskBsCnt = axisBS; + axisK = params.topK; + uint32_t maxAxisBs = params.globalBs / epRankSize; + + stateOffset = STATE_OFFSET; + expertPerSizeOnWin = maxAxisBs * hCommuSize; + winContext_ = (__gm__ HcclOpResParam *)AscendC::GetHcclContext(); + statusDataSpaceGm = (GM_ADDR)(winContext_->localWindowsExp); + } + + CATLASS_DEVICE + void AivInitState() + { + // state of data sapce + AscendC::GlobalTensor selfDataStatusTensor; + selfDataStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + STATE_WIN_OFFSET)); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + selfDataStatusTensor[aivIdx * UB_ALIGN]); + __asm__ __volatile__(""); + dataState = selfDataStatusTensor(aivIdx * UB_ALIGN); + if (dataState == 0) { + selfDataStatusTensor(aivIdx * UB_ALIGN) = 1; + } else { + selfDataStatusTensor(aivIdx * UB_ALIGN) = 0; + } + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + selfDataStatusTensor[aivIdx * UB_ALIGN]); + __asm__ __volatile__(""); + + // state of cv flag + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + selfDataStatusTensor[aivStateGlobalCoreIdx * UB_ALIGN]); + __asm__ __volatile__(""); + cvDataState = selfDataStatusTensor(aivStateGlobalCoreIdx * UB_ALIGN); + if (cvDataState == 0) { + selfDataStatusTensor(aivStateGlobalCoreIdx * UB_ALIGN) = 1; + vToCFlag = V_TO_C_FLAG_1; + } else { + selfDataStatusTensor(aivStateGlobalCoreIdx * UB_ALIGN) = 0; + vToCFlag = V_TO_C_FLAG_2; + } + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + selfDataStatusTensor[aivStateGlobalCoreIdx * UB_ALIGN]); + __asm__ __volatile__(""); + + AscendC::PipeBarrier(); + winDataSizeOffset = dataState * epRankSize * expertPerSizeOnWin * moeExpertNumPerRank; + GM_ADDR statusSpaceGm_ = GET_WIND_STATE_ADDR_BY_RANK_ID(epRankId); + AscendC::GlobalTensor selfStatusTensor; + selfStatusTensor.SetGlobalBuffer((__gm__ int32_t *)(statusSpaceGm_ + SELF_STATE_OFFSET)); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + selfStatusTensor[aivIdx * UB_ALIGN]); + __asm__ __volatile__(""); + state = selfStatusTensor(aivIdx * UB_ALIGN); + if (state == 0) { + sumTarget = (float)1.0; + tokenFlag = TOKEN_FLAG_1; + selfStatusTensor(aivIdx * UB_ALIGN) = 0x3F800000; + } else { + sumTarget = 0.0; + tokenFlag = TOKEN_FLAG_2; + selfStatusTensor(aivIdx * UB_ALIGN) = 0; + } + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + selfStatusTensor[aivIdx * UB_ALIGN]); + __asm__ __volatile__(""); + } + + CATLASS_DEVICE + void UpdateAndCleanInfo(__gm__ ElementGroupList_ *ptrGroupList, GM_ADDR gmEpSendCount, GM_ADDR gmExpertTokenNums) + { + if (aivIdx == aiCoreGroupNum * subBlockNum - 1) { + // clean + AscendC::GlobalTensor groupTokenNumStateTensor; + groupTokenNumStateTensor.SetGlobalBuffer((__gm__ int32_t *)(statusDataSpaceGm + GROUP_TOKEN_NUM_OFFSET)); + AscendC::LocalTensor tmpZeroLocalTensor = resource.ubBuf.template GetBufferByByte(0); + AscendC::Duplicate(tmpZeroLocalTensor, (int32_t)0, GROUP_INFO_SIZE * localExpertNum); + AscendC::SetFlag(0); + AscendC::WaitFlag(0); + AscendC::DataCopy(groupTokenNumStateTensor, tmpZeroLocalTensor, GROUP_INFO_SIZE * localExpertNum); + } + + if (isRecvCore && recvCoreIdx == (recvCoreNum - 1)) { + // record token count for each local expert + AscendC::GlobalTensor expertTokenNumsOutGMTensor_; + expertTokenNumsOutGMTensor_.SetGlobalBuffer((__gm__ int64_t *)(ptrGroupList)); + AscendC::GlobalTensor sendCountsGlobal; + sendCountsGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(gmEpSendCount)); + AscendC::GlobalTensor nonCumSumExpertTokenNumsTensor; + nonCumSumExpertTokenNumsTensor.SetGlobalBuffer((__gm__ int64_t *)gmExpertTokenNums); + uint32_t tmpTokenNum = 0; + for (uint32_t localMoeIndex = 0; localMoeIndex < localExpertNum; ++localMoeIndex) { + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + sendCountsGlobal[localMoeIndex * epRankSize + epRankSize - 1]); + __asm__ __volatile__(""); + + uint32_t tokenNum = sendCountsGlobal.GetValue(localMoeIndex * epRankSize + epRankSize - 1); + expertTokenNumsOutGMTensor_.SetValue(localMoeIndex, tokenNum); + uint32_t nonCumSumTokenNum = tokenNum - tmpTokenNum; + nonCumSumExpertTokenNumsTensor.SetValue(localMoeIndex, nonCumSumTokenNum); + tmpTokenNum = tokenNum; + + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + expertTokenNumsOutGMTensor_[localMoeIndex]); + __asm__ __volatile__(""); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + nonCumSumExpertTokenNumsTensor[localMoeIndex]); + __asm__ __volatile__(""); + } + } + } + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + AivInitParams(params); + AivInitState(); + if (isSendCore) { + SendCoreFunc((GM_ADDR)params.gmX, (GM_ADDR)params.gmexpertIds, (GM_ADDR)params.ptrA, + (GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmExpandIdx, (GM_ADDR)params.gmXActiveMask); + } + if (isRecvCore) { + RecvCoreFunc((GM_ADDR)params.ptrA, (GM_ADDR)params.ptrPerTokenScale, (GM_ADDR)params.gmEpSendCount); + } + + auto gmSwigluOutput = reinterpret_cast<__gm__ float *>( + params.ptrWorkspace + sizeof(int32_t) * (L1TileShape::M * aiCoreGroupNum * WORKSPACE_STAGES * L1TileShape::N)); + if (isCompCore) { + CompCoreFunc(params.ptrWorkspace, params.ptrScale, params.ptrPerTokenScale, gmSwigluOutput, + params.problemShape.n(), params.problemShape.k(), params.layoutScale, params.layoutPerTokenScale, + params.layoutOutput); + } + + icache_preload(8); + AscendC::SyncAll(); + AscendC::PipeBarrier(); + + UpdateAndCleanInfo(params.ptrGroupList, params.gmEpSendCount, params.gmExpertTokenNums); + { + AscendC::GlobalTensor sendCountsGlobal; + sendCountsGlobal.SetGlobalBuffer(reinterpret_cast<__gm__ int32_t *>(params.gmEpSendCount)); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid(sendCountsGlobal); + __asm__ __volatile__(""); + totalTokenCount = sendCountsGlobal.GetValue(localExpertNum * epRankSize - 1); + AscendC::PipeBarrier(); + uint32_t n = params.problemShape.n(); + uint32_t nOut = params.problemShape.n() / 2; + uint32_t swigluRowOnce = 0; + CalQuantRow(nOut, swigluRowOnce); + auto swigluLayout = layout::RowMajor{totalTokenCount, n}; + typename SwigluPost::Params swigluParams{ + gmSwigluOutput, swigluLayout, params.ptrSwigluScale, params.layoutSwigluScale, + params.ptrOutput, params.layoutOutput, swigluRowOnce, nOut}; + + SwigluPost blockSwiglu(resource, swigluParams); + MatrixCoord swigluShape(totalTokenCount, nOut); + MatrixCoord swigluBlockShape((uint16_t)(subBlockNum * swigluRowOnce), nOut); + Epilogue::Tile::EpilogueHorizontalTileSwizzle swigluSwizzle(swigluShape, swigluBlockShape); + for (uint32_t loopIdx = aiCoreGroupIdx; loopIdx < swigluSwizzle.GetLoops(); loopIdx += aiCoreGroupNum) { + auto blockCoord = swigluSwizzle.GetTileCoord(loopIdx); + auto actualBlockShape = swigluSwizzle.GetActualTileShape(blockCoord); + blockSwiglu(swigluBlockShape, blockCoord, actualBlockShape); + } + } + } + +private: + friend struct AicWaitFunc1; + friend struct AicSetFunc1; + + struct AicWaitFunc1 { + CATLASS_DEVICE + AicWaitFunc1() = default; + + CATLASS_DEVICE + void operator()() const + { + CheckSyncFlag(flagAddr, idx, target); + } + + __gm__ uint8_t *flagAddr; + uint8_t idx; + uint32_t target; + }; + + struct AicSetFunc1 { + CATLASS_DEVICE + AicSetFunc1() = default; + + CATLASS_DEVICE + void operator()() const + { + EncreaseSyncFlag(flagAddr, idx); + } + + __gm__ uint8_t *flagAddr; + uint8_t idx; + }; + + AicWaitFunc1 aicWaitFunc1; + AicSetFunc1 aicSetFunc1; + Arch::Resource resource; + + AscendC::LocalTensor expertIdsTensor_; + + // rank and expert info + uint32_t epRankSize{0}; + uint32_t epRankId{0}; + bool hasShareExpert{false}; + bool isShareExpert{false}; + uint32_t expertCntUp{0}; + uint32_t localExpertNum{0}; + uint32_t sharedExpertRankNum{0}; + uint32_t moeExpertNumPerRank{0}; + uint32_t moeExpertNum{0}; + + // token info + uint32_t hOutSize{0}; + uint32_t scaleParamPad{0}; + uint32_t hCommuSize{0}; + uint32_t axisHCommu{0}; + uint32_t axisHCommuBf16Fp16{0}; + uint32_t axisBS{0}; + uint32_t activeMaskBsCnt{0}; + uint32_t axisK{0}; + uint32_t totalTokenCount{0}; + uint32_t expertIdsCnt{0}; + uint32_t tokenLength{0}; + + // state info + int32_t tokenFlag{0}; // token flag + int32_t vToCFlag{0}; // cv flag, decided by cvDataState + int32_t dataState{0}; // data space state + int32_t cvDataState{0}; // cv flag state + int32_t state{0}; // count flag state + float sumTarget{0.0}; + + // memory info + __gm__ HcclOpResParam *winContext_; + GM_ADDR statusDataSpaceGm; + uint32_t stateOffset{0}; + uint64_t expertPerSizeOnWin{0}; + uint64_t winDataSizeOffset{0}; + + int64_t ubOffset; + + // core info + bool isSendCore{false}; + bool isRecvCore{false}; + bool isCompCore{false}; // calculate deq_swiglu + uint32_t aiCoreGroupNum{0}; + uint32_t aiCoreGroupIdx{0}; + uint32_t subBlockNum{0}; + uint32_t aicNum{0}; + uint32_t aivNum{0}; + uint32_t sendCoreNum{0}; + uint32_t recvCoreNum{0}; + uint32_t compCoreNum{0}; + uint32_t aivIdx{0}; + uint32_t aicIdx{0}; + uint32_t sendCoreIdx{0}; + uint32_t recvCoreIdx{0}; + uint32_t compCoreIdx{0}; + uint32_t aivStateGlobalCoreIdx{0}; + uint32_t aicStateGlobalCoreIdx{0}; + uint32_t sendToMoeAivNum{0}; + uint32_t sendToShareAivNum{0}; +}; + +} // namespace Catlass::Gemm::Kernel + +namespace Catlass::Gemm::Kernel { + +template +class GroupedMatmulSliceMSwigluMultiStageWorkspaceWithShallowDispatch +{ +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + using ElementAccumulator = typename BlockMmad::ElementAccumulator; + + using BlockEpilogue = BlockEpilogue_; + using ElementScale = typename BlockEpilogue::ElementRawScale; + using LayoutScale = typename BlockEpilogue::LayoutScale; + using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale; + using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale; + using ElementD = typename BlockEpilogue::ElementD; + using LayoutD = typename BlockEpilogue::LayoutD; + using EpilogueParams = typename BlockEpilogue::Params; + + using XType = ExpandXType; + using ElementSwigluScale = typename SwigluPost::ElementSwigluScale; + using LayoutSwigluScale = typename SwigluPost::LayoutSwigluScale; + using ElementOutput = typename SwigluPost::ElementOutput; + using LayoutOutput = typename SwigluPost::LayoutOutput; + + using BlockScheduler = BlockScheduler_; + static constexpr uint32_t WORKSPACE_STAGES = WORKSPACE_STAGES_; + using ElementGroupList = ElementGroupList_; + + /// Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + uint32_t problemCount; + __gm__ ElementGroupList_ *ptrGroupList; + __gm__ ElementA *ptrA; + LayoutA layoutA; + __gm__ ElementB *ptrB; + LayoutB layoutB; + __gm__ ElementScale *ptrScale; + LayoutScale layoutScale; + __gm__ ElementPerTokenScale *ptrPerTokenScale; + LayoutPerTokenScale layoutPerTokenScale; + __gm__ ElementOutput *ptrOutput; + LayoutOutput layoutOutput; + __gm__ ElementSwigluScale *ptrSwigluScale; + LayoutSwigluScale layoutSwigluScale; + GM_ADDR ptrWorkspace; + + // Methods + CATLASS_DEVICE + Params() {} + + CATLASS_DEVICE + Params(GemmCoord problemShape_, uint32_t problemCount_, GM_ADDR ptrGroupList_, GM_ADDR ptrA_, + LayoutA const &layoutA_, GM_ADDR ptrB_, LayoutB const &layoutB_, GM_ADDR ptrScale_, + LayoutScale const &layoutScale_, GM_ADDR ptrPerTokenScale_, + LayoutPerTokenScale const &layoutPerTokenScale_, GM_ADDR ptrOutput_, LayoutOutput const &layoutOutput_, + GM_ADDR ptrSwigluScale_, LayoutSwigluScale const &layoutSwigluScale_, GM_ADDR ptrWorkspace_) + : problemShape(problemShape_), + problemCount(problemCount_), + ptrGroupList(reinterpret_cast<__gm__ ElementGroupList *>(ptrGroupList_)), + ptrA(reinterpret_cast<__gm__ ElementA *>(ptrA_)), + layoutA(layoutA_), + ptrB(reinterpret_cast<__gm__ ElementB *>(ptrB_)), + layoutB(layoutB_), + ptrScale(reinterpret_cast<__gm__ ElementScale *>(ptrScale_)), + layoutScale(layoutScale_), + ptrPerTokenScale(reinterpret_cast<__gm__ ElementPerTokenScale *>(ptrPerTokenScale_)), + layoutPerTokenScale(layoutPerTokenScale_), + ptrOutput(reinterpret_cast<__gm__ ElementOutput *>(ptrOutput_)), + layoutOutput(layoutOutput_), + ptrSwigluScale(reinterpret_cast<__gm__ ElementSwigluScale *>(ptrSwigluScale_)), + layoutSwigluScale(layoutSwigluScale_), + ptrWorkspace(ptrWorkspace_) + {} + }; + + // Methods + CATLASS_DEVICE + GroupedMatmulSliceMSwigluMultiStageWorkspaceWithShallowDispatch() + { + Arch::FlagID flagId = 0; + for (uint32_t stageId = 0; stageId < WORKSPACE_STAGES; ++stageId) { + flagAicFinishStoreList[stageId] = Arch::CrossCoreFlag(flagId++); + flagAivFinishComputeList[stageId] = Arch::CrossCoreFlag(flagId++); + aicWaitFuncList[stageId] = {this, stageId}; + aicSetFuncList[stageId] = {this, stageId}; + } + } + + template + CATLASS_DEVICE void operator()(Params const ¶ms); + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + BlockScheduler blockScheduler; + BlockMmad blockMmad(resource); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer(params.ptrA); + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer(params.ptrB); + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + uint32_t coreIdx = AscendC::GetBlockIdx(); + uint32_t coreNum = AscendC::GetBlockNum(); + int64_t gmGroupOffsetA = 0; + int64_t gmGroupOffsetB = 0; + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + auto layoutC = layout::RowMajor{L1TileShape::M * coreNum * WORKSPACE_STAGES, L1TileShape::N}; + + uint32_t stageId = 0; + uint32_t stageUsed = 0; + uint32_t startCoreIdx = 0; + for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx) + : (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1)); + GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; + + LayoutA layoutA = params.layoutA.GetTileLayout(inGroupProblemShape.GetCoordMK()); + LayoutB layoutB = params.layoutB; + + blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + // Determine the starting loopIdx of the current core under the current groupIdx + uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; + // Loop through the matmul of each groupIdx + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { + // Compute block location + GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord); + + Callback callbackBeforeFixpipe{}; + if (stageUsed == WORKSPACE_STAGES) { + callbackBeforeFixpipe = MakeCallback(&aicWaitFuncList[stageId]); + } else { + ++stageUsed; + } + Callback callbackAfterFixpipe = MakeCallback(&aicSetFuncList[stageId]); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, 0}; + int64_t gmOffsetA = layoutA.GetOffset(offsetA); + int64_t gmOffsetB = layoutB.GetOffset(offsetB); + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + + // Compute block-scoped matrix multiply-add + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, + gmC[gmOffsetC], layoutC, actualBlockShape, callbackBeforeFixpipe, callbackAfterFixpipe); + } else { + callbackBeforeFixpipe(); + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, + gmC[gmOffsetC], layoutC, actualBlockShape); + callbackAfterFixpipe(); + } + + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k(); + gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; + } + + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad.SynchronizeBlock(); + } + + while (stageUsed > 0) { + uint32_t aivComputeStageId = + (stageId >= stageUsed) ? (stageId - stageUsed) : (stageId + WORKSPACE_STAGES - stageUsed); + Arch::CrossCoreWaitFlag(flagAivFinishComputeList[aivComputeStageId]); + --stageUsed; + } + } + + template <> + CATLASS_DEVICE void operator()(Params const ¶ms) + { + uint32_t coreIdx = AscendC::GetBlockIdx() / AscendC::GetSubBlockNum(); + uint32_t coreNum = AscendC::GetBlockNum(); + int64_t gmGroupOffsetScale = 0; + int64_t gmGroupOffsetPerTokenScale = 0; + int64_t gmGroupOffsetD = 0; + + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + auto layoutC = layout::RowMajor{L1TileShape::M * coreNum * WORKSPACE_STAGES, L1TileShape::N}; + + auto ptrD = reinterpret_cast<__gm__ float *>( + params.ptrWorkspace + sizeof(int32_t) * (L1TileShape::M * coreNum * WORKSPACE_STAGES * L1TileShape::N)); + + uint32_t mActual = groupList.GetValue(params.problemCount - 1); + uint32_t n = params.problemShape.n(); + uint32_t nOut = params.problemShape.n() / 2; + + { + BlockScheduler blockScheduler; + BlockEpilogue blockEpilogue(resource); + + uint32_t stageId = 0; + uint32_t startCoreIdx = 0; + for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx) + : (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1)); + GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; + + LayoutScale layoutScale = params.layoutScale; + LayoutPerTokenScale layoutPerTokenScale = + params.layoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>()); + LayoutD layoutD = layout::RowMajor{currentM, n}; + + EpilogueParams epilogueParams{params.ptrScale + gmGroupOffsetScale, + layoutScale, + params.ptrPerTokenScale + gmGroupOffsetPerTokenScale, + layoutPerTokenScale, + ptrD + gmGroupOffsetD, + layoutD}; + + blockScheduler.Update(inGroupProblemShape, L1TileShape::ToCoordMN()); + blockEpilogue.UpdateParams(epilogueParams); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + GemmCoord blockShapeMNK = L1TileShape::ToCoord(); + uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { + GemmCoord blockCoordMNK = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShapeMNK = blockScheduler.GetActualBlockShape(blockCoordMNK); + + MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, 0}; + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + auto gmBlockC = gmC[gmOffsetC]; + auto layoutBlockC = layoutC.GetTileLayout(actualBlockShapeMNK.GetCoordMN()); + + Arch::CrossCoreWaitFlag(flagAicFinishStoreList[stageId]); + blockEpilogue(blockShapeMNK, blockCoordMNK, actualBlockShapeMNK, gmBlockC, layoutBlockC); + Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(flagAivFinishComputeList[stageId]); + + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + gmGroupOffsetScale += inGroupProblemShape.n(); + gmGroupOffsetPerTokenScale += inGroupProblemShape.m(); + gmGroupOffsetD += currentM * n; + + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; + } + } + + Arch::CrossCoreBarrier<0x0, PIPE_MTE3>(); + + { + uint32_t swigluRowOnce = 0; + CalQuantRow(nOut, swigluRowOnce); + auto swigluLayout = layout::RowMajor{mActual, n}; + typename SwigluPost::Params swigluParams{ptrD, + swigluLayout, + params.ptrSwigluScale, + params.layoutSwigluScale, + params.ptrOutput, + params.layoutOutput, + swigluRowOnce, + nOut}; + + SwigluPost blockSwiglu(resource, swigluParams); + MatrixCoord swigluShape(mActual, nOut); + MatrixCoord swigluBlockShape((uint16_t)(AscendC::GetSubBlockNum() * swigluRowOnce), nOut); + Epilogue::Tile::EpilogueHorizontalTileSwizzle swigluSwizzle(swigluShape, swigluBlockShape); + for (uint32_t loopIdx = coreIdx; loopIdx < swigluSwizzle.GetLoops(); loopIdx += coreNum) { + auto blockCoord = swigluSwizzle.GetTileCoord(loopIdx); + auto actualBlockShape = swigluSwizzle.GetActualTileShape(blockCoord); + + blockSwiglu(swigluBlockShape, blockCoord, actualBlockShape); + } + } + } + +private: + friend struct AicWaitFunc; + friend struct AicSetFunc; + + struct AicWaitFunc { + using MatmulKernel = GroupedMatmulSliceMSwigluMultiStageWorkspaceWithShallowDispatch< + TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>; + + CATLASS_DEVICE + AicWaitFunc() = default; + + CATLASS_DEVICE + void operator()() const + { + Arch::CrossCoreWaitFlag(ptr->flagAivFinishComputeList[stageId]); + } + + MatmulKernel *ptr{nullptr}; + uint32_t stageId; + }; + + struct AicSetFunc { + using MatmulKernel = GroupedMatmulSliceMSwigluMultiStageWorkspaceWithShallowDispatch< + TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>; + + CATLASS_DEVICE + AicSetFunc() = default; + + CATLASS_DEVICE + void operator()() const + { + Arch::CrossCoreSetFlag<0x2, PIPE_FIX>(ptr->flagAicFinishStoreList[stageId]); + } + + MatmulKernel *ptr{nullptr}; + uint32_t stageId; + }; + + Arch::CrossCoreFlag flagAicFinishStoreList[WORKSPACE_STAGES]; + Arch::CrossCoreFlag flagAivFinishComputeList[WORKSPACE_STAGES]; + + AicWaitFunc aicWaitFuncList[WORKSPACE_STAGES]; + AicSetFunc aicSetFuncList[WORKSPACE_STAGES]; + Arch::Resource resource; +}; + +} // namespace Catlass::Gemm::Kernel diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_combine.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_combine.h index c8fa2e1f..c80678e4 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_combine.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_combine.h @@ -9,7 +9,9 @@ */ #ifndef CAM_MOE_DISTRIBUTE_COMBINE_H #define CAM_MOE_DISTRIBUTE_COMBINE_H +#ifndef OPT_RANK_OFFSET #define OPT_RANK_OFFSET 512 +#endif #include "kernel_operator.h" #include "kernel_tiling/kernel_tiling.h" diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_dispatch.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_dispatch.h index f73e2c60..1cc430bd 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_dispatch.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_dispatch.h @@ -10,7 +10,9 @@ #ifndef CAM_MOE_DISTRIBUTE_DISPATCH_H #define CAM_MOE_DISTRIBUTE_DISPATCH_H +#ifndef OPT_RANK_OFFSET #define OPT_RANK_OFFSET 512 +#endif #include "kernel_operator.h" #include "kernel_tiling/kernel_tiling.h" diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_base.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_base.h index b9ac8932..cd4dd6b9 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_base.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_base.h @@ -12,10 +12,107 @@ #include "../common/moe_distribute_base.h" -#define TemplateMC2TypeClass typename ExpandXType, typename W1ScaleType, typename W2ScaleType, typename ExpandIdxType, bool IsNeedReduceScatter, uint32_t EXEC_FLAG -#define TemplateMC2TypeFunc ExpandXType, W1ScaleType, W2ScaleType, ExpandIdxType, IsNeedReduceScatter, EXEC_FLAG +#define TemplateMC2TypeClass typename ExpandXType, typename W1ScaleType, typename W2ScaleType, typename WType, typename ExpandIdxType, bool IsNeedReduceScatter, uint32_t EXEC_FLAG +#define TemplateMC2TypeFunc ExpandXType, W1ScaleType, W2ScaleType, WType, ExpandIdxType, IsNeedReduceScatter, EXEC_FLAG #define TemplateDispatchTypeClass \ typename XType, typename ExpandXOutType, bool StaticQuant, bool DynamicQuant, bool IsSmoothScaleExist, \ bool IsNeedAllgater, uint32_t EXEC_FLAG #define TemplateDispatchTypeFunc XType, ExpandXOutType, StaticQuant, DynamicQuant, IsSmoothScaleExist, IsNeedAllgater, EXEC_FLAG + +constexpr uint32_t STATE_OFFSET = 512; +constexpr uint64_t WIN_STATE_OFFSET = 512 * 1024; +constexpr uint64_t STATE_WIN_OFFSET = 900 * 1024; +constexpr uint64_t GROUP_TOKEN_NUM_OFFSET = 932 * 1024; +constexpr uint64_t SOFT_SYNC_OFFSET = 964 * 1024; +constexpr uint32_t SELF_STATE_OFFSET = 256 * 1024; +constexpr uint32_t SUM_TMP_TENSOR_SIZE = 1024; +constexpr uint32_t UB_ALIGN = 32; +constexpr uint32_t TOKEN_EXTRA_SPACE = 512; +constexpr uint32_t INT32_COUNT_PER_BLOCK = 8; +constexpr uint32_t SOFT_SYNC_SPACE_SIZE = 512; +constexpr int64_t LOOP_TMP_SIZE = 4096; +constexpr int32_t SUB_AIV_NUM = 2; +constexpr int32_t ODD_EVEN_BASE = 2; +constexpr int32_t BUFFER_NUM = 2; +constexpr int32_t GATHER_SECOND_NUM = 2; +constexpr uint32_t MAX_QUANT_ROW_ONCE = 8; +constexpr uint32_t QUANT_SPACE_FACTOR = 176 * 1024 / 11; // up to 176KB for quant +#ifndef OPT_RANK_OFFSET +#define OPT_RANK_OFFSET 512 +#endif + +#define CEIL_UP(x) ((x + UB_ALIGN - 1) / UB_ALIGN * UB_ALIGN) +#define CEIL(x, y) (((x) + (y - 1)) / (y)) +#define UB_BLOCK_SIZE (32) +#define GET_WIND_STATE_ADDR_BY_RANK_ID(rankId) \ + (((epRankId == rankId) \ + ? ((GM_ADDR)(winContext_->localWindowsExp)) \ + : ((GM_ADDR)(((HcclRankRelationResV2 *)(winContext_->remoteRes[rankId].nextDevicePtr))->windowsExp))) + \ + dataState * WIN_STATE_OFFSET) +#define GET_WIND_ADDR_BY_RANK_ID(rankId) \ + (((epRankId == rankId) \ + ? ((GM_ADDR)(winContext_->localWindowsIn)) \ + : ((GM_ADDR)(((HcclRankRelationResV2 *)(winContext_->remoteRes[rankId].nextDevicePtr))->windowsIn))) + \ + winDataSizeOffset + rankId * OPT_RANK_OFFSET) +#define TOKEN_FLAG_1 (0x55555555) +#define TOKEN_FLAG_2 (0x33333333) +#define V_TO_C_FLAG_1 (0x03030303) +#define V_TO_C_FLAG_2 (0x05050505) +#define CV_FLAG_INDEX 0 +#define GROUP_ID_INDEX 1 +#define PRE_COUNT_INDEX 2 +#define SELF_COUNT_INDEX 3 +#define TOTAL_COUNT_INDEX 4 +#define GROUP_TOKEN_COUNT 3 // equal to SELF_COUNT_INDEX +#define GROUP_INFO_SIZE 32 + +__aicore__ inline static void EncreaseSyncFlag(__gm__ uint8_t *flagAddr, uint8_t idx) +{ + // flag++, like set flag + AscendC::PipeBarrier(); + AscendC::GlobalTensor global; + global.SetGlobalBuffer(flagAddr + idx * SOFT_SYNC_SPACE_SIZE); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + global); + __asm__ __volatile__(""); + uint8_t value = global.GetValue(0); + global.SetValue(0, value + 1); + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid( + global); + __asm__ __volatile__(""); + AscendC::PipeBarrier(); +} + +__aicore__ inline static void CheckSyncFlag(__gm__ uint8_t *flagAddr, uint8_t idx, uint32_t target) +{ + // check flag, like wait flag + AscendC::PipeBarrier(); + AscendC::GlobalTensor global; + global.SetGlobalBuffer(flagAddr + idx * SOFT_SYNC_SPACE_SIZE); + while (true) { + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid(global); + __asm__ __volatile__(""); + uint8_t value = global.GetValue(0); + if (value >= target) { + __asm__ __volatile__(""); + AscendC::DataCacheCleanAndInvalid(global); + __asm__ __volatile__(""); + break; + } + } + AscendC::PipeBarrier(); +} + +__aicore__ inline static void CalQuantRow(const uint32_t column, uint32_t &row) +{ + row = QUANT_SPACE_FACTOR / column; + row = row < MAX_QUANT_ROW_ONCE ? row : MAX_QUANT_ROW_ONCE; +} + + #endif // DISPATCH_GMM_COMBINE_DECODE_BASE_H diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_bf16_fp16.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_bf16_fp16.h new file mode 100644 index 00000000..28d54759 --- /dev/null +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_bf16_fp16.h @@ -0,0 +1,457 @@ +/* + * Copyright (c) 2026 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + */ +#ifndef DISPATCH_GMM_COMBINE_DECODE_BF16_FP16_H +#define DISPATCH_GMM_COMBINE_DECODE_BF16_FP16_H + +#include "lib/matmul_intf.h" +#include + +#include "catlass/catlass.hpp" +#include "catlass/arch/arch.hpp" +#include "catlass/layout/layout.hpp" +#include "catlass/epilogue/tile/tile_broadcast_mul.hpp" +#include "catlass/epilogue/tile/tile_broadcast_one_blk.hpp" +#include "catlass/epilogue/tile/tile_swizzle.hpp" +#include "catlass/gemm/block/block_swizzle.hpp" +#include "dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_multistage_workspace_bf16_fp16.h" +#include "catlass/gemm/gemm_type.hpp" +#include "dispatch_gmm_combine_decode/epilogue/dispatch_policy.h" +#include "dispatch_gmm_combine_decode/gemm/dispatch_policy.h" +#include "dispatch_gmm_combine_decode/epilogue/block/block_epilogue.h" +#include "dispatch_gmm_combine_decode/gemm/block/block_mmad.h" +#include "dispatch_gmm_combine_decode/gemm/kernel/grouped_matmul_slice_m_swiglu_multistage_workspace_bf16_fp16.h" + +#include "dispatch_gmm_combine_decode/raw_distributed/cam_moe_distribute_dispatch.h" + +#include "dispatch_gmm_combine_decode_tiling.h" +#include "dispatch_gmm_combine_decode_base.h" + +using namespace Catlass; + +namespace DispatchGmmCombineDecodeBf16Fp16Impl { + +using MmadAtlasA2Custom = + Gemm::MmadAtlasA2PreloadAsyncWithCallback; + +using Gmm1L1TileShape = GemmShape; +using Gmm1L0TileShape = GemmShape; +using Gmm1EpilogueTileShape = MatrixShape; +using Gmm1BlockScheduler = typename Gemm::Block::GemmIdentityBlockSwizzle; + +using Gmm2L1TileShape = GemmShape; +using Gmm2L0TileShape = GemmShape; +using Gmm2EpilogueTileShape = MatrixShape; +using Gmm2BlockScheduler = typename Gemm::Block::GemmIdentityBlockSwizzle; +using Gmm2DispatchPolicy = + Gemm::MmadAtlasA2PreloadAsyncWithCallbackResidentA; + +template +CATLASS_DEVICE void GmmDeqSwigluQuant(GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA, + layout::RowMajor layoutA, GM_ADDR gmB, + typename std::conditional<(EXEC_FLAG & EXEC_FLAG_ND_FORMAT) != 0, layout::RowMajor, layout::zN>::type layoutB, + GM_ADDR gmScale, + layout::VectorLayout layoutScale, GM_ADDR gmPerTokenScale, + layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD, + GM_ADDR gmDequantScale, layout::VectorLayout layoutDequantScale, GM_ADDR gmWorkspace, + GM_ADDR gmX, GM_ADDR debugGm, GM_ADDR gmexpertIds, GM_ADDR gmExpandIdx, + GM_ADDR gmEpSendCount, GM_ADDR xActiveMask, GM_ADDR gmResvered, GM_ADDR gmExpertTokenNums, + uint32_t epRankSize, uint32_t epRankId, uint32_t moeExpertNum, + uint32_t moeExpertNumPerRank, uint32_t sharedExpertNum, uint32_t sharedExpertRankNum, + uint32_t quantMode, uint32_t globalBs, uint32_t bs, uint32_t topK, uint32_t tokenLen) +{ + using ArchTag = Arch::AtlasA2; + using DispatchPolicy = DispatchPolicy_; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + + using AType = Gemm::GemmType; + using LayoutB = typename std::conditional<(EXEC_FLAG & EXEC_FLAG_ND_FORMAT) != 0, layout::RowMajor, layout::zN>::type; + using BType = Gemm::GemmType; + using CType = Gemm::GemmType; + + using BlockMmad = Gemm::Block::BlockMmad; + + constexpr uint32_t ubStages = 1; + using EpilogueDispatchPolicy = Epilogue::EpilogueAtlasA2Swiglu; + using ScaleType = Gemm::GemmType; + using PerTokenScaleType = Gemm::GemmType; + using DType = Gemm::GemmType; + + using RowBroadcastMulType = Gemm::GemmType; + using BroadcastOneBlkType = Gemm::GemmType; + using OneBlkColumnBroadcastMulType = Gemm::GemmType; + + using EpilogueTileShape = EpilogueTileShape_; + using TileRowBroadcastMul = Epilogue::Tile::TileRowBroadcastMul; + using TileBroadcastOneBlk = + Epilogue::Tile::TileBroadcastOneBlk; + using TileOneBlkColumnBroadcastMul = + Epilogue::Tile::TileOneBlkColumnBroadcastMul; + using TileCopy = Epilogue::Tile::TileCopy; + using TileScheduler = Epilogue::Tile::EpilogueHorizontalTileSwizzle; + + using BlockEpilogue = Epilogue::Block::BlockEpilogue; + + using BlockScheduler = BlockScheduler_; + + // kernel level + using ElementGroupList = int64_t; + + using GemmKernel = typename std::conditional< + (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) != 0, + Gemm::Kernel::GroupedMatmulSliceMSwigluMultiStageWorkspace< + TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>, + Gemm::Kernel::GroupedMatmulSliceMSwigluMultiStageWorkspaceWithShallowDispatch< + TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>>::type; + + if constexpr ((EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) != 0) { + typename GemmKernel::Params params{problemShape, + groupCount, + gmGroupList, + gmA, + layoutA, + gmB, + layoutB, + gmScale, + layoutScale, + gmPerTokenScale, + layoutPerTokenScale, + gmD, + layoutD, + gmDequantScale, + layoutDequantScale, + gmWorkspace, + gmX, + debugGm, + gmexpertIds, + gmExpandIdx, + gmEpSendCount, + xActiveMask, + gmResvered, + gmExpertTokenNums, + epRankSize, + epRankId, + moeExpertNum, + moeExpertNumPerRank, + sharedExpertNum, + sharedExpertRankNum, + quantMode, + globalBs, + bs, + topK, + tokenLen}; + // call a kernel + GemmKernel gemm; + gemm(params); + } else { + typename GemmKernel::Params params{problemShape, + groupCount, + gmGroupList, + gmA, + layoutA, + gmB, + layoutB, + gmScale, + layoutScale, + gmPerTokenScale, + layoutPerTokenScale, + gmD, + layoutD, + gmDequantScale, + layoutDequantScale, + gmWorkspace}; + // call a kernel + GemmKernel gemm; + gemm(params); + } +} + +template +CATLASS_DEVICE void GmmDeq(GemmCoord problemShape, uint32_t groupCount, GM_ADDR gmGroupList, GM_ADDR gmA, + layout::RowMajor layoutA, GM_ADDR gmB, + typename std::conditional<(EXEC_FLAG & EXEC_FLAG_ND_FORMAT) != 0, layout::RowMajor, layout::zN>::type layoutB, + GM_ADDR gmScale, + layout::VectorLayout layoutScale, GM_ADDR gmPerTokenScale, + layout::VectorLayout layoutPerTokenScale, GM_ADDR gmD, layout::RowMajor layoutD, + GM_ADDR gmWorkspace, void *combiner) +{ + using ArchTag = Arch::AtlasA2; + using DispatchPolicy = DispatchPolicy_; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + + using AType = Gemm::GemmType; + using LayoutB = typename std::conditional<(EXEC_FLAG & EXEC_FLAG_ND_FORMAT) != 0, layout::RowMajor, layout::zN>::type; + using BType = Gemm::GemmType; + using CType = Gemm::GemmType; + + using BlockMmad = Gemm::Block::BlockMmad; + + constexpr uint32_t ubStages = 1; + using EpilogueDispatchPolicy = Epilogue::EpilogueAtlasA2Combine; + using ScaleType = Gemm::GemmType; + using PerTokenScaleType = Gemm::GemmType; + using DType = Gemm::GemmType; + + using RowBroadcastMulType = Gemm::GemmType; + using BroadcastOneBlkType = Gemm::GemmType; + using OneBlkColumnBroadcastMulType = Gemm::GemmType; + + using EpilogueTileShape = EpilogueTileShape_; + using TileRowBroadcastMul = Epilogue::Tile::TileRowBroadcastMul; + using TileBroadcastOneBlk = + Epilogue::Tile::TileBroadcastOneBlk; + using TileOneBlkColumnBroadcastMul = + Epilogue::Tile::TileOneBlkColumnBroadcastMul; + using TileCopy = Epilogue::Tile::TileCopy; + using TileScheduler = Epilogue::Tile::EpilogueHorizontalTileSwizzle; + + using BlockEpilogue = Epilogue::Block::BlockEpilogue; + + using BlockScheduler = BlockScheduler_; + + // kernel level + using ElementGroupList = int64_t; + using GemmKernel = Gemm::Kernel::GroupedMatmulSliceMMultiStageWorkspace< + TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, WORKSPACE_STAGES, ElementGroupList>; + + typename GemmKernel::Params params{ + problemShape, groupCount, gmGroupList, gmA, layoutA, gmB, layoutB, gmScale, + layoutScale, gmPerTokenScale, layoutPerTokenScale, gmD, layoutD, gmWorkspace, combiner}; + + // call a kernel + GemmKernel gemm; + gemm(params); +} + +template +class DispatchGmmCombineDecodeBf16Fp16 +{ +public: + __aicore__ inline DispatchGmmCombineDecodeBf16Fp16(){}; + __aicore__ inline void Init( + // input + GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale, + GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales, GM_ADDR x_active_mask, + // output + GM_ADDR output, GM_ADDR expertTokenNums, + // system + GM_ADDR workspaceGM, AscendC::TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData); + __aicore__ inline void Process(); + +private: + GM_ADDR gmX_; + GM_ADDR gmexpertIds_; + GM_ADDR gmPermuteWeight1_; + GM_ADDR gmPermuteScale1_; + GM_ADDR gmWeight2_; + GM_ADDR gmScale2_; + GM_ADDR gmOutput_; + GM_ADDR gmExpertTokenNums_; + GM_ADDR workspaceGM_; + GM_ADDR gmSmoothScales_; + GM_ADDR gmexpertScales_; + GM_ADDR xActiveMask_; + + uint32_t maxTokenNum_{0}; + uint32_t gmm1OutputDim_{0}; + uint32_t tokenHiddenSize_{0}; + uint32_t groupCount_{0}; + uint32_t gmm2OutputDim_{0}; + uint32_t gmm2InputDim_{0}; + uint32_t globalRankId_{0}; + uint32_t winSizePerRank_{0}; + uint32_t blockDim_{0}; + uint32_t epRankSize_{0}; + uint32_t epRankId_{0}; + uint32_t moeExpertNum_{0}; + uint32_t moeExpertNumPerRank_{0}; + uint32_t sharedExpertNum_{0}; + uint32_t sharedExpertRankNum_{0}; + uint32_t quantMode_{0}; + uint32_t globalBs_{0}; + uint32_t bs_{0}; + uint32_t maxBs_{0}; + uint32_t topK_{0}; + + AscendC::TPipe *tpipe_{nullptr}; + __gm__ HcclOpResParam *winContext_{nullptr}; + const DispatchGmmCombineDecodeTilingData *tilingData_; +}; + +template +__aicore__ inline void DispatchGmmCombineDecodeBf16Fp16::Init( + // input + GM_ADDR x, GM_ADDR expert_ids, GM_ADDR gmm1_permuted_weight, GM_ADDR gmm1_permuted_weight_scale, + GM_ADDR gmm2_weight, GM_ADDR gmm2_weight_scale, GM_ADDR expert_scales, GM_ADDR expert_smooth_scales, + GM_ADDR x_active_mask, + // output + GM_ADDR output, GM_ADDR expertTokenNums, + // system + GM_ADDR workspaceGM, AscendC::TPipe *pipe, const DispatchGmmCombineDecodeTilingData *tilingData) +{ + tpipe_ = pipe; + blockDim_ = AscendC::GetBlockNum(); + winContext_ = (__gm__ HcclOpResParam *)AscendC::GetHcclContext(); + + gmSmoothScales_ = expert_smooth_scales; // not used now + gmX_ = x; // input token + gmexpertIds_ = expert_ids; + gmPermuteWeight1_ = gmm1_permuted_weight; + gmPermuteScale1_ = nullptr; + gmWeight2_ = gmm2_weight; + gmScale2_ = nullptr; + gmOutput_ = output; + gmExpertTokenNums_ = expertTokenNums; + workspaceGM_ = workspaceGM; + gmexpertScales_ = expert_scales; + xActiveMask_ = x_active_mask; + tilingData_ = tilingData; + epRankSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankSize; + epRankId_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.epRankId; + moeExpertNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNum; + moeExpertNumPerRank_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank; + sharedExpertNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertNum; + sharedExpertRankNum_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.sharedExpertRankNum; + quantMode_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.quantMode; + globalBs_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.globalBs; + bs_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.bs; + topK_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.k; + maxBs_ = globalBs_ / epRankSize_; + + bool isShareExpert = (epRankId_ < sharedExpertRankNum_); + if (isShareExpert) { + maxTokenNum_ = maxBs_ * epRankSize_ / sharedExpertRankNum_; + } else { + maxTokenNum_ = maxBs_ * epRankSize_ * (topK_ < moeExpertNumPerRank_ ? topK_ : moeExpertNumPerRank_); + } + + gmm1OutputDim_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.gmm1HLen; + tokenHiddenSize_ = tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.h; + groupCount_ = isShareExpert ? 1 : tilingData->disGmmDeqSwigluQuantGmmDeqComInfo.moeExpertNumPerRank; + gmm2OutputDim_ = tokenHiddenSize_; + gmm2InputDim_ = gmm1OutputDim_ / 2; +} + +template +__aicore__ inline auto CreateWeightLayout(uint32_t k, uint32_t n) { + if constexpr ((EXEC_FLAG & EXEC_FLAG_ND_FORMAT) != 0) { + MatrixCoord mc{k, n}; + return layout::RowMajor::template MakeLayoutInUb(mc); + } else { + return layout::zN::template MakeLayout(k, n); + } +} + +template +__aicore__ inline void DispatchGmmCombineDecodeBf16Fp16::Process() +{ + using LayoutB = typename std::conditional<(EXEC_FLAG & EXEC_FLAG_ND_FORMAT) != 0, layout::RowMajor, layout::zN>::type; + GemmCoord gmm1ProblemShape{maxTokenNum_, gmm1OutputDim_, tokenHiddenSize_}; + GemmCoord gmm2ProblemShape{maxTokenNum_, gmm2OutputDim_, gmm2InputDim_}; + + layout::RowMajor layoutX1{maxTokenNum_, tokenHiddenSize_}; + auto layoutWeight1 = CreateWeightLayout(tokenHiddenSize_, gmm1OutputDim_); + layout::VectorLayout layoutW1Scale{gmm1OutputDim_}; + layout::VectorLayout layoutX1Scale{maxTokenNum_}; + layout::RowMajor layoutX2{maxTokenNum_, gmm2InputDim_}; + auto layoutWeight2 = CreateWeightLayout(gmm2InputDim_, gmm2OutputDim_); + layout::VectorLayout layoutW2Scale{gmm2OutputDim_}; + layout::VectorLayout layoutX2Scale{maxTokenNum_}; + layout::RowMajor layoutOutput{maxTokenNum_, gmm2OutputDim_}; + + size_t workspaceOffset = 0; + constexpr int32_t resveredWorkSpaceSize = 256 * 1024; + int64_t x1TokenSize = maxTokenNum_ * tokenHiddenSize_ * sizeof(ExpandXType); + int64_t x2TokenSize = maxTokenNum_ * gmm2InputDim_ * sizeof(ExpandXType); + int64_t maxTokenSize = x1TokenSize < x2TokenSize ? x2TokenSize : x1TokenSize; + GM_ADDR gmX1 = workspaceGM_ + workspaceOffset; + GM_ADDR gmX2 = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(maxTokenSize); + GM_ADDR gmX1Scale = nullptr; + GM_ADDR gmX2Scale = nullptr; + GM_ADDR gmWorkspace = workspaceGM_ + workspaceOffset; + GM_ADDR gmCVSwap = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(static_cast(blockDim_) * (FP16_BF16_L1M * FP16_BF16_L1N) * + WORKSPACE_STAGES * sizeof(float)); + int64_t swigluOutSize = maxTokenNum_ * gmm1OutputDim_ * sizeof(float); + int64_t gmm2OutSize = maxTokenNum_ * tokenHiddenSize_ * sizeof(ExpandXType); + int64_t maxSwigluGmm2Size = swigluOutSize < gmm2OutSize ? gmm2OutSize : swigluOutSize; + GM_ADDR gmSwigluOut = workspaceGM_ + workspaceOffset; + GM_ADDR gmGmm2DepOut = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(maxSwigluGmm2Size); + GM_ADDR gmGroupList = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(static_cast(groupCount_) * sizeof(int64_t)); + GM_ADDR gmExpandIdx = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(static_cast(bs_) * topK_ * sizeof(int32_t)); + GM_ADDR gmEpSendCount = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(static_cast(epRankSize_) * groupCount_ * sizeof(int32_t)); + GM_ADDR gmResvered = workspaceGM_ + workspaceOffset; + workspaceOffset += RoundUp(resveredWorkSpaceSize); + + if constexpr ((EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) == 0) { + if constexpr (g_coreType == AscendC::AIV) { + AscendC::TPipe tpipe; + MoeDistributeDispatchImpl::CamMoeDistributeDispatch + dispatcher; + dispatcher.Init(gmX_, gmexpertIds_, gmSmoothScales_, xActiveMask_, gmX1, gmX1Scale, gmExpandIdx, gmGroupList, + gmEpSendCount, gmExpertTokenNums_, nullptr, gmWorkspace, &tpipe, tilingData_); + dispatcher.Process(); + tpipe.Destroy(); + icache_preload(8); + } + + AscendC::PipeBarrier(); + Arch::CrossCoreFlag gmm1AivFinished{0}; + if constexpr (g_coreType == AscendC::AIV) { + Arch::CrossCoreBarrier<0x0, PIPE_MTE3>(); + Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(gmm1AivFinished); + } else { + Arch::CrossCoreWaitFlag(gmm1AivFinished); + } + } + GmmDeqSwigluQuant( + gmm1ProblemShape, groupCount_, gmGroupList, gmX1, layoutX1, gmPermuteWeight1_, layoutWeight1, + gmPermuteScale1_, layoutW1Scale, gmX1Scale, layoutX1Scale, gmX2, layoutX2, gmX2Scale, + layoutX2Scale, gmWorkspace, gmX_, gmSmoothScales_, gmexpertIds_, gmExpandIdx, gmEpSendCount, xActiveMask_, gmResvered, + gmExpertTokenNums_, epRankSize_, epRankId_, moeExpertNum_, moeExpertNumPerRank_, sharedExpertNum_, + sharedExpertRankNum_, quantMode_, globalBs_, bs_, topK_, tokenHiddenSize_); + AscendC::PipeBarrier(); + Arch::CrossCoreFlag gmm1AivFinished{0}; + if constexpr (g_coreType == AscendC::AIV) { + Arch::CrossCoreBarrier<0x0, PIPE_MTE3>(); + Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(gmm1AivFinished); + } else { + Arch::CrossCoreWaitFlag(gmm1AivFinished); + } + + MoeDistributeCombineImpl::CamMoeDistributeCombine combiner; + if (g_coreType == AscendC::AIV) { + combiner.Init(gmGmm2DepOut, gmexpertIds_, gmExpandIdx, gmEpSendCount, nullptr, gmexpertScales_, xActiveMask_, gmOutput_, + workspaceGM_, nullptr, tilingData_); + } + GmmDeq(gmm2ProblemShape, groupCount_, gmGroupList, gmX2, layoutX2, gmWeight2_, layoutWeight2, + gmScale2_, layoutW2Scale, gmX2Scale, layoutX2Scale, gmGmm2DepOut, + layoutOutput, gmWorkspace, &combiner); +} +} // namespace DispatchGmmCombineDecodeBf16Fp16Impl +#endif // DISPATCH_GMM_COMBINE_DECODE_BF16_FP16_H diff --git a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_tiling.h b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_tiling.h index 3874ffa2..b006d9d7 100644 --- a/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_tiling.h +++ b/csrc/dispatch_gmm_combine_decode/op_kernel/dispatch_gmm_combine_decode_tiling.h @@ -31,6 +31,8 @@ struct DispatchGmmCombineDecodeInfo { uint64_t totalWinSize; uint64_t gmm1HLen; bool isTensorList; + bool isBf16Fp16W; + bool isNDFormat; }; struct DispatchGmmCombineDecodeTilingData { @@ -48,6 +50,8 @@ constexpr uint32_t CUSTOM_L0C_STAGES = 1; constexpr bool CUSTOM_ENABLE_UNIT_FLAG = true; constexpr bool CUSTOM_ENABLE_SHUFFLE_K = true; +constexpr uint32_t FP16_BF16_L1M = 128; +constexpr uint32_t FP16_BF16_L1N = 128; constexpr uint32_t GMM1_L1M = 256; constexpr uint32_t GMM1_L1N = 128; constexpr uint32_t GMM1_L1K = 512; @@ -56,6 +60,8 @@ constexpr uint32_t GMM1_EPIM = 64; constexpr uint32_t GMM1_SWIZZLE_OFFSET = 3; constexpr uint32_t GMM1_SWIZZLE_DIRECTION = 0; +constexpr uint32_t FP16_BF16_GMM2_L1M = 64; +constexpr uint32_t FP16_BF16_GMM2_L1N = 128; constexpr uint32_t GMM2_L1A_STAGES = 4; constexpr uint32_t GMM2_L1B_STAGES = 2; constexpr uint32_t GMM2_L0A_STAGES = 4; @@ -73,5 +79,6 @@ constexpr uint32_t WORKSPACE_STAGES = 4; constexpr uint32_t EXEC_FLAG_DEEP_FUSE = (1U << 0); constexpr uint32_t EXEC_FLAG_TENSOR_LIST = (1U << 1); constexpr uint32_t EXEC_FLAG_X_ACTIVE_MASK = (1U << 2); +constexpr uint32_t EXEC_FLAG_ND_FORMAT = (1U << 3); #endif // DISPATCH_GMM_COMBINE_DECODE_TILING_H diff --git a/tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_gmm_combine_decode.py b/tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_gmm_combine_decode.py index b928ad7e..278e58c0 100644 --- a/tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_gmm_combine_decode.py +++ b/tests/e2e/nightly/single_node/ops/multicard_ops_a3/test_dispatch_gmm_combine_decode.py @@ -4,6 +4,7 @@ import sys from pathlib import Path import numpy as np +import time import torch import torch.distributed as dist import torch.multiprocessing as mp @@ -28,7 +29,9 @@ BASE_KWARGS = { "enable_dynamic_bs": False, "test_graph": False, "with_mc2_mask": False, - "dynamic_eplb": False + "dynamic_eplb": False, + "w8a8_dynamic": True, + "is_nz": True } @@ -50,7 +53,7 @@ def permute_weight(w: torch.Tensor, tile_n): def output_to_file(rank_id): - return False + return rank_id > 0 class DecodeMoeOps(torch.nn.Module): @@ -68,8 +71,14 @@ class DecodeMoeOps(torch.nn.Module): moe_expert_num, global_rank_id, shared_expert_rank_num=0, - dynamic_eplb=False): + dynamic_eplb=False, + w8a8_dynamic=True, + is_nz=True): super().__init__() + if w8a8_dynamic: + assert (gmm1_weight_scale is not None and gmm2_weight_scale is not None), "gmm1_weight_scale and gmm2_weight_scale must be provided for w8a8_dynamic" + else: + assert (gmm1_weight_scale is None and gmm2_weight_scale is None), "gmm1_weight_scale and gmm2_weight_scale must be None for w8a8_dynamic" self.ep_hcomm_info = ep_hcomm_info self.batch_size = batch_size self.token_hidden_size = token_hidden_size @@ -84,38 +93,47 @@ class DecodeMoeOps(torch.nn.Module): self.local_expert_num = 1 if is_shared_expert else moe_expert_num_per_rank self.ep_recv_count_size = self.local_expert_num * ep_world_size self.dynamic_eplb = dynamic_eplb + self.w8a8_dynamic = w8a8_dynamic + self.is_nz = is_nz self.gmm1_weight = torch.empty([ self.local_expert_num, self.token_hidden_size, self.moe_intermediate_size * 2 ]) - self.gmm1_weight_scale = torch.empty( - [self.local_expert_num, self.moe_intermediate_size * 2]) self.gmm2_weight = torch.empty([ self.local_expert_num, self.moe_intermediate_size, self.token_hidden_size ]) - self.gmm2_weight_scale = torch.empty( - [self.local_expert_num, self.token_hidden_size]) + if self.w8a8_dynamic: + self.gmm1_weight_scale = torch.empty( + [self.local_expert_num, self.moe_intermediate_size * 2]) + self.gmm2_weight_scale = torch.empty( + [self.local_expert_num, self.token_hidden_size]) + else: + self.gmm1_weight_scale = None + self.gmm2_weight_scale = None + self.gmm1_weight_scale_fp32 = None + self.gmm2_weight_scale_fp32 = None self._process_weights_after_loading(gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale) def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale): - gmm1_weight = torch_npu.npu_format_cast(gmm1_weight, - torch_npu.Format.FRACTAL_NZ) - gmm2_weight = torch_npu.npu_format_cast(gmm2_weight, - torch_npu.Format.FRACTAL_NZ) + if self.w8a8_dynamic: + gmm1_weight = torch_npu.npu_format_cast(gmm1_weight, + torch_npu.Format.FRACTAL_NZ) + gmm2_weight = torch_npu.npu_format_cast(gmm2_weight, + torch_npu.Format.FRACTAL_NZ) self.gmm1_weight = torch.nn.Parameter(gmm1_weight, requires_grad=False) - self.gmm1_weight_scale = torch.nn.Parameter(gmm1_weight_scale, - requires_grad=False) self.gmm2_weight = torch.nn.Parameter(gmm2_weight, requires_grad=False) - self.gmm2_weight_scale = torch.nn.Parameter(gmm2_weight_scale, - requires_grad=False) - - self.gmm1_weight_scale_fp32 = torch.nn.Parameter( - gmm1_weight_scale.float(), requires_grad=False) - self.gmm2_weight_scale_fp32 = torch.nn.Parameter( - gmm2_weight_scale.float(), requires_grad=False) + if self.w8a8_dynamic: + self.gmm1_weight_scale = torch.nn.Parameter(gmm1_weight_scale, + requires_grad=False) + self.gmm2_weight_scale = torch.nn.Parameter(gmm2_weight_scale, + requires_grad=False) + self.gmm1_weight_scale_fp32 = torch.nn.Parameter( + gmm1_weight_scale.float(), requires_grad=False) + self.gmm2_weight_scale_fp32 = torch.nn.Parameter( + gmm2_weight_scale.float(), requires_grad=False) def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales, x_active_mask): @@ -142,12 +160,15 @@ class SmallOps(DecodeMoeOps): moe_expert_num, global_rank_id, shared_expert_rank_num=0, - dynamic_eplb=False): + dynamic_eplb=False, + w8a8_dynamic=True, + is_nz=True): super().__init__(gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale, ep_hcomm_info, batch_size, token_hidden_size, moe_intermediate_size, ep_world_size, moe_expert_num, global_rank_id, - shared_expert_rank_num, dynamic_eplb) + shared_expert_rank_num, dynamic_eplb, w8a8_dynamic, + is_nz) self.tp_hcomm_info = "" def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales, @@ -167,7 +188,7 @@ class SmallOps(DecodeMoeOps): expert_shard_type=0, shared_expert_num=1, shared_expert_rank_num=self.shared_expert_rank_num, - quant_mode=2, + quant_mode=2 if self.w8a8_dynamic else 0, global_bs=self.batch_size * self.ep_world_size, expert_token_nums_type=1, # 0代表前缀和,1代表各自数量 ) @@ -181,22 +202,26 @@ class SmallOps(DecodeMoeOps): group_list_type=1, # 默认为0,代表前缀和形式 group_type=0, # 0代表m轴分组 group_list=expert_token_nums, - output_dtype=torch.int32)[0] - y1, y1_scale = torch_npu.npu_dequant_swiglu_quant( - x=y1_int32, - weight_scale=self.gmm1_weight_scale.to(torch.float32), - activation_scale=dynamic_scales, - bias=None, - quant_scale=None, - quant_offset=None, - group_index=expert_token_nums, - activate_left=True, - quant_mode=1, - ) + output_dtype=torch.int32 if self.w8a8_dynamic else output_dtype)[0] + y1_scale = None + if self.w8a8_dynamic: + y1, y1_scale = torch_npu.npu_dequant_swiglu_quant( + x=y1_int32, + weight_scale=self.gmm1_weight_scale.to(torch.float32), + activation_scale=dynamic_scales, + bias=None, + quant_scale=None, + quant_offset=None, + group_index=expert_token_nums, + activate_left=True, + quant_mode=1, + ) + else: + y1 = torch_npu.npu_swiglu(y1_int32) y2 = torch_npu.npu_grouped_matmul(x=[y1], weight=[self.gmm2_weight], - scale=[self.gmm2_weight_scale], - per_token_scale=[y1_scale], + scale=[self.gmm2_weight_scale] if self.w8a8_dynamic else None, + per_token_scale=[y1_scale] if self.w8a8_dynamic else None, split_item=2, group_list_type=1, group_type=0, @@ -240,15 +265,19 @@ class FusionOp(DecodeMoeOps): moe_expert_num, global_rank_id, shared_expert_rank_num=0, - dynamic_eplb=False): + dynamic_eplb=False, + w8a8_dynamic=True, + is_nz=True): super().__init__(gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale, ep_hcomm_info, batch_size, token_hidden_size, moe_intermediate_size, ep_world_size, moe_expert_num, global_rank_id, - shared_expert_rank_num, dynamic_eplb) + shared_expert_rank_num, dynamic_eplb, w8a8_dynamic, + is_nz) def _apply_ops(self, x, expert_ids, smooth_scales, expert_scales, x_active_mask): + smooth_scales = torch.zeros(128 * 1024 * 1024).npu() output = torch.ops._C_ascend.dispatch_gmm_combine_decode( x=x, expert_ids=expert_ids, @@ -271,29 +300,35 @@ class FusionOp(DecodeMoeOps): def _process_weights_after_loading(self, gmm1_weight, gmm1_weight_scale, gmm2_weight, gmm2_weight_scale): - gmm1_weight = torch_npu.npu_format_cast(gmm1_weight, - torch_npu.Format.FRACTAL_NZ) - gmm2_weight = torch_npu.npu_format_cast(gmm2_weight, - torch_npu.Format.FRACTAL_NZ) + if self.is_nz: + gmm1_weight = torch_npu.npu_format_cast(gmm1_weight, + torch_npu.Format.FRACTAL_NZ) + gmm2_weight = torch_npu.npu_format_cast(gmm2_weight, + torch_npu.Format.FRACTAL_NZ) if self.dynamic_eplb: self.gmm1_weight = [ weight.clone() for weight in gmm1_weight.unbind(dim=0) ] - self.gmm1_weight_scale_fp32 = [ - weight.clone() for weight in gmm1_weight_scale.unbind(dim=0) - ] self.gmm2_weight = [ weight.clone() for weight in gmm2_weight.unbind(dim=0) ] - self.gmm2_weight_scale_fp32 = [ - weight.clone() for weight in gmm2_weight_scale.unbind(dim=0) - ] + if self.w8a8_dynamic: + self.gmm1_weight_scale_fp32 = [ + weight.clone() for weight in gmm1_weight_scale.unbind(dim=0) + ] + self.gmm2_weight_scale_fp32 = [ + weight.clone() for weight in gmm2_weight_scale.unbind(dim=0) + ] else: self.gmm1_weight = [gmm1_weight.clone()] - self.gmm1_weight_scale_fp32 = [gmm1_weight_scale.clone()] self.gmm2_weight = [gmm2_weight.clone()] - self.gmm2_weight_scale_fp32 = [gmm2_weight_scale.clone()] + if self.w8a8_dynamic: + self.gmm1_weight_scale_fp32 = [gmm1_weight_scale.clone()] + self.gmm2_weight_scale_fp32 = [gmm2_weight_scale.clone()] + else: + self.gmm1_weight_scale_fp32 = [torch.ones(1).npu().to(gmm1_weight.dtype)] + self.gmm2_weight_scale_fp32 = [torch.ones(1).npu().to(gmm2_weight.dtype)] def generate_datas(batch_size, @@ -306,7 +341,8 @@ def generate_datas(batch_size, top_k=8, test_bfloat16=True, enable_dynamic_bs=False, - with_mc2_mask=False): + with_mc2_mask=False, + w8a8_dynamic=True): is_shared_expert = global_rank_id < shared_expert_rank_num moe_expert_num_per_rank = moe_expert_num // (ep_world_size - shared_expert_rank_num) @@ -318,41 +354,59 @@ def generate_datas(batch_size, gmm1_output_dim = moe_intermediate_size * 2 gmm2_input_dim = moe_intermediate_size gmm2_output_dim = token_hidden_size - x = torch.rand([actual_bs, token_hidden_size]) * 10 - 5 + x = torch.rand([actual_bs, token_hidden_size]) * 0.5 - 0.5 expert_ids = torch.arange( global_rank_id * batch_size * top_k, global_rank_id * batch_size * top_k + actual_bs * top_k).to( torch.int32).view(actual_bs, top_k) expert_ids = expert_ids % moe_expert_num - if is_shared_expert: - gmm1_weight = torch.ones([ - local_expert_num, gmm1_input_dim, gmm1_output_dim - ]).to(torch.int8) * 4 - gmm2_weight = torch.ones([ - local_expert_num, gmm2_input_dim, gmm2_output_dim - ]).to(torch.int8) * 4 - gmm1_weight[:, :, ::2] = gmm1_weight[:, :, ::2] * -1 - gmm2_weight[:, :, ::2] = gmm2_weight[:, :, ::2] * -1 - gmm1_weight_scale = torch.ones([local_expert_num, gmm1_output_dim - ]) * 0.0015 - gmm2_weight_scale = torch.ones([local_expert_num, gmm2_output_dim - ]) * 0.0015 + gmm1_weight_scale = None + gmm2_weight_scale = None + if w8a8_dynamic: + if is_shared_expert: + gmm1_weight = torch.ones([ + local_expert_num, gmm1_input_dim, gmm1_output_dim + ]).to(torch.int8) * 4 + gmm2_weight = torch.ones([ + local_expert_num, gmm2_input_dim, gmm2_output_dim + ]).to(torch.int8) * 4 + gmm1_weight[:, :, ::2] = gmm1_weight[:, :, ::2] * -1 + gmm2_weight[:, :, ::2] = gmm2_weight[:, :, ::2] * -1 + gmm1_weight_scale = torch.ones([local_expert_num, gmm1_output_dim + ]) * 0.0015 + gmm2_weight_scale = torch.ones([local_expert_num, gmm2_output_dim + ]) * 0.0015 + else: + gmm1_weight = torch.randint( + -16, 16, + [local_expert_num, gmm1_input_dim, gmm1_output_dim]).to(torch.int8) + gmm2_weight = torch.randint( + -16, 16, + [local_expert_num, gmm2_input_dim, gmm2_output_dim]).to(torch.int8) + gmm1_weight_scale = torch.rand([local_expert_num, gmm1_output_dim + ]) * 0.003 + 0.0015 + gmm2_weight_scale = torch.rand([local_expert_num, gmm2_output_dim + ]) * 0.003 + 0.0015 else: - gmm1_weight = torch.randint( - -16, 16, - [local_expert_num, gmm1_input_dim, gmm1_output_dim]).to(torch.int8) - gmm2_weight = torch.randint( - -16, 16, - [local_expert_num, gmm2_input_dim, gmm2_output_dim]).to(torch.int8) - gmm1_weight_scale = torch.rand([local_expert_num, gmm1_output_dim - ]) * 0.003 + 0.0015 - gmm2_weight_scale = torch.rand([local_expert_num, gmm2_output_dim - ]) * 0.003 + 0.0015 + if is_shared_expert: + gmm1_weight = torch.ones([ + local_expert_num, gmm1_input_dim, gmm1_output_dim + ]).to(torch.bfloat16 if test_bfloat16 else torch.float16) * 0.5 + gmm2_weight = torch.ones([ + local_expert_num, gmm2_input_dim, gmm2_output_dim + ]).to(torch.bfloat16 if test_bfloat16 else torch.float16) * 0.5 + else: + gmm1_weight = torch.rand([local_expert_num, gmm1_input_dim, gmm1_output_dim]).to(torch.bfloat16 if test_bfloat16 else torch.float16) * 0.25 + gmm2_weight = torch.rand([local_expert_num, gmm2_input_dim, gmm2_output_dim]).to(torch.bfloat16 if test_bfloat16 else torch.float16) * 0.25 + gmm1_weight[:, ::2, :] = gmm1_weight[:, ::2, :] * -1 + gmm2_weight[:, ::2, :] = gmm2_weight[:, ::2, :] * -1 expert_scales = torch.rand(actual_bs, top_k) if test_bfloat16: x = x.bfloat16() - gmm1_weight_scale = gmm1_weight_scale.bfloat16() - gmm2_weight_scale = gmm2_weight_scale.bfloat16() + if w8a8_dynamic: + assert (gmm1_weight_scale is not None and gmm2_weight_scale is not None), "gmm1_weight_scale and gmm2_weight_scale must be provided for w8a8_dynamic" + gmm1_weight_scale = gmm1_weight_scale.bfloat16() + gmm2_weight_scale = gmm2_weight_scale.bfloat16() else: x = x.half() smooth_sales = None @@ -380,7 +434,9 @@ def run_once(local_rank_id, enable_dynamic_bs=False, test_graph=False, with_mc2_mask=False, - dynamic_eplb=False): + dynamic_eplb=False, + w8a8_dynamic=True, + is_nz=True): log_file = redirect_output(f"local_rank_{local_rank_id}.log" ) if output_to_file(local_rank_id) else None global_rank_id = local_rank_id # 单机 @@ -407,7 +463,7 @@ def run_once(local_rank_id, ep_world_size, moe_expert_num, global_rank_id, shared_expert_rank_num) input_datas, weight_datas, actual_bs, valid_token_num = generate_datas( - *parameter, top_k, test_bfloat16, enable_dynamic_bs, with_mc2_mask) + *parameter, top_k, test_bfloat16, enable_dynamic_bs, with_mc2_mask, w8a8_dynamic) input_datas = [ data.npu() if data is not None else None for data in input_datas ] @@ -415,27 +471,52 @@ def run_once(local_rank_id, data.npu() if data is not None else None for data in weight_datas ] small_ops = SmallOps(*weight_datas, ep_hcomm_info_small, *parameter, - dynamic_eplb).npu() # type: ignore + dynamic_eplb, w8a8_dynamic, is_nz).npu() # type: ignore fused_ops = FusionOp(*weight_datas, ep_hcomm_info_fused, *parameter, - dynamic_eplb).npu() # type: ignore + dynamic_eplb, w8a8_dynamic, is_nz).npu() # type: ignore if test_graph: config = torchair.CompilerConfig() config.mode = "reduce-overhead" npu_backend = torchair.get_npu_backend(compiler_config=config) fused_ops = torch.compile(fused_ops, backend=npu_backend) + + # test performance + start_time = time.perf_counter() + for _ in range(100): + small_op_token_output, small_op_count_output = small_ops(*input_datas) + torch_npu.npu.synchronize(device_id) + end_time = time.perf_counter() + elapsed_time = end_time - start_time + elapsed_time_us = elapsed_time * 1000000 + print(f"rank-{global_rank_id} small {elapsed_time_us} us") + start_time = time.perf_counter() + for _ in range(100): + fused_op_token_output, fused_op_count_output = fused_ops(*input_datas) + torch_npu.npu.synchronize(device_id) + end_time = time.perf_counter() + elapsed_time = end_time - start_time + elapsed_time_us = elapsed_time * 1000000 + print(f"rank-{global_rank_id} fused {elapsed_time_us} us") small_op_token_output, small_op_count_output = small_ops(*input_datas) + torch_npu.npu.synchronize(device_id) + print(f"rank-{global_rank_id} Small op End") fused_op_token_output, fused_op_count_output = fused_ops(*input_datas) torch_npu.npu.synchronize(device_id) + print(f"rank-{global_rank_id} Fused op End") dist.destroy_process_group() if log_file is not None: log_file.close() - - torch.testing.assert_close(small_op_token_output[0:valid_token_num].cpu(), - fused_op_token_output[0:valid_token_num].cpu(), - atol=2.0, - rtol=0.02) - torch.testing.assert_close(small_op_count_output.cpu(), - fused_op_count_output.cpu()) + try: + torch.testing.assert_close(small_op_token_output[0:valid_token_num].cpu(), + fused_op_token_output[0:valid_token_num].cpu(), + atol=2.0, + rtol=0.02) + torch.testing.assert_close(small_op_count_output.cpu(), + fused_op_count_output.cpu()) + except Exception as e: + print(f"rank-{global_rank_id} Assert close Failed: {e}") + else: + print(f"rank-{global_rank_id} Assert close Pass") gc.collect() torch.npu.empty_cache() torch.npu.reset_peak_memory_stats() @@ -444,9 +525,16 @@ def run_once(local_rank_id, @torch.inference_mode() def test_dispatch_gmm_combine_decode_base(): custom_kwargs = BASE_KWARGS + custom_kwargs["batch_size"] = 32 + custom_kwargs["ep_world_size"] = 8 + custom_kwargs["moe_expert_num"] = 32 + custom_kwargs["w8a8_dynamic"] = False + custom_kwargs["is_nz"] = True ep_world_size = custom_kwargs["ep_world_size"] custom_args = tuple(custom_kwargs.values()) + print(f"{custom_kwargs=}") mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True) + print(f"{custom_kwargs=}") @torch.inference_mode() @@ -465,3 +553,6 @@ def test_dispatch_gmm_combine_decode_dynamic_eplb(): ep_world_size = custom_kwargs["ep_world_size"] custom_args = tuple(custom_kwargs.values()) mp.spawn(run_once, args=custom_args, nprocs=ep_world_size, join=True) + +if __name__ == "__main__": + test_dispatch_gmm_combine_decode_base()