将分散在各文件中的CUDA/HIP/MUSA硬件相关头文件引用统一到vendors目录下的对应头文件中,提高代码可维护性。移除重复的头文件引用,优化构建配置。
182 lines
7.5 KiB
C
182 lines
7.5 KiB
C
// All header files
|
|
|
|
#pragma once
|
|
#include <torch/all.h>
|
|
|
|
#include <musa_runtime.h>
|
|
#include <musa.h>
|
|
#include <mublas.h>
|
|
#include <musa_bf16.h>
|
|
#include <musa_fp16.h>
|
|
#include <c10/core/ScalarType.h>
|
|
#include <c10/musa/MUSAStream.h>
|
|
#include <musa/std/limits>
|
|
#include <cooperative_groups.h>
|
|
#include <cooperative_groups/reduce.h>
|
|
#include <c10/util/BFloat16.h>
|
|
#include <c10/util/Half.h>
|
|
#include <cudaTypedefs>
|
|
#include <c10/musa/MUSAException.h>
|
|
#include <c10/musa/MUSAGuard.h>
|
|
#include <c10/util/Float8_e4m3fnuz.h>
|
|
#include <c10/util/Float8_e4m3fn.h>
|
|
#include <ATen/cuda/CUDAContext.h>
|
|
#include <c10/util/Optional.h>
|
|
#include <cub/block/block_load.cuh>
|
|
#include <cub/block/block_store.cuh>
|
|
#include <cub/block/block_scan.cuh>
|
|
#include <cub/cub.cuh>
|
|
#include <cub/device/device_radix_sort.cuh>
|
|
#include <cub/util_type.cuh>
|
|
#include <ATen/musa/MUSAContext.h>
|
|
|
|
using CubAddOp = cub::Sum;
|
|
using CubMaxOp = cub::Max;
|
|
|
|
#define CUBLAS_COMPUTE_16F CUDA_R_16F
|
|
#define CUBLAS_COMPUTE_32F CUDA_R_32F
|
|
#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
|
|
#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
|
|
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
|
|
#define CUBLAS_OP_N MUBLAS_OP_N
|
|
#define CUBLAS_OP_T MUBLAS_OP_T
|
|
#define CUBLAS_DEFAULT_MATH MUBLAS_DEFAULT_MATH
|
|
#define CUBLAS_SIDE_RIGHT MUBLAS_SIDE_RIGHT
|
|
#define CUBLAS_FILL_MODE_UPPER MUBLAS_FILL_MODE_UPPER
|
|
#define CUBLAS_DIAG_NON_UNIT MUBLAS_DIAG_NON_UNIT
|
|
#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
|
|
#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_TENSOR_OP_MATH
|
|
#define CUDA_R_16F MUSA_R_16F
|
|
#define CUDA_R_16BF MUSA_R_16BF
|
|
#define CUDA_R_32F MUSA_R_32F
|
|
#define cublasStrsmBatched mublasStrsmBatched
|
|
#define cublasComputeType_t cudaDataType_t
|
|
#define cublasCreate mublasCreate
|
|
#define cublasDestroy mublasDestroy
|
|
#define cublasGemmEx mublasGemmEx
|
|
#define cublasGemmBatchedEx mublasGemmBatchedEx
|
|
#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
|
|
#define cublasHandle_t mublasHandle_t
|
|
#define cublasSetMathMode mublasSetMathMode
|
|
#define cublasSetStream mublasSetStream
|
|
#define cublasSgemm mublasSgemm
|
|
#define cublasStatus_t mublasStatus_t
|
|
#define cublasOperation_t mublasOperation_t
|
|
#define cublasGetStatusString mublasGetStatusString
|
|
#define cudaDataType_t musaDataType_t
|
|
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
|
|
#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
|
|
#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
|
|
#define cudaDeviceProp musaDeviceProp
|
|
#define cudaDeviceSynchronize musaDeviceSynchronize
|
|
#define cudaError_t musaError_t
|
|
#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
|
|
#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
|
|
#define cudaEventCreateWithFlags musaEventCreateWithFlags
|
|
#define cudaEventDisableTiming musaEventDisableTiming
|
|
#define cudaEventRecord musaEventRecord
|
|
#define cudaEventSynchronize musaEventSynchronize
|
|
#define cudaEvent_t musaEvent_t
|
|
#define cudaEventDestroy musaEventDestroy
|
|
#define cudaFree musaFree
|
|
#define cudaFreeHost musaFreeHost
|
|
#define cudaGetDevice musaGetDevice
|
|
#define cudaGetDeviceCount musaGetDeviceCount
|
|
#define cudaGetDeviceProperties musaGetDeviceProperties
|
|
#define cudaGetErrorString musaGetErrorString
|
|
#define cudaGetLastError musaGetLastError
|
|
#define cudaHostRegister musaHostRegister
|
|
#define cudaHostRegisterPortable musaHostRegisterPortable
|
|
#define cudaHostRegisterReadOnly musaHostRegisterReadOnly
|
|
#define cudaHostUnregister musaHostUnregister
|
|
#define cudaLaunchHostFunc musaLaunchHostFunc
|
|
#define cudaMalloc musaMalloc
|
|
#define cudaMallocHost musaMallocHost
|
|
#define cudaMallocManaged musaMallocManaged
|
|
#define cudaMemcpy musaMemcpy
|
|
#define cudaMemcpyAsync musaMemcpyAsync
|
|
#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
|
|
#define cudaMemcpy2DAsync musaMemcpy2DAsync
|
|
#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
|
|
#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
|
|
#define cudaMemcpyHostToDevice musaMemcpyHostToDevice
|
|
#define cudaMemcpyKind musaMemcpyKind
|
|
#define cudaMemset musaMemset
|
|
#define cudaMemsetAsync musaMemsetAsync
|
|
#define cudaMemGetInfo musaMemGetInfo
|
|
#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
|
|
#define cudaSetDevice musaSetDevice
|
|
#define cudaStreamCreateWithFlags musaStreamCreateWithFlags
|
|
#define cudaStreamDestroy musaStreamDestroy
|
|
#define cudaStreamFireAndForget musaStreamFireAndForget
|
|
#define cudaStreamNonBlocking musaStreamNonBlocking
|
|
#define cudaStreamPerThread musaStreamPerThread
|
|
#define cudaStreamSynchronize musaStreamSynchronize
|
|
#define cudaStreamWaitEvent musaStreamWaitEvent
|
|
#define cudaStream_t musaStream_t
|
|
#define cudaSuccess musaSuccess
|
|
|
|
// Additional mappings for MUSA virtual memory pool
|
|
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
|
|
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
|
|
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
|
|
#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
|
|
#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
|
|
#define CUdevice MUdevice
|
|
#define CUdeviceptr MUdeviceptr
|
|
#define CUmemAccessDesc MUmemAccessDesc
|
|
#define CUmemAllocationProp MUmemAllocationProp
|
|
#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
|
|
#define cuDeviceGet muDeviceGet
|
|
#define cuDeviceGetAttribute muDeviceGetAttribute
|
|
#define cuMemAddressFree muMemAddressFree
|
|
#define cuMemAddressReserve muMemAddressReserve
|
|
#define cuMemCreate muMemCreate
|
|
#define cuMemGetAllocationGranularity muMemGetAllocationGranularity
|
|
#define cuMemMap muMemMap
|
|
#define cuMemRelease muMemRelease
|
|
#define cuMemSetAccess muMemSetAccess
|
|
#define cuMemUnmap muMemUnmap
|
|
#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
|
|
#define cudaFuncSetAttribute musaFuncSetAttribute
|
|
#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
|
|
#define make_cudaExtent make_musaExtent
|
|
#define make_cudaPitchedPtr make_musaPitchedPtr
|
|
|
|
// Additional mappings for MUSA graphs
|
|
#define CUDA_SUCCESS MUSA_SUCCESS
|
|
#define CUresult MUresult
|
|
#define cuGetErrorString muGetErrorString
|
|
#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
|
|
#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
|
|
#define cudaGraphDestroy musaGraphDestroy
|
|
#define cudaGraphExecDestroy musaGraphExecDestroy
|
|
#define cudaGraphExec_t musaGraphExec_t
|
|
#define cudaGraphExecUpdate musaGraphExecUpdate
|
|
#define cudaGraphExecUpdateResult musaGraphExecUpdateResult
|
|
#define cudaGraphGetNodes musaGraphGetNodes
|
|
#define cudaGraphInstantiate musaGraphInstantiate
|
|
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
|
|
#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
|
|
#define cudaGraphLaunch musaGraphLaunch
|
|
#define cudaGraphNodeGetType musaGraphNodeGetType
|
|
#define cudaGraphNode_t musaGraphNode_t
|
|
#define cudaGraphNodeType musaGraphNodeType
|
|
#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
|
|
#define cudaGraph_t musaGraph_t
|
|
#define cudaKernelNodeParams musaKernelNodeParams
|
|
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
|
|
#define cudaStreamBeginCapture musaStreamBeginCapture
|
|
#define cudaStreamEndCapture musaStreamEndCapture
|
|
#define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
|
|
#define __ldg(arg) *(arg)
|
|
typedef __mt_bfloat16 nv_bfloat16;
|
|
typedef __mt_bfloat16 __nv_bfloat16;
|
|
typedef __mt_bfloat162 nv_bfloat162;
|
|
typedef __mt_bfloat162 __nv_bfloat162;
|
|
typedef __mt_bfloat162 __nv_bfloat162;
|
|
typedef __mt_bfloat16 __nv_bfloat16;
|
|
typedef __mt_bfloat16_raw __nv_bfloat16_raw;
|
|
typedef __mt_fp8_e4m3 __nv_fp8_e4m3;
|
|
typedef __mt_fp8x4_e4m3 __nv_fp8x4_e4m3;
|