/************************************************************************* * Copyright (C) [2023-2024] by Cambricon, Inc. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ #ifndef CSRC_KERNELS_QUANTIZE_MLUH_ #define CSRC_KERNELS_QUANTIZE_MLUH_ #include "cnnl.h" #include "kernel_utils.h" namespace tmo { /** * @brief quantize tensor by per head. * @param queue: The queue for mlu. * @param dst: Output. Pointer to the destination MLU memory. Shape is [bs, seq, head_num, * head_size], may not be continuous. * @param scale: Input. Pointer to the destination scale MLU memory. Shape is [bs, seq, head_num], * must be continuous. * @param src: Input. Pointer to the source MLU memory. Shape is [bs, seq, head_num, head_size], may * not be continuous. * @param dst_dtype: Data type of destination tensor. Must be int8. * @param scale_dtype: Data type of destination scale tensor. Must be float32. * @param src_dtype: Data type of src tensor. Must be float or half. * @param bs: batch_size of dst or src tensor. * @param seq_len: seq_len of dst or src tensor. * @param head_num: head_num of dst or src tensor. * @param head_size: head_size of dst or src tensor. * @param dst_bs_stride: stride of batch_size dim of dst tensor. * @param dst_seq_stride: stride of seq_len dim of dst tensor. * @param dst_head_stride: stride of head_num dim of dst tensor. * @param src_bs_stride: stride of batch_size dim of src tensor. * @param src_seq_stride: stride of seq_len dim of src tensor. * @param src_head_stride: stride of head_num dim of src tensor. */ KernelStatus invokeMluQuantizePerHead(cnrtQueue_t queue, void *dst, void *scale, const void *src, cnnlDataType_t dst_dtype, cnnlDataType_t scale_dtype, cnnlDataType_t src_dtype, int bs, int seq_len, int head_num, int head_size, int dst_bs_stride, int dst_seq_stride, int dst_head_stride, int src_bs_stride, int src_seq_stride, int src_head_stride); } // namespace tmo #endif // CSRC_KERNELS_QUANTIZE_MLUH_