chore: upgrade cutlass 3.9.2 (#6004)
Co-authored-by: yizhang2077 <1109276519@qq.com>
This commit is contained in:
@@ -45,7 +45,7 @@ include(FetchContent)
|
||||
FetchContent_Declare(
|
||||
repo-cutlass
|
||||
GIT_REPOSITORY https://github.com/NVIDIA/cutlass
|
||||
GIT_TAG e94e888df3551224738bfa505787b515eae8352f
|
||||
GIT_TAG ad7b2f5e84fcfa124cb02b91d5bd26d238c0459e
|
||||
GIT_SHALLOW OFF
|
||||
)
|
||||
FetchContent_Populate(repo-cutlass)
|
||||
|
||||
@@ -384,16 +384,23 @@ torch::Tensor fp8_blockwise_scaled_mm(
|
||||
|
||||
auto sm_version = getSMVersion();
|
||||
|
||||
int64_t original_rows = mat_a.size(0);
|
||||
torch::Tensor mat_a_padded = pad_tensor(mat_a, /*alignment=*/4);
|
||||
torch::Tensor scales_a_padded = pad_tensor(scales_a, /*alignment=*/4, /*col_major=*/true);
|
||||
torch::Tensor out_padded = torch::empty({mat_a_padded.size(0), mat_b.size(1)}, out.options());
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED)
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
|
||||
if (sm_version == 90) {
|
||||
torch::Tensor scales_b_contiguous = scales_b.contiguous();
|
||||
if (out_dtype == torch::kBFloat16) {
|
||||
sm90_fp8_blockwise_dispatch_shape<cutlass::bfloat16_t>(out, mat_a, mat_b, scales_a, scales_b_contiguous);
|
||||
sm90_fp8_blockwise_dispatch_shape<cutlass::bfloat16_t>(
|
||||
out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b_contiguous);
|
||||
} else {
|
||||
sm90_fp8_blockwise_dispatch_shape<cutlass::half_t>(out, mat_a, mat_b, scales_a, scales_b_contiguous);
|
||||
sm90_fp8_blockwise_dispatch_shape<cutlass::half_t>(
|
||||
out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b_contiguous);
|
||||
}
|
||||
return out;
|
||||
return out_padded.slice(0, 0, original_rows);
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
@@ -401,12 +408,6 @@ torch::Tensor fp8_blockwise_scaled_mm(
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100A_SUPPORTED) || defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
#if defined CUDA_VERSION && CUDA_VERSION >= 12080
|
||||
if (sm_version == 100) {
|
||||
int64_t original_rows = mat_a.size(0);
|
||||
|
||||
torch::Tensor mat_a_padded = pad_tensor(mat_a, /*alignment=*/4);
|
||||
torch::Tensor scales_a_padded = pad_tensor(scales_a, /*alignment=*/4, /*col_major=*/true);
|
||||
torch::Tensor out_padded = torch::empty({mat_a_padded.size(0), mat_b.size(1)}, out.options());
|
||||
|
||||
if (out_dtype == torch::kBFloat16) {
|
||||
sm100_fp8_blockwise_dispatch_shape<cutlass::bfloat16_t>(
|
||||
out_padded, mat_a_padded, mat_b, scales_a_padded, scales_b);
|
||||
|
||||
Reference in New Issue
Block a user