CPU: map changes from developing branch in sgl-kernel (#6833)
Co-authored-by: mingfeima <mingfei.ma@intel.com>
This commit is contained in:
@@ -10,11 +10,72 @@ namespace {
|
||||
// 3. computes attention for prefix and extend separately
|
||||
// 4. TODO: vectorize `pack_vnni` and `pack_vnni2`
|
||||
//
|
||||
|
||||
template <typename index_t>
|
||||
inline index_t get_index(index_t* ind, int i) {
|
||||
return (ind == nullptr) ? (index_t)i : ind[i];
|
||||
}
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
// key: from [N, 32] to [32/2, N, 2]
|
||||
template <typename scalar_t, typename index_t>
|
||||
inline void pack_vnni_Nx32(
|
||||
scalar_t* __restrict__ dst,
|
||||
const scalar_t* __restrict__ src,
|
||||
const index_t* __restrict__ ind,
|
||||
int N,
|
||||
int ld_src,
|
||||
int ld_dst) {
|
||||
__m512i vinputs[16];
|
||||
|
||||
int n = 0;
|
||||
for (; n < N; ++n) {
|
||||
index_t index = get_index(ind, n);
|
||||
vinputs[n] = _mm512_loadu_si512(src + index * ld_src);
|
||||
}
|
||||
// padding with zero to avoid uninitialized vectors
|
||||
for (; n < 16; ++n) {
|
||||
vinputs[n] = _mm512_set1_epi32(0);
|
||||
}
|
||||
|
||||
// pack key
|
||||
transpose_16x16_32bit(vinputs);
|
||||
|
||||
const __mmask16 vmask = (1 << N) - 1;
|
||||
for (int k = 0; k < 16; ++k) {
|
||||
_mm512_mask_storeu_epi32(dst + k * ld_dst * 2, vmask, vinputs[k]);
|
||||
}
|
||||
}
|
||||
|
||||
// value: from [K, 32] to [K/2, 32, 2]
|
||||
template <typename scalar_t, typename index_t>
|
||||
inline void pack_vnni_Kx32(
|
||||
scalar_t* __restrict__ dst,
|
||||
const scalar_t* __restrict__ src,
|
||||
const index_t* __restrict__ ind,
|
||||
int K,
|
||||
int ld_src,
|
||||
int ld_dst) {
|
||||
__m512i vinputs[2];
|
||||
|
||||
int k = 0;
|
||||
for (; k < K; ++k) {
|
||||
index_t index = get_index(ind, k);
|
||||
vinputs[k] = _mm512_loadu_si512(src + index * ld_src);
|
||||
}
|
||||
// padding with zero to avoid uninitialized vectors
|
||||
for (; k < 2; ++k) {
|
||||
vinputs[k] = _mm512_set1_epi32(0);
|
||||
}
|
||||
|
||||
// pack value
|
||||
__m512i d0, d1;
|
||||
std::tie(d0, d1) = transpose_2x32_16bit(vinputs[0], vinputs[1]);
|
||||
_mm512_storeu_si512(dst + 0 * ld_dst * 2, d0);
|
||||
_mm512_storeu_si512(dst + 0 * ld_dst * 2 + 32, d1);
|
||||
}
|
||||
#endif
|
||||
|
||||
// convert to vnni format
|
||||
// from [N, K/2, 2] to [K/2, N, 2] for bfloat16 and float16
|
||||
template <typename scalar_t, typename index_t>
|
||||
@@ -26,6 +87,25 @@ void pack_vnni(
|
||||
int K,
|
||||
int ld_src,
|
||||
int ld_dst) {
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
const int NB = div_up(N, 16);
|
||||
const int KB = K / 32; // no remainder
|
||||
const bool is_indexed = ind != nullptr;
|
||||
|
||||
for (int nb = 0; nb < NB; ++nb) {
|
||||
for (int kb = 0; kb < KB; ++kb) {
|
||||
// handle 16x512bits each block
|
||||
int nb_size = std::min(N - nb * 16, 16);
|
||||
pack_vnni_Nx32<scalar_t, index_t>(
|
||||
/* dst */ dst + ((kb * 32) >> 1) * ld_dst * 2 + nb * 16 * 2,
|
||||
/* src */ src + kb * 32 + (is_indexed ? 0 : nb * 16 * ld_src),
|
||||
/* ind */ is_indexed ? ind + nb * 16 : nullptr,
|
||||
/* N */ nb_size,
|
||||
/* ld_src */ ld_src,
|
||||
/* ld_dst */ ld_dst);
|
||||
}
|
||||
}
|
||||
#else
|
||||
for (int n = 0; n < N; ++n) {
|
||||
index_t index = get_index(ind, n);
|
||||
for (int k = 0; k < K / 2; ++k) {
|
||||
@@ -34,6 +114,7 @@ void pack_vnni(
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert to vnni format
|
||||
@@ -47,6 +128,25 @@ void pack_vnni2(
|
||||
int N,
|
||||
int ld_src,
|
||||
int ld_dst) {
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
const int KB = div_up(K, 2);
|
||||
const int NB = N / 32; // no remainder
|
||||
const bool is_indexed = ind != nullptr;
|
||||
|
||||
for (int kb = 0; kb < KB; ++kb) {
|
||||
for (int nb = 0; nb < NB; ++nb) {
|
||||
// handle 2x512bits each block
|
||||
int kb_size = std::min(K - kb * 2, 2);
|
||||
pack_vnni_Kx32<scalar_t, index_t>(
|
||||
/* dst */ dst + ((kb * 2) >> 1) * ld_dst * 2 + nb * 32 * 2,
|
||||
/* src */ src + (is_indexed ? 0 : kb * 2 * ld_src) + nb * 32,
|
||||
/* ind */ is_indexed ? ind + kb * 2 : nullptr,
|
||||
/* K */ kb_size,
|
||||
/* ld_src */ ld_src,
|
||||
/* ld_dst */ ld_dst);
|
||||
}
|
||||
}
|
||||
#else
|
||||
int k = 0;
|
||||
for (; k < (K >> 1) * 2; k += 2) {
|
||||
index_t index0 = get_index(ind, k + 0);
|
||||
@@ -64,21 +164,17 @@ void pack_vnni2(
|
||||
}
|
||||
k += 2;
|
||||
}
|
||||
// TODO: check whether we can skip this!
|
||||
// const int padded_K = div_up(K, TILE_K) * TILE_K;
|
||||
// for (; k < padded_K; ++k) {
|
||||
// for (int n = 0; n < N; ++n) {
|
||||
// dst[k * ld_dst + n] = static_cast<scalar_t>(0);
|
||||
// }
|
||||
// }
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void fill_stub(scalar_t* __restrict__ out, float val, int size) {
|
||||
using Vec = at::vec::Vectorized<scalar_t>;
|
||||
constexpr int kVecSize = Vec::size();
|
||||
const Vec data_vec = Vec(static_cast<scalar_t>(val));
|
||||
int d = 0;
|
||||
for (; d <= size - Vec::size(); d += Vec::size()) {
|
||||
#pragma GCC unroll 4
|
||||
for (; d <= size - kVecSize; d += kVecSize) {
|
||||
data_vec.store(out + d);
|
||||
}
|
||||
if (size - d > 0) {
|
||||
@@ -110,9 +206,11 @@ template <typename scalar_t>
|
||||
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ acc, float s, int size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
const fVec s_fvec = fVec(s);
|
||||
int d = 0;
|
||||
for (; d <= size - bVec::size(); d += bVec::size()) {
|
||||
#pragma GCC unroll 4
|
||||
for (; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec a_fvec0 = fVec::loadu(acc + d) * s_fvec;
|
||||
fVec a_fvec1 = fVec::loadu(acc + d + fVec::size()) * s_fvec;
|
||||
bVec out_bvec = convert_from_float_ext<scalar_t>(a_fvec0, a_fvec1);
|
||||
|
||||
Reference in New Issue
Block a user