refactor: 统一硬件相关头文件引用

将分散在各文件中的CUDA/HIP/MUSA硬件相关头文件引用统一到vendors目录下的对应头文件中,提高代码可维护性。移除重复的头文件引用,优化构建配置。
This commit is contained in:
2026-01-20 10:14:31 +08:00
parent 5aef6c175a
commit 2bd9bd4cc2
98 changed files with 1757 additions and 1286 deletions

View File

@@ -6,12 +6,7 @@
#pragma once
#ifndef USE_ROCM
#include <cuda_bf16.h>
#else
#include <hip/hip_bf16.h>
#endif
#include <cuda_fp16.h>
#include "vendors/functions.h"
////////////////////////////////////////////////////////////////////////////////////////////////////
struct SSMParamsBase {

View File

@@ -1,27 +1,9 @@
// clang-format off
// adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan_fwd_kernel.cuh
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "vendors/functions.h"
#include "selective_scan.h"
#include <c10/util/BFloat16.h>
#include <c10/util/Half.h>
#ifdef USE_ROCM
#include <c10/hip/HIPException.h> // For C10_HIP_CHECK and C10_HIP_KERNEL_LAUNCH_CHECK
#else
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
#endif
#ifndef USE_ROCM
#include <cub/block/block_load.cuh>
#include <cub/block/block_store.cuh>
#include <cub/block/block_scan.cuh>
#else
#include <hipcub/hipcub.hpp>
namespace cub = hipcub;
#endif
#include "selective_scan.h"
#include "static_switch.h"