Optimize prefill performance on cpu backend (#8750)
This commit is contained in:
@@ -254,7 +254,7 @@ void tinygemm_kernel(
|
||||
return;
|
||||
}
|
||||
|
||||
// pattern: 1-4-16
|
||||
// pattern: 1-4-16, N = 16, 32, 48, 64
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 64;
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
@@ -268,35 +268,59 @@ void tinygemm_kernel(
|
||||
|
||||
switch (mb_size << 4 | nb_size >> 4) {
|
||||
// mb_size = 1
|
||||
case 0x11:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(1, 16);
|
||||
break;
|
||||
case 0x12:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(1, 32);
|
||||
break;
|
||||
case 0x13:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(1, 48);
|
||||
break;
|
||||
case 0x14:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(1, 64);
|
||||
break;
|
||||
// mb_size = 2
|
||||
case 0x21:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(2, 16);
|
||||
break;
|
||||
case 0x22:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(2, 32);
|
||||
break;
|
||||
case 0x23:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(2, 48);
|
||||
break;
|
||||
case 0x24:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(2, 64);
|
||||
break;
|
||||
// mb_size = 3
|
||||
case 0x31:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(3, 16);
|
||||
break;
|
||||
case 0x32:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(3, 32);
|
||||
break;
|
||||
case 0x33:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(3, 48);
|
||||
break;
|
||||
case 0x34:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(3, 64);
|
||||
break;
|
||||
// mb_size = 4
|
||||
case 0x41:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(4, 16);
|
||||
break;
|
||||
case 0x42:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(4, 32);
|
||||
break;
|
||||
case 0x43:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(4, 48);
|
||||
break;
|
||||
case 0x44:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(4, 64);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
|
||||
TORCH_CHECK(false, "Unexpected block size, ", mb_size, " x ", nb_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -318,20 +342,15 @@ void weight_packed_linear_kernel_impl(
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
// use avx512-bf16 when a) M is small; b) dtype is bfloat16, otherwise use amx c) N is small
|
||||
const bool use_brgemm = (M > 4) || (!std::is_same_v<scalar_t, at::BFloat16>) || (N < 64);
|
||||
const bool use_brgemm = can_use_brgemm<scalar_t>(M);
|
||||
|
||||
// parallel on [MB, NB]
|
||||
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t mb{0}, nb{0};
|
||||
data_index_init(begin, mb, MB, nb, NB);
|
||||
|
||||
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
|
||||
// for brgemm, use float32 for accumulate
|
||||
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
UNUSED(i);
|
||||
loop_2d<scalar_t>(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(M - mb_start, BLOCK_M);
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
@@ -350,10 +369,7 @@ void weight_packed_linear_kernel_impl(
|
||||
/* ldb */ nb_size,
|
||||
/* ldc */ out_strideM,
|
||||
/* brg */ use_brgemm);
|
||||
|
||||
// move to the next index
|
||||
data_index_step(mb, MB, nb, NB);
|
||||
}
|
||||
});
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
|
||||
Reference in New Issue
Block a user