Optimize prefill performance on cpu backend (#8750)

This commit is contained in:
Ma Mingfei
2025-08-29 08:21:55 +08:00
committed by GitHub
parent 9f81d741a2
commit 5ad296bda1
9 changed files with 680 additions and 273 deletions

View File

@@ -254,7 +254,7 @@ void tinygemm_kernel(
return;
}
// pattern: 1-4-16
// pattern: 1-4-16, N = 16, 32, 48, 64
constexpr int64_t BLOCK_M = 4;
constexpr int64_t BLOCK_N = 64;
const int64_t MB = div_up(M, BLOCK_M);
@@ -268,35 +268,59 @@ void tinygemm_kernel(
switch (mb_size << 4 | nb_size >> 4) {
// mb_size = 1
case 0x11:
LAUNCH_TINYGEMM_KERNEL_NN(1, 16);
break;
case 0x12:
LAUNCH_TINYGEMM_KERNEL_NN(1, 32);
break;
case 0x13:
LAUNCH_TINYGEMM_KERNEL_NN(1, 48);
break;
case 0x14:
LAUNCH_TINYGEMM_KERNEL_NN(1, 64);
break;
// mb_size = 2
case 0x21:
LAUNCH_TINYGEMM_KERNEL_NN(2, 16);
break;
case 0x22:
LAUNCH_TINYGEMM_KERNEL_NN(2, 32);
break;
case 0x23:
LAUNCH_TINYGEMM_KERNEL_NN(2, 48);
break;
case 0x24:
LAUNCH_TINYGEMM_KERNEL_NN(2, 64);
break;
// mb_size = 3
case 0x31:
LAUNCH_TINYGEMM_KERNEL_NN(3, 16);
break;
case 0x32:
LAUNCH_TINYGEMM_KERNEL_NN(3, 32);
break;
case 0x33:
LAUNCH_TINYGEMM_KERNEL_NN(3, 48);
break;
case 0x34:
LAUNCH_TINYGEMM_KERNEL_NN(3, 64);
break;
// mb_size = 4
case 0x41:
LAUNCH_TINYGEMM_KERNEL_NN(4, 16);
break;
case 0x42:
LAUNCH_TINYGEMM_KERNEL_NN(4, 32);
break;
case 0x43:
LAUNCH_TINYGEMM_KERNEL_NN(4, 48);
break;
case 0x44:
LAUNCH_TINYGEMM_KERNEL_NN(4, 64);
break;
default:
TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
TORCH_CHECK(false, "Unexpected block size, ", mb_size, " x ", nb_size);
}
}
}
@@ -318,20 +342,15 @@ void weight_packed_linear_kernel_impl(
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
// use avx512-bf16 when a) M is small; b) dtype is bfloat16, otherwise use amx c) N is small
const bool use_brgemm = (M > 4) || (!std::is_same_v<scalar_t, at::BFloat16>) || (N < 64);
const bool use_brgemm = can_use_brgemm<scalar_t>(M);
// parallel on [MB, NB]
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
int64_t mb{0}, nb{0};
data_index_init(begin, mb, MB, nb, NB);
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
// for brgemm, use float32 for accumulate
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
for (int64_t i = begin; i < end; ++i) {
UNUSED(i);
loop_2d<scalar_t>(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(M - mb_start, BLOCK_M);
int64_t nb_start = nb * BLOCK_N;
@@ -350,10 +369,7 @@ void weight_packed_linear_kernel_impl(
/* ldb */ nb_size,
/* ldc */ out_strideM,
/* brg */ use_brgemm);
// move to the next index
data_index_step(mb, MB, nb, NB);
}
});
if (use_brgemm) {
at::native::cpublas::brgemm_release();