/************************************************************************* * Copyright (C) [2023-2024] by Cambricon, Inc. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ #ifndef CSRC_TORCH_API_UTILS_H_ #define CSRC_TORCH_API_UTILS_H_ #include #include #include #include #include "ATen/ScalarType.h" #include "ATen/Tensor.h" #include "aten/cnnl/cnnlHandle.h" #include "c10/util/Exception.h" #include "cnnl.h" #include "framework/core/MLUStream.h" #include "framework/core/caching_allocator.h" #include "framework/core/device.h" #include "framework/core/mlu_guard.h" #include "torch/torch.h" #include "torch/version.h" namespace tmo { namespace torch_api { using TensorDesc = std::unique_ptr, decltype(&cnnlDestroyTensorDescriptor)>; std::vector createTensorDescs(const std::initializer_list &tensors); cnnlDataType_t getCnnlDataType(const at::ScalarType &data_type); template bool isMlu(const T &tensor) { #if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 1) return tensor.device().is_privateuseone(); #else return tensor.is_mlu(); #endif } enum class TensorAttr { DEVICE, DTYPE, ALL }; struct attr_t { int64_t device_id; at::ScalarType dtype; }; inline void checkDevice(int64_t &device_id, const at::Tensor &tensor) { auto tensor_device_id = tensor.get_device(); if (device_id == -1) { device_id = tensor_device_id; return; } TORCH_CHECK(tensor_device_id == device_id, "Tensor device id is not same, original device_id: ", device_id, "now device_id is: ", tensor_device_id); } inline void checkDtype(at::ScalarType &dtype, const at::Tensor &tensor) { auto tensor_dtype = tensor.scalar_type(); if (dtype == at::ScalarType::Undefined) { dtype = tensor_dtype; return; } TORCH_CHECK(tensor_dtype == dtype, "Tensor dtype is not same. original dtype: ", dtype, "now dtype is: ", tensor_dtype); } template inline void checkTensorAttr(attr_t &attr_states, const at::Tensor &tensor) { if (attr == TensorAttr::DEVICE) { checkDevice(attr_states.device_id, tensor); } else if (attr == TensorAttr::DTYPE) { checkDtype(attr_states.dtype, tensor); } else if (attr == TensorAttr::ALL) { checkDevice(attr_states.device_id, tensor); checkDtype(attr_states.dtype, tensor); } } template ::type, at::Tensor>::value>::type> void checkTensorSameWithSpecificAttr(attr_t &attr_states, const c10::optional &tensor) { if (!tensor.has_value() || !tensor->defined()) return; auto temp_tensor = tensor.value(); TORCH_CHECK(isMlu(temp_tensor), "Only support mlu tensor."); checkTensorAttr(attr_states, temp_tensor); } template ::type, at::Tensor>::value>::type> void checkTensorSameWithSpecificAttr(attr_t &attr_states, const T &tensor) { if (!tensor.defined()) return; TORCH_CHECK(isMlu(tensor), "Only support mlu tensor."); checkTensorAttr(attr_states, tensor); } template void checkTensorSameWithSpecificAttr(attr_t &attr_states, const T &tensor, Args &&...args) { checkTensorSameWithSpecificAttr(attr_states, tensor); checkTensorSameWithSpecificAttr(attr_states, std::forward(args)...); } template void checkTensorSameAttr(Args &&...args) { attr_t attr_states = {-1, at::ScalarType::Undefined}; checkTensorSameWithSpecificAttr(attr_states, std::forward(args)...); } inline at::ScalarType str2TorchDtype(const std::string &type) { static std::map dtype_map = { {"float", torch::kFloat32}, {"half", torch::kHalf}, {"bfloat16", torch::kBFloat16}, {"int32", torch::kInt32}, {"int8", torch::kInt8}, }; return dtype_map.at(type); } inline std::string &torchDtype2Str(const at::ScalarType type) { static std::map torch_dtype_map = { {torch::kFloat32, "float"}, {torch::kHalf, "half"}, {torch::kBFloat16, "bfloat16"}, {torch::kInt32, "int32"}, {torch::kInt8, "int8"}, }; return torch_dtype_map.at(type); } inline cnnlDataType_t str2CnnlDtype(const std::string &type) { static std::map cnnl_dtype_map = { {"float", CNNL_DTYPE_FLOAT}, {"half", CNNL_DTYPE_HALF}, {"bfloat16", CNNL_DTYPE_BFLOAT16}, {"int32", CNNL_DTYPE_INT32}, {"int8", CNNL_DTYPE_INT8}, }; return cnnl_dtype_map.at(type); } #define CHECK_SHAPE(x, ...) \ TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), \ #x " must have shape (" #__VA_ARGS__ ")") #define CHECK_TENSOR_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous.") #define CHECK_OPTIONAL_TENSOR_CONTIGUOUS(x) \ if (x.has_value()) TORCH_CHECK(x.value().is_contiguous(), #x " must be contiguous.") inline void *getAtTensorPtr(const c10::optional &tensor) { return tensor.has_value() ? tensor.value().data_ptr() : nullptr; } inline void *getAtTensorPtr(const at::Tensor &tensor) { return tensor.defined() ? tensor.data_ptr() : nullptr; } } // namespace torch_api } // namespace tmo #endif // CSRC_TORCH_API_UTILS_H_