refactor: 统一硬件相关头文件引用

将分散在各文件中的CUDA/HIP/MUSA硬件相关头文件引用统一到vendors目录下的对应头文件中,提高代码可维护性。移除重复的头文件引用,优化构建配置。
This commit is contained in:
2026-01-20 10:14:31 +08:00
parent 5aef6c175a
commit 2bd9bd4cc2
98 changed files with 1757 additions and 1286 deletions

View File

@@ -1,6 +1,5 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include "../vendors/functions.h"
#include <cmath>
#include "core/math.hpp"
@@ -9,29 +8,8 @@
#include "quantization/w8a8/fp8/common.cuh"
#include <c10/util/Float8_e4m3fn.h>
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_fp8.h>
typedef __hip_bfloat162 __nv_bfloat162;
typedef __hip_bfloat16 __nv_bfloat16;
typedef __hip_bfloat16_raw __nv_bfloat16_raw;
#if defined(HIP_FP8_TYPE_OCP)
typedef __hip_fp8_e4m3 __nv_fp8_e4m3;
typedef __hip_fp8x4_e4m3 __nv_fp8x4_e4m3;
#else
// ROCm 6.2 fallback: only *_fnuz types exist
typedef __hip_fp8_e4m3_fnuz __nv_fp8_e4m3;
typedef __hip_fp8x4_e4m3_fnuz __nv_fp8x4_e4m3;
#endif
#endif
#include "core/registration.h"
namespace vllm {

View File

@@ -7,12 +7,11 @@ Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
}
*/
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include "dequantize.cuh"
#include <cuda_fp16.h>
#include "../../vendors/functions.h"
namespace vllm {
namespace awq {

View File

@@ -1,9 +1,7 @@
// see csrc/quantization/w8a8/cutlass/moe/get_group_starts.cuh
#pragma once
#include <cuda.h>
#include <torch/all.h>
#include <c10/cuda/CUDAStream.h>
#include "../../vendors/functions.h"
#include "core/scalar_type.hpp"
#include "cutlass/bfloat16.h"

View File

@@ -14,9 +14,8 @@
#include "cutlass/util/mixed_dtype_utils.hpp"
// vllm includes
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "../../vendors/functions.h"
#include "cutlass_extensions/torch_utils.hpp"
#include "cutlass_extensions/common.hpp"

View File

@@ -3,9 +3,10 @@
// https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu
//
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "../../vendors/functions.h"
#include "cutlass_extensions/torch_utils.hpp"
#include "w4a8_utils.cuh"
@@ -26,7 +27,6 @@
#include "cutlass_extensions/common.hpp"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
#include <cuda_runtime.h>
namespace vllm::cutlass_w4a8 {

View File

@@ -1,7 +1,10 @@
#include "w4a8_utils.cuh"
#include "../../vendors/functions.h"
#include <array>
#include <cuda_runtime.h>
#include <cstdio>
namespace vllm::cutlass_w4a8_utils {

View File

@@ -14,15 +14,10 @@
* limitations under the License.
*/
#include <torch/all.h>
#include <cuda_runtime_api.h>
#include <cuda_runtime.h>
#include "../../vendors/functions.h"
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp8.h>
#include "dispatch_utils.h"
#include "cuda_utils.h"

View File

@@ -19,9 +19,9 @@
#include <torch/all.h>
#include <cutlass/arch/arch.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include "../../vendors/functions.h"
#include "cutlass_extensions/common.hpp"
#include "cute/tensor.hpp"

View File

@@ -14,15 +14,10 @@
* limitations under the License.
*/
#include <torch/all.h>
#include <cuda_runtime_api.h>
#include <cuda_runtime.h>
#include "../../vendors/functions.h"
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp8.h>
#include "dispatch_utils.h"
#include "nvfp4_utils.cuh"

View File

@@ -14,7 +14,9 @@
* limitations under the License.
*/
#include <torch/all.h>
#include "../../vendors/functions.h"
#if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \
(defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120)

View File

@@ -14,15 +14,9 @@
* limitations under the License.
*/
#include <torch/all.h>
#include "../../vendors/functions.h"
#include <cuda_runtime_api.h>
#include <cuda_runtime.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda_fp8.h>
#include "dispatch_utils.h"
#include "cuda_utils.h"

View File

@@ -14,8 +14,10 @@
* limitations under the License.
*/
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include "../../vendors/functions.h"
#include "cutlass_extensions/common.hpp"
#if defined ENABLE_NVFP4_SM100 && ENABLE_NVFP4_SM100

View File

@@ -14,10 +14,7 @@
* limitations under the License.
*/
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "../../vendors/functions.h"
#include "cutlass_extensions/common.hpp"

View File

@@ -14,10 +14,8 @@
* limitations under the License.
*/
#include <torch/all.h>
#include "../../vendors/functions.h"
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "cutlass_extensions/common.hpp"

View File

@@ -16,8 +16,8 @@
#pragma once
#include <cuda_runtime.h>
#include <cuda_fp8.h>
#include "../../vendors/functions.h"
#define ELTS_PER_THREAD 8

View File

@@ -1,6 +1,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "../../vendors/functions.h"
#include "../../dispatch_utils.h"
#include "layernorm_utils.cuh"

View File

@@ -1,8 +1,5 @@
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include "../../vendors/functions.h"
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include "../../cuda_compat.h"
#include "dispatch_utils.h"

View File

@@ -6,8 +6,8 @@ https://github.com/turboderp/exllama
#ifndef _matrix_view_cuh
#define _matrix_view_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "../../vendors/functions.h"
#include "qdq_util.cuh"

View File

@@ -6,11 +6,8 @@ https://github.com/qwopqwop200/GPTQ-for-LLaMa
#include <cstdint>
#include <cstdio>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "../../vendors/functions.h"
#include "compat.cuh"
#include "matrix_view.cuh"

View File

@@ -1,5 +1,5 @@
#include "allspark_utils.cuh"
#include <torch/all.h>
#include "../../vendors/functions.h"
#include "core/registration.h"
#include <cublas_v2.h>

View File

@@ -1,5 +1,5 @@
#include "allspark_utils.cuh"
#include <torch/all.h>
#include "../../vendors/functions.h"
#include "core/registration.h"
namespace allspark {

View File

@@ -1,11 +1,6 @@
#pragma once
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include "../../vendors/functions.h"
#include <iostream>
#include "../gptq_marlin/marlin_dtypes.cuh"
using marlin::MarlinScalarType2;

View File

@@ -1,12 +1,7 @@
#pragma once
#include <torch/all.h>
#include "../../vendors/functions.h"
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <iostream>
#ifndef MARLIN_NAMESPACE_NAME

View File

@@ -11,15 +11,9 @@ Redistribution and use in source and binary forms, with or without modification,
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
***********/
#include <torch/all.h>
#include "../../../vendors/functions.h"
#include <stdint.h>
#include <cuda_runtime.h>
#include <mma.h>
#include <cuda/annotated_ptr>
#include <c10/cuda/CUDAException.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include "core/registration.h"
#include "dispatch_utils.h"

View File

@@ -1,8 +1,6 @@
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "../../vendors/functions.h"
// clang-format off
// The cutlass include order matters (annoyingly)

View File

@@ -17,8 +17,7 @@
#pragma once
#include "base.h"
#include <cudaTypedefs.h>
#include "../../../../vendors/functions.h"
namespace marlin_24 {
// On CUDA earlier than 12.5, the ordered_metadata version of this instruction

View File

@@ -4,8 +4,7 @@
*/
// Include both AMD and NVIDIA fp8 types to avoid circular import
#include <c10/util/Float8_e4m3fnuz.h>
#include <c10/util/Float8_e4m3fn.h>
#include "../vendors/functions.h"
namespace vllm {

View File

@@ -2,9 +2,9 @@
// clang-format will break include orders
// clang-format off
#include <torch/all.h>
#include "../../../../vendors/functions.h"
#include <ATen/cuda/CUDAContext.h>
#include "cutlass/cutlass.h"

View File

@@ -1,4 +1,4 @@
#include <torch/all.h>
#include "../../../../vendors/functions.h"
#include "cuda_utils.h"
#include "cutlass_extensions/common.hpp"

View File

@@ -1,6 +1,6 @@
#pragma once
#include <torch/all.h>
#include "../../../../vendors/functions.h"
namespace vllm {

View File

@@ -1,11 +1,10 @@
#include "core/registration.h"
#include <torch/all.h>
#include <cutlass/arch/arch.h>
// #include <cutlass/arch/arch.h>
#include "../../../../vendors/functions.h"
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAStream.h>
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"

View File

@@ -1,8 +1,6 @@
#pragma once
#include <cuda.h>
#include <torch/all.h>
#include <c10/cuda/CUDAStream.h>
#include "../../../../vendors/functions.h"
#include "core/scalar_type.hpp"
#include "cutlass/bfloat16.h"

View File

@@ -1,5 +1,7 @@
#pragma once
#include "../../../../vendors/functions.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/collective/collective_builder.hpp"

View File

@@ -1,7 +1,5 @@
#include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "../../../../vendors/functions.h"
#include "cutlass/cutlass.h"
#include "grouped_mm_c3x.cuh"

View File

@@ -1,7 +1,5 @@
#include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "../../../../vendors/functions.h"
#include "cutlass/cutlass.h"
#include "grouped_mm_c3x.cuh"

View File

@@ -1,7 +1,4 @@
#include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "../../../../vendors/functions.h"
#include <iostream>

View File

@@ -2,7 +2,7 @@
#include <stddef.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include "../../../../vendors/functions.h"
// clang-format will break include orders
// clang-format off

View File

@@ -1,7 +1,7 @@
#include <cudaTypedefs.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include "../../../vendors/functions.h"
#include "cutlass_extensions/common.hpp"

View File

@@ -2,8 +2,9 @@
#include "dispatch_utils.h"
#include "cub_helpers.h"
#include "quantization/vectorization_utils.cuh"
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/Exceptions.h>
#include "../../../../vendors/functions.h"
namespace vllm {

View File

@@ -1,5 +1,4 @@
#include <ATen/cuda/CUDAContext.h>
#include "../../../vendors/functions.h"
#include "quantization/w8a8/per_token_group_quant_8bit.h"
#include <cmath>

View File

@@ -1,6 +1,4 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include "../../../vendors/functions.h"
#include "quantization/w8a8/per_token_group_quant_8bit.h"
void per_token_group_quant_int8(const torch::Tensor& input,

View File

@@ -1,6 +1,6 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/all.h>
#include <c10/cuda/CUDAGuard.h>
#include "../../../vendors/functions.h"
#include <cmath>

View File

@@ -1,6 +1,5 @@
#pragma once
#include <torch/all.h>
#include "../../vendors/functions.h"
// 8-bit per-token-group quantization helper used by both FP8 and INT8
void per_token_group_quant_8bit(const torch::Tensor& input,
torch::Tensor& output_q,