Optimize prefill performance on cpu backend (#8750)
This commit is contained in:
@@ -105,7 +105,19 @@ namespace {
|
||||
|
||||
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
|
||||
|
||||
// parallel routines
|
||||
// [NB] Parallel Routines
|
||||
//
|
||||
// * at::parallel_for - applies for most of generic use cases, this will be compiled
|
||||
// against openmp in default torch release.
|
||||
//
|
||||
// * parallel_for - same function as above, can choose payload partition scheme in
|
||||
// balance211.
|
||||
//
|
||||
// * parallel_2d - parallel for 2 dimensions, used in GEMM, etc.
|
||||
// this one will do payload balance across 2 dimensions.
|
||||
//
|
||||
|
||||
// grain size for each thread
|
||||
constexpr int GRAIN_SIZE = 1024;
|
||||
|
||||
template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
|
||||
@@ -113,6 +125,17 @@ inline T div_up(T x, T y) {
|
||||
return (x + y - 1) / y;
|
||||
}
|
||||
|
||||
// you can only use at::get_thread_num() with at::parallel_for()
|
||||
// as it is lazy initialized, otherwise it will always return 0.
|
||||
inline int get_thread_num() {
|
||||
#if defined(_OPENMP)
|
||||
return omp_get_thread_num();
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
// balance payload across each thread
|
||||
template <typename T>
|
||||
inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) {
|
||||
#if 0
|
||||
@@ -153,6 +176,100 @@ inline void parallel_for(int n, const func_t& f) {
|
||||
#endif
|
||||
}
|
||||
|
||||
// for 1d parallel, use `actual_nth`
|
||||
// for 2d parallel, use even nths, e.g. 43->42
|
||||
int inline adjust_num_threads(int m) {
|
||||
int actual_nth = at::get_num_threads();
|
||||
if (m == 1) {
|
||||
return actual_nth;
|
||||
}
|
||||
return std::max(1, (actual_nth >> 1) * 2);
|
||||
}
|
||||
|
||||
template <typename func_t>
|
||||
inline void parallel_2d(int m, int n, const func_t& f) {
|
||||
// make sure we have even num_threads
|
||||
int nth = adjust_num_threads(m);
|
||||
|
||||
// [NOTE] thread blocking:
|
||||
//
|
||||
// 1) prefer square block per thread
|
||||
// 2) use even number of CPU cores
|
||||
// 3) use all `num_threads` cores
|
||||
//
|
||||
// we have:
|
||||
// TM * TN = T
|
||||
// BM / TM = BN / TN
|
||||
// then:
|
||||
// TM = ((BM / BN) * T) ^ 0.5
|
||||
//
|
||||
float r = float(m) / n;
|
||||
int nth_m = std::ceil(std::sqrt(r * nth));
|
||||
int nth_n = 1;
|
||||
for (; nth_m > 0; --nth_m) {
|
||||
nth_n = nth / nth_m;
|
||||
if (nth_m * nth_n == nth) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp parallel num_threads(nth)
|
||||
{
|
||||
int ith = omp_get_thread_num();
|
||||
int ith_m = ith / nth_n;
|
||||
int ith_n = ith % nth_n;
|
||||
|
||||
int thread_block_m = div_up(m, nth_m);
|
||||
int thread_block_n = div_up(n, nth_n);
|
||||
|
||||
int begin_m = ith_m * thread_block_m;
|
||||
int end_m = std::min(m, begin_m + thread_block_m);
|
||||
int begin_n = ith_n * thread_block_n;
|
||||
int end_n = std::min(n, begin_n + thread_block_n);
|
||||
|
||||
f(begin_m, end_m, begin_n, end_n);
|
||||
}
|
||||
#else
|
||||
f(0, m, 0, n);
|
||||
#endif
|
||||
}
|
||||
|
||||
// limit max cache blocks
|
||||
// when we need to do pre-unpack for weights, e.g. fp8
|
||||
#define MAX_CACHE_BLOCK_SIZE 4
|
||||
|
||||
template <typename T>
|
||||
inline int get_cache_blocks(int chunk_size) {
|
||||
// L2 2MB and ratio of 50%
|
||||
const int L2_size = 2048 * 1024 >> 1;
|
||||
return std::max(1, int(L2_size / (chunk_size * sizeof(T))));
|
||||
}
|
||||
|
||||
template <>
|
||||
inline int get_cache_blocks<at::Float8_e4m3fn>(int chunk_size) {
|
||||
// fp8 uses bf16 as accumulate type
|
||||
int cache_block_size = get_cache_blocks<at::BFloat16>(chunk_size);
|
||||
return std::min(MAX_CACHE_BLOCK_SIZE, cache_block_size);
|
||||
}
|
||||
|
||||
// 2d sequential loop in range : [mb0, mb1), [nb0, nb1)
|
||||
template <typename T, typename func_t>
|
||||
inline void loop_2d(int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1, int64_t chunk_size, const func_t& f) {
|
||||
// get number of blocks for L2 in most inner loop
|
||||
int64_t cache_blocks_nb = get_cache_blocks<T>(chunk_size);
|
||||
|
||||
// loop order: [NB / cache_blocks_nb, MB, cache_blocks_nb]
|
||||
// TODO: implement reverse order of [MB / cache_blocks_mb, NB, cache_blocks_mb]
|
||||
for (int64_t nbb = nb0; nbb < nb1; nbb += cache_blocks_nb) {
|
||||
for (int64_t mb = mb0; mb < mb1; ++mb) {
|
||||
for (int64_t nb = nbb; nb < std::min(nbb + cache_blocks_nb, nb1); ++nb) {
|
||||
f(mb, nb, nb - nbb);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// data indexing for dimension collapse
|
||||
template <typename T>
|
||||
inline T data_index_init(T offset) {
|
||||
|
||||
@@ -254,7 +254,7 @@ void tinygemm_kernel(
|
||||
return;
|
||||
}
|
||||
|
||||
// pattern: 1-4-16
|
||||
// pattern: 1-4-16, N = 16, 32, 48, 64
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 64;
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
@@ -268,35 +268,59 @@ void tinygemm_kernel(
|
||||
|
||||
switch (mb_size << 4 | nb_size >> 4) {
|
||||
// mb_size = 1
|
||||
case 0x11:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(1, 16);
|
||||
break;
|
||||
case 0x12:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(1, 32);
|
||||
break;
|
||||
case 0x13:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(1, 48);
|
||||
break;
|
||||
case 0x14:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(1, 64);
|
||||
break;
|
||||
// mb_size = 2
|
||||
case 0x21:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(2, 16);
|
||||
break;
|
||||
case 0x22:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(2, 32);
|
||||
break;
|
||||
case 0x23:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(2, 48);
|
||||
break;
|
||||
case 0x24:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(2, 64);
|
||||
break;
|
||||
// mb_size = 3
|
||||
case 0x31:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(3, 16);
|
||||
break;
|
||||
case 0x32:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(3, 32);
|
||||
break;
|
||||
case 0x33:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(3, 48);
|
||||
break;
|
||||
case 0x34:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(3, 64);
|
||||
break;
|
||||
// mb_size = 4
|
||||
case 0x41:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(4, 16);
|
||||
break;
|
||||
case 0x42:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(4, 32);
|
||||
break;
|
||||
case 0x43:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(4, 48);
|
||||
break;
|
||||
case 0x44:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(4, 64);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
|
||||
TORCH_CHECK(false, "Unexpected block size, ", mb_size, " x ", nb_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -318,20 +342,15 @@ void weight_packed_linear_kernel_impl(
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
// use avx512-bf16 when a) M is small; b) dtype is bfloat16, otherwise use amx c) N is small
|
||||
const bool use_brgemm = (M > 4) || (!std::is_same_v<scalar_t, at::BFloat16>) || (N < 64);
|
||||
const bool use_brgemm = can_use_brgemm<scalar_t>(M);
|
||||
|
||||
// parallel on [MB, NB]
|
||||
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t mb{0}, nb{0};
|
||||
data_index_init(begin, mb, MB, nb, NB);
|
||||
|
||||
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
|
||||
// for brgemm, use float32 for accumulate
|
||||
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
UNUSED(i);
|
||||
loop_2d<scalar_t>(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(M - mb_start, BLOCK_M);
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
@@ -350,10 +369,7 @@ void weight_packed_linear_kernel_impl(
|
||||
/* ldb */ nb_size,
|
||||
/* ldc */ out_strideM,
|
||||
/* brg */ use_brgemm);
|
||||
|
||||
// move to the next index
|
||||
data_index_step(mb, MB, nb, NB);
|
||||
}
|
||||
});
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
|
||||
@@ -27,10 +27,10 @@ template <>
|
||||
inline bool can_use_brgemm<at::Half>(int M) {
|
||||
return true;
|
||||
}
|
||||
// TODO: add u8s8 brgemm, this requires PyTorch 2.7
|
||||
// this requires PyTorch 2.7 or above
|
||||
template <>
|
||||
inline bool can_use_brgemm<int8_t>(int M) {
|
||||
return false;
|
||||
return M > 4;
|
||||
}
|
||||
|
||||
template <>
|
||||
@@ -198,4 +198,5 @@ void tinygemm_kernel(
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg,
|
||||
int64_t block_size_K);
|
||||
int64_t block_size_K,
|
||||
bool do_unpack = true);
|
||||
|
||||
@@ -2,9 +2,6 @@
|
||||
#include "gemm.h"
|
||||
#include "vec.h"
|
||||
|
||||
// we use 4x32 for BLOCK_M
|
||||
#define BLOCK_SIZE_M_SCALE 4
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t>
|
||||
@@ -250,7 +247,8 @@ struct brgemm {
|
||||
int K,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc) {
|
||||
int ldc,
|
||||
bool do_unpack = true) {
|
||||
TORCH_CHECK(false, "struct brgemm: primary template not implemented!");
|
||||
}
|
||||
};
|
||||
@@ -270,17 +268,20 @@ struct brgemm<at::BFloat16, at::Float8_e4m3fn, has_bias> {
|
||||
int K,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc) {
|
||||
int ldc,
|
||||
bool do_unpack = true) {
|
||||
constexpr int BLOCK_N = block_size_n();
|
||||
|
||||
// [K, BLOCK_N] -> [K / 2, BLOCK_N * 2]
|
||||
const int ldb_tmp = BLOCK_N;
|
||||
|
||||
for (int k = 0; k < K; k += BLOCK_K) {
|
||||
int kb_size = std::min(BLOCK_K, K - k);
|
||||
if (do_unpack) {
|
||||
for (int k = 0; k < K; k += BLOCK_K) {
|
||||
int kb_size = std::min(BLOCK_K, K - k);
|
||||
|
||||
int idx = k >> 7; // k / BLOCK_K where BLOCK_K = 128
|
||||
unpack_B(Btmp + k * ldb_tmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]);
|
||||
int idx = k >> 7; // k / BLOCK_K where BLOCK_K = 128
|
||||
unpack_B(Btmp + k * ldb_tmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]);
|
||||
}
|
||||
}
|
||||
|
||||
at::native::cpublas::brgemm(M, N, K, lda, ldb_tmp, BLOCK_N, /* add_C */ false, A, Btmp, Ctmp);
|
||||
@@ -312,9 +313,11 @@ void tinygemm_kernel(
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg,
|
||||
int64_t block_size_K) {
|
||||
int64_t block_size_K,
|
||||
bool do_unpack = true) {
|
||||
if (brg) {
|
||||
brgemm<scalar_t, at::Float8_e4m3fn, has_bias>::apply(A, B, C, Btmp, Ctmp, bias, scale, M, N, K, lda, ldb, ldc);
|
||||
brgemm<scalar_t, at::Float8_e4m3fn, has_bias>::apply(
|
||||
A, B, C, Btmp, Ctmp, bias, scale, M, N, K, lda, ldb, ldc, do_unpack);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -366,7 +369,7 @@ void fp8_scaled_mm_kernel_impl(
|
||||
int64_t block_size_N,
|
||||
int64_t block_size_K,
|
||||
int64_t buffer_size_per_thread) {
|
||||
constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE;
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
@@ -378,16 +381,12 @@ void fp8_scaled_mm_kernel_impl(
|
||||
|
||||
// parallel on [MB, NB]
|
||||
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t mb{0}, nb{0};
|
||||
data_index_init(begin, mb, MB, nb, NB);
|
||||
|
||||
int tid = at::get_thread_num();
|
||||
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
|
||||
int tid = get_thread_num();
|
||||
scalar_t* __restrict__ Btmp = buffer + tid * buffer_size_per_thread;
|
||||
float* __restrict__ Ctmp = (float*)((void*)(Btmp + BLOCK_N * K));
|
||||
float* __restrict__ Ctmp = (float*)((void*)(Btmp + MAX_CACHE_BLOCK_SIZE * BLOCK_N * K));
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
UNUSED(i);
|
||||
loop_2d<at::Float8_e4m3fn>(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
|
||||
const float* scale_ptr = scales2 + (nb / blocks_n_per_group) * scale_size_K;
|
||||
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
@@ -395,11 +394,14 @@ void fp8_scaled_mm_kernel_impl(
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(N - nb_start, BLOCK_N);
|
||||
|
||||
// only do unpacking for the first row
|
||||
bool do_unpack = (mb == mb0);
|
||||
|
||||
tinygemm_kernel<scalar_t, has_bias>(
|
||||
/* A */ mat1 + mb_start * mat1_strideM,
|
||||
/* B */ mat2 + nb_start * K, // nb * BLOCK_N * K
|
||||
/* C */ out + mb_start * out_strideM + nb_start,
|
||||
/* Btmp */ Btmp,
|
||||
/* Btmp */ Btmp + nb_offset * BLOCK_N * K,
|
||||
/* Ctmp */ Ctmp,
|
||||
/* scale */ scale_ptr,
|
||||
/* bias */ bias + nb_start,
|
||||
@@ -410,11 +412,9 @@ void fp8_scaled_mm_kernel_impl(
|
||||
/* ldb */ nb_size,
|
||||
/* ldc */ out_strideM,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K);
|
||||
|
||||
// move to the next index
|
||||
data_index_step(mb, MB, nb, NB);
|
||||
}
|
||||
/* block_size_K */ block_size_K,
|
||||
/* do_unpack */ do_unpack);
|
||||
});
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
@@ -441,8 +441,10 @@ void tinygemm_kernel(
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg,
|
||||
int64_t block_size_K) {
|
||||
tinygemm_kernel<scalar_t, false>(A, B, C, Btmp, Ctmp, scale, nullptr, M, N, K, lda, ldb, ldc, brg, block_size_K);
|
||||
int64_t block_size_K,
|
||||
bool do_unpack) {
|
||||
tinygemm_kernel<scalar_t, false>(
|
||||
A, B, C, Btmp, Ctmp, scale, nullptr, M, N, K, lda, ldb, ldc, brg, block_size_K, do_unpack);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
|
||||
@@ -460,7 +462,8 @@ void tinygemm_kernel(
|
||||
int64_t ldb, \
|
||||
int64_t ldc, \
|
||||
bool brg, \
|
||||
int64_t block_size_K)
|
||||
int64_t block_size_K, \
|
||||
bool do_unpack)
|
||||
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
|
||||
@@ -495,7 +498,7 @@ at::Tensor fp8_scaled_mm_cpu(
|
||||
int64_t block_size_N = block_size[0];
|
||||
int64_t block_size_K = block_size[1];
|
||||
|
||||
constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE;
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
TORCH_CHECK(block_size_N % BLOCK_N == 0, "fp8_scaled_mm_cpu: expect block_size_N to be multiples of BLOCK_N");
|
||||
TORCH_CHECK(block_size_K == BLOCK_K, "fp8_scaled_mm_cpu: expect block_size_K equals to BLOCK_K");
|
||||
@@ -523,7 +526,7 @@ at::Tensor fp8_scaled_mm_cpu(
|
||||
// Btmp : [T, BLOCK_N * K]
|
||||
// Ctmp : [T, BLOCK_M * BLOCK_N]
|
||||
int num_threads = at::get_num_threads();
|
||||
int64_t size_per_thread = BLOCK_N * K + BLOCK_M * BLOCK_N * 2;
|
||||
int64_t size_per_thread = MAX_CACHE_BLOCK_SIZE * BLOCK_N * K + BLOCK_M * BLOCK_N * 2;
|
||||
auto buffer = at::empty({num_threads, size_per_thread}, mat1.options());
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "fp8_scaled_mm_kernel_impl", [&] {
|
||||
|
||||
@@ -4,6 +4,61 @@
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t, bool has_bias, int BLOCK_N>
|
||||
struct scale_C {
|
||||
static inline void apply(
|
||||
scalar_t* __restrict__ C,
|
||||
const int32_t* __restrict__ Ctmp,
|
||||
const int32_t* __restrict__ Bcomp,
|
||||
const float* __restrict__ bias,
|
||||
float As,
|
||||
const float* __restrict__ Bs) {
|
||||
TORCH_CHECK(false, "scale_C: scalar path not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <bool has_bias, int BLOCK_N>
|
||||
struct scale_C<at::BFloat16, has_bias, BLOCK_N> {
|
||||
static inline void apply(
|
||||
at::BFloat16* __restrict__ C,
|
||||
const int32_t* __restrict__ Ctmp,
|
||||
const int32_t* __restrict__ Bcomp,
|
||||
const float* __restrict__ bias,
|
||||
float As,
|
||||
const float* __restrict__ Bs) {
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
static_assert(COLS % 2 == 0);
|
||||
|
||||
__m512 vc[COLS];
|
||||
__m512 vd0 = _mm512_set1_ps(As);
|
||||
|
||||
auto compute = [&](auto col) {
|
||||
__m512 vd1 = _mm512_loadu_ps(Bs + col * 16);
|
||||
__m512i vcomp = _mm512_loadu_si512(Bcomp + col * 16);
|
||||
__m512i vc32 = _mm512_loadu_si512(Ctmp + col * 16);
|
||||
vc[col] = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc32, vcomp));
|
||||
if constexpr (has_bias) {
|
||||
__m512 vbias = _mm512_loadu_ps(bias + col * 16);
|
||||
vc[col] = _mm512_fmadd_ps(_mm512_mul_ps(vc[col], vd0), vd1, vbias);
|
||||
} else {
|
||||
vc[col] = _mm512_mul_ps(_mm512_mul_ps(vc[col], vd0), vd1);
|
||||
}
|
||||
};
|
||||
Unroll<COLS>{}(compute);
|
||||
|
||||
auto storec = [&](auto col) {
|
||||
// for COLS = 2, 4 use 512bit store
|
||||
if constexpr (col % 2 == 0) {
|
||||
_mm512_storeu_si512(
|
||||
reinterpret_cast<__m512i*>((C + col * 16)), (__m512i)(_mm512_cvtne2ps_pbh(vc[col + 1], vc[col + 0])));
|
||||
}
|
||||
};
|
||||
Unroll<COLS>{}(storec);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename scalar_t, bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn {
|
||||
static inline void apply(
|
||||
@@ -169,6 +224,17 @@ void tinygemm_kernel(
|
||||
// B compensation
|
||||
const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * K);
|
||||
|
||||
if (brg) {
|
||||
constexpr int BLOCK_N = block_size_n();
|
||||
at::native::cpublas::brgemm(M, N, K, lda, ldb, BLOCK_N, /* add_C */ false, A, B, Ctmp);
|
||||
|
||||
// apply compensation and scale
|
||||
for (int64_t m = 0; m < M; ++m) {
|
||||
scale_C<scalar_t, has_bias, BLOCK_N>::apply(C + m * ldc, Ctmp + m * BLOCK_N, Bcomp, bias, As[m], Bs);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// pattern: 1-4-16
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 64;
|
||||
@@ -233,22 +299,17 @@ void int8_scaled_mm_kernel_impl(
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
// TODO: brgemm u8s8 depends on PyTorch 2.7 release.
|
||||
const bool use_brgemm = false;
|
||||
const bool use_brgemm = can_use_brgemm<int8_t>(M);
|
||||
|
||||
// K + 4 after compensation
|
||||
const int64_t packed_row_size = get_row_size<int8_t>(K);
|
||||
|
||||
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t mb{0}, nb{0};
|
||||
data_index_init(begin, mb, MB, nb, NB);
|
||||
|
||||
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
|
||||
// for brgemm, use int32_t for accumulate
|
||||
alignas(64) int32_t Ctmp[BLOCK_M * BLOCK_N];
|
||||
|
||||
for (int i = begin; i < end; ++i) {
|
||||
UNUSED(i);
|
||||
loop_2d<int8_t>(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
|
||||
int mb_start = mb * BLOCK_M;
|
||||
int mb_size = std::min(M - mb_start, BLOCK_M);
|
||||
int nb_start = nb * BLOCK_N;
|
||||
@@ -269,10 +330,7 @@ void int8_scaled_mm_kernel_impl(
|
||||
/* ldb */ nb_size,
|
||||
/* ldc */ N,
|
||||
/* brg */ use_brgemm);
|
||||
|
||||
// move to the next index
|
||||
data_index_step(mb, MB, nb, NB);
|
||||
}
|
||||
});
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
|
||||
@@ -579,36 +579,31 @@ void fused_experts_kernel_impl(
|
||||
const int64_t stride_e = 2 * N * K;
|
||||
const int64_t stride_n = K;
|
||||
|
||||
int64_t avg_M = std::max(int64_t(1), M * topk / E);
|
||||
const bool use_brgemm = can_use_brgemm<scalar_t>(avg_M);
|
||||
|
||||
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
|
||||
// get local pointers
|
||||
int tid = at::get_thread_num();
|
||||
int tid = get_thread_num();
|
||||
scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K;
|
||||
float* __restrict__ C0 = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
|
||||
float* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N;
|
||||
|
||||
bool is_brgemm_used = false;
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB;
|
||||
int64_t nb = i % NB;
|
||||
|
||||
// nb0 from top half and nb1 from bottom half
|
||||
int64_t nb0 = nb, nb1 = nb + NB;
|
||||
int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N);
|
||||
loop_2d<scalar_t>(mb0, mb1, nb0, nb1, BLOCK_N * K * 2, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
|
||||
// nb_upper from top half and nb_lower from bottom half
|
||||
int64_t nb_upper = nb, nb_lower = nb + NB;
|
||||
int64_t n_size = std::min(N - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
// B shape [K, n_size] in vnni format
|
||||
int32_t expert_id = expert_ids[mb];
|
||||
const scalar_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb0 * BLOCK_N * stride_n;
|
||||
const scalar_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb1 * BLOCK_N * stride_n;
|
||||
const scalar_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb_upper * BLOCK_N * stride_n;
|
||||
const scalar_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb_lower * BLOCK_N * stride_n;
|
||||
|
||||
// 1.a load A
|
||||
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
|
||||
int64_t m_size = offsets[mb + 1] - offsets[mb];
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<scalar_t>(m_size);
|
||||
is_brgemm_used = is_brgemm_used || use_brgemm;
|
||||
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
int32_t index = A_ids[m] / topk;
|
||||
copy_stub(A + m * K, input + index * K, K);
|
||||
@@ -659,9 +654,9 @@ void fused_experts_kernel_impl(
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ N);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (is_brgemm_used) {
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
@@ -676,24 +671,16 @@ void fused_experts_kernel_impl(
|
||||
const int64_t stride_oc = IC;
|
||||
|
||||
// parallel on [MB2, NB2]
|
||||
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
|
||||
parallel_2d(MB2, NB2, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
|
||||
// get local pointers
|
||||
int tid = at::get_thread_num();
|
||||
int tid = get_thread_num();
|
||||
// we won't be using C1 for gemm2
|
||||
float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
|
||||
|
||||
bool is_brgemm_used = false;
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB2;
|
||||
int64_t nb = i % NB2;
|
||||
|
||||
loop_2d<scalar_t>(mb0, mb1, nb0, nb1, BLOCK_N * IC, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
|
||||
int64_t m_size = offsets[mb + 1] - offsets[mb];
|
||||
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<scalar_t>(m_size);
|
||||
is_brgemm_used = is_brgemm_used || use_brgemm;
|
||||
|
||||
// A ptr from ic1 of [M * topk, N] in sorted order
|
||||
// so as to avoid copy A to tmp buffer again
|
||||
const scalar_t* __restrict__ A = ic1 + offsets[mb] * N;
|
||||
@@ -736,9 +723,9 @@ void fused_experts_kernel_impl(
|
||||
float weight = topk_weights[index];
|
||||
copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (is_brgemm_used) {
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
@@ -776,36 +763,27 @@ void shared_expert_kernel_impl(
|
||||
TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N);
|
||||
const int64_t stride_n = K;
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<scalar_t>(M);
|
||||
|
||||
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
|
||||
// get local pointers
|
||||
int tid = at::get_thread_num();
|
||||
int tid = get_thread_num();
|
||||
float* __restrict__ C0 = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
|
||||
float* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N;
|
||||
|
||||
bool is_brgemm_used = false;
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB;
|
||||
int64_t nb = i % NB;
|
||||
|
||||
// nb0 from top half and nb1 from bottom half
|
||||
int64_t nb0 = nb, nb1 = nb + NB;
|
||||
int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N);
|
||||
loop_2d<scalar_t>(mb0, mb1, nb0, nb1, BLOCK_N * K * 2, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
|
||||
// nb_upper from top half and nb_lower from bottom half
|
||||
int64_t nb_upper = nb, nb_lower = nb + NB;
|
||||
int64_t n_size = std::min(N - nb * BLOCK_N, BLOCK_N);
|
||||
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
|
||||
// int64_t mb_start = mb * BLOCK_M;
|
||||
// int64_t mb_size = std::min(M - mb_start, BLOCK_M);
|
||||
|
||||
// A shape [m_size, K]
|
||||
const scalar_t* A = input + mb * BLOCK_M * K;
|
||||
|
||||
// B shape [K, n_size] in vnni format
|
||||
const scalar_t* __restrict__ B0 = packed_w1 + nb0 * BLOCK_N * stride_n;
|
||||
const scalar_t* __restrict__ B1 = packed_w1 + nb1 * BLOCK_N * stride_n;
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<scalar_t>(m_size);
|
||||
is_brgemm_used = is_brgemm_used || use_brgemm;
|
||||
const scalar_t* __restrict__ B0 = packed_w1 + nb_upper * BLOCK_N * stride_n;
|
||||
const scalar_t* __restrict__ B1 = packed_w1 + nb_lower * BLOCK_N * stride_n;
|
||||
|
||||
if (use_brgemm) {
|
||||
// 1.b gemm: C0 = A @ B0
|
||||
@@ -850,9 +828,9 @@ void shared_expert_kernel_impl(
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ N);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (is_brgemm_used) {
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
@@ -866,24 +844,16 @@ void shared_expert_kernel_impl(
|
||||
const int64_t stride_oc = IC;
|
||||
|
||||
// parallel on [MB2, NB2]
|
||||
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
|
||||
parallel_2d(MB2, NB2, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
|
||||
// get local pointers
|
||||
int tid = at::get_thread_num();
|
||||
int tid = get_thread_num();
|
||||
// we won't be using C1 for gemm2
|
||||
float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
|
||||
|
||||
bool is_brgemm_used = false;
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB2;
|
||||
int64_t nb = i % NB2;
|
||||
|
||||
loop_2d<scalar_t>(mb0, mb1, nb0, nb1, BLOCK_N * IC, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
|
||||
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<scalar_t>(m_size);
|
||||
is_brgemm_used = is_brgemm_used || use_brgemm;
|
||||
|
||||
// A shape [m_size, IC]
|
||||
const scalar_t* __restrict__ A = ic1 + mb * BLOCK_M * N;
|
||||
|
||||
@@ -922,9 +892,9 @@ void shared_expert_kernel_impl(
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (is_brgemm_used) {
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
@@ -1086,7 +1056,7 @@ at::Tensor fused_experts_cpu(
|
||||
//
|
||||
// for fp8 w8a16:
|
||||
// 7. intermediate_cache0 : [M * topk, 2N]
|
||||
// 8. B_tmp : [T, BLOCK_N, std::max(K, N)]
|
||||
// 8. B_tmp : [T, MAX_CACHE_BLOCK_SIZE, BLOCK_N, std::max(K, N)]
|
||||
//
|
||||
int64_t buffer_size_nbytes = M * topk * N * 2 + M * topk * K * 2 +
|
||||
num_threads * BLOCK_M * K * (use_int8_w8a8 ? 1 : 2) +
|
||||
@@ -1096,7 +1066,7 @@ at::Tensor fused_experts_cpu(
|
||||
buffer_size_nbytes += std::max(M * K, M * topk * N) + M * topk * sizeof(float);
|
||||
}
|
||||
if (use_fp8_w8a16) {
|
||||
buffer_size_nbytes += M * topk * 2 * N * 2 + num_threads * BLOCK_N * std::max(K, N) * 2;
|
||||
buffer_size_nbytes += M * topk * 2 * N * 2 + num_threads * MAX_CACHE_BLOCK_SIZE * BLOCK_N * std::max(K, N) * 2;
|
||||
}
|
||||
|
||||
auto buffer2 = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar));
|
||||
@@ -1268,7 +1238,7 @@ at::Tensor shared_expert_cpu(
|
||||
//
|
||||
// for fp8 w8a16:
|
||||
// 5. intermediate_cache0 : [M, 2N]
|
||||
// 6. B_tmp: [T, BLOCK_M, max(K, N)]
|
||||
// 6. B_tmp: [T, MAX_CACHE_BLOCK_SIZE, BLOCK_M, max(K, N)]
|
||||
//
|
||||
int num_threads = at::get_num_threads();
|
||||
int64_t buffer_size_nbytes = M * N * 2 + num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float);
|
||||
@@ -1277,7 +1247,7 @@ at::Tensor shared_expert_cpu(
|
||||
buffer_size_nbytes += std::max(M * K, M * N) + M * sizeof(float);
|
||||
}
|
||||
if (use_fp8_w8a16) {
|
||||
buffer_size_nbytes += M * 2 * N * 2 + num_threads * BLOCK_M * std::max(K, N) * 2;
|
||||
buffer_size_nbytes += M * 2 * N * 2 + num_threads * MAX_CACHE_BLOCK_SIZE * BLOCK_M * std::max(K, N) * 2;
|
||||
}
|
||||
|
||||
auto buffer = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar));
|
||||
|
||||
@@ -174,18 +174,18 @@ void fused_experts_fp8_kernel_impl(
|
||||
const int64_t stride_e = 2 * N * K;
|
||||
const int64_t stride_n = K;
|
||||
|
||||
int64_t avg_M = std::max(int64_t(1), M * topk / E);
|
||||
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(avg_M);
|
||||
|
||||
int64_t B_tmp_size_per_thread = MAX_CACHE_BLOCK_SIZE * BLOCK_N * std::max(K, N);
|
||||
|
||||
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
|
||||
// get local pointers
|
||||
int tid = at::get_thread_num();
|
||||
int tid = get_thread_num();
|
||||
scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K;
|
||||
|
||||
bool is_brgemm_used = false;
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB;
|
||||
int64_t nb = i % NB;
|
||||
|
||||
loop_2d<at::Float8_e4m3fn>(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
|
||||
int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
// B shape [K, n_size] in vnni format
|
||||
@@ -194,13 +194,14 @@ void fused_experts_fp8_kernel_impl(
|
||||
const float* __restrict__ Bs =
|
||||
w1s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K;
|
||||
|
||||
// do unpacking for the first row or a new expert
|
||||
int32_t pre_expert_id = mb == 0 ? -1 : expert_ids[mb - 1];
|
||||
bool do_unpack = (mb == mb0) || (expert_id != pre_expert_id);
|
||||
|
||||
// 1.a load A
|
||||
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
|
||||
int64_t m_size = offsets[mb + 1] - offsets[mb];
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(m_size);
|
||||
is_brgemm_used = is_brgemm_used || use_brgemm;
|
||||
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
int32_t index = A_ids[m] / topk;
|
||||
copy_stub(A + m * K, input + index * K, K);
|
||||
@@ -211,7 +212,7 @@ void fused_experts_fp8_kernel_impl(
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ ic0 + offset * 2 * N + nb * BLOCK_N,
|
||||
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
|
||||
/* Btmp */ B_tmp + tid * B_tmp_size_per_thread + nb_offset * BLOCK_N * K,
|
||||
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
|
||||
/* scale */ Bs,
|
||||
/* M */ m_size,
|
||||
@@ -221,10 +222,11 @@ void fused_experts_fp8_kernel_impl(
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ 2 * N,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K);
|
||||
}
|
||||
/* block_size_K */ block_size_K,
|
||||
/* do_unpack */ do_unpack);
|
||||
});
|
||||
|
||||
if (is_brgemm_used) {
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
@@ -248,22 +250,14 @@ void fused_experts_fp8_kernel_impl(
|
||||
const int64_t stride_oc = IC;
|
||||
|
||||
// parallel on [MB2, NB2]
|
||||
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
|
||||
int tid = at::get_thread_num();
|
||||
parallel_2d(MB2, NB2, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
|
||||
int tid = get_thread_num();
|
||||
alignas(64) scalar_t C[BLOCK_M * BLOCK_K];
|
||||
|
||||
bool is_brgemm_used = false;
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB2;
|
||||
int64_t nb = i % NB2;
|
||||
|
||||
loop_2d<at::Float8_e4m3fn>(mb0, mb1, nb0, nb1, BLOCK_N * IC, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
|
||||
int64_t m_size = offsets[mb + 1] - offsets[mb];
|
||||
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(m_size);
|
||||
is_brgemm_used = is_brgemm_used || use_brgemm;
|
||||
|
||||
// A ptr from ic1 of [M * topk, N] in sorted order
|
||||
// so as to avoid copy A to tmp buffer again
|
||||
const scalar_t* __restrict__ A = ic1 + offsets[mb] * N;
|
||||
@@ -275,11 +269,15 @@ void fused_experts_fp8_kernel_impl(
|
||||
const float* __restrict__ Bs =
|
||||
w2s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K;
|
||||
|
||||
// do unpacking for the first row or a new expert
|
||||
int32_t pre_expert_id = mb == 0 ? -1 : expert_ids[mb - 1];
|
||||
bool do_unpack = (mb == mb0) || (expert_id != pre_expert_id);
|
||||
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ C,
|
||||
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
|
||||
/* Btmp */ B_tmp + tid * B_tmp_size_per_thread + nb_offset * BLOCK_N * IC,
|
||||
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
|
||||
/* scale */ Bs,
|
||||
/* M */ m_size,
|
||||
@@ -289,7 +287,8 @@ void fused_experts_fp8_kernel_impl(
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K);
|
||||
/* block_size_K */ block_size_K,
|
||||
/* do_unpack */ do_unpack);
|
||||
|
||||
// 2.b copy from C to ic2 in original order
|
||||
// and also mul topk_weights in float32
|
||||
@@ -298,9 +297,9 @@ void fused_experts_fp8_kernel_impl(
|
||||
float weight = topk_weights[index];
|
||||
copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (is_brgemm_used) {
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
@@ -374,20 +373,23 @@ void shared_expert_fp8_kernel_impl(
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M);
|
||||
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int tid = at::get_thread_num();
|
||||
int64_t B_tmp_size_per_thread = MAX_CACHE_BLOCK_SIZE * BLOCK_N * std::max(K, N);
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB;
|
||||
int64_t nb = i % NB;
|
||||
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
|
||||
int tid = get_thread_num();
|
||||
|
||||
loop_2d<at::Float8_e4m3fn>(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
|
||||
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
// do unpacking for the first row
|
||||
bool do_unpack = (mb == mb0);
|
||||
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ input + mb * BLOCK_M * K,
|
||||
/* B */ packed_w1 + nb * BLOCK_N * K,
|
||||
/* C */ ic0 + mb * BLOCK_M * 2 * N + nb * BLOCK_N,
|
||||
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
|
||||
/* Btmp */ B_tmp + tid * B_tmp_size_per_thread + nb_offset * BLOCK_N * K,
|
||||
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
|
||||
/* scale */ w1s + (nb / blocks_n_per_group) * scale_size_K,
|
||||
/* M */ m_size,
|
||||
@@ -397,8 +399,9 @@ void shared_expert_fp8_kernel_impl(
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ 2 * N,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K);
|
||||
}
|
||||
/* block_size_K */ block_size_K,
|
||||
/* do_unpack */ do_unpack);
|
||||
});
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
@@ -421,22 +424,23 @@ void shared_expert_fp8_kernel_impl(
|
||||
scale_size_K = div_up(N, block_size_K);
|
||||
|
||||
// parallel on [MB2, NB2]
|
||||
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
|
||||
int tid = at::get_thread_num();
|
||||
parallel_2d(MB2, NB2, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
|
||||
int tid = get_thread_num();
|
||||
alignas(64) scalar_t C[BLOCK_M * BLOCK_K];
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB2;
|
||||
int64_t nb = i % NB2;
|
||||
loop_2d<at::Float8_e4m3fn>(mb0, mb1, nb0, nb1, BLOCK_N * IC, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
|
||||
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
// do unpacking for the first row
|
||||
bool do_unpack = (mb == mb0);
|
||||
|
||||
// 2.a gemm: C = A @ B
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ ic1 + mb * BLOCK_M * N,
|
||||
/* B */ packed_w2 + nb * BLOCK_N * N,
|
||||
/* C */ C,
|
||||
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
|
||||
/* Btmp */ B_tmp + tid * B_tmp_size_per_thread + nb_offset * BLOCK_N * IC,
|
||||
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
|
||||
/* scale */ w2s + (nb / blocks_n_per_group) * scale_size_K,
|
||||
/* M */ m_size,
|
||||
@@ -446,7 +450,8 @@ void shared_expert_fp8_kernel_impl(
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K);
|
||||
/* block_size_K */ block_size_K,
|
||||
/* do_unpack */ do_unpack);
|
||||
|
||||
// 2.b copy from C to output and add fused_experts_out
|
||||
scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N;
|
||||
@@ -454,7 +459,7 @@ void shared_expert_fp8_kernel_impl(
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
if (use_brgemm) {
|
||||
|
||||
@@ -109,6 +109,120 @@ inline void add_mul_stub(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int BLOCK_N>
|
||||
inline void silu_and_mul(
|
||||
scalar_t* __restrict__ C,
|
||||
const int32_t* __restrict__ C0, // x: x0, x1
|
||||
const int32_t* __restrict__ C1, // y: y0, y1
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs0,
|
||||
const float* __restrict__ Bs1,
|
||||
const int32_t* __restrict__ Bcomp0,
|
||||
const int32_t* __restrict__ Bcomp1,
|
||||
int64_t m_size,
|
||||
int64_t N) {
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
static_assert(COLS % 2 == 0);
|
||||
|
||||
__m512 vc0[COLS];
|
||||
__m512 vc1[COLS];
|
||||
__m512i vcomp0[COLS];
|
||||
__m512i vcomp1[COLS];
|
||||
__m512 vas;
|
||||
__m512 vbs0[COLS];
|
||||
__m512 vbs1[COLS];
|
||||
|
||||
auto load_scale_and_comp = [&](auto col) {
|
||||
vcomp0[col] = _mm512_loadu_si512(Bcomp0 + col * 16);
|
||||
vcomp1[col] = _mm512_loadu_si512(Bcomp1 + col * 16);
|
||||
vbs0[col] = _mm512_loadu_ps(Bs0 + col * 16);
|
||||
vbs1[col] = _mm512_loadu_ps(Bs1 + col * 16);
|
||||
};
|
||||
Unroll<COLS>{}(load_scale_and_comp);
|
||||
|
||||
auto scalec = [&](auto col, int64_t m) {
|
||||
// update As
|
||||
vas = _mm512_set1_ps(As[m]);
|
||||
// C = As * (C - Bcomp) * Bs
|
||||
__m512i vc32_0 = _mm512_loadu_si512(C0 + m * BLOCK_N + col * 16);
|
||||
__m512i vc32_1 = _mm512_loadu_si512(C1 + m * BLOCK_N + col * 16);
|
||||
vc0[col] = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc32_0, vcomp0[col]));
|
||||
vc1[col] = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc32_1, vcomp1[col]));
|
||||
vc0[col] = _mm512_mul_ps(_mm512_mul_ps(vc0[col], vas), vbs0[col]);
|
||||
vc1[col] = _mm512_mul_ps(_mm512_mul_ps(vc1[col], vas), vbs1[col]);
|
||||
};
|
||||
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
const fVec one = fVec(1.f);
|
||||
auto silu_and_mul = [&](auto col) {
|
||||
fVec x = fVec(vc0[col]);
|
||||
fVec y = fVec(vc1[col]);
|
||||
x = x / (one + x.neg().exp_u20());
|
||||
vc0[col] = x * y;
|
||||
};
|
||||
|
||||
auto storec = [&](auto col, int64_t m) {
|
||||
if constexpr (col % 2 == 0) {
|
||||
fVec x0 = fVec(vc0[col + 0]);
|
||||
fVec x1 = fVec(vc0[col + 1]);
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
|
||||
out_vec.store(C + m * N + col * 16);
|
||||
}
|
||||
};
|
||||
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
Unroll<COLS>{}(scalec, m);
|
||||
Unroll<COLS>{}(silu_and_mul);
|
||||
Unroll<COLS>{}(storec, m);
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(false, "silu_and_mul: scalar path not implemented!");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int BLOCK_N>
|
||||
inline void scale_C(
|
||||
float* __restrict__ C,
|
||||
const int32_t* __restrict__ Ctmp,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs,
|
||||
const int32_t* __restrict__ Bcomp,
|
||||
int64_t m_size) {
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
static_assert(COLS % 2 == 0);
|
||||
|
||||
__m512 vc[COLS];
|
||||
__m512i vcomp[COLS];
|
||||
__m512 vas;
|
||||
__m512 vbs[COLS];
|
||||
|
||||
auto load_scale_and_comp = [&](auto col) {
|
||||
vcomp[col] = _mm512_loadu_si512(Bcomp + col * 16);
|
||||
vbs[col] = _mm512_loadu_ps(Bs + col * 16);
|
||||
};
|
||||
Unroll<COLS>{}(load_scale_and_comp);
|
||||
|
||||
auto scalec = [&](auto col, int64_t m) {
|
||||
// update As
|
||||
vas = _mm512_set1_ps(As[m]);
|
||||
// C = As * (C - Bcomp) * Bs
|
||||
__m512i vc32 = _mm512_loadu_si512(Ctmp + m * BLOCK_N + col * 16);
|
||||
vc[col] = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc32, vcomp[col]));
|
||||
vc[col] = _mm512_mul_ps(_mm512_mul_ps(vc[col], vas), vbs[col]);
|
||||
_mm512_storeu_ps(C + m * BLOCK_N + col * 16, vc[col]);
|
||||
};
|
||||
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
Unroll<COLS>{}(scalec, m);
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(false, "scale_C: scalar path not implemented!");
|
||||
#endif
|
||||
}
|
||||
|
||||
/// gemm for w13
|
||||
template <typename scalar_t, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_vnni {
|
||||
@@ -515,28 +629,31 @@ void fused_experts_int8_kernel_impl(
|
||||
|
||||
const int64_t stride_e = 2 * N * packed_K;
|
||||
const int64_t stride_n = packed_K;
|
||||
|
||||
int64_t avg_M = std::max(int64_t(1), M * topk / E);
|
||||
const bool use_brgemm = can_use_brgemm<int8_t>(avg_M);
|
||||
|
||||
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
|
||||
// get local pointers
|
||||
int tid = at::get_thread_num();
|
||||
int tid = get_thread_num();
|
||||
uint8_t* __restrict__ A = A_tmp + tid * BLOCK_M * K;
|
||||
int32_t* __restrict__ C0 = reinterpret_cast<int32_t*>(C_tmp) + tid * 2 * BLOCK_M * BLOCK_N;
|
||||
int32_t* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N;
|
||||
|
||||
alignas(64) float As[BLOCK_M];
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB;
|
||||
int64_t nb = i % NB;
|
||||
|
||||
// nb0 from top half and nb1 from bottom half
|
||||
int64_t nb0 = nb, nb1 = nb + NB;
|
||||
int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N);
|
||||
loop_2d<int8_t>(mb0, mb1, nb0, nb1, BLOCK_N * K * 2, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
|
||||
// nb_upper from top half and nb_lower from bottom half
|
||||
int64_t nb_upper = nb, nb_lower = nb + NB;
|
||||
int64_t n_size = std::min(N - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
// B shape [K, n_size] in vnni format
|
||||
int32_t expert_id = expert_ids[mb];
|
||||
const int8_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb0 * BLOCK_N * stride_n;
|
||||
const int8_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb1 * BLOCK_N * stride_n;
|
||||
const float* __restrict__ Bs0 = w1s + expert_id * 2 * N + nb0 * BLOCK_N;
|
||||
const float* __restrict__ Bs1 = w1s + expert_id * 2 * N + nb1 * BLOCK_N;
|
||||
const int8_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb_upper * BLOCK_N * stride_n;
|
||||
const int8_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb_lower * BLOCK_N * stride_n;
|
||||
const float* __restrict__ Bs0 = w1s + expert_id * 2 * N + nb_upper * BLOCK_N;
|
||||
const float* __restrict__ Bs1 = w1s + expert_id * 2 * N + nb_lower * BLOCK_N;
|
||||
|
||||
// 1.a load A
|
||||
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
|
||||
@@ -548,22 +665,62 @@ void fused_experts_int8_kernel_impl(
|
||||
As[m] = As_tmp[index];
|
||||
}
|
||||
|
||||
// fused 1.b: silu_and_mul(A @ B0, A @ B1)
|
||||
const int64_t offset = offsets[mb];
|
||||
tinygemm_kernel(
|
||||
/* A */ A,
|
||||
/* B0 */ B0,
|
||||
/* B1 */ B1,
|
||||
/* C */ ic1 + offset * N + nb * BLOCK_N,
|
||||
/* As */ As,
|
||||
/* Bs0 */ Bs0,
|
||||
/* Bs1 */ Bs1,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ N);
|
||||
if (use_brgemm) {
|
||||
// 1.b gemm: C0 = A @ B0
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N,
|
||||
/* add_C */ false,
|
||||
/* A */ A,
|
||||
/* B */ B0,
|
||||
/* C */ C0);
|
||||
|
||||
// 1.c gemm: C1 = A @ B1
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N,
|
||||
/* add_C */ false,
|
||||
/* A */ A,
|
||||
/* B */ B1,
|
||||
/* C */ C1);
|
||||
|
||||
const int32_t* Bcomp0 = reinterpret_cast<const int32_t*>(B0 + block_size_n() * K);
|
||||
const int32_t* Bcomp1 = reinterpret_cast<const int32_t*>(B1 + block_size_n() * K);
|
||||
|
||||
// 1.d silu and mul
|
||||
const int64_t offset = offsets[mb];
|
||||
silu_and_mul<scalar_t, BLOCK_N>(
|
||||
ic1 + offset * N + nb * BLOCK_N, C0, C1, As, Bs0, Bs1, Bcomp0, Bcomp1, m_size, N);
|
||||
} else {
|
||||
// fused 1.bcd: silu_and_mul(A @ B0, A @ B1)
|
||||
const int64_t offset = offsets[mb];
|
||||
tinygemm_kernel(
|
||||
/* A */ A,
|
||||
/* B0 */ B0,
|
||||
/* B1 */ B1,
|
||||
/* C */ ic1 + offset * N + nb * BLOCK_N,
|
||||
/* As */ As,
|
||||
/* Bs0 */ Bs0,
|
||||
/* Bs1 */ Bs1,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ N);
|
||||
}
|
||||
});
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
|
||||
@@ -584,16 +741,13 @@ void fused_experts_int8_kernel_impl(
|
||||
const int64_t stride_oc = packed_N;
|
||||
|
||||
// parallel on [MB2, NB2]
|
||||
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
|
||||
parallel_2d(MB2, NB2, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
|
||||
// get local pointers
|
||||
int tid = at::get_thread_num();
|
||||
// we won't be using C1 for gemm2
|
||||
int tid = get_thread_num();
|
||||
float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
|
||||
int32_t* __restrict__ C32 = reinterpret_cast<int32_t*>(C + BLOCK_M * BLOCK_N);
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB2;
|
||||
int64_t nb = i % NB2;
|
||||
|
||||
loop_2d<int8_t>(mb0, mb1, nb0, nb1, BLOCK_N * IC, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
|
||||
int64_t m_size = offsets[mb + 1] - offsets[mb];
|
||||
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
@@ -609,18 +763,36 @@ void fused_experts_int8_kernel_impl(
|
||||
const float* __restrict__ Bs = w2s + expert_id * K + nb * BLOCK_N;
|
||||
|
||||
// 2.a gemm: C = A @ B
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ C,
|
||||
/* As */ As,
|
||||
/* Bs */ Bs,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ IC,
|
||||
/* lda */ IC,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N);
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ IC,
|
||||
/* lda */ IC,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N,
|
||||
/* add_C */ false,
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ C32);
|
||||
|
||||
// apply scales
|
||||
const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * IC);
|
||||
scale_C<BLOCK_N>(C, C32, As, Bs, Bcomp, m_size);
|
||||
} else {
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ C,
|
||||
/* As */ As,
|
||||
/* Bs */ Bs,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ IC,
|
||||
/* lda */ IC,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N);
|
||||
}
|
||||
|
||||
// 2.b copy from C to ic2 in original order
|
||||
// and also mul topk_weights in float32
|
||||
@@ -629,6 +801,10 @@ void fused_experts_int8_kernel_impl(
|
||||
float weight = topk_weights[index];
|
||||
copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size);
|
||||
}
|
||||
});
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
|
||||
@@ -708,15 +884,19 @@ void shared_expert_int8_kernel_impl(
|
||||
const int64_t packed_N = get_row_size<int8_t>(N);
|
||||
const int64_t stride_n = packed_K;
|
||||
|
||||
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB;
|
||||
int64_t nb = i % NB;
|
||||
const bool use_brgemm = can_use_brgemm<int8_t>(M);
|
||||
|
||||
// nb0 from top half and nb1 from bottom half
|
||||
int64_t nb0 = nb, nb1 = nb + NB;
|
||||
int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N);
|
||||
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
|
||||
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
|
||||
// get local pointers
|
||||
int tid = get_thread_num();
|
||||
int32_t* __restrict__ C0 = reinterpret_cast<int32_t*>(C_tmp) + tid * 2 * BLOCK_M * BLOCK_N;
|
||||
int32_t* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N;
|
||||
|
||||
loop_2d<int8_t>(mb0, mb1, nb0, nb1, BLOCK_N * K * 2, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
|
||||
// nb_upper from top half and nb_lower from bottom half
|
||||
int64_t nb_upper = nb, nb_lower = nb + NB;
|
||||
int64_t n_size = std::min(N - nb * BLOCK_N, BLOCK_N);
|
||||
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
|
||||
// A shape [m_size, K]
|
||||
@@ -724,26 +904,65 @@ void shared_expert_int8_kernel_impl(
|
||||
const float* As = As_tmp + mb * BLOCK_M;
|
||||
|
||||
// B shape [K, n_size] in vnni format
|
||||
const int8_t* __restrict__ B0 = packed_w1 + nb0 * BLOCK_N * stride_n;
|
||||
const int8_t* __restrict__ B1 = packed_w1 + nb1 * BLOCK_N * stride_n;
|
||||
const float* __restrict__ Bs0 = w1s + nb0 * BLOCK_N;
|
||||
const float* __restrict__ Bs1 = w1s + nb1 * BLOCK_N;
|
||||
const int8_t* __restrict__ B0 = packed_w1 + nb_upper * BLOCK_N * stride_n;
|
||||
const int8_t* __restrict__ B1 = packed_w1 + nb_lower * BLOCK_N * stride_n;
|
||||
const float* __restrict__ Bs0 = w1s + nb_upper * BLOCK_N;
|
||||
const float* __restrict__ Bs1 = w1s + nb_lower * BLOCK_N;
|
||||
|
||||
// fused 1.b: silu_and_mul(A @ B0, A @ B1)
|
||||
tinygemm_kernel(
|
||||
/* A */ A,
|
||||
/* B0 */ B0,
|
||||
/* B1 */ B1,
|
||||
/* C */ ic1 + mb * BLOCK_M * N + nb * BLOCK_N,
|
||||
/* As */ As,
|
||||
/* Bs0 */ Bs0,
|
||||
/* Bs1 */ Bs1,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ N);
|
||||
if (use_brgemm) {
|
||||
// 1.b gemm: C0 = A @ B0
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N,
|
||||
/* add_C */ false,
|
||||
/* A */ A,
|
||||
/* B */ B0,
|
||||
/* C */ C0);
|
||||
|
||||
// 1.c gemm: C1 = A @ B1
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N,
|
||||
/* add_C */ false,
|
||||
/* A */ A,
|
||||
/* B */ B1,
|
||||
/* C */ C1);
|
||||
|
||||
const int32_t* Bcomp0 = reinterpret_cast<const int32_t*>(B0 + block_size_n() * K);
|
||||
const int32_t* Bcomp1 = reinterpret_cast<const int32_t*>(B1 + block_size_n() * K);
|
||||
|
||||
// 1.d silu and mul
|
||||
silu_and_mul<scalar_t, BLOCK_N>(
|
||||
ic1 + mb * BLOCK_M * N + nb * BLOCK_N, C0, C1, As, Bs0, Bs1, Bcomp0, Bcomp1, m_size, N);
|
||||
} else {
|
||||
// fused 1.bcd: silu_and_mul(A @ B0, A @ B1)
|
||||
tinygemm_kernel(
|
||||
/* A */ A,
|
||||
/* B0 */ B0,
|
||||
/* B1 */ B1,
|
||||
/* C */ ic1 + mb * BLOCK_M * N + nb * BLOCK_N,
|
||||
/* As */ As,
|
||||
/* Bs0 */ Bs0,
|
||||
/* Bs1 */ Bs1,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ N);
|
||||
}
|
||||
});
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
|
||||
@@ -763,16 +982,13 @@ void shared_expert_int8_kernel_impl(
|
||||
const int64_t stride_oc = packed_N;
|
||||
|
||||
// parallel on [MB2, NB2]
|
||||
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
|
||||
parallel_2d(MB2, NB2, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
|
||||
// get local pointers
|
||||
int tid = at::get_thread_num();
|
||||
// we won't be using C1 for gemm2
|
||||
int tid = get_thread_num();
|
||||
float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
|
||||
int32_t* __restrict__ C32 = reinterpret_cast<int32_t*>(C + BLOCK_M * BLOCK_N);
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB2;
|
||||
int64_t nb = i % NB2;
|
||||
|
||||
loop_2d<int8_t>(mb0, mb1, nb0, nb1, BLOCK_N * IC, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
|
||||
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
@@ -784,19 +1000,37 @@ void shared_expert_int8_kernel_impl(
|
||||
const int8_t* __restrict__ B = packed_w2 + nb * BLOCK_N * stride_oc;
|
||||
const float* __restrict__ Bs = w2s + nb * BLOCK_N;
|
||||
|
||||
// 2.a gemm: C = A @ B
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ C,
|
||||
/* As */ As,
|
||||
/* Bs */ Bs,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ IC,
|
||||
/* lda */ IC,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N);
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ IC,
|
||||
/* lda */ IC,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N,
|
||||
/* add_C */ false,
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ C32);
|
||||
|
||||
// apply scales
|
||||
const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * IC);
|
||||
scale_C<BLOCK_N>(C, C32, As, Bs, Bcomp, m_size);
|
||||
} else {
|
||||
// 2.a gemm: C = A @ B
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ C,
|
||||
/* As */ As,
|
||||
/* Bs */ Bs,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ IC,
|
||||
/* lda */ IC,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N);
|
||||
}
|
||||
|
||||
// 2.b copy from C to output and add fused_experts_out
|
||||
scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N;
|
||||
@@ -804,6 +1038,10 @@ void shared_expert_int8_kernel_impl(
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size);
|
||||
}
|
||||
});
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
@@ -100,8 +100,7 @@ void segment_gemm_kernel_impl(
|
||||
const int64_t NB1 = div_up(N1, BLOCK_N);
|
||||
const int64_t NB = NB0 + NB1;
|
||||
|
||||
// TODO: brgemm u8s8 depends on PyTorch 2.7 release.
|
||||
const bool use_brgemm = false;
|
||||
const bool use_brgemm = can_use_brgemm<int8_t>(M);
|
||||
|
||||
// K + 4 after compensation
|
||||
const int64_t packed_row_size = get_row_size<int8_t>(K);
|
||||
|
||||
Reference in New Issue
Block a user