/************************************************************************* * Copyright (C) [2023-2024] by Cambricon, Inc. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ #include #include #include "kernel_api.h" namespace tmo { namespace ops { size_t getGroupGemmWorkspaceSize(cnnlHandle_t handle, GroupGemmDesc &desc, const int num_expert) { size_t workspace_size = 0; CNNL_CHECK_FATAL( cnnlGetGroupGemmWorkspaceSize(handle, // handle desc.group_gemm_desc(), // cnnlGroupGemmDescriptor_t nullptr, // cnnlGroupGemmAlgo_t num_expert, // groups false, // is_trans_a true, // is_trans_b desc.group_host_tensor(), // m_desc desc.group_host_tensor(), // n_desc desc.group_host_tensor(), // k_desc nullptr, // alpha_desc desc.group_host_tensor(), // lda_desc desc.a_desc(), // a_desc desc.group_host_tensor(), // ldb_desc desc.b_desc(), // b_desc nullptr, // beta_desc nullptr, // ldc_desc nullptr, // c_desc desc.group_host_tensor(), // ldd_desc desc.d_desc(), // d_desc &workspace_size)); return workspace_size; } void GroupGemm(const cnnlHandle_t &handle, GroupGemmDesc &desc, void *m, void *alpha, void *beta, void *workspace, size_t workspace_size, int num_expert, int k, int n, int lda, std::vector &ldb) { std::vector n_array(num_expert, n); std::vector k_array(num_expert, k); std::vector lda_array(num_expert, lda); std::vector ldd_array(num_expert, n); CNNL_CHECK_FATAL(cnnlGroupGemm(handle, desc.group_gemm_desc(), /*desc*/ nullptr, /*algo*/ num_expert, /*groups*/ false, /*is_trans_a*/ true, /*is_trans_b*/ desc.group_device_tensor(), m, /*m*/ desc.group_host_tensor(), n_array.data(), /*n*/ desc.group_host_tensor(), k_array.data(), /*k*/ alpha ? desc.scale_factor_tensor() : nullptr, /*alpha_desc*/ alpha, /*alpha*/ desc.group_host_tensor(), lda_array.data(), /*lda*/ desc.a_desc(), /*a*/ ldb.empty() ? nullptr : desc.group_host_tensor(), ldb.empty() ? nullptr : ldb.data(), /*ldb*/ desc.b_desc(), /*b*/ beta ? desc.scale_factor_tensor() : nullptr, /*beta_desc*/ beta, /*beta*/ desc.group_host_tensor(), /*ldc_desc*/ ldd_array.data(), /*ldc*/ desc.has_c() ? desc.c_desc() : desc.d_desc(), /*c*/ workspace, workspace_size, /*workspace*/ desc.group_host_tensor(), ldd_array.data(), /*ldd*/ desc.d_desc())); /*d*/ } } // namespace ops } // namespace tmo