Update int8 gemm config (#2774)

This commit is contained in:
Ke Bao
2025-01-07 19:47:37 +08:00
committed by GitHub
parent bdc1acf6cd
commit 58f9060efe
2 changed files with 16 additions and 5 deletions

View File

@@ -88,10 +88,11 @@ void cutlass_int8_scaled_mm(torch::Tensor& out, const torch::Tensor& mat_a, cons
auto stream = at::cuda::getCurrentCUDAStream(mat_a.get_device());
auto can_implement = gemm_op.can_implement(args);
TORCH_CHECK(can_implement == cutlass::Status::kSuccess)
TORCH_CHECK(can_implement == cutlass::Status::kSuccess,
"gemm cannot implement, error: ", cutlassGetStatusString(can_implement));
auto status = gemm_op(args, workspace.data_ptr(), stream);
TORCH_CHECK(status == cutlass::Status::kSuccess)
TORCH_CHECK(status == cutlass::Status::kSuccess, "gemm executioin failed, error: ", cutlassGetStatusString(status));
}
template <typename ElementOutput, typename ArchTag, typename InstructionShape>
@@ -144,7 +145,17 @@ void sm80_dispatch_shape(torch::Tensor& out, const torch::Tensor& mat_a, const t
cutlass::gemm::GemmShape<32, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a,
scales_b, bias);
}
} else if (m <= 64 || (m <= 128 && n < 8192)) {
} else if (m <= 64) {
if (n <= 4096) {
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<64, 64, 128>,
cutlass::gemm::GemmShape<32, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a,
scales_b, bias);
} else {
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<64, 128, 128>,
cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a,
scales_b, bias);
}
} else if (m <= 128 && n < 8192) {
cutlass_int8_scaled_mm<ElementOutput, ArchTag, cutlass::gemm::GemmShape<64, 128, 128>,
cutlass::gemm::GemmShape<64, 64, 64>, InstructionShape, 5>(out, mat_a, mat_b, scales_a,
scales_b, bias);