61 lines
3.0 KiB
Plaintext
61 lines
3.0 KiB
Plaintext
/*************************************************************************
|
|
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
|
*
|
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
|
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
|
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
|
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
|
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
*************************************************************************/
|
|
#ifndef CSRC_KERNELS_QUANTIZE_MLUH_
|
|
#define CSRC_KERNELS_QUANTIZE_MLUH_
|
|
|
|
#include "cnnl.h"
|
|
#include "kernel_utils.h"
|
|
namespace tmo {
|
|
/**
|
|
* @brief quantize tensor by per head.
|
|
* @param queue: The queue for mlu.
|
|
* @param dst: Output. Pointer to the destination MLU memory. Shape is [bs, seq, head_num,
|
|
* head_size], may not be continuous.
|
|
* @param scale: Input. Pointer to the destination scale MLU memory. Shape is [bs, seq, head_num],
|
|
* must be continuous.
|
|
* @param src: Input. Pointer to the source MLU memory. Shape is [bs, seq, head_num, head_size], may
|
|
* not be continuous.
|
|
* @param dst_dtype: Data type of destination tensor. Must be int8.
|
|
* @param scale_dtype: Data type of destination scale tensor. Must be float32.
|
|
* @param src_dtype: Data type of src tensor. Must be float or half.
|
|
* @param bs: batch_size of dst or src tensor.
|
|
* @param seq_len: seq_len of dst or src tensor.
|
|
* @param head_num: head_num of dst or src tensor.
|
|
* @param head_size: head_size of dst or src tensor.
|
|
* @param dst_bs_stride: stride of batch_size dim of dst tensor.
|
|
* @param dst_seq_stride: stride of seq_len dim of dst tensor.
|
|
* @param dst_head_stride: stride of head_num dim of dst tensor.
|
|
* @param src_bs_stride: stride of batch_size dim of src tensor.
|
|
* @param src_seq_stride: stride of seq_len dim of src tensor.
|
|
* @param src_head_stride: stride of head_num dim of src tensor.
|
|
*/
|
|
KernelStatus invokeMluQuantizePerHead(cnrtQueue_t queue,
|
|
void *dst,
|
|
void *scale,
|
|
const void *src,
|
|
cnnlDataType_t dst_dtype,
|
|
cnnlDataType_t scale_dtype,
|
|
cnnlDataType_t src_dtype,
|
|
int bs,
|
|
int seq_len,
|
|
int head_num,
|
|
int head_size,
|
|
int dst_bs_stride,
|
|
int dst_seq_stride,
|
|
int dst_head_stride,
|
|
int src_bs_stride,
|
|
int src_seq_stride,
|
|
int src_head_stride);
|
|
} // namespace tmo
|
|
|
|
#endif // CSRC_KERNELS_QUANTIZE_MLUH_
|