Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/csrc/torch_api/matmul.cpp
2026-02-04 17:39:32 +08:00

130 lines
5.5 KiB
C++

/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#include "torch_ops_api.h"
namespace tmo {
namespace torch_api {
at::Tensor matmul(const at::Tensor &a,
const at::Tensor &b,
const c10::optional<at::Tensor> &d,
const c10::optional<at::Tensor> &bias,
const c10::optional<at::Tensor> &c,
const c10::optional<std::string> &dtype,
const std::string &act_mode,
double alpha,
double beta,
bool fast_act,
bool approximate,
double a_scale,
double b_scale,
bool trans_a,
bool trans_b) {
bool has_bias = bias.has_value();
bool has_res = c.has_value();
bool use_beta = false;
int m = trans_a ? a.size(1) : a.size(0);
int n = trans_b ? b.size(0) : b.size(1);
int k = trans_a ? a.size(0) : a.size(1);
// check device and dtype
checkTensorSameAttr<TensorAttr::DEVICE>(a, b, c, bias);
// check contiguous
TORCH_CHECK(a.stride(-1) == 1, "a last dim must be contiguous.")
TORCH_CHECK(b.stride(-1) == 1, "b last dim must be contiguous.")
if (has_bias) TORCH_CHECK(bias.value().is_contiguous(), "bias must be contiguous.")
if (has_res) {
use_beta = true;
TORCH_CHECK(c.value().is_contiguous(), "c must be contiguous.")
CHECK_SHAPE(c.value(), m, n);
}
at::Tensor bias_view;
if (has_bias) {
TORCH_CHECK(bias.value().dim() == 1, "bias must be 1-D tensor.")
bias_view = bias.value().unsqueeze(0);
}
// get cnnl data type and init output
auto a_dtype = getCnnlDataType(a.scalar_type());
auto b_dtype = getCnnlDataType(b.scalar_type());
TORCH_CHECK(a_dtype == b_dtype, "a, b must be same dtype.");
if (a_dtype == CNNL_DTYPE_INT8) {
TORCH_CHECK(d.has_value() || dtype.has_value(),
"d and d_dtype cant be none at the same time when input dtype is int8");
}
at::Tensor output;
if (d.has_value()) {
output = d.value();
} else {
torch::Dtype out_dtype = a.scalar_type();
if (dtype.has_value()) {
TORCH_CHECK(
dtype.value() == "float" || dtype.value() == "half" || dtype.value() == "bfloat16",
"data_type must be 'float', 'half' or 'bfloat16'.");
out_dtype = dtype.value() == "float" ? torch::kFloat32
: dtype.value() == "half" ? torch::kFloat16
: torch::kBFloat16;
}
output = at::empty({m, n}, a.options().dtype(out_dtype));
}
auto cd_dtype = getCnnlDataType(output.scalar_type());
if (a_dtype == CNNL_DTYPE_BFLOAT16) cd_dtype = CNNL_DTYPE_FLOAT;
// create tensor desc
auto descs = createTensorDescs({a, b, bias_view, c.value_or(at::Tensor()), output});
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorOnchipDataType(descs[0].get(), a_dtype));
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorOnchipDataType(descs[1].get(), b_dtype));
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorOnchipDataType(descs[4].get(), cd_dtype));
if (has_res) {
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorOnchipDataType(descs[3].get(), cd_dtype));
}
if (a_dtype == CNNL_DTYPE_INT8) {
float int_max = 127.0;
int quant_bit = 8;
float max_a = int_max / a_scale;
float max_b = int_max / b_scale;
int pos_a = std::floor(std::log2(max_a) - (quant_bit - 2));
int pos_b = std::floor(std::log2(max_b) - (quant_bit - 2));
float new_a_scale = std::pow(2.0f, pos_a) * a_scale;
float new_b_scale = std::pow(2.0f, pos_b) * b_scale;
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorPositionAndScale(descs[0].get(), pos_a, new_a_scale));
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorPositionAndScale(descs[1].get(), pos_b, new_b_scale));
TORCH_CHECK(cd_dtype != CNNL_DTYPE_BFLOAT16,
"output dtype cannot be bfloat16 when a/b is fixed-point")
}
// get && set Op desc
int lda = a.stride(0);
int ldb = b.stride(0);
int ldc = output.stride(0);
size_t workspace_size = 0;
auto matmul_desc =
tmo::op_desc::MatMulDesc(descs[2].get(), getAtTensorPtr(bias_view), workspace_size, lda, ldb,
ldc, act_mode, use_beta, trans_a, trans_b, fast_act, approximate);
const torch_mlu::mlu::MLUGuard device_guard(a.device());
auto handle = torch_mlu::getCurrentHandle();
auto workspace = at::empty({static_cast<int64_t>(workspace_size)}, a.options().dtype(at::kByte));
// run forward
float alpha_f = alpha;
float beta_f = beta;
MatmulTheory obj(m, k, n, has_res, a_dtype, cd_dtype);
cnpxPush(obj);
CNNL_CHECK_FATAL(cnnlMatMulEx(handle, matmul_desc, &alpha_f, descs[0].get(), getAtTensorPtr(a),
descs[1].get(), getAtTensorPtr(b), &beta_f, descs[3].get(),
getAtTensorPtr(c), descs[4].get(), getAtTensorPtr(output), nullptr,
getAtTensorPtr(workspace), workspace_size));
cnpxPop();
return output;
}
} // namespace torch_api
} // namespace tmo