This commit is contained in:
Chranos
2026-02-04 17:39:32 +08:00
parent 8511fe8530
commit 79dfc69789
299 changed files with 55927 additions and 0 deletions

View File

@@ -0,0 +1,89 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <cstddef>
#include <vector>
#include "kernel_api.h"
namespace tmo {
namespace ops {
size_t getGroupGemmWorkspaceSize(cnnlHandle_t handle, GroupGemmDesc &desc, const int num_expert) {
size_t workspace_size = 0;
CNNL_CHECK_FATAL(
cnnlGetGroupGemmWorkspaceSize(handle, // handle
desc.group_gemm_desc(), // cnnlGroupGemmDescriptor_t
nullptr, // cnnlGroupGemmAlgo_t
num_expert, // groups
false, // is_trans_a
true, // is_trans_b
desc.group_host_tensor(), // m_desc
desc.group_host_tensor(), // n_desc
desc.group_host_tensor(), // k_desc
nullptr, // alpha_desc
desc.group_host_tensor(), // lda_desc
desc.a_desc(), // a_desc
desc.group_host_tensor(), // ldb_desc
desc.b_desc(), // b_desc
nullptr, // beta_desc
nullptr, // ldc_desc
nullptr, // c_desc
desc.group_host_tensor(), // ldd_desc
desc.d_desc(), // d_desc
&workspace_size));
return workspace_size;
}
void GroupGemm(const cnnlHandle_t &handle,
GroupGemmDesc &desc,
void *m,
void *alpha,
void *beta,
void *workspace,
size_t workspace_size,
int num_expert,
int k,
int n,
int lda,
std::vector<int> &ldb) {
std::vector<int> n_array(num_expert, n);
std::vector<int> k_array(num_expert, k);
std::vector<int> lda_array(num_expert, lda);
std::vector<int> ldd_array(num_expert, n);
CNNL_CHECK_FATAL(cnnlGroupGemm(handle, desc.group_gemm_desc(), /*desc*/
nullptr, /*algo*/
num_expert, /*groups*/
false, /*is_trans_a*/
true, /*is_trans_b*/
desc.group_device_tensor(), m, /*m*/
desc.group_host_tensor(), n_array.data(), /*n*/
desc.group_host_tensor(), k_array.data(), /*k*/
alpha ? desc.scale_factor_tensor() : nullptr, /*alpha_desc*/
alpha, /*alpha*/
desc.group_host_tensor(), lda_array.data(), /*lda*/
desc.a_desc(), /*a*/
ldb.empty() ? nullptr : desc.group_host_tensor(),
ldb.empty() ? nullptr : ldb.data(), /*ldb*/
desc.b_desc(), /*b*/
beta ? desc.scale_factor_tensor() : nullptr, /*beta_desc*/
beta, /*beta*/
desc.group_host_tensor(), /*ldc_desc*/
ldd_array.data(), /*ldc*/
desc.has_c() ? desc.c_desc() : desc.d_desc(), /*c*/
workspace, workspace_size, /*workspace*/
desc.group_host_tensor(), ldd_array.data(), /*ldd*/
desc.d_desc())); /*d*/
}
} // namespace ops
} // namespace tmo

View File

@@ -0,0 +1,117 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "kernel_api.h"
namespace tmo {
namespace ops {
void SmoothQuant(const cnnlHandle_t &handle,
void *input,
void *smooth_scale,
void *token_count,
void *gather_idx,
void *gather_idx_start_position,
void *output,
void *output_scale,
int n,
int c,
int e,
int input_stride,
int output_stride,
int topk,
cnnlDataType_t input_dtype) {
int dim_nb = 2;
int output_scale_dim_nb = 1;
if (!gather_idx) {
topk = 1;
}
int64_t input_dims[] = {n, c};
int64_t output_dims[] = {n * topk, c};
int64_t output_scale_dims[] = {n * topk};
int64_t input_strides[] = {input_stride, 1};
int64_t output_strides[] = {output_stride, 1};
cnnlTensorDescriptor_t input_tensor, smooth_scale_tensor, token_count_tensor, gather_idx_tensor,
gather_idx_start_pos_tensor, output_tensor, output_scale_tensor;
cnnlCreateTensorDescriptor(&input_tensor);
cnnlCreateTensorDescriptor(&smooth_scale_tensor);
cnnlCreateTensorDescriptor(&output_tensor);
cnnlCreateTensorDescriptor(&output_scale_tensor);
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorEx_v2(input_tensor, CNNL_LAYOUT_ARRAY, input_dtype,
dim_nb, input_dims, input_strides));
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorEx_v2(output_tensor, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_INT8,
dim_nb, output_dims, output_strides));
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(output_scale_tensor, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, output_scale_dim_nb,
output_scale_dims));
if (token_count) {
int64_t smooth_scale_dims[] = {e, c};
int64_t token_count_dims[] = {e};
cnnlCreateTensorDescriptor(&token_count_tensor);
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(token_count_tensor, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_INT32, 1, token_count_dims));
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(smooth_scale_tensor, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, 2, smooth_scale_dims));
} else {
int64_t smooth_scale_dims[] = {c};
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(smooth_scale_tensor, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, 1, smooth_scale_dims));
}
if (gather_idx) {
int64_t gather_idx_dims[] = {n * topk};
cnnlCreateTensorDescriptor(&gather_idx_tensor);
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(gather_idx_tensor, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_INT32, 1, gather_idx_dims));
}
if (gather_idx_start_position) {
int64_t gather_idx_start_pos_dims[] = {1};
cnnlCreateTensorDescriptor(&gather_idx_start_pos_tensor);
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(gather_idx_start_pos_tensor, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_INT32, 1, gather_idx_start_pos_dims));
}
CNNL_CHECK_FATAL(cnnlSmoothQuantOnline_v4(
handle, nullptr /*smooth_quant_online_desc*/, input_tensor /*input_desc*/,
input /*input_ptr*/, smooth_scale_tensor /*input_smooth_quant_scale_desc*/,
smooth_scale /*input_sq_scale_ptr*/, nullptr /*input_smooth_quant_zero_desc*/,
nullptr /*input_sq_zero_ptr*/,
token_count ? token_count_tensor /*token_count_desc*/ : nullptr,
token_count /*token_count_ptr*/, gather_idx ? gather_idx_tensor /*gather_idx_desc*/ : nullptr,
gather_idx /*gather_idx_ptr*/,
gather_idx_start_position /*gather_idx_start_pos_desc*/ ? gather_idx_start_pos_tensor
: nullptr,
gather_idx_start_position /*gather_idx_start_pos_ptr*/, output_tensor /*output_desc*/,
output /*output_ptr*/, output_scale_tensor /*output_scale_desc*/,
output_scale /*output_scale_ptr*/, nullptr /*output_zero_desc*/, nullptr /*output_zero_ptr*/,
nullptr /*worksapce*/, 0 /*workspace_size*/));
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(input_tensor));
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(smooth_scale_tensor));
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(output_tensor));
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(output_scale_tensor));
if (token_count) {
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(token_count_tensor));
}
if (gather_idx) {
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(gather_idx_tensor));
}
if (gather_idx_start_position) {
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(gather_idx_start_pos_tensor));
}
}
} // namespace ops
} // namespace tmo

