Files
enginex-mlu370-vllm/torch_mlu_ops-v1.3.2/tests/kernels_pytest/src/utils.h
2026-02-04 17:39:32 +08:00

134 lines
4.2 KiB
C++

/*************************************************************************
* Copyright (C) [2023-2024] by Cambricon, Inc.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
* OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*************************************************************************/
#ifndef TEST_KERNELS_PYTEST_UTILS_H_
#define TEST_KERNELS_PYTEST_UTILS_H_
#include <cnrt.h>
#include <torch/extension.h>
#include <torch/torch.h>
#include "aten/cnnl/cnnlHandle.h"
#include "common/utils.h"
#include "framework/core/MLUStream.h"
#include "framework/core/caching_allocator.h"
#include "framework/core/device.h"
#include "framework/core/mlu_guard.h"
#include "kernels/kernel_utils.h"
namespace tmo {
static cnnlDataType_t torch2cnnlDataType(torch::Dtype dtype) {
switch (dtype) {
case torch::kFloat32:
return CNNL_DTYPE_FLOAT;
case torch::kFloat16:
return CNNL_DTYPE_HALF;
case torch::kFloat64:
return CNNL_DTYPE_DOUBLE;
case torch::kInt8:
return CNNL_DTYPE_INT8;
case torch::kInt16:
return CNNL_DTYPE_INT16;
case torch::kInt32:
return CNNL_DTYPE_INT32;
case torch::kInt64:
return CNNL_DTYPE_INT64;
case torch::kUInt8:
return CNNL_DTYPE_UINT8;
case torch::kBool:
return CNNL_DTYPE_BOOL;
case torch::kBFloat16:
return CNNL_DTYPE_BFLOAT16;
default:
throw std::runtime_error("Unsupported torch::Dtype");
}
}
static constexpr int dtype_size_map[] = {
[CNNL_DTYPE_INVALID] = 0,
[CNNL_DTYPE_HALF] = 2,
[CNNL_DTYPE_FLOAT] = 4,
[CNNL_DTYPE_INT8] = 1,
[CNNL_DTYPE_INT16] = 2,
[CNNL_DTYPE_INT31] = 4,
[CNNL_DTYPE_INT32] = 4,
[CNNL_DTYPE_UINT8] = 1,
[CNNL_DTYPE_BOOL] = 1,
[CNNL_DTYPE_INT64] = 8,
[10] = 0,
[CNNL_DTYPE_UINT32] = 4,
[CNNL_DTYPE_UINT64] = 8,
[CNNL_DTYPE_UINT16] = 2,
[CNNL_DTYPE_DOUBLE] = 8,
[CNNL_DTYPE_COMPLEX_HALF] = 4,
[CNNL_DTYPE_COMPLEX_FLOAT] = 8,
[CNNL_DTYPE_BFLOAT16] = 2,
};
static torch::Dtype cnnl2torchDataType(cnnlDataType_t dtype) {
switch (dtype) {
case CNNL_DTYPE_FLOAT:
return torch::kFloat32;
case CNNL_DTYPE_HALF:
return torch::kFloat16;
case CNNL_DTYPE_DOUBLE:
return torch::kFloat64;
case CNNL_DTYPE_INT8:
return torch::kInt8;
case CNNL_DTYPE_INT16:
return torch::kInt16;
case CNNL_DTYPE_INT32:
return torch::kInt32;
case CNNL_DTYPE_INT64:
return torch::kInt64;
case CNNL_DTYPE_UINT8:
return torch::kUInt8;
case CNNL_DTYPE_BOOL:
return torch::kBool;
case CNNL_DTYPE_BFLOAT16:
return torch::kBFloat16;
default:
throw std::runtime_error("Unsupported cnnlDataType_t");
}
}
static float getBandWidth() {
int card = -1;
CNRT_CHECK(cnrtGetDevice(&card));
if (cndevInit(0) != CNDEV_SUCCESS) {
abort();
}
cndevDDRInfo_t ddrinfo;
ddrinfo.version = CNDEV_VERSION_5;
if (cndevGetDDRInfo(&ddrinfo, card) != CNDEV_SUCCESS) {
abort();
}
double band_width = ddrinfo.bandWidth;
double band_width_decimal = ddrinfo.bandWidthDecimal;
do {
band_width_decimal /= 10;
} while (band_width_decimal > 1);
return float(band_width + band_width_decimal);
}
static void print_info(float time_usec, size_t io_bytes) {
float io_bandwidth = getBandWidth();
std::cout << "kernel time: " << time_usec << "us" << std::endl;
std::cout << "io_bandwidth: " << io_bandwidth << "GB/s" << std::endl;
std::cout << "IO efficiency: " << io_bytes / (time_usec * 1000 * io_bandwidth) << std::endl;
}
#define MLU_TENSOR_CHECK_FATAL(tensor) \
if (tensor.device().type() == c10::kCPU or tensor.device().type() == c10::kCUDA) { \
throw std::runtime_error("Check failed: " #tensor " is not a MLU tensor."); \
}
} // namespace tmo
#endif // TEST_KERNELS_PYTEST_UTILS_H_