Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/csrc/kernels/quantize.mluh
2026-02-04 17:39:32 +08:00

61 lines
3.0 KiB
Plaintext

/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_KERNELS_QUANTIZE_MLUH_
#define CSRC_KERNELS_QUANTIZE_MLUH_
#include "cnnl.h"
#include "kernel_utils.h"
namespace tmo {
/**
* @brief quantize tensor by per head.
* @param queue: The queue for mlu.
* @param dst: Output. Pointer to the destination MLU memory. Shape is [bs, seq, head_num,
* head_size], may not be continuous.
* @param scale: Input. Pointer to the destination scale MLU memory. Shape is [bs, seq, head_num],
* must be continuous.
* @param src: Input. Pointer to the source MLU memory. Shape is [bs, seq, head_num, head_size], may
* not be continuous.
* @param dst_dtype: Data type of destination tensor. Must be int8.
* @param scale_dtype: Data type of destination scale tensor. Must be float32.
* @param src_dtype: Data type of src tensor. Must be float or half.
* @param bs: batch_size of dst or src tensor.
* @param seq_len: seq_len of dst or src tensor.
* @param head_num: head_num of dst or src tensor.
* @param head_size: head_size of dst or src tensor.
* @param dst_bs_stride: stride of batch_size dim of dst tensor.
* @param dst_seq_stride: stride of seq_len dim of dst tensor.
* @param dst_head_stride: stride of head_num dim of dst tensor.
* @param src_bs_stride: stride of batch_size dim of src tensor.
* @param src_seq_stride: stride of seq_len dim of src tensor.
* @param src_head_stride: stride of head_num dim of src tensor.
*/
KernelStatus invokeMluQuantizePerHead(cnrtQueue_t queue,
void *dst,
void *scale,
const void *src,
cnnlDataType_t dst_dtype,
cnnlDataType_t scale_dtype,
cnnlDataType_t src_dtype,
int bs,
int seq_len,
int head_num,
int head_size,
int dst_bs_stride,
int dst_seq_stride,
int dst_head_stride,
int src_bs_stride,
int src_seq_stride,
int src_head_stride);
} // namespace tmo
#endif // CSRC_KERNELS_QUANTIZE_MLUH_