View File

@@ -0,0 +1,66 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_OPS_KERNEL_API_H_
#define CSRC_OPS_KERNEL_API_H_
#include <map>
#include <vector>
#include "cnrt.h"
#include "op_descriptor/attn_proj_descriptor.h"
#include "op_descriptor/batchmatmul_descriptor.h"
#include "op_descriptor/ffn_descriptor.h"
#include "op_descriptor/group_gemm_descriptor.h"
#include "op_descriptor/matmul_descriptor.h"
#include "op_descriptor/quant_matmul_descriptor.h"
namespace tmo {
namespace ops {
using GroupGemmDesc = tmo::op_desc::GroupGemmDesc;
size_t getGroupGemmWorkspaceSize(cnnlHandle_t handle, GroupGemmDesc &desc, const int num_expert);
void GroupGemm(const cnnlHandle_t &handle,
GroupGemmDesc &desc,
void *m,
void *alpha,
void *beta,
void *workspace,
size_t workspace_size,
int num_expert,
int k,
int n,
int lda,
std::vector<int> &ldb);
void SmoothQuant(const cnnlHandle_t &handle,
void *input,
void *smooth_scale,
void *token_count,
void *gather_idx,
void *gather_idx_start_position,
void *output,
void *output_scale,
int n,
int c,
int e,
int input_stride,
int output_stride,
int topk,
cnnlDataType_t input_dtype);
} // namespace ops
} // namespace tmo
#endif // CSRC_OPS_KERNEL_API_H_

View File

@@ -0,0 +1,58 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "attn_proj_descriptor.h"
#include "base_descriptor.h"
namespace tmo {
namespace op_desc {
namespace {
static OpDescPool<AttnProjDescImpl> attn_proj_instance;
}
AttnProjDesc::AttnProjDesc() {
impl_ = attn_proj_instance.get();
}
void AttnProjDesc::setDesc(
const cnnlTransformerLayernormResidualStructure_t layernorm_residual_mode,
const cnnlDataType_t compute_dtype,
const bool q_has_value,
const bool k_has_value,
const bool v_has_value,
const bool has_bias,
const bool is_pack_mode,
const int packed_max_seq_len,
const bool trans_out,
const bool store_layernorm_result,
const float alpha,
const float beta,
const float layernorm_eps,
const bool is_quant) {
CNNL_CHECK_FATAL(cnnlSetTransformerAttnProjDescriptor(
impl_->attn_proj_desc_, layernorm_residual_mode, nullptr, /*activation desc*/
compute_dtype, q_has_value, k_has_value, v_has_value, has_bias, is_pack_mode,
packed_max_seq_len, trans_out, store_layernorm_result, alpha, beta, layernorm_eps));
const int allow_tf32 = 0; /*Not allow*/
CNNL_CHECK_FATAL(
cnnlSetTransformerAttnProjDescriptorAllowTF32(impl_->attn_proj_desc_, allow_tf32));
this->is_quant_ = is_quant;
}
AttnProjDescImpl::AttnProjDescImpl() {
CNNL_CHECK_FATAL(cnnlCreateTransformerAttnProjDescriptor(&this->attn_proj_desc_));
CNNL_CHECK_FATAL(cnnlCreateTransformerAttnProjQuantifyDescriptor(&this->attn_proj_quant_desc_));
}
} // namespace op_desc
} // namespace tmo

View File

@@ -0,0 +1,82 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_OPS_OP_DESCRIPTOR_ATTN_PROJ_DESCRIPTOR_H_
#define CSRC_OPS_OP_DESCRIPTOR_ATTN_PROJ_DESCRIPTOR_H_
#include <functional>
#include <memory>
#include "cnnl.h"
#include "cnnl_extra.h"
#include "common/utils.h"
namespace tmo {
namespace op_desc {
class TMO_HIDDEN AttnProjDescImpl {
public:
AttnProjDescImpl();
// using clear function to avoid compilation errors.
~AttnProjDescImpl() { clear(); }
DELETE_COPY_ASSIGN_CONSTRUCT(AttnProjDescImpl);
cnnlTransformerAttnProjDescriptor_t attn_proj_desc_;
cnnlTransformerAttnProjQuantizeDescriptor_t attn_proj_quant_desc_;
private:
inline void clear() {
CNNL_CHECK_FATAL(cnnlDestroyTransformerAttnProjDescriptor(this->attn_proj_desc_));
CNNL_CHECK_FATAL(cnnlDestroyTransformerAttnProjQuantifyDescriptor(this->attn_proj_quant_desc_));
}
};
class TMO_EXPORT AttnProjDesc {
public:
using impl_type_ = std::unique_ptr<AttnProjDescImpl, std::function<void(AttnProjDescImpl *)>>;
AttnProjDesc();
void setDesc(const cnnlTransformerLayernormResidualStructure_t layernorm_residual_mode,
const cnnlDataType_t compute_dtype,
const bool q_has_value,
const bool k_has_value,
const bool v_has_value,
const bool has_bias,
const bool is_pack_mode,
const int packed_max_seq_len,
const bool trans_out,
const bool store_layernorm_result,
const float alpha,
const float beta,
const float layernorm_eps,
const bool is_quant = false);
operator cnnlTransformerAttnProjDescriptor_t() { return this->impl_->attn_proj_desc_; }
operator cnnlTransformerAttnProjDescriptor_t() const { return this->impl_->attn_proj_desc_; }
operator cnnlTransformerAttnProjQuantizeDescriptor_t() {
return is_quant_ ? this->impl_->attn_proj_quant_desc_ : nullptr;
}
operator cnnlTransformerAttnProjQuantizeDescriptor_t() const {
return is_quant_ ? this->impl_->attn_proj_quant_desc_ : nullptr;
}
private:
impl_type_ impl_;
bool is_quant_ = false;
};
} // namespace op_desc
} // namespace tmo
#endif // CSRC_OPS_OP_DESCRIPTOR_ATTN_PROJ_DESCRIPTOR_H_

View File

@@ -0,0 +1,62 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_OPS_OP_DESCRIPTOR_BASE_DESCRIPTOR_H_
#define CSRC_OPS_OP_DESCRIPTOR_BASE_DESCRIPTOR_H_
#include <deque>
#include <mutex>
#include "common/utils.h"
namespace tmo {
namespace op_desc {
static constexpr int OpCreateStage = 0;
// Add BaseDescPool is only for store different op
// desc pool in a specific container.
template <typename T>
class TMO_HIDDEN OpDescPool {
public:
using desc_unique_ptr = std::unique_ptr<T>;
OpDescPool() { createMultiItems(); }
// delete copy construct and assign construct.
DELETE_COPY_ASSIGN_CONSTRUCT(OpDescPool);
desc_unique_ptr get() {
auto ptr = std::make_unique<T>();
return desc_unique_ptr(ptr.release());
}
~OpDescPool() {
std::lock_guard<std::mutex> lock(this->mutex_);
for (size_t i = 0; i < this->container_.size(); ++i) {
this->container_[i].release();
}
this->container_.clear();
}
private:
void createMultiItems() {
for (int i = 0; i < OpCreateStage; ++i) {
container_.emplace_back(std::make_unique<T>());
}
}
// variable
std::mutex mutex_;
std::deque<std::unique_ptr<T>> container_;
};
} // namespace op_desc
} // namespace tmo
#endif // CSRC_OPS_OP_DESCRIPTOR_BASE_DESCRIPTOR_H_

