refactor: 统一硬件相关头文件引用
将分散在各文件中的CUDA/HIP/MUSA硬件相关头文件引用统一到vendors目录下的对应头文件中,提高代码可维护性。移除重复的头文件引用,优化构建配置。
This commit is contained in:
50
csrc/vendors/cuda.h
vendored
Normal file
50
csrc/vendors/cuda.h
vendored
Normal file
@@ -0,0 +1,50 @@
|
||||
#pragma once
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda.h>
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_store.cuh>
|
||||
#include <cub/block/block_scan.cuh>
|
||||
#include <cub/cub.cuh>
|
||||
#include <cub/device/device_radix_sort.cuh>
|
||||
#include <cub/util_type.cuh>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <cuda/std/limits>
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <cuda/annotated_ptr>
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
#include <c10/util/Float8_e4m3fnuz.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#if CUDART_VERSION >= 12050
|
||||
#include <cuda_fp8.h>
|
||||
#endif // CUDART_VERSION >= 12050
|
||||
|
||||
#if CUDART_VERSION < 11020
|
||||
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
|
||||
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
|
||||
#define CUBLAS_COMPUTE_16F CUDA_R_16F
|
||||
#define CUBLAS_COMPUTE_32F CUDA_R_32F
|
||||
#define cublasComputeType_t cudaDataType_t
|
||||
#endif // CUDART_VERSION < 11020
|
||||
|
||||
#if CUB_VERSION >= 200800
|
||||
#include <cuda/std/functional>
|
||||
using CubAddOp = cuda::std::plus<>;
|
||||
using CubMaxOp = cuda::maximum<>;
|
||||
#else // if CUB_VERSION < 200800
|
||||
using CubAddOp = cub::Sum;
|
||||
using CubMaxOp = cub::Max;
|
||||
#endif // CUB_VERSION
|
||||
9
csrc/vendors/functions.h
vendored
Normal file
9
csrc/vendors/functions.h
vendored
Normal file
@@ -0,0 +1,9 @@
|
||||
#ifdef USE_MUSA
|
||||
#include "musa.h"
|
||||
#elif USE_HIP
|
||||
#include "hip.h"
|
||||
#elif USE_CUDA
|
||||
#include "cuda.h"
|
||||
#else
|
||||
"No Support"
|
||||
#endif
|
||||
307
csrc/vendors/hip.h
vendored
Normal file
307
csrc/vendors/hip.h
vendored
Normal file
@@ -0,0 +1,307 @@
|
||||
#pragma once
|
||||
#include <torch/all.h>
|
||||
|
||||
#define HIP_DISABLE_WARP_SYNC_BUILTINS 1
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hipblas/hipblas.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hip/hip_bf16.h>
|
||||
|
||||
#include <c10/core/ScalarType.h>
|
||||
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/hip/HIPException.h>
|
||||
|
||||
#include <hipcub/hipcub.hpp>
|
||||
#include <c10/util/Float8_e4m3fnuz.h>
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
namespace cub = hipcub;
|
||||
|
||||
typedef __hip_bfloat16 __nv_bfloat16;
|
||||
typedef __hip_bfloat162 __nv_bfloat162;
|
||||
|
||||
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
|
||||
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
|
||||
#define CUBLAS_OP_N HIPBLAS_OP_N
|
||||
#define CUBLAS_OP_T HIPBLAS_OP_T
|
||||
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
||||
#define CUBLAS_TF32_TENSOR_OP_MATH 0
|
||||
#define CUDA_R_16F HIPBLAS_R_16F
|
||||
#define CUDA_R_16BF HIPBLAS_R_16B
|
||||
#define CUDA_R_32F HIPBLAS_R_32F
|
||||
#define CUBLAS_SIDE_RIGHT HIPBLAS_SIDE_RIGHT
|
||||
#define CUBLAS_FILL_MODE_UPPER HIPBLAS_FILL_MODE_UPPER
|
||||
#define CUBLAS_DIAG_NON_UNIT HIPBLAS_DIAG_NON_UNIT
|
||||
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED hipDeviceAttributeVirtualMemoryManagementSupported
|
||||
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED hipMemAllocationGranularityRecommended
|
||||
#define CU_MEM_ALLOCATION_TYPE_PINNED hipMemAllocationTypePinned
|
||||
#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
|
||||
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
|
||||
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
|
||||
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
|
||||
#define __shfl_up_sync(mask, var, laneMask, width) __shfl_up(var, laneMask, width)
|
||||
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
||||
#define __all_sync(mask, var) __all(var)
|
||||
#define __any_sync(mask, var) __any(var)
|
||||
#define cublasStrsmBatched hipblasStrsmBatched
|
||||
#define cublasCreate hipblasCreate
|
||||
#define cublasDestroy hipblasDestroy
|
||||
#define cublasGemmEx hipblasGemmEx
|
||||
#define cublasGemmBatchedEx hipblasGemmBatchedEx
|
||||
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
|
||||
#define cublasHandle_t hipblasHandle_t
|
||||
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
|
||||
#define cublasSetStream hipblasSetStream
|
||||
#define cublasSgemm hipblasSgemm
|
||||
#define cublasStatus_t hipblasStatus_t
|
||||
#define cublasOperation_t hipblasOperation_t
|
||||
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
|
||||
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
||||
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
||||
#define cudaDeviceProp hipDeviceProp_t
|
||||
#define cudaDeviceSynchronize hipDeviceSynchronize
|
||||
#define cudaError_t hipError_t
|
||||
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
|
||||
#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
|
||||
#define cudaEventCreateWithFlags hipEventCreateWithFlags
|
||||
#define cudaEventDisableTiming hipEventDisableTiming
|
||||
#define cudaEventRecord hipEventRecord
|
||||
#define cudaEventSynchronize hipEventSynchronize
|
||||
#define cudaEvent_t hipEvent_t
|
||||
#define cudaEventDestroy hipEventDestroy
|
||||
#define cudaFree hipFree
|
||||
#define cudaFreeHost hipHostFree
|
||||
#define cudaGetDevice hipGetDevice
|
||||
#define cudaGetDeviceCount hipGetDeviceCount
|
||||
#define cudaGetDeviceProperties hipGetDeviceProperties
|
||||
#define cudaGetErrorString hipGetErrorString
|
||||
#define cudaGetLastError hipGetLastError
|
||||
#define cudaHostRegister hipHostRegister
|
||||
#define cudaHostRegisterPortable hipHostRegisterPortable
|
||||
#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
|
||||
#define cudaHostUnregister hipHostUnregister
|
||||
#define cudaLaunchHostFunc hipLaunchHostFunc
|
||||
#define cudaMalloc hipMalloc
|
||||
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
|
||||
#define cudaMallocManaged hipMallocManaged
|
||||
#define cudaMemAdvise hipMemAdvise
|
||||
#define cudaMemcpy hipMemcpy
|
||||
#define cudaMemcpyAsync hipMemcpyAsync
|
||||
#define cudaMemcpyPeerAsync hipMemcpyPeerAsync
|
||||
#define cudaMemcpy2DAsync hipMemcpy2DAsync
|
||||
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
|
||||
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
|
||||
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
|
||||
#define cudaMemcpyKind hipMemcpyKind
|
||||
#define cudaMemset hipMemset
|
||||
#define cudaMemsetAsync hipMemsetAsync
|
||||
#define cudaMemGetInfo hipMemGetInfo
|
||||
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
|
||||
#define cudaSetDevice hipSetDevice
|
||||
#define cuDeviceGet hipDeviceGet
|
||||
#define CUdevice hipDevice_t
|
||||
#define CUdeviceptr hipDeviceptr_t
|
||||
#define cuMemUnmap hipMemUnmap
|
||||
#define CUmemAccessDesc hipMemAccessDesc
|
||||
#define cuMemAddressFree hipMemAddressFree
|
||||
#define cuMemRelease hipMemRelease
|
||||
#define CUmemGenericAllocationHandle hipMemGenericAllocationHandle_t
|
||||
#define cuMemCreate hipMemCreate
|
||||
#define cuMemAddressReserve hipMemAddressReserve
|
||||
#define cuMemMap hipMemMap
|
||||
#define cuMemSetAccess hipMemSetAccess
|
||||
#define cuMemGetAllocationGranularity hipMemGetAllocationGranularity
|
||||
#define CUmemAllocationProp hipMemAllocationProp
|
||||
#define cuDeviceGetAttribute hipDeviceGetAttribute
|
||||
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
|
||||
#define cudaStreamDestroy hipStreamDestroy
|
||||
#define cudaStreamFireAndForget hipStreamFireAndForget
|
||||
#define cudaStreamNonBlocking hipStreamNonBlocking
|
||||
#define cudaStreamPerThread hipStreamPerThread
|
||||
#define cudaStreamSynchronize hipStreamSynchronize
|
||||
#define cudaStreamWaitEvent hipStreamWaitEvent
|
||||
#define cudaGraphExec_t hipGraphExec_t
|
||||
#define cudaGraphNode_t hipGraphNode_t
|
||||
#define cudaKernelNodeParams hipKernelNodeParams
|
||||
#define cudaKernelNodeParams hipKernelNodeParams
|
||||
#define cudaGraphExecDestroy hipGraphExecDestroy
|
||||
#define cudaGraphLaunch hipGraphLaunch
|
||||
#define cudaErrorGraphExecUpdateFailure hipErrorGraphExecUpdateFailure
|
||||
#define cudaGraphExecUpdateResult hipGraphExecUpdateResult
|
||||
#define cudaGraphNodeType hipGraphNodeType
|
||||
#define cudaGraphNodeTypeKernel hipGraphNodeTypeKernel
|
||||
#define cudaGraphInstantiate hipGraphInstantiate
|
||||
#define cudaStreamEndCapture hipStreamEndCapture
|
||||
#define cudaGraphDestroy hipGraphDestroy
|
||||
#define cudaGraphKernelNodeSetParams hipGraphKernelNodeSetParams
|
||||
#define cudaErrorInvalidDeviceFunction hipErrorInvalidDeviceFunction
|
||||
#define cudaGraphKernelNodeGetParams hipGraphKernelNodeGetParams
|
||||
#define cudaGraphNodeGetType hipGraphNodeGetType
|
||||
#define cudaGraphGetNodes hipGraphGetNodes
|
||||
#define cudaGraphExecUpdate hipGraphExecUpdate
|
||||
#define cudaStreamCaptureModeRelaxed hipStreamCaptureModeRelaxed
|
||||
#define cudaStreamBeginCapture hipStreamBeginCapture
|
||||
#define cudaGraph_t hipGraph_t
|
||||
#define cudaStream_t hipStream_t
|
||||
#define cudaSuccess hipSuccess
|
||||
#define cudaFuncSetAttribute hipFuncSetAttribute
|
||||
#define cudaOccupancyMaxActiveBlocksPerMultiprocessor hipOccupancyMaxActiveBlocksPerMultiprocessor
|
||||
#define __trap() do { abort(); __builtin_unreachable(); } while(0)
|
||||
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
||||
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
|
||||
#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
|
||||
#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
|
||||
#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
|
||||
#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
|
||||
#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
|
||||
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
|
||||
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
|
||||
|
||||
|
||||
#define __ldg(arg) *(arg)
|
||||
|
||||
|
||||
#if HIP_VERSION >= 60500000
|
||||
#define CUBLAS_COMPUTE_16F HIPBLAS_COMPUTE_16F
|
||||
#define CUBLAS_COMPUTE_32F HIPBLAS_COMPUTE_32F
|
||||
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_COMPUTE_32F_FAST_16F
|
||||
#define cublasComputeType_t hipblasComputeType_t
|
||||
#define cudaDataType_t hipDataType
|
||||
#else
|
||||
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
|
||||
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
|
||||
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
|
||||
#define cublasComputeType_t hipblasDatatype_t
|
||||
#define cudaDataType_t hipblasDatatype_t
|
||||
#endif // HIP_VERSION >= 6050000
|
||||
|
||||
#if !defined(__HIP_PLATFORM_AMD__)
|
||||
#error "The HIP backend supports only AMD targets"
|
||||
#endif // !defined(__HIP_PLATFORM_AMD__)
|
||||
|
||||
#define __CUDA_ARCH__ 1300
|
||||
|
||||
#if defined(__gfx900__) || defined(__gfx906__)
|
||||
#define GCN5
|
||||
#endif // defined(__gfx900__) || defined(__gfx906__)
|
||||
|
||||
#if defined(__gfx803__)
|
||||
#define GCN4
|
||||
#endif // defined(__gfx803__)
|
||||
|
||||
#if defined(GCN5) || defined(GCN4)
|
||||
#define GCN
|
||||
#endif // defined(GCN5) || defined(GCN4)
|
||||
|
||||
#if defined(__gfx942__)
|
||||
#define CDNA3
|
||||
#endif // defined(__gfx942__)
|
||||
|
||||
#if defined(__gfx90a__)
|
||||
#define CDNA2
|
||||
#endif // defined(__gfx90a__)
|
||||
|
||||
#if defined(__gfx908__)
|
||||
#define CDNA1
|
||||
#endif // defined(__gfx908__)
|
||||
|
||||
#if defined(CDNA3) || defined(CDNA2) || defined(CDNA1)
|
||||
#define CDNA // For the entire family
|
||||
#endif // defined(CDNA3) || defined(CDNA2) || defined(CDNA1)
|
||||
|
||||
#if defined(__GFX12__)
|
||||
#define RDNA4
|
||||
#endif // defined(__GFX12__)
|
||||
|
||||
#if defined(__GFX11__)
|
||||
#define RDNA3
|
||||
#endif // defined(__GFX11__)
|
||||
|
||||
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
|
||||
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
|
||||
#define RDNA2
|
||||
#endif
|
||||
|
||||
#if defined(__gfx1010__) || defined(__gfx1012__)
|
||||
#define RDNA1
|
||||
#endif // defined(__gfx1010__) || defined(__gfx1012__)
|
||||
|
||||
#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(RDNA1)
|
||||
#define RDNA // For the entire family
|
||||
#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(RDNA1)
|
||||
|
||||
#ifndef __has_builtin
|
||||
#define __has_builtin(x) 0
|
||||
#endif
|
||||
|
||||
typedef __hip_bfloat16 nv_bfloat16;
|
||||
typedef __hip_bfloat162 nv_bfloat162;
|
||||
|
||||
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
|
||||
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
|
||||
static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
|
||||
const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
|
||||
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
|
||||
#if __has_builtin(__builtin_elementwise_sub_sat)
|
||||
const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
|
||||
return reinterpret_cast<const int &>(c);
|
||||
#else
|
||||
int8x4_t c;
|
||||
int16_t tmp;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
tmp = va[i] - vb[i];
|
||||
if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max();
|
||||
if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
|
||||
c[i] = tmp;
|
||||
}
|
||||
return reinterpret_cast<int &>(c);
|
||||
#endif // __has_builtin(__builtin_elementwise_sub_sat)
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ int __vsub4(const int a, const int b) {
|
||||
return __vsubss4(a, b);
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) {
|
||||
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
|
||||
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
|
||||
unsigned int c;
|
||||
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
vc[i] = va[i] == vb[i] ? 0xff : 0x00;
|
||||
}
|
||||
return c;
|
||||
}
|
||||
|
||||
static __device__ __forceinline__ unsigned int __vcmpne4(unsigned int a, unsigned int b) {
|
||||
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
|
||||
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
|
||||
unsigned int c;
|
||||
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
vc[i] = va[i] == vb[i] ? 0x00 : 0xff;
|
||||
}
|
||||
return c;
|
||||
}
|
||||
|
||||
|
||||
typedef __hip_bfloat162 __nv_bfloat162;
|
||||
typedef __hip_bfloat16 __nv_bfloat16;
|
||||
typedef __hip_bfloat16_raw __nv_bfloat16_raw;
|
||||
#if defined(HIP_FP8_TYPE_OCP)
|
||||
typedef __hip_fp8_e4m3 __nv_fp8_e4m3;
|
||||
typedef __hip_fp8x4_e4m3 __nv_fp8x4_e4m3;
|
||||
#else
|
||||
// ROCm 6.2 fallback: only *_fnuz types exist
|
||||
typedef __hip_fp8_e4m3_fnuz __nv_fp8_e4m3;
|
||||
typedef __hip_fp8x4_e4m3_fnuz __nv_fp8x4_e4m3;
|
||||
#include <hipcub/hipcub.hpp>
|
||||
namespace cub = hipcub;
|
||||
using CubAddOp = hipcub::Sum;
|
||||
using CubMaxOp = hipcub::Max;
|
||||
181
csrc/vendors/musa.h
vendored
Normal file
181
csrc/vendors/musa.h
vendored
Normal file
@@ -0,0 +1,181 @@
|
||||
// All header files
|
||||
|
||||
#pragma once
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <musa_runtime.h>
|
||||
#include <musa.h>
|
||||
#include <mublas.h>
|
||||
#include <musa_bf16.h>
|
||||
#include <musa_fp16.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/musa/MUSAStream.h>
|
||||
#include <musa/std/limits>
|
||||
#include <cooperative_groups.h>
|
||||
#include <cooperative_groups/reduce.h>
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <cudaTypedefs>
|
||||
#include <c10/musa/MUSAException.h>
|
||||
#include <c10/musa/MUSAGuard.h>
|
||||
#include <c10/util/Float8_e4m3fnuz.h>
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_store.cuh>
|
||||
#include <cub/block/block_scan.cuh>
|
||||
#include <cub/cub.cuh>
|
||||
#include <cub/device/device_radix_sort.cuh>
|
||||
#include <cub/util_type.cuh>
|
||||
#include <ATen/musa/MUSAContext.h>
|
||||
|
||||
using CubAddOp = cub::Sum;
|
||||
using CubMaxOp = cub::Max;
|
||||
|
||||
#define CUBLAS_COMPUTE_16F CUDA_R_16F
|
||||
#define CUBLAS_COMPUTE_32F CUDA_R_32F
|
||||
#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
|
||||
#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
|
||||
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
|
||||
#define CUBLAS_OP_N MUBLAS_OP_N
|
||||
#define CUBLAS_OP_T MUBLAS_OP_T
|
||||
#define CUBLAS_DEFAULT_MATH MUBLAS_DEFAULT_MATH
|
||||
#define CUBLAS_SIDE_RIGHT MUBLAS_SIDE_RIGHT
|
||||
#define CUBLAS_FILL_MODE_UPPER MUBLAS_FILL_MODE_UPPER
|
||||
#define CUBLAS_DIAG_NON_UNIT MUBLAS_DIAG_NON_UNIT
|
||||
#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
|
||||
#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_TENSOR_OP_MATH
|
||||
#define CUDA_R_16F MUSA_R_16F
|
||||
#define CUDA_R_16BF MUSA_R_16BF
|
||||
#define CUDA_R_32F MUSA_R_32F
|
||||
#define cublasStrsmBatched mublasStrsmBatched
|
||||
#define cublasComputeType_t cudaDataType_t
|
||||
#define cublasCreate mublasCreate
|
||||
#define cublasDestroy mublasDestroy
|
||||
#define cublasGemmEx mublasGemmEx
|
||||
#define cublasGemmBatchedEx mublasGemmBatchedEx
|
||||
#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
|
||||
#define cublasHandle_t mublasHandle_t
|
||||
#define cublasSetMathMode mublasSetMathMode
|
||||
#define cublasSetStream mublasSetStream
|
||||
#define cublasSgemm mublasSgemm
|
||||
#define cublasStatus_t mublasStatus_t
|
||||
#define cublasOperation_t mublasOperation_t
|
||||
#define cublasGetStatusString mublasGetStatusString
|
||||
#define cudaDataType_t musaDataType_t
|
||||
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
|
||||
#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
|
||||
#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
|
||||
#define cudaDeviceProp musaDeviceProp
|
||||
#define cudaDeviceSynchronize musaDeviceSynchronize
|
||||
#define cudaError_t musaError_t
|
||||
#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
|
||||
#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
|
||||
#define cudaEventCreateWithFlags musaEventCreateWithFlags
|
||||
#define cudaEventDisableTiming musaEventDisableTiming
|
||||
#define cudaEventRecord musaEventRecord
|
||||
#define cudaEventSynchronize musaEventSynchronize
|
||||
#define cudaEvent_t musaEvent_t
|
||||
#define cudaEventDestroy musaEventDestroy
|
||||
#define cudaFree musaFree
|
||||
#define cudaFreeHost musaFreeHost
|
||||
#define cudaGetDevice musaGetDevice
|
||||
#define cudaGetDeviceCount musaGetDeviceCount
|
||||
#define cudaGetDeviceProperties musaGetDeviceProperties
|
||||
#define cudaGetErrorString musaGetErrorString
|
||||
#define cudaGetLastError musaGetLastError
|
||||
#define cudaHostRegister musaHostRegister
|
||||
#define cudaHostRegisterPortable musaHostRegisterPortable
|
||||
#define cudaHostRegisterReadOnly musaHostRegisterReadOnly
|
||||
#define cudaHostUnregister musaHostUnregister
|
||||
#define cudaLaunchHostFunc musaLaunchHostFunc
|
||||
#define cudaMalloc musaMalloc
|
||||
#define cudaMallocHost musaMallocHost
|
||||
#define cudaMallocManaged musaMallocManaged
|
||||
#define cudaMemcpy musaMemcpy
|
||||
#define cudaMemcpyAsync musaMemcpyAsync
|
||||
#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
|
||||
#define cudaMemcpy2DAsync musaMemcpy2DAsync
|
||||
#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
|
||||
#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
|
||||
#define cudaMemcpyHostToDevice musaMemcpyHostToDevice
|
||||
#define cudaMemcpyKind musaMemcpyKind
|
||||
#define cudaMemset musaMemset
|
||||
#define cudaMemsetAsync musaMemsetAsync
|
||||
#define cudaMemGetInfo musaMemGetInfo
|
||||
#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
|
||||
#define cudaSetDevice musaSetDevice
|
||||
#define cudaStreamCreateWithFlags musaStreamCreateWithFlags
|
||||
#define cudaStreamDestroy musaStreamDestroy
|
||||
#define cudaStreamFireAndForget musaStreamFireAndForget
|
||||
#define cudaStreamNonBlocking musaStreamNonBlocking
|
||||
#define cudaStreamPerThread musaStreamPerThread
|
||||
#define cudaStreamSynchronize musaStreamSynchronize
|
||||
#define cudaStreamWaitEvent musaStreamWaitEvent
|
||||
#define cudaStream_t musaStream_t
|
||||
#define cudaSuccess musaSuccess
|
||||
|
||||
// Additional mappings for MUSA virtual memory pool
|
||||
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
|
||||
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
|
||||
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
|
||||
#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
|
||||
#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
|
||||
#define CUdevice MUdevice
|
||||
#define CUdeviceptr MUdeviceptr
|
||||
#define CUmemAccessDesc MUmemAccessDesc
|
||||
#define CUmemAllocationProp MUmemAllocationProp
|
||||
#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
|
||||
#define cuDeviceGet muDeviceGet
|
||||
#define cuDeviceGetAttribute muDeviceGetAttribute
|
||||
#define cuMemAddressFree muMemAddressFree
|
||||
#define cuMemAddressReserve muMemAddressReserve
|
||||
#define cuMemCreate muMemCreate
|
||||
#define cuMemGetAllocationGranularity muMemGetAllocationGranularity
|
||||
#define cuMemMap muMemMap
|
||||
#define cuMemRelease muMemRelease
|
||||
#define cuMemSetAccess muMemSetAccess
|
||||
#define cuMemUnmap muMemUnmap
|
||||
#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
|
||||
#define cudaFuncSetAttribute musaFuncSetAttribute
|
||||
#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
|
||||
#define make_cudaExtent make_musaExtent
|
||||
#define make_cudaPitchedPtr make_musaPitchedPtr
|
||||
|
||||
// Additional mappings for MUSA graphs
|
||||
#define CUDA_SUCCESS MUSA_SUCCESS
|
||||
#define CUresult MUresult
|
||||
#define cuGetErrorString muGetErrorString
|
||||
#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
|
||||
#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
|
||||
#define cudaGraphDestroy musaGraphDestroy
|
||||
#define cudaGraphExecDestroy musaGraphExecDestroy
|
||||
#define cudaGraphExec_t musaGraphExec_t
|
||||
#define cudaGraphExecUpdate musaGraphExecUpdate
|
||||
#define cudaGraphExecUpdateResult musaGraphExecUpdateResult
|
||||
#define cudaGraphGetNodes musaGraphGetNodes
|
||||
#define cudaGraphInstantiate musaGraphInstantiate
|
||||
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
|
||||
#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
|
||||
#define cudaGraphLaunch musaGraphLaunch
|
||||
#define cudaGraphNodeGetType musaGraphNodeGetType
|
||||
#define cudaGraphNode_t musaGraphNode_t
|
||||
#define cudaGraphNodeType musaGraphNodeType
|
||||
#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
|
||||
#define cudaGraph_t musaGraph_t
|
||||
#define cudaKernelNodeParams musaKernelNodeParams
|
||||
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
|
||||
#define cudaStreamBeginCapture musaStreamBeginCapture
|
||||
#define cudaStreamEndCapture musaStreamEndCapture
|
||||
#define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
|
||||
#define __ldg(arg) *(arg)
|
||||
typedef __mt_bfloat16 nv_bfloat16;
|
||||
typedef __mt_bfloat16 __nv_bfloat16;
|
||||
typedef __mt_bfloat162 nv_bfloat162;
|
||||
typedef __mt_bfloat162 __nv_bfloat162;
|
||||
typedef __mt_bfloat162 __nv_bfloat162;
|
||||
typedef __mt_bfloat16 __nv_bfloat16;
|
||||
typedef __mt_bfloat16_raw __nv_bfloat16_raw;
|
||||
typedef __mt_fp8_e4m3 __nv_fp8_e4m3;
|
||||
typedef __mt_fp8x4_e4m3 __nv_fp8x4_e4m3;
|
||||
Reference in New Issue
Block a user