Optimize cutlass int8 gemm kernel for large M on SM89 Ada GPU (#10714)
This commit is contained in:
@@ -409,8 +409,8 @@ void sm89_dispatch_shape(
|
|||||||
cutlass_int8_scaled_mm<
|
cutlass_int8_scaled_mm<
|
||||||
ElementOutput,
|
ElementOutput,
|
||||||
ArchTag,
|
ArchTag,
|
||||||
cutlass::gemm::GemmShape<32, 64, 128>,
|
cutlass::gemm::GemmShape<128, 128, 64>,
|
||||||
cutlass::gemm::GemmShape<16, 64, 64>,
|
cutlass::gemm::GemmShape<64, 64, 64>,
|
||||||
InstructionShape,
|
InstructionShape,
|
||||||
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
5>(out, mat_a, mat_b, scales_a, scales_b, bias);
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user