Optimize prefill performance on cpu backend (#8750)

This commit is contained in:
Ma Mingfei
2025-08-29 08:21:55 +08:00
committed by GitHub
parent 9f81d741a2
commit 5ad296bda1
9 changed files with 680 additions and 273 deletions

View File

@@ -109,6 +109,120 @@ inline void add_mul_stub(
}
}
template <typename scalar_t, int BLOCK_N>
inline void silu_and_mul(
scalar_t* __restrict__ C,
const int32_t* __restrict__ C0, // x: x0, x1
const int32_t* __restrict__ C1, // y: y0, y1
const float* __restrict__ As,
const float* __restrict__ Bs0,
const float* __restrict__ Bs1,
const int32_t* __restrict__ Bcomp0,
const int32_t* __restrict__ Bcomp1,
int64_t m_size,
int64_t N) {
#if defined(CPU_CAPABILITY_AVX512)
constexpr int COLS = BLOCK_N / 16;
static_assert(COLS % 2 == 0);
__m512 vc0[COLS];
__m512 vc1[COLS];
__m512i vcomp0[COLS];
__m512i vcomp1[COLS];
__m512 vas;
__m512 vbs0[COLS];
__m512 vbs1[COLS];
auto load_scale_and_comp = [&](auto col) {
vcomp0[col] = _mm512_loadu_si512(Bcomp0 + col * 16);
vcomp1[col] = _mm512_loadu_si512(Bcomp1 + col * 16);
vbs0[col] = _mm512_loadu_ps(Bs0 + col * 16);
vbs1[col] = _mm512_loadu_ps(Bs1 + col * 16);
};
Unroll<COLS>{}(load_scale_and_comp);
auto scalec = [&](auto col, int64_t m) {
// update As
vas = _mm512_set1_ps(As[m]);
// C = As * (C - Bcomp) * Bs
__m512i vc32_0 = _mm512_loadu_si512(C0 + m * BLOCK_N + col * 16);
__m512i vc32_1 = _mm512_loadu_si512(C1 + m * BLOCK_N + col * 16);
vc0[col] = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc32_0, vcomp0[col]));
vc1[col] = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc32_1, vcomp1[col]));
vc0[col] = _mm512_mul_ps(_mm512_mul_ps(vc0[col], vas), vbs0[col]);
vc1[col] = _mm512_mul_ps(_mm512_mul_ps(vc1[col], vas), vbs1[col]);
};
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
const fVec one = fVec(1.f);
auto silu_and_mul = [&](auto col) {
fVec x = fVec(vc0[col]);
fVec y = fVec(vc1[col]);
x = x / (one + x.neg().exp_u20());
vc0[col] = x * y;
};
auto storec = [&](auto col, int64_t m) {
if constexpr (col % 2 == 0) {
fVec x0 = fVec(vc0[col + 0]);
fVec x1 = fVec(vc0[col + 1]);
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
out_vec.store(C + m * N + col * 16);
}
};
for (int64_t m = 0; m < m_size; ++m) {
Unroll<COLS>{}(scalec, m);
Unroll<COLS>{}(silu_and_mul);
Unroll<COLS>{}(storec, m);
}
#else
TORCH_CHECK(false, "silu_and_mul: scalar path not implemented!");
#endif
}
template <int BLOCK_N>
inline void scale_C(
float* __restrict__ C,
const int32_t* __restrict__ Ctmp,
const float* __restrict__ As,
const float* __restrict__ Bs,
const int32_t* __restrict__ Bcomp,
int64_t m_size) {
#if defined(CPU_CAPABILITY_AVX512)
constexpr int COLS = BLOCK_N / 16;
static_assert(COLS % 2 == 0);
__m512 vc[COLS];
__m512i vcomp[COLS];
__m512 vas;
__m512 vbs[COLS];
auto load_scale_and_comp = [&](auto col) {
vcomp[col] = _mm512_loadu_si512(Bcomp + col * 16);
vbs[col] = _mm512_loadu_ps(Bs + col * 16);
};
Unroll<COLS>{}(load_scale_and_comp);
auto scalec = [&](auto col, int64_t m) {
// update As
vas = _mm512_set1_ps(As[m]);
// C = As * (C - Bcomp) * Bs
__m512i vc32 = _mm512_loadu_si512(Ctmp + m * BLOCK_N + col * 16);
vc[col] = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc32, vcomp[col]));
vc[col] = _mm512_mul_ps(_mm512_mul_ps(vc[col], vas), vbs[col]);
_mm512_storeu_ps(C + m * BLOCK_N + col * 16, vc[col]);
};
for (int64_t m = 0; m < m_size; ++m) {
Unroll<COLS>{}(scalec, m);
}
#else
TORCH_CHECK(false, "scale_C: scalar path not implemented!");
#endif
}
/// gemm for w13
template <typename scalar_t, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_vnni {
@@ -515,28 +629,31 @@ void fused_experts_int8_kernel_impl(
const int64_t stride_e = 2 * N * packed_K;
const int64_t stride_n = packed_K;
int64_t avg_M = std::max(int64_t(1), M * topk / E);
const bool use_brgemm = can_use_brgemm<int8_t>(avg_M);
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
// get local pointers
int tid = at::get_thread_num();
int tid = get_thread_num();
uint8_t* __restrict__ A = A_tmp + tid * BLOCK_M * K;
int32_t* __restrict__ C0 = reinterpret_cast<int32_t*>(C_tmp) + tid * 2 * BLOCK_M * BLOCK_N;
int32_t* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N;
alignas(64) float As[BLOCK_M];
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB;
int64_t nb = i % NB;
// nb0 from top half and nb1 from bottom half
int64_t nb0 = nb, nb1 = nb + NB;
int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N);
loop_2d<int8_t>(mb0, mb1, nb0, nb1, BLOCK_N * K * 2, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
// nb_upper from top half and nb_lower from bottom half
int64_t nb_upper = nb, nb_lower = nb + NB;
int64_t n_size = std::min(N - nb * BLOCK_N, BLOCK_N);
// B shape [K, n_size] in vnni format
int32_t expert_id = expert_ids[mb];
const int8_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb0 * BLOCK_N * stride_n;
const int8_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb1 * BLOCK_N * stride_n;
const float* __restrict__ Bs0 = w1s + expert_id * 2 * N + nb0 * BLOCK_N;
const float* __restrict__ Bs1 = w1s + expert_id * 2 * N + nb1 * BLOCK_N;
const int8_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb_upper * BLOCK_N * stride_n;
const int8_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb_lower * BLOCK_N * stride_n;
const float* __restrict__ Bs0 = w1s + expert_id * 2 * N + nb_upper * BLOCK_N;
const float* __restrict__ Bs1 = w1s + expert_id * 2 * N + nb_lower * BLOCK_N;
// 1.a load A
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
@@ -548,22 +665,62 @@ void fused_experts_int8_kernel_impl(
As[m] = As_tmp[index];
}
// fused 1.b: silu_and_mul(A @ B0, A @ B1)
const int64_t offset = offsets[mb];
tinygemm_kernel(
/* A */ A,
/* B0 */ B0,
/* B1 */ B1,
/* C */ ic1 + offset * N + nb * BLOCK_N,
/* As */ As,
/* Bs0 */ Bs0,
/* Bs1 */ Bs1,
/* M */ m_size,
/* N */ n_size,
/* K */ K,
/* lda */ K,
/* ldb */ n_size,
/* ldc */ N);
if (use_brgemm) {
// 1.b gemm: C0 = A @ B0
at::native::cpublas::brgemm(
/* M */ m_size,
/* N */ n_size,
/* K */ K,
/* lda */ K,
/* ldb */ n_size,
/* ldc */ BLOCK_N,
/* add_C */ false,
/* A */ A,
/* B */ B0,
/* C */ C0);
// 1.c gemm: C1 = A @ B1
at::native::cpublas::brgemm(
/* M */ m_size,
/* N */ n_size,
/* K */ K,
/* lda */ K,
/* ldb */ n_size,
/* ldc */ BLOCK_N,
/* add_C */ false,
/* A */ A,
/* B */ B1,
/* C */ C1);
const int32_t* Bcomp0 = reinterpret_cast<const int32_t*>(B0 + block_size_n() * K);
const int32_t* Bcomp1 = reinterpret_cast<const int32_t*>(B1 + block_size_n() * K);
// 1.d silu and mul
const int64_t offset = offsets[mb];
silu_and_mul<scalar_t, BLOCK_N>(
ic1 + offset * N + nb * BLOCK_N, C0, C1, As, Bs0, Bs1, Bcomp0, Bcomp1, m_size, N);
} else {
// fused 1.bcd: silu_and_mul(A @ B0, A @ B1)
const int64_t offset = offsets[mb];
tinygemm_kernel(
/* A */ A,
/* B0 */ B0,
/* B1 */ B1,
/* C */ ic1 + offset * N + nb * BLOCK_N,
/* As */ As,
/* Bs0 */ Bs0,
/* Bs1 */ Bs1,
/* M */ m_size,
/* N */ n_size,
/* K */ K,
/* lda */ K,
/* ldb */ n_size,
/* ldc */ N);
}
});
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
@@ -584,16 +741,13 @@ void fused_experts_int8_kernel_impl(
const int64_t stride_oc = packed_N;
// parallel on [MB2, NB2]
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
parallel_2d(MB2, NB2, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
// get local pointers
int tid = at::get_thread_num();
// we won't be using C1 for gemm2
int tid = get_thread_num();
float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
int32_t* __restrict__ C32 = reinterpret_cast<int32_t*>(C + BLOCK_M * BLOCK_N);
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB2;
int64_t nb = i % NB2;
loop_2d<int8_t>(mb0, mb1, nb0, nb1, BLOCK_N * IC, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
int64_t m_size = offsets[mb + 1] - offsets[mb];
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
@@ -609,18 +763,36 @@ void fused_experts_int8_kernel_impl(
const float* __restrict__ Bs = w2s + expert_id * K + nb * BLOCK_N;
// 2.a gemm: C = A @ B
tinygemm_kernel<scalar_t>(
/* A */ A,
/* B */ B,
/* C */ C,
/* As */ As,
/* Bs */ Bs,
/* M */ m_size,
/* N */ n_size,
/* K */ IC,
/* lda */ IC,
/* ldb */ n_size,
/* ldc */ BLOCK_N);
if (use_brgemm) {
at::native::cpublas::brgemm(
/* M */ m_size,
/* N */ n_size,
/* K */ IC,
/* lda */ IC,
/* ldb */ n_size,
/* ldc */ BLOCK_N,
/* add_C */ false,
/* A */ A,
/* B */ B,
/* C */ C32);
// apply scales
const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * IC);
scale_C<BLOCK_N>(C, C32, As, Bs, Bcomp, m_size);
} else {
tinygemm_kernel<scalar_t>(
/* A */ A,
/* B */ B,
/* C */ C,
/* As */ As,
/* Bs */ Bs,
/* M */ m_size,
/* N */ n_size,
/* K */ IC,
/* lda */ IC,
/* ldb */ n_size,
/* ldc */ BLOCK_N);
}
// 2.b copy from C to ic2 in original order
// and also mul topk_weights in float32
@@ -629,6 +801,10 @@ void fused_experts_int8_kernel_impl(
float weight = topk_weights[index];
copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size);
}
});
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
@@ -708,15 +884,19 @@ void shared_expert_int8_kernel_impl(
const int64_t packed_N = get_row_size<int8_t>(N);
const int64_t stride_n = packed_K;
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB;
int64_t nb = i % NB;
const bool use_brgemm = can_use_brgemm<int8_t>(M);
// nb0 from top half and nb1 from bottom half
int64_t nb0 = nb, nb1 = nb + NB;
int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N);
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
// get local pointers
int tid = get_thread_num();
int32_t* __restrict__ C0 = reinterpret_cast<int32_t*>(C_tmp) + tid * 2 * BLOCK_M * BLOCK_N;
int32_t* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N;
loop_2d<int8_t>(mb0, mb1, nb0, nb1, BLOCK_N * K * 2, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
// nb_upper from top half and nb_lower from bottom half
int64_t nb_upper = nb, nb_lower = nb + NB;
int64_t n_size = std::min(N - nb * BLOCK_N, BLOCK_N);
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
// A shape [m_size, K]
@@ -724,26 +904,65 @@ void shared_expert_int8_kernel_impl(
const float* As = As_tmp + mb * BLOCK_M;
// B shape [K, n_size] in vnni format
const int8_t* __restrict__ B0 = packed_w1 + nb0 * BLOCK_N * stride_n;
const int8_t* __restrict__ B1 = packed_w1 + nb1 * BLOCK_N * stride_n;
const float* __restrict__ Bs0 = w1s + nb0 * BLOCK_N;
const float* __restrict__ Bs1 = w1s + nb1 * BLOCK_N;
const int8_t* __restrict__ B0 = packed_w1 + nb_upper * BLOCK_N * stride_n;
const int8_t* __restrict__ B1 = packed_w1 + nb_lower * BLOCK_N * stride_n;
const float* __restrict__ Bs0 = w1s + nb_upper * BLOCK_N;
const float* __restrict__ Bs1 = w1s + nb_lower * BLOCK_N;
// fused 1.b: silu_and_mul(A @ B0, A @ B1)
tinygemm_kernel(
/* A */ A,
/* B0 */ B0,
/* B1 */ B1,
/* C */ ic1 + mb * BLOCK_M * N + nb * BLOCK_N,
/* As */ As,
/* Bs0 */ Bs0,
/* Bs1 */ Bs1,
/* M */ m_size,
/* N */ n_size,
/* K */ K,
/* lda */ K,
/* ldb */ n_size,
/* ldc */ N);
if (use_brgemm) {
// 1.b gemm: C0 = A @ B0
at::native::cpublas::brgemm(
/* M */ m_size,
/* N */ n_size,
/* K */ K,
/* lda */ K,
/* ldb */ n_size,
/* ldc */ BLOCK_N,
/* add_C */ false,
/* A */ A,
/* B */ B0,
/* C */ C0);
// 1.c gemm: C1 = A @ B1
at::native::cpublas::brgemm(
/* M */ m_size,
/* N */ n_size,
/* K */ K,
/* lda */ K,
/* ldb */ n_size,
/* ldc */ BLOCK_N,
/* add_C */ false,
/* A */ A,
/* B */ B1,
/* C */ C1);
const int32_t* Bcomp0 = reinterpret_cast<const int32_t*>(B0 + block_size_n() * K);
const int32_t* Bcomp1 = reinterpret_cast<const int32_t*>(B1 + block_size_n() * K);
// 1.d silu and mul
silu_and_mul<scalar_t, BLOCK_N>(
ic1 + mb * BLOCK_M * N + nb * BLOCK_N, C0, C1, As, Bs0, Bs1, Bcomp0, Bcomp1, m_size, N);
} else {
// fused 1.bcd: silu_and_mul(A @ B0, A @ B1)
tinygemm_kernel(
/* A */ A,
/* B0 */ B0,
/* B1 */ B1,
/* C */ ic1 + mb * BLOCK_M * N + nb * BLOCK_N,
/* As */ As,
/* Bs0 */ Bs0,
/* Bs1 */ Bs1,
/* M */ m_size,
/* N */ n_size,
/* K */ K,
/* lda */ K,
/* ldb */ n_size,
/* ldc */ N);
}
});
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
@@ -763,16 +982,13 @@ void shared_expert_int8_kernel_impl(
const int64_t stride_oc = packed_N;
// parallel on [MB2, NB2]
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
parallel_2d(MB2, NB2, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
// get local pointers
int tid = at::get_thread_num();
// we won't be using C1 for gemm2
int tid = get_thread_num();
float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
int32_t* __restrict__ C32 = reinterpret_cast<int32_t*>(C + BLOCK_M * BLOCK_N);
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB2;
int64_t nb = i % NB2;
loop_2d<int8_t>(mb0, mb1, nb0, nb1, BLOCK_N * IC, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
@@ -784,19 +1000,37 @@ void shared_expert_int8_kernel_impl(
const int8_t* __restrict__ B = packed_w2 + nb * BLOCK_N * stride_oc;
const float* __restrict__ Bs = w2s + nb * BLOCK_N;
// 2.a gemm: C = A @ B
tinygemm_kernel<scalar_t>(
/* A */ A,
/* B */ B,
/* C */ C,
/* As */ As,
/* Bs */ Bs,
/* M */ m_size,
/* N */ n_size,
/* K */ IC,
/* lda */ IC,
/* ldb */ n_size,
/* ldc */ BLOCK_N);
if (use_brgemm) {
at::native::cpublas::brgemm(
/* M */ m_size,
/* N */ n_size,
/* K */ IC,
/* lda */ IC,
/* ldb */ n_size,
/* ldc */ BLOCK_N,
/* add_C */ false,
/* A */ A,
/* B */ B,
/* C */ C32);
// apply scales
const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * IC);
scale_C<BLOCK_N>(C, C32, As, Bs, Bcomp, m_size);
} else {
// 2.a gemm: C = A @ B
tinygemm_kernel<scalar_t>(
/* A */ A,
/* B */ B,
/* C */ C,
/* As */ As,
/* Bs */ Bs,
/* M */ m_size,
/* N */ n_size,
/* K */ IC,
/* lda */ IC,
/* ldb */ n_size,
/* ldc */ BLOCK_N);
}
// 2.b copy from C to output and add fused_experts_out
scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N;
@@ -804,6 +1038,10 @@ void shared_expert_int8_kernel_impl(
for (int64_t m = 0; m < m_size; ++m) {
add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size);
}
});
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
}