add ops
This commit is contained in:
75
torch_mlu_ops-v1.3.2/csrc/torch_api/utils.cpp
Normal file
75
torch_mlu_ops-v1.3.2/csrc/torch_api/utils.cpp
Normal file
@@ -0,0 +1,75 @@
|
||||
/*************************************************************************
|
||||
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||||
*
|
||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||||
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||||
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||||
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||||
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||
*************************************************************************/
|
||||
|
||||
#include "utils.h"
|
||||
#include "common/utils.h"
|
||||
|
||||
namespace tmo {
|
||||
namespace torch_api {
|
||||
|
||||
#define CNNL_TYPE_AND_SCALAR_TYPE_WITHOUT_64BIT(_) \
|
||||
_(CNNL_DTYPE_FLOAT, at::kFloat) \
|
||||
_(CNNL_DTYPE_BFLOAT16, at::kBFloat16) \
|
||||
_(CNNL_DTYPE_HALF, at::kHalf) \
|
||||
_(CNNL_DTYPE_INT32, at::kInt) \
|
||||
_(CNNL_DTYPE_INT8, at::kChar) \
|
||||
_(CNNL_DTYPE_UINT8, at::kByte) \
|
||||
_(CNNL_DTYPE_BOOL, at::kBool) \
|
||||
_(CNNL_DTYPE_INT16, at::kShort) \
|
||||
_(CNNL_DTYPE_COMPLEX_HALF, at::kComplexHalf) \
|
||||
_(CNNL_DTYPE_COMPLEX_FLOAT, at::kComplexFloat)
|
||||
|
||||
cnnlDataType_t getCnnlDataType(const at::ScalarType &data_type) {
|
||||
switch (data_type) {
|
||||
#define DEFINE_CASE(cnnl_dtype, scalar_type) \
|
||||
case scalar_type: \
|
||||
return cnnl_dtype;
|
||||
CNNL_TYPE_AND_SCALAR_TYPE_WITHOUT_64BIT(DEFINE_CASE)
|
||||
#undef DEFINE_CASE
|
||||
case at::kLong:
|
||||
return CNNL_DTYPE_INT32;
|
||||
case at::kDouble:
|
||||
return CNNL_DTYPE_FLOAT;
|
||||
case at::kComplexDouble:
|
||||
return CNNL_DTYPE_COMPLEX_FLOAT;
|
||||
default:
|
||||
std::string msg("getCnnlDataType() not supported for ");
|
||||
throw std::runtime_error(msg + c10::toString(data_type));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<TensorDesc> createTensorDescs(const std::initializer_list<at::Tensor> &tensors) {
|
||||
std::vector<TensorDesc> descs;
|
||||
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||
descs.emplace_back(TensorDesc{nullptr, cnnlDestroyTensorDescriptor});
|
||||
auto tensor = tensors.begin()[i];
|
||||
if (!tensor.defined()) {
|
||||
continue;
|
||||
}
|
||||
cnnlTensorDescriptor_t desc;
|
||||
CNNL_CHECK_FATAL(cnnlCreateTensorDescriptor(&desc));
|
||||
descs[i].reset(desc);
|
||||
cnnlDataType_t data_type = getCnnlDataType(tensor.scalar_type());
|
||||
if (tensor.strides().size() == 0) {
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptor_v2(descs[i].get(), CNNL_LAYOUT_ARRAY, data_type,
|
||||
tensor.sizes().size(), tensor.sizes().data()));
|
||||
} else {
|
||||
CNNL_CHECK_FATAL(cnnlSetTensorDescriptorEx_v2(descs[i].get(), CNNL_LAYOUT_ARRAY, data_type,
|
||||
tensor.sizes().size(), tensor.sizes().data(),
|
||||
tensor.strides().data()));
|
||||
}
|
||||
}
|
||||
return descs;
|
||||
}
|
||||
|
||||
} // namespace torch_api
|
||||
} // namespace tmo
|
||||
Reference in New Issue
Block a user