Make sm100 fp8 kernels available on sm103 (#9789)
Signed-off-by: Hao Lu <14827759+hlu1@users.noreply.github.com>
This commit is contained in:
@@ -260,7 +260,11 @@ torch::Tensor fp8_blockwise_scaled_mm(
|
|||||||
|
|
||||||
#if defined(CUTLASS_ARCH_MMA_SM100A_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
#if defined(CUTLASS_ARCH_MMA_SM100A_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12080
|
#if defined CUDA_VERSION && CUDA_VERSION >= 12080
|
||||||
if (sm_version == 100) {
|
if (sm_version == 100
|
||||||
|
#if CUDA_VERSION >= 12090
|
||||||
|
|| sm_version == 103
|
||||||
|
#endif
|
||||||
|
) {
|
||||||
if (out_dtype == torch::kBFloat16) {
|
if (out_dtype == torch::kBFloat16) {
|
||||||
sm100_fp8_blockwise_dispatch_shape<cutlass::bfloat16_t>(
|
sm100_fp8_blockwise_dispatch_shape<cutlass::bfloat16_t>(
|
||||||
out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b);
|
out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b);
|
||||||
|
|||||||
@@ -1212,7 +1212,11 @@ torch::Tensor fp8_scaled_mm(
|
|||||||
auto sm_version = getSMVersion();
|
auto sm_version = getSMVersion();
|
||||||
|
|
||||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12080
|
#if defined CUDA_VERSION && CUDA_VERSION >= 12080
|
||||||
if (sm_version >= 100) {
|
if (sm_version == 100
|
||||||
|
#if CUDA_VERSION >= 12090
|
||||||
|
|| sm_version == 103
|
||||||
|
#endif
|
||||||
|
) {
|
||||||
if (out_dtype == torch::kBFloat16) {
|
if (out_dtype == torch::kBFloat16) {
|
||||||
sm100_fp8_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
sm100_fp8_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -708,7 +708,11 @@ void fp8_blockwise_scaled_grouped_mm(
|
|||||||
|
|
||||||
#if defined(CUTLASS_ARCH_MMA_SM100A_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
#if defined(CUTLASS_ARCH_MMA_SM100A_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12080
|
#if defined CUDA_VERSION && CUDA_VERSION >= 12080
|
||||||
if (sm_version == 100) {
|
if (sm_version == 100
|
||||||
|
#if CUDA_VERSION >= 12090
|
||||||
|
|| sm_version == 103
|
||||||
|
#endif
|
||||||
|
) {
|
||||||
if (output.scalar_type() == torch::kBFloat16) {
|
if (output.scalar_type() == torch::kBFloat16) {
|
||||||
sm100_fp8_blockwise_group_mm_dispatch_shape<cutlass::bfloat16_t>(
|
sm100_fp8_blockwise_group_mm_dispatch_shape<cutlass::bfloat16_t>(
|
||||||
output,
|
output,
|
||||||
@@ -802,5 +806,5 @@ void fp8_blockwise_scaled_grouped_mm(
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||||
can_implement, "No implemented fp8_blockwise_scaled_mm for current compute capability: ", sm_version);
|
can_implement, "No implemented fp8_blockwise_scaled_grouped_mm for current compute capability: ", sm_version);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user