add ops
This commit is contained in:
89
torch_mlu_ops-v1.3.2/csrc/ops/GroupGemm.cpp
Normal file
89
torch_mlu_ops-v1.3.2/csrc/ops/GroupGemm.cpp
Normal file
@@ -0,0 +1,89 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include <cstddef>
|
||||
#include <vector>
|
||||
#include "kernel_api.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace ops {
|
||||
|
||||
size_t getGroupGemmWorkspaceSize(cnnlHandle_t handle, GroupGemmDesc &desc, const int num_expert) {
|
||||
size_t workspace_size = 0;
|
||||
CNNL_CHECK_FATAL(
|
||||
cnnlGetGroupGemmWorkspaceSize(handle, // handle
|
||||
desc.group_gemm_desc(), // cnnlGroupGemmDescriptor_t
|
||||
nullptr, // cnnlGroupGemmAlgo_t
|
||||
num_expert, // groups
|
||||
false, // is_trans_a
|
||||
true, // is_trans_b
|
||||
desc.group_host_tensor(), // m_desc
|
||||
desc.group_host_tensor(), // n_desc
|
||||
desc.group_host_tensor(), // k_desc
|
||||
nullptr, // alpha_desc
|
||||
desc.group_host_tensor(), // lda_desc
|
||||
desc.a_desc(), // a_desc
|
||||
desc.group_host_tensor(), // ldb_desc
|
||||
desc.b_desc(), // b_desc
|
||||
nullptr, // beta_desc
|
||||
nullptr, // ldc_desc
|
||||
nullptr, // c_desc
|
||||
desc.group_host_tensor(), // ldd_desc
|
||||
desc.d_desc(), // d_desc
|
||||
&workspace_size));
|
||||
return workspace_size;
|
||||
}
|
||||
|
||||
void GroupGemm(const cnnlHandle_t &handle,
|
||||
GroupGemmDesc &desc,
|
||||
void *m,
|
||||
void *alpha,
|
||||
void *beta,
|
||||
void *workspace,
|
||||
size_t workspace_size,
|
||||
int num_expert,
|
||||
int k,
|
||||
int n,
|
||||
int lda,
|
||||
std::vector<int> &ldb) {
|
||||
std::vector<int> n_array(num_expert, n);
|
||||
std::vector<int> k_array(num_expert, k);
|
||||
std::vector<int> lda_array(num_expert, lda);
|
||||
std::vector<int> ldd_array(num_expert, n);
|
||||
|
||||
CNNL_CHECK_FATAL(cnnlGroupGemm(handle, desc.group_gemm_desc(), /*desc*/
|
||||
nullptr, /*algo*/
|
||||
num_expert, /*groups*/
|
||||
false, /*is_trans_a*/
|
||||
true, /*is_trans_b*/
|
||||
desc.group_device_tensor(), m, /*m*/
|
||||
desc.group_host_tensor(), n_array.data(), /*n*/
|
||||
desc.group_host_tensor(), k_array.data(), /*k*/
|
||||
alpha ? desc.scale_factor_tensor() : nullptr, /*alpha_desc*/
|
||||
alpha, /*alpha*/
|
||||
desc.group_host_tensor(), lda_array.data(), /*lda*/
|
||||
desc.a_desc(), /*a*/
|
||||
ldb.empty() ? nullptr : desc.group_host_tensor(),
|
||||
ldb.empty() ? nullptr : ldb.data(), /*ldb*/
|
||||
desc.b_desc(), /*b*/
|
||||
beta ? desc.scale_factor_tensor() : nullptr, /*beta_desc*/
|
||||
beta, /*beta*/
|
||||
desc.group_host_tensor(), /*ldc_desc*/
|
||||
ldd_array.data(), /*ldc*/
|
||||
desc.has_c() ? desc.c_desc() : desc.d_desc(), /*c*/
|
||||
workspace, workspace_size, /*workspace*/
|
||||
desc.group_host_tensor(), ldd_array.data(), /*ldd*/
|
||||
desc.d_desc())); /*d*/
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tmo
|
||||
117
torch_mlu_ops-v1.3.2/csrc/ops/SmoothQuant.cpp
Normal file
117
torch_mlu_ops-v1.3.2/csrc/ops/SmoothQuant.cpp
Normal file
@@ -0,0 +1,117 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "kernel_api.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace ops {
|
||||
|
||||
void SmoothQuant(const cnnlHandle_t &handle,
|
||||
void *input,
|
||||
void *smooth_scale,
|
||||
void *token_count,
|
||||
void *gather_idx,
|
||||
void *gather_idx_start_position,
|
||||
void *output,
|
||||
void *output_scale,
|
||||
int n,
|
||||
int c,
|
||||
int e,
|
||||
int input_stride,
|
||||
int output_stride,
|
||||
int topk,
|
||||
cnnlDataType_t input_dtype) {
|
||||
int dim_nb = 2;
|
||||
int output_scale_dim_nb = 1;
|
||||
if (!gather_idx) {
|
||||
topk = 1;
|
||||
}
|
||||
int64_t input_dims[] = {n, c};
|
||||
int64_t output_dims[] = {n * topk, c};
|
||||
int64_t output_scale_dims[] = {n * topk};
|
||||
int64_t input_strides[] = {input_stride, 1};
|
||||
int64_t output_strides[] = {output_stride, 1};
|
||||
|
||||
cnnlTensorDescriptor_t input_tensor, smooth_scale_tensor, token_count_tensor, gather_idx_tensor,
|
||||
gather_idx_start_pos_tensor, output_tensor, output_scale_tensor;
|
||||
cnnlCreateTensorDescriptor(&input_tensor);
|
||||
cnnlCreateTensorDescriptor(&smooth_scale_tensor);
|
||||
cnnlCreateTensorDescriptor(&output_tensor);
|
||||
cnnlCreateTensorDescriptor(&output_scale_tensor);
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorEx_v2(input_tensor, CNNL_LAYOUT_ARRAY, input_dtype,
|
||||
dim_nb, input_dims, input_strides));
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorEx_v2(output_tensor, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_INT8,
|
||||
dim_nb, output_dims, output_strides));
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(output_scale_tensor, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_FLOAT, output_scale_dim_nb,
|
||||
output_scale_dims));
|
||||
|
||||
if (token_count) {
|
||||
int64_t smooth_scale_dims[] = {e, c};
|
||||
int64_t token_count_dims[] = {e};
|
||||
cnnlCreateTensorDescriptor(&token_count_tensor);
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(token_count_tensor, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_INT32, 1, token_count_dims));
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(smooth_scale_tensor, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_FLOAT, 2, smooth_scale_dims));
|
||||
} else {
|
||||
int64_t smooth_scale_dims[] = {c};
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(smooth_scale_tensor, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_FLOAT, 1, smooth_scale_dims));
|
||||
}
|
||||
|
||||
if (gather_idx) {
|
||||
int64_t gather_idx_dims[] = {n * topk};
|
||||
cnnlCreateTensorDescriptor(&gather_idx_tensor);
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(gather_idx_tensor, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_INT32, 1, gather_idx_dims));
|
||||
}
|
||||
|
||||
if (gather_idx_start_position) {
|
||||
int64_t gather_idx_start_pos_dims[] = {1};
|
||||
cnnlCreateTensorDescriptor(&gather_idx_start_pos_tensor);
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(gather_idx_start_pos_tensor, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_INT32, 1, gather_idx_start_pos_dims));
|
||||
}
|
||||
|
||||
CNNL_CHECK_FATAL(cnnlSmoothQuantOnline_v4(
|
||||
handle, nullptr /*smooth_quant_online_desc*/, input_tensor /*input_desc*/,
|
||||
input /*input_ptr*/, smooth_scale_tensor /*input_smooth_quant_scale_desc*/,
|
||||
smooth_scale /*input_sq_scale_ptr*/, nullptr /*input_smooth_quant_zero_desc*/,
|
||||
nullptr /*input_sq_zero_ptr*/,
|
||||
token_count ? token_count_tensor /*token_count_desc*/ : nullptr,
|
||||
token_count /*token_count_ptr*/, gather_idx ? gather_idx_tensor /*gather_idx_desc*/ : nullptr,
|
||||
gather_idx /*gather_idx_ptr*/,
|
||||
gather_idx_start_position /*gather_idx_start_pos_desc*/ ? gather_idx_start_pos_tensor
|
||||
: nullptr,
|
||||
gather_idx_start_position /*gather_idx_start_pos_ptr*/, output_tensor /*output_desc*/,
|
||||
output /*output_ptr*/, output_scale_tensor /*output_scale_desc*/,
|
||||
output_scale /*output_scale_ptr*/, nullptr /*output_zero_desc*/, nullptr /*output_zero_ptr*/,
|
||||
nullptr /*worksapce*/, 0 /*workspace_size*/));
|
||||
|
||||
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(input_tensor));
|
||||
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(smooth_scale_tensor));
|
||||
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(output_tensor));
|
||||
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(output_scale_tensor));
|
||||
if (token_count) {
|
||||
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(token_count_tensor));
|
||||
}
|
||||
if (gather_idx) {
|
||||
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(gather_idx_tensor));
|
||||
}
|
||||
if (gather_idx_start_position) {
|
||||
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(gather_idx_start_pos_tensor));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tmo
|
||||
66
torch_mlu_ops-v1.3.2/csrc/ops/kernel_api.h
Normal file
66
torch_mlu_ops-v1.3.2/csrc/ops/kernel_api.h
Normal file
@@ -0,0 +1,66 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#ifndef CSRC_OPS_KERNEL_API_H_
|
||||
#define CSRC_OPS_KERNEL_API_H_
|
||||
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include "cnrt.h"
|
||||
|
||||
#include "op_descriptor/attn_proj_descriptor.h"
|
||||
#include "op_descriptor/batchmatmul_descriptor.h"
|
||||
#include "op_descriptor/ffn_descriptor.h"
|
||||
#include "op_descriptor/group_gemm_descriptor.h"
|
||||
#include "op_descriptor/matmul_descriptor.h"
|
||||
#include "op_descriptor/quant_matmul_descriptor.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace ops {
|
||||
|
||||
using GroupGemmDesc = tmo::op_desc::GroupGemmDesc;
|
||||
|
||||
size_t getGroupGemmWorkspaceSize(cnnlHandle_t handle, GroupGemmDesc &desc, const int num_expert);
|
||||
|
||||
void GroupGemm(const cnnlHandle_t &handle,
|
||||
GroupGemmDesc &desc,
|
||||
void *m,
|
||||
void *alpha,
|
||||
void *beta,
|
||||
void *workspace,
|
||||
size_t workspace_size,
|
||||
int num_expert,
|
||||
int k,
|
||||
int n,
|
||||
int lda,
|
||||
std::vector<int> &ldb);
|
||||
|
||||
void SmoothQuant(const cnnlHandle_t &handle,
|
||||
void *input,
|
||||
void *smooth_scale,
|
||||
void *token_count,
|
||||
void *gather_idx,
|
||||
void *gather_idx_start_position,
|
||||
void *output,
|
||||
void *output_scale,
|
||||
int n,
|
||||
int c,
|
||||
int e,
|
||||
int input_stride,
|
||||
int output_stride,
|
||||
int topk,
|
||||
cnnlDataType_t input_dtype);
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_OPS_KERNEL_API_H_
|
||||
@@ -0,0 +1,58 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "attn_proj_descriptor.h"
|
||||
#include "base_descriptor.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace op_desc {
|
||||
|
||||
namespace {
|
||||
static OpDescPool<AttnProjDescImpl> attn_proj_instance;
|
||||
}
|
||||
|
||||
AttnProjDesc::AttnProjDesc() {
|
||||
impl_ = attn_proj_instance.get();
|
||||
}
|
||||
|
||||
void AttnProjDesc::setDesc(
|
||||
const cnnlTransformerLayernormResidualStructure_t layernorm_residual_mode,
|
||||
const cnnlDataType_t compute_dtype,
|
||||
const bool q_has_value,
|
||||
const bool k_has_value,
|
||||
const bool v_has_value,
|
||||
const bool has_bias,
|
||||
const bool is_pack_mode,
|
||||
const int packed_max_seq_len,
|
||||
const bool trans_out,
|
||||
const bool store_layernorm_result,
|
||||
const float alpha,
|
||||
const float beta,
|
||||
const float layernorm_eps,
|
||||
const bool is_quant) {
|
||||
CNNL_CHECK_FATAL(cnnlSetTransformerAttnProjDescriptor(
|
||||
impl_->attn_proj_desc_, layernorm_residual_mode, nullptr, /*activation desc*/
|
||||
compute_dtype, q_has_value, k_has_value, v_has_value, has_bias, is_pack_mode,
|
||||
packed_max_seq_len, trans_out, store_layernorm_result, alpha, beta, layernorm_eps));
|
||||
const int allow_tf32 = 0; /*Not allow*/
|
||||
CNNL_CHECK_FATAL(
|
||||
cnnlSetTransformerAttnProjDescriptorAllowTF32(impl_->attn_proj_desc_, allow_tf32));
|
||||
this->is_quant_ = is_quant;
|
||||
}
|
||||
|
||||
AttnProjDescImpl::AttnProjDescImpl() {
|
||||
CNNL_CHECK_FATAL(cnnlCreateTransformerAttnProjDescriptor(&this->attn_proj_desc_));
|
||||
CNNL_CHECK_FATAL(cnnlCreateTransformerAttnProjQuantifyDescriptor(&this->attn_proj_quant_desc_));
|
||||
}
|
||||
|
||||
} // namespace op_desc
|
||||
} // namespace tmo
|
||||
@@ -0,0 +1,82 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#ifndef CSRC_OPS_OP_DESCRIPTOR_ATTN_PROJ_DESCRIPTOR_H_
|
||||
#define CSRC_OPS_OP_DESCRIPTOR_ATTN_PROJ_DESCRIPTOR_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include "cnnl.h"
|
||||
#include "cnnl_extra.h"
|
||||
#include "common/utils.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace op_desc {
|
||||
|
||||
class TMO_HIDDEN AttnProjDescImpl {
|
||||
public:
|
||||
AttnProjDescImpl();
|
||||
// using clear function to avoid compilation errors.
|
||||
~AttnProjDescImpl() { clear(); }
|
||||
|
||||
DELETE_COPY_ASSIGN_CONSTRUCT(AttnProjDescImpl);
|
||||
|
||||
cnnlTransformerAttnProjDescriptor_t attn_proj_desc_;
|
||||
cnnlTransformerAttnProjQuantizeDescriptor_t attn_proj_quant_desc_;
|
||||
|
||||
private:
|
||||
inline void clear() {
|
||||
CNNL_CHECK_FATAL(cnnlDestroyTransformerAttnProjDescriptor(this->attn_proj_desc_));
|
||||
CNNL_CHECK_FATAL(cnnlDestroyTransformerAttnProjQuantifyDescriptor(this->attn_proj_quant_desc_));
|
||||
}
|
||||
};
|
||||
|
||||
class TMO_EXPORT AttnProjDesc {
|
||||
public:
|
||||
using impl_type_ = std::unique_ptr<AttnProjDescImpl, std::function<void(AttnProjDescImpl *)>>;
|
||||
AttnProjDesc();
|
||||
void setDesc(const cnnlTransformerLayernormResidualStructure_t layernorm_residual_mode,
|
||||
const cnnlDataType_t compute_dtype,
|
||||
const bool q_has_value,
|
||||
const bool k_has_value,
|
||||
const bool v_has_value,
|
||||
const bool has_bias,
|
||||
const bool is_pack_mode,
|
||||
const int packed_max_seq_len,
|
||||
const bool trans_out,
|
||||
const bool store_layernorm_result,
|
||||
const float alpha,
|
||||
const float beta,
|
||||
const float layernorm_eps,
|
||||
const bool is_quant = false);
|
||||
|
||||
operator cnnlTransformerAttnProjDescriptor_t() { return this->impl_->attn_proj_desc_; }
|
||||
|
||||
operator cnnlTransformerAttnProjDescriptor_t() const { return this->impl_->attn_proj_desc_; }
|
||||
|
||||
operator cnnlTransformerAttnProjQuantizeDescriptor_t() {
|
||||
return is_quant_ ? this->impl_->attn_proj_quant_desc_ : nullptr;
|
||||
}
|
||||
|
||||
operator cnnlTransformerAttnProjQuantizeDescriptor_t() const {
|
||||
return is_quant_ ? this->impl_->attn_proj_quant_desc_ : nullptr;
|
||||
}
|
||||
|
||||
private:
|
||||
impl_type_ impl_;
|
||||
bool is_quant_ = false;
|
||||
};
|
||||
|
||||
} // namespace op_desc
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_OPS_OP_DESCRIPTOR_ATTN_PROJ_DESCRIPTOR_H_
|
||||
@@ -0,0 +1,62 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_OPS_OP_DESCRIPTOR_BASE_DESCRIPTOR_H_
|
||||
#define CSRC_OPS_OP_DESCRIPTOR_BASE_DESCRIPTOR_H_
|
||||
|
||||
#include <deque>
|
||||
#include <mutex>
|
||||
#include "common/utils.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace op_desc {
|
||||
|
||||
static constexpr int OpCreateStage = 0;
|
||||
|
||||
// Add BaseDescPool is only for store different op
|
||||
// desc pool in a specific container.
|
||||
template <typename T>
|
||||
class TMO_HIDDEN OpDescPool {
|
||||
public:
|
||||
using desc_unique_ptr = std::unique_ptr<T>;
|
||||
OpDescPool() { createMultiItems(); }
|
||||
|
||||
// delete copy construct and assign construct.
|
||||
DELETE_COPY_ASSIGN_CONSTRUCT(OpDescPool);
|
||||
|
||||
desc_unique_ptr get() {
|
||||
auto ptr = std::make_unique<T>();
|
||||
return desc_unique_ptr(ptr.release());
|
||||
}
|
||||
|
||||
~OpDescPool() {
|
||||
std::lock_guard<std::mutex> lock(this->mutex_);
|
||||
for (size_t i = 0; i < this->container_.size(); ++i) {
|
||||
this->container_[i].release();
|
||||
}
|
||||
this->container_.clear();
|
||||
}
|
||||
|
||||
private:
|
||||
void createMultiItems() {
|
||||
for (int i = 0; i < OpCreateStage; ++i) {
|
||||
container_.emplace_back(std::make_unique<T>());
|
||||
}
|
||||
}
|
||||
// variable
|
||||
std::mutex mutex_;
|
||||
std::deque<std::unique_ptr<T>> container_;
|
||||
};
|
||||
|
||||
} // namespace op_desc
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_OPS_OP_DESCRIPTOR_BASE_DESCRIPTOR_H_
|
||||
@@ -0,0 +1,52 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "batchmatmul_descriptor.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace op_desc {
|
||||
|
||||
namespace {
|
||||
static OpDescPool<BatchMatMulDescImpl> bmm_instance;
|
||||
}
|
||||
|
||||
BatchMatMulDesc::BatchMatMulDesc(cnnlMatMulHeuristicResult_t &heuristic_result,
|
||||
cnnlMatMulAlgo_t &algo,
|
||||
bool use_beta,
|
||||
bool trans_a,
|
||||
bool trans_b) {
|
||||
impl_ = bmm_instance.get();
|
||||
auto bmm_desc_ = impl_->bmm_desc_;
|
||||
heuristic_result = impl_->heuristic_result_;
|
||||
algo = impl_->algo_;
|
||||
int32_t matmul_trans_a = int(trans_a);
|
||||
int32_t matmul_trans_b = int(trans_b);
|
||||
int32_t matmul_use_tf32 = 0;
|
||||
int32_t matmul_use_beta = int(use_beta);
|
||||
CNNL_CHECK_FATAL(cnnlSetMatMulDescAttr(bmm_desc_, CNNL_MATMUL_DESC_TRANSA, &(matmul_trans_a),
|
||||
sizeof(int32_t)));
|
||||
CNNL_CHECK_FATAL(cnnlSetMatMulDescAttr(bmm_desc_, CNNL_MATMUL_DESC_TRANSB, &(matmul_trans_b),
|
||||
sizeof(int32_t)));
|
||||
CNNL_CHECK_FATAL(cnnlSetMatMulDescAttr(bmm_desc_, CNNL_MATMUL_ALLOW_TF32, &(matmul_use_tf32),
|
||||
sizeof(int32_t)));
|
||||
CNNL_CHECK_FATAL(
|
||||
cnnlSetMatMulDescAttr(bmm_desc_, CNNL_MATMUL_USE_BETA, &(matmul_use_beta), sizeof(int32_t)));
|
||||
}
|
||||
|
||||
BatchMatMulDescImpl::BatchMatMulDescImpl() {
|
||||
CNNL_CHECK_FATAL(cnnlCreateMatMulDescriptor(&this->bmm_desc_));
|
||||
CNNL_CHECK_FATAL(cnnlCreateMatMulHeuristicResult(&this->heuristic_result_));
|
||||
CNNL_CHECK_FATAL(cnnlCreateMatMulAlgo(&this->algo_));
|
||||
}
|
||||
|
||||
} // namespace op_desc
|
||||
} // namespace tmo
|
||||
@@ -0,0 +1,64 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#ifndef CSRC_OPS_OP_DESCRIPTOR_BATCHMATMUL_DESCRIPTOR_H_
|
||||
#define CSRC_OPS_OP_DESCRIPTOR_BATCHMATMUL_DESCRIPTOR_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include "base_descriptor.h"
|
||||
#include "common/utils.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace op_desc {
|
||||
|
||||
class TMO_HIDDEN BatchMatMulDescImpl {
|
||||
public:
|
||||
BatchMatMulDescImpl();
|
||||
// using clear function to avoid compilation errors.
|
||||
~BatchMatMulDescImpl() { clear(); }
|
||||
DELETE_COPY_ASSIGN_CONSTRUCT(BatchMatMulDescImpl);
|
||||
|
||||
cnnlMatMulDescriptor_t bmm_desc_;
|
||||
cnnlMatMulHeuristicResult_t heuristic_result_;
|
||||
cnnlMatMulAlgo_t algo_;
|
||||
|
||||
private:
|
||||
inline void clear() {
|
||||
CNNL_CHECK_FATAL(cnnlDestroyMatMulDescriptor(this->bmm_desc_));
|
||||
CNNL_CHECK_FATAL(cnnlDestroyMatMulHeuristicResult(this->heuristic_result_));
|
||||
CNNL_CHECK_FATAL(cnnlDestroyMatMulAlgo(this->algo_));
|
||||
}
|
||||
};
|
||||
|
||||
class TMO_EXPORT BatchMatMulDesc {
|
||||
public:
|
||||
using impl_type_ =
|
||||
std::unique_ptr<BatchMatMulDescImpl, std::function<void(BatchMatMulDescImpl *)>>;
|
||||
BatchMatMulDesc(cnnlMatMulHeuristicResult_t &heuristic_result,
|
||||
cnnlMatMulAlgo_t &algo,
|
||||
bool use_beta = false,
|
||||
bool trans_a = false,
|
||||
bool trans_b = true);
|
||||
|
||||
CLASS_CAST_TYPE_OPERATOR_DEFINE(cnnlMatMulDescriptor_t, impl_->bmm_desc_)
|
||||
CLASS_CAST_TYPE_OPERATOR_DEFINE(cnnlMatMulHeuristicResult_t, impl_->heuristic_result_)
|
||||
CLASS_CAST_TYPE_OPERATOR_DEFINE(cnnlMatMulAlgo_t, impl_->algo_)
|
||||
|
||||
private:
|
||||
impl_type_ impl_;
|
||||
};
|
||||
|
||||
} // namespace op_desc
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_OPS_OP_DESCRIPTOR_BATCHMATMUL_DESCRIPTOR_H_
|
||||
@@ -0,0 +1,67 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "ffn_descriptor.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace op_desc {
|
||||
|
||||
namespace {
|
||||
static OpDescPool<FeedForwardDescImpl> instance;
|
||||
} // end of anonymous namespace
|
||||
|
||||
FeedForwardDesc::FeedForwardDesc() : is_quanti_(false) {
|
||||
this->impl_ = instance.get();
|
||||
}
|
||||
|
||||
FeedForwardDesc::FeedForwardDesc(
|
||||
const cnnlTransformerLayernormResidualStructure_t layernorm_residual_mode,
|
||||
const std::string &act_name,
|
||||
cnnlDataType_t compute_type,
|
||||
float eps,
|
||||
float alpha,
|
||||
float beta,
|
||||
float act_coef)
|
||||
: is_quanti_(false) {
|
||||
this->impl_ = instance.get();
|
||||
this->setFeedForwardDesc(layernorm_residual_mode, act_name, compute_type, eps, alpha, beta,
|
||||
act_coef);
|
||||
}
|
||||
|
||||
void FeedForwardDesc::setFeedForwardDesc(
|
||||
const cnnlTransformerLayernormResidualStructure_t layernorm_residual_mode,
|
||||
const std::string &act_name,
|
||||
cnnlDataType_t compute_type,
|
||||
float eps,
|
||||
float alpha,
|
||||
float beta,
|
||||
float act_coef) {
|
||||
// Now only using bfloat16 and half.
|
||||
// alpha must be 1.0f when NO_RESIDUAL mode.
|
||||
CNNL_CHECK_FATAL(cnnlSetTransformerFeedForwardDescriptor_v2(
|
||||
impl_->ffn_desc_, eps, alpha, beta, compute_type, layernorm_residual_mode));
|
||||
act_coef = act_name == "silu" ? 1.0f : act_coef;
|
||||
cnnlActivationMode_t act_mode = strToActivationMode(act_name);
|
||||
CNNL_CHECK_FATAL(cnnlSetActivationDescriptor_v5(impl_->act_desc_, act_mode, CNNL_ACTIVATION_FAST,
|
||||
CNNL_NOT_PROPAGATE_NAN, act_coef,
|
||||
// following parameters are not used
|
||||
0, 0, 0, 0));
|
||||
}
|
||||
|
||||
FeedForwardDescImpl::FeedForwardDescImpl() {
|
||||
CNNL_CHECK_FATAL(cnnlCreateTransformerFeedForwardDescriptor(&this->ffn_desc_));
|
||||
CNNL_CHECK_FATAL(cnnlCreateTransformerFeedForwardQuantizeDescriptor(&this->quant_desc_));
|
||||
CNNL_CHECK_FATAL(cnnlCreateActivationDescriptor(&this->act_desc_));
|
||||
}
|
||||
|
||||
} // namespace op_desc
|
||||
} // namespace tmo
|
||||
102
torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/ffn_descriptor.h
Normal file
102
torch_mlu_ops-v1.3.2/csrc/ops/op_descriptor/ffn_descriptor.h
Normal file
@@ -0,0 +1,102 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_OPS_OP_DESCRIPTOR_FFN_DESCRIPTOR_H_
|
||||
#define CSRC_OPS_OP_DESCRIPTOR_FFN_DESCRIPTOR_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include "base_descriptor.h"
|
||||
#include "common/utils.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace op_desc {
|
||||
|
||||
/**
|
||||
* Note [FeedForwardDesc]
|
||||
* ~~~~~~~~~~~~~~~~
|
||||
* FeedForwardDesc is class to create a Transformer FeedForward op descriptor.
|
||||
*
|
||||
* ffn_inner_size is intermediate_size;
|
||||
* act_name mean act type, and will convert act type when set op desc;
|
||||
* act_coef is only using when act_name is CNNL_ACTIVATION_CLIPPED_RELU,
|
||||
* CNNL_ACTIVATION_ELU, CNNL_ACTIVATION_ELU_V2, CNNL_ACTIVATION_LEAKYRELU,
|
||||
* CNNL_ACTIVATION_TF_LEAKYRELU, CNNL_ACTIVATION_CAFFE_RELU6.
|
||||
*
|
||||
*/
|
||||
|
||||
struct TMO_HIDDEN FeedForwardDescImpl {
|
||||
FeedForwardDescImpl();
|
||||
// using clear function to avoid compilation errors.
|
||||
~FeedForwardDescImpl() { clear(); }
|
||||
DELETE_COPY_ASSIGN_CONSTRUCT(FeedForwardDescImpl);
|
||||
// variable
|
||||
cnnlTransformerFeedForwardDescriptor_t ffn_desc_;
|
||||
cnnlTransformerFeedForwardQuantizeDescriptor_t quant_desc_;
|
||||
cnnlActivationDescriptor_t act_desc_;
|
||||
|
||||
private:
|
||||
inline void clear() {
|
||||
CNNL_CHECK_FATAL(cnnlDestroyTransformerFeedForwardDescriptor(this->ffn_desc_));
|
||||
CNNL_CHECK_FATAL(cnnlDestroyTransformerFeedForwardQuantizeDescriptor(this->quant_desc_));
|
||||
CNNL_CHECK_FATAL(cnnlDestroyActivationDescriptor(this->act_desc_));
|
||||
}
|
||||
};
|
||||
|
||||
class TMO_EXPORT FeedForwardDesc {
|
||||
public:
|
||||
using impl_type =
|
||||
std::unique_ptr<FeedForwardDescImpl, std::function<void(FeedForwardDescImpl *)>>;
|
||||
FeedForwardDesc();
|
||||
FeedForwardDesc(const cnnlTransformerLayernormResidualStructure_t layernorm_residual_mode,
|
||||
const std::string &act_name,
|
||||
cnnlDataType_t compute_type,
|
||||
float eps,
|
||||
float alpha,
|
||||
float beta,
|
||||
float act_coef = 0);
|
||||
|
||||
void setFeedForwardDesc(const cnnlTransformerLayernormResidualStructure_t layernorm_residual_mode,
|
||||
const std::string &act_name,
|
||||
cnnlDataType_t compute_type,
|
||||
float eps,
|
||||
float alpha,
|
||||
float beta,
|
||||
float act_coef = 0);
|
||||
|
||||
// TODO(SG): quanti op desc is not support now, will support later.
|
||||
// void setQuantiDesc();
|
||||
|
||||
// TODO(SG): using std::string to pass layer_norm type? It's not clear now.
|
||||
// Add layer_norm fused later.
|
||||
// void setLayerNormInfo();
|
||||
|
||||
CLASS_CAST_TYPE_OPERATOR_DEFINE(cnnlTransformerFeedForwardDescriptor_t, impl_->ffn_desc_);
|
||||
CLASS_CAST_TYPE_OPERATOR_DEFINE(cnnlActivationDescriptor_t, impl_->act_desc_);
|
||||
|
||||
operator cnnlTransformerFeedForwardQuantizeDescriptor_t() const {
|
||||
return is_quanti_ == false
|
||||
? nullptr
|
||||
: const_cast<cnnlTransformerFeedForwardQuantizeDescriptor_t>(impl_->quant_desc_);
|
||||
}
|
||||
operator cnnlTransformerFeedForwardQuantizeDescriptor_t() {
|
||||
return is_quanti_ == false ? nullptr : impl_->quant_desc_;
|
||||
}
|
||||
|
||||
private:
|
||||
impl_type impl_;
|
||||
bool is_quanti_{false};
|
||||
};
|
||||
|
||||
} // namespace op_desc
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_OPS_OP_DESCRIPTOR_FFN_DESCRIPTOR_H_
|
||||
@@ -0,0 +1,178 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "group_gemm_descriptor.h"
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
|
||||
namespace tmo {
|
||||
namespace op_desc {
|
||||
|
||||
namespace {
|
||||
static OpDescPool<GroupGemmDescImpl> instance;
|
||||
} // end of anonymous namespace
|
||||
|
||||
GroupGemmDesc::GroupGemmDesc(int num_expert,
|
||||
int max_m,
|
||||
int n,
|
||||
int k,
|
||||
cnnlDataType_t dtype,
|
||||
bool idx_mode,
|
||||
QuantMode quant_mode) {
|
||||
this->impl_ = instance.get();
|
||||
this->idx_mode_ = idx_mode;
|
||||
this->quant_mode_ = quant_mode;
|
||||
set(num_expert, max_m, n, k, dtype);
|
||||
}
|
||||
|
||||
void GroupGemmDesc::set(int num_expert, int max_m, int n, int k, cnnlDataType_t dtype) {
|
||||
this->num_expert_ = num_expert;
|
||||
this->max_m_ = max_m;
|
||||
this->n_ = n;
|
||||
this->k_ = k;
|
||||
this->dtype_ = dtype;
|
||||
// no need to set compute_dtype, default is CNNL_DTYPE_FLOAT
|
||||
CNNL_CHECK_FATAL(cnnlSetGroupGemmDescAttr(impl_->group_gemm_desc_, CNNL_MATMUL_DESC_GROUP_M_MAX,
|
||||
&max_m, sizeof(int32_t)));
|
||||
CNNL_CHECK_FATAL(cnnlSetGroupGemmDescAttr(impl_->group_gemm_desc_, CNNL_MATMUL_DESC_GROUP_N_MAX,
|
||||
&n, sizeof(int32_t)));
|
||||
CNNL_CHECK_FATAL(cnnlSetGroupGemmDescAttr(impl_->group_gemm_desc_, CNNL_MATMUL_DESC_GROUP_K_MAX,
|
||||
&k, sizeof(int32_t)));
|
||||
int dim_nb = 1;
|
||||
int64_t dims[] = {num_expert};
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(impl_->group_host_tensor_, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_INT32, dim_nb, dims));
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(impl_->group_device_tensor_, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_INT32, dim_nb, dims));
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(impl_->scale_factor_tensor_, CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_FLOAT, dim_nb, dims));
|
||||
CNNL_CHECK_FATAL(
|
||||
cnnlSetTensorDescriptorPointerMode(impl_->group_host_tensor_, CNNL_POINTER_MODE_HOST));
|
||||
CNNL_CHECK_FATAL(
|
||||
cnnlSetTensorDescriptorPointerMode(impl_->group_device_tensor_, CNNL_POINTER_MODE_DEVICE));
|
||||
CNNL_CHECK_FATAL(
|
||||
cnnlSetTensorDescriptorPointerMode(impl_->scale_factor_tensor_, CNNL_POINTER_MODE_DEVICE));
|
||||
}
|
||||
|
||||
void GroupGemmDesc::setPerRowColScaleBiasAct(void *a_scale, void *b_scale) {
|
||||
setPerRowColScaleBiasAct(a_scale, b_scale, nullptr, nullptr, CNNL_DTYPE_INVALID,
|
||||
CNNL_ACTIVATION_IDENTITY, CNNL_PROPAGATE_NAN, 0.f, false, 0, 0, 0);
|
||||
}
|
||||
|
||||
void GroupGemmDesc::setPerRowColScaleBiasAct(void *a_scale,
|
||||
void *b_scale,
|
||||
void *quant_flag,
|
||||
void *bias,
|
||||
cnnlDataType_t bias_dtype,
|
||||
int channels,
|
||||
int blk,
|
||||
int pos) {
|
||||
setPerRowColScaleBiasAct(a_scale, b_scale, quant_flag, bias, bias_dtype, CNNL_ACTIVATION_IDENTITY,
|
||||
CNNL_PROPAGATE_NAN, 0.f, false, channels, blk, pos);
|
||||
}
|
||||
|
||||
void GroupGemmDesc::setPerRowColScaleBiasAct(void *a_scale,
|
||||
void *b_scale,
|
||||
void *quant_flag,
|
||||
void *bias,
|
||||
cnnlDataType_t bias_dtype,
|
||||
cnnlActivationMode_t mode,
|
||||
cnnlNanPropagation_t nan_prop,
|
||||
float coef,
|
||||
bool has_active,
|
||||
int channels,
|
||||
int scale_block_size,
|
||||
int channel_pos) {
|
||||
bool is_a_quant = a_scale != nullptr;
|
||||
bool use_b_scale = false;
|
||||
bool quant_grouped = (scale_block_size > 0 && channels / scale_block_size > 0) ? true : false;
|
||||
if (a_scale) {
|
||||
CNNL_CHECK_FATAL(cnnlSetGroupGemmTensorDescriptor(a_scale_desc(), CNNL_DTYPE_FLOAT, a_scale,
|
||||
CNNL_POINTER_MODE_DEVICE, nullptr));
|
||||
}
|
||||
if (!quant_grouped && b_scale) {
|
||||
use_b_scale = true;
|
||||
CNNL_CHECK_FATAL(cnnlSetGroupGemmTensorDescriptor(b_scale_desc(), CNNL_DTYPE_FLOAT, b_scale,
|
||||
CNNL_POINTER_MODE_DEVICE, nullptr));
|
||||
}
|
||||
if (bias) {
|
||||
this->has_bias_ = true;
|
||||
CNNL_CHECK_FATAL(cnnlSetGroupGemmTensorDescriptor(bias_desc(), bias_dtype, bias,
|
||||
CNNL_POINTER_MODE_DEVICE, nullptr));
|
||||
}
|
||||
this->has_active_ = has_active;
|
||||
if (has_active) {
|
||||
CNNL_CHECK_FATAL(cnnlSetActivationDescriptor(active_desc(), mode, nan_prop, coef));
|
||||
}
|
||||
|
||||
if (quant_grouped) {
|
||||
const int dim_nb = 2;
|
||||
int group_num = channels / scale_block_size;
|
||||
int64_t dims[dim_nb] = {group_num, num_expert_ * n_};
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(b_scale_quant_grouped_desc(), CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_FLOAT, dim_nb, dims));
|
||||
|
||||
CNNL_CHECK_FATAL(cnnlSetGroupGemmGroupwiseScale(
|
||||
group_gemm_desc(), CNNL_MATMUL_DESC_B_SCALE_POINTER, channels, scale_block_size,
|
||||
channel_pos, b_scale_quant_grouped_desc(), b_scale));
|
||||
if (quant_flag) {
|
||||
const int dim_qf = 2;
|
||||
int64_t dims[dim_qf] = {num_expert_, group_num};
|
||||
CNNL_CHECK_FATAL(
|
||||
cnnlSetTensorDescriptorPointerMode(quant_flag_desc(), CNNL_POINTER_MODE_HOST));
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(quant_flag_desc(), CNNL_LAYOUT_ARRAY,
|
||||
CNNL_DTYPE_INT32, dim_qf, dims));
|
||||
CNNL_CHECK_FATAL(cnnlSetGroupGemmGroupMixedQuantBitFlag(
|
||||
group_gemm_desc(), CNNL_MATMUL_DESC_B_QUANT_FLAG_POINTER, channels, scale_block_size,
|
||||
quant_flag_desc(), quant_flag));
|
||||
}
|
||||
}
|
||||
|
||||
CNNL_CHECK_FATAL(cnnlSetGroupGemmPerRowColScaleBiasAct(
|
||||
group_gemm_desc(), is_a_quant ? a_scale_desc() : nullptr,
|
||||
use_b_scale ? b_scale_desc() : nullptr, this->has_bias_ ? bias_desc() : nullptr,
|
||||
has_active ? active_desc() : nullptr));
|
||||
}
|
||||
|
||||
void GroupGemmDesc::setInputOutputTensor(cnnlDataType_t a_dtype,
|
||||
cnnlDataType_t b_dtype,
|
||||
cnnlDataType_t d_dtype,
|
||||
cnnlDataType_t idx_dtype,
|
||||
void *a,
|
||||
void *b,
|
||||
void *c,
|
||||
void *d,
|
||||
void *idx,
|
||||
int64_t k,
|
||||
int64_t k_stride,
|
||||
int64_t total_m,
|
||||
bool has_c,
|
||||
int64_t *b_offset) {
|
||||
this->has_c_ = has_c;
|
||||
if (idx_mode()) {
|
||||
cnnlSetGroupGemmTensorGatherIdxDesc(a_desc(), a_dtype, idx_dtype, a, idx, k, k_stride, total_m);
|
||||
} else {
|
||||
CNNL_CHECK_FATAL(
|
||||
cnnlSetGroupGemmTensorDescriptor(a_desc(), a_dtype, a, CNNL_POINTER_MODE_DEVICE, nullptr));
|
||||
}
|
||||
auto b_ptr_mode = b_offset ? CNNL_POINTER_MODE_HOST : CNNL_POINTER_MODE_DEVICE;
|
||||
CNNL_CHECK_FATAL(cnnlSetGroupGemmTensorDescriptor(b_desc(), b_dtype, b, b_ptr_mode, b_offset));
|
||||
if (has_c) {
|
||||
CNNL_CHECK_FATAL(
|
||||
cnnlSetGroupGemmTensorDescriptor(c_desc(), d_dtype, c, CNNL_POINTER_MODE_DEVICE, nullptr));
|
||||
}
|
||||
CNNL_CHECK_FATAL(
|
||||
cnnlSetGroupGemmTensorDescriptor(d_desc(), d_dtype, d, CNNL_POINTER_MODE_DEVICE, nullptr));
|
||||
}
|
||||
|
||||
} // namespace op_desc
|
||||
} // namespace tmo
|
||||
@@ -0,0 +1,177 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
#ifndef CSRC_OPS_OP_DESCRIPTOR_GROUP_GEMM_DESCRIPTOR_H_
|
||||
#define CSRC_OPS_OP_DESCRIPTOR_GROUP_GEMM_DESCRIPTOR_H_
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include "base_descriptor.h"
|
||||
#include "common/utils.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace op_desc {
|
||||
|
||||
struct TMO_HIDDEN GroupGemmDescImpl {
|
||||
GroupGemmDescImpl() {
|
||||
CNNL_CHECK_FATAL(cnnlCreateGroupGemmDescriptor(&group_gemm_desc_));
|
||||
CNNL_CHECK_FATAL(
|
||||
cnnlCreateGroupGemmTensorDescriptor(&a_idx_desc_, CNNL_GROUP_GEMM_ADDRESSING_GATHER_IDX));
|
||||
CNNL_CHECK_FATAL(
|
||||
cnnlCreateGroupGemmTensorDescriptor(&a_offset_desc_, CNNL_GROUP_GEMM_ADDRESSING_OFFSETS));
|
||||
CNNL_CHECK_FATAL(
|
||||
cnnlCreateGroupGemmTensorDescriptor(&b_desc_, CNNL_GROUP_GEMM_ADDRESSING_OFFSETS));
|
||||
CNNL_CHECK_FATAL(
|
||||
cnnlCreateGroupGemmTensorDescriptor(&c_desc_, CNNL_GROUP_GEMM_ADDRESSING_OFFSETS));
|
||||
CNNL_CHECK_FATAL(
|
||||
cnnlCreateGroupGemmTensorDescriptor(&d_desc_, CNNL_GROUP_GEMM_ADDRESSING_OFFSETS));
|
||||
CNNL_CHECK_FATAL(
|
||||
cnnlCreateGroupGemmTensorDescriptor(&a_scale_desc_, CNNL_GROUP_GEMM_ADDRESSING_OFFSETS));
|
||||
CNNL_CHECK_FATAL(
|
||||
cnnlCreateGroupGemmTensorDescriptor(&b_scale_desc_, CNNL_GROUP_GEMM_ADDRESSING_OFFSETS));
|
||||
CNNL_CHECK_FATAL(
|
||||
cnnlCreateGroupGemmTensorDescriptor(&bias_desc_, CNNL_GROUP_GEMM_ADDRESSING_OFFSETS));
|
||||
CNNL_CHECK_FATAL(cnnlCreateTensorDescriptor(&b_scale_quant_grouped_desc_));
|
||||
CNNL_CHECK_FATAL(cnnlCreateTensorDescriptor(&group_host_tensor_));
|
||||
CNNL_CHECK_FATAL(cnnlCreateTensorDescriptor(&group_device_tensor_));
|
||||
CNNL_CHECK_FATAL(cnnlCreateTensorDescriptor(&scale_factor_tensor_));
|
||||
CNNL_CHECK_FATAL(cnnlCreateTensorDescriptor(&quant_flag_tensor_));
|
||||
CNNL_CHECK_FATAL(cnnlCreateActivationDescriptor(&active_desc_));
|
||||
}
|
||||
// using clear function to avoid compilation errors.
|
||||
~GroupGemmDescImpl() { clear(); }
|
||||
DELETE_COPY_ASSIGN_CONSTRUCT(GroupGemmDescImpl);
|
||||
// variable
|
||||
cnnlGroupGemmDescriptor_t group_gemm_desc_;
|
||||
cnnlGroupGemmTensorDescriptor_t a_idx_desc_;
|
||||
cnnlGroupGemmTensorDescriptor_t a_offset_desc_;
|
||||
cnnlGroupGemmTensorDescriptor_t b_desc_;
|
||||
cnnlGroupGemmTensorDescriptor_t c_desc_;
|
||||
cnnlGroupGemmTensorDescriptor_t d_desc_;
|
||||
cnnlGroupGemmTensorDescriptor_t a_scale_desc_;
|
||||
cnnlGroupGemmTensorDescriptor_t b_scale_desc_;
|
||||
cnnlGroupGemmTensorDescriptor_t bias_desc_;
|
||||
cnnlTensorDescriptor_t b_scale_quant_grouped_desc_;
|
||||
cnnlTensorDescriptor_t group_host_tensor_;
|
||||
cnnlTensorDescriptor_t group_device_tensor_;
|
||||
cnnlTensorDescriptor_t scale_factor_tensor_;
|
||||
cnnlTensorDescriptor_t quant_flag_tensor_;
|
||||
cnnlActivationDescriptor_t active_desc_;
|
||||
|
||||
private:
|
||||
inline void clear() {
|
||||
CNNL_CHECK_FATAL(
|
||||
cnnlSetTensorDescriptorPointerMode(group_host_tensor_, CNNL_POINTER_MODE_DEVICE));
|
||||
CNNL_CHECK_FATAL(cnnlDestroyGroupGemmDescriptor(group_gemm_desc_));
|
||||
CNNL_CHECK_FATAL(cnnlDestroyGroupGemmTensorDescriptor(a_idx_desc_));
|
||||
CNNL_CHECK_FATAL(cnnlDestroyGroupGemmTensorDescriptor(a_offset_desc_));
|
||||
CNNL_CHECK_FATAL(cnnlDestroyGroupGemmTensorDescriptor(b_desc_));
|
||||
CNNL_CHECK_FATAL(cnnlDestroyGroupGemmTensorDescriptor(c_desc_));
|
||||
CNNL_CHECK_FATAL(cnnlDestroyGroupGemmTensorDescriptor(d_desc_));
|
||||
CNNL_CHECK_FATAL(cnnlDestroyGroupGemmTensorDescriptor(a_scale_desc_));
|
||||
CNNL_CHECK_FATAL(cnnlDestroyGroupGemmTensorDescriptor(b_scale_desc_));
|
||||
CNNL_CHECK_FATAL(cnnlDestroyGroupGemmTensorDescriptor(bias_desc_));
|
||||
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(b_scale_quant_grouped_desc_));
|
||||
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(group_host_tensor_));
|
||||
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(group_device_tensor_));
|
||||
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(scale_factor_tensor_));
|
||||
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(quant_flag_tensor_));
|
||||
CNNL_CHECK_FATAL(cnnlDestroyActivationDescriptor(active_desc_));
|
||||
}
|
||||
};
|
||||
|
||||
class TMO_EXPORT GroupGemmDesc {
|
||||
public:
|
||||
enum class QuantMode { noQuant, W8, W4, W4W8 };
|
||||
using impl_type = std::unique_ptr<GroupGemmDescImpl, std::function<void(GroupGemmDescImpl *)>>;
|
||||
GroupGemmDesc() = delete;
|
||||
GroupGemmDesc(int num_expert,
|
||||
int max_m,
|
||||
int n,
|
||||
int k,
|
||||
cnnlDataType_t dtype,
|
||||
bool idx_mode,
|
||||
QuantMode quant_mode);
|
||||
void set(int num_expert, int max_m, int n, int k, cnnlDataType_t dtype);
|
||||
void setPerRowColScaleBiasAct(void *a_scale, void *b_scale);
|
||||
void setPerRowColScaleBiasAct(void *a_scale,
|
||||
void *b_scale,
|
||||
void *quant_flag,
|
||||
void *bias,
|
||||
cnnlDataType_t bias_dtype,
|
||||
int channels,
|
||||
int blk,
|
||||
int pos);
|
||||
void setPerRowColScaleBiasAct(void *a_scale,
|
||||
void *b_scale,
|
||||
void *quant_flag, // for w4/w8 quantize
|
||||
void *bias, // nullptr
|
||||
cnnlDataType_t bias_dtype, // CNNL_DTYPE_INVALID
|
||||
cnnlActivationMode_t mode, // CNNL_ACTIVATION_SILU
|
||||
cnnlNanPropagation_t nan_prop, // CNNL_PROPAGATE_NAN
|
||||
float coef, // 0.f
|
||||
bool has_active, // false
|
||||
int channels,
|
||||
int scale_block_size,
|
||||
int channel_pos);
|
||||
void setInputOutputTensor(cnnlDataType_t a_dtype,
|
||||
cnnlDataType_t b_dtype,
|
||||
cnnlDataType_t d_dtype,
|
||||
cnnlDataType_t idx_dtype,
|
||||
void *a,
|
||||
void *b,
|
||||
void *c,
|
||||
void *d,
|
||||
void *idx,
|
||||
int64_t k,
|
||||
int64_t k_stride,
|
||||
int64_t total_m,
|
||||
bool has_c,
|
||||
int64_t *b_offset = nullptr);
|
||||
cnnlGroupGemmDescriptor_t group_gemm_desc() { return impl_->group_gemm_desc_; }
|
||||
cnnlGroupGemmTensorDescriptor_t a_desc() {
|
||||
return idx_mode_ ? impl_->a_idx_desc_ : impl_->a_offset_desc_;
|
||||
}
|
||||
cnnlGroupGemmTensorDescriptor_t b_desc() { return impl_->b_desc_; }
|
||||
cnnlGroupGemmTensorDescriptor_t c_desc() { return impl_->c_desc_; }
|
||||
cnnlGroupGemmTensorDescriptor_t d_desc() { return impl_->d_desc_; }
|
||||
cnnlGroupGemmTensorDescriptor_t a_scale_desc() { return impl_->a_scale_desc_; }
|
||||
cnnlGroupGemmTensorDescriptor_t b_scale_desc() { return impl_->b_scale_desc_; }
|
||||
cnnlGroupGemmTensorDescriptor_t bias_desc() { return impl_->bias_desc_; }
|
||||
cnnlTensorDescriptor_t b_scale_quant_grouped_desc() { return impl_->b_scale_quant_grouped_desc_; }
|
||||
cnnlTensorDescriptor_t group_host_tensor() { return impl_->group_host_tensor_; }
|
||||
cnnlTensorDescriptor_t group_device_tensor() { return impl_->group_device_tensor_; }
|
||||
cnnlTensorDescriptor_t scale_factor_tensor() { return impl_->scale_factor_tensor_; }
|
||||
cnnlTensorDescriptor_t quant_flag_desc() { return impl_->quant_flag_tensor_; }
|
||||
cnnlActivationDescriptor_t active_desc() { return impl_->active_desc_; }
|
||||
bool idx_mode() { return idx_mode_; }
|
||||
bool has_c() { return has_c_; }
|
||||
|
||||
private:
|
||||
impl_type impl_;
|
||||
bool idx_mode_ = false;
|
||||
bool has_c_ = false;
|
||||
bool has_bias_ = false;
|
||||
bool has_active_ = false;
|
||||
int num_expert_;
|
||||
int max_m_;
|
||||
int n_;
|
||||
int k_;
|
||||
cnnlDataType_t dtype_;
|
||||
QuantMode quant_mode_ = QuantMode::noQuant;
|
||||
};
|
||||
|
||||
} // namespace op_desc
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_OPS_OP_DESCRIPTOR_GROUP_GEMM_DESCRIPTOR_H_
|
||||
@@ -0,0 +1,104 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "matmul_descriptor.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace op_desc {
|
||||
|
||||
namespace {
|
||||
static OpDescPool<MatMulDescImpl> matmul_instance;
|
||||
}
|
||||
|
||||
MatMulDesc::MatMulDesc(const cnnlTensorDescriptor_t &bias_tensor,
|
||||
const void *bias_data,
|
||||
size_t &workspace_size,
|
||||
int32_t lda,
|
||||
int32_t ldb,
|
||||
int32_t ldc,
|
||||
const std::string &act_name,
|
||||
bool use_beta,
|
||||
bool trans_a,
|
||||
bool trans_b,
|
||||
bool fast_act,
|
||||
bool approximate) {
|
||||
impl_ = matmul_instance.get();
|
||||
auto matmul_desc_ = impl_->matmul_desc_;
|
||||
auto act_desc_ = impl_->act_desc_;
|
||||
workspace_size = max_workspace_size_;
|
||||
|
||||
int32_t matmul_trans_a = int(trans_a);
|
||||
int32_t matmul_trans_b = int(trans_b);
|
||||
int32_t matmul_use_tf32 = 0;
|
||||
int32_t matmul_use_beta = int(use_beta);
|
||||
CNNL_CHECK_FATAL(cnnlSetMatMulExDescAttr(matmul_desc_, CNNL_MATMUL_EX_DESC_TRANSA,
|
||||
&(matmul_trans_a), sizeof(int32_t)));
|
||||
CNNL_CHECK_FATAL(cnnlSetMatMulExDescAttr(matmul_desc_, CNNL_MATMUL_EX_DESC_TRANSB,
|
||||
&(matmul_trans_b), sizeof(int32_t)));
|
||||
CNNL_CHECK_FATAL(cnnlSetMatMulExDescAttr(matmul_desc_, CNNL_MATMUL_EX_ALLOW_TF32,
|
||||
&(matmul_use_tf32), sizeof(int32_t)));
|
||||
CNNL_CHECK_FATAL(cnnlSetMatMulExDescAttr(matmul_desc_, CNNL_MATMUL_EX_USE_BETA,
|
||||
&(matmul_use_beta), sizeof(int32_t)));
|
||||
CNNL_CHECK_FATAL(
|
||||
cnnlSetMatMulExDescAttr(matmul_desc_, CNNL_MATMUL_EX_DESC_LDA, &(lda), sizeof(int32_t)));
|
||||
CNNL_CHECK_FATAL(
|
||||
cnnlSetMatMulExDescAttr(matmul_desc_, CNNL_MATMUL_EX_DESC_LDB, &(ldb), sizeof(int32_t)));
|
||||
CNNL_CHECK_FATAL(
|
||||
cnnlSetMatMulExDescAttr(matmul_desc_, CNNL_MATMUL_EX_DESC_LDC, &(ldc), sizeof(int32_t)));
|
||||
|
||||
CNNL_CHECK_FATAL(cnnlSetMatMulExDescAttr(matmul_desc_, CNNL_MATMUL_EX_PREF_MAX_WORKSPACE_BYTES,
|
||||
&(max_workspace_size_), sizeof(int64_t)));
|
||||
if (act_name == "none") {
|
||||
if (bias_data == nullptr) {
|
||||
cnnlMatMulEpilogueType_t fuse_type = CNNL_MATMUL_EPI_NONE;
|
||||
CNNL_CHECK_FATAL(cnnlSetMatMulExDescAttr(
|
||||
matmul_desc_, cnnlMatMulExDescAttribute_t::CNNL_MATMUL_EX_DESC_EPILOGUE_TYPE,
|
||||
(void *)&fuse_type, sizeof(fuse_type)));
|
||||
|
||||
} else {
|
||||
cnnlMatMulEpilogueType_t fuse_type = CNNL_MATMUL_EPI_BIAS;
|
||||
CNNL_CHECK_FATAL(cnnlSetMatMulExDescAttr(
|
||||
matmul_desc_, cnnlMatMulExDescAttribute_t::CNNL_MATMUL_EX_DESC_EPILOGUE_TYPE,
|
||||
(void *)&fuse_type, sizeof(fuse_type)));
|
||||
CNNL_CHECK_FATAL(cnnlSetMatMulExBias(matmul_desc_, bias_tensor, (void *)bias_data));
|
||||
}
|
||||
} else {
|
||||
float act_coef = 0.0f;
|
||||
cnnlActivationMode_t act_mode = strToActivationMode(act_name);
|
||||
if (act_name == "silu") {
|
||||
act_coef = 1.0f;
|
||||
act_mode = CNNL_ACTIVATION_SILU;
|
||||
}
|
||||
cnnlActivationPreference_t act_pref = CNNL_ACTIVATION_HIGH_PRECISION;
|
||||
if (fast_act) {
|
||||
act_pref = CNNL_ACTIVATION_FAST;
|
||||
}
|
||||
CNNL_CHECK_FATAL(cnnlSetActivationDescriptor_v6(
|
||||
act_desc_, act_mode, act_pref, CNNL_NOT_PROPAGATE_NAN, act_coef, 0, 0, 0, 0, approximate));
|
||||
// fuse info
|
||||
cnnlMatMulEpilogueType_t fuse_type = CNNL_MATMUL_EPI_BIAS_SCALE_BN_ACTIVATION;
|
||||
CNNL_CHECK_FATAL(cnnlSetMatMulExDescAttr(
|
||||
matmul_desc_, cnnlMatMulExDescAttribute_t::CNNL_MATMUL_EX_DESC_EPILOGUE_TYPE, &(fuse_type),
|
||||
sizeof(fuse_type)));
|
||||
CNNL_CHECK_FATAL(cnnlSetMatMulExBiasScaleBNActive(matmul_desc_, bias_tensor, bias_data, nullptr,
|
||||
nullptr, nullptr, nullptr, nullptr, nullptr,
|
||||
0, 0, 0, 0, 0, act_desc_));
|
||||
}
|
||||
}
|
||||
|
||||
MatMulDescImpl::MatMulDescImpl() {
|
||||
CNNL_CHECK_FATAL(cnnlCreateMatMulExDescriptor(&this->matmul_desc_));
|
||||
CNNL_CHECK_FATAL(cnnlCreateActivationDescriptor(&this->act_desc_));
|
||||
}
|
||||
|
||||
} // namespace op_desc
|
||||
} // namespace tmo
|
||||
@@ -0,0 +1,68 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#ifndef CSRC_OPS_OP_DESCRIPTOR_MATMUL_DESCRIPTOR_H_
|
||||
#define CSRC_OPS_OP_DESCRIPTOR_MATMUL_DESCRIPTOR_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include "base_descriptor.h"
|
||||
#include "common/utils.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace op_desc {
|
||||
|
||||
class TMO_HIDDEN MatMulDescImpl {
|
||||
public:
|
||||
MatMulDescImpl();
|
||||
// using clear function to avoid compilation errors.
|
||||
~MatMulDescImpl() { clear(); }
|
||||
DELETE_COPY_ASSIGN_CONSTRUCT(MatMulDescImpl);
|
||||
|
||||
cnnlMatMulExDescriptor_t matmul_desc_;
|
||||
cnnlActivationDescriptor_t act_desc_;
|
||||
|
||||
private:
|
||||
inline void clear() {
|
||||
CNNL_CHECK_FATAL(cnnlDestroyMatMulExDescriptor(this->matmul_desc_));
|
||||
CNNL_CHECK_FATAL(cnnlDestroyActivationDescriptor(this->act_desc_));
|
||||
}
|
||||
};
|
||||
|
||||
class TMO_EXPORT MatMulDesc {
|
||||
public:
|
||||
using impl_type_ = std::unique_ptr<MatMulDescImpl, std::function<void(MatMulDescImpl *)>>;
|
||||
MatMulDesc(const cnnlTensorDescriptor_t &bias_tensor,
|
||||
const void *bias_data,
|
||||
size_t &workspace_size,
|
||||
int32_t lda,
|
||||
int32_t ldb,
|
||||
int32_t ldc,
|
||||
const std::string &act_name = "none",
|
||||
bool use_beta = false,
|
||||
bool trans_a = false,
|
||||
bool trans_b = true,
|
||||
bool fast_act = true,
|
||||
bool approximate = true);
|
||||
|
||||
CLASS_CAST_TYPE_OPERATOR_DEFINE(cnnlMatMulExDescriptor_t, impl_->matmul_desc_)
|
||||
CLASS_CAST_TYPE_OPERATOR_DEFINE(cnnlActivationDescriptor_t, impl_->act_desc_)
|
||||
|
||||
private:
|
||||
size_t max_workspace_size_ = 128 * 1024 * 1024;
|
||||
impl_type_ impl_;
|
||||
};
|
||||
|
||||
} // namespace op_desc
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_OPS_OP_DESCRIPTOR_MATMUL_DESCRIPTOR_H_
|
||||
@@ -0,0 +1,79 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "quant_matmul_descriptor.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace op_desc {
|
||||
|
||||
namespace {
|
||||
static OpDescPool<QuantMatmulDescImpl> instance;
|
||||
}
|
||||
|
||||
QuantMatmulDesc::QuantMatmulDesc(const std::string &quant_algo,
|
||||
const std::string &a_quant_layout,
|
||||
const std::string &b_quant_layout,
|
||||
unsigned quant_bit_size,
|
||||
cnnlDataType_t compute_dtype,
|
||||
const std::string &act_name,
|
||||
bool use_hp_active,
|
||||
float act_coef,
|
||||
bool trans_a,
|
||||
bool trans_b) {
|
||||
impl_ = instance.get();
|
||||
this->setQuantMatmulDesc(quant_algo, a_quant_layout, b_quant_layout, quant_bit_size,
|
||||
compute_dtype, act_name, use_hp_active, act_coef, trans_a, trans_b);
|
||||
}
|
||||
|
||||
void QuantMatmulDesc::setQuantMatmulDesc(const std::string &quant_algo,
|
||||
const std::string &a_quant_layout,
|
||||
const std::string &b_quant_layout,
|
||||
unsigned quant_bit_size,
|
||||
cnnlDataType_t compute_dtype,
|
||||
const std::string &act_name,
|
||||
bool use_hp_active,
|
||||
float act_coef,
|
||||
bool trans_a,
|
||||
bool trans_b) {
|
||||
cnnlQuantizeDescriptor_t no_quant;
|
||||
CNNL_CHECK_FATAL(cnnlCreateQuantizeDescriptor(&no_quant));
|
||||
// set quantize algo and layout
|
||||
cnnlLLMQuantAlgo_t cnnl_quant_algo = strToQuantizeAlgo(quant_algo);
|
||||
CNNL_CHECK_FATAL(cnnlSetQuantizeDescriptorBehaviour(
|
||||
impl_->a_quant_desc_, strToQuantizeLayout(a_quant_layout), quant_bit_size, 0, false));
|
||||
CNNL_CHECK_FATAL(cnnlSetQuantizeDescriptorBehaviour(
|
||||
impl_->b_quant_desc_, strToQuantizeLayout(b_quant_layout), quant_bit_size, 0, false));
|
||||
// set activation
|
||||
act_coef = act_name == "silu" ? 1.0f : act_coef;
|
||||
cnnlActivationMode_t act_mode = strToActivationMode(act_name);
|
||||
CNNL_CHECK_FATAL(cnnlSetActivationDescriptor_v5(
|
||||
impl_->act_desc_, act_mode,
|
||||
use_hp_active ? CNNL_ACTIVATION_HIGH_PRECISION : CNNL_ACTIVATION_FAST, CNNL_NOT_PROPAGATE_NAN,
|
||||
act_coef, 0, 0, 0, 0));
|
||||
// set into op_desc
|
||||
CNNL_CHECK_FATAL(cnnlSetLLMQuantMatmulDescriptor(
|
||||
impl_->op_desc_, impl_->a_quant_desc_, impl_->b_quant_desc_, no_quant, no_quant,
|
||||
impl_->act_desc_, compute_dtype, cnnl_quant_algo, trans_a, trans_b));
|
||||
CNNL_CHECK_FATAL(cnnlDestroyQuantizeDescriptor(no_quant));
|
||||
}
|
||||
|
||||
QuantMatmulDescImpl::QuantMatmulDescImpl() {
|
||||
CNNL_CHECK_FATAL(cnnlCreateLLMQuantMatmulDescriptor(&this->op_desc_));
|
||||
CNNL_CHECK_FATAL(cnnlCreateQuantizeDescriptor(&this->a_quant_desc_));
|
||||
CNNL_CHECK_FATAL(cnnlCreateQuantizeDescriptor(&this->b_quant_desc_));
|
||||
CNNL_CHECK_FATAL(cnnlCreateQuantizeDescriptor(&this->c_quant_desc_));
|
||||
CNNL_CHECK_FATAL(cnnlCreateQuantizeDescriptor(&this->d_quant_desc_));
|
||||
CNNL_CHECK_FATAL(cnnlCreateActivationDescriptor(&this->act_desc_));
|
||||
}
|
||||
|
||||
} // namespace op_desc
|
||||
} // namespace tmo
|
||||
@@ -0,0 +1,86 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#ifndef CSRC_OPS_OP_DESCRIPTOR_QUANT_MATMUL_DESCRIPTOR_H_
|
||||
#define CSRC_OPS_OP_DESCRIPTOR_QUANT_MATMUL_DESCRIPTOR_H_
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include "base_descriptor.h"
|
||||
#include "common/utils.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace op_desc {
|
||||
|
||||
class TMO_HIDDEN QuantMatmulDescImpl {
|
||||
public:
|
||||
QuantMatmulDescImpl();
|
||||
// using clear function to avoid compilation errors.
|
||||
~QuantMatmulDescImpl() { clear(); }
|
||||
DELETE_COPY_ASSIGN_CONSTRUCT(QuantMatmulDescImpl);
|
||||
// variable
|
||||
cnnlLLMQuantMatmulDescriptor_t op_desc_;
|
||||
cnnlQuantizeDescriptor_t a_quant_desc_;
|
||||
cnnlQuantizeDescriptor_t b_quant_desc_;
|
||||
cnnlQuantizeDescriptor_t c_quant_desc_;
|
||||
cnnlQuantizeDescriptor_t d_quant_desc_;
|
||||
cnnlActivationDescriptor_t act_desc_;
|
||||
|
||||
private:
|
||||
inline void clear() {
|
||||
CNNL_CHECK_FATAL(cnnlDestroyLLMQuantMatmulDescriptor(this->op_desc_));
|
||||
CNNL_CHECK_FATAL(cnnlDestroyQuantizeDescriptor(this->a_quant_desc_));
|
||||
CNNL_CHECK_FATAL(cnnlDestroyQuantizeDescriptor(this->b_quant_desc_));
|
||||
CNNL_CHECK_FATAL(cnnlDestroyQuantizeDescriptor(this->c_quant_desc_));
|
||||
CNNL_CHECK_FATAL(cnnlDestroyQuantizeDescriptor(this->d_quant_desc_));
|
||||
CNNL_CHECK_FATAL(cnnlDestroyActivationDescriptor(this->act_desc_));
|
||||
}
|
||||
};
|
||||
|
||||
class TMO_EXPORT QuantMatmulDesc {
|
||||
public:
|
||||
using impl_type_ =
|
||||
std::unique_ptr<QuantMatmulDescImpl, std::function<void(QuantMatmulDescImpl *)>>;
|
||||
QuantMatmulDesc(const std::string &quant_algo,
|
||||
const std::string &a_quant_layout,
|
||||
const std::string &b_quant_layout,
|
||||
unsigned quant_bit_size,
|
||||
cnnlDataType_t compute_dtype,
|
||||
const std::string &act_name,
|
||||
bool use_hp_active,
|
||||
float act_coef,
|
||||
bool trans_a,
|
||||
bool trans_b);
|
||||
|
||||
// set descriptor
|
||||
void setQuantMatmulDesc(const std::string &quant_algo,
|
||||
const std::string &a_quant_layout,
|
||||
const std::string &b_quant_layout,
|
||||
unsigned quant_bit_size,
|
||||
cnnlDataType_t compute_dtype,
|
||||
const std::string &act_name,
|
||||
bool use_hp_active,
|
||||
float act_coef,
|
||||
bool trans_a,
|
||||
bool trans_b);
|
||||
|
||||
operator cnnlLLMQuantMatmulDescriptor_t() { return this->impl_->op_desc_; }
|
||||
operator cnnlLLMQuantMatmulDescriptor_t() const { return this->impl_->op_desc_; }
|
||||
|
||||
private:
|
||||
impl_type_ impl_;
|
||||
};
|
||||
|
||||
} // namespace op_desc
|
||||
} // namespace tmo
|
||||
|
||||
#endif // CSRC_OPS_OP_DESCRIPTOR_QUANT_MATMUL_DESCRIPTOR_H_
|
||||
Reference in New Issue
Block a user