View File

@@ -0,0 +1,52 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "batchmatmul_descriptor.h"
namespace tmo {
namespace op_desc {
namespace {
static OpDescPool<BatchMatMulDescImpl> bmm_instance;
}
BatchMatMulDesc::BatchMatMulDesc(cnnlMatMulHeuristicResult_t &heuristic_result,
cnnlMatMulAlgo_t &algo,
bool use_beta,
bool trans_a,
bool trans_b) {
impl_ = bmm_instance.get();
auto bmm_desc_ = impl_->bmm_desc_;
heuristic_result = impl_->heuristic_result_;
algo = impl_->algo_;
int32_t matmul_trans_a = int(trans_a);
int32_t matmul_trans_b = int(trans_b);
int32_t matmul_use_tf32 = 0;
int32_t matmul_use_beta = int(use_beta);
CNNL_CHECK_FATAL(cnnlSetMatMulDescAttr(bmm_desc_, CNNL_MATMUL_DESC_TRANSA, &(matmul_trans_a),
sizeof(int32_t)));
CNNL_CHECK_FATAL(cnnlSetMatMulDescAttr(bmm_desc_, CNNL_MATMUL_DESC_TRANSB, &(matmul_trans_b),
sizeof(int32_t)));
CNNL_CHECK_FATAL(cnnlSetMatMulDescAttr(bmm_desc_, CNNL_MATMUL_ALLOW_TF32, &(matmul_use_tf32),
sizeof(int32_t)));
CNNL_CHECK_FATAL(
cnnlSetMatMulDescAttr(bmm_desc_, CNNL_MATMUL_USE_BETA, &(matmul_use_beta), sizeof(int32_t)));
}
BatchMatMulDescImpl::BatchMatMulDescImpl() {
CNNL_CHECK_FATAL(cnnlCreateMatMulDescriptor(&this->bmm_desc_));
CNNL_CHECK_FATAL(cnnlCreateMatMulHeuristicResult(&this->heuristic_result_));
CNNL_CHECK_FATAL(cnnlCreateMatMulAlgo(&this->algo_));
}
} // namespace op_desc
} // namespace tmo

View File

@@ -0,0 +1,64 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_OPS_OP_DESCRIPTOR_BATCHMATMUL_DESCRIPTOR_H_
#define CSRC_OPS_OP_DESCRIPTOR_BATCHMATMUL_DESCRIPTOR_H_
#include <functional>
#include <memory>
#include "base_descriptor.h"
#include "common/utils.h"
namespace tmo {
namespace op_desc {
class TMO_HIDDEN BatchMatMulDescImpl {
public:
BatchMatMulDescImpl();
// using clear function to avoid compilation errors.
~BatchMatMulDescImpl() { clear(); }
DELETE_COPY_ASSIGN_CONSTRUCT(BatchMatMulDescImpl);
cnnlMatMulDescriptor_t bmm_desc_;
cnnlMatMulHeuristicResult_t heuristic_result_;
cnnlMatMulAlgo_t algo_;
private:
inline void clear() {
CNNL_CHECK_FATAL(cnnlDestroyMatMulDescriptor(this->bmm_desc_));
CNNL_CHECK_FATAL(cnnlDestroyMatMulHeuristicResult(this->heuristic_result_));
CNNL_CHECK_FATAL(cnnlDestroyMatMulAlgo(this->algo_));
}
};
class TMO_EXPORT BatchMatMulDesc {
public:
using impl_type_ =
std::unique_ptr<BatchMatMulDescImpl, std::function<void(BatchMatMulDescImpl *)>>;
BatchMatMulDesc(cnnlMatMulHeuristicResult_t &heuristic_result,
cnnlMatMulAlgo_t &algo,
bool use_beta = false,
bool trans_a = false,
bool trans_b = true);
CLASS_CAST_TYPE_OPERATOR_DEFINE(cnnlMatMulDescriptor_t, impl_->bmm_desc_)
CLASS_CAST_TYPE_OPERATOR_DEFINE(cnnlMatMulHeuristicResult_t, impl_->heuristic_result_)
CLASS_CAST_TYPE_OPERATOR_DEFINE(cnnlMatMulAlgo_t, impl_->algo_)
private:
impl_type_ impl_;
};
} // namespace op_desc
} // namespace tmo
#endif // CSRC_OPS_OP_DESCRIPTOR_BATCHMATMUL_DESCRIPTOR_H_

View File

@@ -0,0 +1,67 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "ffn_descriptor.h"
namespace tmo {
namespace op_desc {
namespace {
static OpDescPool<FeedForwardDescImpl> instance;
} // end of anonymous namespace
FeedForwardDesc::FeedForwardDesc() : is_quanti_(false) {
this->impl_ = instance.get();
}
FeedForwardDesc::FeedForwardDesc(
const cnnlTransformerLayernormResidualStructure_t layernorm_residual_mode,
const std::string &act_name,
cnnlDataType_t compute_type,
float eps,
float alpha,
float beta,
float act_coef)
: is_quanti_(false) {
this->impl_ = instance.get();
this->setFeedForwardDesc(layernorm_residual_mode, act_name, compute_type, eps, alpha, beta,
act_coef);
}
void FeedForwardDesc::setFeedForwardDesc(
const cnnlTransformerLayernormResidualStructure_t layernorm_residual_mode,
const std::string &act_name,
cnnlDataType_t compute_type,
float eps,
float alpha,
float beta,
float act_coef) {
// Now only using bfloat16 and half.
// alpha must be 1.0f when NO_RESIDUAL mode.
CNNL_CHECK_FATAL(cnnlSetTransformerFeedForwardDescriptor_v2(
impl_->ffn_desc_, eps, alpha, beta, compute_type, layernorm_residual_mode));
act_coef = act_name == "silu" ? 1.0f : act_coef;
cnnlActivationMode_t act_mode = strToActivationMode(act_name);
CNNL_CHECK_FATAL(cnnlSetActivationDescriptor_v5(impl_->act_desc_, act_mode, CNNL_ACTIVATION_FAST,
CNNL_NOT_PROPAGATE_NAN, act_coef,
// following parameters are not used
0, 0, 0, 0));
}
FeedForwardDescImpl::FeedForwardDescImpl() {
CNNL_CHECK_FATAL(cnnlCreateTransformerFeedForwardDescriptor(&this->ffn_desc_));
CNNL_CHECK_FATAL(cnnlCreateTransformerFeedForwardQuantizeDescriptor(&this->quant_desc_));
CNNL_CHECK_FATAL(cnnlCreateActivationDescriptor(&this->act_desc_));
}
} // namespace op_desc
} // namespace tmo

