Optimize prefill performance on cpu backend (#8750)

This commit is contained in:
Ma Mingfei
2025-08-29 08:21:55 +08:00
committed by GitHub
parent 9f81d741a2
commit 5ad296bda1
9 changed files with 680 additions and 273 deletions

View File

@@ -579,36 +579,31 @@ void fused_experts_kernel_impl(
const int64_t stride_e = 2 * N * K;
const int64_t stride_n = K;
int64_t avg_M = std::max(int64_t(1), M * topk / E);
const bool use_brgemm = can_use_brgemm<scalar_t>(avg_M);
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
// get local pointers
int tid = at::get_thread_num();
int tid = get_thread_num();
scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K;
float* __restrict__ C0 = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
float* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N;
bool is_brgemm_used = false;
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB;
int64_t nb = i % NB;
// nb0 from top half and nb1 from bottom half
int64_t nb0 = nb, nb1 = nb + NB;
int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N);
loop_2d<scalar_t>(mb0, mb1, nb0, nb1, BLOCK_N * K * 2, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
// nb_upper from top half and nb_lower from bottom half
int64_t nb_upper = nb, nb_lower = nb + NB;
int64_t n_size = std::min(N - nb * BLOCK_N, BLOCK_N);
// B shape [K, n_size] in vnni format
int32_t expert_id = expert_ids[mb];
const scalar_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb0 * BLOCK_N * stride_n;
const scalar_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb1 * BLOCK_N * stride_n;
const scalar_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb_upper * BLOCK_N * stride_n;
const scalar_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb_lower * BLOCK_N * stride_n;
// 1.a load A
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
int64_t m_size = offsets[mb + 1] - offsets[mb];
const bool use_brgemm = can_use_brgemm<scalar_t>(m_size);
is_brgemm_used = is_brgemm_used || use_brgemm;
for (int64_t m = 0; m < m_size; ++m) {
int32_t index = A_ids[m] / topk;
copy_stub(A + m * K, input + index * K, K);
@@ -659,9 +654,9 @@ void fused_experts_kernel_impl(
/* ldb */ n_size,
/* ldc */ N);
}
}
});
if (is_brgemm_used) {
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
@@ -676,24 +671,16 @@ void fused_experts_kernel_impl(
const int64_t stride_oc = IC;
// parallel on [MB2, NB2]
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
parallel_2d(MB2, NB2, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
// get local pointers
int tid = at::get_thread_num();
int tid = get_thread_num();
// we won't be using C1 for gemm2
float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
bool is_brgemm_used = false;
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB2;
int64_t nb = i % NB2;
loop_2d<scalar_t>(mb0, mb1, nb0, nb1, BLOCK_N * IC, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
int64_t m_size = offsets[mb + 1] - offsets[mb];
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
const bool use_brgemm = can_use_brgemm<scalar_t>(m_size);
is_brgemm_used = is_brgemm_used || use_brgemm;
// A ptr from ic1 of [M * topk, N] in sorted order
// so as to avoid copy A to tmp buffer again
const scalar_t* __restrict__ A = ic1 + offsets[mb] * N;
@@ -736,9 +723,9 @@ void fused_experts_kernel_impl(
float weight = topk_weights[index];
copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size);
}
}
});
if (is_brgemm_used) {
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
@@ -776,36 +763,27 @@ void shared_expert_kernel_impl(
TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N);
const int64_t stride_n = K;
const bool use_brgemm = can_use_brgemm<scalar_t>(M);
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
// get local pointers
int tid = at::get_thread_num();
int tid = get_thread_num();
float* __restrict__ C0 = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
float* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N;
bool is_brgemm_used = false;
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB;
int64_t nb = i % NB;
// nb0 from top half and nb1 from bottom half
int64_t nb0 = nb, nb1 = nb + NB;
int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N);
loop_2d<scalar_t>(mb0, mb1, nb0, nb1, BLOCK_N * K * 2, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
// nb_upper from top half and nb_lower from bottom half
int64_t nb_upper = nb, nb_lower = nb + NB;
int64_t n_size = std::min(N - nb * BLOCK_N, BLOCK_N);
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
// int64_t mb_start = mb * BLOCK_M;
// int64_t mb_size = std::min(M - mb_start, BLOCK_M);
// A shape [m_size, K]
const scalar_t* A = input + mb * BLOCK_M * K;
// B shape [K, n_size] in vnni format
const scalar_t* __restrict__ B0 = packed_w1 + nb0 * BLOCK_N * stride_n;
const scalar_t* __restrict__ B1 = packed_w1 + nb1 * BLOCK_N * stride_n;
const bool use_brgemm = can_use_brgemm<scalar_t>(m_size);
is_brgemm_used = is_brgemm_used || use_brgemm;
const scalar_t* __restrict__ B0 = packed_w1 + nb_upper * BLOCK_N * stride_n;
const scalar_t* __restrict__ B1 = packed_w1 + nb_lower * BLOCK_N * stride_n;
if (use_brgemm) {
// 1.b gemm: C0 = A @ B0
@@ -850,9 +828,9 @@ void shared_expert_kernel_impl(
/* ldb */ n_size,
/* ldc */ N);
}
}
});
if (is_brgemm_used) {
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
@@ -866,24 +844,16 @@ void shared_expert_kernel_impl(
const int64_t stride_oc = IC;
// parallel on [MB2, NB2]
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
parallel_2d(MB2, NB2, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
// get local pointers
int tid = at::get_thread_num();
int tid = get_thread_num();
// we won't be using C1 for gemm2
float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
bool is_brgemm_used = false;
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB2;
int64_t nb = i % NB2;
loop_2d<scalar_t>(mb0, mb1, nb0, nb1, BLOCK_N * IC, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
const bool use_brgemm = can_use_brgemm<scalar_t>(m_size);
is_brgemm_used = is_brgemm_used || use_brgemm;
// A shape [m_size, IC]
const scalar_t* __restrict__ A = ic1 + mb * BLOCK_M * N;
@@ -922,9 +892,9 @@ void shared_expert_kernel_impl(
for (int64_t m = 0; m < m_size; ++m) {
add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size);
}
}
});
if (is_brgemm_used) {
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
@@ -1086,7 +1056,7 @@ at::Tensor fused_experts_cpu(
//
// for fp8 w8a16:
// 7. intermediate_cache0 : [M * topk, 2N]
// 8. B_tmp : [T, BLOCK_N, std::max(K, N)]
// 8. B_tmp : [T, MAX_CACHE_BLOCK_SIZE, BLOCK_N, std::max(K, N)]
//
int64_t buffer_size_nbytes = M * topk * N * 2 + M * topk * K * 2 +
num_threads * BLOCK_M * K * (use_int8_w8a8 ? 1 : 2) +
@@ -1096,7 +1066,7 @@ at::Tensor fused_experts_cpu(
buffer_size_nbytes += std::max(M * K, M * topk * N) + M * topk * sizeof(float);
}
if (use_fp8_w8a16) {
buffer_size_nbytes += M * topk * 2 * N * 2 + num_threads * BLOCK_N * std::max(K, N) * 2;
buffer_size_nbytes += M * topk * 2 * N * 2 + num_threads * MAX_CACHE_BLOCK_SIZE * BLOCK_N * std::max(K, N) * 2;
}
auto buffer2 = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar));
@@ -1268,7 +1238,7 @@ at::Tensor shared_expert_cpu(
//
// for fp8 w8a16:
// 5. intermediate_cache0 : [M, 2N]
// 6. B_tmp: [T, BLOCK_M, max(K, N)]
// 6. B_tmp: [T, MAX_CACHE_BLOCK_SIZE, BLOCK_M, max(K, N)]
//
int num_threads = at::get_num_threads();
int64_t buffer_size_nbytes = M * N * 2 + num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float);
@@ -1277,7 +1247,7 @@ at::Tensor shared_expert_cpu(
buffer_size_nbytes += std::max(M * K, M * N) + M * sizeof(float);
}
if (use_fp8_w8a16) {
buffer_size_nbytes += M * 2 * N * 2 + num_threads * BLOCK_M * std::max(K, N) * 2;
buffer_size_nbytes += M * 2 * N * 2 + num_threads * MAX_CACHE_BLOCK_SIZE * BLOCK_M * std::max(K, N) * 2;
}
auto buffer = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar));