165 lines
6.0 KiB
C++
165 lines
6.0 KiB
C++
/*************************************************************************
|
|
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
|
*
|
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
|
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
|
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
|
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
|
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
|
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
*************************************************************************/
|
|
|
|
#ifndef CSRC_TORCH_API_UTILS_H_
|
|
#define CSRC_TORCH_API_UTILS_H_
|
|
|
|
#include <cstdint>
|
|
#include <map>
|
|
#include <optional>
|
|
#include <string>
|
|
#include "ATen/ScalarType.h"
|
|
#include "ATen/Tensor.h"
|
|
#include "aten/cnnl/cnnlHandle.h"
|
|
#include "c10/util/Exception.h"
|
|
#include "cnnl.h"
|
|
#include "framework/core/MLUStream.h"
|
|
#include "framework/core/caching_allocator.h"
|
|
#include "framework/core/device.h"
|
|
#include "framework/core/mlu_guard.h"
|
|
#include "torch/torch.h"
|
|
#include "torch/version.h"
|
|
|
|
namespace tmo {
|
|
namespace torch_api {
|
|
using TensorDesc = std::unique_ptr<std::remove_pointer_t<cnnlTensorDescriptor_t>,
|
|
decltype(&cnnlDestroyTensorDescriptor)>;
|
|
std::vector<TensorDesc> createTensorDescs(const std::initializer_list<at::Tensor> &tensors);
|
|
cnnlDataType_t getCnnlDataType(const at::ScalarType &data_type);
|
|
|
|
template <typename T>
|
|
bool isMlu(const T &tensor) {
|
|
#if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 1)
|
|
return tensor.device().is_privateuseone();
|
|
#else
|
|
return tensor.is_mlu();
|
|
#endif
|
|
}
|
|
|
|
enum class TensorAttr { DEVICE, DTYPE, ALL };
|
|
struct attr_t {
|
|
int64_t device_id;
|
|
at::ScalarType dtype;
|
|
};
|
|
|
|
inline void checkDevice(int64_t &device_id, const at::Tensor &tensor) {
|
|
auto tensor_device_id = tensor.get_device();
|
|
if (device_id == -1) {
|
|
device_id = tensor_device_id;
|
|
return;
|
|
}
|
|
TORCH_CHECK(tensor_device_id == device_id,
|
|
"Tensor device id is not same, original device_id: ", device_id,
|
|
"now device_id is: ", tensor_device_id);
|
|
}
|
|
|
|
inline void checkDtype(at::ScalarType &dtype, const at::Tensor &tensor) {
|
|
auto tensor_dtype = tensor.scalar_type();
|
|
if (dtype == at::ScalarType::Undefined) {
|
|
dtype = tensor_dtype;
|
|
return;
|
|
}
|
|
TORCH_CHECK(tensor_dtype == dtype, "Tensor dtype is not same. original dtype: ", dtype,
|
|
"now dtype is: ", tensor_dtype);
|
|
}
|
|
|
|
template <TensorAttr attr>
|
|
inline void checkTensorAttr(attr_t &attr_states, const at::Tensor &tensor) {
|
|
if (attr == TensorAttr::DEVICE) {
|
|
checkDevice(attr_states.device_id, tensor);
|
|
} else if (attr == TensorAttr::DTYPE) {
|
|
checkDtype(attr_states.dtype, tensor);
|
|
} else if (attr == TensorAttr::ALL) {
|
|
checkDevice(attr_states.device_id, tensor);
|
|
checkDtype(attr_states.dtype, tensor);
|
|
}
|
|
}
|
|
|
|
template <TensorAttr attr,
|
|
typename T,
|
|
typename = typename std::enable_if<
|
|
std::is_same<typename std::decay<T>::type, at::Tensor>::value>::type>
|
|
void checkTensorSameWithSpecificAttr(attr_t &attr_states, const c10::optional<T> &tensor) {
|
|
if (!tensor.has_value() || !tensor->defined()) return;
|
|
auto temp_tensor = tensor.value();
|
|
TORCH_CHECK(isMlu(temp_tensor), "Only support mlu tensor.");
|
|
checkTensorAttr<attr>(attr_states, temp_tensor);
|
|
}
|
|
|
|
template <TensorAttr attr,
|
|
typename T,
|
|
typename = typename std::enable_if<
|
|
std::is_same<typename std::decay<T>::type, at::Tensor>::value>::type>
|
|
void checkTensorSameWithSpecificAttr(attr_t &attr_states, const T &tensor) {
|
|
if (!tensor.defined()) return;
|
|
TORCH_CHECK(isMlu(tensor), "Only support mlu tensor.");
|
|
checkTensorAttr<attr>(attr_states, tensor);
|
|
}
|
|
|
|
template <TensorAttr attr, typename T, typename... Args>
|
|
void checkTensorSameWithSpecificAttr(attr_t &attr_states, const T &tensor, Args &&...args) {
|
|
checkTensorSameWithSpecificAttr<attr>(attr_states, tensor);
|
|
checkTensorSameWithSpecificAttr<attr>(attr_states, std::forward<Args>(args)...);
|
|
}
|
|
|
|
template <TensorAttr attr, typename... Args>
|
|
void checkTensorSameAttr(Args &&...args) {
|
|
attr_t attr_states = {-1, at::ScalarType::Undefined};
|
|
checkTensorSameWithSpecificAttr<attr>(attr_states, std::forward<Args>(args)...);
|
|
}
|
|
|
|
inline at::ScalarType str2TorchDtype(const std::string &type) {
|
|
static std::map<std::string, at::ScalarType> dtype_map = {
|
|
{"float", torch::kFloat32}, {"half", torch::kHalf}, {"bfloat16", torch::kBFloat16},
|
|
{"int32", torch::kInt32}, {"int8", torch::kInt8},
|
|
};
|
|
return dtype_map.at(type);
|
|
}
|
|
|
|
inline std::string &torchDtype2Str(const at::ScalarType type) {
|
|
static std::map<at::ScalarType, std::string> torch_dtype_map = {
|
|
{torch::kFloat32, "float"}, {torch::kHalf, "half"}, {torch::kBFloat16, "bfloat16"},
|
|
{torch::kInt32, "int32"}, {torch::kInt8, "int8"},
|
|
};
|
|
return torch_dtype_map.at(type);
|
|
}
|
|
|
|
inline cnnlDataType_t str2CnnlDtype(const std::string &type) {
|
|
static std::map<std::string, cnnlDataType_t> cnnl_dtype_map = {
|
|
{"float", CNNL_DTYPE_FLOAT}, {"half", CNNL_DTYPE_HALF}, {"bfloat16", CNNL_DTYPE_BFLOAT16},
|
|
{"int32", CNNL_DTYPE_INT32}, {"int8", CNNL_DTYPE_INT8},
|
|
};
|
|
return cnnl_dtype_map.at(type);
|
|
}
|
|
|
|
#define CHECK_SHAPE(x, ...) \
|
|
TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), \
|
|
#x " must have shape (" #__VA_ARGS__ ")")
|
|
|
|
#define CHECK_TENSOR_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous.")
|
|
|
|
#define CHECK_OPTIONAL_TENSOR_CONTIGUOUS(x) \
|
|
if (x.has_value()) TORCH_CHECK(x.value().is_contiguous(), #x " must be contiguous.")
|
|
|
|
inline void *getAtTensorPtr(const c10::optional<at::Tensor> &tensor) {
|
|
return tensor.has_value() ? tensor.value().data_ptr() : nullptr;
|
|
}
|
|
|
|
inline void *getAtTensorPtr(const at::Tensor &tensor) {
|
|
return tensor.defined() ? tensor.data_ptr() : nullptr;
|
|
}
|
|
|
|
} // namespace torch_api
|
|
} // namespace tmo
|
|
|
|
#endif // CSRC_TORCH_API_UTILS_H_
|