Optimize prefill performance on cpu backend (#8750)
This commit is contained in:
@@ -4,6 +4,61 @@
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t, bool has_bias, int BLOCK_N>
|
||||
struct scale_C {
|
||||
static inline void apply(
|
||||
scalar_t* __restrict__ C,
|
||||
const int32_t* __restrict__ Ctmp,
|
||||
const int32_t* __restrict__ Bcomp,
|
||||
const float* __restrict__ bias,
|
||||
float As,
|
||||
const float* __restrict__ Bs) {
|
||||
TORCH_CHECK(false, "scale_C: scalar path not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <bool has_bias, int BLOCK_N>
|
||||
struct scale_C<at::BFloat16, has_bias, BLOCK_N> {
|
||||
static inline void apply(
|
||||
at::BFloat16* __restrict__ C,
|
||||
const int32_t* __restrict__ Ctmp,
|
||||
const int32_t* __restrict__ Bcomp,
|
||||
const float* __restrict__ bias,
|
||||
float As,
|
||||
const float* __restrict__ Bs) {
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
static_assert(COLS % 2 == 0);
|
||||
|
||||
__m512 vc[COLS];
|
||||
__m512 vd0 = _mm512_set1_ps(As);
|
||||
|
||||
auto compute = [&](auto col) {
|
||||
__m512 vd1 = _mm512_loadu_ps(Bs + col * 16);
|
||||
__m512i vcomp = _mm512_loadu_si512(Bcomp + col * 16);
|
||||
__m512i vc32 = _mm512_loadu_si512(Ctmp + col * 16);
|
||||
vc[col] = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc32, vcomp));
|
||||
if constexpr (has_bias) {
|
||||
__m512 vbias = _mm512_loadu_ps(bias + col * 16);
|
||||
vc[col] = _mm512_fmadd_ps(_mm512_mul_ps(vc[col], vd0), vd1, vbias);
|
||||
} else {
|
||||
vc[col] = _mm512_mul_ps(_mm512_mul_ps(vc[col], vd0), vd1);
|
||||
}
|
||||
};
|
||||
Unroll<COLS>{}(compute);
|
||||
|
||||
auto storec = [&](auto col) {
|
||||
// for COLS = 2, 4 use 512bit store
|
||||
if constexpr (col % 2 == 0) {
|
||||
_mm512_storeu_si512(
|
||||
reinterpret_cast<__m512i*>((C + col * 16)), (__m512i)(_mm512_cvtne2ps_pbh(vc[col + 1], vc[col + 0])));
|
||||
}
|
||||
};
|
||||
Unroll<COLS>{}(storec);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename scalar_t, bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn {
|
||||
static inline void apply(
|
||||
@@ -169,6 +224,17 @@ void tinygemm_kernel(
|
||||
// B compensation
|
||||
const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * K);
|
||||
|
||||
if (brg) {
|
||||
constexpr int BLOCK_N = block_size_n();
|
||||
at::native::cpublas::brgemm(M, N, K, lda, ldb, BLOCK_N, /* add_C */ false, A, B, Ctmp);
|
||||
|
||||
// apply compensation and scale
|
||||
for (int64_t m = 0; m < M; ++m) {
|
||||
scale_C<scalar_t, has_bias, BLOCK_N>::apply(C + m * ldc, Ctmp + m * BLOCK_N, Bcomp, bias, As[m], Bs);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// pattern: 1-4-16
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 64;
|
||||
@@ -233,22 +299,17 @@ void int8_scaled_mm_kernel_impl(
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
// TODO: brgemm u8s8 depends on PyTorch 2.7 release.
|
||||
const bool use_brgemm = false;
|
||||
const bool use_brgemm = can_use_brgemm<int8_t>(M);
|
||||
|
||||
// K + 4 after compensation
|
||||
const int64_t packed_row_size = get_row_size<int8_t>(K);
|
||||
|
||||
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t mb{0}, nb{0};
|
||||
data_index_init(begin, mb, MB, nb, NB);
|
||||
|
||||
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
|
||||
// for brgemm, use int32_t for accumulate
|
||||
alignas(64) int32_t Ctmp[BLOCK_M * BLOCK_N];
|
||||
|
||||
for (int i = begin; i < end; ++i) {
|
||||
UNUSED(i);
|
||||
loop_2d<int8_t>(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
|
||||
int mb_start = mb * BLOCK_M;
|
||||
int mb_size = std::min(M - mb_start, BLOCK_M);
|
||||
int nb_start = nb * BLOCK_N;
|
||||
@@ -269,10 +330,7 @@ void int8_scaled_mm_kernel_impl(
|
||||
/* ldb */ nb_size,
|
||||
/* ldc */ N,
|
||||
/* brg */ use_brgemm);
|
||||
|
||||
// move to the next index
|
||||
data_index_step(mb, MB, nb, NB);
|
||||
}
|
||||
});
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
|
||||
Reference in New Issue
Block a user