forked from EngineX-Cambricon/enginex-mlu370-vllm
118 lines
5.4 KiB
C++
118 lines
5.4 KiB
C++
/*************************************************************************
|
|
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
|
*
|
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
|
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
|
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
|
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
|
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
*************************************************************************/
|
|
|
|
#include "kernel_api.h"
|
|
|
|
namespace tmo {
|
|
namespace ops {
|
|
|
|
void SmoothQuant(const cnnlHandle_t &handle,
|
|
void *input,
|
|
void *smooth_scale,
|
|
void *token_count,
|
|
void *gather_idx,
|
|
void *gather_idx_start_position,
|
|
void *output,
|
|
void *output_scale,
|
|
int n,
|
|
int c,
|
|
int e,
|
|
int input_stride,
|
|
int output_stride,
|
|
int topk,
|
|
cnnlDataType_t input_dtype) {
|
|
int dim_nb = 2;
|
|
int output_scale_dim_nb = 1;
|
|
if (!gather_idx) {
|
|
topk = 1;
|
|
}
|
|
int64_t input_dims[] = {n, c};
|
|
int64_t output_dims[] = {n * topk, c};
|
|
int64_t output_scale_dims[] = {n * topk};
|
|
int64_t input_strides[] = {input_stride, 1};
|
|
int64_t output_strides[] = {output_stride, 1};
|
|
|
|
cnnlTensorDescriptor_t input_tensor, smooth_scale_tensor, token_count_tensor, gather_idx_tensor,
|
|
gather_idx_start_pos_tensor, output_tensor, output_scale_tensor;
|
|
cnnlCreateTensorDescriptor(&input_tensor);
|
|
cnnlCreateTensorDescriptor(&smooth_scale_tensor);
|
|
cnnlCreateTensorDescriptor(&output_tensor);
|
|
cnnlCreateTensorDescriptor(&output_scale_tensor);
|
|
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorEx_v2(input_tensor, CNNL_LAYOUT_ARRAY, input_dtype,
|
|
dim_nb, input_dims, input_strides));
|
|
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorEx_v2(output_tensor, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_INT8,
|
|
dim_nb, output_dims, output_strides));
|
|
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(output_scale_tensor, CNNL_LAYOUT_ARRAY,
|
|
CNNL_DTYPE_FLOAT, output_scale_dim_nb,
|
|
output_scale_dims));
|
|
|
|
if (token_count) {
|
|
int64_t smooth_scale_dims[] = {e, c};
|
|
int64_t token_count_dims[] = {e};
|
|
cnnlCreateTensorDescriptor(&token_count_tensor);
|
|
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(token_count_tensor, CNNL_LAYOUT_ARRAY,
|
|
CNNL_DTYPE_INT32, 1, token_count_dims));
|
|
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(smooth_scale_tensor, CNNL_LAYOUT_ARRAY,
|
|
CNNL_DTYPE_FLOAT, 2, smooth_scale_dims));
|
|
} else {
|
|
int64_t smooth_scale_dims[] = {c};
|
|
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(smooth_scale_tensor, CNNL_LAYOUT_ARRAY,
|
|
CNNL_DTYPE_FLOAT, 1, smooth_scale_dims));
|
|
}
|
|
|
|
if (gather_idx) {
|
|
int64_t gather_idx_dims[] = {n * topk};
|
|
cnnlCreateTensorDescriptor(&gather_idx_tensor);
|
|
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(gather_idx_tensor, CNNL_LAYOUT_ARRAY,
|
|
CNNL_DTYPE_INT32, 1, gather_idx_dims));
|
|
}
|
|
|
|
if (gather_idx_start_position) {
|
|
int64_t gather_idx_start_pos_dims[] = {1};
|
|
cnnlCreateTensorDescriptor(&gather_idx_start_pos_tensor);
|
|
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(gather_idx_start_pos_tensor, CNNL_LAYOUT_ARRAY,
|
|
CNNL_DTYPE_INT32, 1, gather_idx_start_pos_dims));
|
|
}
|
|
|
|
CNNL_CHECK_FATAL(cnnlSmoothQuantOnline_v4(
|
|
handle, nullptr /*smooth_quant_online_desc*/, input_tensor /*input_desc*/,
|
|
input /*input_ptr*/, smooth_scale_tensor /*input_smooth_quant_scale_desc*/,
|
|
smooth_scale /*input_sq_scale_ptr*/, nullptr /*input_smooth_quant_zero_desc*/,
|
|
nullptr /*input_sq_zero_ptr*/,
|
|
token_count ? token_count_tensor /*token_count_desc*/ : nullptr,
|
|
token_count /*token_count_ptr*/, gather_idx ? gather_idx_tensor /*gather_idx_desc*/ : nullptr,
|
|
gather_idx /*gather_idx_ptr*/,
|
|
gather_idx_start_position /*gather_idx_start_pos_desc*/ ? gather_idx_start_pos_tensor
|
|
: nullptr,
|
|
gather_idx_start_position /*gather_idx_start_pos_ptr*/, output_tensor /*output_desc*/,
|
|
output /*output_ptr*/, output_scale_tensor /*output_scale_desc*/,
|
|
output_scale /*output_scale_ptr*/, nullptr /*output_zero_desc*/, nullptr /*output_zero_ptr*/,
|
|
nullptr /*worksapce*/, 0 /*workspace_size*/));
|
|
|
|
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(input_tensor));
|
|
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(smooth_scale_tensor));
|
|
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(output_tensor));
|
|
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(output_scale_tensor));
|
|
if (token_count) {
|
|
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(token_count_tensor));
|
|
}
|
|
if (gather_idx) {
|
|
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(gather_idx_tensor));
|
|
}
|
|
if (gather_idx_start_position) {
|
|
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(gather_idx_start_pos_tensor));
|
|
}
|
|
}
|
|
|
|
} // namespace ops
|
|
} // namespace tmo
|