diff --git a/sgl-kernel/csrc/cpu/activation.cpp b/sgl-kernel/csrc/cpu/activation.cpp new file mode 100644 index 000000000..debf5b244 --- /dev/null +++ b/sgl-kernel/csrc/cpu/activation.cpp @@ -0,0 +1,79 @@ +#include "common.h" +#include "vec.h" + +namespace { + +template +void act_and_mul_kernel_impl( + scalar_t* __restrict__ output, + const scalar_t* __restrict__ input, + int64_t num_tokens, + int64_t dim, + const func_t& f, + const vec_func_t& vf) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + + constexpr int64_t kVecSize = bVec::size(); + at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) { + for (int64_t i = begin; i < end; ++i) { + // local ptrs + const scalar_t* __restrict__ input_ptr = input + i * 2 * dim; + const scalar_t* __restrict__ input_other_ptr = input_ptr + dim; + scalar_t* __restrict__ output_ptr = output + i * dim; + + int64_t d; +#pragma GCC unroll 4 + for (d = 0; d <= dim - kVecSize; d += kVecSize) { + bVec x_bvec = bVec::loadu(input_ptr + d); + fVec x_fvec0, x_fvec1; + std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec); + + bVec y_bvec = bVec::loadu(input_other_ptr + d); + fVec y_fvec0, y_fvec1; + std::tie(y_fvec0, y_fvec1) = at::vec::convert_to_float(y_bvec); + + x_fvec0 = vf(x_fvec0); + x_fvec1 = vf(x_fvec1); + + x_fvec0 = x_fvec0 * y_fvec0; + x_fvec1 = x_fvec1 * y_fvec1; + + x_bvec = convert_from_float_ext(x_fvec0, x_fvec1); + x_bvec.store(output_ptr + d); + } +#pragma GCC unroll 4 + for (; d < dim; ++d) { + float x_val = static_cast(input_ptr[d]); + float y_val = static_cast(input_other_ptr[d]); + output_ptr[d] = f(x_val) * y_val; + } + } + }); +} + +} // anonymous namespace + +// input : {num_tokens, 2 * d} +// output : {num_tokens, d} +at::Tensor silu_and_mul_cpu(at::Tensor& input) { + RECORD_FUNCTION("sgl-kernel::silu_and_mul_cpu", std::vector({input})); + auto sizes = input.sizes().vec(); + int64_t last_dim = input.ndimension() - 1; + int64_t d = sizes[last_dim] / 2; + sizes[last_dim] = d; + int64_t num_tokens = input.numel() / input.size(-1); + at::Tensor out = at::empty(sizes, input.options()); + + AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "silu_and_mul", [&] { + using Vec = at::vec::Vectorized; + act_and_mul_kernel_impl( + out.data_ptr(), + input.data_ptr(), + num_tokens, + d, + [](float x) { return x / (1.f + std::exp(-x)); }, + [](Vec x) { return x / (Vec(1.f) + x.neg().exp()); }); + }); + return out; +} diff --git a/sgl-kernel/csrc/cpu/bmm.cpp b/sgl-kernel/csrc/cpu/bmm.cpp new file mode 100644 index 000000000..f7377a09c --- /dev/null +++ b/sgl-kernel/csrc/cpu/bmm.cpp @@ -0,0 +1,122 @@ +#include "common.h" +#include "gemm.h" +#include "vec.h" + +namespace { + +template +void bmm_kernel_impl( + scalar_t* __restrict__ out, + const scalar_t* __restrict__ mat1, + const scalar_t* __restrict__ mat2, + int64_t B, + int64_t M, + int64_t N, + int64_t K, + int64_t mat1_strideB, + int64_t mat1_strideM, + int64_t out_strideB, + int64_t out_strideM, + float scale = 0.f) { + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + // mat2 contiguous in [B, N, K] + int64_t mat2_strideB = N * K; + int64_t mat2_strideN = K; + + const bool use_brgemm = can_use_brgemm(M); + + // parallel on [B, MB, NB] + at::parallel_for(0, B * MB * NB, 0, [&](int64_t begin, int64_t end) { + int64_t bs{0}, mb{0}, nb{0}; + data_index_init(begin, bs, B, mb, MB, nb, NB); + + // for brgemm, use float32 for accumulate + alignas(64) float Ctmp[BLOCK_M * BLOCK_N]; + + for (int i = begin; i < end; ++i) { + UNUSED(i); + int mb_start = mb * BLOCK_M; + int mb_size = std::min(M - mb_start, BLOCK_M); + int nb_start = nb * BLOCK_N; + int nb_size = std::min(N - nb_start, BLOCK_N); + + tinygemm_kernel( + /* A */ mat1 + bs * mat1_strideB + mb_start * mat1_strideM, + /* B */ mat2 + bs * mat2_strideB + nb_start * mat2_strideN /* nb * BLOCK_N * K */, + /* C */ out + bs * out_strideB + mb_start * out_strideM + nb_start, + /* Ctmp*/ Ctmp, + /* M */ mb_size, + /* N */ nb_size, + /* K */ K, + /* lda */ mat1_strideM, + /* ldb */ nb_size, + /* ldc */ out_strideM, + /* brg */ use_brgemm); + + // move to the next index + data_index_step(bs, B, mb, MB, nb, NB); + } + + if (use_brgemm) { + at::native::cpublas::brgemm_release(); + } + }); +} + +} // anonymous namespace + +// mat1 : [B, M, K] +// mat2 : [B, N, K] or [B, OC, IC] +// out : [B, M, N] +// scale: [] 0-dim tensor for per tensor quant +// +void bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, std::optional& scale) { + RECORD_FUNCTION("sgl-kernel::bmm_cpu", std::vector({out, mat1, mat2})); + + auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); + + // input and out could be non-contiguous + // weight needs to be contiguous in [OC, IC] order + CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(out); + CHECK_INPUT(mat2); + CHECK_DIM(3, out); + CHECK_DIM(3, mat1); + CHECK_DIM(3, mat2); + + int64_t B = mat1.size(0); + int64_t M = mat1.size(1); + int64_t N = mat2.size(1); + int64_t K = mat1.size(2); + + TORCH_CHECK(!scale.has_value(), "bmm: do not support fp8 weight for now.") + TORCH_CHECK(N % 32 == 0, "tinygemm requires N to be 32x."); + + int64_t mat1_strideB = mat1.stride(0); + int64_t mat1_strideM = mat1.stride(1); + int64_t out_strideB = out.stride(0); + int64_t out_strideM = out.stride(1); + + // check shapes + TORCH_CHECK(mat2.size(0) == B && mat2.size(2) == K, "bmm: mat2 shape mismatch!"); + TORCH_CHECK(out.size(0) == B && out.size(1) == M, "bmm: out shape mismatch!"); + + AT_DISPATCH_REDUCED_FLOATING_TYPES(mat1.scalar_type(), "bmm_kernel_impl", [&] { + bmm_kernel_impl( + out.data_ptr(), + mat1.data_ptr(), + packed_w.data_ptr(), + B, + M, + N, + K, + mat1_strideB, + mat1_strideM, + out_strideB, + out_strideM); + }); +} diff --git a/sgl-kernel/csrc/cpu/common.h b/sgl-kernel/csrc/cpu/common.h new file mode 100644 index 000000000..0d340a756 --- /dev/null +++ b/sgl-kernel/csrc/cpu/common.h @@ -0,0 +1,164 @@ +#pragma once + +#include +#include +#include + +#if defined(_OPENMP) +#include +#endif + +namespace { + +// dispatch bool +#define AT_DISPATCH_BOOL(BOOL_V, BOOL_NAME, ...) \ + [&] { \ + if (BOOL_V) { \ + constexpr bool BOOL_NAME = true; \ + return __VA_ARGS__(); \ + } else { \ + constexpr bool BOOL_NAME = false; \ + return __VA_ARGS__(); \ + } \ + }() + +// dispatch: bfloat16, float16, int8_t +#define CPU_DISPATCH_PACKED_TYPES(TYPE, ...) \ + [&] { \ + switch (TYPE) { \ + case at::ScalarType::BFloat16: { \ + using packed_t = at::BFloat16; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::Half: { \ + using packed_t = at::Half; \ + return __VA_ARGS__(); \ + } \ + case at::ScalarType::Char: { \ + using packed_t = int8_t; \ + return __VA_ARGS__(); \ + } \ + default: \ + TORCH_CHECK(false, "Unsupported floating data type.\n"); \ + } \ + }() + +#define UNUSED(x) (void)(x) + +#define CHECK_CPU(x) TORCH_CHECK(x.device().type() == at::kCPU, #x " must be a CPU tensor") + +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_LAST_DIM_CONTIGUOUS(x) \ + TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimention") + +#define CHECK_INPUT(x) \ + CHECK_CPU(x); \ + CHECK_CONTIGUOUS(x) +#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \ + CHECK_CPU(x); \ + CHECK_LAST_DIM_CONTIGUOUS(x) + +#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor") + +#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) + +// parallel routines +constexpr int GRAIN_SIZE = 1024; + +template ::value, int>::type = 0> +inline T div_up(T x, T y) { + return (x + y - 1) / y; +} + +template +inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) { +#if 0 + // onednn partition pattern + T& n_my = n_end; + if (nth <= 1 || n == 0) { + n_start = 0; + n_my = n; + } else { + T n1 = div_up(n, nth); + T n2 = n1 - 1; + T T1 = n - n2 * nth; + n_my = ith < T1 ? n1 : n2; + n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2; + } + n_end += n_start; +#else + // pytorch aten partition pattern + T n_my = div_up(n, nth); + n_start = ith * n_my; + n_end = std::min(n_start + n_my, n); +#endif +} + +template +inline void parallel_for(int n, const func_t& f) { +#if defined(_OPENMP) +#pragma omp parallel + { + int nth = omp_get_num_threads(); + int ith = omp_get_thread_num(); + int tbegin, tend; + balance211(n, nth, ith, tbegin, tend); + f(tbegin, tend); + } +#else + f(0, n); +#endif +} + +// data indexing for dimension collapse +template +inline T data_index_init(T offset) { + return offset; +} + +template +inline T data_index_init(T offset, T& x, const T& X, Args&&... args) { + offset = data_index_init(offset, std::forward(args)...); + x = offset % X; + return offset / X; +} + +inline bool data_index_step() { + return true; +} + +template +inline bool data_index_step(T& x, const T& X, Args&&... args) { + if (data_index_step(std::forward(args)...)) { + x = ((x + 1) == X) ? 0 : (x + 1); + return x == 0; + } + return false; +} + +// forced unroll for perf critical path + +#if __has_attribute(always_inline) +#define ALWAYS_INLINE __attribute__((__always_inline__)) inline +#else +#define ALWAYS_INLINE inline +#endif + +template +struct Unroll { + template + ALWAYS_INLINE void operator()(const Func& f, Args... args) const { + Unroll{}(f, args...); + f(std::integral_constant{}, args...); + } +}; + +template <> +struct Unroll<1> { + template + ALWAYS_INLINE void operator()(const Func& f, Args... args) const { + f(std::integral_constant{}, args...); + } +}; + +} // anonymous namespace diff --git a/sgl-kernel/csrc/cpu/decode.cpp b/sgl-kernel/csrc/cpu/decode.cpp new file mode 100644 index 000000000..e469ffdc5 --- /dev/null +++ b/sgl-kernel/csrc/cpu/decode.cpp @@ -0,0 +1,1119 @@ +#include "common.h" +#include "vec.h" + +namespace { + +// [NOTE] TODO list for this kernel: +// 1. tune the value for BLOCK_N +// 2. planning for {batches, num_heads, num_kv_splits} +// and use actual num_kv_splits for small seq length +// 3. try fast impl of `.tanh()` +// 4. provide amx kernel for index_gemm_kernel_nn when M = 16 +// + +inline void fill_stub(float* __restrict__ out, float val, int64_t size) { + using Vec = at::vec::Vectorized; + const Vec data_vec(val); + at::vec::map([data_vec](Vec out) { return out = data_vec; }, out, out, size); +} + +template +inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ acc, float s, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + const fVec s_fvec = fVec(s); + int64_t d = 0; + for (; d <= size - bVec::size(); d += bVec::size()) { + fVec a_fvec0 = fVec::loadu(acc + d) * s_fvec; + fVec a_fvec1 = fVec::loadu(acc + d + fVec::size()) * s_fvec; + bVec out_bvec = convert_from_float_ext(a_fvec0, a_fvec1); + out_bvec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(acc[d] * s); + } +} + +// GEMM handles query @ key (indexed) x scale +// A : [M, K] +// B : [N, K] indexed +// C : [M, N] +// +template +struct tinygemm_kernel_nt { + static inline void apply( + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ B, + float* __restrict__ C, + const index_t* __restrict__ indices, + float scale, + int64_t lda, + int64_t ldb, + int64_t ldc, + int64_t K, + int64_t max_tokens) { + for (int64_t m = 0; m < BLOCK_M; ++m) { + for (int64_t n = 0; n < BLOCK_N; ++n) { + float sum = 0.f; + int64_t b_idx = indices[n]; + TORCH_CHECK(b_idx < max_tokens, "token index out of scope!"); + for (int64_t k = 0; k < K; ++k) { + sum += scale * static_cast(A[m * lda + k]) * static_cast(B[b_idx * ldb + k]); + } + C[m * ldc + n] = sum; + } + } + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_nt { + static inline void apply( + const at::BFloat16* __restrict__ A, + const at::BFloat16* __restrict__ B, + float* __restrict__ C, + const index_t* __restrict__ indices, + float scale, + int64_t lda, + int64_t ldb, + int64_t ldc, + int64_t K, + int64_t max_tokens) { + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N; + + __m512bh va; + __m512bh vb[COLS]; + __m512 vc[ROWS * COLS]; + __m512 vscale = _mm512_set1_ps(scale); + + auto loadc = [&](auto i) { vc[i] = _mm512_setzero_ps(); }; + Unroll{}(loadc); + + // for main loop + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = (__m512bh)(_mm512_loadu_si512(A + row * lda + k)); + } + if constexpr (row == 0) { + if constexpr (col + 1 < COLS) { + int64_t b_idx_prefetch = indices[col + 1]; + _mm_prefetch(B + b_idx_prefetch * ldb + k, _MM_HINT_T0); + } + int64_t b_idx = indices[col]; + TORCH_CHECK(b_idx < max_tokens, "token index out of scope!"); + vb[col] = (__m512bh)(_mm512_loadu_si512(B + b_idx * ldb + k)); + } + vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]); + }; + + // for remainder + auto compute2 = [&](auto i, int64_t k, __mmask32 mask) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = (__m512bh)(_mm512_maskz_loadu_epi16(mask, A + row * lda + k)); + } + if constexpr (row == 0) { + int64_t b_idx = indices[col]; + TORCH_CHECK(b_idx < max_tokens, "token index out of scope!"); + vb[col] = (__m512bh)(_mm512_maskz_loadu_epi16(mask, B + b_idx * ldb + k)); + } + vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]); + }; + + int64_t k = 0; + for (; k <= K - 32; k += 32) { + Unroll{}(compute, k); + } + int64_t count = K - k; + if (count > 0) { + __mmask32 mask = (1ULL << count) - 1; + Unroll{}(compute2, k, mask); + } + + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + C[row * ldc + col] = _mm512_reduce_add_ps(_mm512_mul_ps(vc[i], vscale)); + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_NT(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_nt::apply( \ + A + mb_start * lda, B, C + mb_start * ldc + nb_start, indices + nb_start, scale, lda, ldb, ldc, K, max_tokens); + +// this is used when N isn't multiple of 16, +// N corresponds to `head_size_v` which should be 16x +template +inline void tinygemm_kernel_nn_scalar( + const float* __restrict__ A, + const scalar_t* __restrict__ B, + float* __restrict__ C, + const index_t* __restrict__ indices, + const float* __restrict__ scale, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + int64_t max_tokens) { + for (int64_t m = 0; m < M; ++m) { + for (int64_t n = 0; n < N; ++n) { + C[m * ldc + n] *= scale[m]; + for (int64_t k = 0; k < K; ++k) { + int64_t b_idx = indices[k]; + TORCH_CHECK(b_idx < max_tokens, "token index out of scope!"); + C[m * ldc + n] += A[m * lda + k] * static_cast(B[b_idx * ldb + n]); + } + } + } +} + +// GEMM handles v' * scale + attn @ value (indexed) +// A : [M, K] +// B : [K, N] indexed +// C :[M, N] +// +template +struct tinygemm_kernel_nn { + static inline void apply( + const float* __restrict__ A, + const scalar_t* __restrict__ B, + float* __restrict__ C, + const index_t* __restrict__ indices, + const float* __restrict__ scale, + int64_t lda, + int64_t ldb, + int64_t ldc, + int64_t K, + int64_t max_tokens) { + tinygemm_kernel_nn_scalar(A, B, C, indices, scale, BLOCK_M, BLOCK_N, K, lda, ldb, ldc, max_tokens); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_nn { + static inline void apply( + const float* __restrict__ A, + const at::BFloat16* __restrict__ B, + float* __restrict__ C, + const index_t* __restrict__ indices, + const float* __restrict__ scale, + int64_t lda, + int64_t ldb, + int64_t ldc, + int64_t K, + int64_t max_tokens) { + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + + __m512 va; + __m512 vb[COLS]; + __m512 vc[ROWS * COLS]; + __m512 vscale; + + auto loadc = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; +#pragma GCC diagnostic push +#pragma GCC diagnostic ignored "-Warray-bounds" + if constexpr (col == 0) { + vscale = _mm512_set1_ps(scale[row]); + } +#pragma GCC diagnostic pop + vc[i] = _mm512_loadu_ps(C + row * ldc + col * 16); + vc[i] = _mm512_mul_ps(vc[i], vscale); + }; + Unroll{}(loadc); + + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = _mm512_set1_ps(A[row * lda + k]); + } + if constexpr (row == 0) { + if (k + 1 < K) { + int64_t b_idx_prefetch = indices[k + 1]; + _mm_prefetch(B + b_idx_prefetch * ldb + col * 16, _MM_HINT_T0); + } + int64_t b_idx = indices[k]; + TORCH_CHECK(b_idx < max_tokens, "token index out of scope!"); + + // for COLS = 2, 4, 6, 8 use 512 bit load + // for COLS = 1, 3, 5, 7 use 256 bit load + if constexpr (COLS % 2 == 0) { + if constexpr (col % 2 == 0) { + __m512i b16 = _mm512_loadu_si512(reinterpret_cast(B + b_idx * ldb + col * 16)); + vb[col + 0] = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(b16, 0)); + vb[col + 1] = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(b16, 1)); + } + } else { + __m256i b16 = _mm256_loadu_si256(reinterpret_cast(B + b_idx * ldb + col * 16)); + vb[col] = CVT_BF16_TO_FP32(b16); + } + } + vc[i] = _mm512_fmadd_ps(va, vb[col], vc[i]); + }; + + for (int64_t k = 0; k < K; ++k) { + Unroll{}(compute, k); + } + + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + _mm512_storeu_ps(C + row * ldc + col * 16, vc[i]); + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_nn::apply( \ + A + mb_start * lda, \ + B + nb_start, \ + C + mb_start * ldc + nb_start, \ + indices, \ + scale + mb_start, \ + lda, \ + ldb, \ + ldc, \ + K, \ + max_tokens); + +template +void index_gemm_kernel_nt( + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ B, + float* __restrict__ C, + const index_t* __restrict__ indices, + float scale, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + int64_t max_tokens) { + // pattern: 1-8-8 + if (M == 1) { + constexpr int64_t BLOCK_N = 8; + const int64_t NB = div_up(N, BLOCK_N); + int64_t mb_start = 0, lda = 1, ldc = 1; + + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch (nb_size) { + case 1: + LAUNCH_TINYGEMM_KERNEL_NT(1, 1); + break; + case 2: + LAUNCH_TINYGEMM_KERNEL_NT(1, 2); + break; + case 3: + LAUNCH_TINYGEMM_KERNEL_NT(1, 3); + break; + case 4: + LAUNCH_TINYGEMM_KERNEL_NT(1, 4); + break; + case 5: + LAUNCH_TINYGEMM_KERNEL_NT(1, 5); + break; + case 6: + LAUNCH_TINYGEMM_KERNEL_NT(1, 6); + break; + case 7: + LAUNCH_TINYGEMM_KERNEL_NT(1, 7); + break; + case 8: + LAUNCH_TINYGEMM_KERNEL_NT(1, 8); + break; + default: + TORCH_CHECK(false, "Unexpected block size, 1x", "nb_size"); + } + } + return; + } + + // pattern: 1-6-24 + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 6; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + for (int64_t mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch (mb_size << 4 | nb_size) { + // mb_size = 1 + case 0x11: + LAUNCH_TINYGEMM_KERNEL_NT(1, 1); + break; + case 0x12: + LAUNCH_TINYGEMM_KERNEL_NT(1, 2); + break; + case 0x13: + LAUNCH_TINYGEMM_KERNEL_NT(1, 3); + break; + case 0x14: + LAUNCH_TINYGEMM_KERNEL_NT(1, 4); + break; + case 0x15: + LAUNCH_TINYGEMM_KERNEL_NT(1, 5); + break; + case 0x16: + LAUNCH_TINYGEMM_KERNEL_NT(1, 6); + break; + // mb_size = 2 + case 0x21: + LAUNCH_TINYGEMM_KERNEL_NT(2, 1); + break; + case 0x22: + LAUNCH_TINYGEMM_KERNEL_NT(2, 2); + break; + case 0x23: + LAUNCH_TINYGEMM_KERNEL_NT(2, 3); + break; + case 0x24: + LAUNCH_TINYGEMM_KERNEL_NT(2, 4); + break; + case 0x25: + LAUNCH_TINYGEMM_KERNEL_NT(2, 5); + break; + case 0x26: + LAUNCH_TINYGEMM_KERNEL_NT(2, 6); + break; + // mb_size = 3 + case 0x31: + LAUNCH_TINYGEMM_KERNEL_NT(3, 1); + break; + case 0x32: + LAUNCH_TINYGEMM_KERNEL_NT(3, 2); + break; + case 0x33: + LAUNCH_TINYGEMM_KERNEL_NT(3, 3); + break; + case 0x34: + LAUNCH_TINYGEMM_KERNEL_NT(3, 4); + break; + case 0x35: + LAUNCH_TINYGEMM_KERNEL_NT(3, 5); + break; + case 0x36: + LAUNCH_TINYGEMM_KERNEL_NT(3, 6); + break; + // mb_size = 4 + case 0x41: + LAUNCH_TINYGEMM_KERNEL_NT(4, 1); + break; + case 0x42: + LAUNCH_TINYGEMM_KERNEL_NT(4, 2); + break; + case 0x43: + LAUNCH_TINYGEMM_KERNEL_NT(4, 3); + break; + case 0x44: + LAUNCH_TINYGEMM_KERNEL_NT(4, 4); + break; + case 0x45: + LAUNCH_TINYGEMM_KERNEL_NT(4, 5); + break; + case 0x46: + LAUNCH_TINYGEMM_KERNEL_NT(4, 6); + break; + default: + TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +template +void index_gemm_kernel_nn( + const float* __restrict__ A, + const scalar_t* __restrict__ B, + float* __restrict__ C, + const index_t* __restrict__ indices, + float* __restrict__ scale, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + int64_t max_tokens) { + constexpr int kVecSize = 16; + if ((N & (kVecSize - 1)) != 0) { + tinygemm_kernel_nn_scalar(A, B, C, indices, scale, M, N, K, lda, ldb, ldc, max_tokens); + return; + } + + // pattern: 1-8-8 + if (M == 1) { + constexpr int64_t BLOCK_N = 8 * kVecSize; + const int64_t NB = div_up(N, BLOCK_N); + int64_t mb_start = 0, lda = 1, ldc = 1; + + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch (nb_size >> 4) { + case 1: + LAUNCH_TINYGEMM_KERNEL_NN(1, 16); + break; + case 2: + LAUNCH_TINYGEMM_KERNEL_NN(1, 32); + break; + case 3: + LAUNCH_TINYGEMM_KERNEL_NN(1, 48); + break; + case 4: + LAUNCH_TINYGEMM_KERNEL_NN(1, 64); + break; + case 5: + LAUNCH_TINYGEMM_KERNEL_NN(1, 80); + break; + case 6: + LAUNCH_TINYGEMM_KERNEL_NN(1, 96); + break; + case 7: + LAUNCH_TINYGEMM_KERNEL_NN(1, 112); + break; + case 8: + LAUNCH_TINYGEMM_KERNEL_NN(1, 128); + break; + default: + TORCH_CHECK(false, "Unexpected block size, 1x", "nb_size"); + } + } + return; + } + + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 6 * kVecSize; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + for (int64_t mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch (mb_size << 4 | nb_size >> 4) { + // mb_size = 1 + case 0x11: + LAUNCH_TINYGEMM_KERNEL_NN(1, 16); + break; + case 0x12: + LAUNCH_TINYGEMM_KERNEL_NN(1, 32); + break; + case 0x13: + LAUNCH_TINYGEMM_KERNEL_NN(1, 48); + break; + case 0x14: + LAUNCH_TINYGEMM_KERNEL_NN(1, 64); + break; + case 0x15: + LAUNCH_TINYGEMM_KERNEL_NN(1, 80); + break; + case 0x16: + LAUNCH_TINYGEMM_KERNEL_NN(1, 96); + break; + // mb_size = 2 + case 0x21: + LAUNCH_TINYGEMM_KERNEL_NN(2, 16); + break; + case 0x22: + LAUNCH_TINYGEMM_KERNEL_NN(2, 32); + break; + case 0x23: + LAUNCH_TINYGEMM_KERNEL_NN(2, 48); + break; + case 0x24: + LAUNCH_TINYGEMM_KERNEL_NN(2, 64); + break; + case 0x25: + LAUNCH_TINYGEMM_KERNEL_NN(2, 80); + break; + case 0x26: + LAUNCH_TINYGEMM_KERNEL_NN(2, 96); + break; + // mb_size = 3 + case 0x31: + LAUNCH_TINYGEMM_KERNEL_NN(3, 16); + break; + case 0x32: + LAUNCH_TINYGEMM_KERNEL_NN(3, 32); + break; + case 0x33: + LAUNCH_TINYGEMM_KERNEL_NN(3, 48); + break; + case 0x34: + LAUNCH_TINYGEMM_KERNEL_NN(3, 64); + break; + case 0x35: + LAUNCH_TINYGEMM_KERNEL_NN(3, 80); + break; + case 0x36: + LAUNCH_TINYGEMM_KERNEL_NN(3, 96); + break; + // mb_size = 4 + case 0x41: + LAUNCH_TINYGEMM_KERNEL_NN(4, 16); + break; + case 0x42: + LAUNCH_TINYGEMM_KERNEL_NN(4, 32); + break; + case 0x43: + LAUNCH_TINYGEMM_KERNEL_NN(4, 48); + break; + case 0x44: + LAUNCH_TINYGEMM_KERNEL_NN(4, 64); + break; + case 0x45: + LAUNCH_TINYGEMM_KERNEL_NN(4, 80); + break; + case 0x46: + LAUNCH_TINYGEMM_KERNEL_NN(4, 96); + break; + default: + TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +template +void decode_attention_kernel_impl( + scalar_t* __restrict__ output, + float* __restrict__ attn_logits, + const scalar_t* __restrict__ query, + const scalar_t* __restrict__ k_buffer, + const scalar_t* __restrict__ v_buffer, + const index_t* __restrict__ req_to_token, + const int64_t* __restrict__ req_pool_indices, + const int64_t* __restrict__ seq_lens, + int64_t batches, + int64_t num_heads, + int64_t head_size, + int64_t head_size_v, + int64_t num_kv_splits, + int64_t k_strideN, + int64_t k_strideH, + int64_t v_strideN, + int64_t v_strideH, + float scaling, + float logit_cap, + int64_t max_num_reqs, + int64_t max_context_len, + int64_t max_total_num_tokens) { + using Vec = at::vec::Vectorized; + + // block length for k_buffer and v_buffer + constexpr int64_t BLOCK_N = 256; + + // strides + const int64_t q_strideM = num_heads * head_size; + const int64_t q_strideH = head_size; + const int64_t l_stride1 = num_kv_splits * (head_size_v + 1); + const int64_t l_stride2 = head_size_v + 1; + + const bool has_logit_cap = logit_cap > 0; + float rlogit_cap = has_logit_cap ? 1 / logit_cap : 0.f; + + // parallel on [batches, num_heads, num_kv_splits] + at::parallel_for(0, batches * num_heads * num_kv_splits, 0, [&](int64_t begin, int64_t end) { + int64_t bs{0}, head_id{0}, kv_id{0}; + data_index_init(begin, bs, batches, head_id, num_heads, kv_id, num_kv_splits); + + // s_prime and s_delta + alignas(64) float s_i[BLOCK_N]; + float* __restrict__ s_delta = s_i; + + for (int64_t i = begin; i < end; ++i) { + // get query + const scalar_t* __restrict__ q_ptr = query + bs * q_strideM + head_id * q_strideH; + + // get key/value + int64_t seq_len_kv = seq_lens[bs]; + int64_t req_pool_id = req_pool_indices[bs]; + TORCH_CHECK(seq_len_kv <= max_context_len, "seq_len_kv out of scope!"); + TORCH_CHECK(req_pool_id < max_num_reqs, "req_pool_id out of scope!"); + + const int64_t SPLIT_SIZE = div_up(seq_len_kv, num_kv_splits); + const int64_t kv_start = kv_id * SPLIT_SIZE; + const int64_t kv_end = std::min(kv_start + SPLIT_SIZE, seq_len_kv); + + float m_prime = -std::numeric_limits::infinity(); + float s_prime = 0.f; + + // get v_prime, and init to zero + float* __restrict__ v_prime = attn_logits + i * (head_size_v + 1); + fill_stub(v_prime, 0.f, head_size_v); + + // loop over K and V sequence with BLOCK_N + for (int64_t n = kv_start; n < kv_end; n += BLOCK_N) { + int64_t n_size = std::min(BLOCK_N, kv_end - n); + + // calculate s_i <- scale * Q @ K + index_gemm_kernel_nt( + /* A */ q_ptr, + /* B */ k_buffer + head_id * k_strideH, + /* C */ s_i, + /* ind */ req_to_token + req_pool_id * max_context_len + n, + /* scl */ scaling, + /* M */ 1, + /* N */ n_size, + /* K */ head_size, + /* lda */ 1, + /* ldb */ k_strideN, + /* ldc */ 1, + /* mtt */ max_total_num_tokens); + + // TODO: `tanh` from torch uses sleef u10, going to be slow + if (has_logit_cap) { + at::vec::map( + [logit_cap, rlogit_cap](Vec x) { return Vec(logit_cap) * (x * Vec(rlogit_cap)).tanh(); }, + s_i, + s_i, + n_size); + } + + // m_i: max value per row + float m_i = at::vec::reduce_all([](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, s_i, n_size); + m_i = std::max(m_i, m_prime); + + // m_delta <- exp(m' - m_i) + float m_delta = std::exp(m_prime - m_i); + + // s_delta <- exp(s_i - m_i) + at::vec::map([m_i](Vec x) { return (x - Vec(m_i)).exp_u20(); }, s_delta, s_i, n_size); + + // s' <- s' * m_delta + sum(s_delta) + s_prime *= m_delta; + s_prime += at::vec::reduce_all([](Vec& x, Vec& y) { return x + y; }, s_delta, n_size); + + m_prime = m_i; + + // caculate V' <- s_delta @ V + V' * m_delta + index_gemm_kernel_nn( + /* A */ s_delta, + /* B */ v_buffer + head_id * v_strideH, + /* C */ v_prime, + /* ind */ req_to_token + req_pool_id * max_context_len + n, + /* scl */ &m_delta, + /* M */ 1, + /* N */ head_size_v, + /* K */ n_size, + /* lda */ 1, + /* ldb */ v_strideN, + /* ldc */ 1, + /* mtt */ max_total_num_tokens); + } // loop with KV blocks + + // only update v' when kv_split_size > 0 + if (kv_end > kv_start) { + float s = 1 / s_prime; + at::vec::map([s](Vec out) { return out * Vec(s); }, v_prime, v_prime, head_size_v); + + v_prime[head_size_v] = m_prime + std::log(s_prime); + } + + // move to the next index + data_index_step(bs, batches, head_id, num_heads, kv_id, num_kv_splits); + } + }); + + // parallel on [batches, num_heads] + at::parallel_for(0, batches * num_heads, 0, [&](int64_t begin, int64_t end) { + // NB: here we use logits[b][h][0] as acc, since + // for the first kv split (kv_id == 0): + // m_delta = std::exp(-inf) = 0 + // e_logic = std::exp(0) = 1 + // acc = acc * m_delta + tv * e_logic = tv + for (int64_t i = begin; i < end; ++i) { + float* __restrict__ acc = attn_logits + i * l_stride1; + + float s_prime = 0.f; + float m_prime = -std::numeric_limits::infinity(); + + // update acc with from each kv_split + for (int64_t kv_id = 0; kv_id < num_kv_splits; ++kv_id) { + float* __restrict__ tv = acc + kv_id * l_stride2; + const float tlogic = (acc + kv_id * l_stride2)[head_size_v]; + + float m_i = std::max(tlogic, m_prime); + float m_delta = std::exp(m_prime - m_i); + float e_logic = std::exp(tlogic - m_i); + if (kv_id != 0) { + at::vec::map2( + [m_delta, e_logic](Vec x, Vec y) { return x * Vec(m_delta) + y * Vec(e_logic); }, + acc, + acc, + tv, + head_size_v); + } + + s_prime = s_prime * m_delta + e_logic; + m_prime = m_i; + } + + copy_stub(output + i * head_size_v, acc, 1 / s_prime, head_size_v); + } + }); +} + +template +void decode_attention_grouped_kernel_impl( + scalar_t* __restrict__ output, + float* __restrict__ attn_logits, + const scalar_t* __restrict__ query, + const scalar_t* __restrict__ k_buffer, + const scalar_t* __restrict__ v_buffer, + const index_t* __restrict__ req_to_token, + const int64_t* __restrict__ req_pool_indices, + const int64_t* __restrict__ seq_lens, + int64_t batches, + int64_t num_heads, + int64_t num_heads_kv, + int64_t head_size, + int64_t head_size_v, + int64_t num_kv_splits, + int64_t k_strideN, + int64_t k_strideH, + int64_t v_strideN, + int64_t v_strideH, + float scaling, + float logit_cap, + int64_t max_num_reqs, + int64_t max_context_len, + int64_t max_total_num_tokens) { + using Vec = at::vec::Vectorized; + + // block length for k_buffer and v_buffer + constexpr int64_t BLOCK_N = 256; + // block length for heads + // we parallel on [batches, divup(num_heads, BLOCK_H), num_kv_splits] + // use smaller BLOCK_H when batches is small to utilize all cores + constexpr int64_t kBLOCK_H = 16; + const int64_t BLOCK_H = std::min(4 * batches, kBLOCK_H); + + // strides + const int64_t q_strideM = num_heads * head_size; + const int64_t q_strideH = head_size; + const int64_t l_stride0 = num_heads * num_kv_splits * (head_size_v + 1); + const int64_t l_stride1 = num_kv_splits * (head_size_v + 1); + const int64_t l_stride2 = head_size_v + 1; + + const bool has_logit_cap = logit_cap > 0; + float rlogit_cap = has_logit_cap ? 1 / logit_cap : 0.f; + + // partition the heads into blocks for parallel + const int64_t num_groups = num_heads / num_heads_kv; + const int64_t num_blocks = div_up(num_heads, std::min(BLOCK_H, num_groups)); + const int64_t num_groups_per_block = div_up(num_groups, BLOCK_H); + const int64_t num_heads_per_block = std::min(num_groups, BLOCK_H); + + // parallel on [batches, num_blocks, num_kv_splits] + at::parallel_for(0, batches * num_blocks * num_kv_splits, 0, [&](int64_t begin, int64_t end) { + int64_t bs{0}, head_id{0}, kv_id{0}; + data_index_init(begin, bs, batches, head_id, num_blocks, kv_id, num_kv_splits); + + alignas(64) float s_i[BLOCK_H * BLOCK_N]; + float* __restrict__ s_delta = s_i; + + alignas(64) float s_prime[BLOCK_H]; + alignas(64) float m_prime[BLOCK_H]; + alignas(64) float m_delta[BLOCK_H]; + + for (int64_t i = begin; i < end; ++i) { + const int64_t h_start = head_id * num_heads_per_block; + const int64_t h_end = std::min(h_start + num_heads_per_block, num_heads); + const int64_t h_size = h_end - h_start; + + // get query + const scalar_t* __restrict__ q_ptr = query + bs * q_strideM + h_start * q_strideH; + + // kv head id and valid block head size + int64_t head_kv_id = head_id / num_groups_per_block; + int64_t seq_len_kv = seq_lens[bs]; + int64_t req_pool_id = req_pool_indices[bs]; + TORCH_CHECK(seq_len_kv <= max_context_len, "seq_len_kv out of scope!"); + TORCH_CHECK(req_pool_id < max_num_reqs, "req_pool_id out of scope!"); + + const int64_t SPLIT_SIZE = div_up(seq_len_kv, num_kv_splits); + const int64_t kv_start = kv_id * SPLIT_SIZE; + const int64_t kv_end = std::min(kv_start + SPLIT_SIZE, seq_len_kv); + + fill_stub(s_prime, 0.f, BLOCK_H); + fill_stub(m_prime, -std::numeric_limits::infinity(), BLOCK_H); + + // get v_prime, and init to zero + float* __restrict__ v_prime = attn_logits + bs * l_stride0 + h_start * l_stride1 + kv_id * l_stride2; + for (int64_t h = 0; h < h_size; ++h) { + fill_stub(v_prime + h * l_stride1, 0.f, head_size_v); + } + + // loop over K and V sequence with BLOCK_N + for (int64_t n = kv_start; n < kv_end; n += BLOCK_N) { + int64_t n_size = std::min(BLOCK_N, kv_end - n); + + // calculate Q @ K + index_gemm_kernel_nt( + /* A */ q_ptr, + /* B */ k_buffer + head_kv_id * k_strideH, + /* C */ s_i, + /* ind */ req_to_token + req_pool_id * max_context_len + n, + /* scl */ scaling, + /* M */ h_size, + /* N */ n_size, + /* K */ head_size, + /* lda */ q_strideH, + /* ldb */ k_strideN, + /* ldc */ BLOCK_N, + /* mtt */ max_total_num_tokens); + + if (has_logit_cap) { + at::vec::map( + [logit_cap, rlogit_cap](Vec x) { return Vec(logit_cap) * (x * Vec(rlogit_cap)).tanh(); }, + s_i, + s_i, + n_size); + } + + // update the scaling coefficients + for (int64_t h = 0; h < h_size; ++h) { + // m_i: max value per row + float m_i = at::vec::reduce_all( + [](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, s_i + h * BLOCK_N, n_size); + m_i = std::max(m_i, m_prime[h]); + + // m_delta <- exp(m' - m_i) + m_delta[h] = std::exp(m_prime[h] - m_i); + + // s_delta <- exp(s_i - m_i) + at::vec::map( + [m_i](Vec x) { return (x - Vec(m_i)).exp_u20(); }, s_delta + h * BLOCK_N, s_i + h * BLOCK_N, n_size); + + // s' <- s' * m_delta + sum(s_delta) + s_prime[h] *= m_delta[h]; + s_prime[h] += at::vec::reduce_all([](Vec& x, Vec& y) { return x + y; }, s_delta + h * BLOCK_N, n_size); + + m_prime[h] = m_i; + } + + // caculate V' <- s_delta @ V + V' * m_delta + index_gemm_kernel_nn( + /* A */ s_delta, + /* B */ v_buffer + head_kv_id * v_strideH, + /* C */ v_prime, + /* ind */ req_to_token + req_pool_id * max_context_len + n, + /* scl */ m_delta, + /* M */ h_size, + /* N */ head_size_v, + /* K */ n_size, + /* lda */ BLOCK_N, + /* ldb */ v_strideN, + /* ldc */ l_stride1, + /* mtt */ max_total_num_tokens); + } // loop with KV blocks + + // only update v' when kv_split_size > 0 + if (kv_end > kv_start) { + for (int64_t h = 0; h < h_size; ++h) { + float s = 1 / s_prime[h]; + at::vec::map( + [s](Vec out) { return out * Vec(s); }, v_prime + h * l_stride1, v_prime + h * l_stride1, head_size_v); + (v_prime + h * l_stride1)[head_size_v] = m_prime[h] + std::log(s_prime[h]); + } + } + + // move to the next index + data_index_step(bs, batches, head_id, num_blocks, kv_id, num_kv_splits); + } + }); + + // parallel on [batches, num_heads] + at::parallel_for(0, batches * num_heads, 0, [&](int64_t begin, int64_t end) { + // NB: same as above + for (int64_t i = begin; i < end; ++i) { + float* __restrict__ acc = attn_logits + i * l_stride1; + + float s_prime = 0.f; + float m_prime = -std::numeric_limits::infinity(); + + // update acc with from each kv_split + for (int64_t kv_id = 0; kv_id < num_kv_splits; ++kv_id) { + float* __restrict__ tv = acc + kv_id * l_stride2; + const float tlogic = (acc + kv_id * l_stride2)[head_size_v]; + + float m_i = std::max(tlogic, m_prime); + float m_delta = std::exp(m_prime - m_i); + float e_logic = std::exp(tlogic - m_i); + if (kv_id != 0) { + at::vec::map2( + [m_delta, e_logic](Vec x, Vec y) { return x * Vec(m_delta) + y * Vec(e_logic); }, + acc, + acc, + tv, + head_size_v); + } + + s_prime = s_prime * m_delta + e_logic; + m_prime = m_i; + } + + copy_stub(output + i * head_size_v, acc, 1 / s_prime, head_size_v); + } + }); +} + +} // anonymous namespace + +// query: [num_tokens, num_heads, head_size] +// output: [num_tokens, num_heads, head_size] +// k_buffer: [max_total_num_tokens, num_heads, head_size] +// v_buffer: [max_total_num_tokens, num_heads, head_size_v] +// attn_logits: [num_seqs, num_heads, num_kv_splits, head_size_v + 1] +// req_to_token: [max_num_reqs, max_context_len] int32 or int64 +// req_pool_indices: [num_seqs] int64 +// seq_lens: [num_seqs] int64 +// +void decode_attention_cpu( + at::Tensor& query, + at::Tensor& output, + at::Tensor& k_buffer, + at::Tensor& v_buffer, + at::Tensor& attn_logits, + at::Tensor& req_to_token, + at::Tensor& req_pool_indices, + at::Tensor& seq_lens, + double sm_scale, + double logit_cap) { + RECORD_FUNCTION( + "sgl-kernel::decode_attention_cpu", + std::vector( + {query, output, k_buffer, v_buffer, attn_logits, req_to_token, req_pool_indices, seq_lens})); + + CHECK_INPUT(query); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_buffer); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_buffer); + CHECK_DIM(3, query); + CHECK_DIM(3, k_buffer); + CHECK_DIM(3, v_buffer); + + int64_t num_seqs = seq_lens.size(0); + int64_t max_num_reqs = req_to_token.size(0); + int64_t max_context_len = req_to_token.size(1); + int64_t max_total_num_tokens = k_buffer.size(0); + + int64_t num_heads = query.size(1); + int64_t num_heads_kv = k_buffer.size(1); + int64_t head_size = query.size(2); + int64_t head_size_v = v_buffer.size(2); + + int64_t num_kv_splits = attn_logits.size(2); + + CHECK_EQ(attn_logits.size(0), num_seqs); + CHECK_EQ(attn_logits.size(1), num_heads); + CHECK_EQ(attn_logits.size(3), head_size_v + 1); + CHECK_EQ(attn_logits.scalar_type(), at::kFloat); + + // strides for k_buffer and v_buffer + int64_t k_strideN = k_buffer.stride(0); + int64_t k_strideH = k_buffer.stride(1); + int64_t v_strideN = v_buffer.stride(0); + int64_t v_strideH = v_buffer.stride(1); + + // check index data types + const auto index_dtype = req_to_token.scalar_type(); + TORCH_CHECK( + index_dtype == at::kInt || index_dtype == at::kLong, + "decode: expect req_to_token to be int32 or int64, got ", + index_dtype); + TORCH_CHECK(seq_lens.scalar_type() == at::kLong, "decode: expect req_lens to be int64, got ", seq_lens.scalar_type()); + TORCH_CHECK( + req_pool_indices.scalar_type() == at::kLong, + "decode: expect req_pool_indices to be int64, got ", + req_pool_indices.scalar_type()); + + AT_DISPATCH_REDUCED_FLOATING_TYPES(query.scalar_type(), "decode_attention_kernel", [&] { + AT_DISPATCH_INDEX_TYPES(index_dtype, "decode_attention_indices", [&] { + if (num_heads == num_heads_kv) { + // MHA + decode_attention_kernel_impl( + output.data_ptr(), + attn_logits.data_ptr(), + query.data_ptr(), + k_buffer.data_ptr(), + v_buffer.data_ptr(), + req_to_token.data_ptr(), + req_pool_indices.data_ptr(), + seq_lens.data_ptr(), + num_seqs, + num_heads, + head_size, + head_size_v, + num_kv_splits, + k_strideN, + k_strideH, + v_strideN, + v_strideH, + sm_scale, + logit_cap, + max_num_reqs, + max_context_len, + max_total_num_tokens); + } else { + // GQA/MQA/MLA + decode_attention_grouped_kernel_impl( + output.data_ptr(), + attn_logits.data_ptr(), + query.data_ptr(), + k_buffer.data_ptr(), + v_buffer.data_ptr(), + req_to_token.data_ptr(), + req_pool_indices.data_ptr(), + seq_lens.data_ptr(), + num_seqs, + num_heads, + num_heads_kv, + head_size, + head_size_v, + num_kv_splits, + k_strideN, + k_strideH, + v_strideN, + v_strideH, + sm_scale, + logit_cap, + max_num_reqs, + max_context_len, + max_total_num_tokens); + } + }); + }); +} diff --git a/sgl-kernel/csrc/cpu/extend.cpp b/sgl-kernel/csrc/cpu/extend.cpp new file mode 100644 index 000000000..503cef538 --- /dev/null +++ b/sgl-kernel/csrc/cpu/extend.cpp @@ -0,0 +1,621 @@ +#include "common.h" +#include "gemm.h" +#include "vec.h" + +namespace { + +// [NOTE]: extend attention for CPU +// 1. tune BLOCK_M and BLOCK_N +// 2. can handle non-contiguous k_exttend and v_extend +// 3. computes attention for prefix and extend separately +// 4. TODO: vectorize `pack_vnni` and `pack_vnni2` +// +template +inline index_t get_index(index_t* ind, int i) { + return (ind == nullptr) ? (index_t)i : ind[i]; +} + +// convert to vnni format +// from [N, K/2, 2] to [K/2, N, 2] for bfloat16 and float16 +template +void pack_vnni( + scalar_t* __restrict__ dst, + const scalar_t* __restrict__ src, + const index_t* __restrict__ ind, + int N, + int K, + int ld_src, + int ld_dst) { + for (int n = 0; n < N; ++n) { + index_t index = get_index(ind, n); + for (int k = 0; k < K / 2; ++k) { + for (int d = 0; d < 2; ++d) { + dst[k * ld_dst * 2 + n * 2 + d] = src[index * ld_src + k * 2 + d]; + } + } + } +} + +// convert to vnni format +// from [K/2, 2, N] to [K/2, N, 2] for bfloat16 and float16 +template +void pack_vnni2( + scalar_t* __restrict__ dst, + const scalar_t* __restrict__ src, + const index_t* __restrict__ ind, + int K, + int N, + int ld_src, + int ld_dst) { + int k = 0; + for (; k < (K >> 1) * 2; k += 2) { + index_t index0 = get_index(ind, k + 0); + index_t index1 = get_index(ind, k + 1); + for (int n = 0; n < N; ++n) { + dst[(k >> 1) * ld_dst * 2 + n * 2 + 0] = src[index0 * ld_src + n]; + dst[(k >> 1) * ld_dst * 2 + n * 2 + 1] = src[index1 * ld_src + n]; + } + } + if (K % 2 != 0) { + index_t index = get_index(ind, K - 1); + for (int n = 0; n < N; ++n) { + dst[(K >> 1) * ld_dst * 2 + n * 2 + 0] = src[index * ld_src + n]; + dst[(K >> 1) * ld_dst * 2 + n * 2 + 1] = 0; + } + k += 2; + } + // TODO: check whether we can skip this! + // const int padded_K = div_up(K, TILE_K) * TILE_K; + // for (; k < padded_K; ++k) { + // for (int n = 0; n < N; ++n) { + // dst[k * ld_dst + n] = static_cast(0); + // } + // } +} + +template +inline void fill_stub(scalar_t* __restrict__ out, float val, int size) { + using Vec = at::vec::Vectorized; + const Vec data_vec = Vec(static_cast(val)); + int d = 0; + for (; d <= size - Vec::size(); d += Vec::size()) { + data_vec.store(out + d); + } + if (size - d > 0) { + data_vec.store(out + d, size - d); + } +} + +template +inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input) { + static_assert(BLOCK_N % 32 == 0); + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + + constexpr int COLS = BLOCK_N / 16; + auto store = [&](auto i) { + constexpr int col = i % COLS; + // for COLS = 2, 4 use 512bit store + if constexpr (col % 2 == 0) { + fVec a_fvec0 = fVec::loadu(input + col * 16); + fVec a_fvec1 = fVec::loadu(input + col * 16 + 16); + bVec out_bvec = convert_from_float_ext(a_fvec0, a_fvec1); + out_bvec.store(out + col * 16); + } + }; + Unroll{}(store); +} + +template +inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ acc, float s, int size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + const fVec s_fvec = fVec(s); + int d = 0; + for (; d <= size - bVec::size(); d += bVec::size()) { + fVec a_fvec0 = fVec::loadu(acc + d) * s_fvec; + fVec a_fvec1 = fVec::loadu(acc + d + fVec::size()) * s_fvec; + bVec out_bvec = convert_from_float_ext(a_fvec0, a_fvec1); + out_bvec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(acc[d] * s); + } +} + +template +void extend_attention_kernel_impl( + scalar_t* __restrict__ o_extend, + const scalar_t* __restrict__ q_extend, + const scalar_t* __restrict__ k_extend, + const scalar_t* __restrict__ v_extend, + const scalar_t* __restrict__ k_buffer, + const scalar_t* __restrict__ v_buffer, + const index_t* __restrict__ req_to_token, + const int64_t* __restrict__ req_pool_indices, + const int64_t* __restrict__ seq_lens, + const index_t* __restrict__ extend_seq_lens, + const index_t* __restrict__ extend_start_loc, + const void* __restrict__ buffer, + int batches, + int num_heads, + int num_heads_kv, + int head_size, + int head_size_v, + int ke_strideN, + int ke_strideH, + int ve_strideN, + int ve_strideH, + int k_strideN, + int k_strideH, + int v_strideN, + int v_strideH, + float scaling, + float logit_cap, + int max_num_reqs, + int max_context_len, + int max_total_num_tokens, + int max_len_extend, + int buffer_size_per_thread, + bool is_prefix_skipped) { + using Vec = at::vec::Vectorized; + + // strides + const int q_strideM = num_heads * head_size; + const int q_strideH = head_size; + const int o_strideM = num_heads * head_size_v; + const int o_strideH = head_size_v; + + // we use same buffer for packed key and value + const int ldb_tmp = std::max(head_size, head_size_v); + + const bool has_logit_cap = logit_cap > 0; + float rlogit_cap = has_logit_cap ? 1 / logit_cap : 0.f; + + const int num_groups = num_heads / num_heads_kv; + TORCH_CHECK(num_groups * num_heads_kv == num_heads); + + // number of blocks along M + int MB = div_up(max_len_extend, BLOCK_M); + + // parallel on [batches, num_heads, BM] + at::parallel_for(0, batches * num_heads * MB, 0, [&](int begin, int end) { + int bs{0}, head_id{0}, mb{0}; + data_index_init(begin, bs, batches, head_id, num_heads, mb, MB); + + int tid = at::get_thread_num(); + // s_i and s_delta: [BLOCK_M, BLOCK_N] + float* __restrict__ s_i = reinterpret_cast((char*)(buffer) + tid * buffer_size_per_thread); + float* __restrict__ s_delta = s_i; + + // v_prime: [BLOCK_M, head_size_v] + float* __restrict__ v_prime = s_i + BLOCK_M * BLOCK_N; + + // s_delta2: [BLOCK_M, BLOCK_N]; copy of s_delta in scalar_t + scalar_t* __restrict__ s_delta2 = reinterpret_cast(v_prime + BLOCK_N * head_size_v); + + // Btmp: [BLOCK_N, max(head_size, head_size_v)] + scalar_t* __restrict__ Btmp = s_delta2 + BLOCK_M * BLOCK_N; + + // init Btmp just once for each thread to prevent NaN + fill_stub(Btmp, 0.f, BLOCK_N * ldb_tmp); + + alignas(64) float s_prime[BLOCK_M]; + alignas(64) float m_prime[BLOCK_M]; + + for (int i = begin; i < end; ++i) { + // seq_len = prefix + extend + int head_kv_id = head_id / num_groups; + int seq_len = seq_lens[bs]; + int seq_len_extend = extend_seq_lens[bs]; + int seq_len_prefix = seq_len - seq_len_extend; + int seq_extend_start_loc = extend_start_loc[bs]; + + int req_pool_id = req_pool_indices[bs]; + TORCH_CHECK(seq_len_prefix >= 0, "prefix len < 0!"); + TORCH_CHECK(seq_len <= max_context_len, "seq_len out of scope!"); + TORCH_CHECK(req_pool_id < max_num_reqs, "req_pool_id out of scope!"); + + if (is_prefix_skipped) { + TORCH_CHECK(seq_len_prefix == 0, "extend attention: expect seq_len_prefix to be 0, got ", seq_len_prefix); + } + + // offset and size in MB + int m = mb * BLOCK_N; + int m_size = std::min(BLOCK_M, seq_len_extend - m); + + if (m_size <= 0) { + data_index_step(bs, batches, head_id, num_heads, mb, MB); + continue; + } + + // get query + const scalar_t* __restrict__ q_ptr = q_extend + (seq_extend_start_loc + m) * q_strideM + head_id * q_strideH; + + // init v', s' and m' + fill_stub(v_prime, 0.f, m_size * head_size_v); + fill_stub(s_prime, 0.f, m_size); + fill_stub(m_prime, -std::numeric_limits::infinity(), m_size); + + // stage 1: compute scores with prefix + for (int n = 0; n < seq_len_prefix; n += BLOCK_N) { + int n_size = std::min(BLOCK_N, seq_len_prefix - n); + + // `n_size` is K in 2nd gemm, pad to TILE_K; + const int padded_n_size = div_up(n_size, TILE_K) * TILE_K; + + // get key and pack + pack_vnni( + /* dst */ Btmp, + /* src */ k_buffer + head_kv_id * k_strideH, + /* ind */ req_to_token + req_pool_id * max_context_len + n, + /* N */ n_size, + /* K */ head_size, + /* ld_src */ k_strideN, + /* ld_dst */ BLOCK_N); + + // calculate s_i <- Q @ K + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ head_size, + /* lda */ q_strideM, + /* ldb */ BLOCK_N, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ q_ptr, + /* B */ Btmp, + /* C */ s_i); + + const Vec scale_vec = Vec(scaling); + for (int row = 0; row < m_size; ++row) { + // s_i <- s_i * scale + at::vec::map( + [scale_vec](Vec x) { return x * scale_vec; }, s_i + row * BLOCK_N, s_i + row * BLOCK_N, n_size); + + // TODO: `tanh` from torch uses sleef u10, going to be slow + if (has_logit_cap) { + at::vec::map( + [logit_cap, rlogit_cap](Vec x) { return Vec(logit_cap) * (x * Vec(rlogit_cap)).tanh(); }, + s_i + row * BLOCK_N, + s_i + row * BLOCK_N, + n_size); + } + + // m_i: max value per row + float m_i = at::vec::reduce_all( + [](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, s_i + row * BLOCK_N, n_size); + m_i = std::max(m_i, m_prime[row]); + + // m_delta <- exp(m' - m_i) + float m_delta = std::exp(m_prime[row] - m_i); + + // s_delta <- exp(s_i - m_i) + at::vec::map( + [m_i](Vec x) { return (x - Vec(m_i)).exp_u20(); }, s_delta + row * BLOCK_N, s_i + row * BLOCK_N, n_size); + + // s' <- s' * m_delta + sum(s_delta) + s_prime[row] *= m_delta; + s_prime[row] += + at::vec::reduce_all([](Vec& x, Vec& y) { return x + y; }, s_delta + row * BLOCK_N, n_size); + + m_prime[row] = m_i; + + // v' <- v' * m_delta + at::vec::map( + [m_delta](Vec x) { return x * Vec(m_delta); }, + v_prime + row * head_size_v, + v_prime + row * head_size_v, + head_size_v); + + // pad s_delta with 0 first and then convert to scalar_t + fill_stub(s_delta + row * BLOCK_N + n_size, 0.f, padded_n_size - n_size); + copy_stub(s_delta2 + row * BLOCK_N, s_delta + row * BLOCK_N); + } + + // get value and pack + pack_vnni2( + /* dst */ Btmp, + /* src */ v_buffer + head_kv_id * v_strideH, + /* ind */ req_to_token + req_pool_id * max_context_len + n, + /* K */ n_size, + /* N */ head_size_v, + /* ld_src */ v_strideN, + /* ld_dst */ head_size_v); + + // caculate V' <- s_delta @ V + V' + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ head_size_v, + /* K */ padded_n_size, // n_size + /* lda */ BLOCK_N, + /* ldb */ head_size_v, + /* ldc */ head_size_v, + /* add_C */ true, + /* A */ s_delta2, + /* B */ Btmp, + /* C */ v_prime); + } // loop with seq_len_prefix + + // stage 2: compute the triangle part + int num_keys = std::min(seq_len_extend, m + BLOCK_M); + for (int n = 0; n < num_keys; n += BLOCK_N) { + int n_size = std::min(BLOCK_N, num_keys - n); + + // `n_size` is K in 2nd gemm, pad to TILE_K; + const int padded_n_size = div_up(n_size, TILE_K) * TILE_K; + + // get key and pack + pack_vnni( + /* dst */ Btmp, + /* src */ k_extend + (seq_extend_start_loc + n) * ke_strideN + head_kv_id * ke_strideH, + /* ind */ nullptr, + /* N */ n_size, + /* K */ head_size, + /* ld_src */ ke_strideN, + /* ld_dst */ BLOCK_N); + + // calculate s_i <- Q @ K + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ head_size, + /* lda */ q_strideM, + /* ldb */ BLOCK_N, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ q_ptr, + /* B */ Btmp, + /* C */ s_i); + + // apply causal mask + if (num_keys - n <= BLOCK_N) { + for (int row = 0; row < m_size; ++row) { + int last_col = m + row - n; + // fill [last_col + 1, n_size) to -inf + float* row_ptr = s_i + row * BLOCK_N; + fill_stub(row_ptr + last_col + 1, -std::numeric_limits::infinity(), n_size - last_col - 1); + } + } + + const Vec scale_vec = Vec(scaling); + for (int row = 0; row < m_size; ++row) { + // s_i <- s_i * scale + at::vec::map( + [scale_vec](Vec x) { return x * scale_vec; }, s_i + row * BLOCK_N, s_i + row * BLOCK_N, n_size); + + // TODO: `tanh` from torch uses sleef u10, going to be slow + if (has_logit_cap) { + at::vec::map( + [logit_cap, rlogit_cap](Vec x) { return Vec(logit_cap) * (x * Vec(rlogit_cap)).tanh(); }, + s_i + row * BLOCK_N, + s_i + row * BLOCK_N, + n_size); + } + + // m_i: max value per row + float m_i = at::vec::reduce_all( + [](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, s_i + row * BLOCK_N, n_size); + m_i = std::max(m_i, m_prime[row]); + + // m_delta <- exp(m' - m_i) + float m_delta = std::exp(m_prime[row] - m_i); + + // s_delta <- exp(s_i - m_i) + at::vec::map( + [m_i](Vec x) { return (x - Vec(m_i)).exp_u20(); }, s_delta + row * BLOCK_N, s_i + row * BLOCK_N, n_size); + + // s' <- s' * m_delta + sum(s_delta) + s_prime[row] *= m_delta; + s_prime[row] += + at::vec::reduce_all([](Vec& x, Vec& y) { return x + y; }, s_delta + row * BLOCK_N, n_size); + + m_prime[row] = m_i; + + // v' <- v' * m_delta + at::vec::map( + [m_delta](Vec x) { return x * Vec(m_delta); }, + v_prime + row * head_size_v, + v_prime + row * head_size_v, + head_size_v); + + // pad s_delta with 0 first and then convert to scalar_t + fill_stub(s_delta + row * BLOCK_N + n_size, 0.f, padded_n_size - n_size); + copy_stub(s_delta2 + row * BLOCK_N, s_delta + row * BLOCK_N); + } + + // get value and pack + pack_vnni2( + /* dst */ Btmp, + /* src */ v_extend + (seq_extend_start_loc + n) * ve_strideN + head_kv_id * ve_strideH, + /* ind */ nullptr, + /* K */ n_size, + /* N */ head_size_v, + /* ld_src */ ve_strideN, + /* ld_dst */ head_size_v); + + // caculate V' <- s_delta @ V + V' + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ head_size_v, + /* K */ padded_n_size, // n_size + /* lda */ BLOCK_N, + /* ldb */ head_size_v, + /* ldc */ head_size_v, + /* add_C */ true, + /* A */ s_delta2, + /* B */ Btmp, + /* C */ v_prime); + } // loop with seq_len_extend + + scalar_t* __restrict__ out_ptr = o_extend + (seq_extend_start_loc + m) * o_strideM + head_id * o_strideH; + for (int row = 0; row < m_size; ++row) { + float s = 1 / s_prime[row]; + copy_stub(out_ptr + row * o_strideM, v_prime + row * head_size_v, s, head_size_v); + } + + // move to the next index + data_index_step(bs, batches, head_id, num_heads, mb, MB); + } + at::native::cpublas::brgemm_release(); + }); +} + +} // anonymous namespace + +// q_extend, k_extend, v_extend, o_extend: contiguous tensors +// k_buffer, v_buffer: (prefix + extend) tensors in mem_manager +// +// q_extend: [num_tokens, num_heads, head_size] +// k_extend: [num_extend_tokens, num_heads, head_size] +// v_extend: [num_extend_tokens, num_heads, head_size] +// o_extend: [num_tokens, num_heads, head_size] +// k_buffer: [max_total_num_tokens, num_heads, head_size] +// v_buffer: [max_total_num_tokens, num_heads, head_size] +// req_to_token: [max_num_reqs, max_context_len] int32 or int64 +// req_pool_indices: [num_seqs] int64 +// seq_lens: [num_seqs] int64 +// extend_seq_lens: [num_seqs] +// extend_start_loc: [num_seqs] +// +void extend_attention_cpu( + at::Tensor& q_extend, + at::Tensor& k_extend, + at::Tensor& v_extend, + at::Tensor& o_extend, + at::Tensor& k_buffer, + at::Tensor& v_buffer, + at::Tensor& req_to_token, + at::Tensor& req_pool_indices, + at::Tensor& seq_lens, + at::Tensor& extend_seq_lens, + at::Tensor& extend_start_loc, + int64_t max_len_extend, + double sm_scale, + double logit_cap) { + RECORD_FUNCTION( + "sgl-kernel::extend_attention_cpu", + std::vector( + {q_extend, + k_extend, + v_extend, + o_extend, + k_buffer, + v_buffer, + req_to_token, + req_pool_indices, + seq_lens, + extend_seq_lens, + extend_start_loc})); + + CHECK_INPUT(q_extend); + CHECK_INPUT(o_extend); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_extend); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_extend); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_buffer); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_buffer); + + int num_seqs = seq_lens.size(0); + int max_num_reqs = req_to_token.size(0); + int max_context_len = req_to_token.size(1); + int max_total_num_tokens = k_buffer.size(0); + + int num_heads = q_extend.size(1); + int num_heads_kv = k_extend.size(1); + int head_size = q_extend.size(2); + int head_size_v = v_extend.size(2); + + // strides for k_extend and v_extend + int ke_strideN = k_extend.stride(0); + int ke_strideH = k_extend.stride(1); + int ve_strideN = v_extend.stride(0); + int ve_strideH = v_extend.stride(1); + + // strides for k_buffer and v_buffer + int k_strideN = k_buffer.stride(0); + int k_strideH = k_buffer.stride(1); + int v_strideN = v_buffer.stride(0); + int v_strideH = v_buffer.stride(1); + + // check sizes + CHECK_EQ(req_pool_indices.size(0), num_seqs); + CHECK_EQ(extend_seq_lens.size(0), num_seqs); + CHECK_EQ(extend_start_loc.size(0), num_seqs); + CHECK_EQ(v_extend.size(1), num_heads_kv); + CHECK_EQ(k_buffer.size(1), v_buffer.size(1)); + + // MLA will skip prefix part + const bool is_prefix_skipped = k_buffer.size(1) != num_heads_kv; + + // check index data types + const auto index_dtype = req_to_token.scalar_type(); + TORCH_CHECK( + index_dtype == at::kInt || index_dtype == at::kLong, + "extend: expect req_to_token to be int32 or int64, got ", + index_dtype); + TORCH_CHECK(seq_lens.scalar_type() == at::kLong, "extend: expect req_lens to be int64, got ", seq_lens.scalar_type()); + TORCH_CHECK( + req_pool_indices.scalar_type() == at::kLong, + "extend: expect req_pool_indices to be int64, got ", + req_pool_indices.scalar_type()); + TORCH_CHECK( + extend_seq_lens.scalar_type() == index_dtype && extend_start_loc.scalar_type() == index_dtype, + "extend: expect extend_seq_lens and extend_start_loc to have same dtype as req_to_token."); + + // D and DV need to be 32x as we transpose by 512-bit + TORCH_CHECK(head_size % 32 == 0, "invalid head_size ", head_size); + TORCH_CHECK(head_size_v % 32 == 0, "invalid head_size_v ", head_size_v); + + // block size for query seq length + constexpr int BLOCK_M = 32; + // block size for key/value seq length + constexpr int BLOCK_N = 32; + + const int size_per_thread = + /* s_i */ BLOCK_M * BLOCK_N * sizeof(float) + + /* v_prime */ BLOCK_M * head_size_v * sizeof(float) + + /* s_delta */ BLOCK_M * BLOCK_N * sizeof(uint16_t) + + /* Btmp */ BLOCK_N * std::max(head_size, head_size_v) * sizeof(uint16_t); + + int num_threads = at::get_num_threads(); + auto buffer = at::empty({num_threads, size_per_thread}, q_extend.options().dtype(at::kChar)); + + AT_DISPATCH_REDUCED_FLOATING_TYPES(q_extend.scalar_type(), "extend_attention_kernel", [&] { + AT_DISPATCH_INDEX_TYPES(index_dtype, "extend_attention_indices", [&] { + extend_attention_kernel_impl( + o_extend.data_ptr(), + q_extend.data_ptr(), + k_extend.data_ptr(), + v_extend.data_ptr(), + k_buffer.data_ptr(), + v_buffer.data_ptr(), + req_to_token.data_ptr(), + req_pool_indices.data_ptr(), + seq_lens.data_ptr(), + extend_seq_lens.data_ptr(), + extend_start_loc.data_ptr(), + buffer.data_ptr(), + num_seqs, + num_heads, + num_heads_kv, + head_size, + head_size_v, + ke_strideN, + ke_strideH, + ve_strideN, + ve_strideH, + k_strideN, + k_strideH, + v_strideN, + v_strideH, + sm_scale, + logit_cap, + max_num_reqs, + max_context_len, + max_total_num_tokens, + max_len_extend, + size_per_thread, + is_prefix_skipped); + }); + }); +} diff --git a/sgl-kernel/csrc/cpu/gemm.cpp b/sgl-kernel/csrc/cpu/gemm.cpp new file mode 100644 index 000000000..97c0e7935 --- /dev/null +++ b/sgl-kernel/csrc/cpu/gemm.cpp @@ -0,0 +1,507 @@ +#include "gemm.h" + +#include "common.h" +#include "vec.h" + +namespace { + +// packed layout: +// quants {N, K} int8_t +// comp {N} int32_t +template +inline void s8s8_compensation(int8_t* __restrict__ packed, int K) { +#if defined(CPU_CAPABILITY_AVX512) + constexpr int COLS = BLOCK_N / 16; + __m512i vcomp[COLS]; + + for (int col = 0; col < COLS; ++col) { + vcomp[col] = _mm512_setzero_si512(); + } + + const int64_t offset = BLOCK_N * K; + const __m512i off = _mm512_set1_epi8(static_cast(0x80)); + for (int k = 0; k < K / 4; ++k) { + for (int col = 0; col < COLS; ++col) { + __m512i vb = _mm512_loadu_si512((const __m512i*)(packed + k * BLOCK_N * 4 + col * 64)); + vcomp[col] = _mm512_dpbusd_epi32(vcomp[col], off, vb); + } + } + + for (int col = 0; col < COLS; ++col) { + _mm512_storeu_si512((__m512i*)(packed + offset + col * 64), vcomp[col]); + } +#else + TORCH_CHECK(false, "s8s8_compensation not implemented!"); +#endif +} + +// convert to vnni format +// from [N, K] to [K/2, N, 2] for bfloat16 and float16 +template +inline void pack_vnni(packed_t* __restrict__ packed, const packed_t* __restrict__ weight, int N, int K) { + const int VNNI_BLK = 2; + for (int n = 0; n < N; ++n) { + for (int k = 0; k < K / VNNI_BLK; ++k) { + for (int d = 0; d < VNNI_BLK; ++d) { + packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d]; + } + } + } +} + +template <> +inline void pack_vnni(int8_t* __restrict__ packed, const int8_t* __restrict__ weight, int N, int K) { + constexpr int BLOCK_N = block_size_n(); + TORCH_CHECK(N == BLOCK_N); + + const int VNNI_BLK = 4; + for (int n = 0; n < N; ++n) { + for (int k = 0; k < K / VNNI_BLK; ++k) { + for (int d = 0; d < VNNI_BLK; ++d) { + packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d]; + } + } + } + s8s8_compensation(packed, K); +} + +template +inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + + int64_t d; +#pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec data0 = fVec::loadu(input + d); + fVec data1 = fVec::loadu(input + d + fVec::size()); + bVec out_vec = convert_from_float_ext(data0, data1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d]); + } +} + +template +inline void copy_add_stub( + scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + + int64_t d; +#pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d); + fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size()); + bVec out_vec = convert_from_float_ext(data0, data1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] + bias[d]); + } +} + +template +struct tinygemm_kernel_nn { + static inline void apply( + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ B, + scalar_t* __restrict__ C, + const float* __restrict__ bias, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_nn { + static inline void apply( + const at::BFloat16* __restrict__ A, + const at::BFloat16* __restrict__ B, + at::BFloat16* __restrict__ C, + const float* __restrict__ bias, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 0; + + __m512bh va; + __m512bh vb[COLS]; + __m512 vc[ROWS * COLS]; + + auto loadc = [&](auto i) { + constexpr int col = i % COLS; + if constexpr (has_bias) { + vc[i] = _mm512_loadu_ps(bias + col * 16); + } else { + vc[i] = _mm512_set1_ps(0.f); + } + }; + Unroll{}(loadc); + + const int64_t K2 = K >> 1; + const int64_t lda2 = lda >> 1; + const int64_t ldb2 = ldb; // ldb * 2 >> 1; + const float* a_ptr = reinterpret_cast(A); + const float* b_ptr = reinterpret_cast(B); + + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k])); + } + if constexpr (row == 0) { + vb[col] = (__m512bh)(_mm512_loadu_si512(b_ptr + k * ldb2 + col * 16)); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); + } + } + vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]); + }; + for (int64_t k = 0; k < K2; ++k) { + Unroll{}(compute, k); + } + + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + // for COLS = 2, 4 use 512bit store + // for COLS = 1, 3 use 256bit store + if constexpr (COLS % 2 == 0) { + if constexpr (col % 2 == 0) { + _mm512_storeu_si512( + reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), + (__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col]))); + } + } else { + _mm256_storeu_si256(reinterpret_cast<__m256i*>(C + row * ldc + col * 16), (__m256i)(_mm512_cvtneps_pbh(vc[i]))); + } + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_nn::apply( \ + A + mb_start * lda, \ + B + nb_start * 2, \ + C + mb_start * ldc + nb_start, \ + has_bias ? bias + nb_start : nullptr, \ + K, \ + lda, \ + ldb, \ + ldc); + +template +struct brgemm { + static inline void apply( + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ B, + scalar_t* __restrict__ C, + float* __restrict__ Ctmp, + const float* __restrict__ bias, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + constexpr int BLOCK_N = block_size_n(); + at::native::cpublas::brgemm(M, N, K, lda, ldb, BLOCK_N, /* add_C */ false, A, B, Ctmp); + + // copy from Ctmp to C + for (int64_t m = 0; m < M; ++m) { + if constexpr (has_bias) { + copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N); + } else { + copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N); + } + } + } +}; + +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ B, + scalar_t* __restrict__ C, + float* __restrict__ Ctmp, + const float* __restrict__ bias, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg) { + if (brg) { + brgemm::apply(A, B, C, Ctmp, bias, M, N, K, lda, ldb, ldc); + return; + } + + // pattern: 1-4-16 + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 64; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch (mb_size << 4 | nb_size >> 4) { + // mb_size = 1 + case 0x12: + LAUNCH_TINYGEMM_KERNEL_NN(1, 32); + break; + case 0x14: + LAUNCH_TINYGEMM_KERNEL_NN(1, 64); + break; + // mb_size = 2 + case 0x22: + LAUNCH_TINYGEMM_KERNEL_NN(2, 32); + break; + case 0x24: + LAUNCH_TINYGEMM_KERNEL_NN(2, 64); + break; + // mb_size = 3 + case 0x32: + LAUNCH_TINYGEMM_KERNEL_NN(3, 32); + break; + case 0x34: + LAUNCH_TINYGEMM_KERNEL_NN(3, 64); + break; + // mb_size = 4 + case 0x42: + LAUNCH_TINYGEMM_KERNEL_NN(4, 32); + break; + case 0x44: + LAUNCH_TINYGEMM_KERNEL_NN(4, 64); + break; + default: + TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +template +void weight_packed_linear_kernel_impl( + scalar_t* __restrict__ out, + const scalar_t* __restrict__ mat1, + const scalar_t* __restrict__ mat2, + const float* __restrict__ bias, + int64_t M, + int64_t N, + int64_t K, + int64_t mat1_strideM, + int64_t out_strideM) { + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + // use avx512-bf16 when a) M is small; b) dtype is bfloat16, otherwise use amx + const bool use_brgemm = (M > 4) || (!std::is_same_v); + + // parallel on [MB, NB] + AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] { + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + int64_t mb{0}, nb{0}; + data_index_init(begin, mb, MB, nb, NB); + + // for brgemm, use float32 for accumulate + alignas(64) float Ctmp[BLOCK_M * BLOCK_N]; + + for (int64_t i = begin; i < end; ++i) { + UNUSED(i); + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(M - mb_start, BLOCK_M); + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(N - nb_start, BLOCK_N); + + tinygemm_kernel( + /* A */ mat1 + mb_start * mat1_strideM, + /* B */ mat2 + nb_start * K /* nb * BLOCK_N * K */, + /* C */ out + mb_start * out_strideM + nb_start, + /* Ctmp*/ Ctmp, + /* bias*/ bias + nb_start, + /* M */ mb_size, + /* N */ nb_size, + /* K */ K, + /* lda */ mat1_strideM, + /* ldb */ nb_size, + /* ldc */ out_strideM, + /* brg */ use_brgemm); + + // move to the next index + data_index_step(mb, MB, nb, NB); + } + + if (use_brgemm) { + at::native::cpublas::brgemm_release(); + } + }); + }); +} + +} // anonymous namespace + +// tinygemm interface +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ B, + scalar_t* __restrict__ C, + float* __restrict__ Ctmp, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg) { + tinygemm_kernel(A, B, C, Ctmp, nullptr, M, N, K, lda, ldb, ldc, brg); +} + +#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \ + template void tinygemm_kernel( \ + const TYPE* __restrict__ A, \ + const TYPE* __restrict__ B, \ + TYPE* __restrict__ C, \ + float* __restrict__ Ctmp, \ + int64_t M, \ + int64_t N, \ + int64_t K, \ + int64_t lda, \ + int64_t ldb, \ + int64_t ldc, \ + bool brg) + +INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16); +INSTANTIATE_TINYGEMM_TEMPLATE(at::Half); + +at::Tensor convert_weight_packed(at::Tensor& weight) { + // for 3d moe weights + // weight : [E, OC, IC] + // w1 : [E, 2N, K] + // w2 : [E, K, N] + CHECK_INPUT(weight); + + const int64_t ndim = weight.ndimension(); + TORCH_CHECK(ndim == 2 || ndim == 3, "expect weight to be 2d or 3d, got ", ndim, "d tensor."); + const auto st = weight.scalar_type(); + const int64_t E = ndim == 3 ? weight.size(0) : 1; + const int64_t OC = ndim == 3 ? weight.size(1) : weight.size(0); + const int64_t IC = ndim == 3 ? weight.size(2) : weight.size(1); + + // we handle 2 TILE_N at a time. + TORCH_CHECK(OC % TILE_N == 0, "invalid weight out features ", OC); + TORCH_CHECK(IC % TILE_K == 0, "invalid weight input features ", IC); + + constexpr int64_t BLOCK_N = block_size_n(); + const int64_t NB = div_up(OC, BLOCK_N); + + // use phony sizes here [E, OC, IC], for each [E], [OC, IC] -> [IC / 2, OC, 2] + auto packed_weight = at::empty({}, weight.options()); + const int64_t stride = OC * IC; + + TORCH_CHECK( + st == at::kBFloat16 || st == at::kHalf || st == at::kChar, "expect weight to be bfloat16, float16 or int8."); + + CPU_DISPATCH_PACKED_TYPES(st, [&] { + // adjust most inner dimension size + const int packed_row_size = get_row_size(IC); + auto sizes = weight.sizes().vec(); + sizes[ndim - 1] = packed_row_size; + packed_weight.resize_(sizes); + + const packed_t* w_data = weight.data_ptr(); + packed_t* packed_data = packed_weight.data_ptr(); + + // parallel on {E, NB} + at::parallel_for(0, E * NB, 0, [&](int64_t begin, int64_t end) { + int64_t e{0}, nb{0}; + data_index_init(begin, e, E, nb, NB); + + for (int64_t i = begin; i < end; ++i) { + UNUSED(i); + + int64_t n = nb * BLOCK_N; + int64_t n_size = std::min(BLOCK_N, OC - n); + pack_vnni( + packed_data + e * OC * packed_row_size + n * packed_row_size, w_data + e * stride + n * IC, n_size, IC); + + // move to the next index + data_index_step(e, E, nb, NB); + } + }); + }); + return packed_weight; +} + +// mat1 : [M, K] +// mat2 : [N, K] +// bias : [N] +// out : [M, N] +// +at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, std::optional& bias, bool is_vnni) { + RECORD_FUNCTION("sgl-kernel::weight_packed_linear", std::vector({mat1, mat2, bias})); + + auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); + + CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1); + CHECK_INPUT(mat2); + + int64_t M = mat1.size(0); + int64_t N = mat2.size(0); + int64_t K = mat2.size(1); + CHECK_EQ(mat1.size(1), K); + CHECK_DIM(2, mat1); + CHECK_DIM(2, mat2); + + auto out = at::empty({M, N}, mat1.options()); + + // strides + int64_t mat1_strideM = mat1.stride(0); + int64_t out_strideM = out.stride(0); + + const bool has_bias = bias.has_value(); + const float* bias_data = nullptr; + if (has_bias) { + CHECK_EQ(bias.value().size(0), N); + bias_data = bias.value().data_ptr(); + } + + AT_DISPATCH_REDUCED_FLOATING_TYPES(mat1.scalar_type(), "weight_packed_linear_kernel_impl", [&] { + weight_packed_linear_kernel_impl( + out.data_ptr(), + mat1.data_ptr(), + packed_w.data_ptr(), + bias_data, + M, + N, + K, + mat1_strideM, + out_strideM); + }); + + return out; +} diff --git a/sgl-kernel/csrc/cpu/gemm.h b/sgl-kernel/csrc/cpu/gemm.h new file mode 100644 index 000000000..010f50a0c --- /dev/null +++ b/sgl-kernel/csrc/cpu/gemm.h @@ -0,0 +1,130 @@ +#pragma once + +#include + +// amx-bf16 +#define TILE_M 16 +#define TILE_N 16 +#define TILE_K 32 + +// block size for AMX gemm +constexpr int block_size_m() { + return 2 * TILE_M; +} +constexpr int block_size_n() { + return 2 * TILE_N; +} + +// define threshold using brgemm (intel AMX) +template +inline bool can_use_brgemm(int M); +template <> +inline bool can_use_brgemm(int M) { + return M > 4; +} +template <> +inline bool can_use_brgemm(int M) { + return true; +} +// TODO: add u8s8 brgemm, this requires PyTorch 2.7 +template <> +inline bool can_use_brgemm(int M) { + return false; +} + +// work around compiler internal error +#define BLOCK_K 128 // 4 * TILE_K + +// adjust leading dimension size for K +template +inline int64_t get_row_size(int64_t K) { + return K; +} + +template <> +inline int64_t get_row_size(int64_t K) { + return K + sizeof(int32_t); +} + +inline int64_t get_row_size(int64_t K, bool use_int8_w8a8) { + return use_int8_w8a8 ? K + sizeof(int32_t) : K; +} + +// pack weight to vnni format +at::Tensor convert_weight_packed(at::Tensor& weight); + +// moe implementations for int8 w8a8 +template +void fused_experts_int8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ ic2, + uint8_t* __restrict__ A_tmp, + float* __restrict__ C_tmp, + uint8_t* __restrict__ Aq_tmp, + float* __restrict__ As_tmp, + const scalar_t* __restrict__ input, + const int8_t* __restrict__ packed_w1, + const int8_t* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + const float* __restrict__ topk_weights, + const int32_t* __restrict__ sorted_ids, + const int32_t* __restrict__ expert_ids, + const int32_t* __restrict__ offsets, + int64_t M, + int64_t N, + int64_t K, + int64_t E, + int64_t topk, + int64_t num_tokens_post_pad); + +// shared expert implememntation for int8 w8a8 +template +void shared_expert_int8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic1, + float* __restrict__ C_tmp, + uint8_t* __restrict__ Aq_tmp, + float* __restrict__ As_tmp, + const scalar_t* __restrict__ input, + const int8_t* __restrict__ packed_w1, + const int8_t* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + const scalar_t* __restrict__ fused_experts_out, + float routed_scaling_factor, + int64_t M, + int64_t N, + int64_t K); + +// tinygemm interface +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ B, + scalar_t* __restrict__ C, + float* __restrict__ Ctmp, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg); + +template +void tinygemm_kernel( + const uint8_t* __restrict__ A, + const int8_t* __restrict__ B, + scalar_t* __restrict__ C, + int32_t* __restrict__ Ctmp, + const float* __restrict__ As, + const float* __restrict__ Bs, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg); diff --git a/sgl-kernel/csrc/cpu/gemm_int8.cpp b/sgl-kernel/csrc/cpu/gemm_int8.cpp new file mode 100644 index 000000000..ba383076a --- /dev/null +++ b/sgl-kernel/csrc/cpu/gemm_int8.cpp @@ -0,0 +1,489 @@ +#include "common.h" +#include "gemm.h" +#include "vec.h" + +namespace { + +template +struct tinygemm_kernel_nn { + static inline void apply( + const uint8_t* __restrict__ A, + const int8_t* __restrict__ B, + scalar_t* __restrict__ C, + const float* __restrict__ As, + const float* __restrict__ Bs, + const int32_t* __restrict__ Bcomp, + const float* __restrict__ bias, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_nn { + static inline void apply( + const uint8_t* __restrict__ A, + const int8_t* __restrict__ B, + at::BFloat16* __restrict__ C, + const float* __restrict__ As, + const float* __restrict__ Bs, + const int32_t* __restrict__ Bcomp, + const float* __restrict__ bias, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + static_assert(COLS % 2 == 0); + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 0; + + __m512i va; + __m512i vb[COLS]; + __m512i vc[ROWS * COLS]; + __m512i vcomp[COLS]; + __m512 vd0; + __m512 vd1[COLS]; + + // oops! 4x4 spills but luckly we use 4x2 + __m512 vbias[COLS]; + + // [NOTE]: s8s8 igemm compensation in avx512-vnni + // + // avx512-vnni has no s8s8, so we need to change s8s8 to u8s8 with compensate: + // + // a * b = (a + 128) * b - 128 * b + // s s u s u s + // + // 1) 128 * b is pre-computed when packing B to vnni formats + // 2) a + 128 is fused when dynamically quantize A + // + auto loadc = [&](auto i) { vc[i] = _mm512_set1_epi32(0); }; + Unroll{}(loadc); + + const int64_t K4 = K >> 2; + const int64_t lda4 = lda >> 2; + const int64_t ldb4 = ldb; // ldb * 4 >> 2; + const int32_t* a_ptr = reinterpret_cast(A); + const int32_t* b_ptr = reinterpret_cast(B); + + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = _mm512_set1_epi32(a_ptr[row * lda4 + k]); + } + if constexpr (row == 0) { + vb[col] = _mm512_loadu_si512(b_ptr + k * ldb4 + col * 16); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb4 + col * 16, _MM_HINT_T0); + } + } + vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]); + }; + for (int64_t k = 0; k < K4; ++k) { + Unroll{}(compute, k); + } + + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + // load a scale + if constexpr (col == 0) { + vd0 = _mm512_set1_ps(As[row]); + } + // load b scale and vcomp per 2 vectors + // also load bias if any + if constexpr (row == 0) { + if constexpr (col % 2 == 0) { + vd1[col + 0] = _mm512_loadu_ps(Bs + col * 16); + vd1[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16); + vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16); + vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16); + if constexpr (has_bias) { + vbias[col + 0] = _mm512_loadu_ps(bias + col * 16); + vbias[col + 1] = _mm512_loadu_ps(bias + col * 16 + 16); + } + } + } + + // for COLS = 2, 4 use 512bit store + if constexpr (col % 2 == 0) { + __m512 vc0 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 0], vcomp[col + 0])); + __m512 vc1 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 1], vcomp[col + 1])); + if constexpr (has_bias) { + vc0 = _mm512_fmadd_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0], vbias[col + 0]); + vc1 = _mm512_fmadd_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1], vbias[col + 1]); + } else { + vc0 = _mm512_mul_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0]); + vc1 = _mm512_mul_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1]); + } + + _mm512_storeu_si512( + reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), (__m512i)(_mm512_cvtne2ps_pbh(vc1, vc0))); + } + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_nn::apply( \ + A + mb_start * lda, \ + B + nb_start * 4, \ + C + mb_start * ldc + nb_start, \ + As + mb_start, \ + Bs + nb_start, \ + Bcomp + nb_start, \ + has_bias ? bias + nb_start : nullptr, \ + K, \ + lda, \ + ldb, \ + ldc); + +template +void tinygemm_kernel( + const uint8_t* __restrict__ A, + const int8_t* __restrict__ B, + scalar_t* __restrict__ C, + int32_t* __restrict__ Ctmp, + const float* __restrict__ As, + const float* __restrict__ Bs, + const float* __restrict__ bias, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg) { + // B compensation + const int32_t* Bcomp = reinterpret_cast(B + block_size_n() * K); + + // pattern: 1-4-16 + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 64; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int64_t mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch (mb_size << 4 | nb_size >> 4) { + // mb_size = 1 + case 0x12: + LAUNCH_TINYGEMM_KERNEL_NN(1, 32); + break; + case 0x14: + LAUNCH_TINYGEMM_KERNEL_NN(1, 64); + break; + // mb_size = 2 + case 0x22: + LAUNCH_TINYGEMM_KERNEL_NN(2, 32); + break; + case 0x24: + LAUNCH_TINYGEMM_KERNEL_NN(2, 64); + break; + // mb_size = 3 + case 0x32: + LAUNCH_TINYGEMM_KERNEL_NN(3, 32); + break; + case 0x34: + LAUNCH_TINYGEMM_KERNEL_NN(3, 64); + break; + // mb_size = 4 + case 0x42: + LAUNCH_TINYGEMM_KERNEL_NN(4, 32); + break; + case 0x44: + LAUNCH_TINYGEMM_KERNEL_NN(4, 64); + break; + default: + TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +template +void int8_scaled_mm_kernel_impl( + scalar_t* __restrict__ out, + const uint8_t* __restrict__ mat1, + const int8_t* __restrict__ mat2, + const float* __restrict__ scales1, + const float* __restrict__ scales2, + const float* __restrict__ bias, + int64_t M, + int64_t N, + int64_t K) { + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + // TODO: brgemm u8s8 depends on PyTorch 2.7 release. + const bool use_brgemm = false; + + // K + 4 after compensation + const int64_t packed_row_size = get_row_size(K); + + AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] { + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + int64_t mb{0}, nb{0}; + data_index_init(begin, mb, MB, nb, NB); + + // for brgemm, use int32_t for accumulate + alignas(64) int32_t Ctmp[BLOCK_M * BLOCK_N]; + + for (int i = begin; i < end; ++i) { + UNUSED(i); + int mb_start = mb * BLOCK_M; + int mb_size = std::min(M - mb_start, BLOCK_M); + int nb_start = nb * BLOCK_N; + int nb_size = std::min(N - nb_start, BLOCK_N); + + tinygemm_kernel( + /* A */ mat1 + mb_start * K, + /* B */ mat2 + nb_start * packed_row_size /* nb * BLOCK_N * (K + 4) */, + /* C */ out + mb_start * N + nb_start, + /* Ctmp*/ Ctmp, + /* As */ scales1 + mb_start, + /* Bs */ scales2 + nb_start, + /* bias*/ bias + nb_start, + /* M */ mb_size, + /* N */ nb_size, + /* K */ K, + /* lda */ K, + /* ldb */ nb_size, + /* ldc */ N, + /* brg */ use_brgemm); + + // move to the next index + data_index_step(mb, MB, nb, NB); + } + + if (use_brgemm) { + at::native::cpublas::brgemm_release(); + } + }); + }); +} + +} // anonymous namespace + +// tinygemm interface +template +void tinygemm_kernel( + const uint8_t* __restrict__ A, + const int8_t* __restrict__ B, + scalar_t* __restrict__ C, + int32_t* __restrict__ Ctmp, + const float* __restrict__ As, + const float* __restrict__ Bs, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc, + bool brg) { + tinygemm_kernel(A, B, C, Ctmp, As, Bs, nullptr, M, N, K, lda, ldb, ldc, brg); +} + +#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \ + template void tinygemm_kernel( \ + const uint8_t* __restrict__ A, \ + const int8_t* __restrict__ B, \ + TYPE* __restrict__ C, \ + int32_t* __restrict__ Ctmp, \ + const float* __restrict__ As, \ + const float* __restrict__ Bs, \ + int64_t M, \ + int64_t N, \ + int64_t K, \ + int64_t lda, \ + int64_t ldb, \ + int64_t ldc, \ + bool brg) + +INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16); +INSTANTIATE_TINYGEMM_TEMPLATE(at::Half); + +std::tuple per_token_quant_int8_cpu(at::Tensor& A) { + RECORD_FUNCTION("sgl-kernel::per_token_quant_int8_cpu", std::vector({A})); + + CHECK_LAST_DIM_CONTIGUOUS_INPUT(A); + CHECK_DIM(2, A); + + int64_t M = A.size(0); + int64_t K = A.size(1); + int64_t lda = A.stride(0); + + const auto st = A.scalar_type(); + TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, "per_token_quant_int8: expect A to be bfloat16 or half."); + + auto Aq = at::empty({M, K}, A.options().dtype(at::kByte)); + auto As = at::empty({M}, A.options().dtype(at::kFloat)); + + AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "per_token_quant_int8", [&] { + uint8_t* __restrict__ Aq_data = Aq.data_ptr(); + float* __restrict__ As_data = As.data_ptr(); + const scalar_t* __restrict__ A_data = A.data_ptr(); + + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + quantize_row_int8(Aq_data + m * K, As_data[m], A_data + m * lda, K); + } + }); + }); + return std::make_tuple(Aq, As); +} + +// weight : static, per-channel, symmetric +// activation : dynamic, per-token, symmetric +// +// mat1 : [M, K] +// mat2 : [N, K] +// scales1 : [M] +// scales2 : [N] +// bias : [N] +// out : [M, N] +// +at::Tensor int8_scaled_mm_cpu( + at::Tensor& mat1, + at::Tensor& mat2, + at::Tensor& scales1, + at::Tensor& scales2, + std::optional& bias, + at::ScalarType out_dtype, + bool is_vnni) { + RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector({mat1, mat2, scales1, scales2, bias})); + + auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); + + CHECK_INPUT(mat1); + CHECK_INPUT(mat2); + CHECK_INPUT(scales1); + CHECK_INPUT(scales2); + CHECK_DIM(2, mat1); + CHECK_DIM(2, mat2); + + int64_t M = mat1.size(0); + int64_t N = mat2.size(0); + int64_t K = mat1.size(1); + + // see [NOTE]: s8s8 igemm compensation in avx512-vnni + CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K)); + CHECK_EQ(scales1.numel(), M); + CHECK_EQ(scales2.numel(), N); + + TORCH_CHECK(mat1.scalar_type() == at::kByte, "int8_scaled_mm: expect mat1 to be uint8."); + TORCH_CHECK(mat2.scalar_type() == at::kChar, "int8_scaled_mm: expect mat2 to be int8."); + TORCH_CHECK( + scales1.scalar_type() == at::kFloat && scales2.scalar_type() == at::kFloat, + "int8_scaled_mm: expect scales to be float32."); + + auto out = at::empty({M, N}, mat1.options().dtype(out_dtype)); + + const bool has_bias = bias.has_value(); + const float* bias_data = nullptr; + if (has_bias) { + CHECK_EQ(bias.value().size(0), N); + bias_data = bias.value().data_ptr(); + } + + AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_kernel_impl", [&] { + int8_scaled_mm_kernel_impl( + out.data_ptr(), + mat1.data_ptr(), + packed_w.data_ptr(), + scales1.data_ptr(), + scales2.data_ptr(), + bias_data, + M, + N, + K); + }); + return out; +} + +// fused `per_token_quant_int8_cpu` and `int8_scaled_mm_cpu` +at::Tensor int8_scaled_mm_with_quant( + at::Tensor& mat1, + at::Tensor& mat2, + at::Tensor& scales2, + std::optional& bias, + at::ScalarType out_dtype, + bool is_vnni) { + RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector({mat1, mat2, scales2, bias})); + + auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2); + + CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1); + CHECK_INPUT(mat2); + CHECK_INPUT(scales2); + CHECK_DIM(2, mat1); + CHECK_DIM(2, mat2); + + int64_t M = mat1.size(0); + int64_t N = mat2.size(0); + int64_t K = mat1.size(1); + int64_t lda = mat1.stride(0); + + // see [NOTE]: s8s8 igemm compensation in avx512-vnni + CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K)); + CHECK_EQ(scales2.numel(), N); + + const auto st = mat1.scalar_type(); + TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, "int8_scaled_mm_with_quant: expect A to be bfloat16 or half."); + TORCH_CHECK(st == out_dtype, "int8_scaled_mm_with_quant: expect A has same dtype with out_dtype."); + TORCH_CHECK(mat2.scalar_type() == at::kChar, "int8_scaled_mm_with_quant: expect mat2 to be int8."); + TORCH_CHECK(scales2.scalar_type() == at::kFloat, "int8_scaled_mm_with_quant: expect scales to be float32."); + + const int64_t buffer_size = M * K + M * sizeof(float); + auto buffer = at::empty({buffer_size}, mat1.options().dtype(at::kByte)); + auto out = at::empty({M, N}, mat1.options().dtype(out_dtype)); + + const bool has_bias = bias.has_value(); + const float* bias_data = nullptr; + if (has_bias) { + CHECK_EQ(bias.value().size(0), N); + bias_data = bias.value().data_ptr(); + } + + AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_with_quant_kernel_impl", [&] { + uint8_t* __restrict__ Aq_data = buffer.data_ptr(); + float* __restrict__ As_data = (float*)((void*)(Aq_data + M * K)); + const scalar_t* __restrict__ A_data = mat1.data_ptr(); + + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + quantize_row_int8(Aq_data + m * K, As_data[m], A_data + m * lda, K); + } + }); + + int8_scaled_mm_kernel_impl( + out.data_ptr(), + Aq_data, + packed_w.data_ptr(), + As_data, + scales2.data_ptr(), + bias_data, + M, + N, + K); + }); + return out; +} diff --git a/sgl-kernel/csrc/cpu/interface.cpp b/sgl-kernel/csrc/cpu/interface.cpp new file mode 100644 index 000000000..cc11c4928 --- /dev/null +++ b/sgl-kernel/csrc/cpu/interface.cpp @@ -0,0 +1,120 @@ +#include +#include + +#include "shm.h" + +// Communication settings +static int world_rank = -1; +static int world_size = -1; + +static bool is_initialized = false; + +static bool all_ranks_local_p = false; + +void initialize(int size, int rank) { + if (is_initialized) { + return; + } + + // Check whether all ranks is on the same physical machine. + // If true, we will use an SHM based low latency allreduce + + auto ls_string = std::getenv("LOCAL_SIZE"); + int ls = 0; + if (ls_string != NULL) { + ls = std::stoi(std::getenv("LOCAL_SIZE")); + } + + if (size >= 1 && size == ls) { + all_ranks_local_p = true; + } + + world_size = size; + world_rank = rank; + is_initialized = true; + + auto addr_string = std::getenv("MASTER_ADDR"); + if (addr_string == NULL) { + addr_string = ""; + } + auto port_string = std::getenv("MASTER_PORT"); + if (port_string == NULL) { + port_string = ""; + } + + if (all_ranks_local_p) { + shm_initialize(size, rank, addr_string, port_string); + } +} + +void shm_allreduce(torch::Tensor& data, c10::intrusive_ptr process_group, py::object op) { + RECORD_FUNCTION("sgl-kernel::shm_allreduce", std::vector({data})); + + static py::object ReduceOp = py::module_::import("torch.distributed").attr("ReduceOp"); + static auto ReduceOpSum = (int)py::int_(ReduceOp.attr("SUM").attr("value")); + TORCH_CHECK(py::int_(op.attr("value")) == ReduceOpSum, "Only torch.distributed.ReduceOp.SUM is supported"); + + auto numel = data.numel(); + + int data_size = 0; + bool data_type_fallback = false; + + switch (data.scalar_type()) { + case c10::ScalarType::BFloat16: + data_size = numel * 2; + break; + case c10::ScalarType::Float: + data_size = numel * 4; + break; + default: + data_type_fallback = true; + } + + if (data_type_fallback || !all_ranks_local_p) { + // Fallback to torch distributed allreduce + std::vector tensors = {data}; + process_group->allreduce(tensors)->wait(); + } else { + all_reduce_outer_loop(data, numel, data_size); + } + + return; +} + +torch::Tensor shm_allgather(torch::Tensor& data, c10::intrusive_ptr process_group, int dim) { + RECORD_FUNCTION("sgl-kernel::shm_allgather", std::vector({data})); + + auto numel = data.numel(); + + int data_size = 0; + bool data_type_fallback = false; + + switch (data.scalar_type()) { + case c10::ScalarType::BFloat16: + data_size = numel * 2; + break; + case c10::ScalarType::Float: + data_size = numel * 4; + break; + default: + data_type_fallback = true; + } + if (dim < 0) { + dim += data.dim(); + } + if (data_type_fallback || !all_ranks_local_p) { + // Fallback to torch distributed allreduce + std::vector> output_tensors(1); + auto world_size = process_group->getSize(); + for (int i = 0; i < world_size; i++) { + output_tensors[0].push_back(torch::empty_like(data)); + } + std::vector input_tensors = {data}; + process_group->allgather(output_tensors, input_tensors)->wait(); + return torch::cat(output_tensors[0], dim).contiguous(); + } + std::vector result_shape = data.sizes().vec(); + result_shape[dim] *= world_size; + torch::Tensor result_tensor = torch::empty(result_shape, data.options()); + return all_gather(result_tensor, data, dim, numel, data_size); +} diff --git a/sgl-kernel/csrc/cpu/moe.cpp b/sgl-kernel/csrc/cpu/moe.cpp new file mode 100644 index 000000000..05825e04f --- /dev/null +++ b/sgl-kernel/csrc/cpu/moe.cpp @@ -0,0 +1,1247 @@ +#include "common.h" +#include "gemm.h" +#include "vec.h" + +namespace { + +// [NOTE]: Fused MoE kernel with AMX +// +// This file contains implementations for +// * `moe_align_block_size` +// * `fused_moe` +// +// The functionality is identical to triton kernel, excepts: +// * fuse silu_and_mul with gemm1, therefore this kernel +// allocates 2 intermediate_caches instead of 3 +// * add `offsets` in `moe_align_block_size` which keeps track +// of starting offset for each M block. this is for keeping +// output of silu_and_mul in sorted order, thus load_A for +// the 2nd gemm would be contiguous, therefore we can directly +// load A from intermediate_cache1. +// +// TODO: +// 1. tune BLOCK_M and BLOCK_N (BLOCK_N * K fit L2) +// 2. add prefetch for load A which is indexed access +// 3. abstract at::native::cpublas::brgemm with WoQ gemm (M = 1 & M != 1) +// + +template +inline void fill_stub(scalar_t* __restrict__ out, scalar_t val, int64_t size) { + using Vec = at::vec::Vectorized; + const Vec data_vec(val); + at::vec::map([data_vec](Vec out) { return out = data_vec; }, out, out, size); +} + +template +inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) { + using Vec = at::vec::Vectorized; +// no remainder +#pragma GCC unroll 4 + for (int64_t d = 0; d < size; d += Vec::size()) { + Vec data = Vec::loadu(input + d); + data.store(out + d); + } +} + +template +inline void copy_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input, float weight, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + const fVec weight_vec = fVec(weight); + int64_t d; +#pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec data0 = fVec::loadu(input + d) * weight_vec; + fVec data1 = fVec::loadu(input + d + fVec::size()) * weight_vec; + bVec out_vec = convert_from_float_ext(data0, data1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] * weight); + } +} + +// acc from [topk, K] to [K] +template +inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + if (topk == 1) { + // do copy for topk = 1 + copy_stub(out, input, K); + } else { + // do sum for topk != 1 + int64_t d; +#pragma GCC unroll 4 + for (d = 0; d <= K - kVecSize; d += kVecSize) { + fVec sum_fvec0 = fVec(0.f); + fVec sum_fvec1 = fVec(0.f); + for (int t = 0; t < topk; ++t) { + bVec x_bvec = bVec::loadu(input + t * K + d); + fVec x_fvec0, x_fvec1; + std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec); + + sum_fvec0 += x_fvec0; + sum_fvec1 += x_fvec1; + } + bVec out_bvec = convert_from_float_ext(sum_fvec0, sum_fvec1); + out_bvec.store(out + d); + } + for (; d < K; ++d) { + float sum_val = 0.f; + for (int t = 0; t < topk; ++t) { + sum_val += static_cast(input[t * K + d]); + } + out[d] = static_cast(sum_val); + } + } +} + +// out = input + input2 * scale +template +inline void add_mul_stub( + scalar_t* __restrict__ out, + const float* __restrict__ input, + const scalar_t* __restrict__ input2, + float scale, + int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + const fVec s_vec = fVec(scale); + int64_t d; +#pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec x0 = fVec::loadu(input + d); + fVec x1 = fVec::loadu(input + d + fVec::size()); + + bVec y_bvec = bVec::loadu(input2 + d); + fVec y0, y1; + std::tie(y0, y1) = at::vec::convert_to_float(y_bvec); + + x0 = x0 + y0 * s_vec; + x1 = x1 + y1 * s_vec; + bVec out_vec = convert_from_float_ext(x0, x1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] + float(input2[d]) * scale); + } +} + +template +int moe_align_block_size( + int32_t* __restrict__ sorted_ids, + int32_t* __restrict__ expert_ids, + int32_t* __restrict__ topk_ids, + int32_t* __restrict__ total_cnts, + int32_t* __restrict__ cumsums, + int32_t* __restrict__ offsets, + int num_experts, + int numel, + int num_threads) { +#define T_INDEX(tt) total_cnts + (tt) * num_experts + + // accumulate count of expert ids locally + at::parallel_for(0, numel, 0, [&](int begin, int end) { + int tid = at::get_thread_num(); + int32_t* __restrict__ local_cnts = T_INDEX(tid + 1); + + for (int i = begin; i < end; ++i) { + local_cnts[topk_ids[i]]++; + } + }); + + using iVec = at::vec::Vectorized; + for (int t = 0; t < num_threads; ++t) { + at::vec::map2( + [](iVec x, iVec y) { return x + y; }, T_INDEX(t + 1), T_INDEX(t + 1), T_INDEX(t), num_experts); + } + + // the last row holds sums of each experts + int32_t* total_cnts_t_1 = T_INDEX(num_threads); + + cumsums[0] = 0; + for (int e = 0; e < num_experts; ++e) { + // accumulate `num_tokens_post_pad`, also as the expert offset + cumsums[e + 1] = cumsums[e] + div_up(total_cnts_t_1[e], BLOCK_M) * BLOCK_M; + + for (int k = cumsums[e]; k < cumsums[e + 1]; k += BLOCK_M) { + expert_ids[k / BLOCK_M] = e; + } + } + int num_tokens_post_pad = cumsums[num_experts]; + + at::parallel_for(0, numel, 0, [&](int begin, int end) { + int tid = at::get_thread_num(); + // thread tid offsets in `total_cnts` + int32_t* __restrict__ offsets = T_INDEX(tid); + + for (int i = begin; i < end; ++i) { + int32_t expert_id = topk_ids[i]; + int32_t b_offset = cumsums[expert_id]; + int32_t t_offset = offsets[expert_id]; + sorted_ids[b_offset + t_offset] = i; + offsets[expert_id]++; + } + }); + + // debug: the offset for thread t_1 should be identical to t_2 + int32_t* total_cnts_t_2 = T_INDEX(num_threads - 1); + for (int e = 0; e < num_experts; ++e) { + TORCH_CHECK(total_cnts_t_1[e] == total_cnts_t_2[e]); + } + + // padding value for sorted_ids: numel + auto sorted_id_size = [=](const int32_t* sorted_ids_ptr) { + for (int d = 0; d < BLOCK_M; ++d) { + if (sorted_ids_ptr[d] == numel) { + return d; + } + } + return BLOCK_M; + }; + + // offsets holds starting offset for each valida M blocks + // shape : [num_token_blocks + 1] + offsets[0] = 0; + const int num_token_blocks = num_tokens_post_pad / BLOCK_M; + at::parallel_for(0, num_token_blocks, GRAIN_SIZE / BLOCK_M, [&](int begin, int end) { + for (int mb = begin; mb < end; ++mb) { + offsets[mb + 1] = sorted_id_size(sorted_ids + mb * BLOCK_M); + } + }); + // TODO: do we need to vecterize this ? + for (int mb = 0; mb < num_token_blocks; ++mb) { + offsets[mb + 1] += offsets[mb]; + } + // debug: the last value of offsets should be `numel` + TORCH_CHECK(offsets[num_token_blocks] == numel); + + return num_tokens_post_pad; +} + +// silu : shape leading dimension +// input0 [m_size, BLOCK_N] BLOCK_N +// input1 [m_size, BLOCK_N] BLOCK_N +// output [M * topk, N] N +template +inline void silu_and_mul( + scalar_t* __restrict__ output, + const float* __restrict__ input0, // x: x0, x1 + const float* __restrict__ input1, // y: y0, y1 + int64_t m_size, + int64_t N) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + + const fVec one = fVec(1.f); + + // no remainder + for (int64_t m = 0; m < m_size; ++m) { + scalar_t* __restrict__ out = output + m * N; + const float* __restrict__ x = input0 + m * BLOCK_N; + const float* __restrict__ y = input1 + m * BLOCK_N; + + for (int64_t d = 0; d < BLOCK_N; d += bVec::size()) { + fVec x0 = fVec::loadu(x + d); + fVec x1 = fVec::loadu(x + d + fVec::size()); + fVec y0 = fVec::loadu(y + d); + fVec y1 = fVec::loadu(y + d + fVec::size()); + // silu + x0 = x0 / (one + x0.neg().exp_u20()); + x1 = x1 / (one + x1.neg().exp_u20()); + // mul + x0 = x0 * y0; + x1 = x1 * y1; + // convert + bVec out_vec = convert_from_float_ext(x0, x1); + out_vec.store(out + d); + } + } +} + +template +struct tinygemm_kernel_nn2 { + static inline void apply( + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ B0, + const scalar_t* __restrict__ B1, + scalar_t* __restrict__ C, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_nn2 { + static inline void apply( + const at::BFloat16* __restrict__ A, + const at::BFloat16* __restrict__ B0, + const at::BFloat16* __restrict__ B1, + at::BFloat16* __restrict__ C, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + + static_assert(COLS % 2 == 0); + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 0; + + __m512bh va; + __m512bh vb0[COLS]; + __m512bh vb1[COLS]; + __m512 vc0[ROWS * COLS]; + __m512 vc1[ROWS * COLS]; + + auto loadc = [&](auto i) { + vc0[i] = _mm512_set1_ps(0.f); + vc1[i] = _mm512_set1_ps(0.f); + }; + Unroll{}(loadc); + + const int64_t K2 = K >> 1; + const int64_t lda2 = lda >> 1; + const int64_t ldb2 = ldb; // ldb * 2 >> 1; + const float* a_ptr = reinterpret_cast(A); + const float* b0_ptr = reinterpret_cast(B0); + const float* b1_ptr = reinterpret_cast(B1); + + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k])); + } + if constexpr (row == 0) { + vb0[col] = (__m512bh)(_mm512_loadu_si512(b0_ptr + k * ldb2 + col * 16)); + vb1[col] = (__m512bh)(_mm512_loadu_si512(b1_ptr + k * ldb2 + col * 16)); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(b0_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); + _mm_prefetch(b1_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); + } + } + vc0[i] = _mm512_dpbf16_ps(vc0[i], va, vb0[col]); + vc1[i] = _mm512_dpbf16_ps(vc1[i], va, vb1[col]); + }; + for (int64_t k = 0; k < K2; ++k) { + Unroll{}(compute, k); + } + + using Vec = at::vec::Vectorized; + const Vec one = Vec(1.f); + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + // for COLS = 2, 4 use 512bit store + if constexpr (col % 2 == 0) { + Vec x0 = vc0[row * COLS + col + 0]; + Vec x1 = vc0[row * COLS + col + 1]; + Vec y0 = vc1[row * COLS + col + 0]; + Vec y1 = vc1[row * COLS + col + 1]; + // silu + x0 = x0 / (one + x0.neg().exp_u20()); + x1 = x1 / (one + x1.neg().exp_u20()); + // mul + x0 = x0 * y0; + x1 = x1 * y1; + + _mm512_storeu_si512( + reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), + (__m512i)(_mm512_cvtne2ps_pbh(__m512(x1), __m512(x0)))); + } + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_nn2::apply( \ + A + mb_start * lda, B0 + nb_start * 2, B1 + nb_start * 2, C + mb_start * ldc + nb_start, K, lda, ldb, ldc); + +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ B0, + const scalar_t* __restrict__ B1, + scalar_t* __restrict__ C, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + // pattern: 1-(2+2)-(8+8) + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 32; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch (mb_size << 4 | nb_size >> 4) { + // mb_size = 1 + case 0x12: + LAUNCH_TINYGEMM_KERNEL_NN(1, 32); + break; + // mb_size = 2 + case 0x22: + LAUNCH_TINYGEMM_KERNEL_NN(2, 32); + break; + // mb_size = 3 + case 0x32: + LAUNCH_TINYGEMM_KERNEL_NN(3, 32); + break; + // mb_size = 4 + case 0x42: + LAUNCH_TINYGEMM_KERNEL_NN(4, 32); + break; + default: + TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +template +struct tinygemm_kernel_nn { + static inline void apply( + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ B, + float* __restrict__ C, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_nn { + static inline void apply( + const at::BFloat16* __restrict__ A, + const at::BFloat16* __restrict__ B, + float* __restrict__ C, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + + static_assert(COLS % 2 == 0); + + // prefetch distance + constexpr int PREFETCH_SIZE_K = 0; + + __m512bh va; + __m512bh vb[COLS]; + __m512 vc[ROWS * COLS]; + + auto loadc = [&](auto i) { vc[i] = _mm512_set1_ps(0.f); }; + Unroll{}(loadc); + + const int64_t K2 = K >> 1; + const int64_t lda2 = lda >> 1; + const int64_t ldb2 = ldb; // ldb * 2 >> 1; + const float* a_ptr = reinterpret_cast(A); + const float* b_ptr = reinterpret_cast(B); + + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k])); + } + if constexpr (row == 0) { + vb[col] = (__m512bh)(_mm512_loadu_si512(b_ptr + k * ldb2 + col * 16)); + if constexpr (PREFETCH_SIZE_K > 0) { + _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); + } + } + vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]); + }; + for (int64_t k = 0; k < K2; ++k) { + Unroll{}(compute, k); + } + + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + _mm512_storeu_ps(reinterpret_cast<__m512*>(C + row * ldc + col * 16), vc[i]); + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_NN2(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_nn::apply( \ + A + mb_start * lda, B + nb_start * 2, C + mb_start * ldc + nb_start, K, lda, ldb, ldc); + +template +void tinygemm_kernel( + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ B, + float* __restrict__ C, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + // pattern: 1-2-8 + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 32; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch (mb_size << 4 | nb_size >> 4) { + // mb_size = 1 + case 0x12: + LAUNCH_TINYGEMM_KERNEL_NN2(1, 32); + break; + // mb_size = 2 + case 0x22: + LAUNCH_TINYGEMM_KERNEL_NN2(2, 32); + break; + // mb_size = 3 + case 0x32: + LAUNCH_TINYGEMM_KERNEL_NN2(3, 32); + break; + // mb_size = 4 + case 0x42: + LAUNCH_TINYGEMM_KERNEL_NN2(4, 32); + break; + default: + TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +template +void fused_experts_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ ic2, + scalar_t* __restrict__ A_tmp, + float* __restrict__ C_tmp, + const scalar_t* __restrict__ input, + const scalar_t* __restrict__ packed_w1, + const scalar_t* __restrict__ packed_w2, + const float* __restrict__ topk_weights, + const int32_t* __restrict__ sorted_ids, + const int32_t* __restrict__ expert_ids, + const int32_t* __restrict__ offsets, + int64_t M, + int64_t N, + int64_t K, + int64_t E, + int64_t topk, + int64_t num_tokens_post_pad) { + // handle 2 tiles per block + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + // stage 1: intermediate_cache1 = silu(hidden_states @ w1) + const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + // strides for w1: [E, 2N, K] + TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N); + + const int64_t stride_e = 2 * N * K; + const int64_t stride_n = K; + + // here we only parallel on half of 2N to fuse silu_and_mul with gemm + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K; + float* __restrict__ C0 = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; + float* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N; + + bool is_brgemm_used = false; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB; + int64_t nb = i % NB; + + // nb0 from top half and nb1 from bottom half + int64_t nb0 = nb, nb1 = nb + NB; + int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N); + + // B shape [K, n_size] in vnni format + int32_t expert_id = expert_ids[mb]; + const scalar_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb0 * BLOCK_N * stride_n; + const scalar_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb1 * BLOCK_N * stride_n; + + // 1.a load A + const int32_t* A_ids = sorted_ids + mb * BLOCK_M; + int64_t m_size = offsets[mb + 1] - offsets[mb]; + + const bool use_brgemm = can_use_brgemm(m_size); + is_brgemm_used = is_brgemm_used || use_brgemm; + + for (int64_t m = 0; m < m_size; ++m) { + int32_t index = A_ids[m] / topk; + copy_stub(A + m * K, input + index * K, K); + } + + if (use_brgemm) { + // 1.b gemm: C0 = A @ B0 + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B0, + /* C */ C0); + + // 1.c gemm: C1 = A @ B1 + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B1, + /* C */ C1); + + // 1.d silu and mul + const int64_t offset = offsets[mb]; + silu_and_mul(ic1 + offset * N + nb * BLOCK_N, C0, C1, m_size, N); + } else { + // fused 1.bcd: silu_and_mul(A @ B0, A @ B1) + const int64_t offset = offsets[mb]; + tinygemm_kernel( + /* A */ A, + /* B0 */ B0, + /* B1 */ B1, + /* C */ ic1 + offset * N + nb * BLOCK_N, + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ N); + } + } + + if (is_brgemm_used) { + at::native::cpublas::brgemm_release(); + } + }); + + // stage 2: intermediate_cache2 = intermediate_cache1 @ w2 + // w2 : [E, K, N] as [E, OC, IC] + const int64_t OC = K; // rename K as OC + const int64_t IC = N; // rename N as IC + const int64_t MB2 = MB; + const int64_t NB2 = div_up(OC, BLOCK_N); + const int64_t stride_e2 = OC * IC; + const int64_t stride_oc = IC; + + // parallel on [MB2, NB2] + at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + // we won't be using C1 for gemm2 + float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; + + bool is_brgemm_used = false; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB2; + int64_t nb = i % NB2; + + int64_t m_size = offsets[mb + 1] - offsets[mb]; + int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); + + const bool use_brgemm = can_use_brgemm(m_size); + is_brgemm_used = is_brgemm_used || use_brgemm; + + // A ptr from ic1 of [M * topk, N] in sorted order + // so as to avoid copy A to tmp buffer again + const scalar_t* __restrict__ A = ic1 + offsets[mb] * N; + const int32_t* A_ids = sorted_ids + mb * BLOCK_M; + + // B shape [IC, n_size] in vnni format + int32_t expert_id = expert_ids[mb]; + const scalar_t* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc; + + // 2.a gemm: C = A @ B + if (use_brgemm) { + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B, + /* C */ C); + } else { + tinygemm_kernel( + /* A */ A, + /* B */ B, + /* C */ C, + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N); + } + + // 2.b copy from C to ic2 in original order + // and also mul topk_weights in float32 + for (int64_t m = 0; m < m_size; ++m) { + int32_t index = A_ids[m]; + float weight = topk_weights[index]; + copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size); + } + } + + if (is_brgemm_used) { + at::native::cpublas::brgemm_release(); + } + }); + + // stage 3: out = intermediate_cache2.sum(dim=1) + // from [M, topk, K] to [M, K] + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + sum_stub(output + m * K, ic2 + m * topk * K, topk, K); + } + }); +} + +template +void shared_expert_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic1, + float* __restrict__ C_tmp, + scalar_t* __restrict__ input, + const scalar_t* __restrict__ packed_w1, + const scalar_t* __restrict__ packed_w2, + const scalar_t* __restrict__ fused_experts_out, + float routed_scaling_factor, + int64_t M, + int64_t N, + int64_t K) { + // handle 2 tiles per block + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + // stage 1: intermediate_cache1 = silu(hidden_states @ w1) + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N); + const int64_t stride_n = K; + + // here we only parallel on half of 2N to fuse silu_and_mul with gemm + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + float* __restrict__ C0 = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; + float* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N; + + bool is_brgemm_used = false; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB; + int64_t nb = i % NB; + + // nb0 from top half and nb1 from bottom half + int64_t nb0 = nb, nb1 = nb + NB; + int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N); + int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); + + // int64_t mb_start = mb * BLOCK_M; + // int64_t mb_size = std::min(M - mb_start, BLOCK_M); + + // A shape [m_size, K] + const scalar_t* A = input + mb * BLOCK_M * K; + + // B shape [K, n_size] in vnni format + const scalar_t* __restrict__ B0 = packed_w1 + nb0 * BLOCK_N * stride_n; + const scalar_t* __restrict__ B1 = packed_w1 + nb1 * BLOCK_N * stride_n; + + const bool use_brgemm = can_use_brgemm(m_size); + is_brgemm_used = is_brgemm_used || use_brgemm; + + if (use_brgemm) { + // 1.b gemm: C0 = A @ B0 + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B0, + /* C */ C0); + + // 1.c gemm: C1 = A @ B1 + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B1, + /* C */ C1); + + // 1.d silu and mul + silu_and_mul(ic1 + mb * BLOCK_M * N + nb * BLOCK_N, C0, C1, m_size, N); + } else { + // fused 1.bcd: silu_and_mul(A @ B0, A @ B1) + tinygemm_kernel( + /* A */ A, + /* B0 */ B0, + /* B1 */ B1, + /* C */ ic1 + mb * BLOCK_M * N + nb * BLOCK_N, + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ N); + } + } + + if (is_brgemm_used) { + at::native::cpublas::brgemm_release(); + } + }); + + // stage 2: output = intermediate_cache1 @ w2 + // w2 : [K, N] as [OC, IC] + const int64_t OC = K; // rename K as OC + const int64_t IC = N; // rename N as IC + const int64_t MB2 = MB; + const int64_t NB2 = div_up(OC, BLOCK_N); + const int64_t stride_oc = IC; + + // parallel on [MB2, NB2] + at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + // we won't be using C1 for gemm2 + float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; + + bool is_brgemm_used = false; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB2; + int64_t nb = i % NB2; + + int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); + int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); + + const bool use_brgemm = can_use_brgemm(m_size); + is_brgemm_used = is_brgemm_used || use_brgemm; + + // A shape [m_size, IC] + const scalar_t* __restrict__ A = ic1 + mb * BLOCK_M * N; + + // B shape [IC, n_size] in vnni format + const scalar_t* __restrict__ B = packed_w2 + nb * BLOCK_N * stride_oc; + + // 2.a gemm: C = A @ B + if (use_brgemm) { + at::native::cpublas::brgemm( + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N, + /* add_C */ false, + /* A */ A, + /* B */ B, + /* C */ C); + } else { + tinygemm_kernel( + /* A */ A, + /* B */ B, + /* C */ C, + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N); + } + + // 2.b copy from C to output and add fused_experts_out + scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N; + const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N; + for (int64_t m = 0; m < m_size; ++m) { + add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size); + } + } + + if (is_brgemm_used) { + at::native::cpublas::brgemm_release(); + } + }); +} + +} // anonymous namespace + +// hidden_states: [M, K] +// w1: [E, 2N, K] +// w2: [E, K, N] +// topk_weights: [M, topk] +// topk_ids: [M, topk] (int32_t) +// +at::Tensor fused_experts_cpu( + at::Tensor& hidden_states, + at::Tensor& w1, + at::Tensor& w2, + at::Tensor& topk_weights, + at::Tensor& topk_ids, + bool inplace, + bool use_int8_w8a8, + std::optional& w1_scale, + std::optional& w2_scale, + std::optional& a1_scale, + std::optional& a2_scale, + bool is_vnni) { + RECORD_FUNCTION( + "sgl-kernel::fused_experts_cpu", std::vector({hidden_states, w1, w2, topk_weights, topk_ids})); + + auto packed_w1 = is_vnni ? w1 : convert_weight_packed(w1); + auto packed_w2 = is_vnni ? w2 : convert_weight_packed(w2); + + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + const auto st = hidden_states.scalar_type(); + CHECK_INPUT(hidden_states); + CHECK_INPUT(w1); + CHECK_INPUT(w2); + CHECK_EQ(topk_weights.sizes(), topk_ids.sizes()); + CHECK_DIM(2, hidden_states); + CHECK_DIM(3, w1); + CHECK_DIM(3, w2); + CHECK_DIM(2, topk_weights); + CHECK_DIM(2, topk_ids); + + CHECK_EQ(topk_ids.scalar_type(), at::kInt); + CHECK_EQ(topk_weights.scalar_type(), at::kFloat); + + int64_t M = hidden_states.size(0); + int64_t K = hidden_states.size(1); + int64_t N = w1.size(1) / 2; + int64_t E = w1.size(0); + int64_t topk = topk_weights.size(1); + + // we use int32_t compensation for int8 w8a8 + int64_t packed_K = get_row_size(K, use_int8_w8a8); + int64_t packed_N = get_row_size(N, use_int8_w8a8); + + // check weight shapes + CHECK_EQ(w2.size(0), E); + CHECK_EQ(w2.size(1), K); + CHECK_EQ(packed_w1.size(2), packed_K); + CHECK_EQ(packed_w2.size(2), packed_N); + + if (use_int8_w8a8) { + TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for int8 w8a8."); + TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for int8 w8a8."); + TORCH_CHECK(!a1_scale.has_value(), "static quantization for activation not supported."); + TORCH_CHECK(!a2_scale.has_value(), "static quantization for activation not supported."); + } + + at::Tensor out_hidden_states = inplace ? hidden_states : at::empty_like(hidden_states); + + // NB: worst case is each expert holds a block with remainder of 1 + // 1. sorted_ids : [M * topk + E * (BLOCK_M - 1)] + // 2. expert_ids : [max_num_blocks] + // 3. total_cnts : [T + 1, E] + // 4. cumsums : [E + 1] + // 5. offsets : [max_num_blocks + 1] + // + int num_threads = at::get_num_threads(); + int64_t max_num_tokens_padded = M * topk + E * (BLOCK_M - 1); + int64_t max_num_blocks = div_up(max_num_tokens_padded, BLOCK_M); + auto buffer = at::empty( + {max_num_tokens_padded + max_num_blocks + (num_threads + 1) * E + (E + 1) + (max_num_blocks + 1)}, + topk_ids.options()); + + int32_t* __restrict__ sorted_ids = buffer.data_ptr(); + int32_t* __restrict__ expert_ids = sorted_ids + max_num_tokens_padded; + int32_t* __restrict__ total_cnts = expert_ids + max_num_blocks; + int32_t* __restrict__ cumsums = total_cnts + (num_threads + 1) * E; + int32_t* __restrict__ offsets = cumsums + (E + 1); + + // init sorted_ids with `numel` as the padding number + // init expert_ids with `num_experts` + int64_t numel = M * topk; + at::parallel_for(0, max_num_blocks, GRAIN_SIZE / BLOCK_M, [&](int64_t begin, int64_t end) { + int64_t m_start = begin * BLOCK_M; + int64_t m_size = std::min((end - begin) * BLOCK_M, max_num_tokens_padded - m_start); + fill_stub(sorted_ids + m_start, (int32_t)numel, m_size); + fill_stub(expert_ids + begin, (int32_t)E, end - begin); + }); + // zero total_cnts and cumsums + at::parallel_for(0, (num_threads + 1) * E + (E + 1), GRAIN_SIZE, [&](int64_t begin, int64_t end) { + fill_stub(total_cnts + begin, 0, end - begin); + }); + + // align experts index + int64_t num_tokens_post_pad = moe_align_block_size( + sorted_ids, expert_ids, topk_ids.data_ptr(), total_cnts, cumsums, offsets, E, numel, num_threads); + + // unlike triton kernel, we fuse silu with gemm1 so only need 2 intermediate_caches: + // 1. intermediate_cache1 : [M * topk, N] + // 2. intermediate_cache2 : [M * topk, K] + // 3. A_tmp : [T, BLOCK_M * K] + // 4. C_tmp : [T, 2 * BLOCK_M * BLOCK_N] + // + // for int8 w8a8: + // 5. Aq_tmp : [M, K] or [M * topk, N] + // 6. As_tmp : [M * topk] + // + int64_t buffer_size_nbytes = M * topk * N * 2 + M * topk * K * 2 + + num_threads * BLOCK_M * K * (use_int8_w8a8 ? 1 : 2) + + num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float); + + if (use_int8_w8a8) { + buffer_size_nbytes += std::max(M * K, M * topk * N) + M * topk * sizeof(float); + } + + auto buffer2 = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar)); + + AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "fused_experts_kernel_impl", [&] { + scalar_t* __restrict__ intermediate_cache1 = (scalar_t*)((void*)(buffer2.data_ptr())); + scalar_t* __restrict__ intermediate_cache2 = intermediate_cache1 + M * topk * N; + + if (use_int8_w8a8) { + uint8_t* __restrict__ A_tmp = (uint8_t*)((void*)(intermediate_cache2 + M * topk * K)); + float* __restrict__ C_tmp = (float*)((void*)(A_tmp + num_threads * BLOCK_M * K)); + uint8_t* __restrict__ Aq_tmp = (uint8_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N)); + float* __restrict__ As_tmp = (float*)((void*)(Aq_tmp + std::max(M * K, M * topk * N))); + + auto w1s = w1_scale.value(); + auto w2s = w2_scale.value(); + TORCH_CHECK(w1s.numel() == E * 2 * N); + TORCH_CHECK(w2s.numel() == E * K); + + fused_experts_int8_kernel_impl( + out_hidden_states.data_ptr(), + intermediate_cache1, + intermediate_cache2, + A_tmp, + C_tmp, + Aq_tmp, + As_tmp, + hidden_states.data_ptr(), + packed_w1.data_ptr(), + packed_w2.data_ptr(), + w1s.data_ptr(), + w2s.data_ptr(), + topk_weights.data_ptr(), + sorted_ids, + expert_ids, + offsets, + M, + N, + K, + E, + topk, + num_tokens_post_pad); + } else { + scalar_t* __restrict__ A_tmp = intermediate_cache2 + M * topk * K; + float* __restrict__ C_tmp = (float*)((void*)(A_tmp + num_threads * BLOCK_M * K)); + + fused_experts_kernel_impl( + out_hidden_states.data_ptr(), + intermediate_cache1, + intermediate_cache2, + A_tmp, + C_tmp, + hidden_states.data_ptr(), + packed_w1.data_ptr(), + packed_w2.data_ptr(), + topk_weights.data_ptr(), + sorted_ids, + expert_ids, + offsets, + M, + N, + K, + E, + topk, + num_tokens_post_pad); + } + }); + return out_hidden_states; +} + +// shared expert kernel +// +// hidden_states: [M, K] +// w1: [2N, K] +// w2: [K, N] +// fused_experts_out +at::Tensor shared_expert_cpu( + at::Tensor& hidden_states, + at::Tensor& w1, + at::Tensor& w2, + at::Tensor& fused_experts_out, + double routed_scaling_factor, + bool inplace, + bool use_int8_w8a8, + std::optional& w1_scale, + std::optional& w2_scale, + std::optional& a1_scale, + std::optional& a2_scale, + bool is_vnni) { + RECORD_FUNCTION("sgl-kernel::shared_expert_cpu", std::vector({hidden_states, w1, w2})); + + auto packed_w1 = is_vnni ? w1 : convert_weight_packed(w1); + auto packed_w2 = is_vnni ? w2 : convert_weight_packed(w2); + + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + const auto st = hidden_states.scalar_type(); + CHECK_INPUT(hidden_states); + CHECK_INPUT(fused_experts_out); + CHECK_INPUT(w1); + CHECK_INPUT(w2); + CHECK_DIM(2, hidden_states); + CHECK_DIM(2, w1); + CHECK_DIM(2, w2); + CHECK_EQ(hidden_states.sizes(), fused_experts_out.sizes()); + CHECK_EQ(hidden_states.scalar_type(), st); + + int64_t M = hidden_states.size(0); + int64_t K = hidden_states.size(1); + int64_t N = w1.size(0) / 2; + + // we use int32_t compensation for int8 w8a8 + int64_t packed_K = get_row_size(K, use_int8_w8a8); + int64_t packed_N = get_row_size(N, use_int8_w8a8); + + // check weight shapes + CHECK_EQ(w2.size(0), K); + CHECK_EQ(packed_w1.size(1), packed_K); + CHECK_EQ(packed_w2.size(1), packed_N); + + if (use_int8_w8a8) { + TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for int8 w8a8."); + TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for int8 w8a8."); + TORCH_CHECK(!a1_scale.has_value(), "static quantization for activation not supported."); + TORCH_CHECK(!a2_scale.has_value(), "static quantization for activation not supported."); + } + + at::Tensor out_hidden_states = inplace ? hidden_states : at::empty_like(hidden_states); + + // unlike triton kernel, we fuse silu with gemm1 so only need 2 intermediate_caches: + // 1. intermediate_cache1 : [M, N] + // 2. C_tmp : [T, 2 * BLOCK_M * BLOCK_N] + // + // for int8 w8a8: + // 3. Aq_tmp : [M, K] or [M, N] + // 4. As_tmp : [M] + // + int num_threads = at::get_num_threads(); + int64_t buffer_size_nbytes = M * N * 2 + num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float); + + if (use_int8_w8a8) { + buffer_size_nbytes += std::max(M * K, M * N) + M * sizeof(float); + } + + auto buffer = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar)); + AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "share_experts_kernel_impl", [&] { + scalar_t* __restrict__ intermediate_cache1 = (scalar_t*)((void*)(buffer.data_ptr())); + float* __restrict__ C_tmp = (float*)((void*)(intermediate_cache1 + M * N)); + + if (use_int8_w8a8) { + uint8_t* __restrict__ Aq_tmp = (uint8_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N)); + float* __restrict__ As_tmp = (float*)((void*)(Aq_tmp + std::max(M * K, M * N))); + + auto w1s = w1_scale.value(); + auto w2s = w2_scale.value(); + TORCH_CHECK(w1s.numel() == 2 * N); + TORCH_CHECK(w2s.numel() == K); + + shared_expert_int8_kernel_impl( + out_hidden_states.data_ptr(), + intermediate_cache1, + C_tmp, + Aq_tmp, + As_tmp, + hidden_states.data_ptr(), + packed_w1.data_ptr(), + packed_w2.data_ptr(), + w1s.data_ptr(), + w2s.data_ptr(), + fused_experts_out.data_ptr(), + routed_scaling_factor, + M, + N, + K); + } else { + shared_expert_kernel_impl( + out_hidden_states.data_ptr(), + intermediate_cache1, + C_tmp, + hidden_states.data_ptr(), + packed_w1.data_ptr(), + packed_w2.data_ptr(), + fused_experts_out.data_ptr(), + routed_scaling_factor, + M, + N, + K); + } + }); + return out_hidden_states; +} diff --git a/sgl-kernel/csrc/cpu/moe_int8.cpp b/sgl-kernel/csrc/cpu/moe_int8.cpp new file mode 100644 index 000000000..e12e5e7cf --- /dev/null +++ b/sgl-kernel/csrc/cpu/moe_int8.cpp @@ -0,0 +1,830 @@ +#include "common.h" +#include "gemm.h" +#include "vec.h" + +namespace { + +template +inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) { + using Vec = at::vec::Vectorized; +// no remainder +#pragma GCC unroll 4 + for (int64_t d = 0; d < size; d += Vec::size()) { + Vec data = Vec::loadu(input + d); + data.store(out + d); + } +} + +template <> +inline void copy_stub(uint8_t* __restrict__ out, const uint8_t* __restrict__ input, int64_t size) { + // size might be 64x + 32 + std::memcpy(out, input, size * sizeof(uint8_t)); +} + +template +inline void copy_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input, float weight, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + const fVec weight_vec = fVec(weight); + int64_t d; +#pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec data0 = fVec::loadu(input + d) * weight_vec; + fVec data1 = fVec::loadu(input + d + fVec::size()) * weight_vec; + bVec out_vec = convert_from_float_ext(data0, data1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] * weight); + } +} + +// acc from [topk, K] to [K] +template +inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + if (topk == 1) { + // do copy for topk = 1 + copy_stub(out, input, K); + } else { + // do sum for topk != 1 + int64_t d; +#pragma GCC unroll 4 + for (d = 0; d <= K - kVecSize; d += kVecSize) { + fVec sum_fvec0 = fVec(0.f); + fVec sum_fvec1 = fVec(0.f); + for (int t = 0; t < topk; ++t) { + bVec x_bvec = bVec::loadu(input + t * K + d); + fVec x_fvec0, x_fvec1; + std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec); + + sum_fvec0 += x_fvec0; + sum_fvec1 += x_fvec1; + } + bVec out_bvec = convert_from_float_ext(sum_fvec0, sum_fvec1); + out_bvec.store(out + d); + } + for (; d < K; ++d) { + float sum_val = 0.f; + for (int t = 0; t < topk; ++t) { + sum_val += static_cast(input[t * K + d]); + } + out[d] = static_cast(sum_val); + } + } +} + +// out = input + input2 * scale +template +inline void add_mul_stub( + scalar_t* __restrict__ out, + const float* __restrict__ input, + const scalar_t* __restrict__ input2, + float scale, + int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + constexpr int kVecSize = bVec::size(); + const fVec s_vec = fVec(scale); + int64_t d; +#pragma GCC unroll 4 + for (d = 0; d <= size - kVecSize; d += kVecSize) { + fVec x0 = fVec::loadu(input + d); + fVec x1 = fVec::loadu(input + d + fVec::size()); + + bVec y_bvec = bVec::loadu(input2 + d); + fVec y0, y1; + std::tie(y0, y1) = at::vec::convert_to_float(y_bvec); + + x0 = x0 + y0 * s_vec; + x1 = x1 + y1 * s_vec; + bVec out_vec = convert_from_float_ext(x0, x1); + out_vec.store(out + d); + } + for (; d < size; ++d) { + out[d] = static_cast(input[d] + float(input2[d]) * scale); + } +} + +/// gemm for w13 +template +struct tinygemm_kernel_vnni { + static inline void apply( + const uint8_t* __restrict__ A, + const int8_t* __restrict__ B0, + const int8_t* __restrict__ B1, + scalar_t* __restrict__ C, + const float* __restrict__ As, + const float* __restrict__ Bs0, + const float* __restrict__ Bs1, + const int32_t* __restrict__ Bcomp0, + const int32_t* __restrict__ Bcomp1, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_vnni { + static inline void apply( + const uint8_t* __restrict__ A, + const int8_t* __restrict__ B0, + const int8_t* __restrict__ B1, + at::BFloat16* __restrict__ C, + const float* __restrict__ As, + const float* __restrict__ Bs0, + const float* __restrict__ Bs1, + const int32_t* __restrict__ Bcomp0, + const int32_t* __restrict__ Bcomp1, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + static_assert(COLS % 2 == 0); + + __m512i va; + __m512i vb0[COLS]; + __m512i vb1[COLS]; + __m512i vc0[ROWS * COLS]; + __m512i vc1[ROWS * COLS]; + __m512i vcomp0[COLS]; + __m512i vcomp1[COLS]; + __m512 vas; + __m512 vbs0[COLS]; + __m512 vbs1[COLS]; + + auto loadc = [&](auto i) { + vc0[i] = _mm512_set1_epi32(0); + vc1[i] = _mm512_set1_epi32(0); + }; + Unroll{}(loadc); + + const int64_t K4 = K >> 2; + const int64_t lda4 = lda >> 2; + const int64_t ldb4 = ldb; // ldb * 4 >> 2; + const int32_t* a_ptr = reinterpret_cast(A); + const int32_t* b0_ptr = reinterpret_cast(B0); + const int32_t* b1_ptr = reinterpret_cast(B1); + + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = _mm512_set1_epi32(a_ptr[row * lda4 + k]); + } + if constexpr (row == 0) { + vb0[col] = _mm512_loadu_si512(b0_ptr + k * ldb4 + col * 16); + vb1[col] = _mm512_loadu_si512(b1_ptr + k * ldb4 + col * 16); + } + vc0[i] = _mm512_dpbusd_epi32(vc0[i], va, vb0[col]); + vc1[i] = _mm512_dpbusd_epi32(vc1[i], va, vb1[col]); + }; + for (int64_t k = 0; k < K4; ++k) { + Unroll{}(compute, k); + } + + auto scalec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + // load a scale + if constexpr (col == 0) { + vas = _mm512_set1_ps(As[row]); + } + // load b scale and vcomp + if constexpr (row == 0) { + vbs0[col] = _mm512_loadu_ps(Bs0 + col * 16); + vbs1[col] = _mm512_loadu_ps(Bs1 + col * 16); + vcomp0[col] = _mm512_loadu_si512(Bcomp0 + col * 16); + vcomp1[col] = _mm512_loadu_si512(Bcomp1 + col * 16); + } + __m512 c0 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc0[i], vcomp0[col])); + __m512 c1 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc1[i], vcomp1[col])); + vc0[i] = _mm512_castps_si512(_mm512_mul_ps(_mm512_mul_ps(c0, vas), vbs0[col])); + vc1[i] = _mm512_castps_si512(_mm512_mul_ps(_mm512_mul_ps(c1, vas), vbs1[col])); + }; + Unroll{}(scalec); + + using Vec = at::vec::Vectorized; + const Vec one = Vec(1.f); + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + // for COLS = 2, 4 use 512bit store + if constexpr (col % 2 == 0) { + Vec x0 = _mm512_castsi512_ps(vc0[row * COLS + col + 0]); + Vec x1 = _mm512_castsi512_ps(vc0[row * COLS + col + 1]); + Vec y0 = _mm512_castsi512_ps(vc1[row * COLS + col + 0]); + Vec y1 = _mm512_castsi512_ps(vc1[row * COLS + col + 1]); + // silu + x0 = x0 / (one + x0.neg().exp_u20()); + x1 = x1 / (one + x1.neg().exp_u20()); + // mul + x0 = x0 * y0; + x1 = x1 * y1; + + _mm512_storeu_si512( + reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), + (__m512i)(_mm512_cvtne2ps_pbh(__m512(x1), __m512(x0)))); + } + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_VNNI(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_vnni::apply( \ + A + mb_start * lda, \ + B0 + nb_start * 4, \ + B1 + nb_start * 4, \ + C + mb_start * ldc + nb_start, \ + As + mb_start, \ + Bs0 + nb_start, \ + Bs1 + nb_start, \ + Bcomp0 + nb_start, \ + Bcomp1 + nb_start, \ + K, \ + lda, \ + ldb, \ + ldc); + +template +void tinygemm_kernel( + const uint8_t* __restrict__ A, + const int8_t* __restrict__ B0, + const int8_t* __restrict__ B1, + scalar_t* __restrict__ C, + const float* __restrict__ As, + const float* __restrict__ Bs0, + const float* __restrict__ Bs1, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + const int32_t* Bcomp0 = reinterpret_cast(B0 + block_size_n() * K); + const int32_t* Bcomp1 = reinterpret_cast(B1 + block_size_n() * K); + + // pattern: 1-(2+2)-(8+8) + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 32; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch (mb_size << 4 | nb_size >> 4) { + case 0x12: + LAUNCH_TINYGEMM_KERNEL_VNNI(1, 32); + break; + case 0x22: + LAUNCH_TINYGEMM_KERNEL_VNNI(2, 32); + break; + case 0x32: + LAUNCH_TINYGEMM_KERNEL_VNNI(3, 32); + break; + case 0x42: + LAUNCH_TINYGEMM_KERNEL_VNNI(4, 32); + break; + default: + TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +/// gemm for w2 +template +struct tinygemm_kernel_vnni2 { + static inline void apply( + const uint8_t* __restrict__ A, + const int8_t* __restrict__ B, + float* __restrict__ C, + const float* __restrict__ As, + const float* __restrict__ Bs, + const int32_t* __restrict__ Bcomp, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!"); + } +}; + +#if defined(CPU_CAPABILITY_AVX512) +template +struct tinygemm_kernel_vnni2 { + static inline void apply( + const uint8_t* __restrict__ A, + const int8_t* __restrict__ B, + float* __restrict__ C, + const float* __restrict__ As, + const float* __restrict__ Bs, + const int32_t* __restrict__ Bcomp, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + constexpr int ROWS = BLOCK_M; + constexpr int COLS = BLOCK_N / 16; + static_assert(COLS % 2 == 0); + + __m512i va; + __m512i vb[COLS]; + __m512i vc[ROWS * COLS]; + __m512i vcomp[COLS]; + __m512 vas; + __m512 vbs[COLS]; + + auto loadc = [&](auto i) { vc[i] = _mm512_set1_epi32(0); }; + Unroll{}(loadc); + + const int64_t K4 = K >> 2; + const int64_t lda4 = lda >> 2; + const int64_t ldb4 = ldb; // ldb * 4 >> 2; + const int32_t* a_ptr = reinterpret_cast(A); + const int32_t* b_ptr = reinterpret_cast(B); + + auto compute = [&](auto i, int64_t k) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + if constexpr (col == 0) { + va = _mm512_set1_epi32(a_ptr[row * lda4 + k]); + } + if constexpr (row == 0) { + vb[col] = _mm512_loadu_si512(b_ptr + k * ldb4 + col * 16); + } + vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]); + }; + for (int64_t k = 0; k < K4; ++k) { + Unroll{}(compute, k); + } + + auto storec = [&](auto i) { + constexpr int row = i / COLS; + constexpr int col = i % COLS; + + // load a scale + if constexpr (col == 0) { + vas = _mm512_set1_ps(As[row]); + } + // load b scale and vcomp per 2 vectors + // also load bias if any + if constexpr (row == 0) { + if constexpr (col % 2 == 0) { + vbs[col + 0] = _mm512_loadu_ps(Bs + col * 16); + vbs[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16); + vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16); + vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16); + } + } + __m512 x = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[i], vcomp[col])); + x = _mm512_mul_ps(_mm512_mul_ps(x, vas), vbs[col]); + _mm512_storeu_ps(reinterpret_cast<__m512*>(C + row * ldc + col * 16), x); + }; + Unroll{}(storec); + } +}; +#endif + +#define LAUNCH_TINYGEMM_KERNEL_VNNI2(MB_SIZE, NB_SIZE) \ + tinygemm_kernel_vnni2::apply( \ + A + mb_start * lda, \ + B + nb_start * 4, \ + C + mb_start * ldc + nb_start, \ + As + mb_start, \ + Bs + nb_start, \ + Bcomp + nb_start, \ + K, \ + lda, \ + ldb, \ + ldc); + +template +void tinygemm_kernel( + const uint8_t* __restrict__ A, + const int8_t* __restrict__ B, + float* __restrict__ C, + const float* __restrict__ As, + const float* __restrict__ Bs, + int64_t M, + int64_t N, + int64_t K, + int64_t lda, + int64_t ldb, + int64_t ldc) { + // B compensation + const int32_t* Bcomp = reinterpret_cast(B + block_size_n() * K); + + // pattern: 1-4-16 + constexpr int64_t BLOCK_M = 4; + constexpr int64_t BLOCK_N = 64; + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + for (int64_t mb = 0; mb < MB; ++mb) { + int64_t mb_start = mb * BLOCK_M; + int64_t mb_size = std::min(BLOCK_M, M - mb_start); + for (int64_t nb = 0; nb < NB; ++nb) { + int64_t nb_start = nb * BLOCK_N; + int64_t nb_size = std::min(BLOCK_N, N - nb_start); + + switch (mb_size << 4 | nb_size >> 4) { + case 0x12: + LAUNCH_TINYGEMM_KERNEL_VNNI2(1, 32); + break; + case 0x22: + LAUNCH_TINYGEMM_KERNEL_VNNI2(2, 32); + break; + case 0x32: + LAUNCH_TINYGEMM_KERNEL_VNNI2(3, 32); + break; + case 0x42: + LAUNCH_TINYGEMM_KERNEL_VNNI2(4, 32); + break; + default: + TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); + } + } + } +} + +} // anonymous namespace + +template +void fused_experts_int8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic1, + scalar_t* __restrict__ ic2, + uint8_t* __restrict__ A_tmp, + float* __restrict__ C_tmp, + uint8_t* __restrict__ Aq_tmp, + float* __restrict__ As_tmp, + const scalar_t* __restrict__ input, + const int8_t* __restrict__ packed_w1, + const int8_t* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + const float* __restrict__ topk_weights, + const int32_t* __restrict__ sorted_ids, + const int32_t* __restrict__ expert_ids, + const int32_t* __restrict__ offsets, + int64_t M, + int64_t N, + int64_t K, + int64_t E, + int64_t topk, + int64_t num_tokens_post_pad) { + // handle 2 tiles per block + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + // stage 0: quantize input to uint8, [M, K] + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + quantize_row_int8(Aq_tmp + m * K, As_tmp[m], input + m * K, K); + } + }); + + // stage 1: intermediate_cache1 = silu(hidden_states @ w1) + const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + // strides for w1: [E, 2N, K] + TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N); + + // K and N are packed for int8 + const int64_t packed_K = get_row_size(K); + const int64_t packed_N = get_row_size(N); + + const int64_t stride_e = 2 * N * packed_K; + const int64_t stride_n = packed_K; + // here we only parallel on half of 2N to fuse silu_and_mul with gemm + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + uint8_t* __restrict__ A = A_tmp + tid * BLOCK_M * K; + + alignas(64) float As[BLOCK_M]; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB; + int64_t nb = i % NB; + + // nb0 from top half and nb1 from bottom half + int64_t nb0 = nb, nb1 = nb + NB; + int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N); + + // B shape [K, n_size] in vnni format + int32_t expert_id = expert_ids[mb]; + const int8_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb0 * BLOCK_N * stride_n; + const int8_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb1 * BLOCK_N * stride_n; + const float* __restrict__ Bs0 = w1s + expert_id * 2 * N + nb0 * BLOCK_N; + const float* __restrict__ Bs1 = w1s + expert_id * 2 * N + nb1 * BLOCK_N; + + // 1.a load A + const int32_t* A_ids = sorted_ids + mb * BLOCK_M; + int64_t m_size = offsets[mb + 1] - offsets[mb]; + + for (int64_t m = 0; m < m_size; ++m) { + int32_t index = A_ids[m] / topk; + copy_stub(A + m * K, Aq_tmp + index * K, K); + As[m] = As_tmp[index]; + } + + // fused 1.b: silu_and_mul(A @ B0, A @ B1) + const int64_t offset = offsets[mb]; + tinygemm_kernel( + /* A */ A, + /* B0 */ B0, + /* B1 */ B1, + /* C */ ic1 + offset * N + nb * BLOCK_N, + /* As */ As, + /* Bs0 */ Bs0, + /* Bs1 */ Bs1, + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ N); + } + }); + + // stage 1.5: quantize ic1 to uint8, [M * topk, N] + at::parallel_for(0, M * topk, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + quantize_row_int8(Aq_tmp + m * N, As_tmp[m], ic1 + m * N, N); + } + }); + + // stage 2: intermediate_cache2 = intermediate_cache1 @ w2 + // w2 : [E, K, N] as [E, OC, IC] + const int64_t OC = K; // rename K as OC + const int64_t IC = N; // rename N as IC + const int64_t MB2 = MB; + const int64_t NB2 = div_up(OC, BLOCK_N); + const int64_t stride_e2 = OC * packed_N; + const int64_t stride_oc = packed_N; + + // parallel on [MB2, NB2] + at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + // we won't be using C1 for gemm2 + float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB2; + int64_t nb = i % NB2; + + int64_t m_size = offsets[mb + 1] - offsets[mb]; + int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); + + // A ptr from ic1 of [M * topk, N] in sorted order + // so as to avoid copy A to tmp buffer again + const uint8_t* __restrict__ A = Aq_tmp + offsets[mb] * N; + const float* __restrict__ As = As_tmp + offsets[mb]; + const int32_t* A_ids = sorted_ids + mb * BLOCK_M; + + // B shape [IC, n_size] in vnni format + int32_t expert_id = expert_ids[mb]; + const int8_t* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc; + const float* __restrict__ Bs = w2s + expert_id * K + nb * BLOCK_N; + + // 2.a gemm: C = A @ B + tinygemm_kernel( + /* A */ A, + /* B */ B, + /* C */ C, + /* As */ As, + /* Bs */ Bs, + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N); + + // 2.b copy from C to ic2 in original order + // and also mul topk_weights in float32 + for (int64_t m = 0; m < m_size; ++m) { + int32_t index = A_ids[m]; + float weight = topk_weights[index]; + copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size); + } + } + }); + + // stage 3: out = intermediate_cache2.sum(dim=1) + // from [M, topk, K] to [M, K] + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + sum_stub(output + m * K, ic2 + m * topk * K, topk, K); + } + }); +} + +#define INSTANTIATE_MOE_INT8_TEMPLATE(TYPE) \ + template void fused_experts_int8_kernel_impl( \ + TYPE* __restrict__ output, \ + TYPE* __restrict__ ic1, \ + TYPE* __restrict__ ic2, \ + uint8_t* __restrict__ A_tmp, \ + float* __restrict__ C_tmp, \ + uint8_t* __restrict__ Aq_tmp, \ + float* __restrict__ As_tmp, \ + const TYPE* __restrict__ input, \ + const int8_t* __restrict__ packed_w1, \ + const int8_t* __restrict__ packed_w2, \ + const float* __restrict__ w1s, \ + const float* __restrict__ w2s, \ + const float* __restrict__ topk_weights, \ + const int32_t* __restrict__ sorted_ids, \ + const int32_t* __restrict__ expert_ids, \ + const int32_t* __restrict__ offsets, \ + int64_t M, \ + int64_t N, \ + int64_t K, \ + int64_t E, \ + int64_t topk, \ + int64_t num_tokens_post_pad) + +INSTANTIATE_MOE_INT8_TEMPLATE(at::BFloat16); +INSTANTIATE_MOE_INT8_TEMPLATE(at::Half); + +template +void shared_expert_int8_kernel_impl( + scalar_t* __restrict__ output, + scalar_t* __restrict__ ic1, + float* __restrict__ C_tmp, + uint8_t* __restrict__ Aq_tmp, + float* __restrict__ As_tmp, + const scalar_t* __restrict__ input, + const int8_t* __restrict__ packed_w1, + const int8_t* __restrict__ packed_w2, + const float* __restrict__ w1s, + const float* __restrict__ w2s, + const scalar_t* __restrict__ fused_experts_out, + float routed_scaling_factor, + int64_t M, + int64_t N, + int64_t K) { + // handle 2 tiles per block + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + + // stage 0: quantize input to uint8, [M, K] + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + quantize_row_int8(Aq_tmp + m * K, As_tmp[m], input + m * K, K); + } + }); + + // stage 1: intermediate_cache1 = silu(hidden_states @ w1) + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB = div_up(N, BLOCK_N); + + TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N); + + // K and N are packed for int8 + const int64_t packed_K = get_row_size(K); + const int64_t packed_N = get_row_size(N); + const int64_t stride_n = packed_K; + + // here we only parallel on half of 2N to fuse silu_and_mul with gemm + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB; + int64_t nb = i % NB; + + // nb0 from top half and nb1 from bottom half + int64_t nb0 = nb, nb1 = nb + NB; + int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N); + int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); + + // A shape [m_size, K] + const uint8_t* A = Aq_tmp + mb * BLOCK_M * K; + const float* As = As_tmp + mb * BLOCK_M; + + // B shape [K, n_size] in vnni format + const int8_t* __restrict__ B0 = packed_w1 + nb0 * BLOCK_N * stride_n; + const int8_t* __restrict__ B1 = packed_w1 + nb1 * BLOCK_N * stride_n; + const float* __restrict__ Bs0 = w1s + nb0 * BLOCK_N; + const float* __restrict__ Bs1 = w1s + nb1 * BLOCK_N; + + // fused 1.b: silu_and_mul(A @ B0, A @ B1) + tinygemm_kernel( + /* A */ A, + /* B0 */ B0, + /* B1 */ B1, + /* C */ ic1 + mb * BLOCK_M * N + nb * BLOCK_N, + /* As */ As, + /* Bs0 */ Bs0, + /* Bs1 */ Bs1, + /* M */ m_size, + /* N */ n_size, + /* K */ K, + /* lda */ K, + /* ldb */ n_size, + /* ldc */ N); + } + }); + + // stage 1.5: quantize ic1 to uint8, [M * topk, N] + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + quantize_row_int8(Aq_tmp + m * N, As_tmp[m], ic1 + m * N, N); + } + }); + + // stage 2: intermediate_cache2 = intermediate_cache1 @ w2 + // w2 : [K, N] as [OC, IC] + const int64_t OC = K; // rename K as OC + const int64_t IC = N; // rename N as IC + const int64_t MB2 = MB; + const int64_t NB2 = div_up(OC, BLOCK_N); + const int64_t stride_oc = packed_N; + + // parallel on [MB2, NB2] + at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { + // get local pointers + int tid = at::get_thread_num(); + // we won't be using C1 for gemm2 + float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N; + + for (int64_t i = begin; i < end; ++i) { + int64_t mb = i / NB2; + int64_t nb = i % NB2; + + int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M); + int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N); + + // A shape [m_size, IC] + const uint8_t* __restrict__ A = Aq_tmp + mb * BLOCK_M * N; + const float* __restrict__ As = As_tmp + mb * BLOCK_M; + + // B shape [IC, n_size] in vnni format + const int8_t* __restrict__ B = packed_w2 + nb * BLOCK_N * stride_oc; + const float* __restrict__ Bs = w2s + nb * BLOCK_N; + + // 2.a gemm: C = A @ B + tinygemm_kernel( + /* A */ A, + /* B */ B, + /* C */ C, + /* As */ As, + /* Bs */ Bs, + /* M */ m_size, + /* N */ n_size, + /* K */ IC, + /* lda */ IC, + /* ldb */ n_size, + /* ldc */ BLOCK_N); + + // 2.b copy from C to output and add fused_experts_out + scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N; + const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N; + for (int64_t m = 0; m < m_size; ++m) { + add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size); + } + } + }); +} + +#define INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(TYPE) \ + template void shared_expert_int8_kernel_impl( \ + TYPE* __restrict__ output, \ + TYPE* __restrict__ ic1, \ + float* __restrict__ C_tmp, \ + uint8_t* __restrict__ Aq_tmp, \ + float* __restrict__ As_tmp, \ + const TYPE* __restrict__ input, \ + const int8_t* __restrict__ packed_w1, \ + const int8_t* __restrict__ packed_w2, \ + const float* __restrict__ w1s, \ + const float* __restrict__ w2s, \ + const TYPE* __restrict__ fused_experts_out, \ + float routed_scaling_factor, \ + int64_t M, \ + int64_t N, \ + int64_t K) + +INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(at::BFloat16); +INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(at::Half); diff --git a/sgl-kernel/csrc/cpu/norm.cpp b/sgl-kernel/csrc/cpu/norm.cpp new file mode 100644 index 000000000..391a0d4e5 --- /dev/null +++ b/sgl-kernel/csrc/cpu/norm.cpp @@ -0,0 +1,221 @@ +#include "common.h" +#include "vec.h" + +namespace { + +// NB: avoid using `at::vec::map<>` on bfloat16 or half +template +void rmsnorm_kernel_impl( + scalar_t* __restrict__ output, + const scalar_t* __restrict__ input, + const scalar_t* __restrict__ weight, + int64_t batch_size, + int64_t hidden_size, + float eps = 1e-5) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + + constexpr int kVecSize = bVec::size(); + at::parallel_for(0, batch_size, 0, [&](int64_t begin, int64_t end) { + for (int64_t i = begin; i < end; ++i) { + // local ptrs + scalar_t* __restrict__ out_ptr = output + i * hidden_size; + const scalar_t* __restrict__ input_ptr = input + i * hidden_size; + + fVec sum_fvec = fVec(float(0)); + float sum_val = float(0); + + int64_t d; +#pragma GCC unroll 4 + for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) { + bVec x_bvec = bVec::loadu(input_ptr + d); + fVec x_fvec0, x_fvec1; + std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec); + + sum_fvec += x_fvec0 * x_fvec0; + sum_fvec += x_fvec1 * x_fvec1; + } +#pragma GCC unroll 4 + for (; d < hidden_size; ++d) { + float x_val = static_cast(input_ptr[d]); + sum_val += x_val * x_val; + } + + sum_val += vec_reduce_sum(sum_fvec); + float rsqrt_var = float(1) / std::sqrt(sum_val / hidden_size + eps); + const fVec scale_fvec = fVec(rsqrt_var); + +#pragma GCC unroll 4 + for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) { + bVec x_bvec = bVec::loadu(input_ptr + d); + fVec x_fvec0, x_fvec1; + std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec); + + bVec w_bvec = bVec::loadu(weight + d); + fVec w_fvec0, w_fvec1; + std::tie(w_fvec0, w_fvec1) = at::vec::convert_to_float(w_bvec); + + x_fvec0 = x_fvec0 * scale_fvec * w_fvec0; + x_fvec1 = x_fvec1 * scale_fvec * w_fvec1; + + bVec out_bvec = convert_from_float_ext(x_fvec0, x_fvec1); + out_bvec.store(out_ptr + d); + } +#pragma GCC unroll 4 + for (; d < hidden_size; ++d) { + float x_val = static_cast(input_ptr[d]); + float w_val = static_cast(weight[d]); + out_ptr[d] = static_cast(x_val * rsqrt_var * w_val); + } + } + }); +} + +template +void fused_add_rmsnorm_kernel_impl( + scalar_t* __restrict__ input, + scalar_t* __restrict__ residual, + const scalar_t* __restrict__ weight, + float* __restrict__ buffer, + int64_t batch_size, + int64_t hidden_size, + float eps = 1e-5) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + + constexpr int kVecSize = bVec::size(); + at::parallel_for(0, batch_size, 0, [&](int64_t begin, int64_t end) { + int tid = at::get_thread_num(); + float* __restrict__ buffer_ptr = buffer + tid * hidden_size; + + for (int64_t i = begin; i < end; ++i) { + // local ptrs + scalar_t* __restrict__ input_ptr = input + i * hidden_size; + scalar_t* __restrict__ residual_ptr = residual + i * hidden_size; + + fVec sum_fvec = fVec(float(0)); + float sum_val = float(0); + + int64_t d; +#pragma GCC unroll 4 + for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) { + bVec x_bvec = bVec::loadu(input_ptr + d); + fVec x_fvec0, x_fvec1; + std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec); + + bVec r_bvec = bVec::loadu(residual_ptr + d); + fVec r_fvec0, r_fvec1; + std::tie(r_fvec0, r_fvec1) = at::vec::convert_to_float(r_bvec); + + x_fvec0 += r_fvec0; + x_fvec1 += r_fvec1; + + bVec out_bvec = convert_from_float_ext(x_fvec0, x_fvec1); + out_bvec.store(residual_ptr + d); + + sum_fvec += x_fvec0 * x_fvec0; + sum_fvec += x_fvec1 * x_fvec1; + + x_fvec0.store(buffer_ptr + d); + x_fvec1.store(buffer_ptr + d + fVec::size()); + } +#pragma GCC unroll 4 + for (; d < hidden_size; ++d) { + float x_val = static_cast(input_ptr[d]); + float r_val = static_cast(residual_ptr[d]); + + x_val += r_val; + residual_ptr[d] = static_cast(x_val); + + sum_val += x_val * x_val; + buffer_ptr[d] = x_val; + } + + sum_val += vec_reduce_sum(sum_fvec); + float rsqrt_var = float(1) / std::sqrt(sum_val / hidden_size + eps); + const fVec scale_fvec = fVec(rsqrt_var); + +#pragma GCC unroll 4 + for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) { + fVec x_fvec0 = fVec::loadu(buffer_ptr + d); + fVec x_fvec1 = fVec::loadu(buffer_ptr + d + fVec::size()); + + bVec w_bvec = bVec::loadu(weight + d); + fVec w_fvec0, w_fvec1; + std::tie(w_fvec0, w_fvec1) = at::vec::convert_to_float(w_bvec); + + x_fvec0 = x_fvec0 * scale_fvec * w_fvec0; + x_fvec1 = x_fvec1 * scale_fvec * w_fvec1; + bVec x_bvec = convert_from_float_ext(x_fvec0, x_fvec1); + x_bvec.store(input_ptr + d); + } +#pragma GCC unroll 4 + for (; d < hidden_size; ++d) { + float x_val = buffer_ptr[d] * rsqrt_var * static_cast(weight[d]); + input_ptr[d] = x_val; + } + } + }); +} + +} // anonymous namespace + +// input : {batch_size, hidden_size} +// weight: {hidden_size} +at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) { + RECORD_FUNCTION("sgl-kernel::rmsnorm_cpu", std::vector({input, weight})); + + CHECK_INPUT(input); + CHECK_INPUT(weight); + CHECK_DIM(2, input); + CHECK_DIM(1, weight); + CHECK_EQ(input.size(1), weight.size(0)); + int64_t batch_size = input.size(0); + int64_t hidden_size = input.size(1); + at::Tensor output = at::empty_like(input); + + AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "rmsnorm_kernel", [&] { + rmsnorm_kernel_impl( + output.data_ptr(), + input.data_ptr(), + weight.data_ptr(), + batch_size, + hidden_size, + eps); + }); + return output; +} + +// input : {batch_size, hidden_size} +// residual: {batch_size, hidden_size} +// weight : {hidden_size} +void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps) { + RECORD_FUNCTION("sgl-kernel::fused_add_rmsnorm_cpu", std::vector({input, residual, weight})); + CHECK_INPUT(input); + CHECK_INPUT(residual); + CHECK_INPUT(weight); + CHECK_DIM(2, input); + CHECK_DIM(2, residual); + CHECK_DIM(1, weight); + CHECK_EQ(input.size(0), residual.size(0)); + CHECK_EQ(input.size(1), residual.size(1)); + CHECK_EQ(input.size(1), weight.size(0)); + int64_t batch_size = input.size(0); + int64_t hidden_size = input.size(1); + + // allocate temp buffer to store x in float32 per thread + // TODO: implement a singleton for context + int64_t num_threads = at::get_num_threads(); + at::Tensor buffer = at::empty({num_threads, hidden_size}, input.options().dtype(at::kFloat)); + + AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "fused_add_rmsnorm_kernel", [&] { + fused_add_rmsnorm_kernel_impl( + input.data_ptr(), + residual.data_ptr(), + weight.data_ptr(), + buffer.data_ptr(), + batch_size, + hidden_size, + eps); + }); +} diff --git a/sgl-kernel/csrc/cpu/qkv_proj.cpp b/sgl-kernel/csrc/cpu/qkv_proj.cpp new file mode 100644 index 000000000..959072878 --- /dev/null +++ b/sgl-kernel/csrc/cpu/qkv_proj.cpp @@ -0,0 +1,504 @@ +#include "common.h" +#include "gemm.h" +#include "vec.h" + +namespace { + +// [NOTE]: Fused kernel for QKV projection with weight absorption and RoPE +// +// 1. `q_a_proj` and `kv_a_proj_with_mqa` fused into one gemm, +// otherwise we need to split IC for the 2nd gemm. +// 2. `q_a_layernorm` and `kv_a_layernorm` fused into one parallel loop. +// 3. k_input and v_input share the same storage, the torch API did +// this in `set_kv_buffer`. No additional memory movement. +// + +// [C0, C1] = A @ [B0, B1] +template +void segment_gemm_kernel_impl( + scalar_t* __restrict__ C0, + scalar_t* __restrict__ C1, + const scalar_t* __restrict__ A, + const scalar_t* __restrict__ B0, + const scalar_t* __restrict__ B1, + int64_t M, + int64_t N0, + int64_t N1, + int64_t K) { + // convert_weight_packed make sure N0 and N1 are 32x + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB0 = div_up(N0, BLOCK_N); + const int64_t NB1 = div_up(N1, BLOCK_N); + const int64_t NB = NB0 + NB1; + + const bool use_brgemm = can_use_brgemm(M); + + // parallel on [MB, NB0 + NB1] + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + int64_t mb{0}, nb{0}; + data_index_init(begin, mb, MB, nb, NB); + + // for brgemm, use float32 for accumulate + alignas(64) float Ctmp[BLOCK_M * BLOCK_N]; + + for (int64_t i = begin; i < end; ++i) { + UNUSED(i); + int mb_start = mb * BLOCK_M; + int mb_size = std::min(M - mb_start, BLOCK_M); + int nb_start = nb * BLOCK_N; + int nb_size = BLOCK_N; + + const scalar_t* __restrict__ B = nb < NB0 ? B0 : B1; + scalar_t* __restrict__ C = nb < NB0 ? C0 : C1; + int64_t ldc = nb < NB0 ? N0 : N1; + int64_t local_nb_start = nb < NB0 ? nb_start : nb_start - N0; + + tinygemm_kernel( + /* A */ A + mb_start * K, + /* B */ B + local_nb_start * K /* nb * BLOCK_N * K */, + /* C */ C + mb_start * ldc + local_nb_start, + /* Ctmp*/ Ctmp, + /* M */ mb_size, + /* N */ nb_size, + /* K */ K, + /* lda */ K, + /* ldb */ nb_size, + /* ldc */ ldc, + /* brg */ use_brgemm); + + // move to the next index + data_index_step(mb, MB, nb, NB); + } + + if (use_brgemm) { + at::native::cpublas::brgemm_release(); + } + }); +} + +// [C0, C1] = A @ [B0, B1] +template +void segment_gemm_kernel_impl( + scalar_t* __restrict__ C0, + scalar_t* __restrict__ C1, + const uint8_t* __restrict__ A, + const int8_t* __restrict__ B0, + const int8_t* __restrict__ B1, + const float* __restrict__ As, + const float* __restrict__ Bs0, + const float* __restrict__ Bs1, + int64_t M, + int64_t N0, + int64_t N1, + int64_t K) { + constexpr int64_t BLOCK_M = block_size_m(); + constexpr int64_t BLOCK_N = block_size_n(); + const int64_t MB = div_up(M, BLOCK_M); + const int64_t NB0 = div_up(N0, BLOCK_N); + const int64_t NB1 = div_up(N1, BLOCK_N); + const int64_t NB = NB0 + NB1; + + // TODO: brgemm u8s8 depends on PyTorch 2.7 release. + const bool use_brgemm = false; + + // K + 4 after compensation + const int64_t packed_row_size = get_row_size(K); + + // parallel on [MB, NB0 + NB1] + at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { + int64_t mb{0}, nb{0}; + data_index_init(begin, mb, MB, nb, NB); + + // for brgemm, use float32 for accumulate + alignas(64) int32_t Ctmp[BLOCK_M * BLOCK_N]; + + for (int64_t i = begin; i < end; ++i) { + UNUSED(i); + int mb_start = mb * BLOCK_M; + int mb_size = std::min(M - mb_start, BLOCK_M); + int nb_start = nb * BLOCK_N; + int nb_size = BLOCK_N; + + const int8_t* __restrict__ B = nb < NB0 ? B0 : B1; + const float* __restrict__ Bs = nb < NB0 ? Bs0 : Bs1; + scalar_t* __restrict__ C = nb < NB0 ? C0 : C1; + int64_t ldc = nb < NB0 ? N0 : N1; + int64_t local_nb_start = nb < NB0 ? nb_start : nb_start - N0; + + tinygemm_kernel( + /* A */ A + mb_start * K, + /* B */ B + local_nb_start * packed_row_size /* nb * BLOCK_N * (K + 4) */, + /* C */ C + mb_start * ldc + local_nb_start, + /* Ctmp*/ Ctmp, + /* As */ As + mb_start, + /* Bs */ Bs + local_nb_start, + /* M */ mb_size, + /* N */ nb_size, + /* K */ K, + /* lda */ K, + /* ldb */ nb_size, + /* ldc */ ldc, + /* brg */ use_brgemm); + + // move to the next index + data_index_step(mb, MB, nb, NB); + } + + if (use_brgemm) { + at::native::cpublas::brgemm_release(); + } + }); +} + +template +inline float reduce(const scalar_t* __restrict__ x, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + fVec sum_fvec = fVec(float(0)); + +// no remainder +#pragma GCC unroll 4 + for (int64_t d = 0; d < size; d += bVec::size()) { + bVec x_bvec = bVec::loadu(x + d); + fVec x_fvec0, x_fvec1; + std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec); + sum_fvec += x_fvec0 * x_fvec0; + sum_fvec += x_fvec1 * x_fvec1; + } + return vec_reduce_sum(sum_fvec); +} + +// map2 from aten functional doesn't have fast bf16->fp32 conversion +template +inline void map2(scalar_t* y, const scalar_t* x, const scalar_t* __restrict__ w, float scale, int64_t size) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + fVec scale_fvec = fVec(scale); + +// no remainder +#pragma GCC unroll 4 + for (int64_t d = 0; d < size; d += bVec::size()) { + bVec x_bvec = bVec::loadu(x + d); + fVec x_fvec0, x_fvec1; + std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec); + bVec w_bvec = bVec::loadu(w + d); + fVec w_fvec0, w_fvec1; + std::tie(w_fvec0, w_fvec1) = at::vec::convert_to_float(w_bvec); + x_fvec0 = x_fvec0 * scale_fvec * w_fvec0; + x_fvec1 = x_fvec1 * scale_fvec * w_fvec1; + bVec out_bvec = convert_from_float_ext(x_fvec0, x_fvec1); + out_bvec.store(y + d); + } +} + +template +void rms_norm_kernel_impl( + scalar_t* __restrict__ input0, + scalar_t* __restrict__ input1, + const scalar_t* __restrict__ weight0, + const scalar_t* __restrict__ weight1, + int64_t M, + int64_t N0, + int64_t N1, + int64_t stride1, + float eps = 1e-5) { + at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + scalar_t* x0 = input0 + m * N0; + scalar_t* x1 = input1 + m * stride1; + float scale0 = reduce(x0, N0); + float scale1 = reduce(x1, N1); + scale0 = float(1) / std::sqrt(scale0 / N0 + eps); + scale1 = float(1) / std::sqrt(scale1 / N1 + eps); + map2(x0, x0, weight0, scale0, N0); + map2(x1, x1, weight1, scale1, N1); + } + }); +} + +template +inline void rotary(const scalar_t* input, scalar_t* out, const scalar_t* cos, const scalar_t* sin, int64_t size) { + TORCH_CHECK(false, "rotary scalar path not implemented."); +} + +#if defined(CPU_CAPABILITY_AVX512) +template <> +inline void rotary( + const at::BFloat16* input, at::BFloat16* out, const at::BFloat16* cos, const at::BFloat16* sin, int64_t size) { + // permute indices + const __m512i idx1 = _mm512_set_epi32(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0); + const __m512i idx2 = _mm512_set_epi32(31, 29, 27, 25, 23, 21, 19, 17, 15, 13, 11, 9, 7, 5, 3, 1); + const __m512i idy1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0); + const __m512i idy2 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8); + +// rotary dim is 64, just 2 iters +#pragma GCC unroll 2 + for (int64_t d = 0; d < size; d += 32) { + int64_t d2 = d >> 1; + // load coefs + __m512 vcos = CVT_BF16_TO_FP32(_mm256_loadu_si256(reinterpret_cast(cos + d2))); + __m512 vsin = CVT_BF16_TO_FP32(_mm256_loadu_si256(reinterpret_cast(sin + d2))); + // load input + __m512i a16 = _mm512_loadu_si512(reinterpret_cast(input + d)); + __m512 a = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(a16, 0)); + __m512 b = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(a16, 1)); + // from [16, 2] to [2, 16] + __m512 in1 = _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b); + __m512 in2 = _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b); + // out1 = in1 * cos - in2 * sin; + // out2 = in2 * cos + in1 * sin + __m512 out1 = _mm512_sub_ps(_mm512_mul_ps(in1, vcos), _mm512_mul_ps(in2, vsin)); + __m512 out2 = _mm512_add_ps(_mm512_mul_ps(in2, vcos), _mm512_mul_ps(in1, vsin)); + // from [2, 16] to [16, 2] + a = _mm512_mask_permutex2var_ps(out1, 0xffff, idy1, out2); + b = _mm512_mask_permutex2var_ps(out1, 0xffff, idy2, out2); + + _mm512_storeu_si512(reinterpret_cast<__m512i*>((out + d)), (__m512i)(_mm512_cvtne2ps_pbh(b, a))); + } +} +#endif + +template +void rotary_emb_kernel_impl( + scalar_t* q_pe_out, + scalar_t* k_pe_out, + const scalar_t* q_pe, + const scalar_t* k_pe, + const int64_t* pos, + const scalar_t* cos_sin, + int64_t num_seqs, + int64_t num_heads, + int64_t rotary_dim, + int64_t q_strideB, + int64_t q_strideH, + int64_t k_strideB, + int64_t oq_strideB, + int64_t oq_strideH, + int64_t ok_strideB) { + TORCH_CHECK(rotary_dim % 32 == 0, "rotary_dim is not 32x."); + const int64_t rotary_offset = rotary_dim / 2; + + // parallel on [num_seqs, num_heads + 1] + // top [num_heads] handle q_pe and bottom [1] handle k_pe + at::parallel_for(0, num_seqs * (num_heads + 1), GRAIN_SIZE / rotary_dim, [&](int64_t begin, int64_t end) { + int64_t seq{0}, head_id{0}; + data_index_init(begin, seq, num_seqs, head_id, num_heads + 1); + + for (int64_t i = begin; i < end; ++i) { + UNUSED(i); + // get cos and sin cache ptr + int64_t index = pos[seq]; + const scalar_t* cos = cos_sin + index * rotary_dim; + const scalar_t* sin = cos + rotary_offset; + + const scalar_t* input = + (head_id < num_heads) ? q_pe + seq * q_strideB + head_id * q_strideH : k_pe + seq * k_strideB; + scalar_t* out = + (head_id < num_heads) ? q_pe_out + seq * oq_strideB + head_id * oq_strideH : k_pe_out + seq * ok_strideB; + rotary(input, out, cos, sin, rotary_dim); + + // move to the next index + data_index_step(seq, num_seqs, head_id, num_heads + 1); + } + }); +} + +} // anonymous namespace + +extern at::Tensor +weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, std::optional& bias, bool is_vnni); + +extern at::Tensor int8_scaled_mm_with_quant( + at::Tensor& mat1, + at::Tensor& mat2, + at::Tensor& scales2, + std::optional& bias, + at::ScalarType out_dtype, + bool is_vnni); + +extern void +bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, std::optional& scale); + +// NB: shapes in DeepDeek R1 +// +// hidden_states : [num_seqs, hidden_size] [1, 7168] +// q_a_proj_weight : [q_lora_rank, hidden_size] [1536, 7168] +// q_b_proj_weight : [num_heads * qk_head_dim, q_lora_rank] [4224, 1536] +// kv_a_proj_weight : [kv_lora_rank + qk_rope_head_dim, hidden_size] [576, 7168] +// w_kc : [num_heads, kv_lora_rank, qk_nope_head_dim] [22, 512, 128] +// q_a_layernorm_weight : [q_lora_rank] [1536] +// kv_a_layernorm_weight : [kv_lora_rank] [512] +// +std::tuple qkv_proj_with_rope( + at::Tensor& hidden_states, + at::Tensor& q_a_proj_weight, + at::Tensor& q_b_proj_weight, + at::Tensor& kv_a_proj_weight, + at::Tensor& w_kc, + at::Tensor& q_a_layernorm_weight, + at::Tensor& kv_a_layernorm_weight, + at::Tensor& positions, + at::Tensor& cos_sin_cache, + double eps, + bool use_int8_w8a8, + std::optional& q_a_proj_scale, + std::optional& q_b_proj_scale, + std::optional& kv_a_proj_scale, + bool is_vnni) { + RECORD_FUNCTION( + "sgl-kernel::qkv_proj_with_rope", + std::vector({hidden_states, q_a_proj_weight, q_b_proj_weight, kv_a_proj_weight, w_kc})); + + const auto st = hidden_states.scalar_type(); + CHECK_INPUT(hidden_states); + CHECK_INPUT(positions); + CHECK_INPUT(cos_sin_cache); + CHECK_EQ(q_a_layernorm_weight.scalar_type(), st); + CHECK_EQ(kv_a_layernorm_weight.scalar_type(), st); + CHECK_EQ(positions.scalar_type(), at::kLong); + CHECK_EQ(cos_sin_cache.scalar_type(), st); + CHECK_DIM(2, hidden_states); + CHECK_DIM(3, w_kc); + CHECK_DIM(1, q_a_layernorm_weight); + CHECK_DIM(1, kv_a_layernorm_weight); + CHECK_DIM(1, positions); + CHECK_DIM(2, cos_sin_cache); + + // skip contiguous checks for weights, expect prepacked + TORCH_CHECK(is_vnni, "qkv_proj_with_rope: expect weights are prepacked!"); + + int64_t num_seqs = hidden_states.size(0); + int64_t hidden_size = hidden_states.size(1); + int64_t q_lora_rank = q_a_proj_weight.size(0); + int64_t num_heads = w_kc.size(0); + int64_t kv_lora_rank = w_kc.size(1); + int64_t qk_head_dim = q_b_proj_weight.size(0) / num_heads; + int64_t qk_nope_head_dim = w_kc.size(2); + int64_t qk_rope_head_dim = kv_a_proj_weight.size(0) - kv_lora_rank; + int64_t rotary_dim = cos_sin_cache.size(1); + + CHECK_EQ(positions.numel(), num_seqs); + CHECK_EQ(rotary_dim, qk_rope_head_dim); + CHECK_EQ(q_a_layernorm_weight.numel(), q_lora_rank); + CHECK_EQ(kv_a_layernorm_weight.numel(), kv_lora_rank); + + // check the packed dimension + CHECK_EQ(q_a_proj_weight.size(1), get_row_size(hidden_size, use_int8_w8a8)); + CHECK_EQ(q_b_proj_weight.size(1), get_row_size(q_lora_rank, use_int8_w8a8)); + CHECK_EQ(kv_a_proj_weight.size(1), get_row_size(hidden_size, use_int8_w8a8)); + + if (use_int8_w8a8) { + TORCH_CHECK(q_a_proj_scale.has_value(), "missing q_a_proj_scale for int8 w8a8."); + TORCH_CHECK(q_b_proj_scale.has_value(), "missing q_b_proj_scale for int8 w8a8."); + TORCH_CHECK(kv_a_proj_scale.has_value(), "missing kv_a_proj_scale for int8 w8a8."); + } + + // outputs and temp buffer + const auto options = hidden_states.options(); + auto q_input = at::empty({num_seqs, num_heads, kv_lora_rank + qk_rope_head_dim}, options); + auto k_input = at::empty({num_seqs, 1, kv_lora_rank + qk_rope_head_dim}, options); + auto v_input = k_input.narrow(-1, 0, kv_lora_rank); + + // outputs of q_a_proj and q_b_proj + auto qa = at::empty({num_seqs, q_lora_rank}, options); + + // stage 1: q_a_proj and kv_a_proj + AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "qkv_proj_kernel_impl", [&] { + if (use_int8_w8a8) { + auto q_a_proj_s = q_a_proj_scale.value(); + auto kv_a_proj_s = kv_a_proj_scale.value(); + TORCH_CHECK(q_a_proj_s.numel() == q_lora_rank); + TORCH_CHECK(kv_a_proj_s.numel() == kv_lora_rank + qk_rope_head_dim); + + auto buffer = at::empty({num_seqs * hidden_size + num_seqs * 4}, options.dtype(at::kByte)); + uint8_t* __restrict__ Aq_data = buffer.data_ptr(); + float* __restrict__ As_data = (float*)((void*)(Aq_data + num_seqs * hidden_size)); + const scalar_t* __restrict__ A_data = hidden_states.data_ptr(); + + at::parallel_for(0, num_seqs, 0, [&](int64_t begin, int64_t end) { + for (int64_t m = begin; m < end; ++m) { + quantize_row_int8(Aq_data + m * hidden_size, As_data[m], A_data + m * hidden_size, hidden_size); + } + }); + + segment_gemm_kernel_impl( + qa.data_ptr(), + k_input.data_ptr(), + Aq_data, + q_a_proj_weight.data_ptr(), + kv_a_proj_weight.data_ptr(), + As_data, + q_a_proj_s.data_ptr(), + kv_a_proj_s.data_ptr(), + num_seqs, + q_lora_rank, + kv_lora_rank + qk_rope_head_dim, + hidden_size); + } else { + segment_gemm_kernel_impl( + qa.data_ptr(), + k_input.data_ptr(), + hidden_states.data_ptr(), + q_a_proj_weight.data_ptr(), + kv_a_proj_weight.data_ptr(), + num_seqs, + q_lora_rank, + kv_lora_rank + qk_rope_head_dim, + hidden_size); + } + }); + + // stage 2: apply rmsnorm inplace + AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "rms_norm_kernel_impl", [&] { + rms_norm_kernel_impl( + qa.data_ptr(), + v_input.data_ptr(), + q_a_layernorm_weight.data_ptr(), + kv_a_layernorm_weight.data_ptr(), + num_seqs, + q_lora_rank, + kv_lora_rank, + kv_lora_rank + qk_rope_head_dim, + eps); + }); + + // stage 3: q_b_proj + at::Tensor qb; + std::optional bias; + if (use_int8_w8a8) { + qb = int8_scaled_mm_with_quant(qa, q_b_proj_weight, q_b_proj_scale.value(), bias, at::kBFloat16, is_vnni); + } else { + qb = weight_packed_linear(qa, q_b_proj_weight, bias, is_vnni); + } + qb.as_strided_({num_seqs, num_heads, qk_head_dim}, {num_heads * qk_head_dim, qk_head_dim, 1}); + + // stage 4: bmm + std::optional scale; + auto q_nope = qb.narrow(2, 0, qk_nope_head_dim).transpose_(0, 1); + auto q_nope_out = q_input.narrow(2, 0, kv_lora_rank).transpose_(0, 1); + bmm_cpu(q_nope_out, q_nope, w_kc, is_vnni, scale); + + // stage 5: rope + AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "rotary_emb_kernel_impl", [&] { + rotary_emb_kernel_impl( + q_input.data_ptr() + kv_lora_rank, + k_input.data_ptr() + kv_lora_rank, + qb.data_ptr() + qk_nope_head_dim, + k_input.data_ptr() + kv_lora_rank, + positions.data_ptr(), + cos_sin_cache.data_ptr(), + num_seqs, + num_heads, + rotary_dim, + num_heads * qk_head_dim, + qk_head_dim, + kv_lora_rank + qk_rope_head_dim, + num_heads * (kv_lora_rank + qk_rope_head_dim), + kv_lora_rank + qk_rope_head_dim, + kv_lora_rank + qk_rope_head_dim); + }); + + return std::make_tuple(q_input, k_input, v_input); +} diff --git a/sgl-kernel/csrc/cpu/rope.cpp b/sgl-kernel/csrc/cpu/rope.cpp new file mode 100644 index 000000000..64bc297fe --- /dev/null +++ b/sgl-kernel/csrc/cpu/rope.cpp @@ -0,0 +1,129 @@ +#include "common.h" +#include "vec.h" + +namespace { + +template +void rope_kernel_impl( + scalar_t* __restrict__ q_pe_out, + scalar_t* __restrict__ k_pe_out, + int64_t* __restrict__ t_pos, + scalar_t* __restrict__ q_pe, + scalar_t* __restrict__ k_pe, + scalar_t* __restrict__ t_emb_pos, + int64_t seq_len, + int64_t num_head, + int64_t rotary_dim, + int64_t HR, + int64_t q_pe_stride_s, + int64_t out_stride_qs, + int64_t out_stride_ks, + int64_t HK, + int64_t k_pe_stride_s, + int64_t q_pe_stride_n, + int64_t out_stride_qn) { + int64_t COFF = HR / 2; + at::parallel_for(0, seq_len * num_head, GRAIN_SIZE / rotary_dim, [&](int64_t begin, int64_t end) { + int64_t seq{0}, head_id{0}; + data_index_init(begin, seq, seq_len, head_id, num_head); + for (int64_t i = begin; i < end; ++i) { + int64_t in_offset_q = seq * q_pe_stride_s + head_id * q_pe_stride_n; + int64_t out_offset_q = seq * out_stride_qs + head_id * out_stride_qn; + int64_t out_offset_k = seq * out_stride_ks; + int64_t p = 0; + scalar_t* sin_start = nullptr; + scalar_t* cos_start = nullptr; + // step 0) get the rotary position embedding for the current position + p = t_pos[seq]; + sin_start = t_emb_pos + p * HR + COFF; + cos_start = t_emb_pos + p * HR; + // step 1) apply_rotary_pos_emb for the rotary_dim elements in every + // head of query/key + for (int64_t h = 0; h < rotary_dim; h += 2) { + scalar_t cos = cos_start[h >> 1]; + scalar_t sin = sin_start[h >> 1]; + scalar_t in1 = q_pe[in_offset_q + h]; + scalar_t in2 = q_pe[in_offset_q + h + 1]; + scalar_t out1 = in1 * cos - in2 * sin; + scalar_t out2 = in2 * cos + in1 * sin; + q_pe_out[out_offset_q + h] = out1; + q_pe_out[out_offset_q + h + 1] = out2; + } + for (int64_t h = 0; h < HK; h += 2) { + scalar_t cos = cos_start[h >> 1]; + scalar_t sin = sin_start[h >> 1]; + int64_t k_pe_offset = seq * k_pe_stride_s; + scalar_t in1_k = k_pe[k_pe_offset + h]; + scalar_t in2_k = k_pe[k_pe_offset + h + 1]; + scalar_t out1_k = in1_k * cos - in2_k * sin; + scalar_t out2_k = in2_k * cos + in1_k * sin; + k_pe_out[out_offset_k + h] = out1_k; + k_pe_out[out_offset_k + h + 1] = out2_k; + } + // move to the next index + data_index_step(seq, seq_len, head_id, num_head); + } + }); +} +} // namespace + +std::tuple +rotary_position_embedding_cpu(at::Tensor& t_pos, at::Tensor& q_pe, at::Tensor& k_pe, at::Tensor& t_emb_pos) { + RECORD_FUNCTION( + "sgl-kernel::rotary_position_embedding_cpu", std::vector({t_pos, q_pe, k_pe, t_emb_pos})); + CHECK_INPUT(t_pos); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_pe); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_pe); + CHECK_INPUT(t_emb_pos); + CHECK_DIM(1, t_pos); + CHECK_DIM(3, q_pe); + CHECK_DIM(3, k_pe); + CHECK_DIM(2, t_emb_pos); + + int64_t seq_len = q_pe.size(0); + int64_t num_head = q_pe.size(1); + int64_t rotary_dim = q_pe.size(2); + int64_t HK = k_pe.size(2); + int64_t HR = t_emb_pos.size(1); + CHECK_EQ(HR, rotary_dim); + CHECK_EQ(k_pe.size(0), seq_len); + CHECK_EQ(k_pe.size(1), 1); + CHECK_EQ(t_pos.size(0), seq_len); + CHECK_EQ(HK, rotary_dim); + + at::Tensor q_pe_out = at::empty_like(q_pe); + at::Tensor k_pe_out = at::empty_like(k_pe); + int64_t q_pe_stride_s = q_pe.stride(0); + int64_t q_pe_stride_n = q_pe.stride(1); + int64_t k_pe_stride_s = k_pe.stride(0); + int64_t out_stride_qs = q_pe_out.stride(0); + int64_t out_stride_qn = q_pe_out.stride(1); + int64_t out_stride_ks = k_pe_out.stride(0); + + const auto input_dtype = q_pe.scalar_type(); + TORCH_CHECK(t_pos.scalar_type() == at::kLong, "expect positions to be int64, got ", t_pos.scalar_type()); + TORCH_CHECK(input_dtype == k_pe.scalar_type(), "q_pe and k_pe must have the same data type"); + TORCH_CHECK(input_dtype == t_emb_pos.scalar_type(), "q_pe and t_emb_pos must have the same data type"); + + AT_DISPATCH_REDUCED_FLOATING_TYPES(input_dtype, "rotary_position_embedding_cpu", [&] { + rope_kernel_impl( + q_pe_out.data_ptr(), + k_pe_out.data_ptr(), + t_pos.data_ptr(), + q_pe.data_ptr(), + k_pe.data_ptr(), + t_emb_pos.data_ptr(), + seq_len, + num_head, + rotary_dim, + HR, + q_pe_stride_s, + out_stride_qs, + out_stride_ks, + HK, + k_pe_stride_s, + q_pe_stride_n, + out_stride_qn); + }); + return std::make_tuple(q_pe_out, k_pe_out); +} diff --git a/sgl-kernel/csrc/cpu/shm.cpp b/sgl-kernel/csrc/cpu/shm.cpp new file mode 100644 index 000000000..9f7d89df1 --- /dev/null +++ b/sgl-kernel/csrc/cpu/shm.cpp @@ -0,0 +1,659 @@ +#include "shm.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// states for collectives +enum coll_state { + coll_begin = 0, + coll_allreduce_naive__copy_in_done, + coll_allreduce_naive__reduce_done, + // alternative state when allreduce is working on alternative buffer + // of the double buffer. + coll_alt1_allreduce_naive__copy_in_done, + coll_alt2_allreduce_naive__copy_in_done, + coll_alt1_allreduce_naive__reduce_done, + coll_allgather_naive__copy_in_done, + coll_alt1_allgather_naive__copy_in_done, + coll_alt2_allgather_naive__copy_in_done, +}; + +// SHM building blocks +struct SharedData { + const char* name; + int descriptor; + void* bytes; + size_t nbytes; +}; + +void shared_open(SharedData* data, const char* name, size_t nbytes) { + int d = shm_open(name, O_RDWR, S_IRUSR | S_IWUSR); + if (d != -1) { + void* bytes = mmap(NULL, nbytes, PROT_READ | PROT_WRITE, MAP_SHARED, d, 0); + data->name = name; + data->descriptor = d; + data->bytes = bytes; + data->nbytes = nbytes; + } else { + if (errno != ENOENT) { + // don't print if shm can not be found because we want to loop over from + // caller again until the other ranks created the shm + printf("shared_open %s failed, errno=%d\n", name, errno); + } + data->descriptor = -1; + } +} + +void shared_create(SharedData* data, const char* name, void* bytes, size_t nbytes) { + int d = shm_open(name, O_CREAT | O_RDWR, S_IRUSR | S_IWUSR); + if (d != -1) { + if (nbytes = write(d, bytes, nbytes)) { + shared_open(data, name, nbytes); + } + } else { + printf("shared_create %s failed\n", name); + } +} + +static int world_size; + +// SHM based allreduce helper functions +// buffer that holds shm name +#define NAME_BUF_SIZE 1000 +#define MAX_BUF_SIZE 1048576 * 32 +#define NAIVE_ALLREDUCE_THRESHOLD 1048576 +#define SHM_BUFFER_NAME "deepspeed_allreduce_buffer" +struct allreduce_workspace { + enum coll_state states[2]; // idx=0 -- state for symmetric_naive_all_reduce + // idx=1 -- state for distributed_naive_all_reduce + // double buffer to avoid syncing between rounds + // offset=0 -- 2*NAIVE_ALLREDUCE_THRESHOLD : buffer for + // symmetric_naive_all_reduce after that : buffer for + // distributed_naive_all_reduce + char buffer[2 * NAIVE_ALLREDUCE_THRESHOLD + 2 * MAX_BUF_SIZE]; +}; + +#define BUFFER0_OFFSET(current_buffer) current_buffer* NAIVE_ALLREDUCE_THRESHOLD +#define BUFFER1_OFFSET(current_buffer) 2 * NAIVE_ALLREDUCE_THRESHOLD + current_buffer* MAX_BUF_SIZE + +struct allreduce_workspace** workspace; + +// buffer for small messages, double buffer +char** symmetric_buffer[2]; +// buffer for large messages, double buffer +char** distributed_buffer[2]; + +void wait_buffer_state_until_2(int index, enum coll_state state0, enum coll_state state1, int state_group) { + volatile enum coll_state* state_ptr = &(workspace[index]->states[state_group]); + + while (1) { + volatile enum coll_state cur_state = *state_ptr; + if (cur_state == state0 || cur_state == state1) break; + } +} + +__m512 cvt_bf16_to_fp32(const __m256i src) __attribute__((target("avx512bw"))); +inline __m512 cvt_bf16_to_fp32(const __m256i src) { + auto y = _mm512_cvtepu16_epi32(src); + return _mm512_castsi512_ps(_mm512_bslli_epi128(y, 2)); +} + +inline __m256i cvt_fp32_to_bf16(const __m512 src) __attribute__((target("avx512bw"))); +inline __m256i cvt_fp32_to_bf16(const __m512 src) { + __m512i value = _mm512_castps_si512(src); + __m512i nan = _mm512_set1_epi32(0xffff); + auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q); + __m512i ones = _mm512_set1_epi32(0x1); + __m512i vec_bias = _mm512_set1_epi32(0x7fff); + // uint32_t lsb = (input >> 16) & 1; + auto t_value = _mm512_and_si512(_mm512_srli_epi32(value, 16), ones); + // uint32_t rounding_bias = 0x7fff + lsb; + t_value = _mm512_add_epi32(t_value, vec_bias); + // input += rounding_bias; + t_value = _mm512_add_epi32(t_value, value); + // input = input >> 16; + t_value = _mm512_srli_epi32(t_value, 16); + // Check NaN before converting back to bf16 + t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value); + return _mm512_cvtusepi32_epi16(t_value); +} + +__m512 cvt_fp16_to_fp32(const __m256i src) __attribute__((target("avx512bw"))); +inline __m512 cvt_fp16_to_fp32(const __m256i src) { + return _mm512_cvtph_ps(src); +} + +inline __m256i cvt_fp32_to_fp16(const __m512 src) __attribute__((target("avx512bw"))); +inline __m256i cvt_fp32_to_fp16(const __m512 src) { + return _mm512_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); +} + +void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) + __attribute__((target("avx512bw"))); + +void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) + __attribute__((target("avx512bw"))); + +void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) + __attribute__((target("avx512bw"))); + +void reduce_all_buffers( + int start_elements, + int num_elements, + c10::ScalarType scalar_type, + int to_buffer_idx, + char* to_buffer, + char** buffers) { + switch (scalar_type) { + case c10::ScalarType::BFloat16: + reduce_bf16_buffers(start_elements, num_elements, to_buffer, buffers); + break; + case c10::ScalarType::Half: + reduce_fp16_buffers(start_elements, num_elements, to_buffer, buffers); + break; + case c10::ScalarType::Float: + reduce_fp32_buffers(start_elements, num_elements, to_buffer, buffers); + break; + default: + assert(!"Should not get here"); + } +} + +#define CVT_ADD_BF16(x) \ + do { \ + auto in##x##_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[x] + i))); \ + inout_val = _mm512_add_ps(inout_val, in##x##_val); \ + } while (0) + +// Reduce functions down below use vectorized algorithm, the number of bytes +// processed each iteration depends on vector length. 256bit vector ==> 32 +// bytes, 512bit vector ==> 64 bytes If you change implementation of +// reduce_bf16_buffers, etc. , check whether this number needs to be changed +#define VECTOR_LENGTH_IN_BYTES 32 + +void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) { + const int element_size = 2; + const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size; + int main_elements = num_elements - (num_elements % vector_length); + int remain_elements = num_elements % vector_length; + + // process aligned part +#pragma omp parallel for + for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size; + i += VECTOR_LENGTH_IN_BYTES) { + auto inout_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[0] + i))); + switch (world_size) { + case 16: + CVT_ADD_BF16(15); + case 15: + CVT_ADD_BF16(14); + case 14: + CVT_ADD_BF16(13); + case 13: + CVT_ADD_BF16(12); + case 12: + CVT_ADD_BF16(11); + case 11: + CVT_ADD_BF16(10); + case 10: + CVT_ADD_BF16(9); + case 9: + CVT_ADD_BF16(8); + case 8: + CVT_ADD_BF16(7); + case 7: + CVT_ADD_BF16(6); + case 6: + CVT_ADD_BF16(5); + case 5: + CVT_ADD_BF16(4); + case 4: + CVT_ADD_BF16(3); + case 3: + CVT_ADD_BF16(2); + case 2: + CVT_ADD_BF16(1); + case 1: + break; + default: + for (int j = 1; j < world_size; j++) { + auto in_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[j] + i))); + inout_val = _mm512_add_ps(inout_val, in_val); + } + } + _mm256_storeu_si256((__m256i*)(to_buffer + i), cvt_fp32_to_bf16(inout_val)); + } + + // process remaining part + int i = (start_elements + main_elements) * element_size; + while (remain_elements > 0) { + float val = 0.0f; + for (int j = 0; j < world_size; j++) { + val += *(at::BFloat16*)(buffers[j] + i); + } + *(at::BFloat16*)(to_buffer + i) = val; + remain_elements--; + i += element_size; + } +} + +#define CVT_ADD_FP16(x) \ + do { \ + auto in##x##_val = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[x] + i))); \ + inout_val = _mm512_add_ps(inout_val, in##x##_val); \ + } while (0) + +void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) { + const int element_size = 2; + const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size; + int main_elements = num_elements - (num_elements % vector_length); + int remain_elements = num_elements % vector_length; + + // process aligned part +#pragma omp parallel for + for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size; + i += VECTOR_LENGTH_IN_BYTES) { + auto inout_val = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[0] + i))); + switch (world_size) { + case 16: + CVT_ADD_FP16(15); + case 15: + CVT_ADD_FP16(14); + case 14: + CVT_ADD_FP16(13); + case 13: + CVT_ADD_FP16(12); + case 12: + CVT_ADD_FP16(11); + case 11: + CVT_ADD_FP16(10); + case 10: + CVT_ADD_FP16(9); + case 9: + CVT_ADD_FP16(8); + case 8: + CVT_ADD_FP16(7); + case 7: + CVT_ADD_FP16(6); + case 6: + CVT_ADD_FP16(5); + case 5: + CVT_ADD_FP16(4); + case 4: + CVT_ADD_FP16(3); + case 3: + CVT_ADD_FP16(2); + case 2: + CVT_ADD_FP16(1); + case 1: + break; + default: + for (int j = 1; j < world_size; j++) { + auto in_val = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[j] + i))); + inout_val = _mm512_add_ps(inout_val, in_val); + } + } + _mm256_storeu_si256((__m256i*)(to_buffer + i), cvt_fp32_to_fp16(inout_val)); + } + + // process remaining part + int i = (start_elements + main_elements) * element_size; + while (remain_elements > 0) { + float val = 0.0f; + for (int j = 0; j < world_size; j++) { + val += *(at::Half*)(buffers[j] + i); + } + *(at::Half*)(to_buffer + i) = val; + remain_elements--; + i += element_size; + } +} + +#define CVT_ADD_F32(x) \ + do { \ + auto in##x##_val = _mm256_loadu_ps((float*)(buffers[x] + i)); \ + inout_val = _mm256_add_ps(inout_val, in##x##_val); \ + } while (0) + +void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) { + const int element_size = 4; + const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size; + int main_elements = num_elements - (num_elements % vector_length); + int remain_elements = num_elements % vector_length; + + // process aligned part +#pragma omp parallel for + for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size; + i += VECTOR_LENGTH_IN_BYTES) { + auto inout_val = _mm256_loadu_ps((float*)(buffers[0] + i)); + switch (world_size) { + case 16: + CVT_ADD_F32(15); + case 15: + CVT_ADD_F32(14); + case 14: + CVT_ADD_F32(13); + case 13: + CVT_ADD_F32(12); + case 12: + CVT_ADD_F32(11); + case 11: + CVT_ADD_F32(10); + case 10: + CVT_ADD_F32(9); + case 9: + CVT_ADD_F32(8); + case 8: + CVT_ADD_F32(7); + case 7: + CVT_ADD_F32(6); + case 6: + CVT_ADD_F32(5); + case 5: + CVT_ADD_F32(4); + case 4: + CVT_ADD_F32(3); + case 3: + CVT_ADD_F32(2); + case 2: + CVT_ADD_F32(1); + case 1: + break; + default: + for (int j = 1; j < world_size; j++) { + auto in_val = _mm256_loadu_ps((float*)(buffers[j] + i)); + inout_val = _mm256_add_ps(inout_val, in_val); + } + } + _mm256_storeu_ps((float*)(to_buffer + i), inout_val); + } + + // process remaining part + int i = (start_elements + main_elements) * element_size; + while (remain_elements > 0) { + float val = 0.0f; + for (int j = 0; j < world_size; j++) { + val += *(float*)(buffers[j] + i); + } + *(float*)(to_buffer + i) = val; + remain_elements--; + i += element_size; + } +} + +static bool is_initialized = false; +static int world_rank; + +void shm_initialize(int size, int rank, char* addr_string, char* port_string) { + if (is_initialized) { + return; + } + is_initialized = true; + + world_size = size; + world_rank = rank; + + char shm_name_prefix[NAME_BUF_SIZE]; + char shm_name[NAME_BUF_SIZE]; + snprintf(shm_name_prefix, NAME_BUF_SIZE, "%s_%d_%s_%s", SHM_BUFFER_NAME, getuid(), addr_string, port_string); + // create shared workspace for SHM based allreduce + SharedData allreduce_buffer; + // allocate workspace_buf for current rank + struct allreduce_workspace* workspace_buf; + struct allreduce_workspace* workspace_buf_other; + workspace_buf = (struct allreduce_workspace*)malloc(sizeof(struct allreduce_workspace)); + snprintf(shm_name, NAME_BUF_SIZE, "%s_%d", shm_name_prefix, rank); + shared_create(&allreduce_buffer, shm_name, workspace_buf, sizeof(struct allreduce_workspace)); + workspace_buf = (struct allreduce_workspace*)allreduce_buffer.bytes; + workspace_buf->states[0] = coll_alt2_allreduce_naive__copy_in_done; + workspace_buf->states[1] = coll_begin; + + // create the workspace pointer list + workspace = (struct allreduce_workspace**)malloc(size * sizeof(struct allreduce_workspace*)); + symmetric_buffer[0] = (char**)malloc(size * sizeof(char**)); + symmetric_buffer[1] = (char**)malloc(size * sizeof(char**)); + distributed_buffer[0] = (char**)malloc(size * sizeof(char**)); + distributed_buffer[1] = (char**)malloc(size * sizeof(char**)); + + // map shm of all ranks + for (int i = 0; i < size; i++) { + if (i != rank) { + snprintf(shm_name, NAME_BUF_SIZE, "%s_%d", shm_name_prefix, i); + // printf("open %s, %d\n", shm_name, rank); + do { + shared_open(&allreduce_buffer, shm_name, sizeof(struct allreduce_workspace)); + } while (allreduce_buffer.descriptor == -1 && errno == ENOENT); + workspace_buf_other = (struct allreduce_workspace*)allreduce_buffer.bytes; + workspace[i] = workspace_buf_other; + } else { + workspace[i] = workspace_buf; + } + symmetric_buffer[0][i] = workspace[i]->buffer + BUFFER0_OFFSET(0); + symmetric_buffer[1][i] = workspace[i]->buffer + BUFFER0_OFFSET(1); + distributed_buffer[0][i] = workspace[i]->buffer + BUFFER1_OFFSET(0); + distributed_buffer[1][i] = workspace[i]->buffer + BUFFER1_OFFSET(1); + } +} + +static void parallel_memcpy(void* to, void* from, size_t n_bytes) __attribute__((target("avx512bw"))); +static void parallel_memcpy(void* to, void* from, size_t n_bytes) { + auto aligned_bytes = n_bytes - (n_bytes % VECTOR_LENGTH_IN_BYTES); + // process aligned part +#pragma omp parallel for + for (int i = 0; i < aligned_bytes; i += VECTOR_LENGTH_IN_BYTES) { + auto val = _mm256_loadu_si256((__m256i*)((char*)from + i)); + _mm256_storeu_si256((__m256i*)((char*)to + i), val); + } + + // process remaining part + for (int i = aligned_bytes; i < n_bytes; i++) { + *((char*)to + i) = *((char*)from + i); + } +} + +#define positive_mod(num, mod) ((((num) % (mod)) + (mod)) % (mod)) +#define rank_mod(rank) positive_mod(rank, world_size) +size_t slice_size(size_t chunk_el, int slice_idx) { + size_t slice_size = chunk_el / world_size; + return slice_idx == world_size - 1 ? slice_size + (chunk_el % world_size) : slice_size; +} + +char* slice_data(char* data_ptr, size_t chunk_el, int el_size, int slice_idx) { + size_t slice_size = chunk_el / world_size; + size_t el_offset = slice_size * slice_idx; + return data_ptr + el_offset * el_size; +} + +size_t slice_el_start(size_t chunk_el, int slice_idx) { + size_t slice_size = chunk_el / world_size; + return slice_size * slice_idx; +} + +void symmetric_naive_all_reduce(char* data_ptr, c10::ScalarType scalar_type, size_t chunk_size, size_t chunk_el) { + const int state_group = 0; + static int current_buffer = 0; + static int state_idx = 0; + + enum coll_state copy_current, copy_next; + + switch (state_idx) { + case 0: + copy_current = coll_allreduce_naive__copy_in_done; + copy_next = coll_alt1_allreduce_naive__copy_in_done; + break; + case 1: + copy_current = coll_alt1_allreduce_naive__copy_in_done; + copy_next = coll_alt2_allreduce_naive__copy_in_done; + break; + case 2: + copy_current = coll_alt2_allreduce_naive__copy_in_done; + copy_next = coll_allreduce_naive__copy_in_done; + break; + default: + assert(!"Should not get here."); + } + state_idx = (state_idx + 1) % 3; + + parallel_memcpy(symmetric_buffer[current_buffer][world_rank], data_ptr, chunk_size); + std::atomic_thread_fence(std::memory_order_release); + workspace[world_rank]->states[state_group] = copy_current; + + for (int i = 0; i < world_size; i++) { + // wait until the other rank copy the buffer + if (i != world_rank) { + wait_buffer_state_until_2(i, copy_current, copy_next, state_group); + } + } + + // each rank reduce the buffer independently so therre is no need for + // synchronization afterward + reduce_all_buffers(0, chunk_el, scalar_type, world_rank, data_ptr, symmetric_buffer[current_buffer]); + + // switch buffer + current_buffer = 1 - current_buffer; +} + +// naive allreduce distributed, each rank do naive reduce on its slice +void distributed_naive_reduce(char* data_ptr, c10::ScalarType scalar_type, size_t chunk_size, size_t chunk_el) { + const int state_group = 1; + static int current_buffer = 0; + static int state_idx = 0; + + enum coll_state copy_current, copy_next, reduce_current; + + // similar to symmetric_naive_allreduce, but here we only need two sets of + // states, because distributed naive reduce has two barriers in the algorithm + switch (state_idx) { + case 0: + copy_current = coll_allreduce_naive__copy_in_done; + reduce_current = coll_allreduce_naive__reduce_done; + copy_next = coll_alt1_allreduce_naive__copy_in_done; + break; + case 1: + copy_current = coll_alt1_allreduce_naive__copy_in_done; + reduce_current = coll_alt1_allreduce_naive__reduce_done; + copy_next = coll_allreduce_naive__copy_in_done; + break; + default: + assert(!"Should not get here."); + } + state_idx = (state_idx + 1) % 2; + + int data_size = chunk_size / chunk_el; + parallel_memcpy(distributed_buffer[current_buffer][world_rank], data_ptr, chunk_size); + std::atomic_thread_fence(std::memory_order_release); + workspace[world_rank]->states[state_group] = copy_current; + + for (int i = 0; i < world_size; i++) { + // wait until all the other ranks copy the buffer + if (i != world_rank) wait_buffer_state_until_2(i, copy_current, reduce_current, state_group); + } + + // reduce scatter + reduce_all_buffers( + slice_el_start(chunk_el, world_rank), + slice_size(chunk_el, world_rank), + scalar_type, + world_rank, + distributed_buffer[current_buffer][world_rank], + distributed_buffer[current_buffer]); + std::atomic_thread_fence(std::memory_order_release); + workspace[world_rank]->states[state_group] = reduce_current; + + for (int i = 0; i < world_size; i++) { + // wait until all the other ranks reduce the buffer + if (i != world_rank) wait_buffer_state_until_2(i, reduce_current, copy_next, state_group); + } + + for (int i = 0; i < world_size; i++) { + int rank = (i + world_rank) % world_size; + parallel_memcpy( + slice_data(data_ptr, chunk_el, data_size, rank), + slice_data(distributed_buffer[current_buffer][rank], chunk_el, chunk_size / chunk_el, rank), + slice_size(chunk_el, rank) * data_size); + } + + current_buffer = 1 - current_buffer; +} + +void all_reduce_outer_loop(torch::Tensor& data, size_t numel, int data_size) { + for (int offset = 0; offset < data_size; offset += MAX_BUF_SIZE) { + auto data_ptr = ((char*)(data.data_ptr()) + offset); + size_t chunk_size = data_size - offset > MAX_BUF_SIZE ? MAX_BUF_SIZE : data_size - offset; + size_t chunk_el = chunk_size / (data_size / numel); + if (chunk_size < NAIVE_ALLREDUCE_THRESHOLD) { + symmetric_naive_all_reduce(data_ptr, data.scalar_type(), chunk_size, chunk_el); + } else { + distributed_naive_reduce(data_ptr, data.scalar_type(), chunk_size, chunk_el); + } + } +} + +void naive_all_gather(char* result_ptr, char* data_ptr, size_t res_stride, size_t chunk_size, size_t chunk_el) { + const int state_group = 1; + static int current_buffer = 0; + static int state_idx = 0; + + enum coll_state copy_current, copy_next; + + switch (state_idx) { + case 0: + copy_current = coll_allgather_naive__copy_in_done; + copy_next = coll_alt1_allgather_naive__copy_in_done; + break; + case 1: + copy_current = coll_alt1_allgather_naive__copy_in_done; + copy_next = coll_alt2_allgather_naive__copy_in_done; + break; + case 2: + copy_current = coll_alt2_allgather_naive__copy_in_done; + copy_next = coll_allgather_naive__copy_in_done; + break; + default: + assert(!"Should not get here."); + } + state_idx = (state_idx + 1) % 3; + + int data_size = chunk_size / chunk_el; + parallel_memcpy(distributed_buffer[current_buffer][world_rank], data_ptr, chunk_size); + std::atomic_thread_fence(std::memory_order_release); + workspace[world_rank]->states[state_group] = copy_current; + + for (int i = 0; i < world_size; i++) { + // wait until all the other ranks copy the buffer + if (i != world_rank) wait_buffer_state_until_2(i, copy_current, copy_next, state_group); + } + for (int i = 0; i < world_size; i++) { + parallel_memcpy(result_ptr + i * res_stride, distributed_buffer[current_buffer][i], chunk_size); + } + current_buffer = 1 - current_buffer; +} + +torch::Tensor& all_gather(torch::Tensor& result, torch::Tensor& data, int dim, size_t numel, int data_size) { + size_t dim_el = data.stride(dim) * data.size(dim); + int dtype_size = data_size / numel; + size_t dim_size = dim_el * dtype_size; + int dim_count = data_size / dim_size; + auto data_ptr = (char*)(data.data_ptr()); + auto result_ptr = (char*)(result.data_ptr()); + for (int i = 0; i < dim_count; i++) { + for (int offset = 0; offset < dim_size; offset += MAX_BUF_SIZE) { + size_t chunk_size = dim_size - offset > MAX_BUF_SIZE ? MAX_BUF_SIZE : dim_size - offset; + size_t chunk_el = chunk_size / dtype_size; + naive_all_gather( + result_ptr + i * dim_size * world_size + offset, + data_ptr + i * dim_size + offset, + dim_size, + chunk_size, + chunk_el); + } + } + return result; +} diff --git a/sgl-kernel/csrc/cpu/shm.h b/sgl-kernel/csrc/cpu/shm.h new file mode 100644 index 000000000..d21fe3d36 --- /dev/null +++ b/sgl-kernel/csrc/cpu/shm.h @@ -0,0 +1,11 @@ +#include + +#include + +#ifndef __SHM_COLLECTIVES__ +#define __SHM_COLLECTIVES__ +#define VECTOR_LENGTH_IN_BYTES 32 +void shm_initialize(int size, int rank, char* addr_string, char* port_string); +void all_reduce_outer_loop(torch::Tensor& data, size_t numel, int data_size); +torch::Tensor& all_gather(torch::Tensor& result, torch::Tensor& data, int dim, size_t numel, int data_size); +#endif diff --git a/sgl-kernel/csrc/cpu/topk.cpp b/sgl-kernel/csrc/cpu/topk.cpp new file mode 100644 index 000000000..6a6b64d12 --- /dev/null +++ b/sgl-kernel/csrc/cpu/topk.cpp @@ -0,0 +1,406 @@ +#include "common.h" +#include "vec.h" + +namespace { + +template +inline void softmax(float* __restrict__ out, const scalar_t* __restrict__ input) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + + constexpr int kVecSize = bVec::size(); + + // step 1: get max + fVec max_fvec = fVec(-std::numeric_limits::infinity()); + if constexpr (SIZE < kVecSize) { + // SIZE = 1, 2, 4, 8, 16; only the top half is used + bVec x_bvec = bVec::loadu(input, SIZE); + fVec x_fvec0, x_fvec1; + std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec); + x_fvec0 = fVec::set(max_fvec, x_fvec0, SIZE); + max_fvec = at::vec::maximum(max_fvec, x_fvec0); + x_fvec0.store(out, SIZE); + } else { + for (int d = 0; d < SIZE; d += kVecSize) { + bVec x_bvec = bVec::loadu(input + d); + fVec x_fvec0, x_fvec1; + std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec); + + max_fvec = at::vec::maximum(max_fvec, x_fvec0); + max_fvec = at::vec::maximum(max_fvec, x_fvec1); + x_fvec0.store(out + d); + x_fvec1.store(out + d + fVec::size()); + } + } + float max_val = vec_reduce_max(max_fvec); + max_fvec = fVec(max_val); + + // step 2: sum of (x - max).exp() + fVec sum_fvec = fVec(float(0)); + if constexpr (SIZE < fVec::size()) { + // SIZE = 1, 2, 4, 8 + fVec x_fvec = (fVec::loadu(out, SIZE) - max_fvec).exp_u20(); + x_fvec = fVec::set(sum_fvec, x_fvec, SIZE); + sum_fvec += x_fvec; + x_fvec.store(out, SIZE); + } else { + for (int d = 0; d < SIZE; d += fVec::size()) { + fVec x_fvec = (fVec::loadu(out + d) - max_fvec).exp_u20(); + sum_fvec += x_fvec; + x_fvec.store(out + d); + } + } + float sum_val = vec_reduce_sum(sum_fvec); + + // step 3: x * (1 / sum) + sum_fvec = fVec(1.f / sum_val); + if constexpr (SIZE < fVec::size()) { + // SIZE = 1, 2, 4, 8 + fVec out_fvec = fVec::loadu(out, SIZE) * sum_fvec; + out_fvec.store(out, SIZE); + } else { + for (int d = 0; d < SIZE; d += fVec::size()) { + fVec out_fvec = fVec::loadu(out + d) * sum_fvec; + out_fvec.store(out + d); + } + } +} + +template +void grouped_topk_kernel_impl( + float* __restrict__ topk_weights, + int32_t* __restrict__ topk_ids, + const scalar_t* __restrict__ gating_output, + int64_t num_tokens, + int64_t topk, + int64_t num_groups, + int64_t topk_group, + bool renormalize) { + const int64_t num_experts_per_group = NUM_EXPERTS / num_groups; + at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) { + alignas(64) float scores[NUM_EXPERTS]; + + using elem_t = std::pair; + std::vector queue(num_groups); + std::vector queue2(topk_group * num_experts_per_group); + + for (int64_t i = begin; i < end; ++i) { + // do softmax to get scores + softmax(scores, gating_output + i * NUM_EXPERTS); + + // find max score per group + for (int64_t g = 0; g < num_groups; ++g) { + float gmax = -std::numeric_limits::infinity(); + for (int64_t e = 0; e < num_experts_per_group; ++e) { + gmax = std::max(gmax, scores[g * num_experts_per_group + e]); + } + queue[g] = {gmax, g}; + } + + // find group topk + std::partial_sort( + queue.begin(), queue.begin() + topk_group, queue.end(), [](const elem_t& x, const elem_t& y) -> bool { + return x.first > y.first; + }); + + for (int64_t g = 0; g < topk_group; ++g) { + int32_t group_idx = queue[g].second; + for (int64_t e = 0; e < num_experts_per_group; ++e) { + int32_t expert_idx = group_idx * num_experts_per_group + e; + queue2[g * num_experts_per_group + e] = {scores[expert_idx], expert_idx}; + } + } + + // find global topk + std::partial_sort( + queue2.begin(), queue2.begin() + topk, queue2.end(), [](const elem_t& x, const elem_t& y) -> bool { + return x.first > y.first; + }); + + for (int64_t j = 0; j < topk; ++j) { + topk_weights[i * topk + j] = queue2[j].first; + topk_ids[i * topk + j] = queue2[j].second; + } + + if (renormalize) { + float sum = 0.f; + for (int64_t j = 0; j < topk; ++j) { + sum += topk_weights[i * topk + j]; + } + float scale = 1.f / sum; + for (int64_t j = 0; j < topk; ++j) { + topk_weights[i * topk + j] *= scale; + } + } + } + }); +} + +template +inline void sigmoid(float* __restrict__ out, const scalar_t* __restrict__ input) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + + const fVec one = fVec(1.f); + + constexpr int kVecSize = bVec::size(); + for (int d = 0; d < SIZE; d += kVecSize) { + bVec x_bvec = bVec::loadu(input + d); + fVec x_fvec0, x_fvec1; + std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec); + + x_fvec0 = one / (one + x_fvec0.neg().exp_u20()); + x_fvec1 = one / (one + x_fvec1.neg().exp_u20()); + + x_fvec0.store(out + d); + x_fvec1.store(out + d + fVec::size()); + } +} + +template +inline void +apply_bias(float* __restrict__ scores2, const float* __restrict__ scores, const scalar_t* __restrict__ bias) { + using bVec = at::vec::Vectorized; + using fVec = at::vec::Vectorized; + for (int d = 0; d < SIZE; d += bVec::size()) { + bVec bias_vec = bVec::loadu(bias + d); + fVec bias0, bias1; + std::tie(bias0, bias1) = at::vec::convert_to_float(bias_vec); + + fVec x0 = fVec::loadu(scores + d) + bias0; + fVec x1 = fVec::loadu(scores + d + fVec::size()) + bias1; + x0.store(scores2 + d); + x1.store(scores2 + d + fVec::size()); + } +} + +template +void biased_grouped_topk_kernel_impl( + float* __restrict__ topk_weights, + int32_t* __restrict__ topk_ids, + const scalar_t* __restrict__ gating_output, + const scalar_t* __restrict__ bias, + int64_t num_tokens, + int64_t num_groups, + int64_t topk_group, + bool renormalize) { + using Vec = at::vec::Vectorized; + + const int64_t num_experts_per_group = NUM_EXPERTS / num_groups; + at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) { + // scores: sigmoid + alignas(64) float scores[NUM_EXPERTS]; + // scores for choice: sigmoid + bias + alignas(64) float scores2[NUM_EXPERTS]; + + using elem_t = std::pair; + std::vector queue(num_groups); + std::vector queue2(topk_group * num_experts_per_group); + + for (int64_t i = begin; i < end; ++i) { + // do sigmoid to get scores + sigmoid(scores, gating_output + i * NUM_EXPERTS); + apply_bias(scores2, scores, bias); + + for (int64_t g = 0; g < num_groups; ++g) { + // find the max + float gmax = at::vec::reduce_all( + [](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, + scores2 + g * num_experts_per_group, + num_experts_per_group); + + // find position of first max, + // note that we may have multiple max values. + int first_max_idx = -1; + for (int64_t e = 0; e < num_experts_per_group; ++e) { + if (scores2[g * num_experts_per_group + e] == gmax) { + first_max_idx = g * num_experts_per_group + e; + break; + } + } + + // find the 2nd max + scores2[first_max_idx] = -std::numeric_limits::infinity(); + float gmax2 = at::vec::reduce_all( + [](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, + scores2 + g * num_experts_per_group, + num_experts_per_group); + // restore scores for choice + scores2[first_max_idx] = gmax; + + queue[g] = {gmax + gmax2, g}; + } + + // find group topk + std::partial_sort( + queue.begin(), queue.begin() + topk_group, queue.end(), [](const elem_t& x, const elem_t& y) -> bool { + return x.first > y.first; + }); + + for (int64_t g = 0; g < topk_group; ++g) { + int32_t group_idx = queue[g].second; + for (int64_t e = 0; e < num_experts_per_group; ++e) { + int32_t expert_idx = group_idx * num_experts_per_group + e; + queue2[g * num_experts_per_group + e] = {scores2[expert_idx], expert_idx}; + } + } + + // find global topk + std::partial_sort( + queue2.begin(), queue2.begin() + TOPK, queue2.end(), [](const elem_t& x, const elem_t& y) -> bool { + return x.first > y.first; + }); + + for (int j = 0; j < TOPK; ++j) { + int32_t index = queue2[j].second; + topk_ids[i * TOPK + j] = index; + topk_weights[i * TOPK + j] = scores[index]; + } + +#if defined(CPU_CAPABILITY_AVX512) + if (renormalize) { + __mmask16 mask = (1ULL << TOPK) - 1; + __m512 x = _mm512_maskz_loadu_ps(mask, topk_weights + i * TOPK); + float sum = _mm512_reduce_add_ps(x); + __m512 vscale = _mm512_set1_ps(1.f / sum); + __m512 y = _mm512_mul_ps(x, vscale); + _mm512_mask_storeu_ps(topk_weights + i * TOPK, mask, y); + } +#else + if (renormalize) { + float sum = 0.f; + for (int64_t j = 0; j < TOPK; ++j) { + sum += topk_weights[i * TOPK + j]; + } + float scale = 1.f / sum; + for (int64_t j = 0; j < TOPK; ++j) { + topk_weights[i * TOPK + j] *= scale; + } + } +#endif + } + }); +} + +#define LAUNCH_GROUPED_TOPK_KERNEL(NE) \ + grouped_topk_kernel_impl( \ + topk_weights.data_ptr(), \ + topk_ids.data_ptr(), \ + gating_output.data_ptr(), \ + num_tokens, \ + topk, \ + num_expert_group, \ + topk_group, \ + renormalize); + +#define LAUNCH_BIASED_GROUPED_TOPK_KERNEL(NE, NTOPK) \ + biased_grouped_topk_kernel_impl( \ + topk_weights.data_ptr(), \ + topk_ids.data_ptr(), \ + gating_output.data_ptr(), \ + correction_bias.data_ptr(), \ + num_tokens, \ + num_expert_group, \ + topk_group, \ + renormalize); + +} // anonymous namespace + +// grouped topk for DeepSeek V2 +std::tuple grouped_topk_cpu( + at::Tensor& hidden_states, + at::Tensor& gating_output, + int64_t topk, + bool renormalize, + int64_t num_expert_group, + int64_t topk_group) { + RECORD_FUNCTION("sgl-kernel::grouped_topk_cpu", std::vector({hidden_states, gating_output})); + CHECK_INPUT(gating_output); + + const auto st = hidden_states.scalar_type(); + CHECK_EQ(gating_output.scalar_type(), st); + + int64_t num_tokens = hidden_states.size(0); + int64_t num_experts = gating_output.size(1); + TORCH_CHECK(gating_output.size(0) == num_tokens, "Number of tokens mismatch"); + at::Tensor topk_weights = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kFloat)); + at::Tensor topk_ids = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kInt)); + + AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "grouped_topk_kernel", [&] { + switch (num_experts) { + case 1: + LAUNCH_GROUPED_TOPK_KERNEL(1); + break; + case 2: + LAUNCH_GROUPED_TOPK_KERNEL(2); + break; + case 4: + LAUNCH_GROUPED_TOPK_KERNEL(4); + break; + case 8: + LAUNCH_GROUPED_TOPK_KERNEL(8); + break; + case 16: + LAUNCH_GROUPED_TOPK_KERNEL(16); + break; + case 32: + LAUNCH_GROUPED_TOPK_KERNEL(32); + break; + case 64: + LAUNCH_GROUPED_TOPK_KERNEL(64); + break; + case 128: + LAUNCH_GROUPED_TOPK_KERNEL(128); + break; + case 160: + LAUNCH_GROUPED_TOPK_KERNEL(160); + break; + case 256: + LAUNCH_GROUPED_TOPK_KERNEL(256); + break; + default: + TORCH_CHECK(false, "Unexpected num_experts: ", num_experts); + } + }); + return std::make_tuple(topk_weights, topk_ids); +} + +// biased grouped topk DeepSeek V3/R1 +std::tuple biased_grouped_topk_cpu( + at::Tensor& hidden_states, + at::Tensor& gating_output, + at::Tensor& correction_bias, + int64_t topk, + bool renormalize, + int64_t num_expert_group, + int64_t topk_group) { + RECORD_FUNCTION( + "sgl-kernel::biased_grouped_topk_cpu", std::vector({hidden_states, gating_output, correction_bias})); + + CHECK_INPUT(gating_output); + CHECK_INPUT(correction_bias); + + const auto st = hidden_states.scalar_type(); + CHECK_EQ(gating_output.scalar_type(), st); + CHECK_EQ(correction_bias.scalar_type(), st); + + int64_t num_tokens = hidden_states.size(0); + int64_t num_experts = gating_output.size(1); + TORCH_CHECK(gating_output.size(0) == num_tokens, "Number of tokens mismatch"); + TORCH_CHECK(correction_bias.numel() == num_experts, "Bias shape mismatch"); + at::Tensor topk_weights = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kFloat)); + at::Tensor topk_ids = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kInt)); + + AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "biased_grouped_topk_kernel", [&] { + // NOW only support DSv3 configs + TORCH_CHECK(topk == 8, "Unexpected topk: ", topk); + switch (num_experts) { + case 256: + LAUNCH_BIASED_GROUPED_TOPK_KERNEL(256, 8); + break; + default: + TORCH_CHECK(false, "Unexpected num_experts: ", num_experts); + } + }); + return std::make_tuple(topk_weights, topk_ids); +} diff --git a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp new file mode 100644 index 000000000..6b7cc1d39 --- /dev/null +++ b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp @@ -0,0 +1,224 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "shm.h" + +// silu_and_mul +at::Tensor silu_and_mul_cpu(at::Tensor& input); + +// rmsnorm +at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps); + +// fused_add_rmsnorm +void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps); + +// topk +std::tuple grouped_topk_cpu( + at::Tensor& hidden_states, + at::Tensor& gating_output, + int64_t topk, + bool renormalize, + int64_t num_expert_group, + int64_t topk_group); + +std::tuple biased_grouped_topk_cpu( + at::Tensor& hidden_states, + at::Tensor& gating_output, + at::Tensor& correction_bias, + int64_t topk, + bool renormalize, + int64_t num_expert_group, + int64_t topk_group); + +// attention +void decode_attention_cpu( + at::Tensor& query, + at::Tensor& output, + at::Tensor& k_cache, + at::Tensor& v_cahce, + at::Tensor& attn_logits, + at::Tensor& req_to_token, + at::Tensor& req_pool_indices, + at::Tensor& seq_lens, + double sm_scale, + double logit_cap); + +void extend_attention_cpu( + at::Tensor& q_extend, + at::Tensor& k_extend, + at::Tensor& v_extend, + at::Tensor& o_extend, + at::Tensor& k_buffer, + at::Tensor& v_buffer, + at::Tensor& req_to_token, + at::Tensor& req_pool_indices, + at::Tensor& seq_lens, + at::Tensor& extend_seq_lens, + at::Tensor& extend_start_loc, + int64_t max_len_extend, + double sm_scale, + double logit_cap); + +// weight prepack +at::Tensor convert_weight_packed(at::Tensor& weight); + +// quant +std::tuple per_token_quant_int8_cpu(at::Tensor& A); + +// gemm +at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, std::optional& bias, bool is_vnni); + +// igemm +at::Tensor int8_scaled_mm_cpu( + at::Tensor& mat1, + at::Tensor& mat2, + at::Tensor& scales1, + at::Tensor& scales2, + std::optional& bias, + at::ScalarType out_dtype, + bool is_vnni); + +// quant + igemm +at::Tensor int8_scaled_mm_with_quant( + at::Tensor& mat1, + at::Tensor& mat2, + at::Tensor& scales2, + std::optional& bias, + at::ScalarType out_dtype, + bool is_vnni); + +// bmm +void bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, std::optional& scale); + +// fused moe +at::Tensor fused_experts_cpu( + at::Tensor& hidden_states, + at::Tensor& w1, + at::Tensor& w2, + at::Tensor& topk_weights, + at::Tensor& topk_ids, + bool inplace, + bool use_int8_w8a8, + std::optional& w1_scale, + std::optional& w2_scale, + std::optional& a1_scale, + std::optional& a2_scale, + bool is_vnni); + +at::Tensor shared_expert_cpu( + at::Tensor& hidden_states, + at::Tensor& w1, + at::Tensor& w2, + at::Tensor& fused_experts_out, + double routed_scaling_factor, + bool inplace, + bool use_int8_w8a8, + std::optional& w1_scale, + std::optional& w2_scale, + std::optional& a1_scale, + std::optional& a2_scale, + bool is_vnni); + +// weight absorption +std::tuple qkv_proj_with_rope( + at::Tensor& hidden_states, + at::Tensor& q_a_proj_weight, + at::Tensor& q_b_proj_weight, + at::Tensor& kv_a_proj_weight, + at::Tensor& w_kc, + at::Tensor& q_a_layernorm_weight, + at::Tensor& kv_a_layernorm_weight, + at::Tensor& positions, + at::Tensor& cos_sin_cache, + double eps, + bool use_int8_w8a8, + std::optional& q_a_proj_scale, + std::optional& q_b_proj_scale, + std::optional& kv_a_proj_scale, + bool is_vnni); + +// shared memory init +void initialize(int size, int rank); + +// shared mmeory all_reduce +void shm_allreduce(at::Tensor& data, c10::intrusive_ptr process_group, py::object op); + +// shared memory all_gather +at::Tensor shm_allgather(at::Tensor& data, c10::intrusive_ptr process_group, int dim); + +// rope +std::tuple +rotary_position_embedding_cpu(at::Tensor& t_pos, at::Tensor& q_pe, at::Tensor& k_pe, at::Tensor& t_emb_pos); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + // activation + m.def("silu_and_mul_cpu", &silu_and_mul_cpu, "SiLU and mul for CPU"); + + // norm + m.def("rmsnorm_cpu", &rmsnorm_cpu, "Root mean square normalization for CPU"); + m.def("fused_add_rmsnorm_cpu", &fused_add_rmsnorm_cpu, "Fused add root mean square normalization for CPU"); + + // topk + m.def("grouped_topk_cpu", &grouped_topk_cpu, "Grouped TopK for CPU"); + + // biased group topk + m.def("biased_grouped_topk_cpu", &biased_grouped_topk_cpu, "Biased Grouped TopK for CPU"); + + // decode + m.def("decode_attention_cpu", &decode_attention_cpu, "Attention decoding for CPU"); + + // extend + m.def("extend_attention_cpu", &extend_attention_cpu, "Attention extend for CPU"); + + // weight prepack + m.def("convert_weight_packed", &convert_weight_packed, "prepack weight to vnni format for intel AMX"); + + // quant + m.def("per_token_quant_int8_cpu", &per_token_quant_int8_cpu, "dynamic quantization for CPU"); + + // gemm + m.def("weight_packed_linear", &weight_packed_linear, "weight packed linear for intel AMX"); + + // igemm + m.def("int8_scaled_mm_cpu", &int8_scaled_mm_cpu, "int8 weight packed linear for intel AMX"); + + // quant + igemm + m.def( + "int8_scaled_mm_with_quant", &int8_scaled_mm_with_quant, "fused per row quant and int8 scaled mm for intel AMX"); + + // bmm + m.def("bmm_cpu", &bmm_cpu, "bmm kernel for intel AMX"); + + // moe + m.def("fused_experts_cpu", &fused_experts_cpu, "fused moe kernel for CPU"); + + // weight absorption + m.def("qkv_proj_with_rope", &qkv_proj_with_rope, "fused qkv projection kernel with weight absorption for intel AMX"); + + // shared expert + m.def("shared_expert_cpu", &shared_expert_cpu, "shared expert kernel for CPU"); + + // all reduce + m.def("initialize", &initialize, "shared memory initialization for CPU"); + m.def("shm_allreduce", &shm_allreduce, "low latency all_reduce implementation for CPU"); + m.def("shm_allgather", &shm_allgather, "low latency all_gather implementation for CPU"); + + // rope + m.def("rotary_position_embedding_cpu", &rotary_position_embedding_cpu, "rotary position embedding for CPU"); +} diff --git a/sgl-kernel/csrc/cpu/vec.h b/sgl-kernel/csrc/cpu/vec.h new file mode 100644 index 000000000..e058bd716 --- /dev/null +++ b/sgl-kernel/csrc/cpu/vec.h @@ -0,0 +1,115 @@ +#pragma once + +#if defined(__AVX512F__) && defined(__AVX512BF16__) && defined(__AMX_BF16__) +#define CPU_CAPABILITY_AVX512 +#endif + +#include +#include + +namespace { + +using namespace at::vec; + +template , int> = 0> +inline Vectorized convert_from_float_ext(const Vectorized& a, const Vectorized& b) { + return at::vec::convert_from_float(a, b); +} + +#if defined(CPU_CAPABILITY_AVX512) + +// `at::vec::convert_from_float<>` from PyTorch doesn't have avx512-bf16 intrinsics +// use native instruction for bfloat16->float32 conversion +template <> +inline Vectorized +convert_from_float_ext(const Vectorized& a, const Vectorized& b) { + return (__m512i)(_mm512_cvtne2ps_pbh(__m512(b), __m512(a))); +} + +#define CVT_BF16_TO_FP32(a) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16)) + +#define CVT_FP16_TO_FP32(a) _mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) + +#endif + +// vector to scalar reduction +#if defined(CPU_CAPABILITY_AVX512) && 0 +inline float vec_reduce_sum(const Vectorized& a) { + return _mm512_reduce_add_ps(__m512(a)); +} + +inline float vec_reduce_max(const Vectorized& a) { + return _mm512_reduce_max_ps(__m512(a)); +} +#else +inline float vec_reduce_sum(const Vectorized& a) { + return vec_reduce_all([](Vectorized& x, Vectorized& y) { return x + y; }, a); +} + +inline float vec_reduce_max(const Vectorized& a) { + return vec_reduce_all([](Vectorized& x, Vectorized& y) { return maximum(x, y); }, a); +} +#endif + +// https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282 +template +inline void +quantize_row_int8(uint8_t* __restrict__ Aq, float& As, const scalar_t* __restrict__ A, int64_t K, float eps = 1e-7) { + float amax = 0.f; // absolute max + for (int64_t k = 0; k < K; ++k) { + const float val = static_cast(A[k]); + amax = std::max(amax, std::abs(val)); + } + + amax = std::max(amax, eps); + const float scale = amax / 127; + const float inv_scale = 127 / amax; + + for (int64_t k = 0; k < K; ++k) { + const float val = static_cast(A[k]) * inv_scale; + Aq[k] = (uint8_t)(std::round(val)) + 128; + } + As = scale; +} + +#if defined(CPU_CAPABILITY_AVX512) +template <> +inline void quantize_row_int8( + uint8_t* __restrict__ Aq, float& As, const at::BFloat16* __restrict__ A, int64_t K, float eps) { + const __m512 signBit = _mm512_set1_ps(-0.0f); + const __m512i off = _mm512_set1_epi32(128); + + // K is 32x, no remainder + float amax = 0.f; + __m512 vamax0 = _mm512_set1_ps(0.f); + __m512 vamax1 = _mm512_set1_ps(0.f); + for (int64_t k = 0; k < K; k += 32) { + __m512i va = _mm512_loadu_si512((void*)(A + k)); + __m512 va0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 0)); + __m512 va1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 1)); + vamax0 = _mm512_max_ps(vamax0, _mm512_andnot_ps(signBit, va0)); + vamax1 = _mm512_max_ps(vamax1, _mm512_andnot_ps(signBit, va1)); + } + amax = _mm512_reduce_max_ps(_mm512_max_ps(vamax0, vamax1)); + amax = std::max(amax, eps); + const float scale = amax / 127; + const float inv_scale = 127 / amax; + const __m512 vd = _mm512_set1_ps(inv_scale); + + for (int64_t k = 0; k < K; k += 32) { + __m512i va = _mm512_loadu_si512((void*)(A + k)); + __m512 va0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 0)); + __m512 va1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 1)); + va0 = _mm512_mul_ps(va0, vd); + va1 = _mm512_mul_ps(va1, vd); + va0 = _mm512_roundscale_ps(va0, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + va1 = _mm512_roundscale_ps(va1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)); + __m128i i0 = _mm512_cvtepi32_epi8(_mm512_add_epi32(_mm512_cvtps_epi32(va0), off)); + __m128i i1 = _mm512_cvtepi32_epi8(_mm512_add_epi32(_mm512_cvtps_epi32(va1), off)); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(Aq + k), _mm256_set_m128i(i1, i0)); + } + As = scale; +} +#endif + +} // anonymous namespace diff --git a/sgl-kernel/setup_cpu.py b/sgl-kernel/setup_cpu.py new file mode 100644 index 000000000..04e06cb1a --- /dev/null +++ b/sgl-kernel/setup_cpu.py @@ -0,0 +1,95 @@ +# Copyright 2025 SGLang Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import os +import shutil +import sys +from pathlib import Path + +import torch +from setuptools import find_packages, setup +from setuptools.command.build_py import build_py +from torch.utils.cpp_extension import BuildExtension, CppExtension + +root = Path(__file__).parent.resolve() + +if "bdist_wheel" in sys.argv and "--plat-name" not in sys.argv: + sys.argv.extend(["--plat-name", "manylinux2014_x86_64"]) + + +def _get_version(): + with open(root / "pyproject.toml") as f: + for line in f: + if line.startswith("version"): + return line.split("=")[1].strip().strip('"') + + +operator_namespace = "sgl_kernel" +include_dirs = [] + +sources = [ + "csrc/cpu/activation.cpp", + "csrc/cpu/bmm.cpp", + "csrc/cpu/decode.cpp", + "csrc/cpu/extend.cpp", + "csrc/cpu/gemm.cpp", + "csrc/cpu/gemm_int8.cpp", + "csrc/cpu/moe.cpp", + "csrc/cpu/moe_int8.cpp", + "csrc/cpu/norm.cpp", + "csrc/cpu/qkv_proj.cpp", + "csrc/cpu/topk.cpp", + "csrc/cpu/interface.cpp", + "csrc/cpu/shm.cpp", + "csrc/cpu/torch_extension_cpu.cpp", +] + +extra_compile_args = { + "cxx": [ + "-O3", + "-Wno-unknown-pragmas", + "-march=native", + "-fopenmp", + ] +} +libraries = ["c10", "torch", "torch_python"] +cmdclass = { + "build_ext": BuildExtension.with_options(use_ninja=True), +} +Extension = CppExtension + +extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"] + +ext_modules = [ + Extension( + name="sgl_kernel.common_ops", + sources=sources, + include_dirs=include_dirs, + extra_compile_args=extra_compile_args, + libraries=libraries, + extra_link_args=extra_link_args, + py_limited_api=True, + ), +] + +setup( + name="sgl-kernel", + version=_get_version(), + packages=find_packages(where="python"), + package_dir={"": "python"}, + ext_modules=ext_modules, + cmdclass=cmdclass, + options={"bdist_wheel": {"py_limited_api": "cp39"}}, +)