forked from EngineX-Cambricon/enginex-mlu370-vllm
67 lines
2.2 KiB
C++
67 lines
2.2 KiB
C++
/*************************************************************************
|
|
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
|
*
|
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
|
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
|
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
|
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
|
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
*************************************************************************/
|
|
|
|
#ifndef CSRC_OPS_KERNEL_API_H_
|
|
#define CSRC_OPS_KERNEL_API_H_
|
|
|
|
#include <map>
|
|
#include <vector>
|
|
#include "cnrt.h"
|
|
|
|
#include "op_descriptor/attn_proj_descriptor.h"
|
|
#include "op_descriptor/batchmatmul_descriptor.h"
|
|
#include "op_descriptor/ffn_descriptor.h"
|
|
#include "op_descriptor/group_gemm_descriptor.h"
|
|
#include "op_descriptor/matmul_descriptor.h"
|
|
#include "op_descriptor/quant_matmul_descriptor.h"
|
|
|
|
namespace tmo {
|
|
namespace ops {
|
|
|
|
using GroupGemmDesc = tmo::op_desc::GroupGemmDesc;
|
|
|
|
size_t getGroupGemmWorkspaceSize(cnnlHandle_t handle, GroupGemmDesc &desc, const int num_expert);
|
|
|
|
void GroupGemm(const cnnlHandle_t &handle,
|
|
GroupGemmDesc &desc,
|
|
void *m,
|
|
void *alpha,
|
|
void *beta,
|
|
void *workspace,
|
|
size_t workspace_size,
|
|
int num_expert,
|
|
int k,
|
|
int n,
|
|
int lda,
|
|
std::vector<int> &ldb);
|
|
|
|
void SmoothQuant(const cnnlHandle_t &handle,
|
|
void *input,
|
|
void *smooth_scale,
|
|
void *token_count,
|
|
void *gather_idx,
|
|
void *gather_idx_start_position,
|
|
void *output,
|
|
void *output_scale,
|
|
int n,
|
|
int c,
|
|
int e,
|
|
int input_stride,
|
|
int output_stride,
|
|
int topk,
|
|
cnnlDataType_t input_dtype);
|
|
|
|
} // namespace ops
|
|
} // namespace tmo
|
|
|
|
#endif // CSRC_OPS_KERNEL_API_H_
|