diff --git a/.github/workflows/pr-test-amd.yml b/.github/workflows/pr-test-amd.yml index 02f79f7cb..7835b1ec0 100644 --- a/.github/workflows/pr-test-amd.yml +++ b/.github/workflows/pr-test-amd.yml @@ -342,6 +342,7 @@ jobs: docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_moe_topk_softmax.py docker exec -w /sglang-checkout/sgl-kernel/tests/speculative ci_sglang python3 -m pytest test_eagle_utils.py docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_apply_token_bitmask_inplace.py + docker exec -w /sglang-checkout/sgl-kernel/tests ci_sglang python3 -m pytest test_activation.py pr-test-amd-finish: if: always() diff --git a/sgl-kernel/csrc/allreduce/mscclpp_allreduce.cuh b/sgl-kernel/csrc/allreduce/mscclpp_allreduce.cuh index 2e064d704..ba0bc33fd 100644 --- a/sgl-kernel/csrc/allreduce/mscclpp_allreduce.cuh +++ b/sgl-kernel/csrc/allreduce/mscclpp_allreduce.cuh @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. #pragma once -#if defined(__HIP_PLATFORM_AMD__) +#ifdef USE_ROCM #include #else #include diff --git a/sgl-kernel/csrc/elementwise/activation.cu b/sgl-kernel/csrc/elementwise/activation.cu index 20b889530..43617f87f 100644 --- a/sgl-kernel/csrc/elementwise/activation.cu +++ b/sgl-kernel/csrc/elementwise/activation.cu @@ -25,7 +25,7 @@ #include "utils.h" #else -#include "hip_act_and_mul.cuh" +#include "hip/hip_act_and_mul.cuh" #endif // Adapted from flashinfer activation diff --git a/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu b/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu index 6da13d079..7afff7794 100644 --- a/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu +++ b/sgl-kernel/csrc/gemm/per_tensor_quant_fp8.cu @@ -69,7 +69,7 @@ __global__ void per_tensor_quant_fp8_kernel( #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { float val = fmax(fmin(static_cast(input_vec[j]) * scale_val, FP8_E4M3_MAX), -FP8_E4M3_MAX); -#ifndef USE_ROCM +#if !defined(USE_ROCM) || defined(HIP_FP8_TYPE_E4M3) output_arr[j] = static_cast(val); #else output_arr[j] = c10::Float8_e4m3fnuz( @@ -83,7 +83,7 @@ __global__ void per_tensor_quant_fp8_kernel( const int32_t remaining_start = num_vec_elems * VEC_SIZE; for (int32_t idx = remaining_start + gid; idx < num_elements; idx += grid_size) { float val = fmax(-FP8_E4M3_MAX, fmin(static_cast(input[idx]) * scale_val, FP8_E4M3_MAX)); -#ifndef USE_ROCM +#if !defined(USE_ROCM) || defined(HIP_FP8_TYPE_E4M3) output[idx] = static_cast(val); #else output[idx] = c10::Float8_e4m3fnuz( diff --git a/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu b/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu index c71022fd1..e73716c86 100644 --- a/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu +++ b/sgl-kernel/csrc/gemm/per_token_quant_fp8.cu @@ -67,7 +67,7 @@ __global__ void per_token_quant_fp8_kernel( for (uint32_t j = 0; j < kVecSize; ++j) { float val = static_cast(input_vec[j]) * scale_inv; val = fmaxf(fminf(val, FP8_E4M3_MAX), -FP8_E4M3_MAX); -#ifndef USE_ROCM +#if !defined(USE_ROCM) || defined(HIP_FP8_TYPE_E4M3) output_arr[j] = static_cast(val); #else output_arr[j] = c10::Float8_e4m3fnuz( @@ -143,7 +143,7 @@ __global__ void per_token_quant_fp8_small_batch_kernel( #pragma unroll for (uint32_t j = 0; j < kVecSize; ++j) { float val = fmaxf(fminf(static_cast(input_vec[j]) * scale_inv, FP8_E4M3_MAX), -FP8_E4M3_MAX); -#ifndef USE_ROCM +#if !defined(USE_ROCM) || defined(HIP_FP8_TYPE_E4M3) output_arr[j] = static_cast(val); #else output_arr[j] = c10::Float8_e4m3fnuz( diff --git a/sgl-kernel/csrc/moe/moe_align_kernel.cu b/sgl-kernel/csrc/moe/moe_align_kernel.cu index 19d0cc7a9..92fd34270 100644 --- a/sgl-kernel/csrc/moe/moe_align_kernel.cu +++ b/sgl-kernel/csrc/moe/moe_align_kernel.cu @@ -21,8 +21,6 @@ limitations under the License. #include "utils.h" -#define WARP_SIZE 32 - #define VEC_SIZE 4 using Vec = int4; diff --git a/sgl-kernel/include/hip_act_and_mul.cuh b/sgl-kernel/include/hip/hip_act_and_mul.cuh similarity index 100% rename from sgl-kernel/include/hip_act_and_mul.cuh rename to sgl-kernel/include/hip/hip_act_and_mul.cuh diff --git a/sgl-kernel/include/hip_math_def.h b/sgl-kernel/include/hip/hip_math_def.h similarity index 98% rename from sgl-kernel/include/hip_math_def.h rename to sgl-kernel/include/hip/hip_math_def.h index 21cc67456..356ed953f 100644 --- a/sgl-kernel/include/hip_math_def.h +++ b/sgl-kernel/include/hip/hip_math_def.h @@ -15,7 +15,7 @@ limitations under the License. #pragma once -#if defined(__HIP_PLATFORM_AMD__) +#ifdef USE_ROCM #include #include diff --git a/sgl-kernel/include/hip_vec_dtypes.h b/sgl-kernel/include/hip/hip_vec_dtypes.h similarity index 100% rename from sgl-kernel/include/hip_vec_dtypes.h rename to sgl-kernel/include/hip/hip_vec_dtypes.h diff --git a/sgl-kernel/include/impl/hip_vec_bf16_impl.h b/sgl-kernel/include/hip/impl/hip_vec_bf16_impl.h similarity index 100% rename from sgl-kernel/include/impl/hip_vec_bf16_impl.h rename to sgl-kernel/include/hip/impl/hip_vec_bf16_impl.h diff --git a/sgl-kernel/include/impl/hip_vec_fp32_impl.h b/sgl-kernel/include/hip/impl/hip_vec_fp32_impl.h similarity index 100% rename from sgl-kernel/include/impl/hip_vec_fp32_impl.h rename to sgl-kernel/include/hip/impl/hip_vec_fp32_impl.h diff --git a/sgl-kernel/include/impl/hip_vec_half_impl.h b/sgl-kernel/include/hip/impl/hip_vec_half_impl.h similarity index 100% rename from sgl-kernel/include/impl/hip_vec_half_impl.h rename to sgl-kernel/include/hip/impl/hip_vec_half_impl.h diff --git a/sgl-kernel/include/utils.h b/sgl-kernel/include/utils.h index d78049a68..56f322764 100644 --- a/sgl-kernel/include/utils.h +++ b/sgl-kernel/include/utils.h @@ -331,13 +331,15 @@ inline bool getEnvEnablePDL() { #ifndef USE_ROCM #define WARP_SIZE 32 #else -#define WARP_SIZE warpSize // 64 +#include +#include +#define WARP_SIZE C10_WARP_SIZE #endif -#if defined(__HIP_PLATFORM_AMD__) +#ifdef USE_ROCM -#include "hip_math_def.h" -#include "hip_vec_dtypes.h" +#include "hip/hip_math_def.h" +#include "hip/hip_vec_dtypes.h" #else @@ -354,14 +356,11 @@ __device__ __forceinline__ dstDtype castFromFloat(float val) { #endif // add FP8 support - #ifndef USE_ROCM #include using FP8_TYPE = c10::Float8_e4m3fn; C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits::max(); - #else // USE_ROCM - #if HIP_FP8_TYPE_FNUZ #include using FP8_TYPE = c10::Float8_e4m3fnuz; diff --git a/sgl-kernel/setup_rocm.py b/sgl-kernel/setup_rocm.py index ac61e4df9..02c2019ff 100644 --- a/sgl-kernel/setup_rocm.py +++ b/sgl-kernel/setup_rocm.py @@ -72,6 +72,9 @@ if amdgpu_target not in ["gfx942", "gfx950"]: ) sys.exit(1) +fp8_macro = ( + "-DHIP_FP8_TYPE_FNUZ" if amdgpu_target == "gfx942" else "-DHIP_FP8_TYPE_E4M3" +) hipcc_flags = [ "-DNDEBUG", @@ -80,10 +83,10 @@ hipcc_flags = [ "-Xcompiler", "-fPIC", "-std=c++17", - "-D__HIP_PLATFORM_AMD__=1", f"--amdgpu-target={amdgpu_target}", "-DENABLE_BF16", "-DENABLE_FP8", + fp8_macro, ] ext_modules = [