Files
enginex-mthreads-vllm/csrc/vendors/musa.h
xiezhongtao 2bd9bd4cc2 refactor: 统一硬件相关头文件引用
将分散在各文件中的CUDA/HIP/MUSA硬件相关头文件引用统一到vendors目录下的对应头文件中,提高代码可维护性。移除重复的头文件引用,优化构建配置。
2026-01-20 10:14:31 +08:00

182 lines
7.5 KiB
C

// All header files
#pragma once
#include <torch/all.h>
#include <musa_runtime.h>
#include <musa.h>
#include <mublas.h>
#include <musa_bf16.h>
#include <musa_fp16.h>
#include <c10/core/ScalarType.h>
#include <c10/musa/MUSAStream.h>
#include <musa/std/limits>
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
#include <cudaTypedefs>
#include <c10/musa/MUSAException.h>
#include <c10/musa/MUSAGuard.h>
#include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e4m3fn.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/util/Optional.h>
#include <cub/block/block_load.cuh>
#include <cub/block/block_store.cuh>
#include <cub/block/block_scan.cuh>
#include <cub/cub.cuh>
#include <cub/device/device_radix_sort.cuh>
#include <cub/util_type.cuh>
#include <ATen/musa/MUSAContext.h>
using CubAddOp = cub::Sum;
using CubMaxOp = cub::Max;
#define CUBLAS_COMPUTE_16F CUDA_R_16F
#define CUBLAS_COMPUTE_32F CUDA_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F MUBLAS_COMPUTE_32F_FAST_16F
#define CUBLAS_GEMM_DEFAULT MUBLAS_GEMM_DEFAULT
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP MUBLAS_GEMM_DEFAULT
#define CUBLAS_OP_N MUBLAS_OP_N
#define CUBLAS_OP_T MUBLAS_OP_T
#define CUBLAS_DEFAULT_MATH MUBLAS_DEFAULT_MATH
#define CUBLAS_SIDE_RIGHT MUBLAS_SIDE_RIGHT
#define CUBLAS_FILL_MODE_UPPER MUBLAS_FILL_MODE_UPPER
#define CUBLAS_DIAG_NON_UNIT MUBLAS_DIAG_NON_UNIT
#define CUBLAS_STATUS_SUCCESS MUBLAS_STATUS_SUCCESS
#define CUBLAS_TF32_TENSOR_OP_MATH MUBLAS_TENSOR_OP_MATH
#define CUDA_R_16F MUSA_R_16F
#define CUDA_R_16BF MUSA_R_16BF
#define CUDA_R_32F MUSA_R_32F
#define cublasStrsmBatched mublasStrsmBatched
#define cublasComputeType_t cudaDataType_t
#define cublasCreate mublasCreate
#define cublasDestroy mublasDestroy
#define cublasGemmEx mublasGemmEx
#define cublasGemmBatchedEx mublasGemmBatchedEx
#define cublasGemmStridedBatchedEx mublasGemmStridedBatchedEx
#define cublasHandle_t mublasHandle_t
#define cublasSetMathMode mublasSetMathMode
#define cublasSetStream mublasSetStream
#define cublasSgemm mublasSgemm
#define cublasStatus_t mublasStatus_t
#define cublasOperation_t mublasOperation_t
#define cublasGetStatusString mublasGetStatusString
#define cudaDataType_t musaDataType_t
#define cudaDeviceCanAccessPeer musaDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess musaDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess musaDeviceEnablePeerAccess
#define cudaDeviceProp musaDeviceProp
#define cudaDeviceSynchronize musaDeviceSynchronize
#define cudaError_t musaError_t
#define cudaErrorPeerAccessAlreadyEnabled musaErrorPeerAccessAlreadyEnabled
#define cudaErrorPeerAccessNotEnabled musaErrorPeerAccessNotEnabled
#define cudaEventCreateWithFlags musaEventCreateWithFlags
#define cudaEventDisableTiming musaEventDisableTiming
#define cudaEventRecord musaEventRecord
#define cudaEventSynchronize musaEventSynchronize
#define cudaEvent_t musaEvent_t
#define cudaEventDestroy musaEventDestroy
#define cudaFree musaFree
#define cudaFreeHost musaFreeHost
#define cudaGetDevice musaGetDevice
#define cudaGetDeviceCount musaGetDeviceCount
#define cudaGetDeviceProperties musaGetDeviceProperties
#define cudaGetErrorString musaGetErrorString
#define cudaGetLastError musaGetLastError
#define cudaHostRegister musaHostRegister
#define cudaHostRegisterPortable musaHostRegisterPortable
#define cudaHostRegisterReadOnly musaHostRegisterReadOnly
#define cudaHostUnregister musaHostUnregister
#define cudaLaunchHostFunc musaLaunchHostFunc
#define cudaMalloc musaMalloc
#define cudaMallocHost musaMallocHost
#define cudaMallocManaged musaMallocManaged
#define cudaMemcpy musaMemcpy
#define cudaMemcpyAsync musaMemcpyAsync
#define cudaMemcpyPeerAsync musaMemcpyPeerAsync
#define cudaMemcpy2DAsync musaMemcpy2DAsync
#define cudaMemcpyDeviceToDevice musaMemcpyDeviceToDevice
#define cudaMemcpyDeviceToHost musaMemcpyDeviceToHost
#define cudaMemcpyHostToDevice musaMemcpyHostToDevice
#define cudaMemcpyKind musaMemcpyKind
#define cudaMemset musaMemset
#define cudaMemsetAsync musaMemsetAsync
#define cudaMemGetInfo musaMemGetInfo
#define cudaOccupancyMaxPotentialBlockSize musaOccupancyMaxPotentialBlockSize
#define cudaSetDevice musaSetDevice
#define cudaStreamCreateWithFlags musaStreamCreateWithFlags
#define cudaStreamDestroy musaStreamDestroy
#define cudaStreamFireAndForget musaStreamFireAndForget
#define cudaStreamNonBlocking musaStreamNonBlocking
#define cudaStreamPerThread musaStreamPerThread
#define cudaStreamSynchronize musaStreamSynchronize
#define cudaStreamWaitEvent musaStreamWaitEvent
#define cudaStream_t musaStream_t
#define cudaSuccess musaSuccess
// Additional mappings for MUSA virtual memory pool
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED MU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE MU_MEM_ACCESS_FLAGS_PROT_READWRITE
#define CU_MEM_ALLOC_GRANULARITY_RECOMMENDED MU_MEM_ALLOC_GRANULARITY_RECOMMENDED
#define CU_MEM_ALLOCATION_TYPE_PINNED MU_MEM_ALLOCATION_TYPE_PINNED
#define CU_MEM_LOCATION_TYPE_DEVICE MU_MEM_LOCATION_TYPE_DEVICE
#define CUdevice MUdevice
#define CUdeviceptr MUdeviceptr
#define CUmemAccessDesc MUmemAccessDesc
#define CUmemAllocationProp MUmemAllocationProp
#define CUmemGenericAllocationHandle MUmemGenericAllocationHandle
#define cuDeviceGet muDeviceGet
#define cuDeviceGetAttribute muDeviceGetAttribute
#define cuMemAddressFree muMemAddressFree
#define cuMemAddressReserve muMemAddressReserve
#define cuMemCreate muMemCreate
#define cuMemGetAllocationGranularity muMemGetAllocationGranularity
#define cuMemMap muMemMap
#define cuMemRelease muMemRelease
#define cuMemSetAccess muMemSetAccess
#define cuMemUnmap muMemUnmap
#define cudaFuncAttributeMaxDynamicSharedMemorySize musaFuncAttributeMaxDynamicSharedMemorySize
#define cudaFuncSetAttribute musaFuncSetAttribute
#define cudaMemcpy3DPeerParms musaMemcpy3DPeerParms
#define make_cudaExtent make_musaExtent
#define make_cudaPitchedPtr make_musaPitchedPtr
// Additional mappings for MUSA graphs
#define CUDA_SUCCESS MUSA_SUCCESS
#define CUresult MUresult
#define cuGetErrorString muGetErrorString
#define cudaErrorGraphExecUpdateFailure musaErrorGraphExecUpdateFailure
#define cudaErrorInvalidDeviceFunction musaErrorInvalidDeviceFunction
#define cudaGraphDestroy musaGraphDestroy
#define cudaGraphExecDestroy musaGraphExecDestroy
#define cudaGraphExec_t musaGraphExec_t
#define cudaGraphExecUpdate musaGraphExecUpdate
#define cudaGraphExecUpdateResult musaGraphExecUpdateResult
#define cudaGraphGetNodes musaGraphGetNodes
#define cudaGraphInstantiate musaGraphInstantiate
#define cudaGraphKernelNodeGetParams musaGraphKernelNodeGetParams
#define cudaGraphKernelNodeSetParams musaGraphKernelNodeSetParams
#define cudaGraphLaunch musaGraphLaunch
#define cudaGraphNodeGetType musaGraphNodeGetType
#define cudaGraphNode_t musaGraphNode_t
#define cudaGraphNodeType musaGraphNodeType
#define cudaGraphNodeTypeKernel musaGraphNodeTypeKernel
#define cudaGraph_t musaGraph_t
#define cudaKernelNodeParams musaKernelNodeParams
#define cudaStreamCaptureModeRelaxed musaStreamCaptureModeRelaxed
#define cudaStreamBeginCapture musaStreamBeginCapture
#define cudaStreamEndCapture musaStreamEndCapture
#define cudaOccupancyMaxActiveBlocksPerMultiprocessor musaOccupancyMaxActiveBlocksPerMultiprocessor
#define __ldg(arg) *(arg)
typedef __mt_bfloat16 nv_bfloat16;
typedef __mt_bfloat16 __nv_bfloat16;
typedef __mt_bfloat162 nv_bfloat162;
typedef __mt_bfloat162 __nv_bfloat162;
typedef __mt_bfloat162 __nv_bfloat162;
typedef __mt_bfloat16 __nv_bfloat16;
typedef __mt_bfloat16_raw __nv_bfloat16_raw;
typedef __mt_fp8_e4m3 __nv_fp8_e4m3;
typedef __mt_fp8x4_e4m3 __nv_fp8x4_e4m3;