Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/csrc/torch_api/utils.h
2026-02-04 17:39:32 +08:00

165 lines
6.0 KiB
C++

/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef CSRC_TORCH_API_UTILS_H_
#define CSRC_TORCH_API_UTILS_H_
#include <cstdint>
#include <map>
#include <optional>
#include <string>
#include "ATen/ScalarType.h"
#include "ATen/Tensor.h"
#include "aten/cnnl/cnnlHandle.h"
#include "c10/util/Exception.h"
#include "cnnl.h"
#include "framework/core/MLUStream.h"
#include "framework/core/caching_allocator.h"
#include "framework/core/device.h"
#include "framework/core/mlu_guard.h"
#include "torch/torch.h"
#include "torch/version.h"
namespace tmo {
namespace torch_api {
using TensorDesc = std::unique_ptr<std::remove_pointer_t<cnnlTensorDescriptor_t>,
decltype(&cnnlDestroyTensorDescriptor)>;
std::vector<TensorDesc> createTensorDescs(const std::initializer_list<at::Tensor> &tensors);
cnnlDataType_t getCnnlDataType(const at::ScalarType &data_type);
template <typename T>
bool isMlu(const T &tensor) {
#if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 1)
return tensor.device().is_privateuseone();
#else
return tensor.is_mlu();
#endif
}
enum class TensorAttr { DEVICE, DTYPE, ALL };
struct attr_t {
int64_t device_id;
at::ScalarType dtype;
};
inline void checkDevice(int64_t &device_id, const at::Tensor &tensor) {
auto tensor_device_id = tensor.get_device();
if (device_id == -1) {
device_id = tensor_device_id;
return;
}
TORCH_CHECK(tensor_device_id == device_id,
"Tensor device id is not same, original device_id: ", device_id,
"now device_id is: ", tensor_device_id);
}
inline void checkDtype(at::ScalarType &dtype, const at::Tensor &tensor) {
auto tensor_dtype = tensor.scalar_type();
if (dtype == at::ScalarType::Undefined) {
dtype = tensor_dtype;
return;
}
TORCH_CHECK(tensor_dtype == dtype, "Tensor dtype is not same. original dtype: ", dtype,
"now dtype is: ", tensor_dtype);
}
template <TensorAttr attr>
inline void checkTensorAttr(attr_t &attr_states, const at::Tensor &tensor) {
if (attr == TensorAttr::DEVICE) {
checkDevice(attr_states.device_id, tensor);
} else if (attr == TensorAttr::DTYPE) {
checkDtype(attr_states.dtype, tensor);
} else if (attr == TensorAttr::ALL) {
checkDevice(attr_states.device_id, tensor);
checkDtype(attr_states.dtype, tensor);
}
}
template <TensorAttr attr,
typename T,
typename = typename std::enable_if<
std::is_same<typename std::decay<T>::type, at::Tensor>::value>::type>
void checkTensorSameWithSpecificAttr(attr_t &attr_states, const c10::optional<T> &tensor) {
if (!tensor.has_value() || !tensor->defined()) return;
auto temp_tensor = tensor.value();
TORCH_CHECK(isMlu(temp_tensor), "Only support mlu tensor.");
checkTensorAttr<attr>(attr_states, temp_tensor);
}
template <TensorAttr attr,
typename T,
typename = typename std::enable_if<
std::is_same<typename std::decay<T>::type, at::Tensor>::value>::type>
void checkTensorSameWithSpecificAttr(attr_t &attr_states, const T &tensor) {
if (!tensor.defined()) return;
TORCH_CHECK(isMlu(tensor), "Only support mlu tensor.");
checkTensorAttr<attr>(attr_states, tensor);
}
template <TensorAttr attr, typename T, typename... Args>
void checkTensorSameWithSpecificAttr(attr_t &attr_states, const T &tensor, Args &&...args) {
checkTensorSameWithSpecificAttr<attr>(attr_states, tensor);
checkTensorSameWithSpecificAttr<attr>(attr_states, std::forward<Args>(args)...);
}
template <TensorAttr attr, typename... Args>
void checkTensorSameAttr(Args &&...args) {
attr_t attr_states = {-1, at::ScalarType::Undefined};
checkTensorSameWithSpecificAttr<attr>(attr_states, std::forward<Args>(args)...);
}
inline at::ScalarType str2TorchDtype(const std::string &type) {
static std::map<std::string, at::ScalarType> dtype_map = {
{"float", torch::kFloat32}, {"half", torch::kHalf}, {"bfloat16", torch::kBFloat16},
{"int32", torch::kInt32}, {"int8", torch::kInt8},
};
return dtype_map.at(type);
}
inline std::string &torchDtype2Str(const at::ScalarType type) {
static std::map<at::ScalarType, std::string> torch_dtype_map = {
{torch::kFloat32, "float"}, {torch::kHalf, "half"}, {torch::kBFloat16, "bfloat16"},
{torch::kInt32, "int32"}, {torch::kInt8, "int8"},
};
return torch_dtype_map.at(type);
}
inline cnnlDataType_t str2CnnlDtype(const std::string &type) {
static std::map<std::string, cnnlDataType_t> cnnl_dtype_map = {
{"float", CNNL_DTYPE_FLOAT}, {"half", CNNL_DTYPE_HALF}, {"bfloat16", CNNL_DTYPE_BFLOAT16},
{"int32", CNNL_DTYPE_INT32}, {"int8", CNNL_DTYPE_INT8},
};
return cnnl_dtype_map.at(type);
}
#define CHECK_SHAPE(x, ...) \
TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), \
#x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_TENSOR_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous.")
#define CHECK_OPTIONAL_TENSOR_CONTIGUOUS(x) \
if (x.has_value()) TORCH_CHECK(x.value().is_contiguous(), #x " must be contiguous.")
inline void *getAtTensorPtr(const c10::optional<at::Tensor> &tensor) {
return tensor.has_value() ? tensor.value().data_ptr() : nullptr;
}
inline void *getAtTensorPtr(const at::Tensor &tensor) {
return tensor.defined() ? tensor.data_ptr() : nullptr;
}
} // namespace torch_api
} // namespace tmo
#endif // CSRC_TORCH_API_UTILS_H_