/************************************************************************* * Copyright (C) [2023-2024] by Cambricon, Inc. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ #include "torch_ops_api.h" namespace tmo { namespace torch_api { at::Tensor matmul(const at::Tensor &a, const at::Tensor &b, const c10::optional &d, const c10::optional &bias, const c10::optional &c, const c10::optional &dtype, const std::string &act_mode, double alpha, double beta, bool fast_act, bool approximate, double a_scale, double b_scale, bool trans_a, bool trans_b) { bool has_bias = bias.has_value(); bool has_res = c.has_value(); bool use_beta = false; int m = trans_a ? a.size(1) : a.size(0); int n = trans_b ? b.size(0) : b.size(1); int k = trans_a ? a.size(0) : a.size(1); // check device and dtype checkTensorSameAttr(a, b, c, bias); // check contiguous TORCH_CHECK(a.stride(-1) == 1, "a last dim must be contiguous.") TORCH_CHECK(b.stride(-1) == 1, "b last dim must be contiguous.") if (has_bias) TORCH_CHECK(bias.value().is_contiguous(), "bias must be contiguous.") if (has_res) { use_beta = true; TORCH_CHECK(c.value().is_contiguous(), "c must be contiguous.") CHECK_SHAPE(c.value(), m, n); } at::Tensor bias_view; if (has_bias) { TORCH_CHECK(bias.value().dim() == 1, "bias must be 1-D tensor.") bias_view = bias.value().unsqueeze(0); } // get cnnl data type and init output auto a_dtype = getCnnlDataType(a.scalar_type()); auto b_dtype = getCnnlDataType(b.scalar_type()); TORCH_CHECK(a_dtype == b_dtype, "a, b must be same dtype."); if (a_dtype == CNNL_DTYPE_INT8) { TORCH_CHECK(d.has_value() || dtype.has_value(), "d and d_dtype cant be none at the same time when input dtype is int8"); } at::Tensor output; if (d.has_value()) { output = d.value(); } else { torch::Dtype out_dtype = a.scalar_type(); if (dtype.has_value()) { TORCH_CHECK( dtype.value() == "float" || dtype.value() == "half" || dtype.value() == "bfloat16", "data_type must be 'float', 'half' or 'bfloat16'."); out_dtype = dtype.value() == "float" ? torch::kFloat32 : dtype.value() == "half" ? torch::kFloat16 : torch::kBFloat16; } output = at::empty({m, n}, a.options().dtype(out_dtype)); } auto cd_dtype = getCnnlDataType(output.scalar_type()); if (a_dtype == CNNL_DTYPE_BFLOAT16) cd_dtype = CNNL_DTYPE_FLOAT; // create tensor desc auto descs = createTensorDescs({a, b, bias_view, c.value_or(at::Tensor()), output}); CNNL_CHECK_FATAL(cnnlSetTensorDescriptorOnchipDataType(descs[0].get(), a_dtype)); CNNL_CHECK_FATAL(cnnlSetTensorDescriptorOnchipDataType(descs[1].get(), b_dtype)); CNNL_CHECK_FATAL(cnnlSetTensorDescriptorOnchipDataType(descs[4].get(), cd_dtype)); if (has_res) { CNNL_CHECK_FATAL(cnnlSetTensorDescriptorOnchipDataType(descs[3].get(), cd_dtype)); } if (a_dtype == CNNL_DTYPE_INT8) { float int_max = 127.0; int quant_bit = 8; float max_a = int_max / a_scale; float max_b = int_max / b_scale; int pos_a = std::floor(std::log2(max_a) - (quant_bit - 2)); int pos_b = std::floor(std::log2(max_b) - (quant_bit - 2)); float new_a_scale = std::pow(2.0f, pos_a) * a_scale; float new_b_scale = std::pow(2.0f, pos_b) * b_scale; CNNL_CHECK_FATAL(cnnlSetTensorDescriptorPositionAndScale(descs[0].get(), pos_a, new_a_scale)); CNNL_CHECK_FATAL(cnnlSetTensorDescriptorPositionAndScale(descs[1].get(), pos_b, new_b_scale)); TORCH_CHECK(cd_dtype != CNNL_DTYPE_BFLOAT16, "output dtype cannot be bfloat16 when a/b is fixed-point") } // get && set Op desc int lda = a.stride(0); int ldb = b.stride(0); int ldc = output.stride(0); size_t workspace_size = 0; auto matmul_desc = tmo::op_desc::MatMulDesc(descs[2].get(), getAtTensorPtr(bias_view), workspace_size, lda, ldb, ldc, act_mode, use_beta, trans_a, trans_b, fast_act, approximate); const torch_mlu::mlu::MLUGuard device_guard(a.device()); auto handle = torch_mlu::getCurrentHandle(); auto workspace = at::empty({static_cast(workspace_size)}, a.options().dtype(at::kByte)); // run forward float alpha_f = alpha; float beta_f = beta; MatmulTheory obj(m, k, n, has_res, a_dtype, cd_dtype); cnpxPush(obj); CNNL_CHECK_FATAL(cnnlMatMulEx(handle, matmul_desc, &alpha_f, descs[0].get(), getAtTensorPtr(a), descs[1].get(), getAtTensorPtr(b), &beta_f, descs[3].get(), getAtTensorPtr(c), descs[4].get(), getAtTensorPtr(output), nullptr, getAtTensorPtr(workspace), workspace_size)); cnpxPop(); return output; } } // namespace torch_api } // namespace tmo