Optimize prefill performance on cpu backend (#8750)
This commit is contained in:
@@ -579,36 +579,31 @@ void fused_experts_kernel_impl(
|
||||
const int64_t stride_e = 2 * N * K;
|
||||
const int64_t stride_n = K;
|
||||
|
||||
int64_t avg_M = std::max(int64_t(1), M * topk / E);
|
||||
const bool use_brgemm = can_use_brgemm<scalar_t>(avg_M);
|
||||
|
||||
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
|
||||
// get local pointers
|
||||
int tid = at::get_thread_num();
|
||||
int tid = get_thread_num();
|
||||
scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K;
|
||||
float* __restrict__ C0 = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
|
||||
float* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N;
|
||||
|
||||
bool is_brgemm_used = false;
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB;
|
||||
int64_t nb = i % NB;
|
||||
|
||||
// nb0 from top half and nb1 from bottom half
|
||||
int64_t nb0 = nb, nb1 = nb + NB;
|
||||
int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N);
|
||||
loop_2d<scalar_t>(mb0, mb1, nb0, nb1, BLOCK_N * K * 2, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
|
||||
// nb_upper from top half and nb_lower from bottom half
|
||||
int64_t nb_upper = nb, nb_lower = nb + NB;
|
||||
int64_t n_size = std::min(N - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
// B shape [K, n_size] in vnni format
|
||||
int32_t expert_id = expert_ids[mb];
|
||||
const scalar_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb0 * BLOCK_N * stride_n;
|
||||
const scalar_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb1 * BLOCK_N * stride_n;
|
||||
const scalar_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb_upper * BLOCK_N * stride_n;
|
||||
const scalar_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb_lower * BLOCK_N * stride_n;
|
||||
|
||||
// 1.a load A
|
||||
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
|
||||
int64_t m_size = offsets[mb + 1] - offsets[mb];
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<scalar_t>(m_size);
|
||||
is_brgemm_used = is_brgemm_used || use_brgemm;
|
||||
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
int32_t index = A_ids[m] / topk;
|
||||
copy_stub(A + m * K, input + index * K, K);
|
||||
@@ -659,9 +654,9 @@ void fused_experts_kernel_impl(
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ N);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (is_brgemm_used) {
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
@@ -676,24 +671,16 @@ void fused_experts_kernel_impl(
|
||||
const int64_t stride_oc = IC;
|
||||
|
||||
// parallel on [MB2, NB2]
|
||||
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
|
||||
parallel_2d(MB2, NB2, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
|
||||
// get local pointers
|
||||
int tid = at::get_thread_num();
|
||||
int tid = get_thread_num();
|
||||
// we won't be using C1 for gemm2
|
||||
float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
|
||||
|
||||
bool is_brgemm_used = false;
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB2;
|
||||
int64_t nb = i % NB2;
|
||||
|
||||
loop_2d<scalar_t>(mb0, mb1, nb0, nb1, BLOCK_N * IC, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
|
||||
int64_t m_size = offsets[mb + 1] - offsets[mb];
|
||||
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<scalar_t>(m_size);
|
||||
is_brgemm_used = is_brgemm_used || use_brgemm;
|
||||
|
||||
// A ptr from ic1 of [M * topk, N] in sorted order
|
||||
// so as to avoid copy A to tmp buffer again
|
||||
const scalar_t* __restrict__ A = ic1 + offsets[mb] * N;
|
||||
@@ -736,9 +723,9 @@ void fused_experts_kernel_impl(
|
||||
float weight = topk_weights[index];
|
||||
copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (is_brgemm_used) {
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
@@ -776,36 +763,27 @@ void shared_expert_kernel_impl(
|
||||
TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N);
|
||||
const int64_t stride_n = K;
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<scalar_t>(M);
|
||||
|
||||
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
|
||||
// get local pointers
|
||||
int tid = at::get_thread_num();
|
||||
int tid = get_thread_num();
|
||||
float* __restrict__ C0 = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
|
||||
float* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N;
|
||||
|
||||
bool is_brgemm_used = false;
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB;
|
||||
int64_t nb = i % NB;
|
||||
|
||||
// nb0 from top half and nb1 from bottom half
|
||||
int64_t nb0 = nb, nb1 = nb + NB;
|
||||
int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N);
|
||||
loop_2d<scalar_t>(mb0, mb1, nb0, nb1, BLOCK_N * K * 2, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
|
||||
// nb_upper from top half and nb_lower from bottom half
|
||||
int64_t nb_upper = nb, nb_lower = nb + NB;
|
||||
int64_t n_size = std::min(N - nb * BLOCK_N, BLOCK_N);
|
||||
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
|
||||
// int64_t mb_start = mb * BLOCK_M;
|
||||
// int64_t mb_size = std::min(M - mb_start, BLOCK_M);
|
||||
|
||||
// A shape [m_size, K]
|
||||
const scalar_t* A = input + mb * BLOCK_M * K;
|
||||
|
||||
// B shape [K, n_size] in vnni format
|
||||
const scalar_t* __restrict__ B0 = packed_w1 + nb0 * BLOCK_N * stride_n;
|
||||
const scalar_t* __restrict__ B1 = packed_w1 + nb1 * BLOCK_N * stride_n;
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<scalar_t>(m_size);
|
||||
is_brgemm_used = is_brgemm_used || use_brgemm;
|
||||
const scalar_t* __restrict__ B0 = packed_w1 + nb_upper * BLOCK_N * stride_n;
|
||||
const scalar_t* __restrict__ B1 = packed_w1 + nb_lower * BLOCK_N * stride_n;
|
||||
|
||||
if (use_brgemm) {
|
||||
// 1.b gemm: C0 = A @ B0
|
||||
@@ -850,9 +828,9 @@ void shared_expert_kernel_impl(
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ N);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (is_brgemm_used) {
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
@@ -866,24 +844,16 @@ void shared_expert_kernel_impl(
|
||||
const int64_t stride_oc = IC;
|
||||
|
||||
// parallel on [MB2, NB2]
|
||||
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
|
||||
parallel_2d(MB2, NB2, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
|
||||
// get local pointers
|
||||
int tid = at::get_thread_num();
|
||||
int tid = get_thread_num();
|
||||
// we won't be using C1 for gemm2
|
||||
float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
|
||||
|
||||
bool is_brgemm_used = false;
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB2;
|
||||
int64_t nb = i % NB2;
|
||||
|
||||
loop_2d<scalar_t>(mb0, mb1, nb0, nb1, BLOCK_N * IC, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
|
||||
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<scalar_t>(m_size);
|
||||
is_brgemm_used = is_brgemm_used || use_brgemm;
|
||||
|
||||
// A shape [m_size, IC]
|
||||
const scalar_t* __restrict__ A = ic1 + mb * BLOCK_M * N;
|
||||
|
||||
@@ -922,9 +892,9 @@ void shared_expert_kernel_impl(
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (is_brgemm_used) {
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
@@ -1086,7 +1056,7 @@ at::Tensor fused_experts_cpu(
|
||||
//
|
||||
// for fp8 w8a16:
|
||||
// 7. intermediate_cache0 : [M * topk, 2N]
|
||||
// 8. B_tmp : [T, BLOCK_N, std::max(K, N)]
|
||||
// 8. B_tmp : [T, MAX_CACHE_BLOCK_SIZE, BLOCK_N, std::max(K, N)]
|
||||
//
|
||||
int64_t buffer_size_nbytes = M * topk * N * 2 + M * topk * K * 2 +
|
||||
num_threads * BLOCK_M * K * (use_int8_w8a8 ? 1 : 2) +
|
||||
@@ -1096,7 +1066,7 @@ at::Tensor fused_experts_cpu(
|
||||
buffer_size_nbytes += std::max(M * K, M * topk * N) + M * topk * sizeof(float);
|
||||
}
|
||||
if (use_fp8_w8a16) {
|
||||
buffer_size_nbytes += M * topk * 2 * N * 2 + num_threads * BLOCK_N * std::max(K, N) * 2;
|
||||
buffer_size_nbytes += M * topk * 2 * N * 2 + num_threads * MAX_CACHE_BLOCK_SIZE * BLOCK_N * std::max(K, N) * 2;
|
||||
}
|
||||
|
||||
auto buffer2 = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar));
|
||||
@@ -1268,7 +1238,7 @@ at::Tensor shared_expert_cpu(
|
||||
//
|
||||
// for fp8 w8a16:
|
||||
// 5. intermediate_cache0 : [M, 2N]
|
||||
// 6. B_tmp: [T, BLOCK_M, max(K, N)]
|
||||
// 6. B_tmp: [T, MAX_CACHE_BLOCK_SIZE, BLOCK_M, max(K, N)]
|
||||
//
|
||||
int num_threads = at::get_num_threads();
|
||||
int64_t buffer_size_nbytes = M * N * 2 + num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float);
|
||||
@@ -1277,7 +1247,7 @@ at::Tensor shared_expert_cpu(
|
||||
buffer_size_nbytes += std::max(M * K, M * N) + M * sizeof(float);
|
||||
}
|
||||
if (use_fp8_w8a16) {
|
||||
buffer_size_nbytes += M * 2 * N * 2 + num_threads * BLOCK_M * std::max(K, N) * 2;
|
||||
buffer_size_nbytes += M * 2 * N * 2 + num_threads * MAX_CACHE_BLOCK_SIZE * BLOCK_M * std::max(K, N) * 2;
|
||||
}
|
||||
|
||||
auto buffer = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar));
|
||||
|
||||
Reference in New Issue
Block a user