refactor: 统一硬件相关头文件引用
将分散在各文件中的CUDA/HIP/MUSA硬件相关头文件引用统一到vendors目录下的对应头文件中,提高代码可维护性。移除重复的头文件引用,优化构建配置。
This commit is contained in:
@@ -6,12 +6,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda_bf16.h>
|
||||
#else
|
||||
#include <hip/hip_bf16.h>
|
||||
#endif
|
||||
#include <cuda_fp16.h>
|
||||
#include "vendors/functions.h"
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct SSMParamsBase {
|
||||
|
||||
@@ -1,27 +1,9 @@
|
||||
// clang-format off
|
||||
// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan_fwd_kernel.cuh
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "vendors/functions.h"
|
||||
|
||||
#include "selective_scan.h"
|
||||
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Half.h>
|
||||
#ifdef USE_ROCM
|
||||
#include <c10/hip/HIPException.h> // For C10_HIP_CHECK and C10_HIP_KERNEL_LAUNCH_CHECK
|
||||
#else
|
||||
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_store.cuh>
|
||||
#include <cub/block/block_scan.cuh>
|
||||
#else
|
||||
#include <hipcub/hipcub.hpp>
|
||||
namespace cub = hipcub;
|
||||
#endif
|
||||
|
||||
#include "selective_scan.h"
|
||||
#include "static_switch.h"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user