|
|
|
|
@@ -308,6 +308,93 @@ struct tinygemm_kernel_nt<at::BFloat16, index_t, BLOCK_M, BLOCK_N> {
|
|
|
|
|
};
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#if defined(CPU_CAPABILITY_AVX512)
|
|
|
|
|
template <typename index_t, int BLOCK_M, int BLOCK_N>
|
|
|
|
|
struct tinygemm_kernel_nt<at::Half, index_t, BLOCK_M, BLOCK_N> {
|
|
|
|
|
static inline void apply(
|
|
|
|
|
const at::Half* __restrict__ A,
|
|
|
|
|
const at::Half* __restrict__ B,
|
|
|
|
|
float* __restrict__ C,
|
|
|
|
|
const index_t* __restrict__ indices,
|
|
|
|
|
float scale,
|
|
|
|
|
int64_t lda,
|
|
|
|
|
int64_t ldb,
|
|
|
|
|
int64_t ldc,
|
|
|
|
|
int64_t K,
|
|
|
|
|
int64_t max_tokens) {
|
|
|
|
|
constexpr int ROWS = BLOCK_M;
|
|
|
|
|
constexpr int COLS = BLOCK_N;
|
|
|
|
|
|
|
|
|
|
__m512 va0, va1;
|
|
|
|
|
__m512 vb0[COLS], vb1[COLS];
|
|
|
|
|
__m512 vc[ROWS * COLS];
|
|
|
|
|
__m512 vscale = _mm512_set1_ps(scale);
|
|
|
|
|
|
|
|
|
|
auto loadc = [&](auto i) { vc[i] = _mm512_setzero_ps(); };
|
|
|
|
|
Unroll<ROWS * COLS>{}(loadc);
|
|
|
|
|
|
|
|
|
|
auto compute = [&](auto i, int64_t k) {
|
|
|
|
|
constexpr int row = i / COLS;
|
|
|
|
|
constexpr int col = i % COLS;
|
|
|
|
|
|
|
|
|
|
if constexpr (col == 0) {
|
|
|
|
|
__m512i a16 = _mm512_loadu_si512((__m512i const*)(A + row * lda + k));
|
|
|
|
|
va0 = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(a16, 0));
|
|
|
|
|
va1 = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(a16, 1));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if constexpr (row == 0) {
|
|
|
|
|
int64_t b_idx = indices[col];
|
|
|
|
|
TORCH_CHECK(b_idx < max_tokens, "token index out of scope!");
|
|
|
|
|
__m512i b16 = _mm512_loadu_si512((__m512i const*)(B + b_idx * ldb + k));
|
|
|
|
|
vb0[col] = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(b16, 0));
|
|
|
|
|
vb1[col] = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(b16, 1));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
vc[i] = _mm512_fmadd_ps(va0, vb0[col], _mm512_fmadd_ps(va1, vb1[col], vc[i]));
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
auto compute2 = [&](auto i, int64_t k, __mmask32 mask) {
|
|
|
|
|
constexpr int row = i / COLS;
|
|
|
|
|
constexpr int col = i % COLS;
|
|
|
|
|
|
|
|
|
|
if constexpr (col == 0) {
|
|
|
|
|
__m512i a16 = _mm512_maskz_loadu_epi16(mask, (const void*)(A + row * lda + k));
|
|
|
|
|
va0 = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(a16, 0));
|
|
|
|
|
va1 = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(a16, 1));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if constexpr (row == 0) {
|
|
|
|
|
int64_t b_idx = indices[col];
|
|
|
|
|
TORCH_CHECK(b_idx < max_tokens, "token index out of scope!");
|
|
|
|
|
__m512i b16 = _mm512_maskz_loadu_epi16(mask, (const void*)(B + b_idx * ldb + k));
|
|
|
|
|
vb0[col] = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(b16, 0));
|
|
|
|
|
vb1[col] = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(b16, 1));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
vc[i] = _mm512_fmadd_ps(va0, vb0[col], _mm512_fmadd_ps(va1, vb1[col], vc[i]));
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
int64_t k = 0;
|
|
|
|
|
for (; k <= K - 32; k += 32) {
|
|
|
|
|
Unroll<ROWS * COLS>{}(compute, k);
|
|
|
|
|
}
|
|
|
|
|
int64_t count = K - k;
|
|
|
|
|
if (count > 0) {
|
|
|
|
|
__mmask32 mask = (1ULL << count) - 1;
|
|
|
|
|
Unroll<ROWS * COLS>{}(compute2, k, mask);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto storec = [&](auto i) {
|
|
|
|
|
constexpr int row = i / COLS;
|
|
|
|
|
constexpr int col = i % COLS;
|
|
|
|
|
C[row * ldc + col] = _mm512_reduce_add_ps(_mm512_mul_ps(vc[i], vscale));
|
|
|
|
|
};
|
|
|
|
|
Unroll<ROWS * COLS>{}(storec);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#define LAUNCH_TINYGEMM_KERNEL_NT(MB_SIZE, NB_SIZE) \
|
|
|
|
|
tinygemm_kernel_nt<scalar_t, index_t, MB_SIZE, NB_SIZE>::apply( \
|
|
|
|
|
A + mb_start * lda, B, C + mb_start * ldc + nb_start, indices + nb_start, scale, lda, ldb, ldc, K, max_tokens);
|
|
|
|
|
@@ -443,6 +530,87 @@ struct tinygemm_kernel_nn<at::BFloat16, index_t, BLOCK_M, BLOCK_N> {
|
|
|
|
|
};
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#if defined(CPU_CAPABILITY_AVX512)
|
|
|
|
|
template <typename index_t, int BLOCK_M, int BLOCK_N>
|
|
|
|
|
struct tinygemm_kernel_nn<at::Half, index_t, BLOCK_M, BLOCK_N> {
|
|
|
|
|
static inline void apply(
|
|
|
|
|
const float* __restrict__ A,
|
|
|
|
|
const at::Half* __restrict__ B,
|
|
|
|
|
float* __restrict__ C,
|
|
|
|
|
const index_t* __restrict__ indices,
|
|
|
|
|
const float* __restrict__ scale,
|
|
|
|
|
int64_t lda,
|
|
|
|
|
int64_t ldb,
|
|
|
|
|
int64_t ldc,
|
|
|
|
|
int64_t K,
|
|
|
|
|
int64_t max_tokens) {
|
|
|
|
|
constexpr int ROWS = BLOCK_M;
|
|
|
|
|
constexpr int COLS = BLOCK_N / 16;
|
|
|
|
|
|
|
|
|
|
__m512 va;
|
|
|
|
|
__m512 vb[COLS];
|
|
|
|
|
__m512 vc[ROWS * COLS];
|
|
|
|
|
__m512 vscale;
|
|
|
|
|
|
|
|
|
|
auto loadc = [&](auto i) {
|
|
|
|
|
constexpr int row = i / COLS;
|
|
|
|
|
constexpr int col = i % COLS;
|
|
|
|
|
#pragma GCC diagnostic push
|
|
|
|
|
#pragma GCC diagnostic ignored "-Warray-bounds"
|
|
|
|
|
if constexpr (col == 0) {
|
|
|
|
|
vscale = _mm512_set1_ps(scale[row]);
|
|
|
|
|
}
|
|
|
|
|
#pragma GCC diagnostic pop
|
|
|
|
|
vc[i] = _mm512_loadu_ps(C + row * ldc + col * 16);
|
|
|
|
|
vc[i] = _mm512_mul_ps(vc[i], vscale);
|
|
|
|
|
};
|
|
|
|
|
Unroll<ROWS * COLS>{}(loadc);
|
|
|
|
|
|
|
|
|
|
auto compute = [&](auto i, int64_t k) {
|
|
|
|
|
constexpr int row = i / COLS;
|
|
|
|
|
constexpr int col = i % COLS;
|
|
|
|
|
|
|
|
|
|
if constexpr (col == 0) {
|
|
|
|
|
va = _mm512_set1_ps(A[row * lda + k]);
|
|
|
|
|
}
|
|
|
|
|
if constexpr (row == 0) {
|
|
|
|
|
if (k + 1 < K) {
|
|
|
|
|
int64_t b_idx_prefetch = indices[k + 1];
|
|
|
|
|
_mm_prefetch(B + b_idx_prefetch * ldb + col * 16, _MM_HINT_T0);
|
|
|
|
|
}
|
|
|
|
|
int64_t b_idx = indices[k];
|
|
|
|
|
TORCH_CHECK(b_idx < max_tokens, "token index out of scope!");
|
|
|
|
|
|
|
|
|
|
// for COLS = 2, 4, 6, 8 use 512 bit load
|
|
|
|
|
// for COLS = 1, 3, 5, 7 use 256 bit load
|
|
|
|
|
if constexpr (COLS % 2 == 0) {
|
|
|
|
|
if constexpr (col % 2 == 0) {
|
|
|
|
|
__m512i b16 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(B + b_idx * ldb + col * 16));
|
|
|
|
|
vb[col + 0] = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(b16, 0));
|
|
|
|
|
vb[col + 1] = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(b16, 1));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
__m256i b16 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(B + b_idx * ldb + col * 16));
|
|
|
|
|
vb[col] = CVT_FP16_TO_FP32(b16);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
vc[i] = _mm512_fmadd_ps(va, vb[col], vc[i]);
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
for (int64_t k = 0; k < K; ++k) {
|
|
|
|
|
Unroll<ROWS * COLS>{}(compute, k);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto storec = [&](auto i) {
|
|
|
|
|
constexpr int row = i / COLS;
|
|
|
|
|
constexpr int col = i % COLS;
|
|
|
|
|
_mm512_storeu_ps(C + row * ldc + col * 16, vc[i]);
|
|
|
|
|
};
|
|
|
|
|
Unroll<ROWS * COLS>{}(storec);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
|
|
|
|
|
tinygemm_kernel_nn<scalar_t, index_t, MB_SIZE, NB_SIZE>::apply( \
|
|
|
|
|
A + mb_start * lda, \
|
|
|
|
|
@@ -512,9 +680,10 @@ void index_gemm_kernel_nt(
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// pattern: 1-6-24
|
|
|
|
|
// default pattern: 1-6-24
|
|
|
|
|
// FP16 pattern: 2-8-16
|
|
|
|
|
constexpr int64_t BLOCK_M = 4;
|
|
|
|
|
constexpr int64_t BLOCK_N = 6;
|
|
|
|
|
constexpr int64_t BLOCK_N = std::is_same_v<scalar_t, at::Half> ? 4 : 6;
|
|
|
|
|
const int64_t MB = div_up(M, BLOCK_M);
|
|
|
|
|
const int64_t NB = div_up(N, BLOCK_N);
|
|
|
|
|
|
|
|
|
|
|