130 lines
5.5 KiB
C++
130 lines
5.5 KiB
C++
/*************************************************************************
|
|
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
|
*
|
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
|
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
|
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
|
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
|
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
*************************************************************************/
|
|
|
|
#include "torch_ops_api.h"
|
|
namespace tmo {
|
|
namespace torch_api {
|
|
|
|
at::Tensor matmul(const at::Tensor &a,
|
|
const at::Tensor &b,
|
|
const c10::optional<at::Tensor> &d,
|
|
const c10::optional<at::Tensor> &bias,
|
|
const c10::optional<at::Tensor> &c,
|
|
const c10::optional<std::string> &dtype,
|
|
const std::string &act_mode,
|
|
double alpha,
|
|
double beta,
|
|
bool fast_act,
|
|
bool approximate,
|
|
double a_scale,
|
|
double b_scale,
|
|
bool trans_a,
|
|
bool trans_b) {
|
|
bool has_bias = bias.has_value();
|
|
bool has_res = c.has_value();
|
|
bool use_beta = false;
|
|
int m = trans_a ? a.size(1) : a.size(0);
|
|
int n = trans_b ? b.size(0) : b.size(1);
|
|
int k = trans_a ? a.size(0) : a.size(1);
|
|
// check device and dtype
|
|
checkTensorSameAttr<TensorAttr::DEVICE>(a, b, c, bias);
|
|
// check contiguous
|
|
TORCH_CHECK(a.stride(-1) == 1, "a last dim must be contiguous.")
|
|
TORCH_CHECK(b.stride(-1) == 1, "b last dim must be contiguous.")
|
|
if (has_bias) TORCH_CHECK(bias.value().is_contiguous(), "bias must be contiguous.")
|
|
if (has_res) {
|
|
use_beta = true;
|
|
TORCH_CHECK(c.value().is_contiguous(), "c must be contiguous.")
|
|
CHECK_SHAPE(c.value(), m, n);
|
|
}
|
|
at::Tensor bias_view;
|
|
if (has_bias) {
|
|
TORCH_CHECK(bias.value().dim() == 1, "bias must be 1-D tensor.")
|
|
bias_view = bias.value().unsqueeze(0);
|
|
}
|
|
// get cnnl data type and init output
|
|
auto a_dtype = getCnnlDataType(a.scalar_type());
|
|
auto b_dtype = getCnnlDataType(b.scalar_type());
|
|
TORCH_CHECK(a_dtype == b_dtype, "a, b must be same dtype.");
|
|
if (a_dtype == CNNL_DTYPE_INT8) {
|
|
TORCH_CHECK(d.has_value() || dtype.has_value(),
|
|
"d and d_dtype cant be none at the same time when input dtype is int8");
|
|
}
|
|
|
|
at::Tensor output;
|
|
if (d.has_value()) {
|
|
output = d.value();
|
|
} else {
|
|
torch::Dtype out_dtype = a.scalar_type();
|
|
if (dtype.has_value()) {
|
|
TORCH_CHECK(
|
|
dtype.value() == "float" || dtype.value() == "half" || dtype.value() == "bfloat16",
|
|
"data_type must be 'float', 'half' or 'bfloat16'.");
|
|
out_dtype = dtype.value() == "float" ? torch::kFloat32
|
|
: dtype.value() == "half" ? torch::kFloat16
|
|
: torch::kBFloat16;
|
|
}
|
|
output = at::empty({m, n}, a.options().dtype(out_dtype));
|
|
}
|
|
|
|
auto cd_dtype = getCnnlDataType(output.scalar_type());
|
|
if (a_dtype == CNNL_DTYPE_BFLOAT16) cd_dtype = CNNL_DTYPE_FLOAT;
|
|
// create tensor desc
|
|
auto descs = createTensorDescs({a, b, bias_view, c.value_or(at::Tensor()), output});
|
|
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorOnchipDataType(descs[0].get(), a_dtype));
|
|
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorOnchipDataType(descs[1].get(), b_dtype));
|
|
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorOnchipDataType(descs[4].get(), cd_dtype));
|
|
if (has_res) {
|
|
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorOnchipDataType(descs[3].get(), cd_dtype));
|
|
}
|
|
if (a_dtype == CNNL_DTYPE_INT8) {
|
|
float int_max = 127.0;
|
|
int quant_bit = 8;
|
|
float max_a = int_max / a_scale;
|
|
float max_b = int_max / b_scale;
|
|
int pos_a = std::floor(std::log2(max_a) - (quant_bit - 2));
|
|
int pos_b = std::floor(std::log2(max_b) - (quant_bit - 2));
|
|
float new_a_scale = std::pow(2.0f, pos_a) * a_scale;
|
|
float new_b_scale = std::pow(2.0f, pos_b) * b_scale;
|
|
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorPositionAndScale(descs[0].get(), pos_a, new_a_scale));
|
|
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorPositionAndScale(descs[1].get(), pos_b, new_b_scale));
|
|
TORCH_CHECK(cd_dtype != CNNL_DTYPE_BFLOAT16,
|
|
"output dtype cannot be bfloat16 when a/b is fixed-point")
|
|
}
|
|
|
|
// get && set Op desc
|
|
int lda = a.stride(0);
|
|
int ldb = b.stride(0);
|
|
int ldc = output.stride(0);
|
|
size_t workspace_size = 0;
|
|
auto matmul_desc =
|
|
tmo::op_desc::MatMulDesc(descs[2].get(), getAtTensorPtr(bias_view), workspace_size, lda, ldb,
|
|
ldc, act_mode, use_beta, trans_a, trans_b, fast_act, approximate);
|
|
const torch_mlu::mlu::MLUGuard device_guard(a.device());
|
|
auto handle = torch_mlu::getCurrentHandle();
|
|
auto workspace = at::empty({static_cast<int64_t>(workspace_size)}, a.options().dtype(at::kByte));
|
|
|
|
// run forward
|
|
float alpha_f = alpha;
|
|
float beta_f = beta;
|
|
MatmulTheory obj(m, k, n, has_res, a_dtype, cd_dtype);
|
|
cnpxPush(obj);
|
|
CNNL_CHECK_FATAL(cnnlMatMulEx(handle, matmul_desc, &alpha_f, descs[0].get(), getAtTensorPtr(a),
|
|
descs[1].get(), getAtTensorPtr(b), &beta_f, descs[3].get(),
|
|
getAtTensorPtr(c), descs[4].get(), getAtTensorPtr(output), nullptr,
|
|
getAtTensorPtr(workspace), workspace_size));
|
|
cnpxPop();
|
|
return output;
|
|
}
|
|
} // namespace torch_api
|
|
} // namespace tmo
|