Optimize prefill performance on cpu backend (#8750)

This commit is contained in:
Ma Mingfei
2025-08-29 08:21:55 +08:00
committed by GitHub
parent 9f81d741a2
commit 5ad296bda1
9 changed files with 680 additions and 273 deletions

View File

@@ -174,18 +174,18 @@ void fused_experts_fp8_kernel_impl(
const int64_t stride_e = 2 * N * K;
const int64_t stride_n = K;
int64_t avg_M = std::max(int64_t(1), M * topk / E);
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(avg_M);
int64_t B_tmp_size_per_thread = MAX_CACHE_BLOCK_SIZE * BLOCK_N * std::max(K, N);
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
// get local pointers
int tid = at::get_thread_num();
int tid = get_thread_num();
scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K;
bool is_brgemm_used = false;
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB;
int64_t nb = i % NB;
loop_2d<at::Float8_e4m3fn>(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N);
// B shape [K, n_size] in vnni format
@@ -194,13 +194,14 @@ void fused_experts_fp8_kernel_impl(
const float* __restrict__ Bs =
w1s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K;
// do unpacking for the first row or a new expert
int32_t pre_expert_id = mb == 0 ? -1 : expert_ids[mb - 1];
bool do_unpack = (mb == mb0) || (expert_id != pre_expert_id);
// 1.a load A
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
int64_t m_size = offsets[mb + 1] - offsets[mb];
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(m_size);
is_brgemm_used = is_brgemm_used || use_brgemm;
for (int64_t m = 0; m < m_size; ++m) {
int32_t index = A_ids[m] / topk;
copy_stub(A + m * K, input + index * K, K);
@@ -211,7 +212,7 @@ void fused_experts_fp8_kernel_impl(
/* A */ A,
/* B */ B,
/* C */ ic0 + offset * 2 * N + nb * BLOCK_N,
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
/* Btmp */ B_tmp + tid * B_tmp_size_per_thread + nb_offset * BLOCK_N * K,
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
/* scale */ Bs,
/* M */ m_size,
@@ -221,10 +222,11 @@ void fused_experts_fp8_kernel_impl(
/* ldb */ n_size,
/* ldc */ 2 * N,
/* brg */ use_brgemm,
/* block_size_K */ block_size_K);
}
/* block_size_K */ block_size_K,
/* do_unpack */ do_unpack);
});
if (is_brgemm_used) {
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
@@ -248,22 +250,14 @@ void fused_experts_fp8_kernel_impl(
const int64_t stride_oc = IC;
// parallel on [MB2, NB2]
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
int tid = at::get_thread_num();
parallel_2d(MB2, NB2, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
int tid = get_thread_num();
alignas(64) scalar_t C[BLOCK_M * BLOCK_K];
bool is_brgemm_used = false;
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB2;
int64_t nb = i % NB2;
loop_2d<at::Float8_e4m3fn>(mb0, mb1, nb0, nb1, BLOCK_N * IC, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
int64_t m_size = offsets[mb + 1] - offsets[mb];
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(m_size);
is_brgemm_used = is_brgemm_used || use_brgemm;
// A ptr from ic1 of [M * topk, N] in sorted order
// so as to avoid copy A to tmp buffer again
const scalar_t* __restrict__ A = ic1 + offsets[mb] * N;
@@ -275,11 +269,15 @@ void fused_experts_fp8_kernel_impl(
const float* __restrict__ Bs =
w2s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K;
// do unpacking for the first row or a new expert
int32_t pre_expert_id = mb == 0 ? -1 : expert_ids[mb - 1];
bool do_unpack = (mb == mb0) || (expert_id != pre_expert_id);
tinygemm_kernel<scalar_t>(
/* A */ A,
/* B */ B,
/* C */ C,
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
/* Btmp */ B_tmp + tid * B_tmp_size_per_thread + nb_offset * BLOCK_N * IC,
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
/* scale */ Bs,
/* M */ m_size,
@@ -289,7 +287,8 @@ void fused_experts_fp8_kernel_impl(
/* ldb */ n_size,
/* ldc */ BLOCK_N,
/* brg */ use_brgemm,
/* block_size_K */ block_size_K);
/* block_size_K */ block_size_K,
/* do_unpack */ do_unpack);
// 2.b copy from C to ic2 in original order
// and also mul topk_weights in float32
@@ -298,9 +297,9 @@ void fused_experts_fp8_kernel_impl(
float weight = topk_weights[index];
copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size);
}
}
});
if (is_brgemm_used) {
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
@@ -374,20 +373,23 @@ void shared_expert_fp8_kernel_impl(
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M);
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
int tid = at::get_thread_num();
int64_t B_tmp_size_per_thread = MAX_CACHE_BLOCK_SIZE * BLOCK_N * std::max(K, N);
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB;
int64_t nb = i % NB;
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
int tid = get_thread_num();
loop_2d<at::Float8_e4m3fn>(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N);
// do unpacking for the first row
bool do_unpack = (mb == mb0);
tinygemm_kernel<scalar_t>(
/* A */ input + mb * BLOCK_M * K,
/* B */ packed_w1 + nb * BLOCK_N * K,
/* C */ ic0 + mb * BLOCK_M * 2 * N + nb * BLOCK_N,
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
/* Btmp */ B_tmp + tid * B_tmp_size_per_thread + nb_offset * BLOCK_N * K,
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
/* scale */ w1s + (nb / blocks_n_per_group) * scale_size_K,
/* M */ m_size,
@@ -397,8 +399,9 @@ void shared_expert_fp8_kernel_impl(
/* ldb */ n_size,
/* ldc */ 2 * N,
/* brg */ use_brgemm,
/* block_size_K */ block_size_K);
}
/* block_size_K */ block_size_K,
/* do_unpack */ do_unpack);
});
if (use_brgemm) {
at::native::cpublas::brgemm_release();
@@ -421,22 +424,23 @@ void shared_expert_fp8_kernel_impl(
scale_size_K = div_up(N, block_size_K);
// parallel on [MB2, NB2]
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
int tid = at::get_thread_num();
parallel_2d(MB2, NB2, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
int tid = get_thread_num();
alignas(64) scalar_t C[BLOCK_M * BLOCK_K];
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB2;
int64_t nb = i % NB2;
loop_2d<at::Float8_e4m3fn>(mb0, mb1, nb0, nb1, BLOCK_N * IC, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
// do unpacking for the first row
bool do_unpack = (mb == mb0);
// 2.a gemm: C = A @ B
tinygemm_kernel<scalar_t>(
/* A */ ic1 + mb * BLOCK_M * N,
/* B */ packed_w2 + nb * BLOCK_N * N,
/* C */ C,
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
/* Btmp */ B_tmp + tid * B_tmp_size_per_thread + nb_offset * BLOCK_N * IC,
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
/* scale */ w2s + (nb / blocks_n_per_group) * scale_size_K,
/* M */ m_size,
@@ -446,7 +450,8 @@ void shared_expert_fp8_kernel_impl(
/* ldb */ n_size,
/* ldc */ BLOCK_N,
/* brg */ use_brgemm,
/* block_size_K */ block_size_K);
/* block_size_K */ block_size_K,
/* do_unpack */ do_unpack);
// 2.b copy from C to output and add fused_experts_out
scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N;
@@ -454,7 +459,7 @@ void shared_expert_fp8_kernel_impl(
for (int64_t m = 0; m < m_size; ++m) {
add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size);
}
}
});
});
if (use_brgemm) {