forked from EngineX-Cambricon/enginex-mlu370-vllm
58 lines
2.6 KiB
Plaintext
58 lines
2.6 KiB
Plaintext
/*************************************************************************
|
|
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
|
*
|
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
|
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
|
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
|
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
|
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
*************************************************************************/
|
|
#ifndef CSRC_KERNELS_DEQUANTIFY_MLUH_
|
|
#define CSRC_KERNELS_DEQUANTIFY_MLUH_
|
|
|
|
#include "cnnl.h"
|
|
#include "kernel_utils.h"
|
|
namespace tmo {
|
|
/**
|
|
* @brief Dequantify per tensor.
|
|
* @param handle: The handle of cnnl.
|
|
* @param src: Input. Pointer to the MLU memory that stores the input.
|
|
* @param src_bitwidth: The bitwidth of input quantized data.
|
|
* @param dst: Output. Pointer to the MLU memory that stores the output.
|
|
* @param dst_dtype: The data type of output.
|
|
* @param src_count: The number of elements in input.
|
|
* @param scale: The scale for dequantify.
|
|
*/
|
|
KernelStatus invokeDequantifyPerTensor(cnnlHandle_t handle,
|
|
const void *src,
|
|
int src_bitwidth,
|
|
void *dst,
|
|
cnnlDataType_t dst_dtype,
|
|
size_t src_count,
|
|
float scale);
|
|
|
|
/**
|
|
* @brief Dequantify per channel.
|
|
* @param handle: The handle of cnnl.
|
|
* @param src: Input. Pointer to the MLU memory that stores the input.
|
|
* @param src_bitwidth: The bitwidth of input quantized data.
|
|
* @param dst: Output. Pointer to the MLU memory that stores the output.
|
|
* @param dst_dtype: The data type of output.
|
|
* @param src_ci: The ci of input.
|
|
* @param co: The co of input.
|
|
* @param scale: Pointer to the MLU memory that stores the scale for dequantify.
|
|
*/
|
|
KernelStatus invokeDequantifyPerChannel(cnnlHandle_t handle,
|
|
const void *src,
|
|
int src_bitwidth,
|
|
void *dst,
|
|
cnnlDataType_t dst_dtype,
|
|
int src_ci,
|
|
int co,
|
|
const void *scale);
|
|
} // namespace tmo
|
|
|
|
#endif // CSRC_KERNELS_DEQUANTIFY_MLUH_
|