View File

@@ -0,0 +1,102 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_OPS_OP_DESCRIPTOR_FFN_DESCRIPTOR_H_
#define CSRC_OPS_OP_DESCRIPTOR_FFN_DESCRIPTOR_H_
#include <functional>
#include <memory>
#include "base_descriptor.h"
#include "common/utils.h"
namespace tmo {
namespace op_desc {
/**
* Note [FeedForwardDesc]
* ~~~~~~~~~~~~~~~~
* FeedForwardDesc is class to create a Transformer FeedForward op descriptor.
*
* ffn_inner_size is intermediate_size;
* act_name mean act type, and will convert act type when set op desc;
* act_coef is only using when act_name is CNNL_ACTIVATION_CLIPPED_RELU,
* CNNL_ACTIVATION_ELU, CNNL_ACTIVATION_ELU_V2, CNNL_ACTIVATION_LEAKYRELU,
* CNNL_ACTIVATION_TF_LEAKYRELU, CNNL_ACTIVATION_CAFFE_RELU6.
*
*/
struct TMO_HIDDEN FeedForwardDescImpl {
FeedForwardDescImpl();
// using clear function to avoid compilation errors.
~FeedForwardDescImpl() { clear(); }
DELETE_COPY_ASSIGN_CONSTRUCT(FeedForwardDescImpl);
// variable
cnnlTransformerFeedForwardDescriptor_t ffn_desc_;
cnnlTransformerFeedForwardQuantizeDescriptor_t quant_desc_;
cnnlActivationDescriptor_t act_desc_;
private:
inline void clear() {
CNNL_CHECK_FATAL(cnnlDestroyTransformerFeedForwardDescriptor(this->ffn_desc_));
CNNL_CHECK_FATAL(cnnlDestroyTransformerFeedForwardQuantizeDescriptor(this->quant_desc_));
CNNL_CHECK_FATAL(cnnlDestroyActivationDescriptor(this->act_desc_));
}
};
class TMO_EXPORT FeedForwardDesc {
public:
using impl_type =
std::unique_ptr<FeedForwardDescImpl, std::function<void(FeedForwardDescImpl *)>>;
FeedForwardDesc();
FeedForwardDesc(const cnnlTransformerLayernormResidualStructure_t layernorm_residual_mode,
const std::string &act_name,
cnnlDataType_t compute_type,
float eps,
float alpha,
float beta,
float act_coef = 0);
void setFeedForwardDesc(const cnnlTransformerLayernormResidualStructure_t layernorm_residual_mode,
const std::string &act_name,
cnnlDataType_t compute_type,
float eps,
float alpha,
float beta,
float act_coef = 0);
// TODO(SG): quanti op desc is not support now, will support later.
// void setQuantiDesc();
// TODO(SG): using std::string to pass layer_norm type? It's not clear now.
// Add layer_norm fused later.
// void setLayerNormInfo();
CLASS_CAST_TYPE_OPERATOR_DEFINE(cnnlTransformerFeedForwardDescriptor_t, impl_->ffn_desc_);
CLASS_CAST_TYPE_OPERATOR_DEFINE(cnnlActivationDescriptor_t, impl_->act_desc_);
operator cnnlTransformerFeedForwardQuantizeDescriptor_t() const {
return is_quanti_ == false
? nullptr
: const_cast<cnnlTransformerFeedForwardQuantizeDescriptor_t>(impl_->quant_desc_);
}
operator cnnlTransformerFeedForwardQuantizeDescriptor_t() {
return is_quanti_ == false ? nullptr : impl_->quant_desc_;
}
private:
impl_type impl_;
bool is_quanti_{false};
};
} // namespace op_desc
} // namespace tmo
#endif // CSRC_OPS_OP_DESCRIPTOR_FFN_DESCRIPTOR_H_

View File

