/************************************************************************* * Copyright (C) [2023-2024] by Cambricon, Inc. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ #include "utils.h" #include "common/utils.h" namespace tmo { namespace torch_api { #define CNNL_TYPE_AND_SCALAR_TYPE_WITHOUT_64BIT(_) \ _(CNNL_DTYPE_FLOAT, at::kFloat) \ _(CNNL_DTYPE_BFLOAT16, at::kBFloat16) \ _(CNNL_DTYPE_HALF, at::kHalf) \ _(CNNL_DTYPE_INT32, at::kInt) \ _(CNNL_DTYPE_INT8, at::kChar) \ _(CNNL_DTYPE_UINT8, at::kByte) \ _(CNNL_DTYPE_BOOL, at::kBool) \ _(CNNL_DTYPE_INT16, at::kShort) \ _(CNNL_DTYPE_COMPLEX_HALF, at::kComplexHalf) \ _(CNNL_DTYPE_COMPLEX_FLOAT, at::kComplexFloat) cnnlDataType_t getCnnlDataType(const at::ScalarType &data_type) { switch (data_type) { #define DEFINE_CASE(cnnl_dtype, scalar_type) \ case scalar_type: \ return cnnl_dtype; CNNL_TYPE_AND_SCALAR_TYPE_WITHOUT_64BIT(DEFINE_CASE) #undef DEFINE_CASE case at::kLong: return CNNL_DTYPE_INT32; case at::kDouble: return CNNL_DTYPE_FLOAT; case at::kComplexDouble: return CNNL_DTYPE_COMPLEX_FLOAT; default: std::string msg("getCnnlDataType() not supported for "); throw std::runtime_error(msg + c10::toString(data_type)); } } std::vector createTensorDescs(const std::initializer_list &tensors) { std::vector descs; for (size_t i = 0; i < tensors.size(); ++i) { descs.emplace_back(TensorDesc{nullptr, cnnlDestroyTensorDescriptor}); auto tensor = tensors.begin()[i]; if (!tensor.defined()) { continue; } cnnlTensorDescriptor_t desc; CNNL_CHECK_FATAL(cnnlCreateTensorDescriptor(&desc)); descs[i].reset(desc); cnnlDataType_t data_type = getCnnlDataType(tensor.scalar_type()); if (tensor.strides().size() == 0) { CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(descs[i].get(), CNNL_LAYOUT_ARRAY, data_type, tensor.sizes().size(), tensor.sizes().data())); } else { CNNL_CHECK_FATAL(cnnlSetTensorDescriptorEx_v2(descs[i].get(), CNNL_LAYOUT_ARRAY, data_type, tensor.sizes().size(), tensor.sizes().data(), tensor.strides().data())); } } return descs; } } // namespace torch_api } // namespace tmo