forked from EngineX-Cambricon/enginex-mlu370-vllm
add ops
This commit is contained in:
117
torch_mlu_ops-v1.3.2/csrc/ops/SmoothQuant.cpp
Normal file
117
torch_mlu_ops-v1.3.2/csrc/ops/SmoothQuant.cpp
Normal file
@@ -0,0 +1,117 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "kernel_api.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace ops {
|
||||
|
||||
void SmoothQuant(const cnnlHandle_t &handle,
|
||||
void *input,
|
||||
void *smooth_scale,
|
||||
void *token_count,
|
||||
void *gather_idx,
|
||||
void *gather_idx_start_position,
|
||||
void *output,
|
||||
void *output_scale,
|
||||
int n,
|
||||
int c,
|
||||
int e,
|
||||
int input_stride,
|
||||
int output_stride,
|
||||
int topk,
|
||||
cnnlDataType_t input_dtype) {
|
||||
int dim_nb = 2;
|
||||
int output_scale_dim_nb = 1;
|
||||
if (!gather_idx) {
|
||||
topk = 1;
|
||||
}
|
||||
int64_t input_dims[] = {n, c};
|
||||
int64_t output_dims[] = {n * topk, c};
|
||||
int64_t output_scale_dims[] = {n * topk};
|
||||
int64_t input_strides[] = {input_stride, 1};
|
||||
int64_t output_strides[] = {output_stride, 1};
|
||||
|
||||
cnnlTensorDescriptor_t input_tensor, smooth_scale_tensor, token_count_tensor, gather_idx_tensor,
|
||||
gather_idx_start_pos_tensor, output_tensor, output_scale_tensor;
|
||||
cnnlCreateTensorDescriptor(&input_tensor);
|
||||
cnnlCreateTensorDescriptor(&smooth_scale_tensor);
|
||||
cnnlCreateTensorDescriptor(&output_tensor);
|
||||
cnnlCreateTensorDescriptor(&output_scale_tensor);
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorEx_v2(input_tensor, CNNL_LAYOUT_ARRAY, input_dtype,
|
||||
dim_nb, input_dims, input_strides));
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorEx_v2(output_tensor, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_INT8,
|
||||
dim_nb, output_dims, output_strides));
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(output_scale_tensor, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_FLOAT, output_scale_dim_nb,
|
||||
output_scale_dims));
|
||||
|
||||
if (token_count) {
|
||||
int64_t smooth_scale_dims[] = {e, c};
|
||||
int64_t token_count_dims[] = {e};
|
||||
cnnlCreateTensorDescriptor(&token_count_tensor);
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(token_count_tensor, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_INT32, 1, token_count_dims));
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(smooth_scale_tensor, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_FLOAT, 2, smooth_scale_dims));
|
||||
} else {
|
||||
int64_t smooth_scale_dims[] = {c};
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(smooth_scale_tensor, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_FLOAT, 1, smooth_scale_dims));
|
||||
}
|
||||
|
||||
if (gather_idx) {
|
||||
int64_t gather_idx_dims[] = {n * topk};
|
||||
cnnlCreateTensorDescriptor(&gather_idx_tensor);
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(gather_idx_tensor, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_INT32, 1, gather_idx_dims));
|
||||
}
|
||||
|
||||
if (gather_idx_start_position) {
|
||||
int64_t gather_idx_start_pos_dims[] = {1};
|
||||
cnnlCreateTensorDescriptor(&gather_idx_start_pos_tensor);
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(gather_idx_start_pos_tensor, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_INT32, 1, gather_idx_start_pos_dims));
|
||||
}
|
||||
|
||||
CNNL_CHECK_FATAL(cnnlSmoothQuantOnline_v4(
|
||||
handle, nullptr /*smooth_quant_online_desc*/, input_tensor /*input_desc*/,
|
||||
input /*input_ptr*/, smooth_scale_tensor /*input_smooth_quant_scale_desc*/,
|
||||
smooth_scale /*input_sq_scale_ptr*/, nullptr /*input_smooth_quant_zero_desc*/,
|
||||
nullptr /*input_sq_zero_ptr*/,
|
||||
token_count ? token_count_tensor /*token_count_desc*/ : nullptr,
|
||||
token_count /*token_count_ptr*/, gather_idx ? gather_idx_tensor /*gather_idx_desc*/ : nullptr,
|
||||
gather_idx /*gather_idx_ptr*/,
|
||||
gather_idx_start_position /*gather_idx_start_pos_desc*/ ? gather_idx_start_pos_tensor
|
||||
: nullptr,
|
||||
gather_idx_start_position /*gather_idx_start_pos_ptr*/, output_tensor /*output_desc*/,
|
||||
output /*output_ptr*/, output_scale_tensor /*output_scale_desc*/,
|
||||
output_scale /*output_scale_ptr*/, nullptr /*output_zero_desc*/, nullptr /*output_zero_ptr*/,
|
||||
nullptr /*worksapce*/, 0 /*workspace_size*/));
|
||||
|
||||
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(input_tensor));
|
||||
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(smooth_scale_tensor));
|
||||
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(output_tensor));
|
||||
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(output_scale_tensor));
|
||||
if (token_count) {
|
||||
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(token_count_tensor));
|
||||
}
|
||||
if (gather_idx) {
|
||||
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(gather_idx_tensor));
|
||||
}
|
||||
if (gather_idx_start_position) {
|
||||
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(gather_idx_start_pos_tensor));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tmo
|
||||
Reference in New Issue
Block a user