@@ -0,0 +1,178 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "group_gemm_descriptor.h"
#include <cstddef>
#include <cstdint>
namespace tmo {
namespace op_desc {
namespace {
static OpDescPool<GroupGemmDescImpl> instance;
} // end of anonymous namespace
GroupGemmDesc::GroupGemmDesc(int num_expert,
int max_m,
int n,
int k,
cnnlDataType_t dtype,
bool idx_mode,
QuantMode quant_mode) {
this->impl_ = instance.get();
this->idx_mode_ = idx_mode;
this->quant_mode_ = quant_mode;
set(num_expert, max_m, n, k, dtype);
}
void GroupGemmDesc::set(int num_expert, int max_m, int n, int k, cnnlDataType_t dtype) {
this->num_expert_ = num_expert;
this->max_m_ = max_m;
this->n_ = n;
this->k_ = k;
this->dtype_ = dtype;
// no need to set compute_dtype, default is CNNL_DTYPE_FLOAT
CNNL_CHECK_FATAL(cnnlSetGroupGemmDescAttr(impl_->group_gemm_desc_, CNNL_MATMUL_DESC_GROUP_M_MAX,
&max_m, sizeof(int32_t)));
CNNL_CHECK_FATAL(cnnlSetGroupGemmDescAttr(impl_->group_gemm_desc_, CNNL_MATMUL_DESC_GROUP_N_MAX,
&n, sizeof(int32_t)));
CNNL_CHECK_FATAL(cnnlSetGroupGemmDescAttr(impl_->group_gemm_desc_, CNNL_MATMUL_DESC_GROUP_K_MAX,
&k, sizeof(int32_t)));
int dim_nb = 1;
int64_t dims[] = {num_expert};
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(impl_->group_host_tensor_, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_INT32, dim_nb, dims));
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(impl_->group_device_tensor_, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_INT32, dim_nb, dims));
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(impl_->scale_factor_tensor_, CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, dim_nb, dims));
CNNL_CHECK_FATAL(
cnnlSetTensorDescriptorPointerMode(impl_->group_host_tensor_, CNNL_POINTER_MODE_HOST));
CNNL_CHECK_FATAL(
cnnlSetTensorDescriptorPointerMode(impl_->group_device_tensor_, CNNL_POINTER_MODE_DEVICE));
CNNL_CHECK_FATAL(
cnnlSetTensorDescriptorPointerMode(impl_->scale_factor_tensor_, CNNL_POINTER_MODE_DEVICE));
}
void GroupGemmDesc::setPerRowColScaleBiasAct(void *a_scale, void *b_scale) {
setPerRowColScaleBiasAct(a_scale, b_scale, nullptr, nullptr, CNNL_DTYPE_INVALID,
CNNL_ACTIVATION_IDENTITY, CNNL_PROPAGATE_NAN, 0.f, false, 0, 0, 0);
}
void GroupGemmDesc::setPerRowColScaleBiasAct(void *a_scale,
void *b_scale,
void *quant_flag,
void *bias,
cnnlDataType_t bias_dtype,
int channels,
int blk,
int pos) {
setPerRowColScaleBiasAct(a_scale, b_scale, quant_flag, bias, bias_dtype, CNNL_ACTIVATION_IDENTITY,
CNNL_PROPAGATE_NAN, 0.f, false, channels, blk, pos);
}
void GroupGemmDesc::setPerRowColScaleBiasAct(void *a_scale,
void *b_scale,
void *quant_flag,
void *bias,
cnnlDataType_t bias_dtype,
cnnlActivationMode_t mode,
cnnlNanPropagation_t nan_prop,
float coef,
bool has_active,
int channels,
int scale_block_size,
int channel_pos) {
bool is_a_quant = a_scale != nullptr;
bool use_b_scale = false;
bool quant_grouped = (scale_block_size > 0 && channels / scale_block_size > 0) ? true : false;
if (a_scale) {
CNNL_CHECK_FATAL(cnnlSetGroupGemmTensorDescriptor(a_scale_desc(), CNNL_DTYPE_FLOAT, a_scale,
CNNL_POINTER_MODE_DEVICE, nullptr));
}
if (!quant_grouped && b_scale) {
use_b_scale = true;
CNNL_CHECK_FATAL(cnnlSetGroupGemmTensorDescriptor(b_scale_desc(), CNNL_DTYPE_FLOAT, b_scale,
CNNL_POINTER_MODE_DEVICE, nullptr));
}
if (bias) {
this->has_bias_ = true;
CNNL_CHECK_FATAL(cnnlSetGroupGemmTensorDescriptor(bias_desc(), bias_dtype, bias,
CNNL_POINTER_MODE_DEVICE, nullptr));
}
this->has_active_ = has_active;
if (has_active) {
CNNL_CHECK_FATAL(cnnlSetActivationDescriptor(active_desc(), mode, nan_prop, coef));
}
if (quant_grouped) {
const int dim_nb = 2;
int group_num = channels / scale_block_size;
int64_t dims[dim_nb] = {group_num, num_expert_ * n_};
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(b_scale_quant_grouped_desc(), CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_FLOAT, dim_nb, dims));
CNNL_CHECK_FATAL(cnnlSetGroupGemmGroupwiseScale(
group_gemm_desc(), CNNL_MATMUL_DESC_B_SCALE_POINTER, channels, scale_block_size,
channel_pos, b_scale_quant_grouped_desc(), b_scale));
if (quant_flag) {
const int dim_qf = 2;
int64_t dims[dim_qf] = {num_expert_, group_num};
CNNL_CHECK_FATAL(
cnnlSetTensorDescriptorPointerMode(quant_flag_desc(), CNNL_POINTER_MODE_HOST));
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(quant_flag_desc(), CNNL_LAYOUT_ARRAY,
CNNL_DTYPE_INT32, dim_qf, dims));
CNNL_CHECK_FATAL(cnnlSetGroupGemmGroupMixedQuantBitFlag(
group_gemm_desc(), CNNL_MATMUL_DESC_B_QUANT_FLAG_POINTER, channels, scale_block_size,
quant_flag_desc(), quant_flag));
}
}
CNNL_CHECK_FATAL(cnnlSetGroupGemmPerRowColScaleBiasAct(
group_gemm_desc(), is_a_quant ? a_scale_desc() : nullptr,
use_b_scale ? b_scale_desc() : nullptr, this->has_bias_ ? bias_desc() : nullptr,
has_active ? active_desc() : nullptr));
}
void GroupGemmDesc::setInputOutputTensor(cnnlDataType_t a_dtype,
cnnlDataType_t b_dtype,
cnnlDataType_t d_dtype,
cnnlDataType_t idx_dtype,
void *a,
void *b,
void *c,
void *d,
void *idx,
int64_t k,
int64_t k_stride,
int64_t total_m,
bool has_c,
int64_t *b_offset) {
this->has_c_ = has_c;
if (idx_mode()) {
cnnlSetGroupGemmTensorGatherIdxDesc(a_desc(), a_dtype, idx_dtype, a, idx, k, k_stride, total_m);
} else {
CNNL_CHECK_FATAL(
cnnlSetGroupGemmTensorDescriptor(a_desc(), a_dtype, a, CNNL_POINTER_MODE_DEVICE, nullptr));
}
auto b_ptr_mode = b_offset ? CNNL_POINTER_MODE_HOST : CNNL_POINTER_MODE_DEVICE;
CNNL_CHECK_FATAL(cnnlSetGroupGemmTensorDescriptor(b_desc(), b_dtype, b, b_ptr_mode, b_offset));
if (has_c) {
CNNL_CHECK_FATAL(
cnnlSetGroupGemmTensorDescriptor(c_desc(), d_dtype, c, CNNL_POINTER_MODE_DEVICE, nullptr));
}
CNNL_CHECK_FATAL(
cnnlSetGroupGemmTensorDescriptor(d_desc(), d_dtype, d, CNNL_POINTER_MODE_DEVICE, nullptr));
}
} // namespace op_desc
} // namespace tmo

View File

