forked from EngineX-Cambricon/enginex-mlu370-vllm
add ops
This commit is contained in:
60
torch_mlu_ops-v1.3.2/csrc/kernels/quantize.mluh
Normal file
60
torch_mlu_ops-v1.3.2/csrc/kernels/quantize.mluh
Normal file
@@ -0,0 +1,60 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_KERNELS_QUANTIZE_MLUH_
|
||||
#define CSRC_KERNELS_QUANTIZE_MLUH_
|
||||
|
||||
#include "cnnl.h"
|
||||
#include "kernel_utils.h"
|
||||
namespace tmo {
|
||||
/**
|
||||
* @brief quantize tensor by per head.
|
||||
* @param queue: The queue for mlu.
|
||||
* @param dst: Output. Pointer to the destination MLU memory. Shape is [bs, seq, head_num,
|
||||
* head_size], may not be continuous.
|
||||
* @param scale: Input. Pointer to the destination scale MLU memory. Shape is [bs, seq, head_num],
|
||||
* must be continuous.
|
||||
* @param src: Input. Pointer to the source MLU memory. Shape is [bs, seq, head_num, head_size], may
|
||||
* not be continuous.
|
||||
* @param dst_dtype: Data type of destination tensor. Must be int8.
|
||||
* @param scale_dtype: Data type of destination scale tensor. Must be float32.
|
||||
* @param src_dtype: Data type of src tensor. Must be float or half.
|
||||
* @param bs: batch_size of dst or src tensor.
|
||||
* @param seq_len: seq_len of dst or src tensor.
|
||||
* @param head_num: head_num of dst or src tensor.
|
||||
* @param head_size: head_size of dst or src tensor.
|
||||
* @param dst_bs_stride: stride of batch_size dim of dst tensor.
|
||||
* @param dst_seq_stride: stride of seq_len dim of dst tensor.
|
||||
* @param dst_head_stride: stride of head_num dim of dst tensor.
|
||||
* @param src_bs_stride: stride of batch_size dim of src tensor.
|
||||
* @param src_seq_stride: stride of seq_len dim of src tensor.
|
||||
* @param src_head_stride: stride of head_num dim of src tensor.
|
||||
*/
|
||||
KernelStatus invokeMluQuantizePerHead(cnrtQueue_t queue,
|
||||
void *dst,
|
||||
void *scale,
|
||||
const void *src,
|
||||
cnnlDataType_t dst_dtype,
|
||||
cnnlDataType_t scale_dtype,
|
||||
cnnlDataType_t src_dtype,
|
||||
int bs,
|
||||
int seq_len,
|
||||
int head_num,
|
||||
int head_size,
|
||||
int dst_bs_stride,
|
||||
int dst_seq_stride,
|
||||
int dst_head_stride,
|
||||
int src_bs_stride,
|
||||
int src_seq_stride,
|
||||
int src_head_stride);
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_KERNELS_QUANTIZE_MLUH_
|
||||
Reference in New Issue
Block a user