/************************************************************************* * Copyright (C) [2023-2024] by Cambricon, Inc. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ #include "kernel_api.h" namespace tmo { namespace ops { void SmoothQuant(const cnnlHandle_t &handle, void *input, void *smooth_scale, void *token_count, void *gather_idx, void *gather_idx_start_position, void *output, void *output_scale, int n, int c, int e, int input_stride, int output_stride, int topk, cnnlDataType_t input_dtype) { int dim_nb = 2; int output_scale_dim_nb = 1; if (!gather_idx) { topk = 1; } int64_t input_dims[] = {n, c}; int64_t output_dims[] = {n * topk, c}; int64_t output_scale_dims[] = {n * topk}; int64_t input_strides[] = {input_stride, 1}; int64_t output_strides[] = {output_stride, 1}; cnnlTensorDescriptor_t input_tensor, smooth_scale_tensor, token_count_tensor, gather_idx_tensor, gather_idx_start_pos_tensor, output_tensor, output_scale_tensor; cnnlCreateTensorDescriptor(&input_tensor); cnnlCreateTensorDescriptor(&smooth_scale_tensor); cnnlCreateTensorDescriptor(&output_tensor); cnnlCreateTensorDescriptor(&output_scale_tensor); CNNL_CHECK_FATAL(cnnlSetTensorDescriptorEx_v2(input_tensor, CNNL_LAYOUT_ARRAY, input_dtype, dim_nb, input_dims, input_strides)); CNNL_CHECK_FATAL(cnnlSetTensorDescriptorEx_v2(output_tensor, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_INT8, dim_nb, output_dims, output_strides)); CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(output_scale_tensor, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, output_scale_dim_nb, output_scale_dims)); if (token_count) { int64_t smooth_scale_dims[] = {e, c}; int64_t token_count_dims[] = {e}; cnnlCreateTensorDescriptor(&token_count_tensor); CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(token_count_tensor, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_INT32, 1, token_count_dims)); CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(smooth_scale_tensor, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, 2, smooth_scale_dims)); } else { int64_t smooth_scale_dims[] = {c}; CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(smooth_scale_tensor, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, 1, smooth_scale_dims)); } if (gather_idx) { int64_t gather_idx_dims[] = {n * topk}; cnnlCreateTensorDescriptor(&gather_idx_tensor); CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(gather_idx_tensor, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_INT32, 1, gather_idx_dims)); } if (gather_idx_start_position) { int64_t gather_idx_start_pos_dims[] = {1}; cnnlCreateTensorDescriptor(&gather_idx_start_pos_tensor); CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(gather_idx_start_pos_tensor, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_INT32, 1, gather_idx_start_pos_dims)); } CNNL_CHECK_FATAL(cnnlSmoothQuantOnline_v4( handle, nullptr /*smooth_quant_online_desc*/, input_tensor /*input_desc*/, input /*input_ptr*/, smooth_scale_tensor /*input_smooth_quant_scale_desc*/, smooth_scale /*input_sq_scale_ptr*/, nullptr /*input_smooth_quant_zero_desc*/, nullptr /*input_sq_zero_ptr*/, token_count ? token_count_tensor /*token_count_desc*/ : nullptr, token_count /*token_count_ptr*/, gather_idx ? gather_idx_tensor /*gather_idx_desc*/ : nullptr, gather_idx /*gather_idx_ptr*/, gather_idx_start_position /*gather_idx_start_pos_desc*/ ? gather_idx_start_pos_tensor : nullptr, gather_idx_start_position /*gather_idx_start_pos_ptr*/, output_tensor /*output_desc*/, output /*output_ptr*/, output_scale_tensor /*output_scale_desc*/, output_scale /*output_scale_ptr*/, nullptr /*output_zero_desc*/, nullptr /*output_zero_ptr*/, nullptr /*worksapce*/, 0 /*workspace_size*/)); CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(input_tensor)); CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(smooth_scale_tensor)); CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(output_tensor)); CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(output_scale_tensor)); if (token_count) { CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(token_count_tensor)); } if (gather_idx) { CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(gather_idx_tensor)); } if (gather_idx_start_position) { CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(gather_idx_start_pos_tensor)); } } } // namespace ops } // namespace tmo