@@ -0,0 +1,177 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_OPS_OP_DESCRIPTOR_GROUP_GEMM_DESCRIPTOR_H_
#define CSRC_OPS_OP_DESCRIPTOR_GROUP_GEMM_DESCRIPTOR_H_
#include <cstddef>
#include <cstdint>
#include <functional>
#include <memory>
#include "base_descriptor.h"
#include "common/utils.h"
namespace tmo {
namespace op_desc {
struct TMO_HIDDEN GroupGemmDescImpl {
GroupGemmDescImpl() {
CNNL_CHECK_FATAL(cnnlCreateGroupGemmDescriptor(&group_gemm_desc_));
CNNL_CHECK_FATAL(
cnnlCreateGroupGemmTensorDescriptor(&a_idx_desc_, CNNL_GROUP_GEMM_ADDRESSING_GATHER_IDX));
CNNL_CHECK_FATAL(
cnnlCreateGroupGemmTensorDescriptor(&a_offset_desc_, CNNL_GROUP_GEMM_ADDRESSING_OFFSETS));
CNNL_CHECK_FATAL(
cnnlCreateGroupGemmTensorDescriptor(&b_desc_, CNNL_GROUP_GEMM_ADDRESSING_OFFSETS));
CNNL_CHECK_FATAL(
cnnlCreateGroupGemmTensorDescriptor(&c_desc_, CNNL_GROUP_GEMM_ADDRESSING_OFFSETS));
CNNL_CHECK_FATAL(
cnnlCreateGroupGemmTensorDescriptor(&d_desc_, CNNL_GROUP_GEMM_ADDRESSING_OFFSETS));
CNNL_CHECK_FATAL(
cnnlCreateGroupGemmTensorDescriptor(&a_scale_desc_, CNNL_GROUP_GEMM_ADDRESSING_OFFSETS));
CNNL_CHECK_FATAL(
cnnlCreateGroupGemmTensorDescriptor(&b_scale_desc_, CNNL_GROUP_GEMM_ADDRESSING_OFFSETS));
CNNL_CHECK_FATAL(
cnnlCreateGroupGemmTensorDescriptor(&bias_desc_, CNNL_GROUP_GEMM_ADDRESSING_OFFSETS));
CNNL_CHECK_FATAL(cnnlCreateTensorDescriptor(&b_scale_quant_grouped_desc_));
CNNL_CHECK_FATAL(cnnlCreateTensorDescriptor(&group_host_tensor_));
CNNL_CHECK_FATAL(cnnlCreateTensorDescriptor(&group_device_tensor_));
CNNL_CHECK_FATAL(cnnlCreateTensorDescriptor(&scale_factor_tensor_));
CNNL_CHECK_FATAL(cnnlCreateTensorDescriptor(&quant_flag_tensor_));
CNNL_CHECK_FATAL(cnnlCreateActivationDescriptor(&active_desc_));
}
// using clear function to avoid compilation errors.
~GroupGemmDescImpl() { clear(); }
DELETE_COPY_ASSIGN_CONSTRUCT(GroupGemmDescImpl);
// variable
cnnlGroupGemmDescriptor_t group_gemm_desc_;
cnnlGroupGemmTensorDescriptor_t a_idx_desc_;
cnnlGroupGemmTensorDescriptor_t a_offset_desc_;
cnnlGroupGemmTensorDescriptor_t b_desc_;
cnnlGroupGemmTensorDescriptor_t c_desc_;
cnnlGroupGemmTensorDescriptor_t d_desc_;
cnnlGroupGemmTensorDescriptor_t a_scale_desc_;
cnnlGroupGemmTensorDescriptor_t b_scale_desc_;
cnnlGroupGemmTensorDescriptor_t bias_desc_;
cnnlTensorDescriptor_t b_scale_quant_grouped_desc_;
cnnlTensorDescriptor_t group_host_tensor_;
cnnlTensorDescriptor_t group_device_tensor_;
cnnlTensorDescriptor_t scale_factor_tensor_;
cnnlTensorDescriptor_t quant_flag_tensor_;
cnnlActivationDescriptor_t active_desc_;
private:
inline void clear() {
CNNL_CHECK_FATAL(
cnnlSetTensorDescriptorPointerMode(group_host_tensor_, CNNL_POINTER_MODE_DEVICE));
CNNL_CHECK_FATAL(cnnlDestroyGroupGemmDescriptor(group_gemm_desc_));
CNNL_CHECK_FATAL(cnnlDestroyGroupGemmTensorDescriptor(a_idx_desc_));
CNNL_CHECK_FATAL(cnnlDestroyGroupGemmTensorDescriptor(a_offset_desc_));
CNNL_CHECK_FATAL(cnnlDestroyGroupGemmTensorDescriptor(b_desc_));
CNNL_CHECK_FATAL(cnnlDestroyGroupGemmTensorDescriptor(c_desc_));
CNNL_CHECK_FATAL(cnnlDestroyGroupGemmTensorDescriptor(d_desc_));
CNNL_CHECK_FATAL(cnnlDestroyGroupGemmTensorDescriptor(a_scale_desc_));
CNNL_CHECK_FATAL(cnnlDestroyGroupGemmTensorDescriptor(b_scale_desc_));
CNNL_CHECK_FATAL(cnnlDestroyGroupGemmTensorDescriptor(bias_desc_));
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(b_scale_quant_grouped_desc_));
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(group_host_tensor_));
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(group_device_tensor_));
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(scale_factor_tensor_));
CNNL_CHECK_FATAL(cnnlDestroyTensorDescriptor(quant_flag_tensor_));
CNNL_CHECK_FATAL(cnnlDestroyActivationDescriptor(active_desc_));
}
};
class TMO_EXPORT GroupGemmDesc {
public:
enum class QuantMode { noQuant, W8, W4, W4W8 };
using impl_type = std::unique_ptr<GroupGemmDescImpl, std::function<void(GroupGemmDescImpl *)>>;
GroupGemmDesc() = delete;
GroupGemmDesc(int num_expert,
int max_m,
int n,
int k,
cnnlDataType_t dtype,
bool idx_mode,
QuantMode quant_mode);
void set(int num_expert, int max_m, int n, int k, cnnlDataType_t dtype);
void setPerRowColScaleBiasAct(void *a_scale, void *b_scale);
void setPerRowColScaleBiasAct(void *a_scale,
void *b_scale,
void *quant_flag,
void *bias,
cnnlDataType_t bias_dtype,
int channels,
int blk,
int pos);
void setPerRowColScaleBiasAct(void *a_scale,
void *b_scale,
void *quant_flag, // for w4/w8 quantize
void *bias, // nullptr
cnnlDataType_t bias_dtype, // CNNL_DTYPE_INVALID
cnnlActivationMode_t mode, // CNNL_ACTIVATION_SILU
cnnlNanPropagation_t nan_prop, // CNNL_PROPAGATE_NAN
float coef, // 0.f
bool has_active, // false
int channels,
int scale_block_size,
int channel_pos);
void setInputOutputTensor(cnnlDataType_t a_dtype,
cnnlDataType_t b_dtype,
cnnlDataType_t d_dtype,
cnnlDataType_t idx_dtype,
void *a,
void *b,
void *c,
void *d,
void *idx,
int64_t k,
int64_t k_stride,
int64_t total_m,
bool has_c,
int64_t *b_offset = nullptr);
cnnlGroupGemmDescriptor_t group_gemm_desc() { return impl_->group_gemm_desc_; }
cnnlGroupGemmTensorDescriptor_t a_desc() {
return idx_mode_ ? impl_->a_idx_desc_ : impl_->a_offset_desc_;
}
cnnlGroupGemmTensorDescriptor_t b_desc() { return impl_->b_desc_; }
cnnlGroupGemmTensorDescriptor_t c_desc() { return impl_->c_desc_; }
cnnlGroupGemmTensorDescriptor_t d_desc() { return impl_->d_desc_; }
cnnlGroupGemmTensorDescriptor_t a_scale_desc() { return impl_->a_scale_desc_; }
cnnlGroupGemmTensorDescriptor_t b_scale_desc() { return impl_->b_scale_desc_; }
cnnlGroupGemmTensorDescriptor_t bias_desc() { return impl_->bias_desc_; }
cnnlTensorDescriptor_t b_scale_quant_grouped_desc() { return impl_->b_scale_quant_grouped_desc_; }
cnnlTensorDescriptor_t group_host_tensor() { return impl_->group_host_tensor_; }
cnnlTensorDescriptor_t group_device_tensor() { return impl_->group_device_tensor_; }
cnnlTensorDescriptor_t scale_factor_tensor() { return impl_->scale_factor_tensor_; }
cnnlTensorDescriptor_t quant_flag_desc() { return impl_->quant_flag_tensor_; }
cnnlActivationDescriptor_t active_desc() { return impl_->active_desc_; }
bool idx_mode() { return idx_mode_; }
bool has_c() { return has_c_; }
private:
impl_type impl_;
bool idx_mode_ = false;
bool has_c_ = false;
bool has_bias_ = false;
bool has_active_ = false;
int num_expert_;
int max_m_;
int n_;
int k_;
cnnlDataType_t dtype_;
QuantMode quant_mode_ = QuantMode::noQuant;
};
} // namespace op_desc
} // namespace tmo
#endif // CSRC_OPS_OP_DESCRIPTOR_GROUP_GEMM_DESCRIPTOR_H_

