[CPU] Optimize FP16 decode_attention_cpu (#10652)
This commit is contained in:
@@ -540,7 +540,10 @@ class ParallelLMHead(VocabParallelEmbedding):
|
|||||||
|
|
||||||
# We only support pack LMHead if it's not quantized.
|
# We only support pack LMHead if it's not quantized.
|
||||||
if _is_cpu and _is_cpu_amx_available:
|
if _is_cpu and _is_cpu_amx_available:
|
||||||
if hasattr(self, "weight") and self.weight.dtype == torch.bfloat16:
|
if hasattr(self, "weight") and self.weight.dtype in [
|
||||||
|
torch.bfloat16,
|
||||||
|
torch.float16,
|
||||||
|
]:
|
||||||
self.quant_method = PackWeightMethod(weight_names=["weight"])
|
self.quant_method = PackWeightMethod(weight_names=["weight"])
|
||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
|
|||||||
@@ -308,6 +308,93 @@ struct tinygemm_kernel_nt<at::BFloat16, index_t, BLOCK_M, BLOCK_N> {
|
|||||||
};
|
};
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(CPU_CAPABILITY_AVX512)
|
||||||
|
template <typename index_t, int BLOCK_M, int BLOCK_N>
|
||||||
|
struct tinygemm_kernel_nt<at::Half, index_t, BLOCK_M, BLOCK_N> {
|
||||||
|
static inline void apply(
|
||||||
|
const at::Half* __restrict__ A,
|
||||||
|
const at::Half* __restrict__ B,
|
||||||
|
float* __restrict__ C,
|
||||||
|
const index_t* __restrict__ indices,
|
||||||
|
float scale,
|
||||||
|
int64_t lda,
|
||||||
|
int64_t ldb,
|
||||||
|
int64_t ldc,
|
||||||
|
int64_t K,
|
||||||
|
int64_t max_tokens) {
|
||||||
|
constexpr int ROWS = BLOCK_M;
|
||||||
|
constexpr int COLS = BLOCK_N;
|
||||||
|
|
||||||
|
__m512 va0, va1;
|
||||||
|
__m512 vb0[COLS], vb1[COLS];
|
||||||
|
__m512 vc[ROWS * COLS];
|
||||||
|
__m512 vscale = _mm512_set1_ps(scale);
|
||||||
|
|
||||||
|
auto loadc = [&](auto i) { vc[i] = _mm512_setzero_ps(); };
|
||||||
|
Unroll<ROWS * COLS>{}(loadc);
|
||||||
|
|
||||||
|
auto compute = [&](auto i, int64_t k) {
|
||||||
|
constexpr int row = i / COLS;
|
||||||
|
constexpr int col = i % COLS;
|
||||||
|
|
||||||
|
if constexpr (col == 0) {
|
||||||
|
__m512i a16 = _mm512_loadu_si512((__m512i const*)(A + row * lda + k));
|
||||||
|
va0 = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(a16, 0));
|
||||||
|
va1 = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(a16, 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
if constexpr (row == 0) {
|
||||||
|
int64_t b_idx = indices[col];
|
||||||
|
TORCH_CHECK(b_idx < max_tokens, "token index out of scope!");
|
||||||
|
__m512i b16 = _mm512_loadu_si512((__m512i const*)(B + b_idx * ldb + k));
|
||||||
|
vb0[col] = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(b16, 0));
|
||||||
|
vb1[col] = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(b16, 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
vc[i] = _mm512_fmadd_ps(va0, vb0[col], _mm512_fmadd_ps(va1, vb1[col], vc[i]));
|
||||||
|
};
|
||||||
|
|
||||||
|
auto compute2 = [&](auto i, int64_t k, __mmask32 mask) {
|
||||||
|
constexpr int row = i / COLS;
|
||||||
|
constexpr int col = i % COLS;
|
||||||
|
|
||||||
|
if constexpr (col == 0) {
|
||||||
|
__m512i a16 = _mm512_maskz_loadu_epi16(mask, (const void*)(A + row * lda + k));
|
||||||
|
va0 = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(a16, 0));
|
||||||
|
va1 = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(a16, 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
if constexpr (row == 0) {
|
||||||
|
int64_t b_idx = indices[col];
|
||||||
|
TORCH_CHECK(b_idx < max_tokens, "token index out of scope!");
|
||||||
|
__m512i b16 = _mm512_maskz_loadu_epi16(mask, (const void*)(B + b_idx * ldb + k));
|
||||||
|
vb0[col] = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(b16, 0));
|
||||||
|
vb1[col] = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(b16, 1));
|
||||||
|
}
|
||||||
|
|
||||||
|
vc[i] = _mm512_fmadd_ps(va0, vb0[col], _mm512_fmadd_ps(va1, vb1[col], vc[i]));
|
||||||
|
};
|
||||||
|
|
||||||
|
int64_t k = 0;
|
||||||
|
for (; k <= K - 32; k += 32) {
|
||||||
|
Unroll<ROWS * COLS>{}(compute, k);
|
||||||
|
}
|
||||||
|
int64_t count = K - k;
|
||||||
|
if (count > 0) {
|
||||||
|
__mmask32 mask = (1ULL << count) - 1;
|
||||||
|
Unroll<ROWS * COLS>{}(compute2, k, mask);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto storec = [&](auto i) {
|
||||||
|
constexpr int row = i / COLS;
|
||||||
|
constexpr int col = i % COLS;
|
||||||
|
C[row * ldc + col] = _mm512_reduce_add_ps(_mm512_mul_ps(vc[i], vscale));
|
||||||
|
};
|
||||||
|
Unroll<ROWS * COLS>{}(storec);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
#define LAUNCH_TINYGEMM_KERNEL_NT(MB_SIZE, NB_SIZE) \
|
#define LAUNCH_TINYGEMM_KERNEL_NT(MB_SIZE, NB_SIZE) \
|
||||||
tinygemm_kernel_nt<scalar_t, index_t, MB_SIZE, NB_SIZE>::apply( \
|
tinygemm_kernel_nt<scalar_t, index_t, MB_SIZE, NB_SIZE>::apply( \
|
||||||
A + mb_start * lda, B, C + mb_start * ldc + nb_start, indices + nb_start, scale, lda, ldb, ldc, K, max_tokens);
|
A + mb_start * lda, B, C + mb_start * ldc + nb_start, indices + nb_start, scale, lda, ldb, ldc, K, max_tokens);
|
||||||
@@ -443,6 +530,87 @@ struct tinygemm_kernel_nn<at::BFloat16, index_t, BLOCK_M, BLOCK_N> {
|
|||||||
};
|
};
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#if defined(CPU_CAPABILITY_AVX512)
|
||||||
|
template <typename index_t, int BLOCK_M, int BLOCK_N>
|
||||||
|
struct tinygemm_kernel_nn<at::Half, index_t, BLOCK_M, BLOCK_N> {
|
||||||
|
static inline void apply(
|
||||||
|
const float* __restrict__ A,
|
||||||
|
const at::Half* __restrict__ B,
|
||||||
|
float* __restrict__ C,
|
||||||
|
const index_t* __restrict__ indices,
|
||||||
|
const float* __restrict__ scale,
|
||||||
|
int64_t lda,
|
||||||
|
int64_t ldb,
|
||||||
|
int64_t ldc,
|
||||||
|
int64_t K,
|
||||||
|
int64_t max_tokens) {
|
||||||
|
constexpr int ROWS = BLOCK_M;
|
||||||
|
constexpr int COLS = BLOCK_N / 16;
|
||||||
|
|
||||||
|
__m512 va;
|
||||||
|
__m512 vb[COLS];
|
||||||
|
__m512 vc[ROWS * COLS];
|
||||||
|
__m512 vscale;
|
||||||
|
|
||||||
|
auto loadc = [&](auto i) {
|
||||||
|
constexpr int row = i / COLS;
|
||||||
|
constexpr int col = i % COLS;
|
||||||
|
#pragma GCC diagnostic push
|
||||||
|
#pragma GCC diagnostic ignored "-Warray-bounds"
|
||||||
|
if constexpr (col == 0) {
|
||||||
|
vscale = _mm512_set1_ps(scale[row]);
|
||||||
|
}
|
||||||
|
#pragma GCC diagnostic pop
|
||||||
|
vc[i] = _mm512_loadu_ps(C + row * ldc + col * 16);
|
||||||
|
vc[i] = _mm512_mul_ps(vc[i], vscale);
|
||||||
|
};
|
||||||
|
Unroll<ROWS * COLS>{}(loadc);
|
||||||
|
|
||||||
|
auto compute = [&](auto i, int64_t k) {
|
||||||
|
constexpr int row = i / COLS;
|
||||||
|
constexpr int col = i % COLS;
|
||||||
|
|
||||||
|
if constexpr (col == 0) {
|
||||||
|
va = _mm512_set1_ps(A[row * lda + k]);
|
||||||
|
}
|
||||||
|
if constexpr (row == 0) {
|
||||||
|
if (k + 1 < K) {
|
||||||
|
int64_t b_idx_prefetch = indices[k + 1];
|
||||||
|
_mm_prefetch(B + b_idx_prefetch * ldb + col * 16, _MM_HINT_T0);
|
||||||
|
}
|
||||||
|
int64_t b_idx = indices[k];
|
||||||
|
TORCH_CHECK(b_idx < max_tokens, "token index out of scope!");
|
||||||
|
|
||||||
|
// for COLS = 2, 4, 6, 8 use 512 bit load
|
||||||
|
// for COLS = 1, 3, 5, 7 use 256 bit load
|
||||||
|
if constexpr (COLS % 2 == 0) {
|
||||||
|
if constexpr (col % 2 == 0) {
|
||||||
|
__m512i b16 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(B + b_idx * ldb + col * 16));
|
||||||
|
vb[col + 0] = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(b16, 0));
|
||||||
|
vb[col + 1] = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(b16, 1));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
__m256i b16 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(B + b_idx * ldb + col * 16));
|
||||||
|
vb[col] = CVT_FP16_TO_FP32(b16);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
vc[i] = _mm512_fmadd_ps(va, vb[col], vc[i]);
|
||||||
|
};
|
||||||
|
|
||||||
|
for (int64_t k = 0; k < K; ++k) {
|
||||||
|
Unroll<ROWS * COLS>{}(compute, k);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto storec = [&](auto i) {
|
||||||
|
constexpr int row = i / COLS;
|
||||||
|
constexpr int col = i % COLS;
|
||||||
|
_mm512_storeu_ps(C + row * ldc + col * 16, vc[i]);
|
||||||
|
};
|
||||||
|
Unroll<ROWS * COLS>{}(storec);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
#endif
|
||||||
|
|
||||||
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
|
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
|
||||||
tinygemm_kernel_nn<scalar_t, index_t, MB_SIZE, NB_SIZE>::apply( \
|
tinygemm_kernel_nn<scalar_t, index_t, MB_SIZE, NB_SIZE>::apply( \
|
||||||
A + mb_start * lda, \
|
A + mb_start * lda, \
|
||||||
@@ -512,9 +680,10 @@ void index_gemm_kernel_nt(
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// pattern: 1-6-24
|
// default pattern: 1-6-24
|
||||||
|
// FP16 pattern: 2-8-16
|
||||||
constexpr int64_t BLOCK_M = 4;
|
constexpr int64_t BLOCK_M = 4;
|
||||||
constexpr int64_t BLOCK_N = 6;
|
constexpr int64_t BLOCK_N = std::is_same_v<scalar_t, at::Half> ? 4 : 6;
|
||||||
const int64_t MB = div_up(M, BLOCK_M);
|
const int64_t MB = div_up(M, BLOCK_M);
|
||||||
const int64_t NB = div_up(N, BLOCK_N);
|
const int64_t NB = div_up(N, BLOCK_N);
|
||||||
|
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ convert_from_float_ext<at::BFloat16>(const Vectorized<float>& a, const Vectorize
|
|||||||
|
|
||||||
#define CVT_BF16_TO_FP32(a) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16))
|
#define CVT_BF16_TO_FP32(a) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16))
|
||||||
|
|
||||||
#define CVT_FP16_TO_FP32(a) _mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))
|
#define CVT_FP16_TO_FP32(a) _mm512_cvtph_ps(a)
|
||||||
|
|
||||||
// this doesn't handle NaN.
|
// this doesn't handle NaN.
|
||||||
inline __m512bh cvt_e4m3_bf16_intrinsic_no_nan(__m256i fp8_vec) {
|
inline __m512bh cvt_e4m3_bf16_intrinsic_no_nan(__m256i fp8_vec) {
|
||||||
|
|||||||
@@ -59,8 +59,7 @@ class TestDecodeAttention(CustomTestCase):
|
|||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V, device):
|
def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V, dtype, device):
|
||||||
dtype = torch.bfloat16
|
|
||||||
# This represents the number of tokens already in the sequence
|
# This represents the number of tokens already in the sequence
|
||||||
seq_len = 1024
|
seq_len = 1024
|
||||||
total_tokens = B * seq_len
|
total_tokens = B * seq_len
|
||||||
@@ -158,9 +157,10 @@ class TestDecodeAttention(CustomTestCase):
|
|||||||
]
|
]
|
||||||
|
|
||||||
for B, H_Q, H_KV, D, D_V in configs:
|
for B, H_Q, H_KV, D, D_V in configs:
|
||||||
self._test_grouped_decode_attention_once(
|
for dtype in [torch.bfloat16, torch.float16]:
|
||||||
B, H_Q, H_KV, D, D_V, device=device
|
self._test_grouped_decode_attention_once(
|
||||||
)
|
B, H_Q, H_KV, D, D_V, dtype=dtype, device=device
|
||||||
|
)
|
||||||
|
|
||||||
def test_grouped_decode_attention(self):
|
def test_grouped_decode_attention(self):
|
||||||
self._test_grouped_decode_attention("cpu")
|
self._test_grouped_decode_attention("cpu")
|
||||||
|
|||||||
Reference in New Issue
Block a user