forked from EngineX-Cambricon/enginex-mlu370-vllm
add ops
This commit is contained in:
66
torch_mlu_ops-v1.3.2/csrc/ops/kernel_api.h
Normal file
66
torch_mlu_ops-v1.3.2/csrc/ops/kernel_api.h
Normal file
@@ -0,0 +1,66 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#ifndef CSRC_OPS_KERNEL_API_H_
|
||||
#define CSRC_OPS_KERNEL_API_H_
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "cnrt.h"
|
||||
|
||||
#include "op_descriptor/attn_proj_descriptor.h"
|
||||
#include "op_descriptor/batchmatmul_descriptor.h"
|
||||
#include "op_descriptor/ffn_descriptor.h"
|
||||
#include "op_descriptor/group_gemm_descriptor.h"
|
||||
#include "op_descriptor/matmul_descriptor.h"
|
||||
#include "op_descriptor/quant_matmul_descriptor.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace ops {
|
||||
|
||||
using GroupGemmDesc = tmo::op_desc::GroupGemmDesc;
|
||||
|
||||
size_t getGroupGemmWorkspaceSize(cnnlHandle_t handle, GroupGemmDesc &desc, const int num_expert);
|
||||
|
||||
void GroupGemm(const cnnlHandle_t &handle,
|
||||
GroupGemmDesc &desc,
|
||||
void *m,
|
||||
void *alpha,
|
||||
void *beta,
|
||||
void *workspace,
|
||||
size_t workspace_size,
|
||||
int num_expert,
|
||||
int k,
|
||||
int n,
|
||||
int lda,
|
||||
std::vector<int> &ldb);
|
||||
|
||||
void SmoothQuant(const cnnlHandle_t &handle,
|
||||
void *input,
|
||||
void *smooth_scale,
|
||||
void *token_count,
|
||||
void *gather_idx,
|
||||
void *gather_idx_start_position,
|
||||
void *output,
|
||||
void *output_scale,
|
||||
int n,
|
||||
int c,
|
||||
int e,
|
||||
int input_stride,
|
||||
int output_stride,
|
||||
int topk,
|
||||
cnnlDataType_t input_dtype);
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_OPS_KERNEL_API_H_
|
||||
Reference in New Issue
Block a user