/************************************************************************* * Copyright (C) [2023-2024] by Cambricon, Inc. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ #ifndef TEST_KERNELS_PYTEST_UTILS_H_ #define TEST_KERNELS_PYTEST_UTILS_H_ #include #include #include #include "aten/cnnl/cnnlHandle.h" #include "common/utils.h" #include "framework/core/MLUStream.h" #include "framework/core/caching_allocator.h" #include "framework/core/device.h" #include "framework/core/mlu_guard.h" #include "kernels/kernel_utils.h" namespace tmo { static cnnlDataType_t torch2cnnlDataType(torch::Dtype dtype) { switch (dtype) { case torch::kFloat32: return CNNL_DTYPE_FLOAT; case torch::kFloat16: return CNNL_DTYPE_HALF; case torch::kFloat64: return CNNL_DTYPE_DOUBLE; case torch::kInt8: return CNNL_DTYPE_INT8; case torch::kInt16: return CNNL_DTYPE_INT16; case torch::kInt32: return CNNL_DTYPE_INT32; case torch::kInt64: return CNNL_DTYPE_INT64; case torch::kUInt8: return CNNL_DTYPE_UINT8; case torch::kBool: return CNNL_DTYPE_BOOL; case torch::kBFloat16: return CNNL_DTYPE_BFLOAT16; default: throw std::runtime_error("Unsupported torch::Dtype"); } } static constexpr int dtype_size_map[] = { [CNNL_DTYPE_INVALID] = 0, [CNNL_DTYPE_HALF] = 2, [CNNL_DTYPE_FLOAT] = 4, [CNNL_DTYPE_INT8] = 1, [CNNL_DTYPE_INT16] = 2, [CNNL_DTYPE_INT31] = 4, [CNNL_DTYPE_INT32] = 4, [CNNL_DTYPE_UINT8] = 1, [CNNL_DTYPE_BOOL] = 1, [CNNL_DTYPE_INT64] = 8, [10] = 0, [CNNL_DTYPE_UINT32] = 4, [CNNL_DTYPE_UINT64] = 8, [CNNL_DTYPE_UINT16] = 2, [CNNL_DTYPE_DOUBLE] = 8, [CNNL_DTYPE_COMPLEX_HALF] = 4, [CNNL_DTYPE_COMPLEX_FLOAT] = 8, [CNNL_DTYPE_BFLOAT16] = 2, }; static torch::Dtype cnnl2torchDataType(cnnlDataType_t dtype) { switch (dtype) { case CNNL_DTYPE_FLOAT: return torch::kFloat32; case CNNL_DTYPE_HALF: return torch::kFloat16; case CNNL_DTYPE_DOUBLE: return torch::kFloat64; case CNNL_DTYPE_INT8: return torch::kInt8; case CNNL_DTYPE_INT16: return torch::kInt16; case CNNL_DTYPE_INT32: return torch::kInt32; case CNNL_DTYPE_INT64: return torch::kInt64; case CNNL_DTYPE_UINT8: return torch::kUInt8; case CNNL_DTYPE_BOOL: return torch::kBool; case CNNL_DTYPE_BFLOAT16: return torch::kBFloat16; default: throw std::runtime_error("Unsupported cnnlDataType_t"); } } static float getBandWidth() { int card = -1; CNRT_CHECK(cnrtGetDevice(&card)); if (cndevInit(0) != CNDEV_SUCCESS) { abort(); } cndevDDRInfo_t ddrinfo; ddrinfo.version = CNDEV_VERSION_5; if (cndevGetDDRInfo(&ddrinfo, card) != CNDEV_SUCCESS) { abort(); } double band_width = ddrinfo.bandWidth; double band_width_decimal = ddrinfo.bandWidthDecimal; do { band_width_decimal /= 10; } while (band_width_decimal > 1); return float(band_width + band_width_decimal); } static void print_info(float time_usec, size_t io_bytes) { float io_bandwidth = getBandWidth(); std::cout << "kernel time: " << time_usec << "us" << std::endl; std::cout << "io_bandwidth: " << io_bandwidth << "GB/s" << std::endl; std::cout << "IO efficiency: " << io_bytes / (time_usec * 1000 * io_bandwidth) << std::endl; } #define MLU_TENSOR_CHECK_FATAL(tensor) \ if (tensor.device().type() == c10::kCPU or tensor.device().type() == c10::kCUDA) { \ throw std::runtime_error("Check failed: " #tensor " is not a MLU tensor."); \ } } // namespace tmo #endif // TEST_KERNELS_PYTEST_UTILS_H_