134 lines
4.2 KiB
C
134 lines
4.2 KiB
C
|
|
/*************************************************************************
|
||
|
|
* Copyright (C) [2023-2024] by Cambricon, Inc.
|
||
|
|
*
|
||
|
|
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||
|
|
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
|
||
|
|
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
|
||
|
|
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||
|
|
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
|
||
|
|
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
|
||
|
|
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||
|
|
*************************************************************************/
|
||
|
|
#ifndef TEST_KERNELS_PYTEST_UTILS_H_
|
||
|
|
#define TEST_KERNELS_PYTEST_UTILS_H_
|
||
|
|
#include <cnrt.h>
|
||
|
|
#include <torch/extension.h>
|
||
|
|
#include <torch/torch.h>
|
||
|
|
#include "aten/cnnl/cnnlHandle.h"
|
||
|
|
#include "common/utils.h"
|
||
|
|
#include "framework/core/MLUStream.h"
|
||
|
|
#include "framework/core/caching_allocator.h"
|
||
|
|
#include "framework/core/device.h"
|
||
|
|
#include "framework/core/mlu_guard.h"
|
||
|
|
#include "kernels/kernel_utils.h"
|
||
|
|
|
||
|
|
namespace tmo {
|
||
|
|
static cnnlDataType_t torch2cnnlDataType(torch::Dtype dtype) {
|
||
|
|
switch (dtype) {
|
||
|
|
case torch::kFloat32:
|
||
|
|
return CNNL_DTYPE_FLOAT;
|
||
|
|
case torch::kFloat16:
|
||
|
|
return CNNL_DTYPE_HALF;
|
||
|
|
case torch::kFloat64:
|
||
|
|
return CNNL_DTYPE_DOUBLE;
|
||
|
|
case torch::kInt8:
|
||
|
|
return CNNL_DTYPE_INT8;
|
||
|
|
case torch::kInt16:
|
||
|
|
return CNNL_DTYPE_INT16;
|
||
|
|
case torch::kInt32:
|
||
|
|
return CNNL_DTYPE_INT32;
|
||
|
|
case torch::kInt64:
|
||
|
|
return CNNL_DTYPE_INT64;
|
||
|
|
case torch::kUInt8:
|
||
|
|
return CNNL_DTYPE_UINT8;
|
||
|
|
case torch::kBool:
|
||
|
|
return CNNL_DTYPE_BOOL;
|
||
|
|
case torch::kBFloat16:
|
||
|
|
return CNNL_DTYPE_BFLOAT16;
|
||
|
|
default:
|
||
|
|
throw std::runtime_error("Unsupported torch::Dtype");
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
static constexpr int dtype_size_map[] = {
|
||
|
|
[CNNL_DTYPE_INVALID] = 0,
|
||
|
|
[CNNL_DTYPE_HALF] = 2,
|
||
|
|
[CNNL_DTYPE_FLOAT] = 4,
|
||
|
|
[CNNL_DTYPE_INT8] = 1,
|
||
|
|
[CNNL_DTYPE_INT16] = 2,
|
||
|
|
[CNNL_DTYPE_INT31] = 4,
|
||
|
|
[CNNL_DTYPE_INT32] = 4,
|
||
|
|
[CNNL_DTYPE_UINT8] = 1,
|
||
|
|
[CNNL_DTYPE_BOOL] = 1,
|
||
|
|
[CNNL_DTYPE_INT64] = 8,
|
||
|
|
[10] = 0,
|
||
|
|
[CNNL_DTYPE_UINT32] = 4,
|
||
|
|
[CNNL_DTYPE_UINT64] = 8,
|
||
|
|
[CNNL_DTYPE_UINT16] = 2,
|
||
|
|
[CNNL_DTYPE_DOUBLE] = 8,
|
||
|
|
[CNNL_DTYPE_COMPLEX_HALF] = 4,
|
||
|
|
[CNNL_DTYPE_COMPLEX_FLOAT] = 8,
|
||
|
|
[CNNL_DTYPE_BFLOAT16] = 2,
|
||
|
|
};
|
||
|
|
|
||
|
|
static torch::Dtype cnnl2torchDataType(cnnlDataType_t dtype) {
|
||
|
|
switch (dtype) {
|
||
|
|
case CNNL_DTYPE_FLOAT:
|
||
|
|
return torch::kFloat32;
|
||
|
|
case CNNL_DTYPE_HALF:
|
||
|
|
return torch::kFloat16;
|
||
|
|
case CNNL_DTYPE_DOUBLE:
|
||
|
|
return torch::kFloat64;
|
||
|
|
case CNNL_DTYPE_INT8:
|
||
|
|
return torch::kInt8;
|
||
|
|
case CNNL_DTYPE_INT16:
|
||
|
|
return torch::kInt16;
|
||
|
|
case CNNL_DTYPE_INT32:
|
||
|
|
return torch::kInt32;
|
||
|
|
case CNNL_DTYPE_INT64:
|
||
|
|
return torch::kInt64;
|
||
|
|
case CNNL_DTYPE_UINT8:
|
||
|
|
return torch::kUInt8;
|
||
|
|
case CNNL_DTYPE_BOOL:
|
||
|
|
return torch::kBool;
|
||
|
|
case CNNL_DTYPE_BFLOAT16:
|
||
|
|
return torch::kBFloat16;
|
||
|
|
default:
|
||
|
|
throw std::runtime_error("Unsupported cnnlDataType_t");
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
static float getBandWidth() {
|
||
|
|
int card = -1;
|
||
|
|
CNRT_CHECK(cnrtGetDevice(&card));
|
||
|
|
if (cndevInit(0) != CNDEV_SUCCESS) {
|
||
|
|
abort();
|
||
|
|
}
|
||
|
|
cndevDDRInfo_t ddrinfo;
|
||
|
|
ddrinfo.version = CNDEV_VERSION_5;
|
||
|
|
if (cndevGetDDRInfo(&ddrinfo, card) != CNDEV_SUCCESS) {
|
||
|
|
abort();
|
||
|
|
}
|
||
|
|
double band_width = ddrinfo.bandWidth;
|
||
|
|
double band_width_decimal = ddrinfo.bandWidthDecimal;
|
||
|
|
do {
|
||
|
|
band_width_decimal /= 10;
|
||
|
|
} while (band_width_decimal > 1);
|
||
|
|
return float(band_width + band_width_decimal);
|
||
|
|
}
|
||
|
|
|
||
|
|
static void print_info(float time_usec, size_t io_bytes) {
|
||
|
|
float io_bandwidth = getBandWidth();
|
||
|
|
std::cout << "kernel time: " << time_usec << "us" << std::endl;
|
||
|
|
std::cout << "io_bandwidth: " << io_bandwidth << "GB/s" << std::endl;
|
||
|
|
std::cout << "IO efficiency: " << io_bytes / (time_usec * 1000 * io_bandwidth) << std::endl;
|
||
|
|
}
|
||
|
|
|
||
|
|
#define MLU_TENSOR_CHECK_FATAL(tensor) \
|
||
|
|
if (tensor.device().type() == c10::kCPU or tensor.device().type() == c10::kCUDA) { \
|
||
|
|
throw std::runtime_error("Check failed: " #tensor " is not a MLU tensor."); \
|
||
|
|
}
|
||
|
|
|
||
|
|
} // namespace tmo
|
||
|
|
#endif // TEST_KERNELS_PYTEST_UTILS_H_
|