View File

@@ -0,0 +1,104 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "matmul_descriptor.h"
namespace tmo {
namespace op_desc {
namespace {
static OpDescPool<MatMulDescImpl> matmul_instance;
}
MatMulDesc::MatMulDesc(const cnnlTensorDescriptor_t &bias_tensor,
const void *bias_data,
size_t &workspace_size,
int32_t lda,
int32_t ldb,
int32_t ldc,
const std::string &act_name,
bool use_beta,
bool trans_a,
bool trans_b,
bool fast_act,
bool approximate) {
impl_ = matmul_instance.get();
auto matmul_desc_ = impl_->matmul_desc_;
auto act_desc_ = impl_->act_desc_;
workspace_size = max_workspace_size_;
int32_t matmul_trans_a = int(trans_a);
int32_t matmul_trans_b = int(trans_b);
int32_t matmul_use_tf32 = 0;
int32_t matmul_use_beta = int(use_beta);
CNNL_CHECK_FATAL(cnnlSetMatMulExDescAttr(matmul_desc_, CNNL_MATMUL_EX_DESC_TRANSA,
&(matmul_trans_a), sizeof(int32_t)));
CNNL_CHECK_FATAL(cnnlSetMatMulExDescAttr(matmul_desc_, CNNL_MATMUL_EX_DESC_TRANSB,
&(matmul_trans_b), sizeof(int32_t)));
CNNL_CHECK_FATAL(cnnlSetMatMulExDescAttr(matmul_desc_, CNNL_MATMUL_EX_ALLOW_TF32,
&(matmul_use_tf32), sizeof(int32_t)));
CNNL_CHECK_FATAL(cnnlSetMatMulExDescAttr(matmul_desc_, CNNL_MATMUL_EX_USE_BETA,
&(matmul_use_beta), sizeof(int32_t)));
CNNL_CHECK_FATAL(
cnnlSetMatMulExDescAttr(matmul_desc_, CNNL_MATMUL_EX_DESC_LDA, &(lda), sizeof(int32_t)));
CNNL_CHECK_FATAL(
cnnlSetMatMulExDescAttr(matmul_desc_, CNNL_MATMUL_EX_DESC_LDB, &(ldb), sizeof(int32_t)));
CNNL_CHECK_FATAL(
cnnlSetMatMulExDescAttr(matmul_desc_, CNNL_MATMUL_EX_DESC_LDC, &(ldc), sizeof(int32_t)));
CNNL_CHECK_FATAL(cnnlSetMatMulExDescAttr(matmul_desc_, CNNL_MATMUL_EX_PREF_MAX_WORKSPACE_BYTES,
&(max_workspace_size_), sizeof(int64_t)));
if (act_name == "none") {
if (bias_data == nullptr) {
cnnlMatMulEpilogueType_t fuse_type = CNNL_MATMUL_EPI_NONE;
CNNL_CHECK_FATAL(cnnlSetMatMulExDescAttr(
matmul_desc_, cnnlMatMulExDescAttribute_t::CNNL_MATMUL_EX_DESC_EPILOGUE_TYPE,
(void *)&fuse_type, sizeof(fuse_type)));
} else {
cnnlMatMulEpilogueType_t fuse_type = CNNL_MATMUL_EPI_BIAS;
CNNL_CHECK_FATAL(cnnlSetMatMulExDescAttr(
matmul_desc_, cnnlMatMulExDescAttribute_t::CNNL_MATMUL_EX_DESC_EPILOGUE_TYPE,
(void *)&fuse_type, sizeof(fuse_type)));
CNNL_CHECK_FATAL(cnnlSetMatMulExBias(matmul_desc_, bias_tensor, (void *)bias_data));
}
} else {
float act_coef = 0.0f;
cnnlActivationMode_t act_mode = strToActivationMode(act_name);
if (act_name == "silu") {
act_coef = 1.0f;
act_mode = CNNL_ACTIVATION_SILU;
}
cnnlActivationPreference_t act_pref = CNNL_ACTIVATION_HIGH_PRECISION;
if (fast_act) {
act_pref = CNNL_ACTIVATION_FAST;
}
CNNL_CHECK_FATAL(cnnlSetActivationDescriptor_v6(
act_desc_, act_mode, act_pref, CNNL_NOT_PROPAGATE_NAN, act_coef, 0, 0, 0, 0, approximate));
// fuse info
cnnlMatMulEpilogueType_t fuse_type = CNNL_MATMUL_EPI_BIAS_SCALE_BN_ACTIVATION;
CNNL_CHECK_FATAL(cnnlSetMatMulExDescAttr(
matmul_desc_, cnnlMatMulExDescAttribute_t::CNNL_MATMUL_EX_DESC_EPILOGUE_TYPE, &(fuse_type),
sizeof(fuse_type)));
CNNL_CHECK_FATAL(cnnlSetMatMulExBiasScaleBNActive(matmul_desc_, bias_tensor, bias_data, nullptr,
nullptr, nullptr, nullptr, nullptr, nullptr,
0, 0, 0, 0, 0, act_desc_));
}
}
MatMulDescImpl::MatMulDescImpl() {
CNNL_CHECK_FATAL(cnnlCreateMatMulExDescriptor(&this->matmul_desc_));
CNNL_CHECK_FATAL(cnnlCreateActivationDescriptor(&this->act_desc_));
}
} // namespace op_desc
} // namespace tmo

View File

@@ -0,0 +1,68 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_OPS_OP_DESCRIPTOR_MATMUL_DESCRIPTOR_H_
#define CSRC_OPS_OP_DESCRIPTOR_MATMUL_DESCRIPTOR_H_
#include <functional>
#include <memory>
#include "base_descriptor.h"
#include "common/utils.h"
namespace tmo {
namespace op_desc {
class TMO_HIDDEN MatMulDescImpl {
public:
MatMulDescImpl();
// using clear function to avoid compilation errors.
~MatMulDescImpl() { clear(); }
DELETE_COPY_ASSIGN_CONSTRUCT(MatMulDescImpl);
cnnlMatMulExDescriptor_t matmul_desc_;
cnnlActivationDescriptor_t act_desc_;
private:
inline void clear() {
CNNL_CHECK_FATAL(cnnlDestroyMatMulExDescriptor(this->matmul_desc_));
CNNL_CHECK_FATAL(cnnlDestroyActivationDescriptor(this->act_desc_));
}
};
class TMO_EXPORT MatMulDesc {
public:
using impl_type_ = std::unique_ptr<MatMulDescImpl, std::function<void(MatMulDescImpl *)>>;
MatMulDesc(const cnnlTensorDescriptor_t &bias_tensor,
const void *bias_data,
size_t &workspace_size,
int32_t lda,
int32_t ldb,
int32_t ldc,
const std::string &act_name = "none",
bool use_beta = false,
bool trans_a = false,
bool trans_b = true,
bool fast_act = true,
bool approximate = true);
CLASS_CAST_TYPE_OPERATOR_DEFINE(cnnlMatMulExDescriptor_t, impl_->matmul_desc_)
CLASS_CAST_TYPE_OPERATOR_DEFINE(cnnlActivationDescriptor_t, impl_->act_desc_)
private:
size_t max_workspace_size_ = 128 * 1024 * 1024;
impl_type_ impl_;
};
} // namespace op_desc
} // namespace tmo
#endif // CSRC_OPS_OP_DESCRIPTOR_MATMUL_DESCRIPTOR_H_

