Optimize prefill performance on cpu backend (#8750)
This commit is contained in:
@@ -109,6 +109,120 @@ inline void add_mul_stub(
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int BLOCK_N>
|
||||
inline void silu_and_mul(
|
||||
scalar_t* __restrict__ C,
|
||||
const int32_t* __restrict__ C0, // x: x0, x1
|
||||
const int32_t* __restrict__ C1, // y: y0, y1
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs0,
|
||||
const float* __restrict__ Bs1,
|
||||
const int32_t* __restrict__ Bcomp0,
|
||||
const int32_t* __restrict__ Bcomp1,
|
||||
int64_t m_size,
|
||||
int64_t N) {
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
static_assert(COLS % 2 == 0);
|
||||
|
||||
__m512 vc0[COLS];
|
||||
__m512 vc1[COLS];
|
||||
__m512i vcomp0[COLS];
|
||||
__m512i vcomp1[COLS];
|
||||
__m512 vas;
|
||||
__m512 vbs0[COLS];
|
||||
__m512 vbs1[COLS];
|
||||
|
||||
auto load_scale_and_comp = [&](auto col) {
|
||||
vcomp0[col] = _mm512_loadu_si512(Bcomp0 + col * 16);
|
||||
vcomp1[col] = _mm512_loadu_si512(Bcomp1 + col * 16);
|
||||
vbs0[col] = _mm512_loadu_ps(Bs0 + col * 16);
|
||||
vbs1[col] = _mm512_loadu_ps(Bs1 + col * 16);
|
||||
};
|
||||
Unroll<COLS>{}(load_scale_and_comp);
|
||||
|
||||
auto scalec = [&](auto col, int64_t m) {
|
||||
// update As
|
||||
vas = _mm512_set1_ps(As[m]);
|
||||
// C = As * (C - Bcomp) * Bs
|
||||
__m512i vc32_0 = _mm512_loadu_si512(C0 + m * BLOCK_N + col * 16);
|
||||
__m512i vc32_1 = _mm512_loadu_si512(C1 + m * BLOCK_N + col * 16);
|
||||
vc0[col] = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc32_0, vcomp0[col]));
|
||||
vc1[col] = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc32_1, vcomp1[col]));
|
||||
vc0[col] = _mm512_mul_ps(_mm512_mul_ps(vc0[col], vas), vbs0[col]);
|
||||
vc1[col] = _mm512_mul_ps(_mm512_mul_ps(vc1[col], vas), vbs1[col]);
|
||||
};
|
||||
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
const fVec one = fVec(1.f);
|
||||
auto silu_and_mul = [&](auto col) {
|
||||
fVec x = fVec(vc0[col]);
|
||||
fVec y = fVec(vc1[col]);
|
||||
x = x / (one + x.neg().exp_u20());
|
||||
vc0[col] = x * y;
|
||||
};
|
||||
|
||||
auto storec = [&](auto col, int64_t m) {
|
||||
if constexpr (col % 2 == 0) {
|
||||
fVec x0 = fVec(vc0[col + 0]);
|
||||
fVec x1 = fVec(vc0[col + 1]);
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
|
||||
out_vec.store(C + m * N + col * 16);
|
||||
}
|
||||
};
|
||||
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
Unroll<COLS>{}(scalec, m);
|
||||
Unroll<COLS>{}(silu_and_mul);
|
||||
Unroll<COLS>{}(storec, m);
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(false, "silu_and_mul: scalar path not implemented!");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int BLOCK_N>
|
||||
inline void scale_C(
|
||||
float* __restrict__ C,
|
||||
const int32_t* __restrict__ Ctmp,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs,
|
||||
const int32_t* __restrict__ Bcomp,
|
||||
int64_t m_size) {
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
static_assert(COLS % 2 == 0);
|
||||
|
||||
__m512 vc[COLS];
|
||||
__m512i vcomp[COLS];
|
||||
__m512 vas;
|
||||
__m512 vbs[COLS];
|
||||
|
||||
auto load_scale_and_comp = [&](auto col) {
|
||||
vcomp[col] = _mm512_loadu_si512(Bcomp + col * 16);
|
||||
vbs[col] = _mm512_loadu_ps(Bs + col * 16);
|
||||
};
|
||||
Unroll<COLS>{}(load_scale_and_comp);
|
||||
|
||||
auto scalec = [&](auto col, int64_t m) {
|
||||
// update As
|
||||
vas = _mm512_set1_ps(As[m]);
|
||||
// C = As * (C - Bcomp) * Bs
|
||||
__m512i vc32 = _mm512_loadu_si512(Ctmp + m * BLOCK_N + col * 16);
|
||||
vc[col] = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc32, vcomp[col]));
|
||||
vc[col] = _mm512_mul_ps(_mm512_mul_ps(vc[col], vas), vbs[col]);
|
||||
_mm512_storeu_ps(C + m * BLOCK_N + col * 16, vc[col]);
|
||||
};
|
||||
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
Unroll<COLS>{}(scalec, m);
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(false, "scale_C: scalar path not implemented!");
|
||||
#endif
|
||||
}
|
||||
|
||||
/// gemm for w13
|
||||
template <typename scalar_t, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_vnni {
|
||||
@@ -515,28 +629,31 @@ void fused_experts_int8_kernel_impl(
|
||||
|
||||
const int64_t stride_e = 2 * N * packed_K;
|
||||
const int64_t stride_n = packed_K;
|
||||
|
||||
int64_t avg_M = std::max(int64_t(1), M * topk / E);
|
||||
const bool use_brgemm = can_use_brgemm<int8_t>(avg_M);
|
||||
|
||||
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
|
||||
// get local pointers
|
||||
int tid = at::get_thread_num();
|
||||
int tid = get_thread_num();
|
||||
uint8_t* __restrict__ A = A_tmp + tid * BLOCK_M * K;
|
||||
int32_t* __restrict__ C0 = reinterpret_cast<int32_t*>(C_tmp) + tid * 2 * BLOCK_M * BLOCK_N;
|
||||
int32_t* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N;
|
||||
|
||||
alignas(64) float As[BLOCK_M];
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB;
|
||||
int64_t nb = i % NB;
|
||||
|
||||
// nb0 from top half and nb1 from bottom half
|
||||
int64_t nb0 = nb, nb1 = nb + NB;
|
||||
int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N);
|
||||
loop_2d<int8_t>(mb0, mb1, nb0, nb1, BLOCK_N * K * 2, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
|
||||
// nb_upper from top half and nb_lower from bottom half
|
||||
int64_t nb_upper = nb, nb_lower = nb + NB;
|
||||
int64_t n_size = std::min(N - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
// B shape [K, n_size] in vnni format
|
||||
int32_t expert_id = expert_ids[mb];
|
||||
const int8_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb0 * BLOCK_N * stride_n;
|
||||
const int8_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb1 * BLOCK_N * stride_n;
|
||||
const float* __restrict__ Bs0 = w1s + expert_id * 2 * N + nb0 * BLOCK_N;
|
||||
const float* __restrict__ Bs1 = w1s + expert_id * 2 * N + nb1 * BLOCK_N;
|
||||
const int8_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb_upper * BLOCK_N * stride_n;
|
||||
const int8_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb_lower * BLOCK_N * stride_n;
|
||||
const float* __restrict__ Bs0 = w1s + expert_id * 2 * N + nb_upper * BLOCK_N;
|
||||
const float* __restrict__ Bs1 = w1s + expert_id * 2 * N + nb_lower * BLOCK_N;
|
||||
|
||||
// 1.a load A
|
||||
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
|
||||
@@ -548,22 +665,62 @@ void fused_experts_int8_kernel_impl(
|
||||
As[m] = As_tmp[index];
|
||||
}
|
||||
|
||||
// fused 1.b: silu_and_mul(A @ B0, A @ B1)
|
||||
const int64_t offset = offsets[mb];
|
||||
tinygemm_kernel(
|
||||
/* A */ A,
|
||||
/* B0 */ B0,
|
||||
/* B1 */ B1,
|
||||
/* C */ ic1 + offset * N + nb * BLOCK_N,
|
||||
/* As */ As,
|
||||
/* Bs0 */ Bs0,
|
||||
/* Bs1 */ Bs1,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ N);
|
||||
if (use_brgemm) {
|
||||
// 1.b gemm: C0 = A @ B0
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N,
|
||||
/* add_C */ false,
|
||||
/* A */ A,
|
||||
/* B */ B0,
|
||||
/* C */ C0);
|
||||
|
||||
// 1.c gemm: C1 = A @ B1
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N,
|
||||
/* add_C */ false,
|
||||
/* A */ A,
|
||||
/* B */ B1,
|
||||
/* C */ C1);
|
||||
|
||||
const int32_t* Bcomp0 = reinterpret_cast<const int32_t*>(B0 + block_size_n() * K);
|
||||
const int32_t* Bcomp1 = reinterpret_cast<const int32_t*>(B1 + block_size_n() * K);
|
||||
|
||||
// 1.d silu and mul
|
||||
const int64_t offset = offsets[mb];
|
||||
silu_and_mul<scalar_t, BLOCK_N>(
|
||||
ic1 + offset * N + nb * BLOCK_N, C0, C1, As, Bs0, Bs1, Bcomp0, Bcomp1, m_size, N);
|
||||
} else {
|
||||
// fused 1.bcd: silu_and_mul(A @ B0, A @ B1)
|
||||
const int64_t offset = offsets[mb];
|
||||
tinygemm_kernel(
|
||||
/* A */ A,
|
||||
/* B0 */ B0,
|
||||
/* B1 */ B1,
|
||||
/* C */ ic1 + offset * N + nb * BLOCK_N,
|
||||
/* As */ As,
|
||||
/* Bs0 */ Bs0,
|
||||
/* Bs1 */ Bs1,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ N);
|
||||
}
|
||||
});
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
|
||||
@@ -584,16 +741,13 @@ void fused_experts_int8_kernel_impl(
|
||||
const int64_t stride_oc = packed_N;
|
||||
|
||||
// parallel on [MB2, NB2]
|
||||
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
|
||||
parallel_2d(MB2, NB2, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
|
||||
// get local pointers
|
||||
int tid = at::get_thread_num();
|
||||
// we won't be using C1 for gemm2
|
||||
int tid = get_thread_num();
|
||||
float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
|
||||
int32_t* __restrict__ C32 = reinterpret_cast<int32_t*>(C + BLOCK_M * BLOCK_N);
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB2;
|
||||
int64_t nb = i % NB2;
|
||||
|
||||
loop_2d<int8_t>(mb0, mb1, nb0, nb1, BLOCK_N * IC, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
|
||||
int64_t m_size = offsets[mb + 1] - offsets[mb];
|
||||
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
@@ -609,18 +763,36 @@ void fused_experts_int8_kernel_impl(
|
||||
const float* __restrict__ Bs = w2s + expert_id * K + nb * BLOCK_N;
|
||||
|
||||
// 2.a gemm: C = A @ B
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ C,
|
||||
/* As */ As,
|
||||
/* Bs */ Bs,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ IC,
|
||||
/* lda */ IC,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N);
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ IC,
|
||||
/* lda */ IC,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N,
|
||||
/* add_C */ false,
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ C32);
|
||||
|
||||
// apply scales
|
||||
const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * IC);
|
||||
scale_C<BLOCK_N>(C, C32, As, Bs, Bcomp, m_size);
|
||||
} else {
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ C,
|
||||
/* As */ As,
|
||||
/* Bs */ Bs,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ IC,
|
||||
/* lda */ IC,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N);
|
||||
}
|
||||
|
||||
// 2.b copy from C to ic2 in original order
|
||||
// and also mul topk_weights in float32
|
||||
@@ -629,6 +801,10 @@ void fused_experts_int8_kernel_impl(
|
||||
float weight = topk_weights[index];
|
||||
copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size);
|
||||
}
|
||||
});
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
|
||||
@@ -708,15 +884,19 @@ void shared_expert_int8_kernel_impl(
|
||||
const int64_t packed_N = get_row_size<int8_t>(N);
|
||||
const int64_t stride_n = packed_K;
|
||||
|
||||
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB;
|
||||
int64_t nb = i % NB;
|
||||
const bool use_brgemm = can_use_brgemm<int8_t>(M);
|
||||
|
||||
// nb0 from top half and nb1 from bottom half
|
||||
int64_t nb0 = nb, nb1 = nb + NB;
|
||||
int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N);
|
||||
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
|
||||
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
|
||||
// get local pointers
|
||||
int tid = get_thread_num();
|
||||
int32_t* __restrict__ C0 = reinterpret_cast<int32_t*>(C_tmp) + tid * 2 * BLOCK_M * BLOCK_N;
|
||||
int32_t* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N;
|
||||
|
||||
loop_2d<int8_t>(mb0, mb1, nb0, nb1, BLOCK_N * K * 2, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
|
||||
// nb_upper from top half and nb_lower from bottom half
|
||||
int64_t nb_upper = nb, nb_lower = nb + NB;
|
||||
int64_t n_size = std::min(N - nb * BLOCK_N, BLOCK_N);
|
||||
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
|
||||
// A shape [m_size, K]
|
||||
@@ -724,26 +904,65 @@ void shared_expert_int8_kernel_impl(
|
||||
const float* As = As_tmp + mb * BLOCK_M;
|
||||
|
||||
// B shape [K, n_size] in vnni format
|
||||
const int8_t* __restrict__ B0 = packed_w1 + nb0 * BLOCK_N * stride_n;
|
||||
const int8_t* __restrict__ B1 = packed_w1 + nb1 * BLOCK_N * stride_n;
|
||||
const float* __restrict__ Bs0 = w1s + nb0 * BLOCK_N;
|
||||
const float* __restrict__ Bs1 = w1s + nb1 * BLOCK_N;
|
||||
const int8_t* __restrict__ B0 = packed_w1 + nb_upper * BLOCK_N * stride_n;
|
||||
const int8_t* __restrict__ B1 = packed_w1 + nb_lower * BLOCK_N * stride_n;
|
||||
const float* __restrict__ Bs0 = w1s + nb_upper * BLOCK_N;
|
||||
const float* __restrict__ Bs1 = w1s + nb_lower * BLOCK_N;
|
||||
|
||||
// fused 1.b: silu_and_mul(A @ B0, A @ B1)
|
||||
tinygemm_kernel(
|
||||
/* A */ A,
|
||||
/* B0 */ B0,
|
||||
/* B1 */ B1,
|
||||
/* C */ ic1 + mb * BLOCK_M * N + nb * BLOCK_N,
|
||||
/* As */ As,
|
||||
/* Bs0 */ Bs0,
|
||||
/* Bs1 */ Bs1,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ N);
|
||||
if (use_brgemm) {
|
||||
// 1.b gemm: C0 = A @ B0
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N,
|
||||
/* add_C */ false,
|
||||
/* A */ A,
|
||||
/* B */ B0,
|
||||
/* C */ C0);
|
||||
|
||||
// 1.c gemm: C1 = A @ B1
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N,
|
||||
/* add_C */ false,
|
||||
/* A */ A,
|
||||
/* B */ B1,
|
||||
/* C */ C1);
|
||||
|
||||
const int32_t* Bcomp0 = reinterpret_cast<const int32_t*>(B0 + block_size_n() * K);
|
||||
const int32_t* Bcomp1 = reinterpret_cast<const int32_t*>(B1 + block_size_n() * K);
|
||||
|
||||
// 1.d silu and mul
|
||||
silu_and_mul<scalar_t, BLOCK_N>(
|
||||
ic1 + mb * BLOCK_M * N + nb * BLOCK_N, C0, C1, As, Bs0, Bs1, Bcomp0, Bcomp1, m_size, N);
|
||||
} else {
|
||||
// fused 1.bcd: silu_and_mul(A @ B0, A @ B1)
|
||||
tinygemm_kernel(
|
||||
/* A */ A,
|
||||
/* B0 */ B0,
|
||||
/* B1 */ B1,
|
||||
/* C */ ic1 + mb * BLOCK_M * N + nb * BLOCK_N,
|
||||
/* As */ As,
|
||||
/* Bs0 */ Bs0,
|
||||
/* Bs1 */ Bs1,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ N);
|
||||
}
|
||||
});
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
|
||||
@@ -763,16 +982,13 @@ void shared_expert_int8_kernel_impl(
|
||||
const int64_t stride_oc = packed_N;
|
||||
|
||||
// parallel on [MB2, NB2]
|
||||
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
|
||||
parallel_2d(MB2, NB2, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
|
||||
// get local pointers
|
||||
int tid = at::get_thread_num();
|
||||
// we won't be using C1 for gemm2
|
||||
int tid = get_thread_num();
|
||||
float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
|
||||
int32_t* __restrict__ C32 = reinterpret_cast<int32_t*>(C + BLOCK_M * BLOCK_N);
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB2;
|
||||
int64_t nb = i % NB2;
|
||||
|
||||
loop_2d<int8_t>(mb0, mb1, nb0, nb1, BLOCK_N * IC, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
|
||||
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
@@ -784,19 +1000,37 @@ void shared_expert_int8_kernel_impl(
|
||||
const int8_t* __restrict__ B = packed_w2 + nb * BLOCK_N * stride_oc;
|
||||
const float* __restrict__ Bs = w2s + nb * BLOCK_N;
|
||||
|
||||
// 2.a gemm: C = A @ B
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ C,
|
||||
/* As */ As,
|
||||
/* Bs */ Bs,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ IC,
|
||||
/* lda */ IC,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N);
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ IC,
|
||||
/* lda */ IC,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N,
|
||||
/* add_C */ false,
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ C32);
|
||||
|
||||
// apply scales
|
||||
const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * IC);
|
||||
scale_C<BLOCK_N>(C, C32, As, Bs, Bcomp, m_size);
|
||||
} else {
|
||||
// 2.a gemm: C = A @ B
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ C,
|
||||
/* As */ As,
|
||||
/* Bs */ Bs,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ IC,
|
||||
/* lda */ IC,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N);
|
||||
}
|
||||
|
||||
// 2.b copy from C to output and add fused_experts_out
|
||||
scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N;
|
||||
@@ -804,6 +1038,10 @@ void shared_expert_int8_kernel_impl(
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size);
|
||||
}
|
||||
});
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user