Optimize prefill performance on cpu backend (#8750)

This commit is contained in:
Ma Mingfei
2025-08-29 08:21:55 +08:00
committed by GitHub
parent 9f81d741a2
commit 5ad296bda1
9 changed files with 680 additions and 273 deletions

View File

@@ -2,9 +2,6 @@
#include "gemm.h"
#include "vec.h"
// we use 4x32 for BLOCK_M
#define BLOCK_SIZE_M_SCALE 4
namespace {
template <typename scalar_t>
@@ -250,7 +247,8 @@ struct brgemm {
int K,
int lda,
int ldb,
int ldc) {
int ldc,
bool do_unpack = true) {
TORCH_CHECK(false, "struct brgemm: primary template not implemented!");
}
};
@@ -270,17 +268,20 @@ struct brgemm<at::BFloat16, at::Float8_e4m3fn, has_bias> {
int K,
int lda,
int ldb,
int ldc) {
int ldc,
bool do_unpack = true) {
constexpr int BLOCK_N = block_size_n();
// [K, BLOCK_N] -> [K / 2, BLOCK_N * 2]
const int ldb_tmp = BLOCK_N;
for (int k = 0; k < K; k += BLOCK_K) {
int kb_size = std::min(BLOCK_K, K - k);
if (do_unpack) {
for (int k = 0; k < K; k += BLOCK_K) {
int kb_size = std::min(BLOCK_K, K - k);
int idx = k >> 7; // k / BLOCK_K where BLOCK_K = 128
unpack_B(Btmp + k * ldb_tmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]);
int idx = k >> 7; // k / BLOCK_K where BLOCK_K = 128
unpack_B(Btmp + k * ldb_tmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]);
}
}
at::native::cpublas::brgemm(M, N, K, lda, ldb_tmp, BLOCK_N, /* add_C */ false, A, Btmp, Ctmp);
@@ -312,9 +313,11 @@ void tinygemm_kernel(
int64_t ldb,
int64_t ldc,
bool brg,
int64_t block_size_K) {
int64_t block_size_K,
bool do_unpack = true) {
if (brg) {
brgemm<scalar_t, at::Float8_e4m3fn, has_bias>::apply(A, B, C, Btmp, Ctmp, bias, scale, M, N, K, lda, ldb, ldc);
brgemm<scalar_t, at::Float8_e4m3fn, has_bias>::apply(
A, B, C, Btmp, Ctmp, bias, scale, M, N, K, lda, ldb, ldc, do_unpack);
return;
}
@@ -366,7 +369,7 @@ void fp8_scaled_mm_kernel_impl(
int64_t block_size_N,
int64_t block_size_K,
int64_t buffer_size_per_thread) {
constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE;
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
@@ -378,16 +381,12 @@ void fp8_scaled_mm_kernel_impl(
// parallel on [MB, NB]
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
int64_t mb{0}, nb{0};
data_index_init(begin, mb, MB, nb, NB);
int tid = at::get_thread_num();
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
int tid = get_thread_num();
scalar_t* __restrict__ Btmp = buffer + tid * buffer_size_per_thread;
float* __restrict__ Ctmp = (float*)((void*)(Btmp + BLOCK_N * K));
float* __restrict__ Ctmp = (float*)((void*)(Btmp + MAX_CACHE_BLOCK_SIZE * BLOCK_N * K));
for (int64_t i = begin; i < end; ++i) {
UNUSED(i);
loop_2d<at::Float8_e4m3fn>(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
const float* scale_ptr = scales2 + (nb / blocks_n_per_group) * scale_size_K;
int64_t mb_start = mb * BLOCK_M;
@@ -395,11 +394,14 @@ void fp8_scaled_mm_kernel_impl(
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(N - nb_start, BLOCK_N);
// only do unpacking for the first row
bool do_unpack = (mb == mb0);
tinygemm_kernel<scalar_t, has_bias>(
/* A */ mat1 + mb_start * mat1_strideM,
/* B */ mat2 + nb_start * K, // nb * BLOCK_N * K
/* C */ out + mb_start * out_strideM + nb_start,
/* Btmp */ Btmp,
/* Btmp */ Btmp + nb_offset * BLOCK_N * K,
/* Ctmp */ Ctmp,
/* scale */ scale_ptr,
/* bias */ bias + nb_start,
@@ -410,11 +412,9 @@ void fp8_scaled_mm_kernel_impl(
/* ldb */ nb_size,
/* ldc */ out_strideM,
/* brg */ use_brgemm,
/* block_size_K */ block_size_K);
// move to the next index
data_index_step(mb, MB, nb, NB);
}
/* block_size_K */ block_size_K,
/* do_unpack */ do_unpack);
});
if (use_brgemm) {
at::native::cpublas::brgemm_release();
@@ -441,8 +441,10 @@ void tinygemm_kernel(
int64_t ldb,
int64_t ldc,
bool brg,
int64_t block_size_K) {
tinygemm_kernel<scalar_t, false>(A, B, C, Btmp, Ctmp, scale, nullptr, M, N, K, lda, ldb, ldc, brg, block_size_K);
int64_t block_size_K,
bool do_unpack) {
tinygemm_kernel<scalar_t, false>(
A, B, C, Btmp, Ctmp, scale, nullptr, M, N, K, lda, ldb, ldc, brg, block_size_K, do_unpack);
}
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
@@ -460,7 +462,8 @@ void tinygemm_kernel(
int64_t ldb, \
int64_t ldc, \
bool brg, \
int64_t block_size_K)
int64_t block_size_K, \
bool do_unpack)
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
@@ -495,7 +498,7 @@ at::Tensor fp8_scaled_mm_cpu(
int64_t block_size_N = block_size[0];
int64_t block_size_K = block_size[1];
constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE;
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
TORCH_CHECK(block_size_N % BLOCK_N == 0, "fp8_scaled_mm_cpu: expect block_size_N to be multiples of BLOCK_N");
TORCH_CHECK(block_size_K == BLOCK_K, "fp8_scaled_mm_cpu: expect block_size_K equals to BLOCK_K");
@@ -523,7 +526,7 @@ at::Tensor fp8_scaled_mm_cpu(
// Btmp : [T, BLOCK_N * K]
// Ctmp : [T, BLOCK_M * BLOCK_N]
int num_threads = at::get_num_threads();
int64_t size_per_thread = BLOCK_N * K + BLOCK_M * BLOCK_N * 2;
int64_t size_per_thread = MAX_CACHE_BLOCK_SIZE * BLOCK_N * K + BLOCK_M * BLOCK_N * 2;
auto buffer = at::empty({num_threads, size_per_thread}, mat1.options());
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "fp8_scaled_mm_kernel_impl", [&] {