Add fp8 shared_expert kernel for CPU in sgl-kernel and add UT (#6339)
Co-authored-by: Jiang, Yanbing <yanbing.jiang@intel.com> Co-authored-by: mingfeima <mingfei.ma@intel.com>
This commit is contained in:
@@ -104,6 +104,24 @@ void shared_expert_int8_kernel_impl(
|
||||
int64_t N,
|
||||
int64_t K);
|
||||
|
||||
template <typename scalar_t>
|
||||
void shared_expert_fp8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic0,
|
||||
scalar_t* __restrict__ ic1,
|
||||
const scalar_t* __restrict__ input,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
int64_t block_size_N,
|
||||
int64_t block_size_K,
|
||||
const scalar_t* __restrict__ fused_experts_out,
|
||||
float routed_scaling_factor,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K);
|
||||
|
||||
// tinygemm interface
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
@@ -134,3 +152,20 @@ void tinygemm_kernel(
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg);
|
||||
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const at::Float8_e4m3fn* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
scalar_t* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ scale,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg,
|
||||
int64_t block_size_K);
|
||||
|
||||
@@ -248,38 +248,6 @@ struct brgemm {
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_t, bool has_bias>
|
||||
struct brgemm<scalar_t, scalar_t, has_bias> {
|
||||
static inline void apply(
|
||||
const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
scalar_t* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ bias,
|
||||
const float* __restrict__ scale,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc) {
|
||||
UNUSED(scale);
|
||||
|
||||
constexpr int BLOCK_N = block_size_n();
|
||||
at::native::cpublas::brgemm(M, N, K, lda, ldb, BLOCK_N, /* add_C */ false, A, B, Ctmp);
|
||||
|
||||
// copy from Ctmp to C
|
||||
for (int m = 0; m < M; ++m) {
|
||||
if constexpr (has_bias) {
|
||||
copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N);
|
||||
} else {
|
||||
copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <bool has_bias>
|
||||
struct brgemm<at::BFloat16, at::Float8_e4m3fn, has_bias> {
|
||||
static inline void apply(
|
||||
@@ -469,6 +437,46 @@ void fp8_scaled_mm_kernel_impl(
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// tinygemm interface
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const at::Float8_e4m3fn* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
scalar_t* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ scale,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg,
|
||||
int64_t block_size_K) {
|
||||
tinygemm_kernel<scalar_t, false>(A, B, C, Btmp, Ctmp, scale, nullptr, M, N, K, lda, ldb, ldc, brg, block_size_K);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
|
||||
template void tinygemm_kernel<TYPE>( \
|
||||
const TYPE* __restrict__ A, \
|
||||
const at::Float8_e4m3fn* __restrict__ B, \
|
||||
TYPE* __restrict__ C, \
|
||||
TYPE* __restrict__ Btmp, \
|
||||
float* __restrict__ Ctmp, \
|
||||
const float* __restrict__ scale, \
|
||||
int64_t M, \
|
||||
int64_t N, \
|
||||
int64_t K, \
|
||||
int64_t lda, \
|
||||
int64_t ldb, \
|
||||
int64_t ldc, \
|
||||
bool brg, \
|
||||
int64_t block_size_K)
|
||||
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
|
||||
|
||||
at::Tensor fp8_scaled_mm_cpu(
|
||||
at::Tensor& mat1,
|
||||
at::Tensor& mat2,
|
||||
|
||||
@@ -1137,8 +1137,10 @@ at::Tensor shared_expert_cpu(
|
||||
double routed_scaling_factor,
|
||||
bool inplace,
|
||||
bool use_int8_w8a8,
|
||||
bool use_fp8_w8a16,
|
||||
std::optional<at::Tensor>& w1_scale,
|
||||
std::optional<at::Tensor>& w2_scale,
|
||||
std::optional<std::vector<int64_t>> block_size,
|
||||
std::optional<at::Tensor>& a1_scale,
|
||||
std::optional<at::Tensor>& a2_scale,
|
||||
bool is_vnni) {
|
||||
@@ -1180,6 +1182,11 @@ at::Tensor shared_expert_cpu(
|
||||
TORCH_CHECK(!a1_scale.has_value(), "static quantization for activation not supported.");
|
||||
TORCH_CHECK(!a2_scale.has_value(), "static quantization for activation not supported.");
|
||||
}
|
||||
if (use_fp8_w8a16) {
|
||||
TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for fp8 w8a16.");
|
||||
TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for fp8 w8a16.");
|
||||
TORCH_CHECK(block_size.has_value(), "missing block_size for fp8 w8a16.");
|
||||
}
|
||||
|
||||
at::Tensor out_hidden_states = inplace ? hidden_states : at::empty_like(hidden_states);
|
||||
|
||||
@@ -1191,12 +1198,18 @@ at::Tensor shared_expert_cpu(
|
||||
// 3. Aq_tmp : [M, K] or [M, N]
|
||||
// 4. As_tmp : [M]
|
||||
//
|
||||
// for fp8 w8a16:
|
||||
// 5. intermediate_cache0 : [M, 2N]
|
||||
//
|
||||
int num_threads = at::get_num_threads();
|
||||
int64_t buffer_size_nbytes = M * N * 2 + num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float);
|
||||
|
||||
if (use_int8_w8a8) {
|
||||
buffer_size_nbytes += std::max(M * K, M * N) + M * sizeof(float);
|
||||
}
|
||||
if (use_fp8_w8a16) {
|
||||
buffer_size_nbytes += M * 2 * N * 2;
|
||||
}
|
||||
|
||||
auto buffer = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar));
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "share_experts_kernel_impl", [&] {
|
||||
@@ -1228,6 +1241,36 @@ at::Tensor shared_expert_cpu(
|
||||
M,
|
||||
N,
|
||||
K);
|
||||
} else if (use_fp8_w8a16) {
|
||||
scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N));
|
||||
|
||||
auto w1s = w1_scale.value();
|
||||
auto w2s = w2_scale.value();
|
||||
auto block_size_val = block_size.value();
|
||||
TORCH_CHECK(block_size_val.size() == 2, "shared_expert: expect block_size.size() to be 2.");
|
||||
int64_t block_size_N = block_size_val[0];
|
||||
int64_t block_size_K = block_size_val[1];
|
||||
TORCH_CHECK(w1s.size(0) == 2 * N / block_size_N);
|
||||
TORCH_CHECK(w1s.size(1) == K / block_size_K);
|
||||
TORCH_CHECK(w2s.size(0) == K / block_size_N);
|
||||
TORCH_CHECK(w2s.size(1) == N / block_size_K);
|
||||
|
||||
shared_expert_fp8_kernel_impl<scalar_t>(
|
||||
out_hidden_states.data_ptr<scalar_t>(),
|
||||
intermediate_cache0,
|
||||
intermediate_cache1,
|
||||
hidden_states.data_ptr<scalar_t>(),
|
||||
packed_w1.data_ptr<at::Float8_e4m3fn>(),
|
||||
packed_w2.data_ptr<at::Float8_e4m3fn>(),
|
||||
w1s.data_ptr<float>(),
|
||||
w2s.data_ptr<float>(),
|
||||
block_size_N,
|
||||
block_size_K,
|
||||
fused_experts_out.data_ptr<scalar_t>(),
|
||||
routed_scaling_factor,
|
||||
M,
|
||||
N,
|
||||
K);
|
||||
} else {
|
||||
shared_expert_kernel_impl<scalar_t>(
|
||||
out_hidden_states.data_ptr<scalar_t>(),
|
||||
|
||||
205
sgl-kernel/csrc/cpu/moe_fp8.cpp
Normal file
205
sgl-kernel/csrc/cpu/moe_fp8.cpp
Normal file
@@ -0,0 +1,205 @@
|
||||
#include "common.h"
|
||||
#include "gemm.h"
|
||||
#include "vec.h"
|
||||
|
||||
namespace {
|
||||
|
||||
// out = input + input2 * scale
|
||||
template <typename scalar_t>
|
||||
inline void add_mul_stub(
|
||||
scalar_t* __restrict__ out,
|
||||
const scalar_t* __restrict__ input,
|
||||
const scalar_t* __restrict__ input2,
|
||||
float scale,
|
||||
int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
const fVec s_vec = fVec(scale);
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
bVec x_bvec = bVec::loadu(input + d);
|
||||
fVec x0, x1;
|
||||
std::tie(x0, x1) = at::vec::convert_to_float(x_bvec);
|
||||
|
||||
bVec y_bvec = bVec::loadu(input2 + d);
|
||||
fVec y0, y1;
|
||||
std::tie(y0, y1) = at::vec::convert_to_float(y_bvec);
|
||||
|
||||
x0 = x0 + y0 * s_vec;
|
||||
x1 = x1 + y1 * s_vec;
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d] + float(input2[d]) * scale);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void silu_and_mul_stub(
|
||||
scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const scalar_t* __restrict__ input2, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
const fVec one = fVec(1.f);
|
||||
|
||||
// no remainder
|
||||
#pragma GCC unroll 4
|
||||
for (int64_t d = 0; d < size; d += bVec::size()) {
|
||||
bVec x = bVec::loadu(input + d);
|
||||
fVec x0, x1;
|
||||
std::tie(x0, x1) = at::vec::convert_to_float(x);
|
||||
bVec y = bVec::loadu(input2 + d);
|
||||
fVec y0, y1;
|
||||
std::tie(y0, y1) = at::vec::convert_to_float(y);
|
||||
x0 = x0 / (one + x0.neg().exp_u20());
|
||||
x1 = x1 / (one + x1.neg().exp_u20());
|
||||
x0 = x0 * y0;
|
||||
x1 = x1 * y1;
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
template <typename scalar_t>
|
||||
void shared_expert_fp8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic0,
|
||||
scalar_t* __restrict__ ic1,
|
||||
const scalar_t* __restrict__ input,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
int64_t block_size_N,
|
||||
int64_t block_size_K,
|
||||
const scalar_t* __restrict__ fused_experts_out,
|
||||
float routed_scaling_factor,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K) {
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
|
||||
// stage 1: intermediate_cache0 = hidden_states @ w1
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(2 * N, BLOCK_N);
|
||||
int64_t scale_size_K = div_up(K, block_size_K);
|
||||
int64_t blocks_n_per_group = block_size_N / BLOCK_N;
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M);
|
||||
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
alignas(64) scalar_t Btmp[BLOCK_N * BLOCK_K];
|
||||
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB;
|
||||
int64_t nb = i % NB;
|
||||
int64_t mb_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
int64_t nb_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ input + mb * BLOCK_M * K,
|
||||
/* B */ packed_w1 + nb * BLOCK_N * K,
|
||||
/* C */ ic0 + mb * BLOCK_M * 2 * N + nb * BLOCK_N,
|
||||
/* Btmp */ Btmp,
|
||||
/* Ctmp */ Ctmp,
|
||||
/* scale */ w1s + (nb / blocks_n_per_group) * scale_size_K,
|
||||
/* M */ mb_size,
|
||||
/* N */ nb_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ nb_size,
|
||||
/* ldc */ 2 * N,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K);
|
||||
}
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
|
||||
// stage 1.5: intermediate_cache1 = silu(intermediate_cache0)
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
silu_and_mul_stub(ic1 + m * N, ic0 + m * 2 * N, ic0 + m * 2 * N + N, N);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 2: intermediate_cache2 = intermediate_cache1 @ w2
|
||||
// w2 : [K, N] as [OC, IC]
|
||||
const int64_t OC = K; // rename K as OC
|
||||
const int64_t IC = N; // rename N as IC
|
||||
const int64_t MB2 = MB;
|
||||
const int64_t NB2 = div_up(K, BLOCK_N);
|
||||
scale_size_K = div_up(N, block_size_K);
|
||||
|
||||
// parallel on [MB2, NB2]
|
||||
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
|
||||
alignas(64) scalar_t Btmp[BLOCK_K * BLOCK_N];
|
||||
alignas(64) scalar_t C[BLOCK_M * BLOCK_K];
|
||||
alignas(64) float Ctmp[BLOCK_M * BLOCK_K];
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB2;
|
||||
int64_t nb = i % NB2;
|
||||
int64_t mb_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
int64_t nb_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
// 2.a gemm: C = A @ B
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ ic1 + mb * BLOCK_M * N,
|
||||
/* B */ packed_w2 + nb * BLOCK_N * N,
|
||||
/* C */ C,
|
||||
/* Btmp */ Btmp,
|
||||
/* Ctmp */ Ctmp,
|
||||
/* scale */ w2s + (nb / blocks_n_per_group) * scale_size_K,
|
||||
/* M */ mb_size,
|
||||
/* N */ nb_size,
|
||||
/* K */ IC,
|
||||
/* lda */ IC,
|
||||
/* ldb */ nb_size,
|
||||
/* ldc */ BLOCK_N,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K);
|
||||
|
||||
// 2.b copy from C to output and add fused_experts_out
|
||||
scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N;
|
||||
const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N;
|
||||
for (int64_t m = 0; m < mb_size; ++m) {
|
||||
add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, nb_size);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(TYPE) \
|
||||
template void shared_expert_fp8_kernel_impl<TYPE>( \
|
||||
TYPE* __restrict__ output, \
|
||||
TYPE* __restrict__ ic0, \
|
||||
TYPE* __restrict__ ic1, \
|
||||
const TYPE* __restrict__ input, \
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1, \
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2, \
|
||||
const float* __restrict__ w1s, \
|
||||
const float* __restrict__ w2s, \
|
||||
int64_t block_size_N, \
|
||||
int64_t block_size_K, \
|
||||
const TYPE* __restrict__ fused_experts_out, \
|
||||
float routed_scaling_factor, \
|
||||
int64_t M, \
|
||||
int64_t N, \
|
||||
int64_t K)
|
||||
|
||||
INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(at::Half);
|
||||
@@ -139,8 +139,10 @@ at::Tensor shared_expert_cpu(
|
||||
double routed_scaling_factor,
|
||||
bool inplace,
|
||||
bool use_int8_w8a8,
|
||||
bool use_fp8_w8a16,
|
||||
std::optional<at::Tensor>& w1_scale,
|
||||
std::optional<at::Tensor>& w2_scale,
|
||||
std::optional<std::vector<int64_t>> block_size,
|
||||
std::optional<at::Tensor>& a1_scale,
|
||||
std::optional<at::Tensor>& a2_scale,
|
||||
bool is_vnni);
|
||||
|
||||
@@ -61,6 +61,7 @@ sources = [
|
||||
"csrc/cpu/gemm_fp8.cpp",
|
||||
"csrc/cpu/gemm_int8.cpp",
|
||||
"csrc/cpu/moe.cpp",
|
||||
"csrc/cpu/moe_fp8.cpp",
|
||||
"csrc/cpu/moe_int8.cpp",
|
||||
"csrc/cpu/norm.cpp",
|
||||
"csrc/cpu/qkv_proj.cpp",
|
||||
|
||||
Reference in New Issue
Block a user