diff --git a/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu b/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu index e69167a4d..4f9e3b959 100644 --- a/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu +++ b/sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu @@ -260,7 +260,11 @@ torch::Tensor fp8_blockwise_scaled_mm( #if defined(CUTLASS_ARCH_MMA_SM100A_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) #if defined CUDA_VERSION && CUDA_VERSION >= 12080 - if (sm_version == 100) { + if (sm_version == 100 +#if CUDA_VERSION >= 12090 + || sm_version == 103 +#endif + ) { if (out_dtype == torch::kBFloat16) { sm100_fp8_blockwise_dispatch_shape( out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b); diff --git a/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu b/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu index 77b5c500f..0a9e6b7a5 100644 --- a/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu +++ b/sgl-kernel/csrc/gemm/fp8_gemm_kernel.cu @@ -1212,7 +1212,11 @@ torch::Tensor fp8_scaled_mm( auto sm_version = getSMVersion(); #if defined CUDA_VERSION && CUDA_VERSION >= 12080 - if (sm_version >= 100) { + if (sm_version == 100 +#if CUDA_VERSION >= 12090 + || sm_version == 103 +#endif + ) { if (out_dtype == torch::kBFloat16) { sm100_fp8_dispatch_shape(out, mat_a, mat_b, scales_a, scales_b, bias); } else { diff --git a/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu b/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu index 1a11ce2d7..b2e1fc83c 100644 --- a/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu +++ b/sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu @@ -708,7 +708,11 @@ void fp8_blockwise_scaled_grouped_mm( #if defined(CUTLASS_ARCH_MMA_SM100A_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) #if defined CUDA_VERSION && CUDA_VERSION >= 12080 - if (sm_version == 100) { + if (sm_version == 100 +#if CUDA_VERSION >= 12090 + || sm_version == 103 +#endif + ) { if (output.scalar_type() == torch::kBFloat16) { sm100_fp8_blockwise_group_mm_dispatch_shape( output, @@ -802,5 +806,5 @@ void fp8_blockwise_scaled_grouped_mm( } #endif TORCH_CHECK_NOT_IMPLEMENTED( - can_implement, "No implemented fp8_blockwise_scaled_mm for current compute capability: ", sm_version); + can_implement, "No implemented fp8_blockwise_scaled_grouped_mm for current compute capability: ", sm_version); }