Optimize cutlass int8 gemm kernel for large M on SM89 Ada GPU (#10714)
This commit is contained in:
@@ -409,8 +409,8 @@ void sm89_dispatch_shape(
|
||||
cutlass_int8_scaled_mm<
|
||||
ElementOutput,
|
||||
ArchTag,
|
||||
cutlass::gemm::GemmShape<32, 64, 128>,
|
||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||
InstructionShape,
|
||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user