View File

@@ -0,0 +1,79 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "quant_matmul_descriptor.h"
namespace tmo {
namespace op_desc {
namespace {
static OpDescPool<QuantMatmulDescImpl> instance;
}
QuantMatmulDesc::QuantMatmulDesc(const std::string &quant_algo,
const std::string &a_quant_layout,
const std::string &b_quant_layout,
unsigned quant_bit_size,
cnnlDataType_t compute_dtype,
const std::string &act_name,
bool use_hp_active,
float act_coef,
bool trans_a,
bool trans_b) {
impl_ = instance.get();
this->setQuantMatmulDesc(quant_algo, a_quant_layout, b_quant_layout, quant_bit_size,
compute_dtype, act_name, use_hp_active, act_coef, trans_a, trans_b);
}
void QuantMatmulDesc::setQuantMatmulDesc(const std::string &quant_algo,
const std::string &a_quant_layout,
const std::string &b_quant_layout,
unsigned quant_bit_size,
cnnlDataType_t compute_dtype,
const std::string &act_name,
bool use_hp_active,
float act_coef,
bool trans_a,
bool trans_b) {
cnnlQuantizeDescriptor_t no_quant;
CNNL_CHECK_FATAL(cnnlCreateQuantizeDescriptor(&no_quant));
// set quantize algo and layout
cnnlLLMQuantAlgo_t cnnl_quant_algo = strToQuantizeAlgo(quant_algo);
CNNL_CHECK_FATAL(cnnlSetQuantizeDescriptorBehaviour(
impl_->a_quant_desc_, strToQuantizeLayout(a_quant_layout), quant_bit_size, 0, false));
CNNL_CHECK_FATAL(cnnlSetQuantizeDescriptorBehaviour(
impl_->b_quant_desc_, strToQuantizeLayout(b_quant_layout), quant_bit_size, 0, false));
// set activation
act_coef = act_name == "silu" ? 1.0f : act_coef;
cnnlActivationMode_t act_mode = strToActivationMode(act_name);
CNNL_CHECK_FATAL(cnnlSetActivationDescriptor_v5(
impl_->act_desc_, act_mode,
use_hp_active ? CNNL_ACTIVATION_HIGH_PRECISION : CNNL_ACTIVATION_FAST, CNNL_NOT_PROPAGATE_NAN,
act_coef, 0, 0, 0, 0));
// set into op_desc
CNNL_CHECK_FATAL(cnnlSetLLMQuantMatmulDescriptor(
impl_->op_desc_, impl_->a_quant_desc_, impl_->b_quant_desc_, no_quant, no_quant,
impl_->act_desc_, compute_dtype, cnnl_quant_algo, trans_a, trans_b));
CNNL_CHECK_FATAL(cnnlDestroyQuantizeDescriptor(no_quant));
}
QuantMatmulDescImpl::QuantMatmulDescImpl() {
CNNL_CHECK_FATAL(cnnlCreateLLMQuantMatmulDescriptor(&this->op_desc_));
CNNL_CHECK_FATAL(cnnlCreateQuantizeDescriptor(&this->a_quant_desc_));
CNNL_CHECK_FATAL(cnnlCreateQuantizeDescriptor(&this->b_quant_desc_));
CNNL_CHECK_FATAL(cnnlCreateQuantizeDescriptor(&this->c_quant_desc_));
CNNL_CHECK_FATAL(cnnlCreateQuantizeDescriptor(&this->d_quant_desc_));
CNNL_CHECK_FATAL(cnnlCreateActivationDescriptor(&this->act_desc_));
}
} // namespace op_desc
} // namespace tmo

View File

@@ -0,0 +1,86 @@
/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_OPS_OP_DESCRIPTOR_QUANT_MATMUL_DESCRIPTOR_H_
#define CSRC_OPS_OP_DESCRIPTOR_QUANT_MATMUL_DESCRIPTOR_H_
#include <functional>
#include <memory>
#include "base_descriptor.h"
#include "common/utils.h"
namespace tmo {
namespace op_desc {
class TMO_HIDDEN QuantMatmulDescImpl {
public:
QuantMatmulDescImpl();
// using clear function to avoid compilation errors.
~QuantMatmulDescImpl() { clear(); }
DELETE_COPY_ASSIGN_CONSTRUCT(QuantMatmulDescImpl);
// variable
cnnlLLMQuantMatmulDescriptor_t op_desc_;
cnnlQuantizeDescriptor_t a_quant_desc_;
cnnlQuantizeDescriptor_t b_quant_desc_;
cnnlQuantizeDescriptor_t c_quant_desc_;
cnnlQuantizeDescriptor_t d_quant_desc_;
cnnlActivationDescriptor_t act_desc_;
private:
inline void clear() {
CNNL_CHECK_FATAL(cnnlDestroyLLMQuantMatmulDescriptor(this->op_desc_));
CNNL_CHECK_FATAL(cnnlDestroyQuantizeDescriptor(this->a_quant_desc_));
CNNL_CHECK_FATAL(cnnlDestroyQuantizeDescriptor(this->b_quant_desc_));
CNNL_CHECK_FATAL(cnnlDestroyQuantizeDescriptor(this->c_quant_desc_));
CNNL_CHECK_FATAL(cnnlDestroyQuantizeDescriptor(this->d_quant_desc_));
CNNL_CHECK_FATAL(cnnlDestroyActivationDescriptor(this->act_desc_));
}
};
class TMO_EXPORT QuantMatmulDesc {
public:
using impl_type_ =
std::unique_ptr<QuantMatmulDescImpl, std::function<void(QuantMatmulDescImpl *)>>;
QuantMatmulDesc(const std::string &quant_algo,
const std::string &a_quant_layout,
const std::string &b_quant_layout,
unsigned quant_bit_size,
cnnlDataType_t compute_dtype,
const std::string &act_name,
bool use_hp_active,
float act_coef,
bool trans_a,
bool trans_b);
// set descriptor
void setQuantMatmulDesc(const std::string &quant_algo,
const std::string &a_quant_layout,
const std::string &b_quant_layout,
unsigned quant_bit_size,
cnnlDataType_t compute_dtype,
const std::string &act_name,
bool use_hp_active,
float act_coef,
bool trans_a,
bool trans_b);
operator cnnlLLMQuantMatmulDescriptor_t() { return this->impl_->op_desc_; }
operator cnnlLLMQuantMatmulDescriptor_t() const { return this->impl_->op_desc_; }
private:
impl_type_ impl_;
};
} // namespace op_desc
} // namespace tmo
#endif // CSRC_OPS_OP_DESCRIPTOR_QUANT_MATMUL_DESCRIPTOR_H_