Files
2026-02-04 17:39:32 +08:00

90 lines
5.0 KiB
C++

/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include <cstddef>
#include <vector>
#include "kernel_api.h"
namespace tmo {
namespace ops {
size_t getGroupGemmWorkspaceSize(cnnlHandle_t handle, GroupGemmDesc &desc, const int num_expert) {
size_t workspace_size = 0;
CNNL_CHECK_FATAL(
cnnlGetGroupGemmWorkspaceSize(handle, // handle
desc.group_gemm_desc(), // cnnlGroupGemmDescriptor_t
nullptr, // cnnlGroupGemmAlgo_t
num_expert, // groups
false, // is_trans_a
true, // is_trans_b
desc.group_host_tensor(), // m_desc
desc.group_host_tensor(), // n_desc
desc.group_host_tensor(), // k_desc
nullptr, // alpha_desc
desc.group_host_tensor(), // lda_desc
desc.a_desc(), // a_desc
desc.group_host_tensor(), // ldb_desc
desc.b_desc(), // b_desc
nullptr, // beta_desc
nullptr, // ldc_desc
nullptr, // c_desc
desc.group_host_tensor(), // ldd_desc
desc.d_desc(), // d_desc
&workspace_size));
return workspace_size;
}
void GroupGemm(const cnnlHandle_t &handle,
GroupGemmDesc &desc,
void *m,
void *alpha,
void *beta,
void *workspace,
size_t workspace_size,
int num_expert,
int k,
int n,
int lda,
std::vector<int> &ldb) {
std::vector<int> n_array(num_expert, n);
std::vector<int> k_array(num_expert, k);
std::vector<int> lda_array(num_expert, lda);
std::vector<int> ldd_array(num_expert, n);
CNNL_CHECK_FATAL(cnnlGroupGemm(handle, desc.group_gemm_desc(), /*desc*/
nullptr, /*algo*/
num_expert, /*groups*/
false, /*is_trans_a*/
true, /*is_trans_b*/
desc.group_device_tensor(), m, /*m*/
desc.group_host_tensor(), n_array.data(), /*n*/
desc.group_host_tensor(), k_array.data(), /*k*/
alpha ? desc.scale_factor_tensor() : nullptr, /*alpha_desc*/
alpha, /*alpha*/
desc.group_host_tensor(), lda_array.data(), /*lda*/
desc.a_desc(), /*a*/
ldb.empty() ? nullptr : desc.group_host_tensor(),
ldb.empty() ? nullptr : ldb.data(), /*ldb*/
desc.b_desc(), /*b*/
beta ? desc.scale_factor_tensor() : nullptr, /*beta_desc*/
beta, /*beta*/
desc.group_host_tensor(), /*ldc_desc*/
ldd_array.data(), /*ldc*/
desc.has_c() ? desc.c_desc() : desc.d_desc(), /*c*/
workspace, workspace_size, /*workspace*/
desc.group_host_tensor(), ldd_array.data(), /*ldd*/
desc.d_desc())); /*d*/
}
} // namespace ops
} // namespace tmo