/************************************************************************* * Copyright (C) [2023-2024] by Cambricon, Inc. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ #ifndef CSRC_OPS_KERNEL_API_H_ #define CSRC_OPS_KERNEL_API_H_ #include #include #include "cnrt.h" #include "op_descriptor/attn_proj_descriptor.h" #include "op_descriptor/batchmatmul_descriptor.h" #include "op_descriptor/ffn_descriptor.h" #include "op_descriptor/group_gemm_descriptor.h" #include "op_descriptor/matmul_descriptor.h" #include "op_descriptor/quant_matmul_descriptor.h" namespace tmo { namespace ops { using GroupGemmDesc = tmo::op_desc::GroupGemmDesc; size_t getGroupGemmWorkspaceSize(cnnlHandle_t handle, GroupGemmDesc &desc, const int num_expert); void GroupGemm(const cnnlHandle_t &handle, GroupGemmDesc &desc, void *m, void *alpha, void *beta, void *workspace, size_t workspace_size, int num_expert, int k, int n, int lda, std::vector &ldb); void SmoothQuant(const cnnlHandle_t &handle, void *input, void *smooth_scale, void *token_count, void *gather_idx, void *gather_idx_start_position, void *output, void *output_scale, int n, int c, int e, int input_stride, int output_stride, int topk, cnnlDataType_t input_dtype); } // namespace ops } // namespace tmo #endif // CSRC_OPS_KERNEL_API_H_