Add optimized native kernels in sgl-kernel (#5150)
Co-authored-by: Chunyuan WU <chunyuan.wu@intel.com> Co-authored-by: YanbingJiang <yanbing.jiang@intel.com> Co-authored-by: blzheng <beilei.zheng@intel.com>
This commit is contained in:
79
sgl-kernel/csrc/cpu/activation.cpp
Normal file
79
sgl-kernel/csrc/cpu/activation.cpp
Normal file
@@ -0,0 +1,79 @@
|
||||
#include "common.h"
|
||||
#include "vec.h"
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t, typename func_t, typename vec_func_t>
|
||||
void act_and_mul_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
const scalar_t* __restrict__ input,
|
||||
int64_t num_tokens,
|
||||
int64_t dim,
|
||||
const func_t& f,
|
||||
const vec_func_t& vf) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
|
||||
constexpr int64_t kVecSize = bVec::size();
|
||||
at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
// local ptrs
|
||||
const scalar_t* __restrict__ input_ptr = input + i * 2 * dim;
|
||||
const scalar_t* __restrict__ input_other_ptr = input_ptr + dim;
|
||||
scalar_t* __restrict__ output_ptr = output + i * dim;
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= dim - kVecSize; d += kVecSize) {
|
||||
bVec x_bvec = bVec::loadu(input_ptr + d);
|
||||
fVec x_fvec0, x_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||
|
||||
bVec y_bvec = bVec::loadu(input_other_ptr + d);
|
||||
fVec y_fvec0, y_fvec1;
|
||||
std::tie(y_fvec0, y_fvec1) = at::vec::convert_to_float(y_bvec);
|
||||
|
||||
x_fvec0 = vf(x_fvec0);
|
||||
x_fvec1 = vf(x_fvec1);
|
||||
|
||||
x_fvec0 = x_fvec0 * y_fvec0;
|
||||
x_fvec1 = x_fvec1 * y_fvec1;
|
||||
|
||||
x_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
|
||||
x_bvec.store(output_ptr + d);
|
||||
}
|
||||
#pragma GCC unroll 4
|
||||
for (; d < dim; ++d) {
|
||||
float x_val = static_cast<float>(input_ptr[d]);
|
||||
float y_val = static_cast<float>(input_other_ptr[d]);
|
||||
output_ptr[d] = f(x_val) * y_val;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// input : {num_tokens, 2 * d}
|
||||
// output : {num_tokens, d}
|
||||
at::Tensor silu_and_mul_cpu(at::Tensor& input) {
|
||||
RECORD_FUNCTION("sgl-kernel::silu_and_mul_cpu", std::vector<c10::IValue>({input}));
|
||||
auto sizes = input.sizes().vec();
|
||||
int64_t last_dim = input.ndimension() - 1;
|
||||
int64_t d = sizes[last_dim] / 2;
|
||||
sizes[last_dim] = d;
|
||||
int64_t num_tokens = input.numel() / input.size(-1);
|
||||
at::Tensor out = at::empty(sizes, input.options());
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "silu_and_mul", [&] {
|
||||
using Vec = at::vec::Vectorized<float>;
|
||||
act_and_mul_kernel_impl(
|
||||
out.data_ptr<scalar_t>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
num_tokens,
|
||||
d,
|
||||
[](float x) { return x / (1.f + std::exp(-x)); },
|
||||
[](Vec x) { return x / (Vec(1.f) + x.neg().exp()); });
|
||||
});
|
||||
return out;
|
||||
}
|
||||
122
sgl-kernel/csrc/cpu/bmm.cpp
Normal file
122
sgl-kernel/csrc/cpu/bmm.cpp
Normal file
@@ -0,0 +1,122 @@
|
||||
#include "common.h"
|
||||
#include "gemm.h"
|
||||
#include "vec.h"
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t>
|
||||
void bmm_kernel_impl(
|
||||
scalar_t* __restrict__ out,
|
||||
const scalar_t* __restrict__ mat1,
|
||||
const scalar_t* __restrict__ mat2,
|
||||
int64_t B,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t mat1_strideB,
|
||||
int64_t mat1_strideM,
|
||||
int64_t out_strideB,
|
||||
int64_t out_strideM,
|
||||
float scale = 0.f) {
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
// mat2 contiguous in [B, N, K]
|
||||
int64_t mat2_strideB = N * K;
|
||||
int64_t mat2_strideN = K;
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<scalar_t>(M);
|
||||
|
||||
// parallel on [B, MB, NB]
|
||||
at::parallel_for(0, B * MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t bs{0}, mb{0}, nb{0};
|
||||
data_index_init(begin, bs, B, mb, MB, nb, NB);
|
||||
|
||||
// for brgemm, use float32 for accumulate
|
||||
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
|
||||
|
||||
for (int i = begin; i < end; ++i) {
|
||||
UNUSED(i);
|
||||
int mb_start = mb * BLOCK_M;
|
||||
int mb_size = std::min(M - mb_start, BLOCK_M);
|
||||
int nb_start = nb * BLOCK_N;
|
||||
int nb_size = std::min(N - nb_start, BLOCK_N);
|
||||
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ mat1 + bs * mat1_strideB + mb_start * mat1_strideM,
|
||||
/* B */ mat2 + bs * mat2_strideB + nb_start * mat2_strideN /* nb * BLOCK_N * K */,
|
||||
/* C */ out + bs * out_strideB + mb_start * out_strideM + nb_start,
|
||||
/* Ctmp*/ Ctmp,
|
||||
/* M */ mb_size,
|
||||
/* N */ nb_size,
|
||||
/* K */ K,
|
||||
/* lda */ mat1_strideM,
|
||||
/* ldb */ nb_size,
|
||||
/* ldc */ out_strideM,
|
||||
/* brg */ use_brgemm);
|
||||
|
||||
// move to the next index
|
||||
data_index_step(bs, B, mb, MB, nb, NB);
|
||||
}
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// mat1 : [B, M, K]
|
||||
// mat2 : [B, N, K] or [B, OC, IC]
|
||||
// out : [B, M, N]
|
||||
// scale: [] 0-dim tensor for per tensor quant
|
||||
//
|
||||
void bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, std::optional<at::Tensor>& scale) {
|
||||
RECORD_FUNCTION("sgl-kernel::bmm_cpu", std::vector<c10::IValue>({out, mat1, mat2}));
|
||||
|
||||
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
||||
|
||||
// input and out could be non-contiguous
|
||||
// weight needs to be contiguous in [OC, IC] order
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(out);
|
||||
CHECK_INPUT(mat2);
|
||||
CHECK_DIM(3, out);
|
||||
CHECK_DIM(3, mat1);
|
||||
CHECK_DIM(3, mat2);
|
||||
|
||||
int64_t B = mat1.size(0);
|
||||
int64_t M = mat1.size(1);
|
||||
int64_t N = mat2.size(1);
|
||||
int64_t K = mat1.size(2);
|
||||
|
||||
TORCH_CHECK(!scale.has_value(), "bmm: do not support fp8 weight for now.")
|
||||
TORCH_CHECK(N % 32 == 0, "tinygemm requires N to be 32x.");
|
||||
|
||||
int64_t mat1_strideB = mat1.stride(0);
|
||||
int64_t mat1_strideM = mat1.stride(1);
|
||||
int64_t out_strideB = out.stride(0);
|
||||
int64_t out_strideM = out.stride(1);
|
||||
|
||||
// check shapes
|
||||
TORCH_CHECK(mat2.size(0) == B && mat2.size(2) == K, "bmm: mat2 shape mismatch!");
|
||||
TORCH_CHECK(out.size(0) == B && out.size(1) == M, "bmm: out shape mismatch!");
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(mat1.scalar_type(), "bmm_kernel_impl", [&] {
|
||||
bmm_kernel_impl<scalar_t>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
mat1.data_ptr<scalar_t>(),
|
||||
packed_w.data_ptr<scalar_t>(),
|
||||
B,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
mat1_strideB,
|
||||
mat1_strideM,
|
||||
out_strideB,
|
||||
out_strideM);
|
||||
});
|
||||
}
|
||||
164
sgl-kernel/csrc/cpu/common.h
Normal file
164
sgl-kernel/csrc/cpu/common.h
Normal file
@@ -0,0 +1,164 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/record_function.h>
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
|
||||
// dispatch bool
|
||||
#define AT_DISPATCH_BOOL(BOOL_V, BOOL_NAME, ...) \
|
||||
[&] { \
|
||||
if (BOOL_V) { \
|
||||
constexpr bool BOOL_NAME = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
constexpr bool BOOL_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
|
||||
// dispatch: bfloat16, float16, int8_t
|
||||
#define CPU_DISPATCH_PACKED_TYPES(TYPE, ...) \
|
||||
[&] { \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::BFloat16: { \
|
||||
using packed_t = at::BFloat16; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using packed_t = at::Half; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
case at::ScalarType::Char: { \
|
||||
using packed_t = int8_t; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported floating data type.\n"); \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define UNUSED(x) (void)(x)
|
||||
|
||||
#define CHECK_CPU(x) TORCH_CHECK(x.device().type() == at::kCPU, #x " must be a CPU tensor")
|
||||
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_LAST_DIM_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimention")
|
||||
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CPU(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \
|
||||
CHECK_CPU(x); \
|
||||
CHECK_LAST_DIM_CONTIGUOUS(x)
|
||||
|
||||
#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
|
||||
|
||||
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
|
||||
|
||||
// parallel routines
|
||||
constexpr int GRAIN_SIZE = 1024;
|
||||
|
||||
template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
|
||||
inline T div_up(T x, T y) {
|
||||
return (x + y - 1) / y;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) {
|
||||
#if 0
|
||||
// onednn partition pattern
|
||||
T& n_my = n_end;
|
||||
if (nth <= 1 || n == 0) {
|
||||
n_start = 0;
|
||||
n_my = n;
|
||||
} else {
|
||||
T n1 = div_up(n, nth);
|
||||
T n2 = n1 - 1;
|
||||
T T1 = n - n2 * nth;
|
||||
n_my = ith < T1 ? n1 : n2;
|
||||
n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2;
|
||||
}
|
||||
n_end += n_start;
|
||||
#else
|
||||
// pytorch aten partition pattern
|
||||
T n_my = div_up(n, nth);
|
||||
n_start = ith * n_my;
|
||||
n_end = std::min(n_start + n_my, n);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename func_t>
|
||||
inline void parallel_for(int n, const func_t& f) {
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp parallel
|
||||
{
|
||||
int nth = omp_get_num_threads();
|
||||
int ith = omp_get_thread_num();
|
||||
int tbegin, tend;
|
||||
balance211(n, nth, ith, tbegin, tend);
|
||||
f(tbegin, tend);
|
||||
}
|
||||
#else
|
||||
f(0, n);
|
||||
#endif
|
||||
}
|
||||
|
||||
// data indexing for dimension collapse
|
||||
template <typename T>
|
||||
inline T data_index_init(T offset) {
|
||||
return offset;
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
inline T data_index_init(T offset, T& x, const T& X, Args&&... args) {
|
||||
offset = data_index_init(offset, std::forward<Args>(args)...);
|
||||
x = offset % X;
|
||||
return offset / X;
|
||||
}
|
||||
|
||||
inline bool data_index_step() {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
inline bool data_index_step(T& x, const T& X, Args&&... args) {
|
||||
if (data_index_step(std::forward<Args>(args)...)) {
|
||||
x = ((x + 1) == X) ? 0 : (x + 1);
|
||||
return x == 0;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// forced unroll for perf critical path
|
||||
|
||||
#if __has_attribute(always_inline)
|
||||
#define ALWAYS_INLINE __attribute__((__always_inline__)) inline
|
||||
#else
|
||||
#define ALWAYS_INLINE inline
|
||||
#endif
|
||||
|
||||
template <int n>
|
||||
struct Unroll {
|
||||
template <typename Func, typename... Args>
|
||||
ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
|
||||
Unroll<n - 1>{}(f, args...);
|
||||
f(std::integral_constant<int, n - 1>{}, args...);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Unroll<1> {
|
||||
template <typename Func, typename... Args>
|
||||
ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
|
||||
f(std::integral_constant<int, 0>{}, args...);
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
1119
sgl-kernel/csrc/cpu/decode.cpp
Normal file
1119
sgl-kernel/csrc/cpu/decode.cpp
Normal file
File diff suppressed because it is too large
Load Diff
621
sgl-kernel/csrc/cpu/extend.cpp
Normal file
621
sgl-kernel/csrc/cpu/extend.cpp
Normal file
@@ -0,0 +1,621 @@
|
||||
#include "common.h"
|
||||
#include "gemm.h"
|
||||
#include "vec.h"
|
||||
|
||||
namespace {
|
||||
|
||||
// [NOTE]: extend attention for CPU
|
||||
// 1. tune BLOCK_M and BLOCK_N
|
||||
// 2. can handle non-contiguous k_exttend and v_extend
|
||||
// 3. computes attention for prefix and extend separately
|
||||
// 4. TODO: vectorize `pack_vnni` and `pack_vnni2`
|
||||
//
|
||||
template <typename index_t>
|
||||
inline index_t get_index(index_t* ind, int i) {
|
||||
return (ind == nullptr) ? (index_t)i : ind[i];
|
||||
}
|
||||
|
||||
// convert to vnni format
|
||||
// from [N, K/2, 2] to [K/2, N, 2] for bfloat16 and float16
|
||||
template <typename scalar_t, typename index_t>
|
||||
void pack_vnni(
|
||||
scalar_t* __restrict__ dst,
|
||||
const scalar_t* __restrict__ src,
|
||||
const index_t* __restrict__ ind,
|
||||
int N,
|
||||
int K,
|
||||
int ld_src,
|
||||
int ld_dst) {
|
||||
for (int n = 0; n < N; ++n) {
|
||||
index_t index = get_index(ind, n);
|
||||
for (int k = 0; k < K / 2; ++k) {
|
||||
for (int d = 0; d < 2; ++d) {
|
||||
dst[k * ld_dst * 2 + n * 2 + d] = src[index * ld_src + k * 2 + d];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// convert to vnni format
|
||||
// from [K/2, 2, N] to [K/2, N, 2] for bfloat16 and float16
|
||||
template <typename scalar_t, typename index_t>
|
||||
void pack_vnni2(
|
||||
scalar_t* __restrict__ dst,
|
||||
const scalar_t* __restrict__ src,
|
||||
const index_t* __restrict__ ind,
|
||||
int K,
|
||||
int N,
|
||||
int ld_src,
|
||||
int ld_dst) {
|
||||
int k = 0;
|
||||
for (; k < (K >> 1) * 2; k += 2) {
|
||||
index_t index0 = get_index(ind, k + 0);
|
||||
index_t index1 = get_index(ind, k + 1);
|
||||
for (int n = 0; n < N; ++n) {
|
||||
dst[(k >> 1) * ld_dst * 2 + n * 2 + 0] = src[index0 * ld_src + n];
|
||||
dst[(k >> 1) * ld_dst * 2 + n * 2 + 1] = src[index1 * ld_src + n];
|
||||
}
|
||||
}
|
||||
if (K % 2 != 0) {
|
||||
index_t index = get_index(ind, K - 1);
|
||||
for (int n = 0; n < N; ++n) {
|
||||
dst[(K >> 1) * ld_dst * 2 + n * 2 + 0] = src[index * ld_src + n];
|
||||
dst[(K >> 1) * ld_dst * 2 + n * 2 + 1] = 0;
|
||||
}
|
||||
k += 2;
|
||||
}
|
||||
// TODO: check whether we can skip this!
|
||||
// const int padded_K = div_up(K, TILE_K) * TILE_K;
|
||||
// for (; k < padded_K; ++k) {
|
||||
// for (int n = 0; n < N; ++n) {
|
||||
// dst[k * ld_dst + n] = static_cast<scalar_t>(0);
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void fill_stub(scalar_t* __restrict__ out, float val, int size) {
|
||||
using Vec = at::vec::Vectorized<scalar_t>;
|
||||
const Vec data_vec = Vec(static_cast<scalar_t>(val));
|
||||
int d = 0;
|
||||
for (; d <= size - Vec::size(); d += Vec::size()) {
|
||||
data_vec.store(out + d);
|
||||
}
|
||||
if (size - d > 0) {
|
||||
data_vec.store(out + d, size - d);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int BLOCK_N>
|
||||
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input) {
|
||||
static_assert(BLOCK_N % 32 == 0);
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
auto store = [&](auto i) {
|
||||
constexpr int col = i % COLS;
|
||||
// for COLS = 2, 4 use 512bit store
|
||||
if constexpr (col % 2 == 0) {
|
||||
fVec a_fvec0 = fVec::loadu(input + col * 16);
|
||||
fVec a_fvec1 = fVec::loadu(input + col * 16 + 16);
|
||||
bVec out_bvec = convert_from_float_ext<scalar_t>(a_fvec0, a_fvec1);
|
||||
out_bvec.store(out + col * 16);
|
||||
}
|
||||
};
|
||||
Unroll<COLS>{}(store);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ acc, float s, int size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
const fVec s_fvec = fVec(s);
|
||||
int d = 0;
|
||||
for (; d <= size - bVec::size(); d += bVec::size()) {
|
||||
fVec a_fvec0 = fVec::loadu(acc + d) * s_fvec;
|
||||
fVec a_fvec1 = fVec::loadu(acc + d + fVec::size()) * s_fvec;
|
||||
bVec out_bvec = convert_from_float_ext<scalar_t>(a_fvec0, a_fvec1);
|
||||
out_bvec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(acc[d] * s);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename index_t, int BLOCK_M, int BLOCK_N>
|
||||
void extend_attention_kernel_impl(
|
||||
scalar_t* __restrict__ o_extend,
|
||||
const scalar_t* __restrict__ q_extend,
|
||||
const scalar_t* __restrict__ k_extend,
|
||||
const scalar_t* __restrict__ v_extend,
|
||||
const scalar_t* __restrict__ k_buffer,
|
||||
const scalar_t* __restrict__ v_buffer,
|
||||
const index_t* __restrict__ req_to_token,
|
||||
const int64_t* __restrict__ req_pool_indices,
|
||||
const int64_t* __restrict__ seq_lens,
|
||||
const index_t* __restrict__ extend_seq_lens,
|
||||
const index_t* __restrict__ extend_start_loc,
|
||||
const void* __restrict__ buffer,
|
||||
int batches,
|
||||
int num_heads,
|
||||
int num_heads_kv,
|
||||
int head_size,
|
||||
int head_size_v,
|
||||
int ke_strideN,
|
||||
int ke_strideH,
|
||||
int ve_strideN,
|
||||
int ve_strideH,
|
||||
int k_strideN,
|
||||
int k_strideH,
|
||||
int v_strideN,
|
||||
int v_strideH,
|
||||
float scaling,
|
||||
float logit_cap,
|
||||
int max_num_reqs,
|
||||
int max_context_len,
|
||||
int max_total_num_tokens,
|
||||
int max_len_extend,
|
||||
int buffer_size_per_thread,
|
||||
bool is_prefix_skipped) {
|
||||
using Vec = at::vec::Vectorized<float>;
|
||||
|
||||
// strides
|
||||
const int q_strideM = num_heads * head_size;
|
||||
const int q_strideH = head_size;
|
||||
const int o_strideM = num_heads * head_size_v;
|
||||
const int o_strideH = head_size_v;
|
||||
|
||||
// we use same buffer for packed key and value
|
||||
const int ldb_tmp = std::max(head_size, head_size_v);
|
||||
|
||||
const bool has_logit_cap = logit_cap > 0;
|
||||
float rlogit_cap = has_logit_cap ? 1 / logit_cap : 0.f;
|
||||
|
||||
const int num_groups = num_heads / num_heads_kv;
|
||||
TORCH_CHECK(num_groups * num_heads_kv == num_heads);
|
||||
|
||||
// number of blocks along M
|
||||
int MB = div_up(max_len_extend, BLOCK_M);
|
||||
|
||||
// parallel on [batches, num_heads, BM]
|
||||
at::parallel_for(0, batches * num_heads * MB, 0, [&](int begin, int end) {
|
||||
int bs{0}, head_id{0}, mb{0};
|
||||
data_index_init(begin, bs, batches, head_id, num_heads, mb, MB);
|
||||
|
||||
int tid = at::get_thread_num();
|
||||
// s_i and s_delta: [BLOCK_M, BLOCK_N]
|
||||
float* __restrict__ s_i = reinterpret_cast<float*>((char*)(buffer) + tid * buffer_size_per_thread);
|
||||
float* __restrict__ s_delta = s_i;
|
||||
|
||||
// v_prime: [BLOCK_M, head_size_v]
|
||||
float* __restrict__ v_prime = s_i + BLOCK_M * BLOCK_N;
|
||||
|
||||
// s_delta2: [BLOCK_M, BLOCK_N]; copy of s_delta in scalar_t
|
||||
scalar_t* __restrict__ s_delta2 = reinterpret_cast<scalar_t*>(v_prime + BLOCK_N * head_size_v);
|
||||
|
||||
// Btmp: [BLOCK_N, max(head_size, head_size_v)]
|
||||
scalar_t* __restrict__ Btmp = s_delta2 + BLOCK_M * BLOCK_N;
|
||||
|
||||
// init Btmp just once for each thread to prevent NaN
|
||||
fill_stub(Btmp, 0.f, BLOCK_N * ldb_tmp);
|
||||
|
||||
alignas(64) float s_prime[BLOCK_M];
|
||||
alignas(64) float m_prime[BLOCK_M];
|
||||
|
||||
for (int i = begin; i < end; ++i) {
|
||||
// seq_len = prefix + extend
|
||||
int head_kv_id = head_id / num_groups;
|
||||
int seq_len = seq_lens[bs];
|
||||
int seq_len_extend = extend_seq_lens[bs];
|
||||
int seq_len_prefix = seq_len - seq_len_extend;
|
||||
int seq_extend_start_loc = extend_start_loc[bs];
|
||||
|
||||
int req_pool_id = req_pool_indices[bs];
|
||||
TORCH_CHECK(seq_len_prefix >= 0, "prefix len < 0!");
|
||||
TORCH_CHECK(seq_len <= max_context_len, "seq_len out of scope!");
|
||||
TORCH_CHECK(req_pool_id < max_num_reqs, "req_pool_id out of scope!");
|
||||
|
||||
if (is_prefix_skipped) {
|
||||
TORCH_CHECK(seq_len_prefix == 0, "extend attention: expect seq_len_prefix to be 0, got ", seq_len_prefix);
|
||||
}
|
||||
|
||||
// offset and size in MB
|
||||
int m = mb * BLOCK_N;
|
||||
int m_size = std::min(BLOCK_M, seq_len_extend - m);
|
||||
|
||||
if (m_size <= 0) {
|
||||
data_index_step(bs, batches, head_id, num_heads, mb, MB);
|
||||
continue;
|
||||
}
|
||||
|
||||
// get query
|
||||
const scalar_t* __restrict__ q_ptr = q_extend + (seq_extend_start_loc + m) * q_strideM + head_id * q_strideH;
|
||||
|
||||
// init v', s' and m'
|
||||
fill_stub(v_prime, 0.f, m_size * head_size_v);
|
||||
fill_stub(s_prime, 0.f, m_size);
|
||||
fill_stub(m_prime, -std::numeric_limits<scalar_t>::infinity(), m_size);
|
||||
|
||||
// stage 1: compute scores with prefix
|
||||
for (int n = 0; n < seq_len_prefix; n += BLOCK_N) {
|
||||
int n_size = std::min(BLOCK_N, seq_len_prefix - n);
|
||||
|
||||
// `n_size` is K in 2nd gemm, pad to TILE_K;
|
||||
const int padded_n_size = div_up(n_size, TILE_K) * TILE_K;
|
||||
|
||||
// get key and pack
|
||||
pack_vnni<scalar_t, index_t>(
|
||||
/* dst */ Btmp,
|
||||
/* src */ k_buffer + head_kv_id * k_strideH,
|
||||
/* ind */ req_to_token + req_pool_id * max_context_len + n,
|
||||
/* N */ n_size,
|
||||
/* K */ head_size,
|
||||
/* ld_src */ k_strideN,
|
||||
/* ld_dst */ BLOCK_N);
|
||||
|
||||
// calculate s_i <- Q @ K
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ head_size,
|
||||
/* lda */ q_strideM,
|
||||
/* ldb */ BLOCK_N,
|
||||
/* ldc */ BLOCK_N,
|
||||
/* add_C */ false,
|
||||
/* A */ q_ptr,
|
||||
/* B */ Btmp,
|
||||
/* C */ s_i);
|
||||
|
||||
const Vec scale_vec = Vec(scaling);
|
||||
for (int row = 0; row < m_size; ++row) {
|
||||
// s_i <- s_i * scale
|
||||
at::vec::map<float>(
|
||||
[scale_vec](Vec x) { return x * scale_vec; }, s_i + row * BLOCK_N, s_i + row * BLOCK_N, n_size);
|
||||
|
||||
// TODO: `tanh` from torch uses sleef u10, going to be slow
|
||||
if (has_logit_cap) {
|
||||
at::vec::map<float>(
|
||||
[logit_cap, rlogit_cap](Vec x) { return Vec(logit_cap) * (x * Vec(rlogit_cap)).tanh(); },
|
||||
s_i + row * BLOCK_N,
|
||||
s_i + row * BLOCK_N,
|
||||
n_size);
|
||||
}
|
||||
|
||||
// m_i: max value per row
|
||||
float m_i = at::vec::reduce_all<float>(
|
||||
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, s_i + row * BLOCK_N, n_size);
|
||||
m_i = std::max(m_i, m_prime[row]);
|
||||
|
||||
// m_delta <- exp(m' - m_i)
|
||||
float m_delta = std::exp(m_prime[row] - m_i);
|
||||
|
||||
// s_delta <- exp(s_i - m_i)
|
||||
at::vec::map<float>(
|
||||
[m_i](Vec x) { return (x - Vec(m_i)).exp_u20(); }, s_delta + row * BLOCK_N, s_i + row * BLOCK_N, n_size);
|
||||
|
||||
// s' <- s' * m_delta + sum(s_delta)
|
||||
s_prime[row] *= m_delta;
|
||||
s_prime[row] +=
|
||||
at::vec::reduce_all<float>([](Vec& x, Vec& y) { return x + y; }, s_delta + row * BLOCK_N, n_size);
|
||||
|
||||
m_prime[row] = m_i;
|
||||
|
||||
// v' <- v' * m_delta
|
||||
at::vec::map<float>(
|
||||
[m_delta](Vec x) { return x * Vec(m_delta); },
|
||||
v_prime + row * head_size_v,
|
||||
v_prime + row * head_size_v,
|
||||
head_size_v);
|
||||
|
||||
// pad s_delta with 0 first and then convert to scalar_t
|
||||
fill_stub(s_delta + row * BLOCK_N + n_size, 0.f, padded_n_size - n_size);
|
||||
copy_stub<scalar_t, BLOCK_N>(s_delta2 + row * BLOCK_N, s_delta + row * BLOCK_N);
|
||||
}
|
||||
|
||||
// get value and pack
|
||||
pack_vnni2<scalar_t, index_t>(
|
||||
/* dst */ Btmp,
|
||||
/* src */ v_buffer + head_kv_id * v_strideH,
|
||||
/* ind */ req_to_token + req_pool_id * max_context_len + n,
|
||||
/* K */ n_size,
|
||||
/* N */ head_size_v,
|
||||
/* ld_src */ v_strideN,
|
||||
/* ld_dst */ head_size_v);
|
||||
|
||||
// caculate V' <- s_delta @ V + V'
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ m_size,
|
||||
/* N */ head_size_v,
|
||||
/* K */ padded_n_size, // n_size
|
||||
/* lda */ BLOCK_N,
|
||||
/* ldb */ head_size_v,
|
||||
/* ldc */ head_size_v,
|
||||
/* add_C */ true,
|
||||
/* A */ s_delta2,
|
||||
/* B */ Btmp,
|
||||
/* C */ v_prime);
|
||||
} // loop with seq_len_prefix
|
||||
|
||||
// stage 2: compute the triangle part
|
||||
int num_keys = std::min(seq_len_extend, m + BLOCK_M);
|
||||
for (int n = 0; n < num_keys; n += BLOCK_N) {
|
||||
int n_size = std::min(BLOCK_N, num_keys - n);
|
||||
|
||||
// `n_size` is K in 2nd gemm, pad to TILE_K;
|
||||
const int padded_n_size = div_up(n_size, TILE_K) * TILE_K;
|
||||
|
||||
// get key and pack
|
||||
pack_vnni<scalar_t, index_t>(
|
||||
/* dst */ Btmp,
|
||||
/* src */ k_extend + (seq_extend_start_loc + n) * ke_strideN + head_kv_id * ke_strideH,
|
||||
/* ind */ nullptr,
|
||||
/* N */ n_size,
|
||||
/* K */ head_size,
|
||||
/* ld_src */ ke_strideN,
|
||||
/* ld_dst */ BLOCK_N);
|
||||
|
||||
// calculate s_i <- Q @ K
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ head_size,
|
||||
/* lda */ q_strideM,
|
||||
/* ldb */ BLOCK_N,
|
||||
/* ldc */ BLOCK_N,
|
||||
/* add_C */ false,
|
||||
/* A */ q_ptr,
|
||||
/* B */ Btmp,
|
||||
/* C */ s_i);
|
||||
|
||||
// apply causal mask
|
||||
if (num_keys - n <= BLOCK_N) {
|
||||
for (int row = 0; row < m_size; ++row) {
|
||||
int last_col = m + row - n;
|
||||
// fill [last_col + 1, n_size) to -inf
|
||||
float* row_ptr = s_i + row * BLOCK_N;
|
||||
fill_stub(row_ptr + last_col + 1, -std::numeric_limits<float>::infinity(), n_size - last_col - 1);
|
||||
}
|
||||
}
|
||||
|
||||
const Vec scale_vec = Vec(scaling);
|
||||
for (int row = 0; row < m_size; ++row) {
|
||||
// s_i <- s_i * scale
|
||||
at::vec::map<float>(
|
||||
[scale_vec](Vec x) { return x * scale_vec; }, s_i + row * BLOCK_N, s_i + row * BLOCK_N, n_size);
|
||||
|
||||
// TODO: `tanh` from torch uses sleef u10, going to be slow
|
||||
if (has_logit_cap) {
|
||||
at::vec::map<float>(
|
||||
[logit_cap, rlogit_cap](Vec x) { return Vec(logit_cap) * (x * Vec(rlogit_cap)).tanh(); },
|
||||
s_i + row * BLOCK_N,
|
||||
s_i + row * BLOCK_N,
|
||||
n_size);
|
||||
}
|
||||
|
||||
// m_i: max value per row
|
||||
float m_i = at::vec::reduce_all<float>(
|
||||
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, s_i + row * BLOCK_N, n_size);
|
||||
m_i = std::max(m_i, m_prime[row]);
|
||||
|
||||
// m_delta <- exp(m' - m_i)
|
||||
float m_delta = std::exp(m_prime[row] - m_i);
|
||||
|
||||
// s_delta <- exp(s_i - m_i)
|
||||
at::vec::map<float>(
|
||||
[m_i](Vec x) { return (x - Vec(m_i)).exp_u20(); }, s_delta + row * BLOCK_N, s_i + row * BLOCK_N, n_size);
|
||||
|
||||
// s' <- s' * m_delta + sum(s_delta)
|
||||
s_prime[row] *= m_delta;
|
||||
s_prime[row] +=
|
||||
at::vec::reduce_all<float>([](Vec& x, Vec& y) { return x + y; }, s_delta + row * BLOCK_N, n_size);
|
||||
|
||||
m_prime[row] = m_i;
|
||||
|
||||
// v' <- v' * m_delta
|
||||
at::vec::map<float>(
|
||||
[m_delta](Vec x) { return x * Vec(m_delta); },
|
||||
v_prime + row * head_size_v,
|
||||
v_prime + row * head_size_v,
|
||||
head_size_v);
|
||||
|
||||
// pad s_delta with 0 first and then convert to scalar_t
|
||||
fill_stub(s_delta + row * BLOCK_N + n_size, 0.f, padded_n_size - n_size);
|
||||
copy_stub<scalar_t, BLOCK_N>(s_delta2 + row * BLOCK_N, s_delta + row * BLOCK_N);
|
||||
}
|
||||
|
||||
// get value and pack
|
||||
pack_vnni2<scalar_t, index_t>(
|
||||
/* dst */ Btmp,
|
||||
/* src */ v_extend + (seq_extend_start_loc + n) * ve_strideN + head_kv_id * ve_strideH,
|
||||
/* ind */ nullptr,
|
||||
/* K */ n_size,
|
||||
/* N */ head_size_v,
|
||||
/* ld_src */ ve_strideN,
|
||||
/* ld_dst */ head_size_v);
|
||||
|
||||
// caculate V' <- s_delta @ V + V'
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ m_size,
|
||||
/* N */ head_size_v,
|
||||
/* K */ padded_n_size, // n_size
|
||||
/* lda */ BLOCK_N,
|
||||
/* ldb */ head_size_v,
|
||||
/* ldc */ head_size_v,
|
||||
/* add_C */ true,
|
||||
/* A */ s_delta2,
|
||||
/* B */ Btmp,
|
||||
/* C */ v_prime);
|
||||
} // loop with seq_len_extend
|
||||
|
||||
scalar_t* __restrict__ out_ptr = o_extend + (seq_extend_start_loc + m) * o_strideM + head_id * o_strideH;
|
||||
for (int row = 0; row < m_size; ++row) {
|
||||
float s = 1 / s_prime[row];
|
||||
copy_stub<scalar_t>(out_ptr + row * o_strideM, v_prime + row * head_size_v, s, head_size_v);
|
||||
}
|
||||
|
||||
// move to the next index
|
||||
data_index_step(bs, batches, head_id, num_heads, mb, MB);
|
||||
}
|
||||
at::native::cpublas::brgemm_release();
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
||||
// k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
|
||||
//
|
||||
// q_extend: [num_tokens, num_heads, head_size]
|
||||
// k_extend: [num_extend_tokens, num_heads, head_size]
|
||||
// v_extend: [num_extend_tokens, num_heads, head_size]
|
||||
// o_extend: [num_tokens, num_heads, head_size]
|
||||
// k_buffer: [max_total_num_tokens, num_heads, head_size]
|
||||
// v_buffer: [max_total_num_tokens, num_heads, head_size]
|
||||
// req_to_token: [max_num_reqs, max_context_len] int32 or int64
|
||||
// req_pool_indices: [num_seqs] int64
|
||||
// seq_lens: [num_seqs] int64
|
||||
// extend_seq_lens: [num_seqs]
|
||||
// extend_start_loc: [num_seqs]
|
||||
//
|
||||
void extend_attention_cpu(
|
||||
at::Tensor& q_extend,
|
||||
at::Tensor& k_extend,
|
||||
at::Tensor& v_extend,
|
||||
at::Tensor& o_extend,
|
||||
at::Tensor& k_buffer,
|
||||
at::Tensor& v_buffer,
|
||||
at::Tensor& req_to_token,
|
||||
at::Tensor& req_pool_indices,
|
||||
at::Tensor& seq_lens,
|
||||
at::Tensor& extend_seq_lens,
|
||||
at::Tensor& extend_start_loc,
|
||||
int64_t max_len_extend,
|
||||
double sm_scale,
|
||||
double logit_cap) {
|
||||
RECORD_FUNCTION(
|
||||
"sgl-kernel::extend_attention_cpu",
|
||||
std::vector<c10::IValue>(
|
||||
{q_extend,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_extend,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_token,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
extend_seq_lens,
|
||||
extend_start_loc}));
|
||||
|
||||
CHECK_INPUT(q_extend);
|
||||
CHECK_INPUT(o_extend);
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_extend);
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_extend);
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_buffer);
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_buffer);
|
||||
|
||||
int num_seqs = seq_lens.size(0);
|
||||
int max_num_reqs = req_to_token.size(0);
|
||||
int max_context_len = req_to_token.size(1);
|
||||
int max_total_num_tokens = k_buffer.size(0);
|
||||
|
||||
int num_heads = q_extend.size(1);
|
||||
int num_heads_kv = k_extend.size(1);
|
||||
int head_size = q_extend.size(2);
|
||||
int head_size_v = v_extend.size(2);
|
||||
|
||||
// strides for k_extend and v_extend
|
||||
int ke_strideN = k_extend.stride(0);
|
||||
int ke_strideH = k_extend.stride(1);
|
||||
int ve_strideN = v_extend.stride(0);
|
||||
int ve_strideH = v_extend.stride(1);
|
||||
|
||||
// strides for k_buffer and v_buffer
|
||||
int k_strideN = k_buffer.stride(0);
|
||||
int k_strideH = k_buffer.stride(1);
|
||||
int v_strideN = v_buffer.stride(0);
|
||||
int v_strideH = v_buffer.stride(1);
|
||||
|
||||
// check sizes
|
||||
CHECK_EQ(req_pool_indices.size(0), num_seqs);
|
||||
CHECK_EQ(extend_seq_lens.size(0), num_seqs);
|
||||
CHECK_EQ(extend_start_loc.size(0), num_seqs);
|
||||
CHECK_EQ(v_extend.size(1), num_heads_kv);
|
||||
CHECK_EQ(k_buffer.size(1), v_buffer.size(1));
|
||||
|
||||
// MLA will skip prefix part
|
||||
const bool is_prefix_skipped = k_buffer.size(1) != num_heads_kv;
|
||||
|
||||
// check index data types
|
||||
const auto index_dtype = req_to_token.scalar_type();
|
||||
TORCH_CHECK(
|
||||
index_dtype == at::kInt || index_dtype == at::kLong,
|
||||
"extend: expect req_to_token to be int32 or int64, got ",
|
||||
index_dtype);
|
||||
TORCH_CHECK(seq_lens.scalar_type() == at::kLong, "extend: expect req_lens to be int64, got ", seq_lens.scalar_type());
|
||||
TORCH_CHECK(
|
||||
req_pool_indices.scalar_type() == at::kLong,
|
||||
"extend: expect req_pool_indices to be int64, got ",
|
||||
req_pool_indices.scalar_type());
|
||||
TORCH_CHECK(
|
||||
extend_seq_lens.scalar_type() == index_dtype && extend_start_loc.scalar_type() == index_dtype,
|
||||
"extend: expect extend_seq_lens and extend_start_loc to have same dtype as req_to_token.");
|
||||
|
||||
// D and DV need to be 32x as we transpose by 512-bit
|
||||
TORCH_CHECK(head_size % 32 == 0, "invalid head_size ", head_size);
|
||||
TORCH_CHECK(head_size_v % 32 == 0, "invalid head_size_v ", head_size_v);
|
||||
|
||||
// block size for query seq length
|
||||
constexpr int BLOCK_M = 32;
|
||||
// block size for key/value seq length
|
||||
constexpr int BLOCK_N = 32;
|
||||
|
||||
const int size_per_thread =
|
||||
/* s_i */ BLOCK_M * BLOCK_N * sizeof(float) +
|
||||
/* v_prime */ BLOCK_M * head_size_v * sizeof(float) +
|
||||
/* s_delta */ BLOCK_M * BLOCK_N * sizeof(uint16_t) +
|
||||
/* Btmp */ BLOCK_N * std::max(head_size, head_size_v) * sizeof(uint16_t);
|
||||
|
||||
int num_threads = at::get_num_threads();
|
||||
auto buffer = at::empty({num_threads, size_per_thread}, q_extend.options().dtype(at::kChar));
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(q_extend.scalar_type(), "extend_attention_kernel", [&] {
|
||||
AT_DISPATCH_INDEX_TYPES(index_dtype, "extend_attention_indices", [&] {
|
||||
extend_attention_kernel_impl<scalar_t, index_t, BLOCK_M, BLOCK_N>(
|
||||
o_extend.data_ptr<scalar_t>(),
|
||||
q_extend.data_ptr<scalar_t>(),
|
||||
k_extend.data_ptr<scalar_t>(),
|
||||
v_extend.data_ptr<scalar_t>(),
|
||||
k_buffer.data_ptr<scalar_t>(),
|
||||
v_buffer.data_ptr<scalar_t>(),
|
||||
req_to_token.data_ptr<index_t>(),
|
||||
req_pool_indices.data_ptr<int64_t>(),
|
||||
seq_lens.data_ptr<int64_t>(),
|
||||
extend_seq_lens.data_ptr<index_t>(),
|
||||
extend_start_loc.data_ptr<index_t>(),
|
||||
buffer.data_ptr(),
|
||||
num_seqs,
|
||||
num_heads,
|
||||
num_heads_kv,
|
||||
head_size,
|
||||
head_size_v,
|
||||
ke_strideN,
|
||||
ke_strideH,
|
||||
ve_strideN,
|
||||
ve_strideH,
|
||||
k_strideN,
|
||||
k_strideH,
|
||||
v_strideN,
|
||||
v_strideH,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
max_num_reqs,
|
||||
max_context_len,
|
||||
max_total_num_tokens,
|
||||
max_len_extend,
|
||||
size_per_thread,
|
||||
is_prefix_skipped);
|
||||
});
|
||||
});
|
||||
}
|
||||
507
sgl-kernel/csrc/cpu/gemm.cpp
Normal file
507
sgl-kernel/csrc/cpu/gemm.cpp
Normal file
@@ -0,0 +1,507 @@
|
||||
#include "gemm.h"
|
||||
|
||||
#include "common.h"
|
||||
#include "vec.h"
|
||||
|
||||
namespace {
|
||||
|
||||
// packed layout:
|
||||
// quants {N, K} int8_t
|
||||
// comp {N} int32_t
|
||||
template <int BLOCK_N>
|
||||
inline void s8s8_compensation(int8_t* __restrict__ packed, int K) {
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
__m512i vcomp[COLS];
|
||||
|
||||
for (int col = 0; col < COLS; ++col) {
|
||||
vcomp[col] = _mm512_setzero_si512();
|
||||
}
|
||||
|
||||
const int64_t offset = BLOCK_N * K;
|
||||
const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));
|
||||
for (int k = 0; k < K / 4; ++k) {
|
||||
for (int col = 0; col < COLS; ++col) {
|
||||
__m512i vb = _mm512_loadu_si512((const __m512i*)(packed + k * BLOCK_N * 4 + col * 64));
|
||||
vcomp[col] = _mm512_dpbusd_epi32(vcomp[col], off, vb);
|
||||
}
|
||||
}
|
||||
|
||||
for (int col = 0; col < COLS; ++col) {
|
||||
_mm512_storeu_si512((__m512i*)(packed + offset + col * 64), vcomp[col]);
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(false, "s8s8_compensation not implemented!");
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert to vnni format
|
||||
// from [N, K] to [K/2, N, 2] for bfloat16 and float16
|
||||
template <typename packed_t>
|
||||
inline void pack_vnni(packed_t* __restrict__ packed, const packed_t* __restrict__ weight, int N, int K) {
|
||||
const int VNNI_BLK = 2;
|
||||
for (int n = 0; n < N; ++n) {
|
||||
for (int k = 0; k < K / VNNI_BLK; ++k) {
|
||||
for (int d = 0; d < VNNI_BLK; ++d) {
|
||||
packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void pack_vnni<int8_t>(int8_t* __restrict__ packed, const int8_t* __restrict__ weight, int N, int K) {
|
||||
constexpr int BLOCK_N = block_size_n();
|
||||
TORCH_CHECK(N == BLOCK_N);
|
||||
|
||||
const int VNNI_BLK = 4;
|
||||
for (int n = 0; n < N; ++n) {
|
||||
for (int k = 0; k < K / VNNI_BLK; ++k) {
|
||||
for (int d = 0; d < VNNI_BLK; ++d) {
|
||||
packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d];
|
||||
}
|
||||
}
|
||||
}
|
||||
s8s8_compensation<BLOCK_N>(packed, K);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec data0 = fVec::loadu(input + d);
|
||||
fVec data1 = fVec::loadu(input + d + fVec::size());
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_add_stub(
|
||||
scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d);
|
||||
fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size());
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d] + bias[d]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn {
|
||||
static inline void apply(
|
||||
const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
const float* __restrict__ bias,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc) {
|
||||
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn<at::BFloat16, has_bias, BLOCK_M, BLOCK_N> {
|
||||
static inline void apply(
|
||||
const at::BFloat16* __restrict__ A,
|
||||
const at::BFloat16* __restrict__ B,
|
||||
at::BFloat16* __restrict__ C,
|
||||
const float* __restrict__ bias,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc) {
|
||||
constexpr int ROWS = BLOCK_M;
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
|
||||
// prefetch distance
|
||||
constexpr int PREFETCH_SIZE_K = 0;
|
||||
|
||||
__m512bh va;
|
||||
__m512bh vb[COLS];
|
||||
__m512 vc[ROWS * COLS];
|
||||
|
||||
auto loadc = [&](auto i) {
|
||||
constexpr int col = i % COLS;
|
||||
if constexpr (has_bias) {
|
||||
vc[i] = _mm512_loadu_ps(bias + col * 16);
|
||||
} else {
|
||||
vc[i] = _mm512_set1_ps(0.f);
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(loadc);
|
||||
|
||||
const int64_t K2 = K >> 1;
|
||||
const int64_t lda2 = lda >> 1;
|
||||
const int64_t ldb2 = ldb; // ldb * 2 >> 1;
|
||||
const float* a_ptr = reinterpret_cast<const float*>(A);
|
||||
const float* b_ptr = reinterpret_cast<const float*>(B);
|
||||
|
||||
auto compute = [&](auto i, int64_t k) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k]));
|
||||
}
|
||||
if constexpr (row == 0) {
|
||||
vb[col] = (__m512bh)(_mm512_loadu_si512(b_ptr + k * ldb2 + col * 16));
|
||||
if constexpr (PREFETCH_SIZE_K > 0) {
|
||||
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0);
|
||||
}
|
||||
}
|
||||
vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]);
|
||||
};
|
||||
for (int64_t k = 0; k < K2; ++k) {
|
||||
Unroll<ROWS * COLS>{}(compute, k);
|
||||
}
|
||||
|
||||
auto storec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
// for COLS = 2, 4 use 512bit store
|
||||
// for COLS = 1, 3 use 256bit store
|
||||
if constexpr (COLS % 2 == 0) {
|
||||
if constexpr (col % 2 == 0) {
|
||||
_mm512_storeu_si512(
|
||||
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
|
||||
(__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col])));
|
||||
}
|
||||
} else {
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(C + row * ldc + col * 16), (__m256i)(_mm512_cvtneps_pbh(vc[i])));
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(storec);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
|
||||
tinygemm_kernel_nn<scalar_t, has_bias, MB_SIZE, NB_SIZE>::apply( \
|
||||
A + mb_start * lda, \
|
||||
B + nb_start * 2, \
|
||||
C + mb_start * ldc + nb_start, \
|
||||
has_bias ? bias + nb_start : nullptr, \
|
||||
K, \
|
||||
lda, \
|
||||
ldb, \
|
||||
ldc);
|
||||
|
||||
template <typename scalar_t, bool has_bias>
|
||||
struct brgemm {
|
||||
static inline void apply(
|
||||
const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc) {
|
||||
constexpr int BLOCK_N = block_size_n();
|
||||
at::native::cpublas::brgemm(M, N, K, lda, ldb, BLOCK_N, /* add_C */ false, A, B, Ctmp);
|
||||
|
||||
// copy from Ctmp to C
|
||||
for (int64_t m = 0; m < M; ++m) {
|
||||
if constexpr (has_bias) {
|
||||
copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N);
|
||||
} else {
|
||||
copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_t, bool has_bias>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg) {
|
||||
if (brg) {
|
||||
brgemm<scalar_t, has_bias>::apply(A, B, C, Ctmp, bias, M, N, K, lda, ldb, ldc);
|
||||
return;
|
||||
}
|
||||
|
||||
// pattern: 1-4-16
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 64;
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
for (int mb = 0; mb < MB; ++mb) {
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
for (int64_t nb = 0; nb < NB; ++nb) {
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
switch (mb_size << 4 | nb_size >> 4) {
|
||||
// mb_size = 1
|
||||
case 0x12:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(1, 32);
|
||||
break;
|
||||
case 0x14:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(1, 64);
|
||||
break;
|
||||
// mb_size = 2
|
||||
case 0x22:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(2, 32);
|
||||
break;
|
||||
case 0x24:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(2, 64);
|
||||
break;
|
||||
// mb_size = 3
|
||||
case 0x32:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(3, 32);
|
||||
break;
|
||||
case 0x34:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(3, 64);
|
||||
break;
|
||||
// mb_size = 4
|
||||
case 0x42:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(4, 32);
|
||||
break;
|
||||
case 0x44:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(4, 64);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void weight_packed_linear_kernel_impl(
|
||||
scalar_t* __restrict__ out,
|
||||
const scalar_t* __restrict__ mat1,
|
||||
const scalar_t* __restrict__ mat2,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t mat1_strideM,
|
||||
int64_t out_strideM) {
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
// use avx512-bf16 when a) M is small; b) dtype is bfloat16, otherwise use amx
|
||||
const bool use_brgemm = (M > 4) || (!std::is_same_v<scalar_t, at::BFloat16>);
|
||||
|
||||
// parallel on [MB, NB]
|
||||
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t mb{0}, nb{0};
|
||||
data_index_init(begin, mb, MB, nb, NB);
|
||||
|
||||
// for brgemm, use float32 for accumulate
|
||||
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
UNUSED(i);
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(M - mb_start, BLOCK_M);
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(N - nb_start, BLOCK_N);
|
||||
|
||||
tinygemm_kernel<scalar_t, has_bias>(
|
||||
/* A */ mat1 + mb_start * mat1_strideM,
|
||||
/* B */ mat2 + nb_start * K /* nb * BLOCK_N * K */,
|
||||
/* C */ out + mb_start * out_strideM + nb_start,
|
||||
/* Ctmp*/ Ctmp,
|
||||
/* bias*/ bias + nb_start,
|
||||
/* M */ mb_size,
|
||||
/* N */ nb_size,
|
||||
/* K */ K,
|
||||
/* lda */ mat1_strideM,
|
||||
/* ldb */ nb_size,
|
||||
/* ldc */ out_strideM,
|
||||
/* brg */ use_brgemm);
|
||||
|
||||
// move to the next index
|
||||
data_index_step(mb, MB, nb, NB);
|
||||
}
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// tinygemm interface
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
float* __restrict__ Ctmp,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg) {
|
||||
tinygemm_kernel<scalar_t, false>(A, B, C, Ctmp, nullptr, M, N, K, lda, ldb, ldc, brg);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
|
||||
template void tinygemm_kernel<TYPE>( \
|
||||
const TYPE* __restrict__ A, \
|
||||
const TYPE* __restrict__ B, \
|
||||
TYPE* __restrict__ C, \
|
||||
float* __restrict__ Ctmp, \
|
||||
int64_t M, \
|
||||
int64_t N, \
|
||||
int64_t K, \
|
||||
int64_t lda, \
|
||||
int64_t ldb, \
|
||||
int64_t ldc, \
|
||||
bool brg)
|
||||
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
|
||||
|
||||
at::Tensor convert_weight_packed(at::Tensor& weight) {
|
||||
// for 3d moe weights
|
||||
// weight : [E, OC, IC]
|
||||
// w1 : [E, 2N, K]
|
||||
// w2 : [E, K, N]
|
||||
CHECK_INPUT(weight);
|
||||
|
||||
const int64_t ndim = weight.ndimension();
|
||||
TORCH_CHECK(ndim == 2 || ndim == 3, "expect weight to be 2d or 3d, got ", ndim, "d tensor.");
|
||||
const auto st = weight.scalar_type();
|
||||
const int64_t E = ndim == 3 ? weight.size(0) : 1;
|
||||
const int64_t OC = ndim == 3 ? weight.size(1) : weight.size(0);
|
||||
const int64_t IC = ndim == 3 ? weight.size(2) : weight.size(1);
|
||||
|
||||
// we handle 2 TILE_N at a time.
|
||||
TORCH_CHECK(OC % TILE_N == 0, "invalid weight out features ", OC);
|
||||
TORCH_CHECK(IC % TILE_K == 0, "invalid weight input features ", IC);
|
||||
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t NB = div_up(OC, BLOCK_N);
|
||||
|
||||
// use phony sizes here [E, OC, IC], for each [E], [OC, IC] -> [IC / 2, OC, 2]
|
||||
auto packed_weight = at::empty({}, weight.options());
|
||||
const int64_t stride = OC * IC;
|
||||
|
||||
TORCH_CHECK(
|
||||
st == at::kBFloat16 || st == at::kHalf || st == at::kChar, "expect weight to be bfloat16, float16 or int8.");
|
||||
|
||||
CPU_DISPATCH_PACKED_TYPES(st, [&] {
|
||||
// adjust most inner dimension size
|
||||
const int packed_row_size = get_row_size<packed_t>(IC);
|
||||
auto sizes = weight.sizes().vec();
|
||||
sizes[ndim - 1] = packed_row_size;
|
||||
packed_weight.resize_(sizes);
|
||||
|
||||
const packed_t* w_data = weight.data_ptr<packed_t>();
|
||||
packed_t* packed_data = packed_weight.data_ptr<packed_t>();
|
||||
|
||||
// parallel on {E, NB}
|
||||
at::parallel_for(0, E * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t e{0}, nb{0};
|
||||
data_index_init(begin, e, E, nb, NB);
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
UNUSED(i);
|
||||
|
||||
int64_t n = nb * BLOCK_N;
|
||||
int64_t n_size = std::min(BLOCK_N, OC - n);
|
||||
pack_vnni<packed_t>(
|
||||
packed_data + e * OC * packed_row_size + n * packed_row_size, w_data + e * stride + n * IC, n_size, IC);
|
||||
|
||||
// move to the next index
|
||||
data_index_step(e, E, nb, NB);
|
||||
}
|
||||
});
|
||||
});
|
||||
return packed_weight;
|
||||
}
|
||||
|
||||
// mat1 : [M, K]
|
||||
// mat2 : [N, K]
|
||||
// bias : [N]
|
||||
// out : [M, N]
|
||||
//
|
||||
at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, std::optional<at::Tensor>& bias, bool is_vnni) {
|
||||
RECORD_FUNCTION("sgl-kernel::weight_packed_linear", std::vector<c10::IValue>({mat1, mat2, bias}));
|
||||
|
||||
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
||||
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
|
||||
CHECK_INPUT(mat2);
|
||||
|
||||
int64_t M = mat1.size(0);
|
||||
int64_t N = mat2.size(0);
|
||||
int64_t K = mat2.size(1);
|
||||
CHECK_EQ(mat1.size(1), K);
|
||||
CHECK_DIM(2, mat1);
|
||||
CHECK_DIM(2, mat2);
|
||||
|
||||
auto out = at::empty({M, N}, mat1.options());
|
||||
|
||||
// strides
|
||||
int64_t mat1_strideM = mat1.stride(0);
|
||||
int64_t out_strideM = out.stride(0);
|
||||
|
||||
const bool has_bias = bias.has_value();
|
||||
const float* bias_data = nullptr;
|
||||
if (has_bias) {
|
||||
CHECK_EQ(bias.value().size(0), N);
|
||||
bias_data = bias.value().data_ptr<float>();
|
||||
}
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(mat1.scalar_type(), "weight_packed_linear_kernel_impl", [&] {
|
||||
weight_packed_linear_kernel_impl<scalar_t>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
mat1.data_ptr<scalar_t>(),
|
||||
packed_w.data_ptr<scalar_t>(),
|
||||
bias_data,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
mat1_strideM,
|
||||
out_strideM);
|
||||
});
|
||||
|
||||
return out;
|
||||
}
|
||||
130
sgl-kernel/csrc/cpu/gemm.h
Normal file
130
sgl-kernel/csrc/cpu/gemm.h
Normal file
@@ -0,0 +1,130 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/native/CPUBlas.h>
|
||||
|
||||
// amx-bf16
|
||||
#define TILE_M 16
|
||||
#define TILE_N 16
|
||||
#define TILE_K 32
|
||||
|
||||
// block size for AMX gemm
|
||||
constexpr int block_size_m() {
|
||||
return 2 * TILE_M;
|
||||
}
|
||||
constexpr int block_size_n() {
|
||||
return 2 * TILE_N;
|
||||
}
|
||||
|
||||
// define threshold using brgemm (intel AMX)
|
||||
template <typename T>
|
||||
inline bool can_use_brgemm(int M);
|
||||
template <>
|
||||
inline bool can_use_brgemm<at::BFloat16>(int M) {
|
||||
return M > 4;
|
||||
}
|
||||
template <>
|
||||
inline bool can_use_brgemm<at::Half>(int M) {
|
||||
return true;
|
||||
}
|
||||
// TODO: add u8s8 brgemm, this requires PyTorch 2.7
|
||||
template <>
|
||||
inline bool can_use_brgemm<int8_t>(int M) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// work around compiler internal error
|
||||
#define BLOCK_K 128 // 4 * TILE_K
|
||||
|
||||
// adjust leading dimension size for K
|
||||
template <typename T>
|
||||
inline int64_t get_row_size(int64_t K) {
|
||||
return K;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline int64_t get_row_size<int8_t>(int64_t K) {
|
||||
return K + sizeof(int32_t);
|
||||
}
|
||||
|
||||
inline int64_t get_row_size(int64_t K, bool use_int8_w8a8) {
|
||||
return use_int8_w8a8 ? K + sizeof(int32_t) : K;
|
||||
}
|
||||
|
||||
// pack weight to vnni format
|
||||
at::Tensor convert_weight_packed(at::Tensor& weight);
|
||||
|
||||
// moe implementations for int8 w8a8
|
||||
template <typename scalar_t>
|
||||
void fused_experts_int8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ ic2,
|
||||
uint8_t* __restrict__ A_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
uint8_t* __restrict__ Aq_tmp,
|
||||
float* __restrict__ As_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const int8_t* __restrict__ packed_w1,
|
||||
const int8_t* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
const float* __restrict__ topk_weights,
|
||||
const int32_t* __restrict__ sorted_ids,
|
||||
const int32_t* __restrict__ expert_ids,
|
||||
const int32_t* __restrict__ offsets,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t E,
|
||||
int64_t topk,
|
||||
int64_t num_tokens_post_pad);
|
||||
|
||||
// shared expert implememntation for int8 w8a8
|
||||
template <typename scalar_t>
|
||||
void shared_expert_int8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic1,
|
||||
float* __restrict__ C_tmp,
|
||||
uint8_t* __restrict__ Aq_tmp,
|
||||
float* __restrict__ As_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const int8_t* __restrict__ packed_w1,
|
||||
const int8_t* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
const scalar_t* __restrict__ fused_experts_out,
|
||||
float routed_scaling_factor,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K);
|
||||
|
||||
// tinygemm interface
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
float* __restrict__ Ctmp,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg);
|
||||
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
int32_t* __restrict__ Ctmp,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg);
|
||||
489
sgl-kernel/csrc/cpu/gemm_int8.cpp
Normal file
489
sgl-kernel/csrc/cpu/gemm_int8.cpp
Normal file
@@ -0,0 +1,489 @@
|
||||
#include "common.h"
|
||||
#include "gemm.h"
|
||||
#include "vec.h"
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t, bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn {
|
||||
static inline void apply(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs,
|
||||
const int32_t* __restrict__ Bcomp,
|
||||
const float* __restrict__ bias,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc) {
|
||||
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn<at::BFloat16, has_bias, BLOCK_M, BLOCK_N> {
|
||||
static inline void apply(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B,
|
||||
at::BFloat16* __restrict__ C,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs,
|
||||
const int32_t* __restrict__ Bcomp,
|
||||
const float* __restrict__ bias,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc) {
|
||||
constexpr int ROWS = BLOCK_M;
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
static_assert(COLS % 2 == 0);
|
||||
|
||||
// prefetch distance
|
||||
constexpr int PREFETCH_SIZE_K = 0;
|
||||
|
||||
__m512i va;
|
||||
__m512i vb[COLS];
|
||||
__m512i vc[ROWS * COLS];
|
||||
__m512i vcomp[COLS];
|
||||
__m512 vd0;
|
||||
__m512 vd1[COLS];
|
||||
|
||||
// oops! 4x4 spills but luckly we use 4x2
|
||||
__m512 vbias[COLS];
|
||||
|
||||
// [NOTE]: s8s8 igemm compensation in avx512-vnni
|
||||
//
|
||||
// avx512-vnni has no s8s8, so we need to change s8s8 to u8s8 with compensate:
|
||||
//
|
||||
// a * b = (a + 128) * b - 128 * b
|
||||
// s s u s u s
|
||||
//
|
||||
// 1) 128 * b is pre-computed when packing B to vnni formats
|
||||
// 2) a + 128 is fused when dynamically quantize A
|
||||
//
|
||||
auto loadc = [&](auto i) { vc[i] = _mm512_set1_epi32(0); };
|
||||
Unroll<ROWS * COLS>{}(loadc);
|
||||
|
||||
const int64_t K4 = K >> 2;
|
||||
const int64_t lda4 = lda >> 2;
|
||||
const int64_t ldb4 = ldb; // ldb * 4 >> 2;
|
||||
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A);
|
||||
const int32_t* b_ptr = reinterpret_cast<const int32_t*>(B);
|
||||
|
||||
auto compute = [&](auto i, int64_t k) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
va = _mm512_set1_epi32(a_ptr[row * lda4 + k]);
|
||||
}
|
||||
if constexpr (row == 0) {
|
||||
vb[col] = _mm512_loadu_si512(b_ptr + k * ldb4 + col * 16);
|
||||
if constexpr (PREFETCH_SIZE_K > 0) {
|
||||
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb4 + col * 16, _MM_HINT_T0);
|
||||
}
|
||||
}
|
||||
vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]);
|
||||
};
|
||||
for (int64_t k = 0; k < K4; ++k) {
|
||||
Unroll<ROWS * COLS>{}(compute, k);
|
||||
}
|
||||
|
||||
auto storec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
// load a scale
|
||||
if constexpr (col == 0) {
|
||||
vd0 = _mm512_set1_ps(As[row]);
|
||||
}
|
||||
// load b scale and vcomp per 2 vectors
|
||||
// also load bias if any
|
||||
if constexpr (row == 0) {
|
||||
if constexpr (col % 2 == 0) {
|
||||
vd1[col + 0] = _mm512_loadu_ps(Bs + col * 16);
|
||||
vd1[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16);
|
||||
vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16);
|
||||
vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16);
|
||||
if constexpr (has_bias) {
|
||||
vbias[col + 0] = _mm512_loadu_ps(bias + col * 16);
|
||||
vbias[col + 1] = _mm512_loadu_ps(bias + col * 16 + 16);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// for COLS = 2, 4 use 512bit store
|
||||
if constexpr (col % 2 == 0) {
|
||||
__m512 vc0 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 0], vcomp[col + 0]));
|
||||
__m512 vc1 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 1], vcomp[col + 1]));
|
||||
if constexpr (has_bias) {
|
||||
vc0 = _mm512_fmadd_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0], vbias[col + 0]);
|
||||
vc1 = _mm512_fmadd_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1], vbias[col + 1]);
|
||||
} else {
|
||||
vc0 = _mm512_mul_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0]);
|
||||
vc1 = _mm512_mul_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1]);
|
||||
}
|
||||
|
||||
_mm512_storeu_si512(
|
||||
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), (__m512i)(_mm512_cvtne2ps_pbh(vc1, vc0)));
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(storec);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
|
||||
tinygemm_kernel_nn<scalar_t, has_bias, MB_SIZE, NB_SIZE>::apply( \
|
||||
A + mb_start * lda, \
|
||||
B + nb_start * 4, \
|
||||
C + mb_start * ldc + nb_start, \
|
||||
As + mb_start, \
|
||||
Bs + nb_start, \
|
||||
Bcomp + nb_start, \
|
||||
has_bias ? bias + nb_start : nullptr, \
|
||||
K, \
|
||||
lda, \
|
||||
ldb, \
|
||||
ldc);
|
||||
|
||||
template <typename scalar_t, bool has_bias>
|
||||
void tinygemm_kernel(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
int32_t* __restrict__ Ctmp,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg) {
|
||||
// B compensation
|
||||
const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * K);
|
||||
|
||||
// pattern: 1-4-16
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 64;
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
for (int64_t mb = 0; mb < MB; ++mb) {
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
for (int64_t nb = 0; nb < NB; ++nb) {
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
switch (mb_size << 4 | nb_size >> 4) {
|
||||
// mb_size = 1
|
||||
case 0x12:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(1, 32);
|
||||
break;
|
||||
case 0x14:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(1, 64);
|
||||
break;
|
||||
// mb_size = 2
|
||||
case 0x22:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(2, 32);
|
||||
break;
|
||||
case 0x24:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(2, 64);
|
||||
break;
|
||||
// mb_size = 3
|
||||
case 0x32:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(3, 32);
|
||||
break;
|
||||
case 0x34:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(3, 64);
|
||||
break;
|
||||
// mb_size = 4
|
||||
case 0x42:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(4, 32);
|
||||
break;
|
||||
case 0x44:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(4, 64);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void int8_scaled_mm_kernel_impl(
|
||||
scalar_t* __restrict__ out,
|
||||
const uint8_t* __restrict__ mat1,
|
||||
const int8_t* __restrict__ mat2,
|
||||
const float* __restrict__ scales1,
|
||||
const float* __restrict__ scales2,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K) {
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
// TODO: brgemm u8s8 depends on PyTorch 2.7 release.
|
||||
const bool use_brgemm = false;
|
||||
|
||||
// K + 4 after compensation
|
||||
const int64_t packed_row_size = get_row_size<int8_t>(K);
|
||||
|
||||
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t mb{0}, nb{0};
|
||||
data_index_init(begin, mb, MB, nb, NB);
|
||||
|
||||
// for brgemm, use int32_t for accumulate
|
||||
alignas(64) int32_t Ctmp[BLOCK_M * BLOCK_N];
|
||||
|
||||
for (int i = begin; i < end; ++i) {
|
||||
UNUSED(i);
|
||||
int mb_start = mb * BLOCK_M;
|
||||
int mb_size = std::min(M - mb_start, BLOCK_M);
|
||||
int nb_start = nb * BLOCK_N;
|
||||
int nb_size = std::min(N - nb_start, BLOCK_N);
|
||||
|
||||
tinygemm_kernel<scalar_t, has_bias>(
|
||||
/* A */ mat1 + mb_start * K,
|
||||
/* B */ mat2 + nb_start * packed_row_size /* nb * BLOCK_N * (K + 4) */,
|
||||
/* C */ out + mb_start * N + nb_start,
|
||||
/* Ctmp*/ Ctmp,
|
||||
/* As */ scales1 + mb_start,
|
||||
/* Bs */ scales2 + nb_start,
|
||||
/* bias*/ bias + nb_start,
|
||||
/* M */ mb_size,
|
||||
/* N */ nb_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ nb_size,
|
||||
/* ldc */ N,
|
||||
/* brg */ use_brgemm);
|
||||
|
||||
// move to the next index
|
||||
data_index_step(mb, MB, nb, NB);
|
||||
}
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// tinygemm interface
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
int32_t* __restrict__ Ctmp,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg) {
|
||||
tinygemm_kernel<scalar_t, false>(A, B, C, Ctmp, As, Bs, nullptr, M, N, K, lda, ldb, ldc, brg);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
|
||||
template void tinygemm_kernel<TYPE>( \
|
||||
const uint8_t* __restrict__ A, \
|
||||
const int8_t* __restrict__ B, \
|
||||
TYPE* __restrict__ C, \
|
||||
int32_t* __restrict__ Ctmp, \
|
||||
const float* __restrict__ As, \
|
||||
const float* __restrict__ Bs, \
|
||||
int64_t M, \
|
||||
int64_t N, \
|
||||
int64_t K, \
|
||||
int64_t lda, \
|
||||
int64_t ldb, \
|
||||
int64_t ldc, \
|
||||
bool brg)
|
||||
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> per_token_quant_int8_cpu(at::Tensor& A) {
|
||||
RECORD_FUNCTION("sgl-kernel::per_token_quant_int8_cpu", std::vector<c10::IValue>({A}));
|
||||
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(A);
|
||||
CHECK_DIM(2, A);
|
||||
|
||||
int64_t M = A.size(0);
|
||||
int64_t K = A.size(1);
|
||||
int64_t lda = A.stride(0);
|
||||
|
||||
const auto st = A.scalar_type();
|
||||
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, "per_token_quant_int8: expect A to be bfloat16 or half.");
|
||||
|
||||
auto Aq = at::empty({M, K}, A.options().dtype(at::kByte));
|
||||
auto As = at::empty({M}, A.options().dtype(at::kFloat));
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "per_token_quant_int8", [&] {
|
||||
uint8_t* __restrict__ Aq_data = Aq.data_ptr<uint8_t>();
|
||||
float* __restrict__ As_data = As.data_ptr<float>();
|
||||
const scalar_t* __restrict__ A_data = A.data_ptr<scalar_t>();
|
||||
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(Aq_data + m * K, As_data[m], A_data + m * lda, K);
|
||||
}
|
||||
});
|
||||
});
|
||||
return std::make_tuple(Aq, As);
|
||||
}
|
||||
|
||||
// weight : static, per-channel, symmetric
|
||||
// activation : dynamic, per-token, symmetric
|
||||
//
|
||||
// mat1 : [M, K]
|
||||
// mat2 : [N, K]
|
||||
// scales1 : [M]
|
||||
// scales2 : [N]
|
||||
// bias : [N]
|
||||
// out : [M, N]
|
||||
//
|
||||
at::Tensor int8_scaled_mm_cpu(
|
||||
at::Tensor& mat1,
|
||||
at::Tensor& mat2,
|
||||
at::Tensor& scales1,
|
||||
at::Tensor& scales2,
|
||||
std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype,
|
||||
bool is_vnni) {
|
||||
RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales1, scales2, bias}));
|
||||
|
||||
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
||||
|
||||
CHECK_INPUT(mat1);
|
||||
CHECK_INPUT(mat2);
|
||||
CHECK_INPUT(scales1);
|
||||
CHECK_INPUT(scales2);
|
||||
CHECK_DIM(2, mat1);
|
||||
CHECK_DIM(2, mat2);
|
||||
|
||||
int64_t M = mat1.size(0);
|
||||
int64_t N = mat2.size(0);
|
||||
int64_t K = mat1.size(1);
|
||||
|
||||
// see [NOTE]: s8s8 igemm compensation in avx512-vnni
|
||||
CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K));
|
||||
CHECK_EQ(scales1.numel(), M);
|
||||
CHECK_EQ(scales2.numel(), N);
|
||||
|
||||
TORCH_CHECK(mat1.scalar_type() == at::kByte, "int8_scaled_mm: expect mat1 to be uint8.");
|
||||
TORCH_CHECK(mat2.scalar_type() == at::kChar, "int8_scaled_mm: expect mat2 to be int8.");
|
||||
TORCH_CHECK(
|
||||
scales1.scalar_type() == at::kFloat && scales2.scalar_type() == at::kFloat,
|
||||
"int8_scaled_mm: expect scales to be float32.");
|
||||
|
||||
auto out = at::empty({M, N}, mat1.options().dtype(out_dtype));
|
||||
|
||||
const bool has_bias = bias.has_value();
|
||||
const float* bias_data = nullptr;
|
||||
if (has_bias) {
|
||||
CHECK_EQ(bias.value().size(0), N);
|
||||
bias_data = bias.value().data_ptr<float>();
|
||||
}
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_kernel_impl", [&] {
|
||||
int8_scaled_mm_kernel_impl<scalar_t>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
mat1.data_ptr<uint8_t>(),
|
||||
packed_w.data_ptr<int8_t>(),
|
||||
scales1.data_ptr<float>(),
|
||||
scales2.data_ptr<float>(),
|
||||
bias_data,
|
||||
M,
|
||||
N,
|
||||
K);
|
||||
});
|
||||
return out;
|
||||
}
|
||||
|
||||
// fused `per_token_quant_int8_cpu` and `int8_scaled_mm_cpu`
|
||||
at::Tensor int8_scaled_mm_with_quant(
|
||||
at::Tensor& mat1,
|
||||
at::Tensor& mat2,
|
||||
at::Tensor& scales2,
|
||||
std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype,
|
||||
bool is_vnni) {
|
||||
RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, bias}));
|
||||
|
||||
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
||||
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
|
||||
CHECK_INPUT(mat2);
|
||||
CHECK_INPUT(scales2);
|
||||
CHECK_DIM(2, mat1);
|
||||
CHECK_DIM(2, mat2);
|
||||
|
||||
int64_t M = mat1.size(0);
|
||||
int64_t N = mat2.size(0);
|
||||
int64_t K = mat1.size(1);
|
||||
int64_t lda = mat1.stride(0);
|
||||
|
||||
// see [NOTE]: s8s8 igemm compensation in avx512-vnni
|
||||
CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K));
|
||||
CHECK_EQ(scales2.numel(), N);
|
||||
|
||||
const auto st = mat1.scalar_type();
|
||||
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, "int8_scaled_mm_with_quant: expect A to be bfloat16 or half.");
|
||||
TORCH_CHECK(st == out_dtype, "int8_scaled_mm_with_quant: expect A has same dtype with out_dtype.");
|
||||
TORCH_CHECK(mat2.scalar_type() == at::kChar, "int8_scaled_mm_with_quant: expect mat2 to be int8.");
|
||||
TORCH_CHECK(scales2.scalar_type() == at::kFloat, "int8_scaled_mm_with_quant: expect scales to be float32.");
|
||||
|
||||
const int64_t buffer_size = M * K + M * sizeof(float);
|
||||
auto buffer = at::empty({buffer_size}, mat1.options().dtype(at::kByte));
|
||||
auto out = at::empty({M, N}, mat1.options().dtype(out_dtype));
|
||||
|
||||
const bool has_bias = bias.has_value();
|
||||
const float* bias_data = nullptr;
|
||||
if (has_bias) {
|
||||
CHECK_EQ(bias.value().size(0), N);
|
||||
bias_data = bias.value().data_ptr<float>();
|
||||
}
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_with_quant_kernel_impl", [&] {
|
||||
uint8_t* __restrict__ Aq_data = buffer.data_ptr<uint8_t>();
|
||||
float* __restrict__ As_data = (float*)((void*)(Aq_data + M * K));
|
||||
const scalar_t* __restrict__ A_data = mat1.data_ptr<scalar_t>();
|
||||
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(Aq_data + m * K, As_data[m], A_data + m * lda, K);
|
||||
}
|
||||
});
|
||||
|
||||
int8_scaled_mm_kernel_impl<scalar_t>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
Aq_data,
|
||||
packed_w.data_ptr<int8_t>(),
|
||||
As_data,
|
||||
scales2.data_ptr<float>(),
|
||||
bias_data,
|
||||
M,
|
||||
N,
|
||||
K);
|
||||
});
|
||||
return out;
|
||||
}
|
||||
120
sgl-kernel/csrc/cpu/interface.cpp
Normal file
120
sgl-kernel/csrc/cpu/interface.cpp
Normal file
@@ -0,0 +1,120 @@
|
||||
#include <ATen/record_function.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include "shm.h"
|
||||
|
||||
// Communication settings
|
||||
static int world_rank = -1;
|
||||
static int world_size = -1;
|
||||
|
||||
static bool is_initialized = false;
|
||||
|
||||
static bool all_ranks_local_p = false;
|
||||
|
||||
void initialize(int size, int rank) {
|
||||
if (is_initialized) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Check whether all ranks is on the same physical machine.
|
||||
// If true, we will use an SHM based low latency allreduce
|
||||
|
||||
auto ls_string = std::getenv("LOCAL_SIZE");
|
||||
int ls = 0;
|
||||
if (ls_string != NULL) {
|
||||
ls = std::stoi(std::getenv("LOCAL_SIZE"));
|
||||
}
|
||||
|
||||
if (size >= 1 && size == ls) {
|
||||
all_ranks_local_p = true;
|
||||
}
|
||||
|
||||
world_size = size;
|
||||
world_rank = rank;
|
||||
is_initialized = true;
|
||||
|
||||
auto addr_string = std::getenv("MASTER_ADDR");
|
||||
if (addr_string == NULL) {
|
||||
addr_string = "";
|
||||
}
|
||||
auto port_string = std::getenv("MASTER_PORT");
|
||||
if (port_string == NULL) {
|
||||
port_string = "";
|
||||
}
|
||||
|
||||
if (all_ranks_local_p) {
|
||||
shm_initialize(size, rank, addr_string, port_string);
|
||||
}
|
||||
}
|
||||
|
||||
void shm_allreduce(torch::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, py::object op) {
|
||||
RECORD_FUNCTION("sgl-kernel::shm_allreduce", std::vector<c10::IValue>({data}));
|
||||
|
||||
static py::object ReduceOp = py::module_::import("torch.distributed").attr("ReduceOp");
|
||||
static auto ReduceOpSum = (int)py::int_(ReduceOp.attr("SUM").attr("value"));
|
||||
TORCH_CHECK(py::int_(op.attr("value")) == ReduceOpSum, "Only torch.distributed.ReduceOp.SUM is supported");
|
||||
|
||||
auto numel = data.numel();
|
||||
|
||||
int data_size = 0;
|
||||
bool data_type_fallback = false;
|
||||
|
||||
switch (data.scalar_type()) {
|
||||
case c10::ScalarType::BFloat16:
|
||||
data_size = numel * 2;
|
||||
break;
|
||||
case c10::ScalarType::Float:
|
||||
data_size = numel * 4;
|
||||
break;
|
||||
default:
|
||||
data_type_fallback = true;
|
||||
}
|
||||
|
||||
if (data_type_fallback || !all_ranks_local_p) {
|
||||
// Fallback to torch distributed allreduce
|
||||
std::vector<torch::Tensor> tensors = {data};
|
||||
process_group->allreduce(tensors)->wait();
|
||||
} else {
|
||||
all_reduce_outer_loop(data, numel, data_size);
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
torch::Tensor shm_allgather(torch::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, int dim) {
|
||||
RECORD_FUNCTION("sgl-kernel::shm_allgather", std::vector<c10::IValue>({data}));
|
||||
|
||||
auto numel = data.numel();
|
||||
|
||||
int data_size = 0;
|
||||
bool data_type_fallback = false;
|
||||
|
||||
switch (data.scalar_type()) {
|
||||
case c10::ScalarType::BFloat16:
|
||||
data_size = numel * 2;
|
||||
break;
|
||||
case c10::ScalarType::Float:
|
||||
data_size = numel * 4;
|
||||
break;
|
||||
default:
|
||||
data_type_fallback = true;
|
||||
}
|
||||
if (dim < 0) {
|
||||
dim += data.dim();
|
||||
}
|
||||
if (data_type_fallback || !all_ranks_local_p) {
|
||||
// Fallback to torch distributed allreduce
|
||||
std::vector<std::vector<torch::Tensor>> output_tensors(1);
|
||||
auto world_size = process_group->getSize();
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
output_tensors[0].push_back(torch::empty_like(data));
|
||||
}
|
||||
std::vector<torch::Tensor> input_tensors = {data};
|
||||
process_group->allgather(output_tensors, input_tensors)->wait();
|
||||
return torch::cat(output_tensors[0], dim).contiguous();
|
||||
}
|
||||
std::vector<int64_t> result_shape = data.sizes().vec();
|
||||
result_shape[dim] *= world_size;
|
||||
torch::Tensor result_tensor = torch::empty(result_shape, data.options());
|
||||
return all_gather(result_tensor, data, dim, numel, data_size);
|
||||
}
|
||||
1247
sgl-kernel/csrc/cpu/moe.cpp
Normal file
1247
sgl-kernel/csrc/cpu/moe.cpp
Normal file
File diff suppressed because it is too large
Load Diff
830
sgl-kernel/csrc/cpu/moe_int8.cpp
Normal file
830
sgl-kernel/csrc/cpu/moe_int8.cpp
Normal file
@@ -0,0 +1,830 @@
|
||||
#include "common.h"
|
||||
#include "gemm.h"
|
||||
#include "vec.h"
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) {
|
||||
using Vec = at::vec::Vectorized<scalar_t>;
|
||||
// no remainder
|
||||
#pragma GCC unroll 4
|
||||
for (int64_t d = 0; d < size; d += Vec::size()) {
|
||||
Vec data = Vec::loadu(input + d);
|
||||
data.store(out + d);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void copy_stub<uint8_t>(uint8_t* __restrict__ out, const uint8_t* __restrict__ input, int64_t size) {
|
||||
// size might be 64x + 32
|
||||
std::memcpy(out, input, size * sizeof(uint8_t));
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input, float weight, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
const fVec weight_vec = fVec(weight);
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec data0 = fVec::loadu(input + d) * weight_vec;
|
||||
fVec data1 = fVec::loadu(input + d + fVec::size()) * weight_vec;
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d] * weight);
|
||||
}
|
||||
}
|
||||
|
||||
// acc from [topk, K] to [K]
|
||||
template <typename scalar_t>
|
||||
inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
if (topk == 1) {
|
||||
// do copy for topk = 1
|
||||
copy_stub(out, input, K);
|
||||
} else {
|
||||
// do sum for topk != 1
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= K - kVecSize; d += kVecSize) {
|
||||
fVec sum_fvec0 = fVec(0.f);
|
||||
fVec sum_fvec1 = fVec(0.f);
|
||||
for (int t = 0; t < topk; ++t) {
|
||||
bVec x_bvec = bVec::loadu(input + t * K + d);
|
||||
fVec x_fvec0, x_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||
|
||||
sum_fvec0 += x_fvec0;
|
||||
sum_fvec1 += x_fvec1;
|
||||
}
|
||||
bVec out_bvec = convert_from_float_ext<scalar_t>(sum_fvec0, sum_fvec1);
|
||||
out_bvec.store(out + d);
|
||||
}
|
||||
for (; d < K; ++d) {
|
||||
float sum_val = 0.f;
|
||||
for (int t = 0; t < topk; ++t) {
|
||||
sum_val += static_cast<float>(input[t * K + d]);
|
||||
}
|
||||
out[d] = static_cast<scalar_t>(sum_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// out = input + input2 * scale
|
||||
template <typename scalar_t>
|
||||
inline void add_mul_stub(
|
||||
scalar_t* __restrict__ out,
|
||||
const float* __restrict__ input,
|
||||
const scalar_t* __restrict__ input2,
|
||||
float scale,
|
||||
int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
const fVec s_vec = fVec(scale);
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec x0 = fVec::loadu(input + d);
|
||||
fVec x1 = fVec::loadu(input + d + fVec::size());
|
||||
|
||||
bVec y_bvec = bVec::loadu(input2 + d);
|
||||
fVec y0, y1;
|
||||
std::tie(y0, y1) = at::vec::convert_to_float(y_bvec);
|
||||
|
||||
x0 = x0 + y0 * s_vec;
|
||||
x1 = x1 + y1 * s_vec;
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d] + float(input2[d]) * scale);
|
||||
}
|
||||
}
|
||||
|
||||
/// gemm for w13
|
||||
template <typename scalar_t, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_vnni {
|
||||
static inline void apply(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B0,
|
||||
const int8_t* __restrict__ B1,
|
||||
scalar_t* __restrict__ C,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs0,
|
||||
const float* __restrict__ Bs1,
|
||||
const int32_t* __restrict__ Bcomp0,
|
||||
const int32_t* __restrict__ Bcomp1,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc) {
|
||||
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_vnni<at::BFloat16, BLOCK_M, BLOCK_N> {
|
||||
static inline void apply(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B0,
|
||||
const int8_t* __restrict__ B1,
|
||||
at::BFloat16* __restrict__ C,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs0,
|
||||
const float* __restrict__ Bs1,
|
||||
const int32_t* __restrict__ Bcomp0,
|
||||
const int32_t* __restrict__ Bcomp1,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc) {
|
||||
constexpr int ROWS = BLOCK_M;
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
static_assert(COLS % 2 == 0);
|
||||
|
||||
__m512i va;
|
||||
__m512i vb0[COLS];
|
||||
__m512i vb1[COLS];
|
||||
__m512i vc0[ROWS * COLS];
|
||||
__m512i vc1[ROWS * COLS];
|
||||
__m512i vcomp0[COLS];
|
||||
__m512i vcomp1[COLS];
|
||||
__m512 vas;
|
||||
__m512 vbs0[COLS];
|
||||
__m512 vbs1[COLS];
|
||||
|
||||
auto loadc = [&](auto i) {
|
||||
vc0[i] = _mm512_set1_epi32(0);
|
||||
vc1[i] = _mm512_set1_epi32(0);
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(loadc);
|
||||
|
||||
const int64_t K4 = K >> 2;
|
||||
const int64_t lda4 = lda >> 2;
|
||||
const int64_t ldb4 = ldb; // ldb * 4 >> 2;
|
||||
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A);
|
||||
const int32_t* b0_ptr = reinterpret_cast<const int32_t*>(B0);
|
||||
const int32_t* b1_ptr = reinterpret_cast<const int32_t*>(B1);
|
||||
|
||||
auto compute = [&](auto i, int64_t k) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
va = _mm512_set1_epi32(a_ptr[row * lda4 + k]);
|
||||
}
|
||||
if constexpr (row == 0) {
|
||||
vb0[col] = _mm512_loadu_si512(b0_ptr + k * ldb4 + col * 16);
|
||||
vb1[col] = _mm512_loadu_si512(b1_ptr + k * ldb4 + col * 16);
|
||||
}
|
||||
vc0[i] = _mm512_dpbusd_epi32(vc0[i], va, vb0[col]);
|
||||
vc1[i] = _mm512_dpbusd_epi32(vc1[i], va, vb1[col]);
|
||||
};
|
||||
for (int64_t k = 0; k < K4; ++k) {
|
||||
Unroll<ROWS * COLS>{}(compute, k);
|
||||
}
|
||||
|
||||
auto scalec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
// load a scale
|
||||
if constexpr (col == 0) {
|
||||
vas = _mm512_set1_ps(As[row]);
|
||||
}
|
||||
// load b scale and vcomp
|
||||
if constexpr (row == 0) {
|
||||
vbs0[col] = _mm512_loadu_ps(Bs0 + col * 16);
|
||||
vbs1[col] = _mm512_loadu_ps(Bs1 + col * 16);
|
||||
vcomp0[col] = _mm512_loadu_si512(Bcomp0 + col * 16);
|
||||
vcomp1[col] = _mm512_loadu_si512(Bcomp1 + col * 16);
|
||||
}
|
||||
__m512 c0 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc0[i], vcomp0[col]));
|
||||
__m512 c1 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc1[i], vcomp1[col]));
|
||||
vc0[i] = _mm512_castps_si512(_mm512_mul_ps(_mm512_mul_ps(c0, vas), vbs0[col]));
|
||||
vc1[i] = _mm512_castps_si512(_mm512_mul_ps(_mm512_mul_ps(c1, vas), vbs1[col]));
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(scalec);
|
||||
|
||||
using Vec = at::vec::Vectorized<float>;
|
||||
const Vec one = Vec(1.f);
|
||||
auto storec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
// for COLS = 2, 4 use 512bit store
|
||||
if constexpr (col % 2 == 0) {
|
||||
Vec x0 = _mm512_castsi512_ps(vc0[row * COLS + col + 0]);
|
||||
Vec x1 = _mm512_castsi512_ps(vc0[row * COLS + col + 1]);
|
||||
Vec y0 = _mm512_castsi512_ps(vc1[row * COLS + col + 0]);
|
||||
Vec y1 = _mm512_castsi512_ps(vc1[row * COLS + col + 1]);
|
||||
// silu
|
||||
x0 = x0 / (one + x0.neg().exp_u20());
|
||||
x1 = x1 / (one + x1.neg().exp_u20());
|
||||
// mul
|
||||
x0 = x0 * y0;
|
||||
x1 = x1 * y1;
|
||||
|
||||
_mm512_storeu_si512(
|
||||
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
|
||||
(__m512i)(_mm512_cvtne2ps_pbh(__m512(x1), __m512(x0))));
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(storec);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL_VNNI(MB_SIZE, NB_SIZE) \
|
||||
tinygemm_kernel_vnni<scalar_t, MB_SIZE, NB_SIZE>::apply( \
|
||||
A + mb_start * lda, \
|
||||
B0 + nb_start * 4, \
|
||||
B1 + nb_start * 4, \
|
||||
C + mb_start * ldc + nb_start, \
|
||||
As + mb_start, \
|
||||
Bs0 + nb_start, \
|
||||
Bs1 + nb_start, \
|
||||
Bcomp0 + nb_start, \
|
||||
Bcomp1 + nb_start, \
|
||||
K, \
|
||||
lda, \
|
||||
ldb, \
|
||||
ldc);
|
||||
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B0,
|
||||
const int8_t* __restrict__ B1,
|
||||
scalar_t* __restrict__ C,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs0,
|
||||
const float* __restrict__ Bs1,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc) {
|
||||
const int32_t* Bcomp0 = reinterpret_cast<const int32_t*>(B0 + block_size_n() * K);
|
||||
const int32_t* Bcomp1 = reinterpret_cast<const int32_t*>(B1 + block_size_n() * K);
|
||||
|
||||
// pattern: 1-(2+2)-(8+8)
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 32;
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
for (int mb = 0; mb < MB; ++mb) {
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
for (int64_t nb = 0; nb < NB; ++nb) {
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
switch (mb_size << 4 | nb_size >> 4) {
|
||||
case 0x12:
|
||||
LAUNCH_TINYGEMM_KERNEL_VNNI(1, 32);
|
||||
break;
|
||||
case 0x22:
|
||||
LAUNCH_TINYGEMM_KERNEL_VNNI(2, 32);
|
||||
break;
|
||||
case 0x32:
|
||||
LAUNCH_TINYGEMM_KERNEL_VNNI(3, 32);
|
||||
break;
|
||||
case 0x42:
|
||||
LAUNCH_TINYGEMM_KERNEL_VNNI(4, 32);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// gemm for w2
|
||||
template <typename scalar_t, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_vnni2 {
|
||||
static inline void apply(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B,
|
||||
float* __restrict__ C,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs,
|
||||
const int32_t* __restrict__ Bcomp,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc) {
|
||||
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_vnni2<at::BFloat16, BLOCK_M, BLOCK_N> {
|
||||
static inline void apply(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B,
|
||||
float* __restrict__ C,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs,
|
||||
const int32_t* __restrict__ Bcomp,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc) {
|
||||
constexpr int ROWS = BLOCK_M;
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
static_assert(COLS % 2 == 0);
|
||||
|
||||
__m512i va;
|
||||
__m512i vb[COLS];
|
||||
__m512i vc[ROWS * COLS];
|
||||
__m512i vcomp[COLS];
|
||||
__m512 vas;
|
||||
__m512 vbs[COLS];
|
||||
|
||||
auto loadc = [&](auto i) { vc[i] = _mm512_set1_epi32(0); };
|
||||
Unroll<ROWS * COLS>{}(loadc);
|
||||
|
||||
const int64_t K4 = K >> 2;
|
||||
const int64_t lda4 = lda >> 2;
|
||||
const int64_t ldb4 = ldb; // ldb * 4 >> 2;
|
||||
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A);
|
||||
const int32_t* b_ptr = reinterpret_cast<const int32_t*>(B);
|
||||
|
||||
auto compute = [&](auto i, int64_t k) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
va = _mm512_set1_epi32(a_ptr[row * lda4 + k]);
|
||||
}
|
||||
if constexpr (row == 0) {
|
||||
vb[col] = _mm512_loadu_si512(b_ptr + k * ldb4 + col * 16);
|
||||
}
|
||||
vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]);
|
||||
};
|
||||
for (int64_t k = 0; k < K4; ++k) {
|
||||
Unroll<ROWS * COLS>{}(compute, k);
|
||||
}
|
||||
|
||||
auto storec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
// load a scale
|
||||
if constexpr (col == 0) {
|
||||
vas = _mm512_set1_ps(As[row]);
|
||||
}
|
||||
// load b scale and vcomp per 2 vectors
|
||||
// also load bias if any
|
||||
if constexpr (row == 0) {
|
||||
if constexpr (col % 2 == 0) {
|
||||
vbs[col + 0] = _mm512_loadu_ps(Bs + col * 16);
|
||||
vbs[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16);
|
||||
vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16);
|
||||
vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16);
|
||||
}
|
||||
}
|
||||
__m512 x = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[i], vcomp[col]));
|
||||
x = _mm512_mul_ps(_mm512_mul_ps(x, vas), vbs[col]);
|
||||
_mm512_storeu_ps(reinterpret_cast<__m512*>(C + row * ldc + col * 16), x);
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(storec);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL_VNNI2(MB_SIZE, NB_SIZE) \
|
||||
tinygemm_kernel_vnni2<scalar_t, MB_SIZE, NB_SIZE>::apply( \
|
||||
A + mb_start * lda, \
|
||||
B + nb_start * 4, \
|
||||
C + mb_start * ldc + nb_start, \
|
||||
As + mb_start, \
|
||||
Bs + nb_start, \
|
||||
Bcomp + nb_start, \
|
||||
K, \
|
||||
lda, \
|
||||
ldb, \
|
||||
ldc);
|
||||
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B,
|
||||
float* __restrict__ C,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc) {
|
||||
// B compensation
|
||||
const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * K);
|
||||
|
||||
// pattern: 1-4-16
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 64;
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
for (int64_t mb = 0; mb < MB; ++mb) {
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
for (int64_t nb = 0; nb < NB; ++nb) {
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
switch (mb_size << 4 | nb_size >> 4) {
|
||||
case 0x12:
|
||||
LAUNCH_TINYGEMM_KERNEL_VNNI2(1, 32);
|
||||
break;
|
||||
case 0x22:
|
||||
LAUNCH_TINYGEMM_KERNEL_VNNI2(2, 32);
|
||||
break;
|
||||
case 0x32:
|
||||
LAUNCH_TINYGEMM_KERNEL_VNNI2(3, 32);
|
||||
break;
|
||||
case 0x42:
|
||||
LAUNCH_TINYGEMM_KERNEL_VNNI2(4, 32);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
template <typename scalar_t>
|
||||
void fused_experts_int8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ ic2,
|
||||
uint8_t* __restrict__ A_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
uint8_t* __restrict__ Aq_tmp,
|
||||
float* __restrict__ As_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const int8_t* __restrict__ packed_w1,
|
||||
const int8_t* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
const float* __restrict__ topk_weights,
|
||||
const int32_t* __restrict__ sorted_ids,
|
||||
const int32_t* __restrict__ expert_ids,
|
||||
const int32_t* __restrict__ offsets,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t E,
|
||||
int64_t topk,
|
||||
int64_t num_tokens_post_pad) {
|
||||
// handle 2 tiles per block
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
|
||||
// stage 0: quantize input to uint8, [M, K]
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(Aq_tmp + m * K, As_tmp[m], input + m * K, K);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 1: intermediate_cache1 = silu(hidden_states @ w1)
|
||||
const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
// strides for w1: [E, 2N, K]
|
||||
TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N);
|
||||
|
||||
// K and N are packed for int8
|
||||
const int64_t packed_K = get_row_size<int8_t>(K);
|
||||
const int64_t packed_N = get_row_size<int8_t>(N);
|
||||
|
||||
const int64_t stride_e = 2 * N * packed_K;
|
||||
const int64_t stride_n = packed_K;
|
||||
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
// get local pointers
|
||||
int tid = at::get_thread_num();
|
||||
uint8_t* __restrict__ A = A_tmp + tid * BLOCK_M * K;
|
||||
|
||||
alignas(64) float As[BLOCK_M];
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB;
|
||||
int64_t nb = i % NB;
|
||||
|
||||
// nb0 from top half and nb1 from bottom half
|
||||
int64_t nb0 = nb, nb1 = nb + NB;
|
||||
int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N);
|
||||
|
||||
// B shape [K, n_size] in vnni format
|
||||
int32_t expert_id = expert_ids[mb];
|
||||
const int8_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb0 * BLOCK_N * stride_n;
|
||||
const int8_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb1 * BLOCK_N * stride_n;
|
||||
const float* __restrict__ Bs0 = w1s + expert_id * 2 * N + nb0 * BLOCK_N;
|
||||
const float* __restrict__ Bs1 = w1s + expert_id * 2 * N + nb1 * BLOCK_N;
|
||||
|
||||
// 1.a load A
|
||||
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
|
||||
int64_t m_size = offsets[mb + 1] - offsets[mb];
|
||||
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
int32_t index = A_ids[m] / topk;
|
||||
copy_stub(A + m * K, Aq_tmp + index * K, K);
|
||||
As[m] = As_tmp[index];
|
||||
}
|
||||
|
||||
// fused 1.b: silu_and_mul(A @ B0, A @ B1)
|
||||
const int64_t offset = offsets[mb];
|
||||
tinygemm_kernel(
|
||||
/* A */ A,
|
||||
/* B0 */ B0,
|
||||
/* B1 */ B1,
|
||||
/* C */ ic1 + offset * N + nb * BLOCK_N,
|
||||
/* As */ As,
|
||||
/* Bs0 */ Bs0,
|
||||
/* Bs1 */ Bs1,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ N);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 1.5: quantize ic1 to uint8, [M * topk, N]
|
||||
at::parallel_for(0, M * topk, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(Aq_tmp + m * N, As_tmp[m], ic1 + m * N, N);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 2: intermediate_cache2 = intermediate_cache1 @ w2
|
||||
// w2 : [E, K, N] as [E, OC, IC]
|
||||
const int64_t OC = K; // rename K as OC
|
||||
const int64_t IC = N; // rename N as IC
|
||||
const int64_t MB2 = MB;
|
||||
const int64_t NB2 = div_up(OC, BLOCK_N);
|
||||
const int64_t stride_e2 = OC * packed_N;
|
||||
const int64_t stride_oc = packed_N;
|
||||
|
||||
// parallel on [MB2, NB2]
|
||||
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
|
||||
// get local pointers
|
||||
int tid = at::get_thread_num();
|
||||
// we won't be using C1 for gemm2
|
||||
float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB2;
|
||||
int64_t nb = i % NB2;
|
||||
|
||||
int64_t m_size = offsets[mb + 1] - offsets[mb];
|
||||
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
// A ptr from ic1 of [M * topk, N] in sorted order
|
||||
// so as to avoid copy A to tmp buffer again
|
||||
const uint8_t* __restrict__ A = Aq_tmp + offsets[mb] * N;
|
||||
const float* __restrict__ As = As_tmp + offsets[mb];
|
||||
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
|
||||
|
||||
// B shape [IC, n_size] in vnni format
|
||||
int32_t expert_id = expert_ids[mb];
|
||||
const int8_t* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc;
|
||||
const float* __restrict__ Bs = w2s + expert_id * K + nb * BLOCK_N;
|
||||
|
||||
// 2.a gemm: C = A @ B
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ C,
|
||||
/* As */ As,
|
||||
/* Bs */ Bs,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ IC,
|
||||
/* lda */ IC,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N);
|
||||
|
||||
// 2.b copy from C to ic2 in original order
|
||||
// and also mul topk_weights in float32
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
int32_t index = A_ids[m];
|
||||
float weight = topk_weights[index];
|
||||
copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// stage 3: out = intermediate_cache2.sum(dim=1)
|
||||
// from [M, topk, K] to [M, K]
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
sum_stub(output + m * K, ic2 + m * topk * K, topk, K);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#define INSTANTIATE_MOE_INT8_TEMPLATE(TYPE) \
|
||||
template void fused_experts_int8_kernel_impl<TYPE>( \
|
||||
TYPE* __restrict__ output, \
|
||||
TYPE* __restrict__ ic1, \
|
||||
TYPE* __restrict__ ic2, \
|
||||
uint8_t* __restrict__ A_tmp, \
|
||||
float* __restrict__ C_tmp, \
|
||||
uint8_t* __restrict__ Aq_tmp, \
|
||||
float* __restrict__ As_tmp, \
|
||||
const TYPE* __restrict__ input, \
|
||||
const int8_t* __restrict__ packed_w1, \
|
||||
const int8_t* __restrict__ packed_w2, \
|
||||
const float* __restrict__ w1s, \
|
||||
const float* __restrict__ w2s, \
|
||||
const float* __restrict__ topk_weights, \
|
||||
const int32_t* __restrict__ sorted_ids, \
|
||||
const int32_t* __restrict__ expert_ids, \
|
||||
const int32_t* __restrict__ offsets, \
|
||||
int64_t M, \
|
||||
int64_t N, \
|
||||
int64_t K, \
|
||||
int64_t E, \
|
||||
int64_t topk, \
|
||||
int64_t num_tokens_post_pad)
|
||||
|
||||
INSTANTIATE_MOE_INT8_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_MOE_INT8_TEMPLATE(at::Half);
|
||||
|
||||
template <typename scalar_t>
|
||||
void shared_expert_int8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic1,
|
||||
float* __restrict__ C_tmp,
|
||||
uint8_t* __restrict__ Aq_tmp,
|
||||
float* __restrict__ As_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const int8_t* __restrict__ packed_w1,
|
||||
const int8_t* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
const scalar_t* __restrict__ fused_experts_out,
|
||||
float routed_scaling_factor,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K) {
|
||||
// handle 2 tiles per block
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
|
||||
// stage 0: quantize input to uint8, [M, K]
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(Aq_tmp + m * K, As_tmp[m], input + m * K, K);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 1: intermediate_cache1 = silu(hidden_states @ w1)
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N);
|
||||
|
||||
// K and N are packed for int8
|
||||
const int64_t packed_K = get_row_size<int8_t>(K);
|
||||
const int64_t packed_N = get_row_size<int8_t>(N);
|
||||
const int64_t stride_n = packed_K;
|
||||
|
||||
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB;
|
||||
int64_t nb = i % NB;
|
||||
|
||||
// nb0 from top half and nb1 from bottom half
|
||||
int64_t nb0 = nb, nb1 = nb + NB;
|
||||
int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N);
|
||||
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
|
||||
// A shape [m_size, K]
|
||||
const uint8_t* A = Aq_tmp + mb * BLOCK_M * K;
|
||||
const float* As = As_tmp + mb * BLOCK_M;
|
||||
|
||||
// B shape [K, n_size] in vnni format
|
||||
const int8_t* __restrict__ B0 = packed_w1 + nb0 * BLOCK_N * stride_n;
|
||||
const int8_t* __restrict__ B1 = packed_w1 + nb1 * BLOCK_N * stride_n;
|
||||
const float* __restrict__ Bs0 = w1s + nb0 * BLOCK_N;
|
||||
const float* __restrict__ Bs1 = w1s + nb1 * BLOCK_N;
|
||||
|
||||
// fused 1.b: silu_and_mul(A @ B0, A @ B1)
|
||||
tinygemm_kernel(
|
||||
/* A */ A,
|
||||
/* B0 */ B0,
|
||||
/* B1 */ B1,
|
||||
/* C */ ic1 + mb * BLOCK_M * N + nb * BLOCK_N,
|
||||
/* As */ As,
|
||||
/* Bs0 */ Bs0,
|
||||
/* Bs1 */ Bs1,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ N);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 1.5: quantize ic1 to uint8, [M * topk, N]
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(Aq_tmp + m * N, As_tmp[m], ic1 + m * N, N);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 2: intermediate_cache2 = intermediate_cache1 @ w2
|
||||
// w2 : [K, N] as [OC, IC]
|
||||
const int64_t OC = K; // rename K as OC
|
||||
const int64_t IC = N; // rename N as IC
|
||||
const int64_t MB2 = MB;
|
||||
const int64_t NB2 = div_up(OC, BLOCK_N);
|
||||
const int64_t stride_oc = packed_N;
|
||||
|
||||
// parallel on [MB2, NB2]
|
||||
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
|
||||
// get local pointers
|
||||
int tid = at::get_thread_num();
|
||||
// we won't be using C1 for gemm2
|
||||
float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB2;
|
||||
int64_t nb = i % NB2;
|
||||
|
||||
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
// A shape [m_size, IC]
|
||||
const uint8_t* __restrict__ A = Aq_tmp + mb * BLOCK_M * N;
|
||||
const float* __restrict__ As = As_tmp + mb * BLOCK_M;
|
||||
|
||||
// B shape [IC, n_size] in vnni format
|
||||
const int8_t* __restrict__ B = packed_w2 + nb * BLOCK_N * stride_oc;
|
||||
const float* __restrict__ Bs = w2s + nb * BLOCK_N;
|
||||
|
||||
// 2.a gemm: C = A @ B
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ C,
|
||||
/* As */ As,
|
||||
/* Bs */ Bs,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ IC,
|
||||
/* lda */ IC,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N);
|
||||
|
||||
// 2.b copy from C to output and add fused_experts_out
|
||||
scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N;
|
||||
const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N;
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#define INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(TYPE) \
|
||||
template void shared_expert_int8_kernel_impl<TYPE>( \
|
||||
TYPE* __restrict__ output, \
|
||||
TYPE* __restrict__ ic1, \
|
||||
float* __restrict__ C_tmp, \
|
||||
uint8_t* __restrict__ Aq_tmp, \
|
||||
float* __restrict__ As_tmp, \
|
||||
const TYPE* __restrict__ input, \
|
||||
const int8_t* __restrict__ packed_w1, \
|
||||
const int8_t* __restrict__ packed_w2, \
|
||||
const float* __restrict__ w1s, \
|
||||
const float* __restrict__ w2s, \
|
||||
const TYPE* __restrict__ fused_experts_out, \
|
||||
float routed_scaling_factor, \
|
||||
int64_t M, \
|
||||
int64_t N, \
|
||||
int64_t K)
|
||||
|
||||
INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(at::Half);
|
||||
221
sgl-kernel/csrc/cpu/norm.cpp
Normal file
221
sgl-kernel/csrc/cpu/norm.cpp
Normal file
@@ -0,0 +1,221 @@
|
||||
#include "common.h"
|
||||
#include "vec.h"
|
||||
|
||||
namespace {
|
||||
|
||||
// NB: avoid using `at::vec::map<>` on bfloat16 or half
|
||||
template <typename scalar_t>
|
||||
void rmsnorm_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
const scalar_t* __restrict__ input,
|
||||
const scalar_t* __restrict__ weight,
|
||||
int64_t batch_size,
|
||||
int64_t hidden_size,
|
||||
float eps = 1e-5) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
|
||||
constexpr int kVecSize = bVec::size();
|
||||
at::parallel_for(0, batch_size, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
// local ptrs
|
||||
scalar_t* __restrict__ out_ptr = output + i * hidden_size;
|
||||
const scalar_t* __restrict__ input_ptr = input + i * hidden_size;
|
||||
|
||||
fVec sum_fvec = fVec(float(0));
|
||||
float sum_val = float(0);
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
|
||||
bVec x_bvec = bVec::loadu(input_ptr + d);
|
||||
fVec x_fvec0, x_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||
|
||||
sum_fvec += x_fvec0 * x_fvec0;
|
||||
sum_fvec += x_fvec1 * x_fvec1;
|
||||
}
|
||||
#pragma GCC unroll 4
|
||||
for (; d < hidden_size; ++d) {
|
||||
float x_val = static_cast<float>(input_ptr[d]);
|
||||
sum_val += x_val * x_val;
|
||||
}
|
||||
|
||||
sum_val += vec_reduce_sum(sum_fvec);
|
||||
float rsqrt_var = float(1) / std::sqrt(sum_val / hidden_size + eps);
|
||||
const fVec scale_fvec = fVec(rsqrt_var);
|
||||
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
|
||||
bVec x_bvec = bVec::loadu(input_ptr + d);
|
||||
fVec x_fvec0, x_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||
|
||||
bVec w_bvec = bVec::loadu(weight + d);
|
||||
fVec w_fvec0, w_fvec1;
|
||||
std::tie(w_fvec0, w_fvec1) = at::vec::convert_to_float(w_bvec);
|
||||
|
||||
x_fvec0 = x_fvec0 * scale_fvec * w_fvec0;
|
||||
x_fvec1 = x_fvec1 * scale_fvec * w_fvec1;
|
||||
|
||||
bVec out_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
|
||||
out_bvec.store(out_ptr + d);
|
||||
}
|
||||
#pragma GCC unroll 4
|
||||
for (; d < hidden_size; ++d) {
|
||||
float x_val = static_cast<float>(input_ptr[d]);
|
||||
float w_val = static_cast<float>(weight[d]);
|
||||
out_ptr[d] = static_cast<scalar_t>(x_val * rsqrt_var * w_val);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void fused_add_rmsnorm_kernel_impl(
|
||||
scalar_t* __restrict__ input,
|
||||
scalar_t* __restrict__ residual,
|
||||
const scalar_t* __restrict__ weight,
|
||||
float* __restrict__ buffer,
|
||||
int64_t batch_size,
|
||||
int64_t hidden_size,
|
||||
float eps = 1e-5) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
|
||||
constexpr int kVecSize = bVec::size();
|
||||
at::parallel_for(0, batch_size, 0, [&](int64_t begin, int64_t end) {
|
||||
int tid = at::get_thread_num();
|
||||
float* __restrict__ buffer_ptr = buffer + tid * hidden_size;
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
// local ptrs
|
||||
scalar_t* __restrict__ input_ptr = input + i * hidden_size;
|
||||
scalar_t* __restrict__ residual_ptr = residual + i * hidden_size;
|
||||
|
||||
fVec sum_fvec = fVec(float(0));
|
||||
float sum_val = float(0);
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
|
||||
bVec x_bvec = bVec::loadu(input_ptr + d);
|
||||
fVec x_fvec0, x_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||
|
||||
bVec r_bvec = bVec::loadu(residual_ptr + d);
|
||||
fVec r_fvec0, r_fvec1;
|
||||
std::tie(r_fvec0, r_fvec1) = at::vec::convert_to_float(r_bvec);
|
||||
|
||||
x_fvec0 += r_fvec0;
|
||||
x_fvec1 += r_fvec1;
|
||||
|
||||
bVec out_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
|
||||
out_bvec.store(residual_ptr + d);
|
||||
|
||||
sum_fvec += x_fvec0 * x_fvec0;
|
||||
sum_fvec += x_fvec1 * x_fvec1;
|
||||
|
||||
x_fvec0.store(buffer_ptr + d);
|
||||
x_fvec1.store(buffer_ptr + d + fVec::size());
|
||||
}
|
||||
#pragma GCC unroll 4
|
||||
for (; d < hidden_size; ++d) {
|
||||
float x_val = static_cast<float>(input_ptr[d]);
|
||||
float r_val = static_cast<float>(residual_ptr[d]);
|
||||
|
||||
x_val += r_val;
|
||||
residual_ptr[d] = static_cast<scalar_t>(x_val);
|
||||
|
||||
sum_val += x_val * x_val;
|
||||
buffer_ptr[d] = x_val;
|
||||
}
|
||||
|
||||
sum_val += vec_reduce_sum(sum_fvec);
|
||||
float rsqrt_var = float(1) / std::sqrt(sum_val / hidden_size + eps);
|
||||
const fVec scale_fvec = fVec(rsqrt_var);
|
||||
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
|
||||
fVec x_fvec0 = fVec::loadu(buffer_ptr + d);
|
||||
fVec x_fvec1 = fVec::loadu(buffer_ptr + d + fVec::size());
|
||||
|
||||
bVec w_bvec = bVec::loadu(weight + d);
|
||||
fVec w_fvec0, w_fvec1;
|
||||
std::tie(w_fvec0, w_fvec1) = at::vec::convert_to_float(w_bvec);
|
||||
|
||||
x_fvec0 = x_fvec0 * scale_fvec * w_fvec0;
|
||||
x_fvec1 = x_fvec1 * scale_fvec * w_fvec1;
|
||||
bVec x_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
|
||||
x_bvec.store(input_ptr + d);
|
||||
}
|
||||
#pragma GCC unroll 4
|
||||
for (; d < hidden_size; ++d) {
|
||||
float x_val = buffer_ptr[d] * rsqrt_var * static_cast<float>(weight[d]);
|
||||
input_ptr[d] = x_val;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// input : {batch_size, hidden_size}
|
||||
// weight: {hidden_size}
|
||||
at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) {
|
||||
RECORD_FUNCTION("sgl-kernel::rmsnorm_cpu", std::vector<c10::IValue>({input, weight}));
|
||||
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(weight);
|
||||
CHECK_DIM(2, input);
|
||||
CHECK_DIM(1, weight);
|
||||
CHECK_EQ(input.size(1), weight.size(0));
|
||||
int64_t batch_size = input.size(0);
|
||||
int64_t hidden_size = input.size(1);
|
||||
at::Tensor output = at::empty_like(input);
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "rmsnorm_kernel", [&] {
|
||||
rmsnorm_kernel_impl<scalar_t>(
|
||||
output.data_ptr<scalar_t>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(),
|
||||
batch_size,
|
||||
hidden_size,
|
||||
eps);
|
||||
});
|
||||
return output;
|
||||
}
|
||||
|
||||
// input : {batch_size, hidden_size}
|
||||
// residual: {batch_size, hidden_size}
|
||||
// weight : {hidden_size}
|
||||
void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps) {
|
||||
RECORD_FUNCTION("sgl-kernel::fused_add_rmsnorm_cpu", std::vector<c10::IValue>({input, residual, weight}));
|
||||
CHECK_INPUT(input);
|
||||
CHECK_INPUT(residual);
|
||||
CHECK_INPUT(weight);
|
||||
CHECK_DIM(2, input);
|
||||
CHECK_DIM(2, residual);
|
||||
CHECK_DIM(1, weight);
|
||||
CHECK_EQ(input.size(0), residual.size(0));
|
||||
CHECK_EQ(input.size(1), residual.size(1));
|
||||
CHECK_EQ(input.size(1), weight.size(0));
|
||||
int64_t batch_size = input.size(0);
|
||||
int64_t hidden_size = input.size(1);
|
||||
|
||||
// allocate temp buffer to store x in float32 per thread
|
||||
// TODO: implement a singleton for context
|
||||
int64_t num_threads = at::get_num_threads();
|
||||
at::Tensor buffer = at::empty({num_threads, hidden_size}, input.options().dtype(at::kFloat));
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "fused_add_rmsnorm_kernel", [&] {
|
||||
fused_add_rmsnorm_kernel_impl<scalar_t>(
|
||||
input.data_ptr<scalar_t>(),
|
||||
residual.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(),
|
||||
buffer.data_ptr<float>(),
|
||||
batch_size,
|
||||
hidden_size,
|
||||
eps);
|
||||
});
|
||||
}
|
||||
504
sgl-kernel/csrc/cpu/qkv_proj.cpp
Normal file
504
sgl-kernel/csrc/cpu/qkv_proj.cpp
Normal file
@@ -0,0 +1,504 @@
|
||||
#include "common.h"
|
||||
#include "gemm.h"
|
||||
#include "vec.h"
|
||||
|
||||
namespace {
|
||||
|
||||
// [NOTE]: Fused kernel for QKV projection with weight absorption and RoPE
|
||||
//
|
||||
// 1. `q_a_proj` and `kv_a_proj_with_mqa` fused into one gemm,
|
||||
// otherwise we need to split IC for the 2nd gemm.
|
||||
// 2. `q_a_layernorm` and `kv_a_layernorm` fused into one parallel loop.
|
||||
// 3. k_input and v_input share the same storage, the torch API did
|
||||
// this in `set_kv_buffer`. No additional memory movement.
|
||||
//
|
||||
|
||||
// [C0, C1] = A @ [B0, B1]
|
||||
template <typename scalar_t>
|
||||
void segment_gemm_kernel_impl(
|
||||
scalar_t* __restrict__ C0,
|
||||
scalar_t* __restrict__ C1,
|
||||
const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ B0,
|
||||
const scalar_t* __restrict__ B1,
|
||||
int64_t M,
|
||||
int64_t N0,
|
||||
int64_t N1,
|
||||
int64_t K) {
|
||||
// convert_weight_packed make sure N0 and N1 are 32x
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB0 = div_up(N0, BLOCK_N);
|
||||
const int64_t NB1 = div_up(N1, BLOCK_N);
|
||||
const int64_t NB = NB0 + NB1;
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<scalar_t>(M);
|
||||
|
||||
// parallel on [MB, NB0 + NB1]
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t mb{0}, nb{0};
|
||||
data_index_init(begin, mb, MB, nb, NB);
|
||||
|
||||
// for brgemm, use float32 for accumulate
|
||||
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
UNUSED(i);
|
||||
int mb_start = mb * BLOCK_M;
|
||||
int mb_size = std::min(M - mb_start, BLOCK_M);
|
||||
int nb_start = nb * BLOCK_N;
|
||||
int nb_size = BLOCK_N;
|
||||
|
||||
const scalar_t* __restrict__ B = nb < NB0 ? B0 : B1;
|
||||
scalar_t* __restrict__ C = nb < NB0 ? C0 : C1;
|
||||
int64_t ldc = nb < NB0 ? N0 : N1;
|
||||
int64_t local_nb_start = nb < NB0 ? nb_start : nb_start - N0;
|
||||
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A + mb_start * K,
|
||||
/* B */ B + local_nb_start * K /* nb * BLOCK_N * K */,
|
||||
/* C */ C + mb_start * ldc + local_nb_start,
|
||||
/* Ctmp*/ Ctmp,
|
||||
/* M */ mb_size,
|
||||
/* N */ nb_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ nb_size,
|
||||
/* ldc */ ldc,
|
||||
/* brg */ use_brgemm);
|
||||
|
||||
// move to the next index
|
||||
data_index_step(mb, MB, nb, NB);
|
||||
}
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// [C0, C1] = A @ [B0, B1]
|
||||
template <typename scalar_t>
|
||||
void segment_gemm_kernel_impl(
|
||||
scalar_t* __restrict__ C0,
|
||||
scalar_t* __restrict__ C1,
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B0,
|
||||
const int8_t* __restrict__ B1,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs0,
|
||||
const float* __restrict__ Bs1,
|
||||
int64_t M,
|
||||
int64_t N0,
|
||||
int64_t N1,
|
||||
int64_t K) {
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB0 = div_up(N0, BLOCK_N);
|
||||
const int64_t NB1 = div_up(N1, BLOCK_N);
|
||||
const int64_t NB = NB0 + NB1;
|
||||
|
||||
// TODO: brgemm u8s8 depends on PyTorch 2.7 release.
|
||||
const bool use_brgemm = false;
|
||||
|
||||
// K + 4 after compensation
|
||||
const int64_t packed_row_size = get_row_size<int8_t>(K);
|
||||
|
||||
// parallel on [MB, NB0 + NB1]
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t mb{0}, nb{0};
|
||||
data_index_init(begin, mb, MB, nb, NB);
|
||||
|
||||
// for brgemm, use float32 for accumulate
|
||||
alignas(64) int32_t Ctmp[BLOCK_M * BLOCK_N];
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
UNUSED(i);
|
||||
int mb_start = mb * BLOCK_M;
|
||||
int mb_size = std::min(M - mb_start, BLOCK_M);
|
||||
int nb_start = nb * BLOCK_N;
|
||||
int nb_size = BLOCK_N;
|
||||
|
||||
const int8_t* __restrict__ B = nb < NB0 ? B0 : B1;
|
||||
const float* __restrict__ Bs = nb < NB0 ? Bs0 : Bs1;
|
||||
scalar_t* __restrict__ C = nb < NB0 ? C0 : C1;
|
||||
int64_t ldc = nb < NB0 ? N0 : N1;
|
||||
int64_t local_nb_start = nb < NB0 ? nb_start : nb_start - N0;
|
||||
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A + mb_start * K,
|
||||
/* B */ B + local_nb_start * packed_row_size /* nb * BLOCK_N * (K + 4) */,
|
||||
/* C */ C + mb_start * ldc + local_nb_start,
|
||||
/* Ctmp*/ Ctmp,
|
||||
/* As */ As + mb_start,
|
||||
/* Bs */ Bs + local_nb_start,
|
||||
/* M */ mb_size,
|
||||
/* N */ nb_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ nb_size,
|
||||
/* ldc */ ldc,
|
||||
/* brg */ use_brgemm);
|
||||
|
||||
// move to the next index
|
||||
data_index_step(mb, MB, nb, NB);
|
||||
}
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline float reduce(const scalar_t* __restrict__ x, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
fVec sum_fvec = fVec(float(0));
|
||||
|
||||
// no remainder
|
||||
#pragma GCC unroll 4
|
||||
for (int64_t d = 0; d < size; d += bVec::size()) {
|
||||
bVec x_bvec = bVec::loadu(x + d);
|
||||
fVec x_fvec0, x_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||
sum_fvec += x_fvec0 * x_fvec0;
|
||||
sum_fvec += x_fvec1 * x_fvec1;
|
||||
}
|
||||
return vec_reduce_sum(sum_fvec);
|
||||
}
|
||||
|
||||
// map2 from aten functional doesn't have fast bf16->fp32 conversion
|
||||
template <typename scalar_t>
|
||||
inline void map2(scalar_t* y, const scalar_t* x, const scalar_t* __restrict__ w, float scale, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
fVec scale_fvec = fVec(scale);
|
||||
|
||||
// no remainder
|
||||
#pragma GCC unroll 4
|
||||
for (int64_t d = 0; d < size; d += bVec::size()) {
|
||||
bVec x_bvec = bVec::loadu(x + d);
|
||||
fVec x_fvec0, x_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||
bVec w_bvec = bVec::loadu(w + d);
|
||||
fVec w_fvec0, w_fvec1;
|
||||
std::tie(w_fvec0, w_fvec1) = at::vec::convert_to_float(w_bvec);
|
||||
x_fvec0 = x_fvec0 * scale_fvec * w_fvec0;
|
||||
x_fvec1 = x_fvec1 * scale_fvec * w_fvec1;
|
||||
bVec out_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
|
||||
out_bvec.store(y + d);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void rms_norm_kernel_impl(
|
||||
scalar_t* __restrict__ input0,
|
||||
scalar_t* __restrict__ input1,
|
||||
const scalar_t* __restrict__ weight0,
|
||||
const scalar_t* __restrict__ weight1,
|
||||
int64_t M,
|
||||
int64_t N0,
|
||||
int64_t N1,
|
||||
int64_t stride1,
|
||||
float eps = 1e-5) {
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
scalar_t* x0 = input0 + m * N0;
|
||||
scalar_t* x1 = input1 + m * stride1;
|
||||
float scale0 = reduce(x0, N0);
|
||||
float scale1 = reduce(x1, N1);
|
||||
scale0 = float(1) / std::sqrt(scale0 / N0 + eps);
|
||||
scale1 = float(1) / std::sqrt(scale1 / N1 + eps);
|
||||
map2(x0, x0, weight0, scale0, N0);
|
||||
map2(x1, x1, weight1, scale1, N1);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void rotary(const scalar_t* input, scalar_t* out, const scalar_t* cos, const scalar_t* sin, int64_t size) {
|
||||
TORCH_CHECK(false, "rotary scalar path not implemented.");
|
||||
}
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <>
|
||||
inline void rotary<at::BFloat16>(
|
||||
const at::BFloat16* input, at::BFloat16* out, const at::BFloat16* cos, const at::BFloat16* sin, int64_t size) {
|
||||
// permute indices
|
||||
const __m512i idx1 = _mm512_set_epi32(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
||||
const __m512i idx2 = _mm512_set_epi32(31, 29, 27, 25, 23, 21, 19, 17, 15, 13, 11, 9, 7, 5, 3, 1);
|
||||
const __m512i idy1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0);
|
||||
const __m512i idy2 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8);
|
||||
|
||||
// rotary dim is 64, just 2 iters
|
||||
#pragma GCC unroll 2
|
||||
for (int64_t d = 0; d < size; d += 32) {
|
||||
int64_t d2 = d >> 1;
|
||||
// load coefs
|
||||
__m512 vcos = CVT_BF16_TO_FP32(_mm256_loadu_si256(reinterpret_cast<const __m256i*>(cos + d2)));
|
||||
__m512 vsin = CVT_BF16_TO_FP32(_mm256_loadu_si256(reinterpret_cast<const __m256i*>(sin + d2)));
|
||||
// load input
|
||||
__m512i a16 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(input + d));
|
||||
__m512 a = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(a16, 0));
|
||||
__m512 b = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(a16, 1));
|
||||
// from [16, 2] to [2, 16]
|
||||
__m512 in1 = _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b);
|
||||
__m512 in2 = _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b);
|
||||
// out1 = in1 * cos - in2 * sin;
|
||||
// out2 = in2 * cos + in1 * sin
|
||||
__m512 out1 = _mm512_sub_ps(_mm512_mul_ps(in1, vcos), _mm512_mul_ps(in2, vsin));
|
||||
__m512 out2 = _mm512_add_ps(_mm512_mul_ps(in2, vcos), _mm512_mul_ps(in1, vsin));
|
||||
// from [2, 16] to [16, 2]
|
||||
a = _mm512_mask_permutex2var_ps(out1, 0xffff, idy1, out2);
|
||||
b = _mm512_mask_permutex2var_ps(out1, 0xffff, idy2, out2);
|
||||
|
||||
_mm512_storeu_si512(reinterpret_cast<__m512i*>((out + d)), (__m512i)(_mm512_cvtne2ps_pbh(b, a)));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename scalar_t>
|
||||
void rotary_emb_kernel_impl(
|
||||
scalar_t* q_pe_out,
|
||||
scalar_t* k_pe_out,
|
||||
const scalar_t* q_pe,
|
||||
const scalar_t* k_pe,
|
||||
const int64_t* pos,
|
||||
const scalar_t* cos_sin,
|
||||
int64_t num_seqs,
|
||||
int64_t num_heads,
|
||||
int64_t rotary_dim,
|
||||
int64_t q_strideB,
|
||||
int64_t q_strideH,
|
||||
int64_t k_strideB,
|
||||
int64_t oq_strideB,
|
||||
int64_t oq_strideH,
|
||||
int64_t ok_strideB) {
|
||||
TORCH_CHECK(rotary_dim % 32 == 0, "rotary_dim is not 32x.");
|
||||
const int64_t rotary_offset = rotary_dim / 2;
|
||||
|
||||
// parallel on [num_seqs, num_heads + 1]
|
||||
// top [num_heads] handle q_pe and bottom [1] handle k_pe
|
||||
at::parallel_for(0, num_seqs * (num_heads + 1), GRAIN_SIZE / rotary_dim, [&](int64_t begin, int64_t end) {
|
||||
int64_t seq{0}, head_id{0};
|
||||
data_index_init(begin, seq, num_seqs, head_id, num_heads + 1);
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
UNUSED(i);
|
||||
// get cos and sin cache ptr
|
||||
int64_t index = pos[seq];
|
||||
const scalar_t* cos = cos_sin + index * rotary_dim;
|
||||
const scalar_t* sin = cos + rotary_offset;
|
||||
|
||||
const scalar_t* input =
|
||||
(head_id < num_heads) ? q_pe + seq * q_strideB + head_id * q_strideH : k_pe + seq * k_strideB;
|
||||
scalar_t* out =
|
||||
(head_id < num_heads) ? q_pe_out + seq * oq_strideB + head_id * oq_strideH : k_pe_out + seq * ok_strideB;
|
||||
rotary<scalar_t>(input, out, cos, sin, rotary_dim);
|
||||
|
||||
// move to the next index
|
||||
data_index_step(seq, num_seqs, head_id, num_heads + 1);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
extern at::Tensor
|
||||
weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, std::optional<at::Tensor>& bias, bool is_vnni);
|
||||
|
||||
extern at::Tensor int8_scaled_mm_with_quant(
|
||||
at::Tensor& mat1,
|
||||
at::Tensor& mat2,
|
||||
at::Tensor& scales2,
|
||||
std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype,
|
||||
bool is_vnni);
|
||||
|
||||
extern void
|
||||
bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, std::optional<at::Tensor>& scale);
|
||||
|
||||
// NB: shapes in DeepDeek R1
|
||||
//
|
||||
// hidden_states : [num_seqs, hidden_size] [1, 7168]
|
||||
// q_a_proj_weight : [q_lora_rank, hidden_size] [1536, 7168]
|
||||
// q_b_proj_weight : [num_heads * qk_head_dim, q_lora_rank] [4224, 1536]
|
||||
// kv_a_proj_weight : [kv_lora_rank + qk_rope_head_dim, hidden_size] [576, 7168]
|
||||
// w_kc : [num_heads, kv_lora_rank, qk_nope_head_dim] [22, 512, 128]
|
||||
// q_a_layernorm_weight : [q_lora_rank] [1536]
|
||||
// kv_a_layernorm_weight : [kv_lora_rank] [512]
|
||||
//
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
|
||||
at::Tensor& hidden_states,
|
||||
at::Tensor& q_a_proj_weight,
|
||||
at::Tensor& q_b_proj_weight,
|
||||
at::Tensor& kv_a_proj_weight,
|
||||
at::Tensor& w_kc,
|
||||
at::Tensor& q_a_layernorm_weight,
|
||||
at::Tensor& kv_a_layernorm_weight,
|
||||
at::Tensor& positions,
|
||||
at::Tensor& cos_sin_cache,
|
||||
double eps,
|
||||
bool use_int8_w8a8,
|
||||
std::optional<at::Tensor>& q_a_proj_scale,
|
||||
std::optional<at::Tensor>& q_b_proj_scale,
|
||||
std::optional<at::Tensor>& kv_a_proj_scale,
|
||||
bool is_vnni) {
|
||||
RECORD_FUNCTION(
|
||||
"sgl-kernel::qkv_proj_with_rope",
|
||||
std::vector<c10::IValue>({hidden_states, q_a_proj_weight, q_b_proj_weight, kv_a_proj_weight, w_kc}));
|
||||
|
||||
const auto st = hidden_states.scalar_type();
|
||||
CHECK_INPUT(hidden_states);
|
||||
CHECK_INPUT(positions);
|
||||
CHECK_INPUT(cos_sin_cache);
|
||||
CHECK_EQ(q_a_layernorm_weight.scalar_type(), st);
|
||||
CHECK_EQ(kv_a_layernorm_weight.scalar_type(), st);
|
||||
CHECK_EQ(positions.scalar_type(), at::kLong);
|
||||
CHECK_EQ(cos_sin_cache.scalar_type(), st);
|
||||
CHECK_DIM(2, hidden_states);
|
||||
CHECK_DIM(3, w_kc);
|
||||
CHECK_DIM(1, q_a_layernorm_weight);
|
||||
CHECK_DIM(1, kv_a_layernorm_weight);
|
||||
CHECK_DIM(1, positions);
|
||||
CHECK_DIM(2, cos_sin_cache);
|
||||
|
||||
// skip contiguous checks for weights, expect prepacked
|
||||
TORCH_CHECK(is_vnni, "qkv_proj_with_rope: expect weights are prepacked!");
|
||||
|
||||
int64_t num_seqs = hidden_states.size(0);
|
||||
int64_t hidden_size = hidden_states.size(1);
|
||||
int64_t q_lora_rank = q_a_proj_weight.size(0);
|
||||
int64_t num_heads = w_kc.size(0);
|
||||
int64_t kv_lora_rank = w_kc.size(1);
|
||||
int64_t qk_head_dim = q_b_proj_weight.size(0) / num_heads;
|
||||
int64_t qk_nope_head_dim = w_kc.size(2);
|
||||
int64_t qk_rope_head_dim = kv_a_proj_weight.size(0) - kv_lora_rank;
|
||||
int64_t rotary_dim = cos_sin_cache.size(1);
|
||||
|
||||
CHECK_EQ(positions.numel(), num_seqs);
|
||||
CHECK_EQ(rotary_dim, qk_rope_head_dim);
|
||||
CHECK_EQ(q_a_layernorm_weight.numel(), q_lora_rank);
|
||||
CHECK_EQ(kv_a_layernorm_weight.numel(), kv_lora_rank);
|
||||
|
||||
// check the packed dimension
|
||||
CHECK_EQ(q_a_proj_weight.size(1), get_row_size(hidden_size, use_int8_w8a8));
|
||||
CHECK_EQ(q_b_proj_weight.size(1), get_row_size(q_lora_rank, use_int8_w8a8));
|
||||
CHECK_EQ(kv_a_proj_weight.size(1), get_row_size(hidden_size, use_int8_w8a8));
|
||||
|
||||
if (use_int8_w8a8) {
|
||||
TORCH_CHECK(q_a_proj_scale.has_value(), "missing q_a_proj_scale for int8 w8a8.");
|
||||
TORCH_CHECK(q_b_proj_scale.has_value(), "missing q_b_proj_scale for int8 w8a8.");
|
||||
TORCH_CHECK(kv_a_proj_scale.has_value(), "missing kv_a_proj_scale for int8 w8a8.");
|
||||
}
|
||||
|
||||
// outputs and temp buffer
|
||||
const auto options = hidden_states.options();
|
||||
auto q_input = at::empty({num_seqs, num_heads, kv_lora_rank + qk_rope_head_dim}, options);
|
||||
auto k_input = at::empty({num_seqs, 1, kv_lora_rank + qk_rope_head_dim}, options);
|
||||
auto v_input = k_input.narrow(-1, 0, kv_lora_rank);
|
||||
|
||||
// outputs of q_a_proj and q_b_proj
|
||||
auto qa = at::empty({num_seqs, q_lora_rank}, options);
|
||||
|
||||
// stage 1: q_a_proj and kv_a_proj
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "qkv_proj_kernel_impl", [&] {
|
||||
if (use_int8_w8a8) {
|
||||
auto q_a_proj_s = q_a_proj_scale.value();
|
||||
auto kv_a_proj_s = kv_a_proj_scale.value();
|
||||
TORCH_CHECK(q_a_proj_s.numel() == q_lora_rank);
|
||||
TORCH_CHECK(kv_a_proj_s.numel() == kv_lora_rank + qk_rope_head_dim);
|
||||
|
||||
auto buffer = at::empty({num_seqs * hidden_size + num_seqs * 4}, options.dtype(at::kByte));
|
||||
uint8_t* __restrict__ Aq_data = buffer.data_ptr<uint8_t>();
|
||||
float* __restrict__ As_data = (float*)((void*)(Aq_data + num_seqs * hidden_size));
|
||||
const scalar_t* __restrict__ A_data = hidden_states.data_ptr<scalar_t>();
|
||||
|
||||
at::parallel_for(0, num_seqs, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(Aq_data + m * hidden_size, As_data[m], A_data + m * hidden_size, hidden_size);
|
||||
}
|
||||
});
|
||||
|
||||
segment_gemm_kernel_impl<scalar_t>(
|
||||
qa.data_ptr<scalar_t>(),
|
||||
k_input.data_ptr<scalar_t>(),
|
||||
Aq_data,
|
||||
q_a_proj_weight.data_ptr<int8_t>(),
|
||||
kv_a_proj_weight.data_ptr<int8_t>(),
|
||||
As_data,
|
||||
q_a_proj_s.data_ptr<float>(),
|
||||
kv_a_proj_s.data_ptr<float>(),
|
||||
num_seqs,
|
||||
q_lora_rank,
|
||||
kv_lora_rank + qk_rope_head_dim,
|
||||
hidden_size);
|
||||
} else {
|
||||
segment_gemm_kernel_impl<scalar_t>(
|
||||
qa.data_ptr<scalar_t>(),
|
||||
k_input.data_ptr<scalar_t>(),
|
||||
hidden_states.data_ptr<scalar_t>(),
|
||||
q_a_proj_weight.data_ptr<scalar_t>(),
|
||||
kv_a_proj_weight.data_ptr<scalar_t>(),
|
||||
num_seqs,
|
||||
q_lora_rank,
|
||||
kv_lora_rank + qk_rope_head_dim,
|
||||
hidden_size);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 2: apply rmsnorm inplace
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "rms_norm_kernel_impl", [&] {
|
||||
rms_norm_kernel_impl<scalar_t>(
|
||||
qa.data_ptr<scalar_t>(),
|
||||
v_input.data_ptr<scalar_t>(),
|
||||
q_a_layernorm_weight.data_ptr<scalar_t>(),
|
||||
kv_a_layernorm_weight.data_ptr<scalar_t>(),
|
||||
num_seqs,
|
||||
q_lora_rank,
|
||||
kv_lora_rank,
|
||||
kv_lora_rank + qk_rope_head_dim,
|
||||
eps);
|
||||
});
|
||||
|
||||
// stage 3: q_b_proj
|
||||
at::Tensor qb;
|
||||
std::optional<at::Tensor> bias;
|
||||
if (use_int8_w8a8) {
|
||||
qb = int8_scaled_mm_with_quant(qa, q_b_proj_weight, q_b_proj_scale.value(), bias, at::kBFloat16, is_vnni);
|
||||
} else {
|
||||
qb = weight_packed_linear(qa, q_b_proj_weight, bias, is_vnni);
|
||||
}
|
||||
qb.as_strided_({num_seqs, num_heads, qk_head_dim}, {num_heads * qk_head_dim, qk_head_dim, 1});
|
||||
|
||||
// stage 4: bmm
|
||||
std::optional<at::Tensor> scale;
|
||||
auto q_nope = qb.narrow(2, 0, qk_nope_head_dim).transpose_(0, 1);
|
||||
auto q_nope_out = q_input.narrow(2, 0, kv_lora_rank).transpose_(0, 1);
|
||||
bmm_cpu(q_nope_out, q_nope, w_kc, is_vnni, scale);
|
||||
|
||||
// stage 5: rope
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "rotary_emb_kernel_impl", [&] {
|
||||
rotary_emb_kernel_impl<scalar_t>(
|
||||
q_input.data_ptr<scalar_t>() + kv_lora_rank,
|
||||
k_input.data_ptr<scalar_t>() + kv_lora_rank,
|
||||
qb.data_ptr<scalar_t>() + qk_nope_head_dim,
|
||||
k_input.data_ptr<scalar_t>() + kv_lora_rank,
|
||||
positions.data_ptr<int64_t>(),
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
num_seqs,
|
||||
num_heads,
|
||||
rotary_dim,
|
||||
num_heads * qk_head_dim,
|
||||
qk_head_dim,
|
||||
kv_lora_rank + qk_rope_head_dim,
|
||||
num_heads * (kv_lora_rank + qk_rope_head_dim),
|
||||
kv_lora_rank + qk_rope_head_dim,
|
||||
kv_lora_rank + qk_rope_head_dim);
|
||||
});
|
||||
|
||||
return std::make_tuple(q_input, k_input, v_input);
|
||||
}
|
||||
129
sgl-kernel/csrc/cpu/rope.cpp
Normal file
129
sgl-kernel/csrc/cpu/rope.cpp
Normal file
@@ -0,0 +1,129 @@
|
||||
#include "common.h"
|
||||
#include "vec.h"
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t>
|
||||
void rope_kernel_impl(
|
||||
scalar_t* __restrict__ q_pe_out,
|
||||
scalar_t* __restrict__ k_pe_out,
|
||||
int64_t* __restrict__ t_pos,
|
||||
scalar_t* __restrict__ q_pe,
|
||||
scalar_t* __restrict__ k_pe,
|
||||
scalar_t* __restrict__ t_emb_pos,
|
||||
int64_t seq_len,
|
||||
int64_t num_head,
|
||||
int64_t rotary_dim,
|
||||
int64_t HR,
|
||||
int64_t q_pe_stride_s,
|
||||
int64_t out_stride_qs,
|
||||
int64_t out_stride_ks,
|
||||
int64_t HK,
|
||||
int64_t k_pe_stride_s,
|
||||
int64_t q_pe_stride_n,
|
||||
int64_t out_stride_qn) {
|
||||
int64_t COFF = HR / 2;
|
||||
at::parallel_for(0, seq_len * num_head, GRAIN_SIZE / rotary_dim, [&](int64_t begin, int64_t end) {
|
||||
int64_t seq{0}, head_id{0};
|
||||
data_index_init(begin, seq, seq_len, head_id, num_head);
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t in_offset_q = seq * q_pe_stride_s + head_id * q_pe_stride_n;
|
||||
int64_t out_offset_q = seq * out_stride_qs + head_id * out_stride_qn;
|
||||
int64_t out_offset_k = seq * out_stride_ks;
|
||||
int64_t p = 0;
|
||||
scalar_t* sin_start = nullptr;
|
||||
scalar_t* cos_start = nullptr;
|
||||
// step 0) get the rotary position embedding for the current position
|
||||
p = t_pos[seq];
|
||||
sin_start = t_emb_pos + p * HR + COFF;
|
||||
cos_start = t_emb_pos + p * HR;
|
||||
// step 1) apply_rotary_pos_emb for the rotary_dim elements in every
|
||||
// head of query/key
|
||||
for (int64_t h = 0; h < rotary_dim; h += 2) {
|
||||
scalar_t cos = cos_start[h >> 1];
|
||||
scalar_t sin = sin_start[h >> 1];
|
||||
scalar_t in1 = q_pe[in_offset_q + h];
|
||||
scalar_t in2 = q_pe[in_offset_q + h + 1];
|
||||
scalar_t out1 = in1 * cos - in2 * sin;
|
||||
scalar_t out2 = in2 * cos + in1 * sin;
|
||||
q_pe_out[out_offset_q + h] = out1;
|
||||
q_pe_out[out_offset_q + h + 1] = out2;
|
||||
}
|
||||
for (int64_t h = 0; h < HK; h += 2) {
|
||||
scalar_t cos = cos_start[h >> 1];
|
||||
scalar_t sin = sin_start[h >> 1];
|
||||
int64_t k_pe_offset = seq * k_pe_stride_s;
|
||||
scalar_t in1_k = k_pe[k_pe_offset + h];
|
||||
scalar_t in2_k = k_pe[k_pe_offset + h + 1];
|
||||
scalar_t out1_k = in1_k * cos - in2_k * sin;
|
||||
scalar_t out2_k = in2_k * cos + in1_k * sin;
|
||||
k_pe_out[out_offset_k + h] = out1_k;
|
||||
k_pe_out[out_offset_k + h + 1] = out2_k;
|
||||
}
|
||||
// move to the next index
|
||||
data_index_step(seq, seq_len, head_id, num_head);
|
||||
}
|
||||
});
|
||||
}
|
||||
} // namespace
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor>
|
||||
rotary_position_embedding_cpu(at::Tensor& t_pos, at::Tensor& q_pe, at::Tensor& k_pe, at::Tensor& t_emb_pos) {
|
||||
RECORD_FUNCTION(
|
||||
"sgl-kernel::rotary_position_embedding_cpu", std::vector<c10::IValue>({t_pos, q_pe, k_pe, t_emb_pos}));
|
||||
CHECK_INPUT(t_pos);
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_pe);
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_pe);
|
||||
CHECK_INPUT(t_emb_pos);
|
||||
CHECK_DIM(1, t_pos);
|
||||
CHECK_DIM(3, q_pe);
|
||||
CHECK_DIM(3, k_pe);
|
||||
CHECK_DIM(2, t_emb_pos);
|
||||
|
||||
int64_t seq_len = q_pe.size(0);
|
||||
int64_t num_head = q_pe.size(1);
|
||||
int64_t rotary_dim = q_pe.size(2);
|
||||
int64_t HK = k_pe.size(2);
|
||||
int64_t HR = t_emb_pos.size(1);
|
||||
CHECK_EQ(HR, rotary_dim);
|
||||
CHECK_EQ(k_pe.size(0), seq_len);
|
||||
CHECK_EQ(k_pe.size(1), 1);
|
||||
CHECK_EQ(t_pos.size(0), seq_len);
|
||||
CHECK_EQ(HK, rotary_dim);
|
||||
|
||||
at::Tensor q_pe_out = at::empty_like(q_pe);
|
||||
at::Tensor k_pe_out = at::empty_like(k_pe);
|
||||
int64_t q_pe_stride_s = q_pe.stride(0);
|
||||
int64_t q_pe_stride_n = q_pe.stride(1);
|
||||
int64_t k_pe_stride_s = k_pe.stride(0);
|
||||
int64_t out_stride_qs = q_pe_out.stride(0);
|
||||
int64_t out_stride_qn = q_pe_out.stride(1);
|
||||
int64_t out_stride_ks = k_pe_out.stride(0);
|
||||
|
||||
const auto input_dtype = q_pe.scalar_type();
|
||||
TORCH_CHECK(t_pos.scalar_type() == at::kLong, "expect positions to be int64, got ", t_pos.scalar_type());
|
||||
TORCH_CHECK(input_dtype == k_pe.scalar_type(), "q_pe and k_pe must have the same data type");
|
||||
TORCH_CHECK(input_dtype == t_emb_pos.scalar_type(), "q_pe and t_emb_pos must have the same data type");
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(input_dtype, "rotary_position_embedding_cpu", [&] {
|
||||
rope_kernel_impl<scalar_t>(
|
||||
q_pe_out.data_ptr<scalar_t>(),
|
||||
k_pe_out.data_ptr<scalar_t>(),
|
||||
t_pos.data_ptr<int64_t>(),
|
||||
q_pe.data_ptr<scalar_t>(),
|
||||
k_pe.data_ptr<scalar_t>(),
|
||||
t_emb_pos.data_ptr<scalar_t>(),
|
||||
seq_len,
|
||||
num_head,
|
||||
rotary_dim,
|
||||
HR,
|
||||
q_pe_stride_s,
|
||||
out_stride_qs,
|
||||
out_stride_ks,
|
||||
HK,
|
||||
k_pe_stride_s,
|
||||
q_pe_stride_n,
|
||||
out_stride_qn);
|
||||
});
|
||||
return std::make_tuple(q_pe_out, k_pe_out);
|
||||
}
|
||||
659
sgl-kernel/csrc/cpu/shm.cpp
Normal file
659
sgl-kernel/csrc/cpu/shm.cpp
Normal file
@@ -0,0 +1,659 @@
|
||||
#include "shm.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <errno.h>
|
||||
#include <fcntl.h>
|
||||
#include <immintrin.h>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <sys/mman.h>
|
||||
#include <unistd.h>
|
||||
|
||||
// states for collectives
|
||||
enum coll_state {
|
||||
coll_begin = 0,
|
||||
coll_allreduce_naive__copy_in_done,
|
||||
coll_allreduce_naive__reduce_done,
|
||||
// alternative state when allreduce is working on alternative buffer
|
||||
// of the double buffer.
|
||||
coll_alt1_allreduce_naive__copy_in_done,
|
||||
coll_alt2_allreduce_naive__copy_in_done,
|
||||
coll_alt1_allreduce_naive__reduce_done,
|
||||
coll_allgather_naive__copy_in_done,
|
||||
coll_alt1_allgather_naive__copy_in_done,
|
||||
coll_alt2_allgather_naive__copy_in_done,
|
||||
};
|
||||
|
||||
// SHM building blocks
|
||||
struct SharedData {
|
||||
const char* name;
|
||||
int descriptor;
|
||||
void* bytes;
|
||||
size_t nbytes;
|
||||
};
|
||||
|
||||
void shared_open(SharedData* data, const char* name, size_t nbytes) {
|
||||
int d = shm_open(name, O_RDWR, S_IRUSR | S_IWUSR);
|
||||
if (d != -1) {
|
||||
void* bytes = mmap(NULL, nbytes, PROT_READ | PROT_WRITE, MAP_SHARED, d, 0);
|
||||
data->name = name;
|
||||
data->descriptor = d;
|
||||
data->bytes = bytes;
|
||||
data->nbytes = nbytes;
|
||||
} else {
|
||||
if (errno != ENOENT) {
|
||||
// don't print if shm can not be found because we want to loop over from
|
||||
// caller again until the other ranks created the shm
|
||||
printf("shared_open %s failed, errno=%d\n", name, errno);
|
||||
}
|
||||
data->descriptor = -1;
|
||||
}
|
||||
}
|
||||
|
||||
void shared_create(SharedData* data, const char* name, void* bytes, size_t nbytes) {
|
||||
int d = shm_open(name, O_CREAT | O_RDWR, S_IRUSR | S_IWUSR);
|
||||
if (d != -1) {
|
||||
if (nbytes = write(d, bytes, nbytes)) {
|
||||
shared_open(data, name, nbytes);
|
||||
}
|
||||
} else {
|
||||
printf("shared_create %s failed\n", name);
|
||||
}
|
||||
}
|
||||
|
||||
static int world_size;
|
||||
|
||||
// SHM based allreduce helper functions
|
||||
// buffer that holds shm name
|
||||
#define NAME_BUF_SIZE 1000
|
||||
#define MAX_BUF_SIZE 1048576 * 32
|
||||
#define NAIVE_ALLREDUCE_THRESHOLD 1048576
|
||||
#define SHM_BUFFER_NAME "deepspeed_allreduce_buffer"
|
||||
struct allreduce_workspace {
|
||||
enum coll_state states[2]; // idx=0 -- state for symmetric_naive_all_reduce
|
||||
// idx=1 -- state for distributed_naive_all_reduce
|
||||
// double buffer to avoid syncing between rounds
|
||||
// offset=0 -- 2*NAIVE_ALLREDUCE_THRESHOLD : buffer for
|
||||
// symmetric_naive_all_reduce after that : buffer for
|
||||
// distributed_naive_all_reduce
|
||||
char buffer[2 * NAIVE_ALLREDUCE_THRESHOLD + 2 * MAX_BUF_SIZE];
|
||||
};
|
||||
|
||||
#define BUFFER0_OFFSET(current_buffer) current_buffer* NAIVE_ALLREDUCE_THRESHOLD
|
||||
#define BUFFER1_OFFSET(current_buffer) 2 * NAIVE_ALLREDUCE_THRESHOLD + current_buffer* MAX_BUF_SIZE
|
||||
|
||||
struct allreduce_workspace** workspace;
|
||||
|
||||
// buffer for small messages, double buffer
|
||||
char** symmetric_buffer[2];
|
||||
// buffer for large messages, double buffer
|
||||
char** distributed_buffer[2];
|
||||
|
||||
void wait_buffer_state_until_2(int index, enum coll_state state0, enum coll_state state1, int state_group) {
|
||||
volatile enum coll_state* state_ptr = &(workspace[index]->states[state_group]);
|
||||
|
||||
while (1) {
|
||||
volatile enum coll_state cur_state = *state_ptr;
|
||||
if (cur_state == state0 || cur_state == state1) break;
|
||||
}
|
||||
}
|
||||
|
||||
__m512 cvt_bf16_to_fp32(const __m256i src) __attribute__((target("avx512bw")));
|
||||
inline __m512 cvt_bf16_to_fp32(const __m256i src) {
|
||||
auto y = _mm512_cvtepu16_epi32(src);
|
||||
return _mm512_castsi512_ps(_mm512_bslli_epi128(y, 2));
|
||||
}
|
||||
|
||||
inline __m256i cvt_fp32_to_bf16(const __m512 src) __attribute__((target("avx512bw")));
|
||||
inline __m256i cvt_fp32_to_bf16(const __m512 src) {
|
||||
__m512i value = _mm512_castps_si512(src);
|
||||
__m512i nan = _mm512_set1_epi32(0xffff);
|
||||
auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q);
|
||||
__m512i ones = _mm512_set1_epi32(0x1);
|
||||
__m512i vec_bias = _mm512_set1_epi32(0x7fff);
|
||||
// uint32_t lsb = (input >> 16) & 1;
|
||||
auto t_value = _mm512_and_si512(_mm512_srli_epi32(value, 16), ones);
|
||||
// uint32_t rounding_bias = 0x7fff + lsb;
|
||||
t_value = _mm512_add_epi32(t_value, vec_bias);
|
||||
// input += rounding_bias;
|
||||
t_value = _mm512_add_epi32(t_value, value);
|
||||
// input = input >> 16;
|
||||
t_value = _mm512_srli_epi32(t_value, 16);
|
||||
// Check NaN before converting back to bf16
|
||||
t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value);
|
||||
return _mm512_cvtusepi32_epi16(t_value);
|
||||
}
|
||||
|
||||
__m512 cvt_fp16_to_fp32(const __m256i src) __attribute__((target("avx512bw")));
|
||||
inline __m512 cvt_fp16_to_fp32(const __m256i src) {
|
||||
return _mm512_cvtph_ps(src);
|
||||
}
|
||||
|
||||
inline __m256i cvt_fp32_to_fp16(const __m512 src) __attribute__((target("avx512bw")));
|
||||
inline __m256i cvt_fp32_to_fp16(const __m512 src) {
|
||||
return _mm512_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
|
||||
}
|
||||
|
||||
void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
|
||||
__attribute__((target("avx512bw")));
|
||||
|
||||
void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
|
||||
__attribute__((target("avx512bw")));
|
||||
|
||||
void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
|
||||
__attribute__((target("avx512bw")));
|
||||
|
||||
void reduce_all_buffers(
|
||||
int start_elements,
|
||||
int num_elements,
|
||||
c10::ScalarType scalar_type,
|
||||
int to_buffer_idx,
|
||||
char* to_buffer,
|
||||
char** buffers) {
|
||||
switch (scalar_type) {
|
||||
case c10::ScalarType::BFloat16:
|
||||
reduce_bf16_buffers(start_elements, num_elements, to_buffer, buffers);
|
||||
break;
|
||||
case c10::ScalarType::Half:
|
||||
reduce_fp16_buffers(start_elements, num_elements, to_buffer, buffers);
|
||||
break;
|
||||
case c10::ScalarType::Float:
|
||||
reduce_fp32_buffers(start_elements, num_elements, to_buffer, buffers);
|
||||
break;
|
||||
default:
|
||||
assert(!"Should not get here");
|
||||
}
|
||||
}
|
||||
|
||||
#define CVT_ADD_BF16(x) \
|
||||
do { \
|
||||
auto in##x##_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[x] + i))); \
|
||||
inout_val = _mm512_add_ps(inout_val, in##x##_val); \
|
||||
} while (0)
|
||||
|
||||
// Reduce functions down below use vectorized algorithm, the number of bytes
|
||||
// processed each iteration depends on vector length. 256bit vector ==> 32
|
||||
// bytes, 512bit vector ==> 64 bytes If you change implementation of
|
||||
// reduce_bf16_buffers, etc. , check whether this number needs to be changed
|
||||
#define VECTOR_LENGTH_IN_BYTES 32
|
||||
|
||||
void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) {
|
||||
const int element_size = 2;
|
||||
const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size;
|
||||
int main_elements = num_elements - (num_elements % vector_length);
|
||||
int remain_elements = num_elements % vector_length;
|
||||
|
||||
// process aligned part
|
||||
#pragma omp parallel for
|
||||
for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size;
|
||||
i += VECTOR_LENGTH_IN_BYTES) {
|
||||
auto inout_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[0] + i)));
|
||||
switch (world_size) {
|
||||
case 16:
|
||||
CVT_ADD_BF16(15);
|
||||
case 15:
|
||||
CVT_ADD_BF16(14);
|
||||
case 14:
|
||||
CVT_ADD_BF16(13);
|
||||
case 13:
|
||||
CVT_ADD_BF16(12);
|
||||
case 12:
|
||||
CVT_ADD_BF16(11);
|
||||
case 11:
|
||||
CVT_ADD_BF16(10);
|
||||
case 10:
|
||||
CVT_ADD_BF16(9);
|
||||
case 9:
|
||||
CVT_ADD_BF16(8);
|
||||
case 8:
|
||||
CVT_ADD_BF16(7);
|
||||
case 7:
|
||||
CVT_ADD_BF16(6);
|
||||
case 6:
|
||||
CVT_ADD_BF16(5);
|
||||
case 5:
|
||||
CVT_ADD_BF16(4);
|
||||
case 4:
|
||||
CVT_ADD_BF16(3);
|
||||
case 3:
|
||||
CVT_ADD_BF16(2);
|
||||
case 2:
|
||||
CVT_ADD_BF16(1);
|
||||
case 1:
|
||||
break;
|
||||
default:
|
||||
for (int j = 1; j < world_size; j++) {
|
||||
auto in_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[j] + i)));
|
||||
inout_val = _mm512_add_ps(inout_val, in_val);
|
||||
}
|
||||
}
|
||||
_mm256_storeu_si256((__m256i*)(to_buffer + i), cvt_fp32_to_bf16(inout_val));
|
||||
}
|
||||
|
||||
// process remaining part
|
||||
int i = (start_elements + main_elements) * element_size;
|
||||
while (remain_elements > 0) {
|
||||
float val = 0.0f;
|
||||
for (int j = 0; j < world_size; j++) {
|
||||
val += *(at::BFloat16*)(buffers[j] + i);
|
||||
}
|
||||
*(at::BFloat16*)(to_buffer + i) = val;
|
||||
remain_elements--;
|
||||
i += element_size;
|
||||
}
|
||||
}
|
||||
|
||||
#define CVT_ADD_FP16(x) \
|
||||
do { \
|
||||
auto in##x##_val = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[x] + i))); \
|
||||
inout_val = _mm512_add_ps(inout_val, in##x##_val); \
|
||||
} while (0)
|
||||
|
||||
void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) {
|
||||
const int element_size = 2;
|
||||
const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size;
|
||||
int main_elements = num_elements - (num_elements % vector_length);
|
||||
int remain_elements = num_elements % vector_length;
|
||||
|
||||
// process aligned part
|
||||
#pragma omp parallel for
|
||||
for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size;
|
||||
i += VECTOR_LENGTH_IN_BYTES) {
|
||||
auto inout_val = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[0] + i)));
|
||||
switch (world_size) {
|
||||
case 16:
|
||||
CVT_ADD_FP16(15);
|
||||
case 15:
|
||||
CVT_ADD_FP16(14);
|
||||
case 14:
|
||||
CVT_ADD_FP16(13);
|
||||
case 13:
|
||||
CVT_ADD_FP16(12);
|
||||
case 12:
|
||||
CVT_ADD_FP16(11);
|
||||
case 11:
|
||||
CVT_ADD_FP16(10);
|
||||
case 10:
|
||||
CVT_ADD_FP16(9);
|
||||
case 9:
|
||||
CVT_ADD_FP16(8);
|
||||
case 8:
|
||||
CVT_ADD_FP16(7);
|
||||
case 7:
|
||||
CVT_ADD_FP16(6);
|
||||
case 6:
|
||||
CVT_ADD_FP16(5);
|
||||
case 5:
|
||||
CVT_ADD_FP16(4);
|
||||
case 4:
|
||||
CVT_ADD_FP16(3);
|
||||
case 3:
|
||||
CVT_ADD_FP16(2);
|
||||
case 2:
|
||||
CVT_ADD_FP16(1);
|
||||
case 1:
|
||||
break;
|
||||
default:
|
||||
for (int j = 1; j < world_size; j++) {
|
||||
auto in_val = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[j] + i)));
|
||||
inout_val = _mm512_add_ps(inout_val, in_val);
|
||||
}
|
||||
}
|
||||
_mm256_storeu_si256((__m256i*)(to_buffer + i), cvt_fp32_to_fp16(inout_val));
|
||||
}
|
||||
|
||||
// process remaining part
|
||||
int i = (start_elements + main_elements) * element_size;
|
||||
while (remain_elements > 0) {
|
||||
float val = 0.0f;
|
||||
for (int j = 0; j < world_size; j++) {
|
||||
val += *(at::Half*)(buffers[j] + i);
|
||||
}
|
||||
*(at::Half*)(to_buffer + i) = val;
|
||||
remain_elements--;
|
||||
i += element_size;
|
||||
}
|
||||
}
|
||||
|
||||
#define CVT_ADD_F32(x) \
|
||||
do { \
|
||||
auto in##x##_val = _mm256_loadu_ps((float*)(buffers[x] + i)); \
|
||||
inout_val = _mm256_add_ps(inout_val, in##x##_val); \
|
||||
} while (0)
|
||||
|
||||
void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) {
|
||||
const int element_size = 4;
|
||||
const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size;
|
||||
int main_elements = num_elements - (num_elements % vector_length);
|
||||
int remain_elements = num_elements % vector_length;
|
||||
|
||||
// process aligned part
|
||||
#pragma omp parallel for
|
||||
for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size;
|
||||
i += VECTOR_LENGTH_IN_BYTES) {
|
||||
auto inout_val = _mm256_loadu_ps((float*)(buffers[0] + i));
|
||||
switch (world_size) {
|
||||
case 16:
|
||||
CVT_ADD_F32(15);
|
||||
case 15:
|
||||
CVT_ADD_F32(14);
|
||||
case 14:
|
||||
CVT_ADD_F32(13);
|
||||
case 13:
|
||||
CVT_ADD_F32(12);
|
||||
case 12:
|
||||
CVT_ADD_F32(11);
|
||||
case 11:
|
||||
CVT_ADD_F32(10);
|
||||
case 10:
|
||||
CVT_ADD_F32(9);
|
||||
case 9:
|
||||
CVT_ADD_F32(8);
|
||||
case 8:
|
||||
CVT_ADD_F32(7);
|
||||
case 7:
|
||||
CVT_ADD_F32(6);
|
||||
case 6:
|
||||
CVT_ADD_F32(5);
|
||||
case 5:
|
||||
CVT_ADD_F32(4);
|
||||
case 4:
|
||||
CVT_ADD_F32(3);
|
||||
case 3:
|
||||
CVT_ADD_F32(2);
|
||||
case 2:
|
||||
CVT_ADD_F32(1);
|
||||
case 1:
|
||||
break;
|
||||
default:
|
||||
for (int j = 1; j < world_size; j++) {
|
||||
auto in_val = _mm256_loadu_ps((float*)(buffers[j] + i));
|
||||
inout_val = _mm256_add_ps(inout_val, in_val);
|
||||
}
|
||||
}
|
||||
_mm256_storeu_ps((float*)(to_buffer + i), inout_val);
|
||||
}
|
||||
|
||||
// process remaining part
|
||||
int i = (start_elements + main_elements) * element_size;
|
||||
while (remain_elements > 0) {
|
||||
float val = 0.0f;
|
||||
for (int j = 0; j < world_size; j++) {
|
||||
val += *(float*)(buffers[j] + i);
|
||||
}
|
||||
*(float*)(to_buffer + i) = val;
|
||||
remain_elements--;
|
||||
i += element_size;
|
||||
}
|
||||
}
|
||||
|
||||
static bool is_initialized = false;
|
||||
static int world_rank;
|
||||
|
||||
void shm_initialize(int size, int rank, char* addr_string, char* port_string) {
|
||||
if (is_initialized) {
|
||||
return;
|
||||
}
|
||||
is_initialized = true;
|
||||
|
||||
world_size = size;
|
||||
world_rank = rank;
|
||||
|
||||
char shm_name_prefix[NAME_BUF_SIZE];
|
||||
char shm_name[NAME_BUF_SIZE];
|
||||
snprintf(shm_name_prefix, NAME_BUF_SIZE, "%s_%d_%s_%s", SHM_BUFFER_NAME, getuid(), addr_string, port_string);
|
||||
// create shared workspace for SHM based allreduce
|
||||
SharedData allreduce_buffer;
|
||||
// allocate workspace_buf for current rank
|
||||
struct allreduce_workspace* workspace_buf;
|
||||
struct allreduce_workspace* workspace_buf_other;
|
||||
workspace_buf = (struct allreduce_workspace*)malloc(sizeof(struct allreduce_workspace));
|
||||
snprintf(shm_name, NAME_BUF_SIZE, "%s_%d", shm_name_prefix, rank);
|
||||
shared_create(&allreduce_buffer, shm_name, workspace_buf, sizeof(struct allreduce_workspace));
|
||||
workspace_buf = (struct allreduce_workspace*)allreduce_buffer.bytes;
|
||||
workspace_buf->states[0] = coll_alt2_allreduce_naive__copy_in_done;
|
||||
workspace_buf->states[1] = coll_begin;
|
||||
|
||||
// create the workspace pointer list
|
||||
workspace = (struct allreduce_workspace**)malloc(size * sizeof(struct allreduce_workspace*));
|
||||
symmetric_buffer[0] = (char**)malloc(size * sizeof(char**));
|
||||
symmetric_buffer[1] = (char**)malloc(size * sizeof(char**));
|
||||
distributed_buffer[0] = (char**)malloc(size * sizeof(char**));
|
||||
distributed_buffer[1] = (char**)malloc(size * sizeof(char**));
|
||||
|
||||
// map shm of all ranks
|
||||
for (int i = 0; i < size; i++) {
|
||||
if (i != rank) {
|
||||
snprintf(shm_name, NAME_BUF_SIZE, "%s_%d", shm_name_prefix, i);
|
||||
// printf("open %s, %d\n", shm_name, rank);
|
||||
do {
|
||||
shared_open(&allreduce_buffer, shm_name, sizeof(struct allreduce_workspace));
|
||||
} while (allreduce_buffer.descriptor == -1 && errno == ENOENT);
|
||||
workspace_buf_other = (struct allreduce_workspace*)allreduce_buffer.bytes;
|
||||
workspace[i] = workspace_buf_other;
|
||||
} else {
|
||||
workspace[i] = workspace_buf;
|
||||
}
|
||||
symmetric_buffer[0][i] = workspace[i]->buffer + BUFFER0_OFFSET(0);
|
||||
symmetric_buffer[1][i] = workspace[i]->buffer + BUFFER0_OFFSET(1);
|
||||
distributed_buffer[0][i] = workspace[i]->buffer + BUFFER1_OFFSET(0);
|
||||
distributed_buffer[1][i] = workspace[i]->buffer + BUFFER1_OFFSET(1);
|
||||
}
|
||||
}
|
||||
|
||||
static void parallel_memcpy(void* to, void* from, size_t n_bytes) __attribute__((target("avx512bw")));
|
||||
static void parallel_memcpy(void* to, void* from, size_t n_bytes) {
|
||||
auto aligned_bytes = n_bytes - (n_bytes % VECTOR_LENGTH_IN_BYTES);
|
||||
// process aligned part
|
||||
#pragma omp parallel for
|
||||
for (int i = 0; i < aligned_bytes; i += VECTOR_LENGTH_IN_BYTES) {
|
||||
auto val = _mm256_loadu_si256((__m256i*)((char*)from + i));
|
||||
_mm256_storeu_si256((__m256i*)((char*)to + i), val);
|
||||
}
|
||||
|
||||
// process remaining part
|
||||
for (int i = aligned_bytes; i < n_bytes; i++) {
|
||||
*((char*)to + i) = *((char*)from + i);
|
||||
}
|
||||
}
|
||||
|
||||
#define positive_mod(num, mod) ((((num) % (mod)) + (mod)) % (mod))
|
||||
#define rank_mod(rank) positive_mod(rank, world_size)
|
||||
size_t slice_size(size_t chunk_el, int slice_idx) {
|
||||
size_t slice_size = chunk_el / world_size;
|
||||
return slice_idx == world_size - 1 ? slice_size + (chunk_el % world_size) : slice_size;
|
||||
}
|
||||
|
||||
char* slice_data(char* data_ptr, size_t chunk_el, int el_size, int slice_idx) {
|
||||
size_t slice_size = chunk_el / world_size;
|
||||
size_t el_offset = slice_size * slice_idx;
|
||||
return data_ptr + el_offset * el_size;
|
||||
}
|
||||
|
||||
size_t slice_el_start(size_t chunk_el, int slice_idx) {
|
||||
size_t slice_size = chunk_el / world_size;
|
||||
return slice_size * slice_idx;
|
||||
}
|
||||
|
||||
void symmetric_naive_all_reduce(char* data_ptr, c10::ScalarType scalar_type, size_t chunk_size, size_t chunk_el) {
|
||||
const int state_group = 0;
|
||||
static int current_buffer = 0;
|
||||
static int state_idx = 0;
|
||||
|
||||
enum coll_state copy_current, copy_next;
|
||||
|
||||
switch (state_idx) {
|
||||
case 0:
|
||||
copy_current = coll_allreduce_naive__copy_in_done;
|
||||
copy_next = coll_alt1_allreduce_naive__copy_in_done;
|
||||
break;
|
||||
case 1:
|
||||
copy_current = coll_alt1_allreduce_naive__copy_in_done;
|
||||
copy_next = coll_alt2_allreduce_naive__copy_in_done;
|
||||
break;
|
||||
case 2:
|
||||
copy_current = coll_alt2_allreduce_naive__copy_in_done;
|
||||
copy_next = coll_allreduce_naive__copy_in_done;
|
||||
break;
|
||||
default:
|
||||
assert(!"Should not get here.");
|
||||
}
|
||||
state_idx = (state_idx + 1) % 3;
|
||||
|
||||
parallel_memcpy(symmetric_buffer[current_buffer][world_rank], data_ptr, chunk_size);
|
||||
std::atomic_thread_fence(std::memory_order_release);
|
||||
workspace[world_rank]->states[state_group] = copy_current;
|
||||
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
// wait until the other rank copy the buffer
|
||||
if (i != world_rank) {
|
||||
wait_buffer_state_until_2(i, copy_current, copy_next, state_group);
|
||||
}
|
||||
}
|
||||
|
||||
// each rank reduce the buffer independently so therre is no need for
|
||||
// synchronization afterward
|
||||
reduce_all_buffers(0, chunk_el, scalar_type, world_rank, data_ptr, symmetric_buffer[current_buffer]);
|
||||
|
||||
// switch buffer
|
||||
current_buffer = 1 - current_buffer;
|
||||
}
|
||||
|
||||
// naive allreduce distributed, each rank do naive reduce on its slice
|
||||
void distributed_naive_reduce(char* data_ptr, c10::ScalarType scalar_type, size_t chunk_size, size_t chunk_el) {
|
||||
const int state_group = 1;
|
||||
static int current_buffer = 0;
|
||||
static int state_idx = 0;
|
||||
|
||||
enum coll_state copy_current, copy_next, reduce_current;
|
||||
|
||||
// similar to symmetric_naive_allreduce, but here we only need two sets of
|
||||
// states, because distributed naive reduce has two barriers in the algorithm
|
||||
switch (state_idx) {
|
||||
case 0:
|
||||
copy_current = coll_allreduce_naive__copy_in_done;
|
||||
reduce_current = coll_allreduce_naive__reduce_done;
|
||||
copy_next = coll_alt1_allreduce_naive__copy_in_done;
|
||||
break;
|
||||
case 1:
|
||||
copy_current = coll_alt1_allreduce_naive__copy_in_done;
|
||||
reduce_current = coll_alt1_allreduce_naive__reduce_done;
|
||||
copy_next = coll_allreduce_naive__copy_in_done;
|
||||
break;
|
||||
default:
|
||||
assert(!"Should not get here.");
|
||||
}
|
||||
state_idx = (state_idx + 1) % 2;
|
||||
|
||||
int data_size = chunk_size / chunk_el;
|
||||
parallel_memcpy(distributed_buffer[current_buffer][world_rank], data_ptr, chunk_size);
|
||||
std::atomic_thread_fence(std::memory_order_release);
|
||||
workspace[world_rank]->states[state_group] = copy_current;
|
||||
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
// wait until all the other ranks copy the buffer
|
||||
if (i != world_rank) wait_buffer_state_until_2(i, copy_current, reduce_current, state_group);
|
||||
}
|
||||
|
||||
// reduce scatter
|
||||
reduce_all_buffers(
|
||||
slice_el_start(chunk_el, world_rank),
|
||||
slice_size(chunk_el, world_rank),
|
||||
scalar_type,
|
||||
world_rank,
|
||||
distributed_buffer[current_buffer][world_rank],
|
||||
distributed_buffer[current_buffer]);
|
||||
std::atomic_thread_fence(std::memory_order_release);
|
||||
workspace[world_rank]->states[state_group] = reduce_current;
|
||||
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
// wait until all the other ranks reduce the buffer
|
||||
if (i != world_rank) wait_buffer_state_until_2(i, reduce_current, copy_next, state_group);
|
||||
}
|
||||
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
int rank = (i + world_rank) % world_size;
|
||||
parallel_memcpy(
|
||||
slice_data(data_ptr, chunk_el, data_size, rank),
|
||||
slice_data(distributed_buffer[current_buffer][rank], chunk_el, chunk_size / chunk_el, rank),
|
||||
slice_size(chunk_el, rank) * data_size);
|
||||
}
|
||||
|
||||
current_buffer = 1 - current_buffer;
|
||||
}
|
||||
|
||||
void all_reduce_outer_loop(torch::Tensor& data, size_t numel, int data_size) {
|
||||
for (int offset = 0; offset < data_size; offset += MAX_BUF_SIZE) {
|
||||
auto data_ptr = ((char*)(data.data_ptr()) + offset);
|
||||
size_t chunk_size = data_size - offset > MAX_BUF_SIZE ? MAX_BUF_SIZE : data_size - offset;
|
||||
size_t chunk_el = chunk_size / (data_size / numel);
|
||||
if (chunk_size < NAIVE_ALLREDUCE_THRESHOLD) {
|
||||
symmetric_naive_all_reduce(data_ptr, data.scalar_type(), chunk_size, chunk_el);
|
||||
} else {
|
||||
distributed_naive_reduce(data_ptr, data.scalar_type(), chunk_size, chunk_el);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void naive_all_gather(char* result_ptr, char* data_ptr, size_t res_stride, size_t chunk_size, size_t chunk_el) {
|
||||
const int state_group = 1;
|
||||
static int current_buffer = 0;
|
||||
static int state_idx = 0;
|
||||
|
||||
enum coll_state copy_current, copy_next;
|
||||
|
||||
switch (state_idx) {
|
||||
case 0:
|
||||
copy_current = coll_allgather_naive__copy_in_done;
|
||||
copy_next = coll_alt1_allgather_naive__copy_in_done;
|
||||
break;
|
||||
case 1:
|
||||
copy_current = coll_alt1_allgather_naive__copy_in_done;
|
||||
copy_next = coll_alt2_allgather_naive__copy_in_done;
|
||||
break;
|
||||
case 2:
|
||||
copy_current = coll_alt2_allgather_naive__copy_in_done;
|
||||
copy_next = coll_allgather_naive__copy_in_done;
|
||||
break;
|
||||
default:
|
||||
assert(!"Should not get here.");
|
||||
}
|
||||
state_idx = (state_idx + 1) % 3;
|
||||
|
||||
int data_size = chunk_size / chunk_el;
|
||||
parallel_memcpy(distributed_buffer[current_buffer][world_rank], data_ptr, chunk_size);
|
||||
std::atomic_thread_fence(std::memory_order_release);
|
||||
workspace[world_rank]->states[state_group] = copy_current;
|
||||
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
// wait until all the other ranks copy the buffer
|
||||
if (i != world_rank) wait_buffer_state_until_2(i, copy_current, copy_next, state_group);
|
||||
}
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
parallel_memcpy(result_ptr + i * res_stride, distributed_buffer[current_buffer][i], chunk_size);
|
||||
}
|
||||
current_buffer = 1 - current_buffer;
|
||||
}
|
||||
|
||||
torch::Tensor& all_gather(torch::Tensor& result, torch::Tensor& data, int dim, size_t numel, int data_size) {
|
||||
size_t dim_el = data.stride(dim) * data.size(dim);
|
||||
int dtype_size = data_size / numel;
|
||||
size_t dim_size = dim_el * dtype_size;
|
||||
int dim_count = data_size / dim_size;
|
||||
auto data_ptr = (char*)(data.data_ptr());
|
||||
auto result_ptr = (char*)(result.data_ptr());
|
||||
for (int i = 0; i < dim_count; i++) {
|
||||
for (int offset = 0; offset < dim_size; offset += MAX_BUF_SIZE) {
|
||||
size_t chunk_size = dim_size - offset > MAX_BUF_SIZE ? MAX_BUF_SIZE : dim_size - offset;
|
||||
size_t chunk_el = chunk_size / dtype_size;
|
||||
naive_all_gather(
|
||||
result_ptr + i * dim_size * world_size + offset,
|
||||
data_ptr + i * dim_size + offset,
|
||||
dim_size,
|
||||
chunk_size,
|
||||
chunk_el);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
11
sgl-kernel/csrc/cpu/shm.h
Normal file
11
sgl-kernel/csrc/cpu/shm.h
Normal file
@@ -0,0 +1,11 @@
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
||||
|
||||
#ifndef __SHM_COLLECTIVES__
|
||||
#define __SHM_COLLECTIVES__
|
||||
#define VECTOR_LENGTH_IN_BYTES 32
|
||||
void shm_initialize(int size, int rank, char* addr_string, char* port_string);
|
||||
void all_reduce_outer_loop(torch::Tensor& data, size_t numel, int data_size);
|
||||
torch::Tensor& all_gather(torch::Tensor& result, torch::Tensor& data, int dim, size_t numel, int data_size);
|
||||
#endif
|
||||
406
sgl-kernel/csrc/cpu/topk.cpp
Normal file
406
sgl-kernel/csrc/cpu/topk.cpp
Normal file
@@ -0,0 +1,406 @@
|
||||
#include "common.h"
|
||||
#include "vec.h"
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t, int SIZE>
|
||||
inline void softmax(float* __restrict__ out, const scalar_t* __restrict__ input) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
|
||||
constexpr int kVecSize = bVec::size();
|
||||
|
||||
// step 1: get max
|
||||
fVec max_fvec = fVec(-std::numeric_limits<float>::infinity());
|
||||
if constexpr (SIZE < kVecSize) {
|
||||
// SIZE = 1, 2, 4, 8, 16; only the top half is used
|
||||
bVec x_bvec = bVec::loadu(input, SIZE);
|
||||
fVec x_fvec0, x_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||
x_fvec0 = fVec::set(max_fvec, x_fvec0, SIZE);
|
||||
max_fvec = at::vec::maximum(max_fvec, x_fvec0);
|
||||
x_fvec0.store(out, SIZE);
|
||||
} else {
|
||||
for (int d = 0; d < SIZE; d += kVecSize) {
|
||||
bVec x_bvec = bVec::loadu(input + d);
|
||||
fVec x_fvec0, x_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||
|
||||
max_fvec = at::vec::maximum(max_fvec, x_fvec0);
|
||||
max_fvec = at::vec::maximum(max_fvec, x_fvec1);
|
||||
x_fvec0.store(out + d);
|
||||
x_fvec1.store(out + d + fVec::size());
|
||||
}
|
||||
}
|
||||
float max_val = vec_reduce_max(max_fvec);
|
||||
max_fvec = fVec(max_val);
|
||||
|
||||
// step 2: sum of (x - max).exp()
|
||||
fVec sum_fvec = fVec(float(0));
|
||||
if constexpr (SIZE < fVec::size()) {
|
||||
// SIZE = 1, 2, 4, 8
|
||||
fVec x_fvec = (fVec::loadu(out, SIZE) - max_fvec).exp_u20();
|
||||
x_fvec = fVec::set(sum_fvec, x_fvec, SIZE);
|
||||
sum_fvec += x_fvec;
|
||||
x_fvec.store(out, SIZE);
|
||||
} else {
|
||||
for (int d = 0; d < SIZE; d += fVec::size()) {
|
||||
fVec x_fvec = (fVec::loadu(out + d) - max_fvec).exp_u20();
|
||||
sum_fvec += x_fvec;
|
||||
x_fvec.store(out + d);
|
||||
}
|
||||
}
|
||||
float sum_val = vec_reduce_sum(sum_fvec);
|
||||
|
||||
// step 3: x * (1 / sum)
|
||||
sum_fvec = fVec(1.f / sum_val);
|
||||
if constexpr (SIZE < fVec::size()) {
|
||||
// SIZE = 1, 2, 4, 8
|
||||
fVec out_fvec = fVec::loadu(out, SIZE) * sum_fvec;
|
||||
out_fvec.store(out, SIZE);
|
||||
} else {
|
||||
for (int d = 0; d < SIZE; d += fVec::size()) {
|
||||
fVec out_fvec = fVec::loadu(out + d) * sum_fvec;
|
||||
out_fvec.store(out + d);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int NUM_EXPERTS>
|
||||
void grouped_topk_kernel_impl(
|
||||
float* __restrict__ topk_weights,
|
||||
int32_t* __restrict__ topk_ids,
|
||||
const scalar_t* __restrict__ gating_output,
|
||||
int64_t num_tokens,
|
||||
int64_t topk,
|
||||
int64_t num_groups,
|
||||
int64_t topk_group,
|
||||
bool renormalize) {
|
||||
const int64_t num_experts_per_group = NUM_EXPERTS / num_groups;
|
||||
at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
|
||||
alignas(64) float scores[NUM_EXPERTS];
|
||||
|
||||
using elem_t = std::pair<float, int32_t>;
|
||||
std::vector<elem_t> queue(num_groups);
|
||||
std::vector<elem_t> queue2(topk_group * num_experts_per_group);
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
// do softmax to get scores
|
||||
softmax<scalar_t, NUM_EXPERTS>(scores, gating_output + i * NUM_EXPERTS);
|
||||
|
||||
// find max score per group
|
||||
for (int64_t g = 0; g < num_groups; ++g) {
|
||||
float gmax = -std::numeric_limits<float>::infinity();
|
||||
for (int64_t e = 0; e < num_experts_per_group; ++e) {
|
||||
gmax = std::max(gmax, scores[g * num_experts_per_group + e]);
|
||||
}
|
||||
queue[g] = {gmax, g};
|
||||
}
|
||||
|
||||
// find group topk
|
||||
std::partial_sort(
|
||||
queue.begin(), queue.begin() + topk_group, queue.end(), [](const elem_t& x, const elem_t& y) -> bool {
|
||||
return x.first > y.first;
|
||||
});
|
||||
|
||||
for (int64_t g = 0; g < topk_group; ++g) {
|
||||
int32_t group_idx = queue[g].second;
|
||||
for (int64_t e = 0; e < num_experts_per_group; ++e) {
|
||||
int32_t expert_idx = group_idx * num_experts_per_group + e;
|
||||
queue2[g * num_experts_per_group + e] = {scores[expert_idx], expert_idx};
|
||||
}
|
||||
}
|
||||
|
||||
// find global topk
|
||||
std::partial_sort(
|
||||
queue2.begin(), queue2.begin() + topk, queue2.end(), [](const elem_t& x, const elem_t& y) -> bool {
|
||||
return x.first > y.first;
|
||||
});
|
||||
|
||||
for (int64_t j = 0; j < topk; ++j) {
|
||||
topk_weights[i * topk + j] = queue2[j].first;
|
||||
topk_ids[i * topk + j] = queue2[j].second;
|
||||
}
|
||||
|
||||
if (renormalize) {
|
||||
float sum = 0.f;
|
||||
for (int64_t j = 0; j < topk; ++j) {
|
||||
sum += topk_weights[i * topk + j];
|
||||
}
|
||||
float scale = 1.f / sum;
|
||||
for (int64_t j = 0; j < topk; ++j) {
|
||||
topk_weights[i * topk + j] *= scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t, int SIZE>
|
||||
inline void sigmoid(float* __restrict__ out, const scalar_t* __restrict__ input) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
|
||||
const fVec one = fVec(1.f);
|
||||
|
||||
constexpr int kVecSize = bVec::size();
|
||||
for (int d = 0; d < SIZE; d += kVecSize) {
|
||||
bVec x_bvec = bVec::loadu(input + d);
|
||||
fVec x_fvec0, x_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||
|
||||
x_fvec0 = one / (one + x_fvec0.neg().exp_u20());
|
||||
x_fvec1 = one / (one + x_fvec1.neg().exp_u20());
|
||||
|
||||
x_fvec0.store(out + d);
|
||||
x_fvec1.store(out + d + fVec::size());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int SIZE>
|
||||
inline void
|
||||
apply_bias(float* __restrict__ scores2, const float* __restrict__ scores, const scalar_t* __restrict__ bias) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
for (int d = 0; d < SIZE; d += bVec::size()) {
|
||||
bVec bias_vec = bVec::loadu(bias + d);
|
||||
fVec bias0, bias1;
|
||||
std::tie(bias0, bias1) = at::vec::convert_to_float(bias_vec);
|
||||
|
||||
fVec x0 = fVec::loadu(scores + d) + bias0;
|
||||
fVec x1 = fVec::loadu(scores + d + fVec::size()) + bias1;
|
||||
x0.store(scores2 + d);
|
||||
x1.store(scores2 + d + fVec::size());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int NUM_EXPERTS, int TOPK>
|
||||
void biased_grouped_topk_kernel_impl(
|
||||
float* __restrict__ topk_weights,
|
||||
int32_t* __restrict__ topk_ids,
|
||||
const scalar_t* __restrict__ gating_output,
|
||||
const scalar_t* __restrict__ bias,
|
||||
int64_t num_tokens,
|
||||
int64_t num_groups,
|
||||
int64_t topk_group,
|
||||
bool renormalize) {
|
||||
using Vec = at::vec::Vectorized<float>;
|
||||
|
||||
const int64_t num_experts_per_group = NUM_EXPERTS / num_groups;
|
||||
at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
|
||||
// scores: sigmoid
|
||||
alignas(64) float scores[NUM_EXPERTS];
|
||||
// scores for choice: sigmoid + bias
|
||||
alignas(64) float scores2[NUM_EXPERTS];
|
||||
|
||||
using elem_t = std::pair<float, int32_t>;
|
||||
std::vector<elem_t> queue(num_groups);
|
||||
std::vector<elem_t> queue2(topk_group * num_experts_per_group);
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
// do sigmoid to get scores
|
||||
sigmoid<scalar_t, NUM_EXPERTS>(scores, gating_output + i * NUM_EXPERTS);
|
||||
apply_bias<scalar_t, NUM_EXPERTS>(scores2, scores, bias);
|
||||
|
||||
for (int64_t g = 0; g < num_groups; ++g) {
|
||||
// find the max
|
||||
float gmax = at::vec::reduce_all<float>(
|
||||
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); },
|
||||
scores2 + g * num_experts_per_group,
|
||||
num_experts_per_group);
|
||||
|
||||
// find position of first max,
|
||||
// note that we may have multiple max values.
|
||||
int first_max_idx = -1;
|
||||
for (int64_t e = 0; e < num_experts_per_group; ++e) {
|
||||
if (scores2[g * num_experts_per_group + e] == gmax) {
|
||||
first_max_idx = g * num_experts_per_group + e;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// find the 2nd max
|
||||
scores2[first_max_idx] = -std::numeric_limits<float>::infinity();
|
||||
float gmax2 = at::vec::reduce_all<float>(
|
||||
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); },
|
||||
scores2 + g * num_experts_per_group,
|
||||
num_experts_per_group);
|
||||
// restore scores for choice
|
||||
scores2[first_max_idx] = gmax;
|
||||
|
||||
queue[g] = {gmax + gmax2, g};
|
||||
}
|
||||
|
||||
// find group topk
|
||||
std::partial_sort(
|
||||
queue.begin(), queue.begin() + topk_group, queue.end(), [](const elem_t& x, const elem_t& y) -> bool {
|
||||
return x.first > y.first;
|
||||
});
|
||||
|
||||
for (int64_t g = 0; g < topk_group; ++g) {
|
||||
int32_t group_idx = queue[g].second;
|
||||
for (int64_t e = 0; e < num_experts_per_group; ++e) {
|
||||
int32_t expert_idx = group_idx * num_experts_per_group + e;
|
||||
queue2[g * num_experts_per_group + e] = {scores2[expert_idx], expert_idx};
|
||||
}
|
||||
}
|
||||
|
||||
// find global topk
|
||||
std::partial_sort(
|
||||
queue2.begin(), queue2.begin() + TOPK, queue2.end(), [](const elem_t& x, const elem_t& y) -> bool {
|
||||
return x.first > y.first;
|
||||
});
|
||||
|
||||
for (int j = 0; j < TOPK; ++j) {
|
||||
int32_t index = queue2[j].second;
|
||||
topk_ids[i * TOPK + j] = index;
|
||||
topk_weights[i * TOPK + j] = scores[index];
|
||||
}
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
if (renormalize) {
|
||||
__mmask16 mask = (1ULL << TOPK) - 1;
|
||||
__m512 x = _mm512_maskz_loadu_ps(mask, topk_weights + i * TOPK);
|
||||
float sum = _mm512_reduce_add_ps(x);
|
||||
__m512 vscale = _mm512_set1_ps(1.f / sum);
|
||||
__m512 y = _mm512_mul_ps(x, vscale);
|
||||
_mm512_mask_storeu_ps(topk_weights + i * TOPK, mask, y);
|
||||
}
|
||||
#else
|
||||
if (renormalize) {
|
||||
float sum = 0.f;
|
||||
for (int64_t j = 0; j < TOPK; ++j) {
|
||||
sum += topk_weights[i * TOPK + j];
|
||||
}
|
||||
float scale = 1.f / sum;
|
||||
for (int64_t j = 0; j < TOPK; ++j) {
|
||||
topk_weights[i * TOPK + j] *= scale;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#define LAUNCH_GROUPED_TOPK_KERNEL(NE) \
|
||||
grouped_topk_kernel_impl<scalar_t, NE>( \
|
||||
topk_weights.data_ptr<float>(), \
|
||||
topk_ids.data_ptr<int32_t>(), \
|
||||
gating_output.data_ptr<scalar_t>(), \
|
||||
num_tokens, \
|
||||
topk, \
|
||||
num_expert_group, \
|
||||
topk_group, \
|
||||
renormalize);
|
||||
|
||||
#define LAUNCH_BIASED_GROUPED_TOPK_KERNEL(NE, NTOPK) \
|
||||
biased_grouped_topk_kernel_impl<scalar_t, NE, NTOPK>( \
|
||||
topk_weights.data_ptr<float>(), \
|
||||
topk_ids.data_ptr<int32_t>(), \
|
||||
gating_output.data_ptr<scalar_t>(), \
|
||||
correction_bias.data_ptr<scalar_t>(), \
|
||||
num_tokens, \
|
||||
num_expert_group, \
|
||||
topk_group, \
|
||||
renormalize);
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// grouped topk for DeepSeek V2
|
||||
std::tuple<at::Tensor, at::Tensor> grouped_topk_cpu(
|
||||
at::Tensor& hidden_states,
|
||||
at::Tensor& gating_output,
|
||||
int64_t topk,
|
||||
bool renormalize,
|
||||
int64_t num_expert_group,
|
||||
int64_t topk_group) {
|
||||
RECORD_FUNCTION("sgl-kernel::grouped_topk_cpu", std::vector<c10::IValue>({hidden_states, gating_output}));
|
||||
CHECK_INPUT(gating_output);
|
||||
|
||||
const auto st = hidden_states.scalar_type();
|
||||
CHECK_EQ(gating_output.scalar_type(), st);
|
||||
|
||||
int64_t num_tokens = hidden_states.size(0);
|
||||
int64_t num_experts = gating_output.size(1);
|
||||
TORCH_CHECK(gating_output.size(0) == num_tokens, "Number of tokens mismatch");
|
||||
at::Tensor topk_weights = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kFloat));
|
||||
at::Tensor topk_ids = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kInt));
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "grouped_topk_kernel", [&] {
|
||||
switch (num_experts) {
|
||||
case 1:
|
||||
LAUNCH_GROUPED_TOPK_KERNEL(1);
|
||||
break;
|
||||
case 2:
|
||||
LAUNCH_GROUPED_TOPK_KERNEL(2);
|
||||
break;
|
||||
case 4:
|
||||
LAUNCH_GROUPED_TOPK_KERNEL(4);
|
||||
break;
|
||||
case 8:
|
||||
LAUNCH_GROUPED_TOPK_KERNEL(8);
|
||||
break;
|
||||
case 16:
|
||||
LAUNCH_GROUPED_TOPK_KERNEL(16);
|
||||
break;
|
||||
case 32:
|
||||
LAUNCH_GROUPED_TOPK_KERNEL(32);
|
||||
break;
|
||||
case 64:
|
||||
LAUNCH_GROUPED_TOPK_KERNEL(64);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_GROUPED_TOPK_KERNEL(128);
|
||||
break;
|
||||
case 160:
|
||||
LAUNCH_GROUPED_TOPK_KERNEL(160);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_GROUPED_TOPK_KERNEL(256);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unexpected num_experts: ", num_experts);
|
||||
}
|
||||
});
|
||||
return std::make_tuple(topk_weights, topk_ids);
|
||||
}
|
||||
|
||||
// biased grouped topk DeepSeek V3/R1
|
||||
std::tuple<at::Tensor, at::Tensor> biased_grouped_topk_cpu(
|
||||
at::Tensor& hidden_states,
|
||||
at::Tensor& gating_output,
|
||||
at::Tensor& correction_bias,
|
||||
int64_t topk,
|
||||
bool renormalize,
|
||||
int64_t num_expert_group,
|
||||
int64_t topk_group) {
|
||||
RECORD_FUNCTION(
|
||||
"sgl-kernel::biased_grouped_topk_cpu", std::vector<c10::IValue>({hidden_states, gating_output, correction_bias}));
|
||||
|
||||
CHECK_INPUT(gating_output);
|
||||
CHECK_INPUT(correction_bias);
|
||||
|
||||
const auto st = hidden_states.scalar_type();
|
||||
CHECK_EQ(gating_output.scalar_type(), st);
|
||||
CHECK_EQ(correction_bias.scalar_type(), st);
|
||||
|
||||
int64_t num_tokens = hidden_states.size(0);
|
||||
int64_t num_experts = gating_output.size(1);
|
||||
TORCH_CHECK(gating_output.size(0) == num_tokens, "Number of tokens mismatch");
|
||||
TORCH_CHECK(correction_bias.numel() == num_experts, "Bias shape mismatch");
|
||||
at::Tensor topk_weights = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kFloat));
|
||||
at::Tensor topk_ids = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kInt));
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "biased_grouped_topk_kernel", [&] {
|
||||
// NOW only support DSv3 configs
|
||||
TORCH_CHECK(topk == 8, "Unexpected topk: ", topk);
|
||||
switch (num_experts) {
|
||||
case 256:
|
||||
LAUNCH_BIASED_GROUPED_TOPK_KERNEL(256, 8);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unexpected num_experts: ", num_experts);
|
||||
}
|
||||
});
|
||||
return std::make_tuple(topk_weights, topk_ids);
|
||||
}
|
||||
224
sgl-kernel/csrc/cpu/torch_extension_cpu.cpp
Normal file
224
sgl-kernel/csrc/cpu/torch_extension_cpu.cpp
Normal file
@@ -0,0 +1,224 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/extension.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#include "shm.h"
|
||||
|
||||
// silu_and_mul
|
||||
at::Tensor silu_and_mul_cpu(at::Tensor& input);
|
||||
|
||||
// rmsnorm
|
||||
at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps);
|
||||
|
||||
// fused_add_rmsnorm
|
||||
void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps);
|
||||
|
||||
// topk
|
||||
std::tuple<at::Tensor, at::Tensor> grouped_topk_cpu(
|
||||
at::Tensor& hidden_states,
|
||||
at::Tensor& gating_output,
|
||||
int64_t topk,
|
||||
bool renormalize,
|
||||
int64_t num_expert_group,
|
||||
int64_t topk_group);
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> biased_grouped_topk_cpu(
|
||||
at::Tensor& hidden_states,
|
||||
at::Tensor& gating_output,
|
||||
at::Tensor& correction_bias,
|
||||
int64_t topk,
|
||||
bool renormalize,
|
||||
int64_t num_expert_group,
|
||||
int64_t topk_group);
|
||||
|
||||
// attention
|
||||
void decode_attention_cpu(
|
||||
at::Tensor& query,
|
||||
at::Tensor& output,
|
||||
at::Tensor& k_cache,
|
||||
at::Tensor& v_cahce,
|
||||
at::Tensor& attn_logits,
|
||||
at::Tensor& req_to_token,
|
||||
at::Tensor& req_pool_indices,
|
||||
at::Tensor& seq_lens,
|
||||
double sm_scale,
|
||||
double logit_cap);
|
||||
|
||||
void extend_attention_cpu(
|
||||
at::Tensor& q_extend,
|
||||
at::Tensor& k_extend,
|
||||
at::Tensor& v_extend,
|
||||
at::Tensor& o_extend,
|
||||
at::Tensor& k_buffer,
|
||||
at::Tensor& v_buffer,
|
||||
at::Tensor& req_to_token,
|
||||
at::Tensor& req_pool_indices,
|
||||
at::Tensor& seq_lens,
|
||||
at::Tensor& extend_seq_lens,
|
||||
at::Tensor& extend_start_loc,
|
||||
int64_t max_len_extend,
|
||||
double sm_scale,
|
||||
double logit_cap);
|
||||
|
||||
// weight prepack
|
||||
at::Tensor convert_weight_packed(at::Tensor& weight);
|
||||
|
||||
// quant
|
||||
std::tuple<at::Tensor, at::Tensor> per_token_quant_int8_cpu(at::Tensor& A);
|
||||
|
||||
// gemm
|
||||
at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, std::optional<at::Tensor>& bias, bool is_vnni);
|
||||
|
||||
// igemm
|
||||
at::Tensor int8_scaled_mm_cpu(
|
||||
at::Tensor& mat1,
|
||||
at::Tensor& mat2,
|
||||
at::Tensor& scales1,
|
||||
at::Tensor& scales2,
|
||||
std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype,
|
||||
bool is_vnni);
|
||||
|
||||
// quant + igemm
|
||||
at::Tensor int8_scaled_mm_with_quant(
|
||||
at::Tensor& mat1,
|
||||
at::Tensor& mat2,
|
||||
at::Tensor& scales2,
|
||||
std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype,
|
||||
bool is_vnni);
|
||||
|
||||
// bmm
|
||||
void bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, std::optional<at::Tensor>& scale);
|
||||
|
||||
// fused moe
|
||||
at::Tensor fused_experts_cpu(
|
||||
at::Tensor& hidden_states,
|
||||
at::Tensor& w1,
|
||||
at::Tensor& w2,
|
||||
at::Tensor& topk_weights,
|
||||
at::Tensor& topk_ids,
|
||||
bool inplace,
|
||||
bool use_int8_w8a8,
|
||||
std::optional<at::Tensor>& w1_scale,
|
||||
std::optional<at::Tensor>& w2_scale,
|
||||
std::optional<at::Tensor>& a1_scale,
|
||||
std::optional<at::Tensor>& a2_scale,
|
||||
bool is_vnni);
|
||||
|
||||
at::Tensor shared_expert_cpu(
|
||||
at::Tensor& hidden_states,
|
||||
at::Tensor& w1,
|
||||
at::Tensor& w2,
|
||||
at::Tensor& fused_experts_out,
|
||||
double routed_scaling_factor,
|
||||
bool inplace,
|
||||
bool use_int8_w8a8,
|
||||
std::optional<at::Tensor>& w1_scale,
|
||||
std::optional<at::Tensor>& w2_scale,
|
||||
std::optional<at::Tensor>& a1_scale,
|
||||
std::optional<at::Tensor>& a2_scale,
|
||||
bool is_vnni);
|
||||
|
||||
// weight absorption
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
|
||||
at::Tensor& hidden_states,
|
||||
at::Tensor& q_a_proj_weight,
|
||||
at::Tensor& q_b_proj_weight,
|
||||
at::Tensor& kv_a_proj_weight,
|
||||
at::Tensor& w_kc,
|
||||
at::Tensor& q_a_layernorm_weight,
|
||||
at::Tensor& kv_a_layernorm_weight,
|
||||
at::Tensor& positions,
|
||||
at::Tensor& cos_sin_cache,
|
||||
double eps,
|
||||
bool use_int8_w8a8,
|
||||
std::optional<at::Tensor>& q_a_proj_scale,
|
||||
std::optional<at::Tensor>& q_b_proj_scale,
|
||||
std::optional<at::Tensor>& kv_a_proj_scale,
|
||||
bool is_vnni);
|
||||
|
||||
// shared memory init
|
||||
void initialize(int size, int rank);
|
||||
|
||||
// shared mmeory all_reduce
|
||||
void shm_allreduce(at::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, py::object op);
|
||||
|
||||
// shared memory all_gather
|
||||
at::Tensor shm_allgather(at::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, int dim);
|
||||
|
||||
// rope
|
||||
std::tuple<at::Tensor, at::Tensor>
|
||||
rotary_position_embedding_cpu(at::Tensor& t_pos, at::Tensor& q_pe, at::Tensor& k_pe, at::Tensor& t_emb_pos);
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
// activation
|
||||
m.def("silu_and_mul_cpu", &silu_and_mul_cpu, "SiLU and mul for CPU");
|
||||
|
||||
// norm
|
||||
m.def("rmsnorm_cpu", &rmsnorm_cpu, "Root mean square normalization for CPU");
|
||||
m.def("fused_add_rmsnorm_cpu", &fused_add_rmsnorm_cpu, "Fused add root mean square normalization for CPU");
|
||||
|
||||
// topk
|
||||
m.def("grouped_topk_cpu", &grouped_topk_cpu, "Grouped TopK for CPU");
|
||||
|
||||
// biased group topk
|
||||
m.def("biased_grouped_topk_cpu", &biased_grouped_topk_cpu, "Biased Grouped TopK for CPU");
|
||||
|
||||
// decode
|
||||
m.def("decode_attention_cpu", &decode_attention_cpu, "Attention decoding for CPU");
|
||||
|
||||
// extend
|
||||
m.def("extend_attention_cpu", &extend_attention_cpu, "Attention extend for CPU");
|
||||
|
||||
// weight prepack
|
||||
m.def("convert_weight_packed", &convert_weight_packed, "prepack weight to vnni format for intel AMX");
|
||||
|
||||
// quant
|
||||
m.def("per_token_quant_int8_cpu", &per_token_quant_int8_cpu, "dynamic quantization for CPU");
|
||||
|
||||
// gemm
|
||||
m.def("weight_packed_linear", &weight_packed_linear, "weight packed linear for intel AMX");
|
||||
|
||||
// igemm
|
||||
m.def("int8_scaled_mm_cpu", &int8_scaled_mm_cpu, "int8 weight packed linear for intel AMX");
|
||||
|
||||
// quant + igemm
|
||||
m.def(
|
||||
"int8_scaled_mm_with_quant", &int8_scaled_mm_with_quant, "fused per row quant and int8 scaled mm for intel AMX");
|
||||
|
||||
// bmm
|
||||
m.def("bmm_cpu", &bmm_cpu, "bmm kernel for intel AMX");
|
||||
|
||||
// moe
|
||||
m.def("fused_experts_cpu", &fused_experts_cpu, "fused moe kernel for CPU");
|
||||
|
||||
// weight absorption
|
||||
m.def("qkv_proj_with_rope", &qkv_proj_with_rope, "fused qkv projection kernel with weight absorption for intel AMX");
|
||||
|
||||
// shared expert
|
||||
m.def("shared_expert_cpu", &shared_expert_cpu, "shared expert kernel for CPU");
|
||||
|
||||
// all reduce
|
||||
m.def("initialize", &initialize, "shared memory initialization for CPU");
|
||||
m.def("shm_allreduce", &shm_allreduce, "low latency all_reduce implementation for CPU");
|
||||
m.def("shm_allgather", &shm_allgather, "low latency all_gather implementation for CPU");
|
||||
|
||||
// rope
|
||||
m.def("rotary_position_embedding_cpu", &rotary_position_embedding_cpu, "rotary position embedding for CPU");
|
||||
}
|
||||
115
sgl-kernel/csrc/cpu/vec.h
Normal file
115
sgl-kernel/csrc/cpu/vec.h
Normal file
@@ -0,0 +1,115 @@
|
||||
#pragma once
|
||||
|
||||
#if defined(__AVX512F__) && defined(__AVX512BF16__) && defined(__AMX_BF16__)
|
||||
#define CPU_CAPABILITY_AVX512
|
||||
#endif
|
||||
|
||||
#include <ATen/cpu/vec/functional.h>
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace at::vec;
|
||||
|
||||
template <typename scalar_t, typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||
inline Vectorized<scalar_t> convert_from_float_ext(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||
return at::vec::convert_from_float<scalar_t>(a, b);
|
||||
}
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
|
||||
// `at::vec::convert_from_float<>` from PyTorch doesn't have avx512-bf16 intrinsics
|
||||
// use native instruction for bfloat16->float32 conversion
|
||||
template <>
|
||||
inline Vectorized<at::BFloat16>
|
||||
convert_from_float_ext<at::BFloat16>(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||
return (__m512i)(_mm512_cvtne2ps_pbh(__m512(b), __m512(a)));
|
||||
}
|
||||
|
||||
#define CVT_BF16_TO_FP32(a) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16))
|
||||
|
||||
#define CVT_FP16_TO_FP32(a) _mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))
|
||||
|
||||
#endif
|
||||
|
||||
// vector to scalar reduction
|
||||
#if defined(CPU_CAPABILITY_AVX512) && 0
|
||||
inline float vec_reduce_sum(const Vectorized<float>& a) {
|
||||
return _mm512_reduce_add_ps(__m512(a));
|
||||
}
|
||||
|
||||
inline float vec_reduce_max(const Vectorized<float>& a) {
|
||||
return _mm512_reduce_max_ps(__m512(a));
|
||||
}
|
||||
#else
|
||||
inline float vec_reduce_sum(const Vectorized<float>& a) {
|
||||
return vec_reduce_all([](Vectorized<float>& x, Vectorized<float>& y) { return x + y; }, a);
|
||||
}
|
||||
|
||||
inline float vec_reduce_max(const Vectorized<float>& a) {
|
||||
return vec_reduce_all([](Vectorized<float>& x, Vectorized<float>& y) { return maximum(x, y); }, a);
|
||||
}
|
||||
#endif
|
||||
|
||||
// https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
|
||||
template <typename scalar_t>
|
||||
inline void
|
||||
quantize_row_int8(uint8_t* __restrict__ Aq, float& As, const scalar_t* __restrict__ A, int64_t K, float eps = 1e-7) {
|
||||
float amax = 0.f; // absolute max
|
||||
for (int64_t k = 0; k < K; ++k) {
|
||||
const float val = static_cast<float>(A[k]);
|
||||
amax = std::max(amax, std::abs(val));
|
||||
}
|
||||
|
||||
amax = std::max(amax, eps);
|
||||
const float scale = amax / 127;
|
||||
const float inv_scale = 127 / amax;
|
||||
|
||||
for (int64_t k = 0; k < K; ++k) {
|
||||
const float val = static_cast<float>(A[k]) * inv_scale;
|
||||
Aq[k] = (uint8_t)(std::round(val)) + 128;
|
||||
}
|
||||
As = scale;
|
||||
}
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <>
|
||||
inline void quantize_row_int8<at::BFloat16>(
|
||||
uint8_t* __restrict__ Aq, float& As, const at::BFloat16* __restrict__ A, int64_t K, float eps) {
|
||||
const __m512 signBit = _mm512_set1_ps(-0.0f);
|
||||
const __m512i off = _mm512_set1_epi32(128);
|
||||
|
||||
// K is 32x, no remainder
|
||||
float amax = 0.f;
|
||||
__m512 vamax0 = _mm512_set1_ps(0.f);
|
||||
__m512 vamax1 = _mm512_set1_ps(0.f);
|
||||
for (int64_t k = 0; k < K; k += 32) {
|
||||
__m512i va = _mm512_loadu_si512((void*)(A + k));
|
||||
__m512 va0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 0));
|
||||
__m512 va1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 1));
|
||||
vamax0 = _mm512_max_ps(vamax0, _mm512_andnot_ps(signBit, va0));
|
||||
vamax1 = _mm512_max_ps(vamax1, _mm512_andnot_ps(signBit, va1));
|
||||
}
|
||||
amax = _mm512_reduce_max_ps(_mm512_max_ps(vamax0, vamax1));
|
||||
amax = std::max(amax, eps);
|
||||
const float scale = amax / 127;
|
||||
const float inv_scale = 127 / amax;
|
||||
const __m512 vd = _mm512_set1_ps(inv_scale);
|
||||
|
||||
for (int64_t k = 0; k < K; k += 32) {
|
||||
__m512i va = _mm512_loadu_si512((void*)(A + k));
|
||||
__m512 va0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 0));
|
||||
__m512 va1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 1));
|
||||
va0 = _mm512_mul_ps(va0, vd);
|
||||
va1 = _mm512_mul_ps(va1, vd);
|
||||
va0 = _mm512_roundscale_ps(va0, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
|
||||
va1 = _mm512_roundscale_ps(va1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
|
||||
__m128i i0 = _mm512_cvtepi32_epi8(_mm512_add_epi32(_mm512_cvtps_epi32(va0), off));
|
||||
__m128i i1 = _mm512_cvtepi32_epi8(_mm512_add_epi32(_mm512_cvtps_epi32(va1), off));
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(Aq + k), _mm256_set_m128i(i1, i0));
|
||||
}
|
||||
As = scale;
|
||||
}
|
||||
#endif
|
||||
|
||||
} // anonymous namespace
|
||||
95
sgl-kernel/setup_cpu.py
Normal file
95
sgl-kernel/setup_cpu.py
Normal file
@@ -0,0 +1,95 @@
|
||||
# Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from setuptools import find_packages, setup
|
||||
from setuptools.command.build_py import build_py
|
||||
from torch.utils.cpp_extension import BuildExtension, CppExtension
|
||||
|
||||
root = Path(__file__).parent.resolve()
|
||||
|
||||
if "bdist_wheel" in sys.argv and "--plat-name" not in sys.argv:
|
||||
sys.argv.extend(["--plat-name", "manylinux2014_x86_64"])
|
||||
|
||||
|
||||
def _get_version():
|
||||
with open(root / "pyproject.toml") as f:
|
||||
for line in f:
|
||||
if line.startswith("version"):
|
||||
return line.split("=")[1].strip().strip('"')
|
||||
|
||||
|
||||
operator_namespace = "sgl_kernel"
|
||||
include_dirs = []
|
||||
|
||||
sources = [
|
||||
"csrc/cpu/activation.cpp",
|
||||
"csrc/cpu/bmm.cpp",
|
||||
"csrc/cpu/decode.cpp",
|
||||
"csrc/cpu/extend.cpp",
|
||||
"csrc/cpu/gemm.cpp",
|
||||
"csrc/cpu/gemm_int8.cpp",
|
||||
"csrc/cpu/moe.cpp",
|
||||
"csrc/cpu/moe_int8.cpp",
|
||||
"csrc/cpu/norm.cpp",
|
||||
"csrc/cpu/qkv_proj.cpp",
|
||||
"csrc/cpu/topk.cpp",
|
||||
"csrc/cpu/interface.cpp",
|
||||
"csrc/cpu/shm.cpp",
|
||||
"csrc/cpu/torch_extension_cpu.cpp",
|
||||
]
|
||||
|
||||
extra_compile_args = {
|
||||
"cxx": [
|
||||
"-O3",
|
||||
"-Wno-unknown-pragmas",
|
||||
"-march=native",
|
||||
"-fopenmp",
|
||||
]
|
||||
}
|
||||
libraries = ["c10", "torch", "torch_python"]
|
||||
cmdclass = {
|
||||
"build_ext": BuildExtension.with_options(use_ninja=True),
|
||||
}
|
||||
Extension = CppExtension
|
||||
|
||||
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"]
|
||||
|
||||
ext_modules = [
|
||||
Extension(
|
||||
name="sgl_kernel.common_ops",
|
||||
sources=sources,
|
||||
include_dirs=include_dirs,
|
||||
extra_compile_args=extra_compile_args,
|
||||
libraries=libraries,
|
||||
extra_link_args=extra_link_args,
|
||||
py_limited_api=True,
|
||||
),
|
||||
]
|
||||
|
||||
setup(
|
||||
name="sgl-kernel",
|
||||
version=_get_version(),
|
||||
packages=find_packages(where="python"),
|
||||
package_dir={"": "python"},
|
||||
ext_modules=ext_modules,
|
||||
cmdclass=cmdclass,
|
||||
options={"bdist_wheel": {"py_limited_api": "cp39"}},
|
||||
)
|
||||
Reference in New Issue
Block a user