Files
2026-02-04 17:39:32 +08:00

118 lines
5.4 KiB
C++

/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "kernel_api.h"
namespace tmo {
namespace ops {
void SmoothQuant(const cnnlHandle_t &handle,
void *input,
void *smooth_scale,
void *token_count,
void *gather_idx,
void *gather_idx_start_position,
void *output,
void *output_scale,
int n,
int c,
int e,
int input_stride,
int output_stride,
int topk,
cnnlDataType_t input_dtype) {
int dim_nb = 2;
int output_scale_dim_nb = 1;
if (!gather_idx) {
topk = 1;
}
int64_t input_dims[] = {n, c};
int64_t output_dims[] = {n * topk, c};
int64_t output_scale_dims[] = {n * topk};
int64_t input_strides[] = {input_stride, 1};
int64_t output_strides[] = {output_stride, 1};
cnnlTensorDescriptor_t input_tensor, smooth_scale_tensor, token_count_tensor, gather_idx_tensor,
gather_idx_start_pos_tensor, output_tensor, output_scale_tensor;
cnnlCreateTensorDescriptor(&input_tensor);
cnnlCreateTensorDescriptor(&smooth_scale_tensor);
cnnlCreateTensorDescriptor(&output_tensor);
cnnlCreateTensorDescriptor(&output_scale_tensor);
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorEx_v2(input_tensor, CNNL_LAYOUT_ARRAY, input_dtype,
dim_nb, input_dims, input_strides));
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorEx_v2(output_tensor, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_INT8,
dim_nb, output_dims, output_strides));
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(output_scale_tensor, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, output_scale_dim_nb,
output_scale_dims));
if (token_count) {
int64_t smooth_scale_dims[] = {e, c};
int64_t token_count_dims[] = {e};
cnnlCreateTensorDescriptor(&token_count_tensor);
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(token_count_tensor, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_INT32, 1, token_count_dims));
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(smooth_scale_tensor, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, 2, smooth_scale_dims));
} else {
int64_t smooth_scale_dims[] = {c};
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(smooth_scale_tensor, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, 1, smooth_scale_dims));
}
if (gather_idx) {
int64_t gather_idx_dims[] = {n * topk};
cnnlCreateTensorDescriptor(&gather_idx_tensor);
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(gather_idx_tensor, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_INT32, 1, gather_idx_dims));
}
if (gather_idx_start_position) {
int64_t gather_idx_start_pos_dims[] = {1};
cnnlCreateTensorDescriptor(&gather_idx_start_pos_tensor);
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(gather_idx_start_pos_tensor, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_INT32, 1, gather_idx_start_pos_dims));
}
CNNL_CHECK_FATAL(cnnlSmoothQuantOnline_v4(
handle, nullptr /*smooth_quant_online_desc*/, input_tensor /*input_desc*/,
input /*input_ptr*/, smooth_scale_tensor /*input_smooth_quant_scale_desc*/,
smooth_scale /*input_sq_scale_ptr*/, nullptr /*input_smooth_quant_zero_desc*/,
nullptr /*input_sq_zero_ptr*/,
token_count ? token_count_tensor /*token_count_desc*/ : nullptr,
token_count /*token_count_ptr*/, gather_idx ? gather_idx_tensor /*gather_idx_desc*/ : nullptr,
gather_idx /*gather_idx_ptr*/,
gather_idx_start_position /*gather_idx_start_pos_desc*/ ? gather_idx_start_pos_tensor
: nullptr,
gather_idx_start_position /*gather_idx_start_pos_ptr*/, output_tensor /*output_desc*/,
output /*output_ptr*/, output_scale_tensor /*output_scale_desc*/,
output_scale /*output_scale_ptr*/, nullptr /*output_zero_desc*/, nullptr /*output_zero_ptr*/,
nullptr /*worksapce*/, 0 /*workspace_size*/));
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(input_tensor));
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(smooth_scale_tensor));
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(output_tensor));
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(output_scale_tensor));
if (token_count) {
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(token_count_tensor));
}
if (gather_idx) {
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(gather_idx_tensor));
}
if (gather_idx_start_position) {
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(gather_idx_start_pos_tensor));
}
}
} // namespace ops
} // namespace tmo