add ops
This commit is contained in:
89
torch_mlu_ops-v1.3.2/csrc/ops/GroupGemm.cpp
Normal file
89
torch_mlu_ops-v1.3.2/csrc/ops/GroupGemm.cpp
Normal file
@@ -0,0 +1,89 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include <cstddef>
|
||||
#include <vector>
|
||||
#include "kernel_api.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace ops {
|
||||
|
||||
size_t getGroupGemmWorkspaceSize(cnnlHandle_t handle, GroupGemmDesc &desc, const int num_expert) {
|
||||
size_t workspace_size = 0;
|
||||
CNNL_CHECK_FATAL(
|
||||
cnnlGetGroupGemmWorkspaceSize(handle, // handle
|
||||
desc.group_gemm_desc(), // cnnlGroupGemmDescriptor_t
|
||||
nullptr, // cnnlGroupGemmAlgo_t
|
||||
num_expert, // groups
|
||||
false, // is_trans_a
|
||||
true, // is_trans_b
|
||||
desc.group_host_tensor(), // m_desc
|
||||
desc.group_host_tensor(), // n_desc
|
||||
desc.group_host_tensor(), // k_desc
|
||||
nullptr, // alpha_desc
|
||||
desc.group_host_tensor(), // lda_desc
|
||||
desc.a_desc(), // a_desc
|
||||
desc.group_host_tensor(), // ldb_desc
|
||||
desc.b_desc(), // b_desc
|
||||
nullptr, // beta_desc
|
||||
nullptr, // ldc_desc
|
||||
nullptr, // c_desc
|
||||
desc.group_host_tensor(), // ldd_desc
|
||||
desc.d_desc(), // d_desc
|
||||
&workspace_size));
|
||||
return workspace_size;
|
||||
}
|
||||
|
||||
void GroupGemm(const cnnlHandle_t &handle,
|
||||
GroupGemmDesc &desc,
|
||||
void *m,
|
||||
void *alpha,
|
||||
void *beta,
|
||||
void *workspace,
|
||||
size_t workspace_size,
|
||||
int num_expert,
|
||||
int k,
|
||||
int n,
|
||||
int lda,
|
||||
std::vector<int> &ldb) {
|
||||
std::vector<int> n_array(num_expert, n);
|
||||
std::vector<int> k_array(num_expert, k);
|
||||
std::vector<int> lda_array(num_expert, lda);
|
||||
std::vector<int> ldd_array(num_expert, n);
|
||||
|
||||
CNNL_CHECK_FATAL(cnnlGroupGemm(handle, desc.group_gemm_desc(), /*desc*/
|
||||
nullptr, /*algo*/
|
||||
num_expert, /*groups*/
|
||||
false, /*is_trans_a*/
|
||||
true, /*is_trans_b*/
|
||||
desc.group_device_tensor(), m, /*m*/
|
||||
desc.group_host_tensor(), n_array.data(), /*n*/
|
||||
desc.group_host_tensor(), k_array.data(), /*k*/
|
||||
alpha ? desc.scale_factor_tensor() : nullptr, /*alpha_desc*/
|
||||
alpha, /*alpha*/
|
||||
desc.group_host_tensor(), lda_array.data(), /*lda*/
|
||||
desc.a_desc(), /*a*/
|
||||
ldb.empty() ? nullptr : desc.group_host_tensor(),
|
||||
ldb.empty() ? nullptr : ldb.data(), /*ldb*/
|
||||
desc.b_desc(), /*b*/
|
||||
beta ? desc.scale_factor_tensor() : nullptr, /*beta_desc*/
|
||||
beta, /*beta*/
|
||||
desc.group_host_tensor(), /*ldc_desc*/
|
||||
ldd_array.data(), /*ldc*/
|
||||
desc.has_c() ? desc.c_desc() : desc.d_desc(), /*c*/
|
||||
workspace, workspace_size, /*workspace*/
|
||||
desc.group_host_tensor(), ldd_array.data(), /*ldd*/
|
||||
desc.d_desc())); /*d*/
|
||||
}
|
||||
|
||||
} // namespace ops
|
||||
} // namespace tmo
|
||||
Reference in New Issue
Block a user