Files
2026-02-04 17:39:32 +08:00

67 lines
2.2 KiB
C++

/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_OPS_KERNEL_API_H_
#define CSRC_OPS_KERNEL_API_H_
#include <map>
#include <vector>
#include "cnrt.h"
#include "op_descriptor/attn_proj_descriptor.h"
#include "op_descriptor/batchmatmul_descriptor.h"
#include "op_descriptor/ffn_descriptor.h"
#include "op_descriptor/group_gemm_descriptor.h"
#include "op_descriptor/matmul_descriptor.h"
#include "op_descriptor/quant_matmul_descriptor.h"
namespace tmo {
namespace ops {
using GroupGemmDesc = tmo::op_desc::GroupGemmDesc;
size_t getGroupGemmWorkspaceSize(cnnlHandle_t handle, GroupGemmDesc &desc, const int num_expert);
void GroupGemm(const cnnlHandle_t &handle,
GroupGemmDesc &desc,
void *m,
void *alpha,
void *beta,
void *workspace,
size_t workspace_size,
int num_expert,
int k,
int n,
int lda,
std::vector<int> &ldb);
void SmoothQuant(const cnnlHandle_t &handle,
void *input,
void *smooth_scale,
void *token_count,
void *gather_idx,
void *gather_idx_start_position,
void *output,
void *output_scale,
int n,
int c,
int e,
int input_stride,
int output_stride,
int topk,
cnnlDataType_t input_dtype);
} // namespace ops
} // namespace tmo
#endif // CSRC_OPS_KERNEL_API_H_