90 lines
5.0 KiB
C++
90 lines
5.0 KiB
C++
/*************************************************************************
|
|
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
|
*
|
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
|
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
|
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
|
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
|
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
*************************************************************************/
|
|
|
|
#include <cstddef>
|
|
#include <vector>
|
|
#include "kernel_api.h"
|
|
|
|
namespace tmo {
|
|
namespace ops {
|
|
|
|
size_t getGroupGemmWorkspaceSize(cnnlHandle_t handle, GroupGemmDesc &desc, const int num_expert) {
|
|
size_t workspace_size = 0;
|
|
CNNL_CHECK_FATAL(
|
|
cnnlGetGroupGemmWorkspaceSize(handle, // handle
|
|
desc.group_gemm_desc(), // cnnlGroupGemmDescriptor_t
|
|
nullptr, // cnnlGroupGemmAlgo_t
|
|
num_expert, // groups
|
|
false, // is_trans_a
|
|
true, // is_trans_b
|
|
desc.group_host_tensor(), // m_desc
|
|
desc.group_host_tensor(), // n_desc
|
|
desc.group_host_tensor(), // k_desc
|
|
nullptr, // alpha_desc
|
|
desc.group_host_tensor(), // lda_desc
|
|
desc.a_desc(), // a_desc
|
|
desc.group_host_tensor(), // ldb_desc
|
|
desc.b_desc(), // b_desc
|
|
nullptr, // beta_desc
|
|
nullptr, // ldc_desc
|
|
nullptr, // c_desc
|
|
desc.group_host_tensor(), // ldd_desc
|
|
desc.d_desc(), // d_desc
|
|
&workspace_size));
|
|
return workspace_size;
|
|
}
|
|
|
|
void GroupGemm(const cnnlHandle_t &handle,
|
|
GroupGemmDesc &desc,
|
|
void *m,
|
|
void *alpha,
|
|
void *beta,
|
|
void *workspace,
|
|
size_t workspace_size,
|
|
int num_expert,
|
|
int k,
|
|
int n,
|
|
int lda,
|
|
std::vector<int> &ldb) {
|
|
std::vector<int> n_array(num_expert, n);
|
|
std::vector<int> k_array(num_expert, k);
|
|
std::vector<int> lda_array(num_expert, lda);
|
|
std::vector<int> ldd_array(num_expert, n);
|
|
|
|
CNNL_CHECK_FATAL(cnnlGroupGemm(handle, desc.group_gemm_desc(), /*desc*/
|
|
nullptr, /*algo*/
|
|
num_expert, /*groups*/
|
|
false, /*is_trans_a*/
|
|
true, /*is_trans_b*/
|
|
desc.group_device_tensor(), m, /*m*/
|
|
desc.group_host_tensor(), n_array.data(), /*n*/
|
|
desc.group_host_tensor(), k_array.data(), /*k*/
|
|
alpha ? desc.scale_factor_tensor() : nullptr, /*alpha_desc*/
|
|
alpha, /*alpha*/
|
|
desc.group_host_tensor(), lda_array.data(), /*lda*/
|
|
desc.a_desc(), /*a*/
|
|
ldb.empty() ? nullptr : desc.group_host_tensor(),
|
|
ldb.empty() ? nullptr : ldb.data(), /*ldb*/
|
|
desc.b_desc(), /*b*/
|
|
beta ? desc.scale_factor_tensor() : nullptr, /*beta_desc*/
|
|
beta, /*beta*/
|
|
desc.group_host_tensor(), /*ldc_desc*/
|
|
ldd_array.data(), /*ldc*/
|
|
desc.has_c() ? desc.c_desc() : desc.d_desc(), /*c*/
|
|
workspace, workspace_size, /*workspace*/
|
|
desc.group_host_tensor(), ldd_array.data(), /*ldd*/
|
|
desc.d_desc())); /*d*/
|
|
}
|
|
|
|
} // namespace ops
|
|
} // namespace tmo
|