sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct

This commit is contained in:
maxiao1
2025-09-13 17:00:20 +08:00
commit 118f1fc726
2037 changed files with 515371 additions and 0 deletions

View File

@@ -0,0 +1,83 @@
cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
project(sgl_kernel)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_EXTENSIONS OFF)
find_package(Python COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED)
execute_process(
COMMAND ${Python_EXECUTABLE}
-c "import torch; print(torch.utils.cmake_prefix_path)"
OUTPUT_VARIABLE TORCH_PY_PREFIX
OUTPUT_STRIP_TRAILING_WHITESPACE
)
message(STATUS ${TORCH_PY_PREFIX})
list(APPEND CMAKE_PREFIX_PATH ${TORCH_PY_PREFIX}/Torch)
find_package(Torch REQUIRED)
include_directories(
${TORCH_INCLUDE_DIRS}
${TORCH_INSTALL_PREFIX}/include
${Python_INCLUDE_DIRS}
${CMAKE_CURRENT_SOURCE_DIR}/../../csrc
${CMAKE_CURRENT_SOURCE_DIR}/../../include
)
# Platform-specific library directory
if(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64|AMD64")
set(PLAT_LIB_DIR "/usr/lib/x86_64-linux-gnu")
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64")
set(PLAT_LIB_DIR "/usr/lib/aarch64-linux-gnu")
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "ppc64le|ppc64")
set(PLAT_LIB_DIR "/usr/lib/powerpc64le-linux-gnu")
else()
set(PLAT_LIB_DIR "/usr/lib/${CMAKE_SYSTEM_PROCESSOR}-linux-gnu")
endif()
link_directories(${PLAT_LIB_DIR})
# Conda library path support
if(DEFINED ENV{CONDA_PREFIX})
set(CONDA_LIB_DIR "$ENV{CONDA_PREFIX}/lib")
message(STATUS "Using Conda lib dir: ${CONDA_LIB_DIR}")
link_directories(${CONDA_LIB_DIR})
set(CONDA_INCLUDE_DIR "$ENV{CONDA_PREFIX}/include")
include_directories(${CONDA_INCLUDE_DIR})
# Look for libnuma in Conda's lib directory
find_library(NUMA_LIB numa HINTS "${CONDA_LIB_DIR}")
if(NUMA_LIB)
message(STATUS "Found libnuma: ${NUMA_LIB}")
else()
message(FATAL_ERROR "libnuma not found in Conda environment at ${CONDA_LIB_DIR}\n"
"Please install it using: conda install libnuma numactl\n")
endif()
endif()
file(GLOB SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp")
if(NOT DEFINED ENV{SGLANG_CPU_FP8_CVT_FTZ})
set(ENV{SGLANG_CPU_FP8_CVT_FTZ} "1")
endif()
if("$ENV{SGLANG_CPU_FP8_CVT_FTZ}" STREQUAL "1")
message(STATUS "Enabling macro: SGLANG_CPU_FP8_CVT_FTZ")
add_compile_definitions(SGLANG_CPU_FP8_CVT_FTZ)
endif()
add_compile_options(
-O3
-Wno-unknown-pragmas
-march=native
-fopenmp
)
Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} ${NUMA_LIB})
target_include_directories(common_ops PRIVATE ${TORCH_INCLUDE_DIRS})
install(TARGETS common_ops
LIBRARY DESTINATION sgl_kernel
)

View File

@@ -0,0 +1,135 @@
#include "common.h"
#include "vec.h"
namespace {
template <typename scalar_t, typename func_t, typename vec_func_t>
void act_and_mul_kernel_impl(
scalar_t* __restrict__ output,
const scalar_t* __restrict__ input,
int64_t num_tokens,
int64_t dim,
const func_t& f,
const vec_func_t& vf) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int64_t kVecSize = bVec::size();
at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; ++i) {
// local ptrs
const scalar_t* __restrict__ input_ptr = input + i * 2 * dim;
const scalar_t* __restrict__ input_other_ptr = input_ptr + dim;
scalar_t* __restrict__ output_ptr = output + i * dim;
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= dim - kVecSize; d += kVecSize) {
bVec x_bvec = bVec::loadu(input_ptr + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
bVec y_bvec = bVec::loadu(input_other_ptr + d);
fVec y_fvec0, y_fvec1;
std::tie(y_fvec0, y_fvec1) = at::vec::convert_to_float(y_bvec);
x_fvec0 = vf(x_fvec0);
x_fvec1 = vf(x_fvec1);
x_fvec0 = x_fvec0 * y_fvec0;
x_fvec1 = x_fvec1 * y_fvec1;
x_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
x_bvec.store(output_ptr + d);
}
#pragma GCC unroll 4
for (; d < dim; ++d) {
float x_val = static_cast<float>(input_ptr[d]);
float y_val = static_cast<float>(input_other_ptr[d]);
output_ptr[d] = f(x_val) * y_val;
}
}
});
}
} // anonymous namespace
// input : {num_tokens, 2 * d}
// output : {num_tokens, d}
at::Tensor silu_and_mul_cpu(at::Tensor& input) {
RECORD_FUNCTION("sgl-kernel::silu_and_mul_cpu", std::vector<c10::IValue>({input}));
auto sizes = input.sizes().vec();
int64_t last_dim = input.ndimension() - 1;
int64_t d = sizes[last_dim] / 2;
sizes[last_dim] = d;
int64_t num_tokens = input.numel() / input.size(-1);
at::Tensor out = at::empty(sizes, input.options());
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "silu_and_mul", [&] {
using Vec = at::vec::Vectorized<float>;
act_and_mul_kernel_impl(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
num_tokens,
d,
[](float x) { return x / (1.f + std::exp(-x)); },
[](Vec x) { return x / (Vec(1.f) + x.neg().exp()); });
});
return out;
}
at::Tensor gelu_tanh_and_mul_cpu(const at::Tensor& input) {
RECORD_FUNCTION("sgl-kernel::gelu_tanh_and_mul_cpu", std::vector<c10::IValue>({input}));
auto sizes = input.sizes().vec();
int64_t last_dim = input.ndimension() - 1;
int64_t d = sizes[last_dim] / 2;
sizes[last_dim] = d;
int64_t num_tokens = input.numel() / input.size(-1);
at::Tensor out = at::empty(sizes, input.options());
const float sqrt_2_div_pi = std::sqrt(2.f / M_PI);
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "gelu_tanh_and_mul", [&] {
using Vec = at::vec::Vectorized<float>;
act_and_mul_kernel_impl(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
num_tokens,
d,
[sqrt_2_div_pi](float x) {
float x3 = x * x * x;
float tanh_arg = sqrt_2_div_pi * (x + 0.044715f * x3);
return 0.5f * x * (1.f + std::tanh(tanh_arg));
},
[sqrt_2_div_pi](Vec x) {
Vec x3 = x * x * x;
Vec tanh_arg = Vec(sqrt_2_div_pi) * (x + Vec(0.044715f) * x3);
return Vec(0.5f) * x * (Vec(1.f) + tanh_arg.tanh());
});
});
return out;
}
at::Tensor gelu_and_mul_cpu(const at::Tensor& input) {
RECORD_FUNCTION("sgl-kernel::gelu_and_mul_cpu", std::vector<c10::IValue>({input}));
auto sizes = input.sizes().vec();
int64_t last_dim = input.ndimension() - 1;
int64_t d = sizes[last_dim] / 2;
sizes[last_dim] = d;
int64_t num_tokens = input.numel() / input.size(-1);
at::Tensor out = at::empty(sizes, input.options());
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul", [&] {
using Vec = at::vec::Vectorized<float>;
const float inv_sqrt2 = 1.0f / std::sqrt(2.0f);
act_and_mul_kernel_impl(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
num_tokens,
d,
[inv_sqrt2](float x) { return 0.5f * x * (1.f + std::erf(x * inv_sqrt2)); },
[inv_sqrt2](Vec x) { return Vec(0.5f) * x * (Vec(1.f) + (x * Vec(inv_sqrt2)).erf()); });
});
return out;
}

123
sgl-kernel/csrc/cpu/bmm.cpp Normal file
View File

@@ -0,0 +1,123 @@
#include "common.h"
#include "gemm.h"
#include "vec.h"
namespace {
template <typename scalar_t>
void bmm_kernel_impl(
scalar_t* __restrict__ out,
const scalar_t* __restrict__ mat1,
const scalar_t* __restrict__ mat2,
int64_t B,
int64_t M,
int64_t N,
int64_t K,
int64_t mat1_strideB,
int64_t mat1_strideM,
int64_t out_strideB,
int64_t out_strideM,
float scale = 0.f) {
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
// mat2 contiguous in [B, N, K]
int64_t mat2_strideB = N * K;
int64_t mat2_strideN = K;
const bool use_brgemm = can_use_brgemm<scalar_t>(M);
// parallel on [B, MB, NB]
at::parallel_for(0, B * MB * NB, 0, [&](int64_t begin, int64_t end) {
int64_t bs{0}, mb{0}, nb{0};
data_index_init(begin, bs, B, mb, MB, nb, NB);
// for brgemm, use float32 for accumulate
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
for (int i = begin; i < end; ++i) {
UNUSED(i);
int mb_start = mb * BLOCK_M;
int mb_size = std::min(M - mb_start, BLOCK_M);
int nb_start = nb * BLOCK_N;
int nb_size = std::min(N - nb_start, BLOCK_N);
tinygemm_kernel<scalar_t>(
/* A */ mat1 + bs * mat1_strideB + mb_start * mat1_strideM,
/* B */ mat2 + bs * mat2_strideB + nb_start * mat2_strideN /* nb * BLOCK_N * K */,
/* C */ out + bs * out_strideB + mb_start * out_strideM + nb_start,
/* Ctmp*/ Ctmp,
/* M */ mb_size,
/* N */ nb_size,
/* K */ K,
/* lda */ mat1_strideM,
/* ldb */ nb_size,
/* ldc */ out_strideM,
/* brg */ use_brgemm);
// move to the next index
data_index_step(bs, B, mb, MB, nb, NB);
}
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
}
} // anonymous namespace
// mat1 : [B, M, K]
// mat2 : [B, N, K] or [B, OC, IC]
// out : [B, M, N]
// scale: [] 0-dim tensor for per tensor quant
//
void bmm_cpu(
at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, const std::optional<at::Tensor>& scale) {
RECORD_FUNCTION("sgl-kernel::bmm_cpu", std::vector<c10::IValue>({out, mat1, mat2}));
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
// input and out could be non-contiguous
// weight needs to be contiguous in [OC, IC] order
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(out);
CHECK_INPUT(mat2);
CHECK_DIM(3, out);
CHECK_DIM(3, mat1);
CHECK_DIM(3, mat2);
int64_t B = mat1.size(0);
int64_t M = mat1.size(1);
int64_t N = mat2.size(1);
int64_t K = mat1.size(2);
TORCH_CHECK(!scale.has_value(), "bmm: do not support fp8 weight for now.")
TORCH_CHECK(N % 32 == 0, "tinygemm requires N to be 32x.");
int64_t mat1_strideB = mat1.stride(0);
int64_t mat1_strideM = mat1.stride(1);
int64_t out_strideB = out.stride(0);
int64_t out_strideM = out.stride(1);
// check shapes
TORCH_CHECK(mat2.size(0) == B && mat2.size(2) == K, "bmm: mat2 shape mismatch!");
TORCH_CHECK(out.size(0) == B && out.size(1) == M, "bmm: out shape mismatch!");
AT_DISPATCH_REDUCED_FLOATING_TYPES(mat1.scalar_type(), "bmm_kernel_impl", [&] {
bmm_kernel_impl<scalar_t>(
out.data_ptr<scalar_t>(),
mat1.data_ptr<scalar_t>(),
packed_w.data_ptr<scalar_t>(),
B,
M,
N,
K,
mat1_strideB,
mat1_strideM,
out_strideB,
out_strideM);
});
}

View File

@@ -0,0 +1,324 @@
#pragma once
#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include <ATen/record_function.h>
#if defined(_OPENMP)
#include <omp.h>
#endif
namespace {
// dispatch bool
#define AT_DISPATCH_BOOL(BOOL_V, BOOL_NAME, ...) \
[&] { \
if (BOOL_V) { \
constexpr bool BOOL_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool BOOL_NAME = false; \
return __VA_ARGS__(); \
} \
}()
// dispatch: bfloat16, float16, int8_t, fp8_e4m3
#define CPU_DISPATCH_PACKED_TYPES(TYPE, ...) \
[&] { \
switch (TYPE) { \
case at::ScalarType::BFloat16: { \
using packed_t = at::BFloat16; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Half: { \
using packed_t = at::Half; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Char: { \
using packed_t = int8_t; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Float8_e4m3fn: { \
using packed_t = at::Float8_e4m3fn; \
return __VA_ARGS__(); \
} \
default: \
TORCH_CHECK(false, "Unsupported floating data type.\n"); \
} \
}()
// dispatch with mixed dtypes (TYPE1, TYPE2):
// TYPE1: the primary dtype (input, output, weight);
// TYPE2: the secondary dtype (bias, etc.).
#define CPU_DISPATCH_REDUCED_FLOATING_TYPES_EXT(TYPE1, TYPE2, ...) \
[&] { \
if (TYPE2 == at::kFloat) { \
switch (TYPE1) { \
case at::ScalarType::BFloat16: { \
using scalar_t = at::BFloat16; \
using param_t = float; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
using param_t = float; \
return __VA_ARGS__(); \
} \
default: \
TORCH_CHECK(false, "Unsupported floating data type.\n"); \
} \
} else { \
TORCH_CHECK(TYPE1 == TYPE2); \
switch (TYPE1) { \
case at::ScalarType::BFloat16: { \
using scalar_t = at::BFloat16; \
using param_t = at::BFloat16; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Half: { \
using scalar_t = at::Half; \
using param_t = at::Half; \
return __VA_ARGS__(); \
} \
default: \
TORCH_CHECK(false, "Unsupported floating data type.\n"); \
} \
} \
}()
#define UNUSED(x) (void)(x)
#define CHECK_CPU(x) TORCH_CHECK(x.device().type() == at::kCPU, #x " must be a CPU tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_LAST_DIM_CONTIGUOUS(x) \
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimension")
#define CHECK_INPUT(x) \
CHECK_CPU(x); \
CHECK_CONTIGUOUS(x)
#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \
CHECK_CPU(x); \
CHECK_LAST_DIM_CONTIGUOUS(x)
#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
// [NB] Parallel Routines
//
// * at::parallel_for - applies for most of generic use cases, this will be compiled
// against openmp in default torch release.
//
// * parallel_for - same function as above, can choose payload partition scheme in
// balance211.
//
// * parallel_2d - parallel for 2 dimensions, used in GEMM, etc.
// this one will do payload balance across 2 dimensions.
//
// grain size for each thread
constexpr int GRAIN_SIZE = 1024;
template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
inline T div_up(T x, T y) {
return (x + y - 1) / y;
}
// you can only use at::get_thread_num() with at::parallel_for()
// as it is lazy initialized, otherwise it will always return 0.
inline int get_thread_num() {
#if defined(_OPENMP)
return omp_get_thread_num();
#else
return 0;
#endif
}
// balance payload across each thread
template <typename T>
inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) {
#if 0
// onednn partition pattern
T& n_my = n_end;
if (nth <= 1 || n == 0) {
n_start = 0;
n_my = n;
} else {
T n1 = div_up(n, nth);
T n2 = n1 - 1;
T T1 = n - n2 * nth;
n_my = ith < T1 ? n1 : n2;
n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2;
}
n_end += n_start;
#else
// pytorch aten partition pattern
T n_my = div_up(n, nth);
n_start = ith * n_my;
n_end = std::min(n_start + n_my, n);
#endif
}
template <typename func_t>
inline void parallel_for(int n, const func_t& f) {
#if defined(_OPENMP)
#pragma omp parallel
{
int nth = omp_get_num_threads();
int ith = omp_get_thread_num();
int tbegin, tend;
balance211(n, nth, ith, tbegin, tend);
f(tbegin, tend);
}
#else
f(0, n);
#endif
}
// for 1d parallel, use `actual_nth`
// for 2d parallel, use even nths, e.g. 43->42
int inline adjust_num_threads(int m) {
int actual_nth = at::get_num_threads();
if (m == 1) {
return actual_nth;
}
return std::max(1, (actual_nth >> 1) * 2);
}
template <typename func_t>
inline void parallel_2d(int m, int n, const func_t& f) {
// make sure we have even num_threads
int nth = adjust_num_threads(m);
// [NOTE] thread blocking:
//
// 1) prefer square block per thread
// 2) use even number of CPU cores
// 3) use all `num_threads` cores
//
// we have:
// TM * TN = T
// BM / TM = BN / TN
// then:
// TM = ((BM / BN) * T) ^ 0.5
//
float r = float(m) / n;
int nth_m = std::ceil(std::sqrt(r * nth));
int nth_n = 1;
for (; nth_m > 0; --nth_m) {
nth_n = nth / nth_m;
if (nth_m * nth_n == nth) {
break;
}
}
#if defined(_OPENMP)
#pragma omp parallel num_threads(nth)
{
int ith = omp_get_thread_num();
int ith_m = ith / nth_n;
int ith_n = ith % nth_n;
int thread_block_m = div_up(m, nth_m);
int thread_block_n = div_up(n, nth_n);
int begin_m = ith_m * thread_block_m;
int end_m = std::min(m, begin_m + thread_block_m);
int begin_n = ith_n * thread_block_n;
int end_n = std::min(n, begin_n + thread_block_n);
f(begin_m, end_m, begin_n, end_n);
}
#else
f(0, m, 0, n);
#endif
}
// limit max cache blocks
// when we need to do pre-unpack for weights, e.g. fp8
#define MAX_CACHE_BLOCK_SIZE 4
template <typename T>
inline int get_cache_blocks(int chunk_size) {
// L2 2MB and ratio of 50%
const int L2_size = 2048 * 1024 >> 1;
return std::max(1, int(L2_size / (chunk_size * sizeof(T))));
}
template <>
inline int get_cache_blocks<at::Float8_e4m3fn>(int chunk_size) {
// fp8 uses bf16 as accumulate type
int cache_block_size = get_cache_blocks<at::BFloat16>(chunk_size);
return std::min(MAX_CACHE_BLOCK_SIZE, cache_block_size);
}
// 2d sequential loop in range : [mb0, mb1), [nb0, nb1)
template <typename T, typename func_t>
inline void loop_2d(int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1, int64_t chunk_size, const func_t& f) {
// get number of blocks for L2 in most inner loop
int64_t cache_blocks_nb = get_cache_blocks<T>(chunk_size);
// loop order: [NB / cache_blocks_nb, MB, cache_blocks_nb]
// TODO: implement reverse order of [MB / cache_blocks_mb, NB, cache_blocks_mb]
for (int64_t nbb = nb0; nbb < nb1; nbb += cache_blocks_nb) {
for (int64_t mb = mb0; mb < mb1; ++mb) {
for (int64_t nb = nbb; nb < std::min(nbb + cache_blocks_nb, nb1); ++nb) {
f(mb, nb, nb - nbb);
}
}
}
}
// data indexing for dimension collapse
template <typename T>
inline T data_index_init(T offset) {
return offset;
}
template <typename T, typename... Args>
inline T data_index_init(T offset, T& x, const T& X, Args&&... args) {
offset = data_index_init(offset, std::forward<Args>(args)...);
x = offset % X;
return offset / X;
}
inline bool data_index_step() {
return true;
}
template <typename T, typename... Args>
inline bool data_index_step(T& x, const T& X, Args&&... args) {
if (data_index_step(std::forward<Args>(args)...)) {
x = ((x + 1) == X) ? 0 : (x + 1);
return x == 0;
}
return false;
}
// forced unroll for perf critical path
#if __has_attribute(always_inline)
#define ALWAYS_INLINE __attribute__((__always_inline__)) inline
#else
#define ALWAYS_INLINE inline
#endif
template <int n>
struct Unroll {
template <typename Func, typename... Args>
ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
Unroll<n - 1>{}(f, args...);
f(std::integral_constant<int, n - 1>{}, args...);
}
};
template <>
struct Unroll<1> {
template <typename Func, typename... Args>
ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
f(std::integral_constant<int, 0>{}, args...);
}
};
} // anonymous namespace

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,723 @@
#include "common.h"
#include "gemm.h"
#include "vec.h"
namespace {
// [NOTE]: extend attention for CPU
// 1. tune BLOCK_M and BLOCK_N
// 2. can handle non-contiguous k_exttend and v_extend
// 3. computes attention for prefix and extend separately
// 4. TODO: vectorize `pack_vnni` and `pack_vnni2`
//
template <typename index_t>
inline index_t get_index(index_t* ind, int i) {
return (ind == nullptr) ? (index_t)i : ind[i];
}
#if defined(CPU_CAPABILITY_AVX512)
// key: from [N, 32] to [32/2, N, 2]
template <typename scalar_t, typename index_t>
inline void pack_vnni_Nx32(
scalar_t* __restrict__ dst,
const scalar_t* __restrict__ src,
const index_t* __restrict__ ind,
int N,
int ld_src,
int ld_dst) {
__m512i vinputs[16];
int n = 0;
for (; n < N; ++n) {
index_t index = get_index(ind, n);
vinputs[n] = _mm512_loadu_si512(src + index * ld_src);
}
// padding with zero to avoid uninitialized vectors
for (; n < 16; ++n) {
vinputs[n] = _mm512_set1_epi32(0);
}
// pack key
transpose_16x16_32bit(vinputs);
const __mmask16 vmask = (1 << N) - 1;
for (int k = 0; k < 16; ++k) {
_mm512_mask_storeu_epi32(dst + k * ld_dst * 2, vmask, vinputs[k]);
}
}
// value: from [K, 32] to [K/2, 32, 2]
template <typename scalar_t, typename index_t>
inline void pack_vnni_Kx32(
scalar_t* __restrict__ dst,
const scalar_t* __restrict__ src,
const index_t* __restrict__ ind,
int K,
int ld_src,
int ld_dst) {
__m512i vinputs[2];
int k = 0;
for (; k < K; ++k) {
index_t index = get_index(ind, k);
vinputs[k] = _mm512_loadu_si512(src + index * ld_src);
}
// padding with zero to avoid uninitialized vectors
for (; k < 2; ++k) {
vinputs[k] = _mm512_set1_epi32(0);
}
// pack value
__m512i d0, d1;
std::tie(d0, d1) = transpose_2x32_16bit(vinputs[0], vinputs[1]);
_mm512_storeu_si512(dst + 0 * ld_dst * 2, d0);
_mm512_storeu_si512(dst + 0 * ld_dst * 2 + 32, d1);
}
#endif
// convert to vnni format
// from [N, K/2, 2] to [K/2, N, 2] for bfloat16 and float16
template <typename scalar_t, typename index_t>
void pack_vnni(
scalar_t* __restrict__ dst,
const scalar_t* __restrict__ src,
const index_t* __restrict__ ind,
int N,
int K,
int ld_src,
int ld_dst) {
#if defined(CPU_CAPABILITY_AVX512)
const int NB = div_up(N, 16);
const int KB = K / 32; // no remainder
const bool is_indexed = ind != nullptr;
for (int nb = 0; nb < NB; ++nb) {
for (int kb = 0; kb < KB; ++kb) {
// handle 16x512bits each block
int nb_size = std::min(N - nb * 16, 16);
pack_vnni_Nx32<scalar_t, index_t>(
/* dst */ dst + ((kb * 32) >> 1) * ld_dst * 2 + nb * 16 * 2,
/* src */ src + kb * 32 + (is_indexed ? 0 : nb * 16 * ld_src),
/* ind */ is_indexed ? ind + nb * 16 : nullptr,
/* N */ nb_size,
/* ld_src */ ld_src,
/* ld_dst */ ld_dst);
}
}
#else
for (int n = 0; n < N; ++n) {
index_t index = get_index(ind, n);
for (int k = 0; k < K / 2; ++k) {
for (int d = 0; d < 2; ++d) {
dst[k * ld_dst * 2 + n * 2 + d] = src[index * ld_src + k * 2 + d];
}
}
}
#endif
}
// convert to vnni format
// from [K/2, 2, N] to [K/2, N, 2] for bfloat16 and float16
template <typename scalar_t, typename index_t>
void pack_vnni2(
scalar_t* __restrict__ dst,
const scalar_t* __restrict__ src,
const index_t* __restrict__ ind,
int K,
int N,
int ld_src,
int ld_dst) {
#if defined(CPU_CAPABILITY_AVX512)
const int KB = div_up(K, 2);
const int NB = N / 32; // no remainder
const bool is_indexed = ind != nullptr;
for (int kb = 0; kb < KB; ++kb) {
for (int nb = 0; nb < NB; ++nb) {
// handle 2x512bits each block
int kb_size = std::min(K - kb * 2, 2);
pack_vnni_Kx32<scalar_t, index_t>(
/* dst */ dst + ((kb * 2) >> 1) * ld_dst * 2 + nb * 32 * 2,
/* src */ src + (is_indexed ? 0 : kb * 2 * ld_src) + nb * 32,
/* ind */ is_indexed ? ind + kb * 2 : nullptr,
/* K */ kb_size,
/* ld_src */ ld_src,
/* ld_dst */ ld_dst);
}
}
#else
int k = 0;
for (; k < (K >> 1) * 2; k += 2) {
index_t index0 = get_index(ind, k + 0);
index_t index1 = get_index(ind, k + 1);
for (int n = 0; n < N; ++n) {
dst[(k >> 1) * ld_dst * 2 + n * 2 + 0] = src[index0 * ld_src + n];
dst[(k >> 1) * ld_dst * 2 + n * 2 + 1] = src[index1 * ld_src + n];
}
}
if (K % 2 != 0) {
index_t index = get_index(ind, K - 1);
for (int n = 0; n < N; ++n) {
dst[(K >> 1) * ld_dst * 2 + n * 2 + 0] = src[index * ld_src + n];
dst[(K >> 1) * ld_dst * 2 + n * 2 + 1] = 0;
}
k += 2;
}
#endif
}
template <typename scalar_t>
inline void fill_stub(scalar_t* __restrict__ out, float val, int size) {
using Vec = at::vec::Vectorized<scalar_t>;
constexpr int kVecSize = Vec::size();
const Vec data_vec = Vec(static_cast<scalar_t>(val));
int d = 0;
#pragma GCC unroll 4
for (; d <= size - kVecSize; d += kVecSize) {
data_vec.store(out + d);
}
if (size - d > 0) {
data_vec.store(out + d, size - d);
}
}
template <typename scalar_t, int BLOCK_N>
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input) {
static_assert(BLOCK_N % 32 == 0);
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int COLS = BLOCK_N / 16;
auto store = [&](auto i) {
constexpr int col = i % COLS;
// for COLS = 2, 4 use 512bit store
if constexpr (col % 2 == 0) {
fVec a_fvec0 = fVec::loadu(input + col * 16);
fVec a_fvec1 = fVec::loadu(input + col * 16 + 16);
bVec out_bvec = convert_from_float_ext<scalar_t>(a_fvec0, a_fvec1);
out_bvec.store(out + col * 16);
}
};
Unroll<COLS>{}(store);
}
template <typename scalar_t>
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ acc, float s, int size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
const fVec s_fvec = fVec(s);
int d = 0;
#pragma GCC unroll 4
for (; d <= size - kVecSize; d += kVecSize) {
fVec a_fvec0 = fVec::loadu(acc + d) * s_fvec;
fVec a_fvec1 = fVec::loadu(acc + d + fVec::size()) * s_fvec;
bVec out_bvec = convert_from_float_ext<scalar_t>(a_fvec0, a_fvec1);
out_bvec.store(out + d);
}
for (; d < size; ++d) {
out[d] = static_cast<scalar_t>(acc[d] * s);
}
}
template <typename scalar_t, typename index_t, int BLOCK_M, int BLOCK_N>
void extend_attention_kernel_impl(
scalar_t* __restrict__ o_extend,
const scalar_t* __restrict__ q_extend,
const scalar_t* __restrict__ k_extend,
const scalar_t* __restrict__ v_extend,
const scalar_t* __restrict__ k_buffer,
const scalar_t* __restrict__ v_buffer,
const index_t* __restrict__ req_to_token,
const int64_t* __restrict__ req_pool_indices,
const int64_t* __restrict__ seq_lens,
const index_t* __restrict__ extend_seq_lens,
const index_t* __restrict__ extend_start_loc,
const void* __restrict__ buffer,
int batches,
int num_heads,
int num_heads_kv,
int head_size,
int head_size_v,
int q_strideM,
int q_strideH,
int ke_strideN,
int ke_strideH,
int ve_strideN,
int ve_strideH,
int k_strideN,
int k_strideH,
int v_strideN,
int v_strideH,
float scaling,
float logit_cap,
int max_num_reqs,
int max_context_len,
int max_total_num_tokens,
int max_len_extend,
int buffer_size_per_thread,
bool is_prefix_skipped) {
using Vec = at::vec::Vectorized<float>;
// strides
const int o_strideM = num_heads * head_size_v;
const int o_strideH = head_size_v;
// we use same buffer for packed key and value
const int ldb_tmp = std::max(head_size, head_size_v);
const bool has_logit_cap = logit_cap > 0;
float rlogit_cap = has_logit_cap ? 1 / logit_cap : 0.f;
const int num_groups = num_heads / num_heads_kv;
TORCH_CHECK(num_groups * num_heads_kv == num_heads);
// number of blocks along M
int MB = div_up(max_len_extend, BLOCK_M);
// parallel on [batches, num_heads, BM]
at::parallel_for(0, batches * num_heads * MB, 0, [&](int begin, int end) {
int bs{0}, head_id{0}, mb{0};
data_index_init(begin, bs, batches, head_id, num_heads, mb, MB);
int tid = at::get_thread_num();
// s_i and s_delta: [BLOCK_M, BLOCK_N]
float* __restrict__ s_i = reinterpret_cast<float*>((char*)(buffer) + tid * buffer_size_per_thread);
float* __restrict__ s_delta = s_i;
// v_prime: [BLOCK_M, head_size_v]
float* __restrict__ v_prime = s_i + BLOCK_M * BLOCK_N;
// s_delta2: [BLOCK_M, BLOCK_N]; copy of s_delta in scalar_t
scalar_t* __restrict__ s_delta2 = reinterpret_cast<scalar_t*>(v_prime + BLOCK_N * head_size_v);
// Btmp: [BLOCK_N, max(head_size, head_size_v)]
scalar_t* __restrict__ Btmp = s_delta2 + BLOCK_M * BLOCK_N;
// init Btmp just once for each thread to prevent NaN
fill_stub(Btmp, 0.f, BLOCK_N * ldb_tmp);
alignas(64) float s_prime[BLOCK_M];
alignas(64) float m_prime[BLOCK_M];
for (int i = begin; i < end; ++i) {
// seq_len = prefix + extend
int head_kv_id = head_id / num_groups;
int seq_len = seq_lens[bs];
int seq_len_extend = extend_seq_lens[bs];
int seq_len_prefix = seq_len - seq_len_extend;
int seq_extend_start_loc = extend_start_loc[bs];
int req_pool_id = req_pool_indices[bs];
TORCH_CHECK(seq_len_prefix >= 0, "prefix len < 0!");
TORCH_CHECK(seq_len <= max_context_len, "seq_len out of scope!");
TORCH_CHECK(req_pool_id < max_num_reqs, "req_pool_id out of scope!");
if (is_prefix_skipped) {
TORCH_CHECK(seq_len_prefix == 0, "extend attention: expect seq_len_prefix to be 0, got ", seq_len_prefix);
}
// offset and size in MB
int m = mb * BLOCK_N;
int m_size = std::min(BLOCK_M, seq_len_extend - m);
if (m_size <= 0) {
data_index_step(bs, batches, head_id, num_heads, mb, MB);
continue;
}
// get query
const scalar_t* __restrict__ q_ptr = q_extend + (seq_extend_start_loc + m) * q_strideM + head_id * q_strideH;
// init v', s' and m'
fill_stub(v_prime, 0.f, m_size * head_size_v);
fill_stub(s_prime, 0.f, m_size);
fill_stub(m_prime, -std::numeric_limits<scalar_t>::infinity(), m_size);
// stage 1: compute scores with prefix
for (int n = 0; n < seq_len_prefix; n += BLOCK_N) {
int n_size = std::min(BLOCK_N, seq_len_prefix - n);
// `n_size` is K in 2nd gemm, pad to TILE_K;
const int padded_n_size = div_up(n_size, TILE_K) * TILE_K;
// get key and pack
pack_vnni<scalar_t, index_t>(
/* dst */ Btmp,
/* src */ k_buffer + head_kv_id * k_strideH,
/* ind */ req_to_token + req_pool_id * max_context_len + n,
/* N */ n_size,
/* K */ head_size,
/* ld_src */ k_strideN,
/* ld_dst */ BLOCK_N);
// calculate s_i <- Q @ K
at::native::cpublas::brgemm(
/* M */ m_size,
/* N */ n_size,
/* K */ head_size,
/* lda */ q_strideM,
/* ldb */ BLOCK_N,
/* ldc */ BLOCK_N,
/* add_C */ false,
/* A */ q_ptr,
/* B */ Btmp,
/* C */ s_i);
const Vec scale_vec = Vec(scaling);
for (int row = 0; row < m_size; ++row) {
// s_i <- s_i * scale
at::vec::map<float>(
[scale_vec](Vec x) { return x * scale_vec; }, s_i + row * BLOCK_N, s_i + row * BLOCK_N, n_size);
// TODO: `tanh` from torch uses sleef u10, going to be slow
if (has_logit_cap) {
at::vec::map<float>(
[logit_cap, rlogit_cap](Vec x) { return Vec(logit_cap) * (x * Vec(rlogit_cap)).tanh(); },
s_i + row * BLOCK_N,
s_i + row * BLOCK_N,
n_size);
}
// m_i: max value per row
float m_i = at::vec::reduce_all<float>(
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, s_i + row * BLOCK_N, n_size);
m_i = std::max(m_i, m_prime[row]);
// m_delta <- exp(m' - m_i)
float m_delta = std::exp(m_prime[row] - m_i);
// s_delta <- exp(s_i - m_i)
at::vec::map<float>(
[m_i](Vec x) { return (x - Vec(m_i)).exp_u20(); }, s_delta + row * BLOCK_N, s_i + row * BLOCK_N, n_size);
// s' <- s' * m_delta + sum(s_delta)
s_prime[row] *= m_delta;
s_prime[row] +=
at::vec::reduce_all<float>([](Vec& x, Vec& y) { return x + y; }, s_delta + row * BLOCK_N, n_size);
m_prime[row] = m_i;
// v' <- v' * m_delta
at::vec::map<float>(
[m_delta](Vec x) { return x * Vec(m_delta); },
v_prime + row * head_size_v,
v_prime + row * head_size_v,
head_size_v);
// pad s_delta with 0 first and then convert to scalar_t
fill_stub(s_delta + row * BLOCK_N + n_size, 0.f, padded_n_size - n_size);
copy_stub<scalar_t, BLOCK_N>(s_delta2 + row * BLOCK_N, s_delta + row * BLOCK_N);
}
// get value and pack
pack_vnni2<scalar_t, index_t>(
/* dst */ Btmp,
/* src */ v_buffer + head_kv_id * v_strideH,
/* ind */ req_to_token + req_pool_id * max_context_len + n,
/* K */ n_size,
/* N */ head_size_v,
/* ld_src */ v_strideN,
/* ld_dst */ head_size_v);
// calculate V' <- s_delta @ V + V'
at::native::cpublas::brgemm(
/* M */ m_size,
/* N */ head_size_v,
/* K */ padded_n_size, // n_size
/* lda */ BLOCK_N,
/* ldb */ head_size_v,
/* ldc */ head_size_v,
/* add_C */ true,
/* A */ s_delta2,
/* B */ Btmp,
/* C */ v_prime);
} // loop with seq_len_prefix
// stage 2: compute the triangle part
int num_keys = std::min(seq_len_extend, m + BLOCK_M);
for (int n = 0; n < num_keys; n += BLOCK_N) {
int n_size = std::min(BLOCK_N, num_keys - n);
// `n_size` is K in 2nd gemm, pad to TILE_K;
const int padded_n_size = div_up(n_size, TILE_K) * TILE_K;
// get key and pack
pack_vnni<scalar_t, index_t>(
/* dst */ Btmp,
/* src */ k_extend + (seq_extend_start_loc + n) * ke_strideN + head_kv_id * ke_strideH,
/* ind */ nullptr,
/* N */ n_size,
/* K */ head_size,
/* ld_src */ ke_strideN,
/* ld_dst */ BLOCK_N);
// calculate s_i <- Q @ K
at::native::cpublas::brgemm(
/* M */ m_size,
/* N */ n_size,
/* K */ head_size,
/* lda */ q_strideM,
/* ldb */ BLOCK_N,
/* ldc */ BLOCK_N,
/* add_C */ false,
/* A */ q_ptr,
/* B */ Btmp,
/* C */ s_i);
// apply causal mask
if (num_keys - n <= BLOCK_N) {
for (int row = 0; row < m_size; ++row) {
int last_col = m + row - n;
// fill [last_col + 1, n_size) to -inf
float* row_ptr = s_i + row * BLOCK_N;
fill_stub(row_ptr + last_col + 1, -std::numeric_limits<float>::infinity(), n_size - last_col - 1);
}
}
const Vec scale_vec = Vec(scaling);
for (int row = 0; row < m_size; ++row) {
// s_i <- s_i * scale
at::vec::map<float>(
[scale_vec](Vec x) { return x * scale_vec; }, s_i + row * BLOCK_N, s_i + row * BLOCK_N, n_size);
// TODO: `tanh` from torch uses sleef u10, going to be slow
if (has_logit_cap) {
at::vec::map<float>(
[logit_cap, rlogit_cap](Vec x) { return Vec(logit_cap) * (x * Vec(rlogit_cap)).tanh(); },
s_i + row * BLOCK_N,
s_i + row * BLOCK_N,
n_size);
}
// m_i: max value per row
float m_i = at::vec::reduce_all<float>(
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, s_i + row * BLOCK_N, n_size);
m_i = std::max(m_i, m_prime[row]);
// m_delta <- exp(m' - m_i)
float m_delta = std::exp(m_prime[row] - m_i);
// s_delta <- exp(s_i - m_i)
at::vec::map<float>(
[m_i](Vec x) { return (x - Vec(m_i)).exp_u20(); }, s_delta + row * BLOCK_N, s_i + row * BLOCK_N, n_size);
// s' <- s' * m_delta + sum(s_delta)
s_prime[row] *= m_delta;
s_prime[row] +=
at::vec::reduce_all<float>([](Vec& x, Vec& y) { return x + y; }, s_delta + row * BLOCK_N, n_size);
m_prime[row] = m_i;
// v' <- v' * m_delta
at::vec::map<float>(
[m_delta](Vec x) { return x * Vec(m_delta); },
v_prime + row * head_size_v,
v_prime + row * head_size_v,
head_size_v);
// pad s_delta with 0 first and then convert to scalar_t
fill_stub(s_delta + row * BLOCK_N + n_size, 0.f, padded_n_size - n_size);
copy_stub<scalar_t, BLOCK_N>(s_delta2 + row * BLOCK_N, s_delta + row * BLOCK_N);
}
// get value and pack
pack_vnni2<scalar_t, index_t>(
/* dst */ Btmp,
/* src */ v_extend + (seq_extend_start_loc + n) * ve_strideN + head_kv_id * ve_strideH,
/* ind */ nullptr,
/* K */ n_size,
/* N */ head_size_v,
/* ld_src */ ve_strideN,
/* ld_dst */ head_size_v);
// calculate V' <- s_delta @ V + V'
at::native::cpublas::brgemm(
/* M */ m_size,
/* N */ head_size_v,
/* K */ padded_n_size, // n_size
/* lda */ BLOCK_N,
/* ldb */ head_size_v,
/* ldc */ head_size_v,
/* add_C */ true,
/* A */ s_delta2,
/* B */ Btmp,
/* C */ v_prime);
} // loop with seq_len_extend
scalar_t* __restrict__ out_ptr = o_extend + (seq_extend_start_loc + m) * o_strideM + head_id * o_strideH;
for (int row = 0; row < m_size; ++row) {
float s = 1 / s_prime[row];
copy_stub<scalar_t>(out_ptr + row * o_strideM, v_prime + row * head_size_v, s, head_size_v);
}
// move to the next index
data_index_step(bs, batches, head_id, num_heads, mb, MB);
}
at::native::cpublas::brgemm_release();
});
}
} // anonymous namespace
// q_extend, k_extend, v_extend, o_extend: contiguous tensors
// k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
//
// q_extend: [num_tokens, num_heads, head_size]
// k_extend: [num_extend_tokens, num_heads, head_size]
// v_extend: [num_extend_tokens, num_heads, head_size]
// o_extend: [num_tokens, num_heads, head_size]
// k_buffer: [max_total_num_tokens, num_heads, head_size]
// v_buffer: [max_total_num_tokens, num_heads, head_size]
// req_to_token: [max_num_reqs, max_context_len] int32 or int64
// req_pool_indices: [num_seqs] int64
// seq_lens: [num_seqs] int64
// extend_seq_lens: [num_seqs]
// extend_start_loc: [num_seqs]
//
void extend_attention_cpu(
at::Tensor& q_extend,
at::Tensor& k_extend,
at::Tensor& v_extend,
at::Tensor& o_extend,
at::Tensor& k_buffer,
at::Tensor& v_buffer,
at::Tensor& req_to_token,
at::Tensor& req_pool_indices,
at::Tensor& seq_lens,
at::Tensor& extend_seq_lens,
at::Tensor& extend_start_loc,
int64_t max_len_extend,
double sm_scale,
double logit_cap) {
RECORD_FUNCTION(
"sgl-kernel::extend_attention_cpu",
std::vector<c10::IValue>(
{q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
req_to_token,
req_pool_indices,
seq_lens,
extend_seq_lens,
extend_start_loc}));
CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_extend);
CHECK_INPUT(o_extend);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_extend);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_extend);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_buffer);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_buffer);
int num_seqs = seq_lens.size(0);
int max_num_reqs = req_to_token.size(0);
int max_context_len = req_to_token.size(1);
int max_total_num_tokens = k_buffer.size(0);
int num_heads = q_extend.size(1);
int num_heads_kv = k_extend.size(1);
int head_size = q_extend.size(2);
int head_size_v = v_extend.size(2);
// strides for q_extend, k_extend and v_extend
int q_strideM = q_extend.stride(0);
int q_strideH = q_extend.stride(1);
int ke_strideN = k_extend.stride(0);
int ke_strideH = k_extend.stride(1);
int ve_strideN = v_extend.stride(0);
int ve_strideH = v_extend.stride(1);
// strides for k_buffer and v_buffer
int k_strideN = k_buffer.stride(0);
int k_strideH = k_buffer.stride(1);
int v_strideN = v_buffer.stride(0);
int v_strideH = v_buffer.stride(1);
// check sizes
CHECK_EQ(req_pool_indices.size(0), num_seqs);
CHECK_EQ(extend_seq_lens.size(0), num_seqs);
CHECK_EQ(extend_start_loc.size(0), num_seqs);
CHECK_EQ(v_extend.size(1), num_heads_kv);
CHECK_EQ(k_buffer.size(1), v_buffer.size(1));
// MLA will skip prefix part
const bool is_prefix_skipped = k_buffer.size(1) != num_heads_kv;
// check index data types
const auto index_dtype = req_to_token.scalar_type();
TORCH_CHECK(
index_dtype == at::kInt || index_dtype == at::kLong,
"extend: expect req_to_token to be int32 or int64, got ",
index_dtype);
TORCH_CHECK(seq_lens.scalar_type() == at::kLong, "extend: expect req_lens to be int64, got ", seq_lens.scalar_type());
TORCH_CHECK(
req_pool_indices.scalar_type() == at::kLong,
"extend: expect req_pool_indices to be int64, got ",
req_pool_indices.scalar_type());
TORCH_CHECK(
extend_seq_lens.scalar_type() == index_dtype && extend_start_loc.scalar_type() == index_dtype,
"extend: expect extend_seq_lens and extend_start_loc to have same dtype as req_to_token.");
// D and DV need to be 32x as we transpose by 512-bit
TORCH_CHECK(head_size % 32 == 0, "invalid head_size ", head_size);
TORCH_CHECK(head_size_v % 32 == 0, "invalid head_size_v ", head_size_v);
// block size for query seq length
constexpr int BLOCK_M = 32;
// block size for key/value seq length
constexpr int BLOCK_N = 32;
const int size_per_thread =
/* s_i */ BLOCK_M * BLOCK_N * sizeof(float) +
/* v_prime */ BLOCK_M * head_size_v * sizeof(float) +
/* s_delta */ BLOCK_M * BLOCK_N * sizeof(uint16_t) +
/* Btmp */ BLOCK_N * std::max(head_size, head_size_v) * sizeof(uint16_t);
int num_threads = at::get_num_threads();
auto buffer = at::empty({num_threads, size_per_thread}, q_extend.options().dtype(at::kChar));
AT_DISPATCH_REDUCED_FLOATING_TYPES(q_extend.scalar_type(), "extend_attention_kernel", [&] {
AT_DISPATCH_INDEX_TYPES(index_dtype, "extend_attention_indices", [&] {
extend_attention_kernel_impl<scalar_t, index_t, BLOCK_M, BLOCK_N>(
o_extend.data_ptr<scalar_t>(),
q_extend.data_ptr<scalar_t>(),
k_extend.data_ptr<scalar_t>(),
v_extend.data_ptr<scalar_t>(),
k_buffer.data_ptr<scalar_t>(),
v_buffer.data_ptr<scalar_t>(),
req_to_token.data_ptr<index_t>(),
req_pool_indices.data_ptr<int64_t>(),
seq_lens.data_ptr<int64_t>(),
extend_seq_lens.data_ptr<index_t>(),
extend_start_loc.data_ptr<index_t>(),
buffer.data_ptr(),
num_seqs,
num_heads,
num_heads_kv,
head_size,
head_size_v,
q_strideM,
q_strideH,
ke_strideN,
ke_strideH,
ve_strideN,
ve_strideH,
k_strideN,
k_strideH,
v_strideN,
v_strideH,
sm_scale,
logit_cap,
max_num_reqs,
max_context_len,
max_total_num_tokens,
max_len_extend,
size_per_thread,
is_prefix_skipped);
});
});
}

View File

@@ -0,0 +1,525 @@
#include "gemm.h"
#include "common.h"
#include "vec.h"
namespace {
// packed layout:
// quants {N, K} int8_t
// comp {N} int32_t
template <int BLOCK_N>
inline void s8s8_compensation(int8_t* __restrict__ packed, int K) {
#if defined(CPU_CAPABILITY_AVX512)
constexpr int COLS = BLOCK_N / 16;
__m512i vcomp[COLS];
for (int col = 0; col < COLS; ++col) {
vcomp[col] = _mm512_setzero_si512();
}
const int64_t offset = BLOCK_N * K;
const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));
for (int k = 0; k < K / 4; ++k) {
for (int col = 0; col < COLS; ++col) {
__m512i vb = _mm512_loadu_si512((const __m512i*)(packed + k * BLOCK_N * 4 + col * 64));
vcomp[col] = _mm512_dpbusd_epi32(vcomp[col], off, vb);
}
}
for (int col = 0; col < COLS; ++col) {
_mm512_storeu_si512((__m512i*)(packed + offset + col * 64), vcomp[col]);
}
#else
TORCH_CHECK(false, "s8s8_compensation not implemented!");
#endif
}
// convert to vnni format
// from [N, K] to [K/2, N, 2] for bfloat16 and float16
template <typename packed_t>
inline void pack_vnni(packed_t* __restrict__ packed, const packed_t* __restrict__ weight, int N, int K) {
const int VNNI_BLK = 2;
for (int n = 0; n < N; ++n) {
for (int k = 0; k < K / VNNI_BLK; ++k) {
for (int d = 0; d < VNNI_BLK; ++d) {
packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d];
}
}
}
}
template <>
inline void pack_vnni<int8_t>(int8_t* __restrict__ packed, const int8_t* __restrict__ weight, int N, int K) {
constexpr int BLOCK_N = block_size_n();
TORCH_CHECK(N == BLOCK_N);
const int VNNI_BLK = 4;
for (int n = 0; n < N; ++n) {
for (int k = 0; k < K / VNNI_BLK; ++k) {
for (int d = 0; d < VNNI_BLK; ++d) {
packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d];
}
}
}
s8s8_compensation<BLOCK_N>(packed, K);
}
template <typename scalar_t>
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= size - kVecSize; d += kVecSize) {
fVec data0 = fVec::loadu(input + d);
fVec data1 = fVec::loadu(input + d + fVec::size());
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
out_vec.store(out + d);
}
for (; d < size; ++d) {
out[d] = static_cast<scalar_t>(input[d]);
}
}
template <typename scalar_t>
inline void copy_add_stub(
scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= size - kVecSize; d += kVecSize) {
fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d);
fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size());
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
out_vec.store(out + d);
}
for (; d < size; ++d) {
out[d] = static_cast<scalar_t>(input[d] + bias[d]);
}
}
template <typename scalar_t, bool has_bias, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn {
static inline void apply(
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B,
scalar_t* __restrict__ C,
const float* __restrict__ bias,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc) {
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
}
};
#if defined(CPU_CAPABILITY_AVX512)
template <bool has_bias, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn<at::BFloat16, has_bias, BLOCK_M, BLOCK_N> {
static inline void apply(
const at::BFloat16* __restrict__ A,
const at::BFloat16* __restrict__ B,
at::BFloat16* __restrict__ C,
const float* __restrict__ bias,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc) {
constexpr int ROWS = BLOCK_M;
constexpr int COLS = BLOCK_N / 16;
// prefetch distance
constexpr int PREFETCH_SIZE_K = 0;
__m512bh va;
__m512bh vb[COLS];
__m512 vc[ROWS * COLS];
auto loadc = [&](auto i) {
constexpr int col = i % COLS;
if constexpr (has_bias) {
vc[i] = _mm512_loadu_ps(bias + col * 16);
} else {
vc[i] = _mm512_set1_ps(0.f);
}
};
Unroll<ROWS * COLS>{}(loadc);
const int64_t K2 = K >> 1;
const int64_t lda2 = lda >> 1;
const int64_t ldb2 = ldb; // ldb * 2 >> 1;
const float* a_ptr = reinterpret_cast<const float*>(A);
const float* b_ptr = reinterpret_cast<const float*>(B);
auto compute = [&](auto i, int64_t k) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (col == 0) {
va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k]));
}
if constexpr (row == 0) {
vb[col] = (__m512bh)(_mm512_loadu_si512(b_ptr + k * ldb2 + col * 16));
if constexpr (PREFETCH_SIZE_K > 0) {
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0);
}
}
vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]);
};
for (int64_t k = 0; k < K2; ++k) {
Unroll<ROWS * COLS>{}(compute, k);
}
auto storec = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
// for COLS = 2, 4 use 512bit store
// for COLS = 1, 3 use 256bit store
if constexpr (COLS % 2 == 0) {
if constexpr (col % 2 == 0) {
_mm512_storeu_si512(
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
(__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col])));
}
} else {
_mm256_storeu_si256(reinterpret_cast<__m256i*>(C + row * ldc + col * 16), (__m256i)(_mm512_cvtneps_pbh(vc[i])));
}
};
Unroll<ROWS * COLS>{}(storec);
}
};
#endif
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
tinygemm_kernel_nn<scalar_t, has_bias, MB_SIZE, NB_SIZE>::apply( \
A + mb_start * lda, \
B + nb_start * 2, \
C + mb_start * ldc + nb_start, \
has_bias ? bias + nb_start : nullptr, \
K, \
lda, \
ldb, \
ldc);
template <typename scalar_t, bool has_bias>
struct brgemm {
static inline void apply(
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B,
scalar_t* __restrict__ C,
float* __restrict__ Ctmp,
const float* __restrict__ bias,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc) {
constexpr int BLOCK_N = block_size_n();
at::native::cpublas::brgemm(M, N, K, lda, ldb, BLOCK_N, /* add_C */ false, A, B, Ctmp);
// copy from Ctmp to C
for (int64_t m = 0; m < M; ++m) {
if constexpr (has_bias) {
copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N);
} else {
copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N);
}
}
}
};
template <typename scalar_t, bool has_bias>
void tinygemm_kernel(
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B,
scalar_t* __restrict__ C,
float* __restrict__ Ctmp,
const float* __restrict__ bias,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg) {
if (brg) {
brgemm<scalar_t, has_bias>::apply(A, B, C, Ctmp, bias, M, N, K, lda, ldb, ldc);
return;
}
// pattern: 1-4-16, N = 16, 32, 48, 64
constexpr int64_t BLOCK_M = 4;
constexpr int64_t BLOCK_N = 64;
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
for (int mb = 0; mb < MB; ++mb) {
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
for (int64_t nb = 0; nb < NB; ++nb) {
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
switch (mb_size << 4 | nb_size >> 4) {
// mb_size = 1
case 0x11:
LAUNCH_TINYGEMM_KERNEL_NN(1, 16);
break;
case 0x12:
LAUNCH_TINYGEMM_KERNEL_NN(1, 32);
break;
case 0x13:
LAUNCH_TINYGEMM_KERNEL_NN(1, 48);
break;
case 0x14:
LAUNCH_TINYGEMM_KERNEL_NN(1, 64);
break;
// mb_size = 2
case 0x21:
LAUNCH_TINYGEMM_KERNEL_NN(2, 16);
break;
case 0x22:
LAUNCH_TINYGEMM_KERNEL_NN(2, 32);
break;
case 0x23:
LAUNCH_TINYGEMM_KERNEL_NN(2, 48);
break;
case 0x24:
LAUNCH_TINYGEMM_KERNEL_NN(2, 64);
break;
// mb_size = 3
case 0x31:
LAUNCH_TINYGEMM_KERNEL_NN(3, 16);
break;
case 0x32:
LAUNCH_TINYGEMM_KERNEL_NN(3, 32);
break;
case 0x33:
LAUNCH_TINYGEMM_KERNEL_NN(3, 48);
break;
case 0x34:
LAUNCH_TINYGEMM_KERNEL_NN(3, 64);
break;
// mb_size = 4
case 0x41:
LAUNCH_TINYGEMM_KERNEL_NN(4, 16);
break;
case 0x42:
LAUNCH_TINYGEMM_KERNEL_NN(4, 32);
break;
case 0x43:
LAUNCH_TINYGEMM_KERNEL_NN(4, 48);
break;
case 0x44:
LAUNCH_TINYGEMM_KERNEL_NN(4, 64);
break;
default:
TORCH_CHECK(false, "Unexpected block size, ", mb_size, " x ", nb_size);
}
}
}
}
template <typename scalar_t>
void weight_packed_linear_kernel_impl(
scalar_t* __restrict__ out,
const scalar_t* __restrict__ mat1,
const scalar_t* __restrict__ mat2,
const float* __restrict__ bias,
int64_t M,
int64_t N,
int64_t K,
int64_t mat1_strideM,
int64_t out_strideM) {
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
const bool use_brgemm = can_use_brgemm<scalar_t>(M);
// parallel on [MB, NB]
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
// for brgemm, use float32 for accumulate
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
loop_2d<scalar_t>(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(M - mb_start, BLOCK_M);
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(N - nb_start, BLOCK_N);
tinygemm_kernel<scalar_t, has_bias>(
/* A */ mat1 + mb_start * mat1_strideM,
/* B */ mat2 + nb_start * K /* nb * BLOCK_N * K */,
/* C */ out + mb_start * out_strideM + nb_start,
/* Ctmp*/ Ctmp,
/* bias*/ bias + nb_start,
/* M */ mb_size,
/* N */ nb_size,
/* K */ K,
/* lda */ mat1_strideM,
/* ldb */ nb_size,
/* ldc */ out_strideM,
/* brg */ use_brgemm);
});
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
});
}
} // anonymous namespace
// tinygemm interface
template <typename scalar_t>
void tinygemm_kernel(
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B,
scalar_t* __restrict__ C,
float* __restrict__ Ctmp,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg) {
tinygemm_kernel<scalar_t, false>(A, B, C, Ctmp, nullptr, M, N, K, lda, ldb, ldc, brg);
}
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
template void tinygemm_kernel<TYPE>( \
const TYPE* __restrict__ A, \
const TYPE* __restrict__ B, \
TYPE* __restrict__ C, \
float* __restrict__ Ctmp, \
int64_t M, \
int64_t N, \
int64_t K, \
int64_t lda, \
int64_t ldb, \
int64_t ldc, \
bool brg)
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
at::Tensor convert_weight_packed(at::Tensor& weight) {
// for 3d moe weights
// weight : [E, OC, IC]
// w1 : [E, 2N, K]
// w2 : [E, K, N]
CHECK_INPUT(weight);
const int64_t ndim = weight.ndimension();
TORCH_CHECK(ndim == 2 || ndim == 3, "expect weight to be 2d or 3d, got ", ndim, "d tensor.");
const auto st = weight.scalar_type();
const int64_t E = ndim == 3 ? weight.size(0) : 1;
const int64_t OC = ndim == 3 ? weight.size(1) : weight.size(0);
const int64_t IC = ndim == 3 ? weight.size(2) : weight.size(1);
// we handle 2 TILE_N at a time.
TORCH_CHECK(OC % TILE_N == 0, "invalid weight out features ", OC);
TORCH_CHECK(IC % TILE_K == 0, "invalid weight input features ", IC);
constexpr int64_t BLOCK_N = block_size_n();
const int64_t NB = div_up(OC, BLOCK_N);
// use phony sizes here [E, OC, IC], for each [E], [OC, IC] -> [IC / 2, OC, 2]
auto packed_weight = at::empty({}, weight.options());
const int64_t stride = OC * IC;
TORCH_CHECK(
st == at::kBFloat16 || st == at::kHalf || st == at::kChar || st == at::kFloat8_e4m3fn,
"expect weight to be bfloat16, float16, int8 or fp8_e4m3.");
CPU_DISPATCH_PACKED_TYPES(st, [&] {
// adjust most inner dimension size
const int packed_row_size = get_row_size<packed_t>(IC);
auto sizes = weight.sizes().vec();
sizes[ndim - 1] = packed_row_size;
packed_weight.resize_(sizes);
const packed_t* w_data = weight.data_ptr<packed_t>();
packed_t* packed_data = packed_weight.data_ptr<packed_t>();
// parallel on {E, NB}
at::parallel_for(0, E * NB, 0, [&](int64_t begin, int64_t end) {
int64_t e{0}, nb{0};
data_index_init(begin, e, E, nb, NB);
for (int64_t i = begin; i < end; ++i) {
UNUSED(i);
int64_t n = nb * BLOCK_N;
int64_t n_size = std::min(BLOCK_N, OC - n);
pack_vnni<packed_t>(
packed_data + e * OC * packed_row_size + n * packed_row_size, w_data + e * stride + n * IC, n_size, IC);
// move to the next index
data_index_step(e, E, nb, NB);
}
});
});
return packed_weight;
}
// mat1 : [M, K]
// mat2 : [N, K]
// bias : [N]
// out : [M, N]
//
at::Tensor
weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, const std::optional<at::Tensor>& bias, bool is_vnni) {
RECORD_FUNCTION("sgl-kernel::weight_packed_linear", std::vector<c10::IValue>({mat1, mat2, bias}));
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
CHECK_INPUT(mat2);
int64_t M = mat1.size(0);
int64_t N = mat2.size(0);
int64_t K = mat2.size(1);
CHECK_EQ(mat1.size(1), K);
CHECK_DIM(2, mat1);
CHECK_DIM(2, mat2);
auto out = at::empty({M, N}, mat1.options());
// strides
int64_t mat1_strideM = mat1.stride(0);
int64_t out_strideM = out.stride(0);
const bool has_bias = bias.has_value();
const float* bias_data = nullptr;
if (has_bias) {
CHECK_EQ(bias.value().size(0), N);
bias_data = bias.value().data_ptr<float>();
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(mat1.scalar_type(), "weight_packed_linear_kernel_impl", [&] {
weight_packed_linear_kernel_impl<scalar_t>(
out.data_ptr<scalar_t>(),
mat1.data_ptr<scalar_t>(),
packed_w.data_ptr<scalar_t>(),
bias_data,
M,
N,
K,
mat1_strideM,
out_strideM);
});
return out;
}

202
sgl-kernel/csrc/cpu/gemm.h Normal file
View File

@@ -0,0 +1,202 @@
#pragma once
#include <ATen/native/CPUBlas.h>
#include "common.h"
// amx-bf16
#define TILE_M 16
#define TILE_N 16
#define TILE_K 32
// block size for AMX gemm
constexpr int block_size_m() {
return 2 * TILE_M;
}
constexpr int block_size_n() {
return 2 * TILE_N;
}
// define threshold using brgemm (intel AMX)
template <typename T>
inline bool can_use_brgemm(int M);
template <>
inline bool can_use_brgemm<at::BFloat16>(int M) {
return M > 4;
}
template <>
inline bool can_use_brgemm<at::Half>(int M) {
return true;
}
// this requires PyTorch 2.7 or above
template <>
inline bool can_use_brgemm<int8_t>(int M) {
return M > 4;
}
template <>
inline bool can_use_brgemm<at::Float8_e4m3fn>(int M) {
return M > 4;
}
// work around compiler internal error
#define BLOCK_K 128 // 4 * TILE_K
// adjust leading dimension size for K
template <typename T>
inline int64_t get_row_size(int64_t K) {
return K;
}
template <>
inline int64_t get_row_size<int8_t>(int64_t K) {
return K + sizeof(int32_t);
}
inline int64_t get_row_size(int64_t K, bool use_int8_w8a8) {
return use_int8_w8a8 ? K + sizeof(int32_t) : K;
}
// pack weight to vnni format
at::Tensor convert_weight_packed(at::Tensor& weight);
// moe implementations for int8 w8a8
template <typename scalar_t>
void fused_experts_int8_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic1,
scalar_t* __restrict__ ic2,
uint8_t* __restrict__ A_tmp,
float* __restrict__ C_tmp,
uint8_t* __restrict__ Aq_tmp,
float* __restrict__ As_tmp,
const scalar_t* __restrict__ input,
const int8_t* __restrict__ packed_w1,
const int8_t* __restrict__ packed_w2,
const float* __restrict__ w1s,
const float* __restrict__ w2s,
const float* __restrict__ topk_weights,
const int32_t* __restrict__ sorted_ids,
const int32_t* __restrict__ expert_ids,
const int32_t* __restrict__ offsets,
int64_t M,
int64_t N,
int64_t K,
int64_t E,
int64_t topk,
int64_t num_tokens_post_pad);
// moe implementations for fp8 w8a16
template <typename scalar_t>
void fused_experts_fp8_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic0,
scalar_t* __restrict__ ic1,
scalar_t* __restrict__ ic2,
scalar_t* __restrict__ A_tmp,
scalar_t* __restrict__ B_tmp,
float* __restrict__ C_tmp,
const scalar_t* __restrict__ input,
const at::Float8_e4m3fn* __restrict__ packed_w1,
const at::Float8_e4m3fn* __restrict__ packed_w2,
const float* __restrict__ w1s,
const float* __restrict__ w2s,
int64_t block_size_N,
int64_t block_size_K,
const float* __restrict__ topk_weights,
const int32_t* __restrict__ sorted_ids,
const int32_t* __restrict__ expert_ids,
const int32_t* __restrict__ offsets,
int64_t M,
int64_t N,
int64_t K,
int64_t E,
int64_t topk,
int64_t num_tokens_post_pad);
// shared expert implementation for int8 w8a8
template <typename scalar_t>
void shared_expert_int8_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic1,
float* __restrict__ C_tmp,
uint8_t* __restrict__ Aq_tmp,
float* __restrict__ As_tmp,
const scalar_t* __restrict__ input,
const int8_t* __restrict__ packed_w1,
const int8_t* __restrict__ packed_w2,
const float* __restrict__ w1s,
const float* __restrict__ w2s,
const scalar_t* __restrict__ fused_experts_out,
float routed_scaling_factor,
int64_t M,
int64_t N,
int64_t K);
template <typename scalar_t>
void shared_expert_fp8_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic0,
scalar_t* __restrict__ ic1,
scalar_t* __restrict__ B_tmp,
float* __restrict__ C_tmp,
const scalar_t* __restrict__ input,
const at::Float8_e4m3fn* __restrict__ packed_w1,
const at::Float8_e4m3fn* __restrict__ packed_w2,
const float* __restrict__ w1s,
const float* __restrict__ w2s,
int64_t block_size_N,
int64_t block_size_K,
const scalar_t* __restrict__ fused_experts_out,
float routed_scaling_factor,
int64_t M,
int64_t N,
int64_t K);
// tinygemm interface
template <typename scalar_t>
void tinygemm_kernel(
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B,
scalar_t* __restrict__ C,
float* __restrict__ Ctmp,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg);
template <typename scalar_t>
void tinygemm_kernel(
const uint8_t* __restrict__ A,
const int8_t* __restrict__ B,
scalar_t* __restrict__ C,
int32_t* __restrict__ Ctmp,
const float* __restrict__ As,
const float* __restrict__ Bs,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg);
template <typename scalar_t>
void tinygemm_kernel(
const scalar_t* __restrict__ A,
const at::Float8_e4m3fn* __restrict__ B,
scalar_t* __restrict__ C,
scalar_t* __restrict__ Btmp,
float* __restrict__ Ctmp,
const float* __restrict__ scale,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg,
int64_t block_size_K,
bool do_unpack = true);

View File

@@ -0,0 +1,551 @@
#include "common.h"
#include "gemm.h"
#include "vec.h"
namespace {
template <typename scalar_t>
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= size - kVecSize; d += kVecSize) {
fVec data0 = fVec::loadu(input + d);
fVec data1 = fVec::loadu(input + d + fVec::size());
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
out_vec.store(out + d);
}
for (; d < size; ++d) {
out[d] = static_cast<scalar_t>(input[d]);
}
}
template <typename scalar_t>
inline void copy_add_stub(
scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= size - kVecSize; d += kVecSize) {
fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d);
fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size());
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
out_vec.store(out + d);
}
for (; d < size; ++d) {
out[d] = static_cast<scalar_t>(input[d] + bias[d]);
}
}
inline void unpack_B(
at::BFloat16* __restrict__ Btmp,
const at::Float8_e4m3fn* __restrict__ packed_B,
int N,
int K,
int ldb,
int ldb_tmp,
float scale) {
#if defined(CPU_CAPABILITY_AVX512)
// [K/2, N, 2]
const int K2 = K >> 1;
const int ldb2 = ldb; // ldb * 2 >> 1;
const uint16_t* b_ptr = reinterpret_cast<const uint16_t*>(packed_B);
const __m512 vd = _mm512_set1_ps(scale);
constexpr int BLOCK_N = block_size_n();
static_assert(BLOCK_N == 32);
// prefetch distance
constexpr int PREFETCH_SIZE_K = 64;
#pragma GCC unroll 4
for (int k = 0; k < K2; ++k) {
__m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2);
if constexpr (PREFETCH_SIZE_K > 0) {
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2, _MM_HINT_T0);
}
__m256i b8_0 = _mm512_extracti32x8_epi32(b8, 0);
__m256i b8_1 = _mm512_extracti32x8_epi32(b8, 1);
__m512bh bf16_0 = CVT_FP8_TO_BF16(b8_0);
__m512bh bf16_1 = CVT_FP8_TO_BF16(b8_1);
// Apply scale
__m512 f0_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 0));
__m512 f0_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 1));
__m512 f1_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 0));
__m512 f1_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 1));
f0_lo = _mm512_mul_ps(f0_lo, vd);
f0_hi = _mm512_mul_ps(f0_hi, vd);
f1_lo = _mm512_mul_ps(f1_lo, vd);
f1_hi = _mm512_mul_ps(f1_hi, vd);
bf16_0 = _mm512_cvtne2ps_pbh(f0_hi, f0_lo);
bf16_1 = _mm512_cvtne2ps_pbh(f1_hi, f1_lo);
_mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 0, (__m512i)bf16_0);
_mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 32, (__m512i)bf16_1);
}
#else
TORCH_CHECK(false, "unpack_B: scalar path not implemented!");
#endif
}
template <typename scalar_t, typename packed_t, bool has_bias, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn {
static inline void apply(
const scalar_t* __restrict__ A,
const packed_t* __restrict__ B,
scalar_t* __restrict__ C,
const float* __restrict__ bias,
const float* __restrict__ scale,
int K,
int lda,
int ldb,
int ldc,
int64_t block_size_K) {
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
}
};
#if defined(CPU_CAPABILITY_AVX512)
template <bool has_bias, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn<at::BFloat16, at::Float8_e4m3fn, has_bias, BLOCK_M, BLOCK_N> {
static inline void apply(
const at::BFloat16* __restrict__ A,
const at::Float8_e4m3fn* __restrict__ B,
at::BFloat16* __restrict__ C,
const float* __restrict__ bias,
const float* __restrict__ scale,
int K,
int lda,
int ldb,
int ldc,
int64_t block_size_K) {
constexpr int ROWS = BLOCK_M;
constexpr int COLS = BLOCK_N / 16;
const int KB = div_up(K, BLOCK_K);
// prefetch distance
constexpr int PREFETCH_SIZE_K = 64;
constexpr int PREFETCH_SIZE_KB = 1;
__m512bh va;
__m512bh vb[COLS];
__m512 vc[ROWS * COLS];
__m512 vsum[ROWS * COLS];
// block quant scale
__m512 vscale;
auto loadc = [&](auto i) {
constexpr int col = i % COLS;
if constexpr (has_bias) {
vc[i] = _mm512_loadu_ps(bias + col * 16);
} else {
vc[i] = _mm512_setzero_ps();
}
};
Unroll<ROWS * COLS>{}(loadc);
const int lda2 = lda >> 1;
const int ldb2 = ldb; // ldb * 2 >> 1;
const float* a_ptr = reinterpret_cast<const float*>(A);
const uint16_t* b_ptr = reinterpret_cast<const uint16_t*>(B);
auto compute = [&](auto i, int k) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (col == 0) {
va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k]));
if constexpr (PREFETCH_SIZE_K > 0) {
_mm_prefetch(a_ptr + row * lda2 + k + PREFETCH_SIZE_K, _MM_HINT_T0);
}
}
if constexpr (row == 0) {
if constexpr (col % 2 == 0) {
__m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2 + col * 16);
if constexpr (PREFETCH_SIZE_K > 0) {
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0);
}
vb[col + 0] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 0));
vb[col + 1] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 1));
}
}
vsum[i] = _mm512_dpbf16_ps(vsum[i], va, vb[col]);
};
constexpr int BLOCK_K2 = BLOCK_K >> 1;
for (int kb = 0; kb < KB; ++kb) {
int kb_start = kb * BLOCK_K2;
int kb_end = std::min(K >> 1, kb_start + BLOCK_K2);
// 1. load scale vector
vscale = _mm512_set1_ps(scale[kb]);
if constexpr (PREFETCH_SIZE_KB > 0) {
_mm_prefetch(scale + kb + PREFETCH_SIZE_KB, _MM_HINT_T0);
}
// 2. zero vsum for each block
Unroll<ROWS * COLS>{}([&](auto i) { vsum[i] = _mm512_setzero_ps(); });
// 3. accumulate across each block
for (int k = kb_start; k < kb_end; ++k) {
Unroll<ROWS * COLS>{}(compute, k);
}
// 4. apply scale
Unroll<ROWS * COLS>{}([&](auto i) { vc[i] = _mm512_fmadd_ps(vsum[i], vscale, vc[i]); });
}
auto storec = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
// for COLS = 2,4 use 512bit store
if constexpr (col % 2 == 0) {
_mm512_storeu_si512(
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
(__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col])));
}
};
Unroll<ROWS * COLS>{}(storec);
}
};
#endif
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
tinygemm_kernel_nn<scalar_t, at::Float8_e4m3fn, has_bias, MB_SIZE, NB_SIZE>::apply( \
A + mb_start * lda, \
B + nb_start * 2, \
C + mb_start * ldc + nb_start, \
has_bias ? bias + nb_start : nullptr, \
scale, \
K, \
lda, \
ldb, \
ldc, \
block_size_K);
template <typename scalar_t, typename packed_t, bool has_bias>
struct brgemm {
static inline void apply(
const scalar_t* __restrict__ A,
const packed_t* __restrict__ B,
scalar_t* __restrict__ C,
scalar_t* __restrict__ Btmp,
float* __restrict__ Ctmp,
const float* __restrict__ bias,
const float* __restrict__ scale,
int M,
int N,
int K,
int lda,
int ldb,
int ldc,
bool do_unpack = true) {
TORCH_CHECK(false, "struct brgemm: primary template not implemented!");
}
};
template <bool has_bias>
struct brgemm<at::BFloat16, at::Float8_e4m3fn, has_bias> {
static inline void apply(
const at::BFloat16* __restrict__ A,
const at::Float8_e4m3fn* __restrict__ B,
at::BFloat16* __restrict__ C,
at::BFloat16* __restrict__ Btmp,
float* __restrict__ Ctmp,
const float* __restrict__ bias,
const float* __restrict__ scale,
int M,
int N,
int K,
int lda,
int ldb,
int ldc,
bool do_unpack = true) {
constexpr int BLOCK_N = block_size_n();
// [K, BLOCK_N] -> [K / 2, BLOCK_N * 2]
const int ldb_tmp = BLOCK_N;
if (do_unpack) {
for (int k = 0; k < K; k += BLOCK_K) {
int kb_size = std::min(BLOCK_K, K - k);
int idx = k >> 7; // k / BLOCK_K where BLOCK_K = 128
unpack_B(Btmp + k * ldb_tmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]);
}
}
at::native::cpublas::brgemm(M, N, K, lda, ldb_tmp, BLOCK_N, /* add_C */ false, A, Btmp, Ctmp);
// copy from Ctmp to C
for (int m = 0; m < M; ++m) {
if constexpr (has_bias) {
copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N);
} else {
copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N);
}
}
}
};
template <typename scalar_t, bool has_bias>
void tinygemm_kernel(
const scalar_t* __restrict__ A,
const at::Float8_e4m3fn* __restrict__ B,
scalar_t* __restrict__ C,
scalar_t* __restrict__ Btmp,
float* __restrict__ Ctmp,
const float* __restrict__ scale,
const float* __restrict__ bias,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg,
int64_t block_size_K,
bool do_unpack = true) {
if (brg) {
brgemm<scalar_t, at::Float8_e4m3fn, has_bias>::apply(
A, B, C, Btmp, Ctmp, bias, scale, M, N, K, lda, ldb, ldc, do_unpack);
return;
}
// pattern: 1-4-16
constexpr int64_t BLOCK_M = 4;
constexpr int64_t BLOCK_N = 64;
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
for (int mb = 0; mb < MB; ++mb) {
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
for (int64_t nb = 0; nb < NB; ++nb) {
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
switch (mb_size << 4 | nb_size >> 4) {
case 0x12:
LAUNCH_TINYGEMM_KERNEL_NN(1, 32);
break;
case 0x22:
LAUNCH_TINYGEMM_KERNEL_NN(2, 32);
break;
case 0x32:
LAUNCH_TINYGEMM_KERNEL_NN(3, 32);
break;
case 0x42:
LAUNCH_TINYGEMM_KERNEL_NN(4, 32);
break;
default:
TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
}
}
}
}
template <typename scalar_t>
void fp8_scaled_mm_kernel_impl(
scalar_t* __restrict__ out,
const scalar_t* __restrict__ mat1,
const at::Float8_e4m3fn* __restrict__ mat2,
const float* __restrict__ scales2,
const float* __restrict__ bias,
scalar_t* __restrict__ buffer,
int64_t M,
int64_t N,
int64_t K,
int64_t mat1_strideM,
int64_t out_strideM,
int64_t block_size_N,
int64_t block_size_K,
int64_t buffer_size_per_thread) {
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
const int64_t scale_size_K = div_up(K, block_size_K);
const int64_t blocks_n_per_group = block_size_N / BLOCK_N;
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M);
// parallel on [MB, NB]
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
int tid = get_thread_num();
scalar_t* __restrict__ Btmp = buffer + tid * buffer_size_per_thread;
float* __restrict__ Ctmp = (float*)((void*)(Btmp + MAX_CACHE_BLOCK_SIZE * BLOCK_N * K));
loop_2d<at::Float8_e4m3fn>(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
const float* scale_ptr = scales2 + (nb / blocks_n_per_group) * scale_size_K;
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(M - mb_start, BLOCK_M);
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(N - nb_start, BLOCK_N);
// only do unpacking for the first row
bool do_unpack = (mb == mb0);
tinygemm_kernel<scalar_t, has_bias>(
/* A */ mat1 + mb_start * mat1_strideM,
/* B */ mat2 + nb_start * K, // nb * BLOCK_N * K
/* C */ out + mb_start * out_strideM + nb_start,
/* Btmp */ Btmp + nb_offset * BLOCK_N * K,
/* Ctmp */ Ctmp,
/* scale */ scale_ptr,
/* bias */ bias + nb_start,
/* M */ mb_size,
/* N */ nb_size,
/* K */ K,
/* lda */ mat1_strideM,
/* ldb */ nb_size,
/* ldc */ out_strideM,
/* brg */ use_brgemm,
/* block_size_K */ block_size_K,
/* do_unpack */ do_unpack);
});
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
});
}
} // anonymous namespace
// tinygemm interface
template <typename scalar_t>
void tinygemm_kernel(
const scalar_t* __restrict__ A,
const at::Float8_e4m3fn* __restrict__ B,
scalar_t* __restrict__ C,
scalar_t* __restrict__ Btmp,
float* __restrict__ Ctmp,
const float* __restrict__ scale,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg,
int64_t block_size_K,
bool do_unpack) {
tinygemm_kernel<scalar_t, false>(
A, B, C, Btmp, Ctmp, scale, nullptr, M, N, K, lda, ldb, ldc, brg, block_size_K, do_unpack);
}
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
template void tinygemm_kernel<TYPE>( \
const TYPE* __restrict__ A, \
const at::Float8_e4m3fn* __restrict__ B, \
TYPE* __restrict__ C, \
TYPE* __restrict__ Btmp, \
float* __restrict__ Ctmp, \
const float* __restrict__ scale, \
int64_t M, \
int64_t N, \
int64_t K, \
int64_t lda, \
int64_t ldb, \
int64_t ldc, \
bool brg, \
int64_t block_size_K, \
bool do_unpack)
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
at::Tensor fp8_scaled_mm_cpu(
at::Tensor& mat1,
at::Tensor& mat2,
at::Tensor& scales2,
std::vector<int64_t> block_size,
const std::optional<at::Tensor>& bias,
at::ScalarType out_dtype,
bool is_vnni) {
RECORD_FUNCTION("sgl-kernel::fp8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, block_size, bias}));
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
CHECK_INPUT(mat2);
CHECK_INPUT(scales2);
TORCH_CHECK(scales2.scalar_type() == at::kFloat, "fp8_scaled_mm_cpu: expect scales2 to be float32.");
int64_t M = mat1.size(0);
int64_t N = mat2.size(0);
int64_t K = mat2.size(1);
CHECK_EQ(mat1.size(1), K);
CHECK_DIM(2, mat1);
CHECK_DIM(2, mat2);
TORCH_CHECK(block_size.size() == 2, "fp8_scaled_mm_cpu: expect block_size.size() to be 2.");
int64_t block_size_N = block_size[0];
int64_t block_size_K = block_size[1];
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
TORCH_CHECK(block_size_N % BLOCK_N == 0, "fp8_scaled_mm_cpu: expect block_size_N to be multiples of BLOCK_N");
TORCH_CHECK(block_size_K == BLOCK_K, "fp8_scaled_mm_cpu: expect block_size_K equals to BLOCK_K");
CHECK_EQ(scales2.size(0), div_up(N, block_size_N));
CHECK_EQ(scales2.size(1), div_up(K, block_size_K));
const auto st = mat1.scalar_type();
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, "fp8_scaled_mm_cpu: expect A to be bfloat16 or half.");
TORCH_CHECK(st == out_dtype, "fp8_scaled_mm_cpu: expect A has same dtype with out_dtype.");
TORCH_CHECK(mat2.scalar_type() == at::kFloat8_e4m3fn, "fp8_scaled_mm_cpu: expect mat2 to be fp8_e4m3.");
TORCH_CHECK(scales2.scalar_type() == at::kFloat, "fp8_scaled_mm_cpu: expect scales to be float32.");
auto out = at::empty({M, N}, mat1.options().dtype(out_dtype));
// strides
int64_t mat1_strideM = mat1.stride(0);
int64_t out_strideM = out.stride(0);
const bool has_bias = bias.has_value();
const float* bias_data = nullptr;
if (has_bias) {
CHECK_EQ(bias.value().size(0), N);
bias_data = bias.value().data_ptr<float>();
}
// Btmp : [T, BLOCK_N * K]
// Ctmp : [T, BLOCK_M * BLOCK_N]
int num_threads = at::get_num_threads();
int64_t size_per_thread = MAX_CACHE_BLOCK_SIZE * BLOCK_N * K + BLOCK_M * BLOCK_N * 2;
auto buffer = at::empty({num_threads, size_per_thread}, mat1.options());
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "fp8_scaled_mm_kernel_impl", [&] {
fp8_scaled_mm_kernel_impl<scalar_t>(
out.data_ptr<scalar_t>(),
mat1.data_ptr<scalar_t>(),
packed_w.data_ptr<at::Float8_e4m3fn>(),
scales2.data_ptr<float>(),
bias_data,
buffer.data_ptr<scalar_t>(),
M,
N,
K,
mat1_strideM,
out_strideM,
block_size_N,
block_size_K,
size_per_thread);
});
return out;
}

View File

@@ -0,0 +1,547 @@
#include "common.h"
#include "gemm.h"
#include "vec.h"
namespace {
template <typename scalar_t, bool has_bias, int BLOCK_N>
struct scale_C {
static inline void apply(
scalar_t* __restrict__ C,
const int32_t* __restrict__ Ctmp,
const int32_t* __restrict__ Bcomp,
const float* __restrict__ bias,
float As,
const float* __restrict__ Bs) {
TORCH_CHECK(false, "scale_C: scalar path not implemented!");
}
};
#if defined(CPU_CAPABILITY_AVX512)
template <bool has_bias, int BLOCK_N>
struct scale_C<at::BFloat16, has_bias, BLOCK_N> {
static inline void apply(
at::BFloat16* __restrict__ C,
const int32_t* __restrict__ Ctmp,
const int32_t* __restrict__ Bcomp,
const float* __restrict__ bias,
float As,
const float* __restrict__ Bs) {
constexpr int COLS = BLOCK_N / 16;
static_assert(COLS % 2 == 0);
__m512 vc[COLS];
__m512 vd0 = _mm512_set1_ps(As);
auto compute = [&](auto col) {
__m512 vd1 = _mm512_loadu_ps(Bs + col * 16);
__m512i vcomp = _mm512_loadu_si512(Bcomp + col * 16);
__m512i vc32 = _mm512_loadu_si512(Ctmp + col * 16);
vc[col] = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc32, vcomp));
if constexpr (has_bias) {
__m512 vbias = _mm512_loadu_ps(bias + col * 16);
vc[col] = _mm512_fmadd_ps(_mm512_mul_ps(vc[col], vd0), vd1, vbias);
} else {
vc[col] = _mm512_mul_ps(_mm512_mul_ps(vc[col], vd0), vd1);
}
};
Unroll<COLS>{}(compute);
auto storec = [&](auto col) {
// for COLS = 2, 4 use 512bit store
if constexpr (col % 2 == 0) {
_mm512_storeu_si512(
reinterpret_cast<__m512i*>((C + col * 16)), (__m512i)(_mm512_cvtne2ps_pbh(vc[col + 1], vc[col + 0])));
}
};
Unroll<COLS>{}(storec);
}
};
#endif
template <typename scalar_t, bool has_bias, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn {
static inline void apply(
const uint8_t* __restrict__ A,
const int8_t* __restrict__ B,
scalar_t* __restrict__ C,
const float* __restrict__ As,
const float* __restrict__ Bs,
const int32_t* __restrict__ Bcomp,
const float* __restrict__ bias,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc) {
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
}
};
#if defined(CPU_CAPABILITY_AVX512)
template <bool has_bias, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn<at::BFloat16, has_bias, BLOCK_M, BLOCK_N> {
static inline void apply(
const uint8_t* __restrict__ A,
const int8_t* __restrict__ B,
at::BFloat16* __restrict__ C,
const float* __restrict__ As,
const float* __restrict__ Bs,
const int32_t* __restrict__ Bcomp,
const float* __restrict__ bias,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc) {
constexpr int ROWS = BLOCK_M;
constexpr int COLS = BLOCK_N / 16;
static_assert(COLS % 2 == 0);
// prefetch distance
constexpr int PREFETCH_SIZE_K = 0;
__m512i va;
__m512i vb[COLS];
__m512i vc[ROWS * COLS];
__m512i vcomp[COLS];
__m512 vd0;
__m512 vd1[COLS];
// oops! 4x4 spills but we use 4x2
__m512 vbias[COLS];
// [NOTE]: s8s8 igemm compensation in avx512-vnni
//
// avx512-vnni has no s8s8, so we need to change s8s8 to u8s8 with compensate:
//
// a * b = (a + 128) * b - 128 * b
// s s u s u s
//
// 1) 128 * b is pre-computed when packing B to vnni formats
// 2) a + 128 is fused when dynamically quantize A
//
auto loadc = [&](auto i) { vc[i] = _mm512_set1_epi32(0); };
Unroll<ROWS * COLS>{}(loadc);
const int64_t K4 = K >> 2;
const int64_t lda4 = lda >> 2;
const int64_t ldb4 = ldb; // ldb * 4 >> 2;
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A);
const int32_t* b_ptr = reinterpret_cast<const int32_t*>(B);
auto compute = [&](auto i, int64_t k) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (col == 0) {
va = _mm512_set1_epi32(a_ptr[row * lda4 + k]);
}
if constexpr (row == 0) {
vb[col] = _mm512_loadu_si512(b_ptr + k * ldb4 + col * 16);
if constexpr (PREFETCH_SIZE_K > 0) {
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb4 + col * 16, _MM_HINT_T0);
}
}
vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]);
};
for (int64_t k = 0; k < K4; ++k) {
Unroll<ROWS * COLS>{}(compute, k);
}
auto storec = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
// load a scale
if constexpr (col == 0) {
vd0 = _mm512_set1_ps(As[row]);
}
// load b scale and vcomp per 2 vectors
// also load bias if any
if constexpr (row == 0) {
if constexpr (col % 2 == 0) {
vd1[col + 0] = _mm512_loadu_ps(Bs + col * 16);
vd1[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16);
vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16);
vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16);
if constexpr (has_bias) {
vbias[col + 0] = _mm512_loadu_ps(bias + col * 16);
vbias[col + 1] = _mm512_loadu_ps(bias + col * 16 + 16);
}
}
}
// for COLS = 2, 4 use 512bit store
if constexpr (col % 2 == 0) {
__m512 vc0 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 0], vcomp[col + 0]));
__m512 vc1 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 1], vcomp[col + 1]));
if constexpr (has_bias) {
vc0 = _mm512_fmadd_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0], vbias[col + 0]);
vc1 = _mm512_fmadd_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1], vbias[col + 1]);
} else {
vc0 = _mm512_mul_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0]);
vc1 = _mm512_mul_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1]);
}
_mm512_storeu_si512(
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), (__m512i)(_mm512_cvtne2ps_pbh(vc1, vc0)));
}
};
Unroll<ROWS * COLS>{}(storec);
}
};
#endif
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
tinygemm_kernel_nn<scalar_t, has_bias, MB_SIZE, NB_SIZE>::apply( \
A + mb_start * lda, \
B + nb_start * 4, \
C + mb_start * ldc + nb_start, \
As + mb_start, \
Bs + nb_start, \
Bcomp + nb_start, \
has_bias ? bias + nb_start : nullptr, \
K, \
lda, \
ldb, \
ldc);
template <typename scalar_t, bool has_bias>
void tinygemm_kernel(
const uint8_t* __restrict__ A,
const int8_t* __restrict__ B,
scalar_t* __restrict__ C,
int32_t* __restrict__ Ctmp,
const float* __restrict__ As,
const float* __restrict__ Bs,
const float* __restrict__ bias,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg) {
// B compensation
const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * K);
if (brg) {
constexpr int BLOCK_N = block_size_n();
at::native::cpublas::brgemm(M, N, K, lda, ldb, BLOCK_N, /* add_C */ false, A, B, Ctmp);
// apply compensation and scale
for (int64_t m = 0; m < M; ++m) {
scale_C<scalar_t, has_bias, BLOCK_N>::apply(C + m * ldc, Ctmp + m * BLOCK_N, Bcomp, bias, As[m], Bs);
}
return;
}
// pattern: 1-4-16
constexpr int64_t BLOCK_M = 4;
constexpr int64_t BLOCK_N = 64;
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
for (int64_t mb = 0; mb < MB; ++mb) {
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
for (int64_t nb = 0; nb < NB; ++nb) {
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
switch (mb_size << 4 | nb_size >> 4) {
// mb_size = 1
case 0x12:
LAUNCH_TINYGEMM_KERNEL_NN(1, 32);
break;
case 0x14:
LAUNCH_TINYGEMM_KERNEL_NN(1, 64);
break;
// mb_size = 2
case 0x22:
LAUNCH_TINYGEMM_KERNEL_NN(2, 32);
break;
case 0x24:
LAUNCH_TINYGEMM_KERNEL_NN(2, 64);
break;
// mb_size = 3
case 0x32:
LAUNCH_TINYGEMM_KERNEL_NN(3, 32);
break;
case 0x34:
LAUNCH_TINYGEMM_KERNEL_NN(3, 64);
break;
// mb_size = 4
case 0x42:
LAUNCH_TINYGEMM_KERNEL_NN(4, 32);
break;
case 0x44:
LAUNCH_TINYGEMM_KERNEL_NN(4, 64);
break;
default:
TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
}
}
}
}
template <typename scalar_t>
void int8_scaled_mm_kernel_impl(
scalar_t* __restrict__ out,
const uint8_t* __restrict__ mat1,
const int8_t* __restrict__ mat2,
const float* __restrict__ scales1,
const float* __restrict__ scales2,
const float* __restrict__ bias,
int64_t M,
int64_t N,
int64_t K) {
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
const bool use_brgemm = can_use_brgemm<int8_t>(M);
// K + 4 after compensation
const int64_t packed_row_size = get_row_size<int8_t>(K);
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
// for brgemm, use int32_t for accumulate
alignas(64) int32_t Ctmp[BLOCK_M * BLOCK_N];
loop_2d<int8_t>(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
int mb_start = mb * BLOCK_M;
int mb_size = std::min(M - mb_start, BLOCK_M);
int nb_start = nb * BLOCK_N;
int nb_size = std::min(N - nb_start, BLOCK_N);
tinygemm_kernel<scalar_t, has_bias>(
/* A */ mat1 + mb_start * K,
/* B */ mat2 + nb_start * packed_row_size /* nb * BLOCK_N * (K + 4) */,
/* C */ out + mb_start * N + nb_start,
/* Ctmp*/ Ctmp,
/* As */ scales1 + mb_start,
/* Bs */ scales2 + nb_start,
/* bias*/ bias + nb_start,
/* M */ mb_size,
/* N */ nb_size,
/* K */ K,
/* lda */ K,
/* ldb */ nb_size,
/* ldc */ N,
/* brg */ use_brgemm);
});
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
});
}
} // anonymous namespace
// tinygemm interface
template <typename scalar_t>
void tinygemm_kernel(
const uint8_t* __restrict__ A,
const int8_t* __restrict__ B,
scalar_t* __restrict__ C,
int32_t* __restrict__ Ctmp,
const float* __restrict__ As,
const float* __restrict__ Bs,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg) {
tinygemm_kernel<scalar_t, false>(A, B, C, Ctmp, As, Bs, nullptr, M, N, K, lda, ldb, ldc, brg);
}
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
template void tinygemm_kernel<TYPE>( \
const uint8_t* __restrict__ A, \
const int8_t* __restrict__ B, \
TYPE* __restrict__ C, \
int32_t* __restrict__ Ctmp, \
const float* __restrict__ As, \
const float* __restrict__ Bs, \
int64_t M, \
int64_t N, \
int64_t K, \
int64_t lda, \
int64_t ldb, \
int64_t ldc, \
bool brg)
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
std::tuple<at::Tensor, at::Tensor> per_token_quant_int8_cpu(at::Tensor& A) {
RECORD_FUNCTION("sgl-kernel::per_token_quant_int8_cpu", std::vector<c10::IValue>({A}));
CHECK_LAST_DIM_CONTIGUOUS_INPUT(A);
CHECK_DIM(2, A);
int64_t M = A.size(0);
int64_t K = A.size(1);
int64_t lda = A.stride(0);
const auto st = A.scalar_type();
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, "per_token_quant_int8: expect A to be bfloat16 or half.");
auto Aq = at::empty({M, K}, A.options().dtype(at::kByte));
auto As = at::empty({M}, A.options().dtype(at::kFloat));
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "per_token_quant_int8", [&] {
uint8_t* __restrict__ Aq_data = Aq.data_ptr<uint8_t>();
float* __restrict__ As_data = As.data_ptr<float>();
const scalar_t* __restrict__ A_data = A.data_ptr<scalar_t>();
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
for (int64_t m = begin; m < end; ++m) {
quantize_row_int8<scalar_t>(Aq_data + m * K, As_data[m], A_data + m * lda, K);
}
});
});
return std::make_tuple(Aq, As);
}
// weight : static, per-channel, symmetric
// activation : dynamic, per-token, symmetric
//
// mat1 : [M, K]
// mat2 : [N, K]
// scales1 : [M]
// scales2 : [N]
// bias : [N]
// out : [M, N]
//
at::Tensor int8_scaled_mm_cpu(
at::Tensor& mat1,
at::Tensor& mat2,
at::Tensor& scales1,
at::Tensor& scales2,
const std::optional<at::Tensor>& bias,
at::ScalarType out_dtype,
bool is_vnni) {
RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales1, scales2, bias}));
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
CHECK_INPUT(mat1);
CHECK_INPUT(mat2);
CHECK_INPUT(scales1);
CHECK_INPUT(scales2);
CHECK_DIM(2, mat1);
CHECK_DIM(2, mat2);
int64_t M = mat1.size(0);
int64_t N = mat2.size(0);
int64_t K = mat1.size(1);
// see [NOTE]: s8s8 igemm compensation in avx512-vnni
CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K));
CHECK_EQ(scales1.numel(), M);
CHECK_EQ(scales2.numel(), N);
TORCH_CHECK(mat1.scalar_type() == at::kByte, "int8_scaled_mm: expect mat1 to be uint8.");
TORCH_CHECK(mat2.scalar_type() == at::kChar, "int8_scaled_mm: expect mat2 to be int8.");
TORCH_CHECK(
scales1.scalar_type() == at::kFloat && scales2.scalar_type() == at::kFloat,
"int8_scaled_mm: expect scales to be float32.");
auto out = at::empty({M, N}, mat1.options().dtype(out_dtype));
const bool has_bias = bias.has_value();
const float* bias_data = nullptr;
if (has_bias) {
CHECK_EQ(bias.value().size(0), N);
bias_data = bias.value().data_ptr<float>();
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_kernel_impl", [&] {
int8_scaled_mm_kernel_impl<scalar_t>(
out.data_ptr<scalar_t>(),
mat1.data_ptr<uint8_t>(),
packed_w.data_ptr<int8_t>(),
scales1.data_ptr<float>(),
scales2.data_ptr<float>(),
bias_data,
M,
N,
K);
});
return out;
}
// fused `per_token_quant_int8_cpu` and `int8_scaled_mm_cpu`
at::Tensor int8_scaled_mm_with_quant(
at::Tensor& mat1,
at::Tensor& mat2,
at::Tensor& scales2,
const std::optional<at::Tensor>& bias,
at::ScalarType out_dtype,
bool is_vnni) {
RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, bias}));
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
CHECK_INPUT(mat2);
CHECK_INPUT(scales2);
CHECK_DIM(2, mat1);
CHECK_DIM(2, mat2);
int64_t M = mat1.size(0);
int64_t N = mat2.size(0);
int64_t K = mat1.size(1);
int64_t lda = mat1.stride(0);
// see [NOTE]: s8s8 igemm compensation in avx512-vnni
CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K));
CHECK_EQ(scales2.numel(), N);
const auto st = mat1.scalar_type();
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, "int8_scaled_mm_with_quant: expect A to be bfloat16 or half.");
TORCH_CHECK(st == out_dtype, "int8_scaled_mm_with_quant: expect A has same dtype with out_dtype.");
TORCH_CHECK(mat2.scalar_type() == at::kChar, "int8_scaled_mm_with_quant: expect mat2 to be int8.");
TORCH_CHECK(scales2.scalar_type() == at::kFloat, "int8_scaled_mm_with_quant: expect scales to be float32.");
const int64_t buffer_size = M * K + M * sizeof(float);
auto buffer = at::empty({buffer_size}, mat1.options().dtype(at::kByte));
auto out = at::empty({M, N}, mat1.options().dtype(out_dtype));
const bool has_bias = bias.has_value();
const float* bias_data = nullptr;
if (has_bias) {
CHECK_EQ(bias.value().size(0), N);
bias_data = bias.value().data_ptr<float>();
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_with_quant_kernel_impl", [&] {
uint8_t* __restrict__ Aq_data = buffer.data_ptr<uint8_t>();
float* __restrict__ As_data = (float*)((void*)(Aq_data + M * K));
const scalar_t* __restrict__ A_data = mat1.data_ptr<scalar_t>();
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
for (int64_t m = begin; m < end; ++m) {
quantize_row_int8<scalar_t>(Aq_data + m * K, As_data[m], A_data + m * lda, K);
}
});
int8_scaled_mm_kernel_impl<scalar_t>(
out.data_ptr<scalar_t>(),
Aq_data,
packed_w.data_ptr<int8_t>(),
As_data,
scales2.data_ptr<float>(),
bias_data,
M,
N,
K);
});
return out;
}

View File

@@ -0,0 +1,74 @@
#include <ATen/record_function.h>
#include <torch/all.h>
#include "shm.h"
// Communication settings
static int world_rank = -1;
static int world_size = -1;
static bool is_initialized = false;
static bool all_ranks_local_p = false;
void initialize(int64_t size, int64_t rank) {
if (is_initialized) {
return;
}
// Check whether all ranks is on the same physical machine.
// If true, we will use an SHM based low latency allreduce
auto ls_string = std::getenv("LOCAL_SIZE");
int ls = 0;
if (ls_string != NULL) {
ls = std::stoi(std::getenv("LOCAL_SIZE"));
}
if (size >= 1 && size == ls) {
all_ranks_local_p = true;
}
world_size = size;
world_rank = rank;
is_initialized = true;
const char* addr_string = std::getenv("MASTER_ADDR");
if (addr_string == NULL) {
addr_string = "";
}
const char* port_string = std::getenv("MASTER_PORT");
if (port_string == NULL) {
port_string = "";
}
if (all_ranks_local_p) {
shm_initialize(size, rank, addr_string, port_string);
}
}
void shm_allreduce(torch::Tensor& data, int64_t op) {
RECORD_FUNCTION("sgl-kernel::shm_allreduce", std::vector<c10::IValue>({data}));
TORCH_CHECK(op == c10d::ReduceOp::SUM, "Only torch.distributed.ReduceOp.SUM is supported");
auto numel = data.numel();
int data_size = numel * data.element_size();
all_reduce_outer_loop(data, numel, data_size);
return;
}
torch::Tensor shm_allgather(torch::Tensor& data, int64_t dim) {
RECORD_FUNCTION("sgl-kernel::shm_allgather", std::vector<c10::IValue>({data}));
auto numel = data.numel();
int data_size = numel * data.element_size();
if (dim < 0) {
dim += data.dim();
}
std::vector<int64_t> result_shape = data.sizes().vec();
result_shape[dim] *= world_size;
torch::Tensor result_tensor = torch::empty(result_shape, data.options());
return all_gather(result_tensor, data, dim, numel, data_size);
}

1322
sgl-kernel/csrc/cpu/moe.cpp Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,491 @@
#include "common.h"
#include "gemm.h"
#include "vec.h"
namespace {
template <typename scalar_t>
inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) {
using Vec = at::vec::Vectorized<scalar_t>;
// no remainder
#pragma GCC unroll 4
for (int64_t d = 0; d < size; d += Vec::size()) {
Vec data = Vec::loadu(input + d);
data.store(out + d);
}
}
template <typename scalar_t>
inline void copy_mul_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, float weight, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
const fVec weight_vec = fVec(weight);
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= size - kVecSize; d += kVecSize) {
bVec x = bVec::loadu(input + d);
fVec x0, x1;
std::tie(x0, x1) = at::vec::convert_to_float(x);
x0 = x0 * weight_vec;
x1 = x1 * weight_vec;
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
out_vec.store(out + d);
}
for (; d < size; ++d) {
out[d] = static_cast<scalar_t>(input[d] * weight);
}
}
// acc from [topk, K] to [K]
template <typename scalar_t>
inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
if (topk == 1) {
// do copy for topk = 1
copy_stub(out, input, K);
} else {
// do sum for topk != 1
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= K - kVecSize; d += kVecSize) {
fVec sum_fvec0 = fVec(0.f);
fVec sum_fvec1 = fVec(0.f);
for (int t = 0; t < topk; ++t) {
bVec x_bvec = bVec::loadu(input + t * K + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
sum_fvec0 += x_fvec0;
sum_fvec1 += x_fvec1;
}
bVec out_bvec = convert_from_float_ext<scalar_t>(sum_fvec0, sum_fvec1);
out_bvec.store(out + d);
}
for (; d < K; ++d) {
float sum_val = 0.f;
for (int t = 0; t < topk; ++t) {
sum_val += static_cast<float>(input[t * K + d]);
}
out[d] = static_cast<scalar_t>(sum_val);
}
}
}
// out = input + input2 * scale
template <typename scalar_t>
inline void add_mul_stub(
scalar_t* __restrict__ out,
const scalar_t* __restrict__ input,
const scalar_t* __restrict__ input2,
float scale,
int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
const fVec s_vec = fVec(scale);
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= size - kVecSize; d += kVecSize) {
bVec x_bvec = bVec::loadu(input + d);
fVec x0, x1;
std::tie(x0, x1) = at::vec::convert_to_float(x_bvec);
bVec y_bvec = bVec::loadu(input2 + d);
fVec y0, y1;
std::tie(y0, y1) = at::vec::convert_to_float(y_bvec);
x0 = x0 + y0 * s_vec;
x1 = x1 + y1 * s_vec;
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
out_vec.store(out + d);
}
for (; d < size; ++d) {
out[d] = static_cast<scalar_t>(input[d] + float(input2[d]) * scale);
}
}
template <typename scalar_t>
inline void silu_and_mul_stub(
scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const scalar_t* __restrict__ input2, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
const fVec one = fVec(1.f);
// no remainder
#pragma GCC unroll 4
for (int64_t d = 0; d < size; d += bVec::size()) {
bVec x = bVec::loadu(input + d);
fVec x0, x1;
std::tie(x0, x1) = at::vec::convert_to_float(x);
bVec y = bVec::loadu(input2 + d);
fVec y0, y1;
std::tie(y0, y1) = at::vec::convert_to_float(y);
x0 = x0 / (one + x0.neg().exp_u20());
x1 = x1 / (one + x1.neg().exp_u20());
x0 = x0 * y0;
x1 = x1 * y1;
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
out_vec.store(out + d);
}
}
} // anonymous namespace
template <typename scalar_t>
void fused_experts_fp8_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic0,
scalar_t* __restrict__ ic1,
scalar_t* __restrict__ ic2,
scalar_t* __restrict__ A_tmp,
scalar_t* __restrict__ B_tmp,
float* __restrict__ C_tmp,
const scalar_t* __restrict__ input,
const at::Float8_e4m3fn* __restrict__ packed_w1,
const at::Float8_e4m3fn* __restrict__ packed_w2,
const float* __restrict__ w1s,
const float* __restrict__ w2s,
int64_t block_size_N,
int64_t block_size_K,
const float* __restrict__ topk_weights,
const int32_t* __restrict__ sorted_ids,
const int32_t* __restrict__ expert_ids,
const int32_t* __restrict__ offsets,
int64_t M,
int64_t N,
int64_t K,
int64_t E,
int64_t topk,
int64_t num_tokens_post_pad) {
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
// stage 1: intermediate_cache0 = hidden_states @ w1
const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M);
const int64_t NB = div_up(2 * N, BLOCK_N);
int64_t scale_size_N = div_up(2 * N, block_size_N);
int64_t scale_size_K = div_up(K, block_size_K);
int64_t blocks_n_per_group = block_size_N / BLOCK_N;
const int64_t stride_e = 2 * N * K;
const int64_t stride_n = K;
int64_t avg_M = std::max(int64_t(1), M * topk / E);
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(avg_M);
int64_t B_tmp_size_per_thread = MAX_CACHE_BLOCK_SIZE * BLOCK_N * std::max(K, N);
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
// get local pointers
int tid = get_thread_num();
scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K;
loop_2d<at::Float8_e4m3fn>(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N);
// B shape [K, n_size] in vnni format
int32_t expert_id = expert_ids[mb];
const at::Float8_e4m3fn* __restrict__ B = packed_w1 + expert_id * stride_e + nb * BLOCK_N * stride_n;
const float* __restrict__ Bs =
w1s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K;
// do unpacking for the first row or a new expert
int32_t pre_expert_id = mb == 0 ? -1 : expert_ids[mb - 1];
bool do_unpack = (mb == mb0) || (expert_id != pre_expert_id);
// 1.a load A
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
int64_t m_size = offsets[mb + 1] - offsets[mb];
for (int64_t m = 0; m < m_size; ++m) {
int32_t index = A_ids[m] / topk;
copy_stub(A + m * K, input + index * K, K);
}
const int64_t offset = offsets[mb];
tinygemm_kernel<scalar_t>(
/* A */ A,
/* B */ B,
/* C */ ic0 + offset * 2 * N + nb * BLOCK_N,
/* Btmp */ B_tmp + tid * B_tmp_size_per_thread + nb_offset * BLOCK_N * K,
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
/* scale */ Bs,
/* M */ m_size,
/* N */ n_size,
/* K */ K,
/* lda */ K,
/* ldb */ n_size,
/* ldc */ 2 * N,
/* brg */ use_brgemm,
/* block_size_K */ block_size_K,
/* do_unpack */ do_unpack);
});
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
// stage 1.5: intermediate_cache1 = silu(intermediate_cache0)
at::parallel_for(0, M * topk, 0, [&](int64_t begin, int64_t end) {
for (int64_t m = begin; m < end; ++m) {
silu_and_mul_stub(ic1 + m * N, ic0 + m * 2 * N, ic0 + m * 2 * N + N, N);
}
});
// stage 2: intermediate_cache2 = intermediate_cache1 @ w2
// w2 : [E, K, N] as [E, OC, IC]
const int64_t OC = K; // rename K as OC
const int64_t IC = N; // rename N as IC
const int64_t MB2 = MB;
const int64_t NB2 = div_up(OC, BLOCK_N);
scale_size_N = div_up(K, block_size_N);
scale_size_K = div_up(N, block_size_K);
const int64_t stride_e2 = OC * IC;
const int64_t stride_oc = IC;
// parallel on [MB2, NB2]
parallel_2d(MB2, NB2, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
int tid = get_thread_num();
alignas(64) scalar_t C[BLOCK_M * BLOCK_K];
loop_2d<at::Float8_e4m3fn>(mb0, mb1, nb0, nb1, BLOCK_N * IC, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
int64_t m_size = offsets[mb + 1] - offsets[mb];
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
// A ptr from ic1 of [M * topk, N] in sorted order
// so as to avoid copy A to tmp buffer again
const scalar_t* __restrict__ A = ic1 + offsets[mb] * N;
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
// B shape [IC, n_size] in vnni format
int32_t expert_id = expert_ids[mb];
const at::Float8_e4m3fn* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc;
const float* __restrict__ Bs =
w2s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K;
// do unpacking for the first row or a new expert
int32_t pre_expert_id = mb == 0 ? -1 : expert_ids[mb - 1];
bool do_unpack = (mb == mb0) || (expert_id != pre_expert_id);
tinygemm_kernel<scalar_t>(
/* A */ A,
/* B */ B,
/* C */ C,
/* Btmp */ B_tmp + tid * B_tmp_size_per_thread + nb_offset * BLOCK_N * IC,
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
/* scale */ Bs,
/* M */ m_size,
/* N */ n_size,
/* K */ IC,
/* lda */ IC,
/* ldb */ n_size,
/* ldc */ BLOCK_N,
/* brg */ use_brgemm,
/* block_size_K */ block_size_K,
/* do_unpack */ do_unpack);
// 2.b copy from C to ic2 in original order
// and also mul topk_weights in float32
for (int64_t m = 0; m < m_size; ++m) {
int32_t index = A_ids[m];
float weight = topk_weights[index];
copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size);
}
});
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
// stage 3: out = intermediate_cache2.sum(dim=1)
// from [M, topk, K] to [M, K]
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
for (int64_t m = begin; m < end; ++m) {
sum_stub(output + m * K, ic2 + m * topk * K, topk, K);
}
});
}
#define INSTANTIATE_MOE_FP8_TEMPLATE(TYPE) \
template void fused_experts_fp8_kernel_impl<TYPE>( \
TYPE* __restrict__ output, \
TYPE* __restrict__ ic0, \
TYPE* __restrict__ ic1, \
TYPE* __restrict__ ic2, \
TYPE* __restrict__ A_tmp, \
TYPE* __restrict__ B_tmp, \
float* __restrict__ C_tmp, \
const TYPE* __restrict__ input, \
const at::Float8_e4m3fn* __restrict__ packed_w1, \
const at::Float8_e4m3fn* __restrict__ packed_w2, \
const float* __restrict__ w1s, \
const float* __restrict__ w2s, \
int64_t block_size_N, \
int64_t block_size_K, \
const float* __restrict__ topk_weights, \
const int32_t* __restrict__ sorted_ids, \
const int32_t* __restrict__ expert_ids, \
const int32_t* __restrict__ offsets, \
int64_t M, \
int64_t N, \
int64_t K, \
int64_t E, \
int64_t topk, \
int64_t num_tokens_post_pad)
INSTANTIATE_MOE_FP8_TEMPLATE(at::BFloat16);
INSTANTIATE_MOE_FP8_TEMPLATE(at::Half);
template <typename scalar_t>
void shared_expert_fp8_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic0,
scalar_t* __restrict__ ic1,
scalar_t* __restrict__ B_tmp,
float* __restrict__ C_tmp,
const scalar_t* __restrict__ input,
const at::Float8_e4m3fn* __restrict__ packed_w1,
const at::Float8_e4m3fn* __restrict__ packed_w2,
const float* __restrict__ w1s,
const float* __restrict__ w2s,
int64_t block_size_N,
int64_t block_size_K,
const scalar_t* __restrict__ fused_experts_out,
float routed_scaling_factor,
int64_t M,
int64_t N,
int64_t K) {
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
// stage 1: intermediate_cache0 = hidden_states @ w1
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(2 * N, BLOCK_N);
int64_t scale_size_K = div_up(K, block_size_K);
int64_t blocks_n_per_group = block_size_N / BLOCK_N;
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M);
int64_t B_tmp_size_per_thread = MAX_CACHE_BLOCK_SIZE * BLOCK_N * std::max(K, N);
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
int tid = get_thread_num();
loop_2d<at::Float8_e4m3fn>(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N);
// do unpacking for the first row
bool do_unpack = (mb == mb0);
tinygemm_kernel<scalar_t>(
/* A */ input + mb * BLOCK_M * K,
/* B */ packed_w1 + nb * BLOCK_N * K,
/* C */ ic0 + mb * BLOCK_M * 2 * N + nb * BLOCK_N,
/* Btmp */ B_tmp + tid * B_tmp_size_per_thread + nb_offset * BLOCK_N * K,
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
/* scale */ w1s + (nb / blocks_n_per_group) * scale_size_K,
/* M */ m_size,
/* N */ n_size,
/* K */ K,
/* lda */ K,
/* ldb */ n_size,
/* ldc */ 2 * N,
/* brg */ use_brgemm,
/* block_size_K */ block_size_K,
/* do_unpack */ do_unpack);
});
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
// stage 1.5: intermediate_cache1 = silu(intermediate_cache0)
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
for (int64_t m = begin; m < end; ++m) {
silu_and_mul_stub(ic1 + m * N, ic0 + m * 2 * N, ic0 + m * 2 * N + N, N);
}
});
// stage 2: intermediate_cache2 = intermediate_cache1 @ w2
// w2 : [K, N] as [OC, IC]
const int64_t OC = K; // rename K as OC
const int64_t IC = N; // rename N as IC
const int64_t MB2 = MB;
const int64_t NB2 = div_up(K, BLOCK_N);
scale_size_K = div_up(N, block_size_K);
// parallel on [MB2, NB2]
parallel_2d(MB2, NB2, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
int tid = get_thread_num();
alignas(64) scalar_t C[BLOCK_M * BLOCK_K];
loop_2d<at::Float8_e4m3fn>(mb0, mb1, nb0, nb1, BLOCK_N * IC, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
// do unpacking for the first row
bool do_unpack = (mb == mb0);
// 2.a gemm: C = A @ B
tinygemm_kernel<scalar_t>(
/* A */ ic1 + mb * BLOCK_M * N,
/* B */ packed_w2 + nb * BLOCK_N * N,
/* C */ C,
/* Btmp */ B_tmp + tid * B_tmp_size_per_thread + nb_offset * BLOCK_N * IC,
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
/* scale */ w2s + (nb / blocks_n_per_group) * scale_size_K,
/* M */ m_size,
/* N */ n_size,
/* K */ IC,
/* lda */ IC,
/* ldb */ n_size,
/* ldc */ BLOCK_N,
/* brg */ use_brgemm,
/* block_size_K */ block_size_K,
/* do_unpack */ do_unpack);
// 2.b copy from C to output and add fused_experts_out
scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N;
const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N;
for (int64_t m = 0; m < m_size; ++m) {
add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size);
}
});
});
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
}
#define INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(TYPE) \
template void shared_expert_fp8_kernel_impl<TYPE>( \
TYPE* __restrict__ output, \
TYPE* __restrict__ ic0, \
TYPE* __restrict__ ic1, \
TYPE* __restrict__ B_tmp, \
float* __restrict__ C_tmp, \
const TYPE* __restrict__ input, \
const at::Float8_e4m3fn* __restrict__ packed_w1, \
const at::Float8_e4m3fn* __restrict__ packed_w2, \
const float* __restrict__ w1s, \
const float* __restrict__ w2s, \
int64_t block_size_N, \
int64_t block_size_K, \
const TYPE* __restrict__ fused_experts_out, \
float routed_scaling_factor, \
int64_t M, \
int64_t N, \
int64_t K)
INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(at::BFloat16);
INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(at::Half);

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,304 @@
#include "common.h"
#include "vec.h"
namespace {
// NB: avoid using `at::vec::map<>` on bfloat16 or half
// Llama4TextL2Norm
template <typename scalar_t>
void l2norm_kernel_impl(
scalar_t* __restrict__ output,
const scalar_t* __restrict__ input,
int64_t batch_size,
int64_t hidden_size,
float eps = 1e-5) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
at::parallel_for(0, batch_size, 0, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; ++i) {
// local ptrs
scalar_t* __restrict__ out_ptr = output + i * hidden_size;
const scalar_t* __restrict__ input_ptr = input + i * hidden_size;
fVec sum_fvec = fVec(float(0));
float sum_val = float(0);
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
bVec x_bvec = bVec::loadu(input_ptr + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
sum_fvec += x_fvec0 * x_fvec0;
sum_fvec += x_fvec1 * x_fvec1;
}
#pragma GCC unroll 4
for (; d < hidden_size; ++d) {
float x_val = static_cast<float>(input_ptr[d]);
sum_val += x_val * x_val;
}
sum_val += vec_reduce_sum(sum_fvec);
float rsqrt_var = float(1) / std::sqrt(sum_val / hidden_size + eps);
const fVec scale_fvec = fVec(rsqrt_var);
#pragma GCC unroll 4
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
bVec x_bvec = bVec::loadu(input_ptr + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
x_fvec0 = x_fvec0 * scale_fvec;
x_fvec1 = x_fvec1 * scale_fvec;
bVec out_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
out_bvec.store(out_ptr + d);
}
#pragma GCC unroll 4
for (; d < hidden_size; ++d) {
float x_val = static_cast<float>(input_ptr[d]);
out_ptr[d] = static_cast<scalar_t>(x_val * rsqrt_var);
}
}
});
}
template <typename scalar_t>
void rmsnorm_kernel_impl(
scalar_t* __restrict__ output,
const scalar_t* __restrict__ input,
const scalar_t* __restrict__ weight,
int64_t batch_size,
int64_t hidden_size,
int64_t input_strideN,
float eps = 1e-5) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
at::parallel_for(0, batch_size, 0, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; ++i) {
// local ptrs
scalar_t* __restrict__ out_ptr = output + i * hidden_size;
const scalar_t* __restrict__ input_ptr = input + i * input_strideN;
fVec sum_fvec = fVec(float(0));
float sum_val = float(0);
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
bVec x_bvec = bVec::loadu(input_ptr + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
sum_fvec += x_fvec0 * x_fvec0;
sum_fvec += x_fvec1 * x_fvec1;
}
#pragma GCC unroll 4
for (; d < hidden_size; ++d) {
float x_val = static_cast<float>(input_ptr[d]);
sum_val += x_val * x_val;
}
sum_val += vec_reduce_sum(sum_fvec);
float rsqrt_var = float(1) / std::sqrt(sum_val / hidden_size + eps);
const fVec scale_fvec = fVec(rsqrt_var);
#pragma GCC unroll 4
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
bVec x_bvec = bVec::loadu(input_ptr + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
bVec w_bvec = bVec::loadu(weight + d);
fVec w_fvec0, w_fvec1;
std::tie(w_fvec0, w_fvec1) = at::vec::convert_to_float(w_bvec);
x_fvec0 = x_fvec0 * scale_fvec * w_fvec0;
x_fvec1 = x_fvec1 * scale_fvec * w_fvec1;
bVec out_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
out_bvec.store(out_ptr + d);
}
#pragma GCC unroll 4
for (; d < hidden_size; ++d) {
float x_val = static_cast<float>(input_ptr[d]);
float w_val = static_cast<float>(weight[d]);
out_ptr[d] = static_cast<scalar_t>(x_val * rsqrt_var * w_val);
}
}
});
}
template <typename scalar_t>
void fused_add_rmsnorm_kernel_impl(
scalar_t* __restrict__ input,
scalar_t* __restrict__ residual,
const scalar_t* __restrict__ weight,
float* __restrict__ buffer,
int64_t batch_size,
int64_t hidden_size,
int64_t input_strideN,
float eps = 1e-5) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
at::parallel_for(0, batch_size, 0, [&](int64_t begin, int64_t end) {
int tid = at::get_thread_num();
float* __restrict__ buffer_ptr = buffer + tid * hidden_size;
for (int64_t i = begin; i < end; ++i) {
// local ptrs
scalar_t* __restrict__ input_ptr = input + i * input_strideN;
scalar_t* __restrict__ residual_ptr = residual + i * hidden_size;
fVec sum_fvec = fVec(float(0));
float sum_val = float(0);
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
bVec x_bvec = bVec::loadu(input_ptr + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
bVec r_bvec = bVec::loadu(residual_ptr + d);
fVec r_fvec0, r_fvec1;
std::tie(r_fvec0, r_fvec1) = at::vec::convert_to_float(r_bvec);
x_fvec0 += r_fvec0;
x_fvec1 += r_fvec1;
bVec out_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
out_bvec.store(residual_ptr + d);
sum_fvec += x_fvec0 * x_fvec0;
sum_fvec += x_fvec1 * x_fvec1;
x_fvec0.store(buffer_ptr + d);
x_fvec1.store(buffer_ptr + d + fVec::size());
}
#pragma GCC unroll 4
for (; d < hidden_size; ++d) {
float x_val = static_cast<float>(input_ptr[d]);
float r_val = static_cast<float>(residual_ptr[d]);
x_val += r_val;
residual_ptr[d] = static_cast<scalar_t>(x_val);
sum_val += x_val * x_val;
buffer_ptr[d] = x_val;
}
sum_val += vec_reduce_sum(sum_fvec);
float rsqrt_var = float(1) / std::sqrt(sum_val / hidden_size + eps);
const fVec scale_fvec = fVec(rsqrt_var);
#pragma GCC unroll 4
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
fVec x_fvec0 = fVec::loadu(buffer_ptr + d);
fVec x_fvec1 = fVec::loadu(buffer_ptr + d + fVec::size());
bVec w_bvec = bVec::loadu(weight + d);
fVec w_fvec0, w_fvec1;
std::tie(w_fvec0, w_fvec1) = at::vec::convert_to_float(w_bvec);
x_fvec0 = x_fvec0 * scale_fvec * w_fvec0;
x_fvec1 = x_fvec1 * scale_fvec * w_fvec1;
bVec x_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
x_bvec.store(input_ptr + d);
}
#pragma GCC unroll 4
for (; d < hidden_size; ++d) {
float x_val = buffer_ptr[d] * rsqrt_var * static_cast<float>(weight[d]);
input_ptr[d] = x_val;
}
}
});
}
} // anonymous namespace
// input : {batch_size, hidden_size}
at::Tensor l2norm_cpu(at::Tensor& input, double eps) {
RECORD_FUNCTION("sgl-kernel::l2norm_cpu", std::vector<c10::IValue>({input}));
CHECK_INPUT(input);
CHECK_DIM(2, input);
int64_t batch_size = input.size(0);
int64_t hidden_size = input.size(1);
at::Tensor output = at::empty_like(input);
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "l2norm_kernel", [&] {
l2norm_kernel_impl<scalar_t>(output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), batch_size, hidden_size, eps);
});
return output;
}
// input : {batch_size, hidden_size}
// weight: {hidden_size}
at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) {
RECORD_FUNCTION("sgl-kernel::rmsnorm_cpu", std::vector<c10::IValue>({input, weight}));
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
CHECK_INPUT(weight);
CHECK_DIM(2, input);
CHECK_DIM(1, weight);
CHECK_EQ(input.size(1), weight.size(0));
int64_t batch_size = input.size(0);
int64_t hidden_size = input.size(1);
at::Tensor output = at::empty_like(input);
int64_t input_strideN = input.stride(0);
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "rmsnorm_kernel", [&] {
rmsnorm_kernel_impl<scalar_t>(
output.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
batch_size,
hidden_size,
input_strideN,
eps);
});
return output;
}
// input : {batch_size, hidden_size}
// residual: {batch_size, hidden_size}
// weight : {hidden_size}
void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps) {
RECORD_FUNCTION("sgl-kernel::fused_add_rmsnorm_cpu", std::vector<c10::IValue>({input, residual, weight}));
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
CHECK_INPUT(residual);
CHECK_INPUT(weight);
CHECK_DIM(2, input);
CHECK_DIM(2, residual);
CHECK_DIM(1, weight);
CHECK_EQ(input.size(0), residual.size(0));
CHECK_EQ(input.size(1), residual.size(1));
CHECK_EQ(input.size(1), weight.size(0));
int64_t batch_size = input.size(0);
int64_t hidden_size = input.size(1);
int64_t input_strideN = input.stride(0);
// allocate temp buffer to store x in float32 per thread
// TODO: implement a singleton for context
int64_t num_threads = at::get_num_threads();
at::Tensor buffer = at::empty({num_threads, hidden_size}, input.options().dtype(at::kFloat));
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "fused_add_rmsnorm_kernel", [&] {
fused_add_rmsnorm_kernel_impl<scalar_t>(
input.data_ptr<scalar_t>(),
residual.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
buffer.data_ptr<float>(),
batch_size,
hidden_size,
input_strideN,
eps);
});
}

View File

@@ -0,0 +1,91 @@
#include <numa.h>
#include <sched.h>
#include <sys/syscall.h>
#include <sys/types.h>
#include <unistd.h>
#include <string>
#include "common.h"
std::string init_cpu_threads_env(const std::string& cpu_ids) {
bitmask* omp_cpu_mask = numa_parse_cpustring(cpu_ids.c_str());
TORCH_CHECK(omp_cpu_mask->size > 0);
std::vector<int> omp_cpu_ids;
omp_cpu_ids.reserve(omp_cpu_mask->size);
constexpr int group_size = 8 * sizeof(*omp_cpu_mask->maskp);
for (int offset = 0; offset < omp_cpu_mask->size; offset += group_size) {
unsigned long group_mask = omp_cpu_mask->maskp[offset / group_size];
int i = 0;
while (group_mask) {
if (group_mask & 1) {
omp_cpu_ids.emplace_back(offset + i);
}
++i;
group_mask >>= 1;
}
}
// Memory node binding
if (numa_available() != -1) {
int mem_node_id = numa_node_of_cpu(omp_cpu_ids.front());
bitmask* mask = numa_parse_nodestring(std::to_string(mem_node_id).c_str());
bitmask* src_mask = numa_get_membind();
int pid = getpid();
// move all existing pages to the specified numa node.
*(src_mask->maskp) = *(src_mask->maskp) ^ *(mask->maskp);
int page_num = numa_migrate_pages(pid, src_mask, mask);
if (page_num == -1) {
TORCH_WARN(false, "numa_migrate_pages failed. errno: " + std::to_string(errno));
}
// restrict memory allocation node.
numa_set_membind(mask);
numa_set_strict(1);
}
// OMP threads binding
omp_set_num_threads((int)omp_cpu_ids.size());
at::set_num_threads((int)omp_cpu_ids.size());
TORCH_CHECK_EQ(omp_cpu_ids.size(), at::get_num_threads());
TORCH_CHECK_EQ(omp_cpu_ids.size(), omp_get_max_threads());
std::vector<std::pair<int, int>> thread_core_mapping;
thread_core_mapping.reserve(omp_cpu_ids.size());
omp_lock_t writelock;
omp_init_lock(&writelock);
#pragma omp parallel for schedule(static, 1)
for (size_t i = 0; i < omp_cpu_ids.size(); ++i) {
cpu_set_t mask;
CPU_ZERO(&mask);
CPU_SET(omp_cpu_ids[i], &mask);
int ret = sched_setaffinity(0, sizeof(cpu_set_t), &mask);
if (ret == -1) {
TORCH_CHECK(false, "sched_setaffinity failed. errno: " + std::to_string(errno));
}
omp_set_lock(&writelock);
thread_core_mapping.emplace_back(syscall(SYS_gettid), omp_cpu_ids[i]);
omp_unset_lock(&writelock);
}
omp_destroy_lock(&writelock);
numa_free_nodemask(omp_cpu_mask);
std::stringstream ss;
ss << "OMP threads binding of Process " << getpid() << ":\n";
std::sort(
thread_core_mapping.begin(), thread_core_mapping.end(), [](auto&& a, auto&& b) { return a.second < b.second; });
for (auto&& item : thread_core_mapping) {
ss << "\t"
<< "OMP tid: " << item.first << ", core " << item.second << "\n";
}
return ss.str();
}

View File

@@ -0,0 +1,701 @@
#include "common.h"
#include "gemm.h"
#include "vec.h"
namespace {
// [NOTE]: Fused kernel for QKV projection with weight absorption and RoPE
//
// 1. `q_a_proj` and `kv_a_proj_with_mqa` fused into one gemm,
// otherwise we need to split IC for the 2nd gemm.
// 2. `q_a_layernorm` and `kv_a_layernorm` fused into one parallel loop.
// 3. k_input and v_input share the same storage, the torch API did
// this in `set_kv_buffer`. No additional memory movement.
//
// [C0, C1] = A @ [B0, B1]
template <typename scalar_t>
void segment_gemm_kernel_impl(
scalar_t* __restrict__ C0,
scalar_t* __restrict__ C1,
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B0,
const scalar_t* __restrict__ B1,
int64_t M,
int64_t N0,
int64_t N1,
int64_t K) {
// convert_weight_packed make sure N0 and N1 are 32x
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB0 = div_up(N0, BLOCK_N);
const int64_t NB1 = div_up(N1, BLOCK_N);
const int64_t NB = NB0 + NB1;
const bool use_brgemm = can_use_brgemm<scalar_t>(M);
// parallel on [MB, NB0 + NB1]
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
int64_t mb{0}, nb{0};
data_index_init(begin, mb, MB, nb, NB);
// for brgemm, use float32 for accumulate
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
for (int64_t i = begin; i < end; ++i) {
UNUSED(i);
int mb_start = mb * BLOCK_M;
int mb_size = std::min(M - mb_start, BLOCK_M);
int nb_start = nb * BLOCK_N;
int nb_size = BLOCK_N;
const scalar_t* __restrict__ B = nb < NB0 ? B0 : B1;
scalar_t* __restrict__ C = nb < NB0 ? C0 : C1;
int64_t ldc = nb < NB0 ? N0 : N1;
int64_t local_nb_start = nb < NB0 ? nb_start : nb_start - N0;
tinygemm_kernel<scalar_t>(
/* A */ A + mb_start * K,
/* B */ B + local_nb_start * K /* nb * BLOCK_N * K */,
/* C */ C + mb_start * ldc + local_nb_start,
/* Ctmp*/ Ctmp,
/* M */ mb_size,
/* N */ nb_size,
/* K */ K,
/* lda */ K,
/* ldb */ nb_size,
/* ldc */ ldc,
/* brg */ use_brgemm);
// move to the next index
data_index_step(mb, MB, nb, NB);
}
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
}
// [C0, C1] = A @ [B0, B1]
template <typename scalar_t>
void segment_gemm_kernel_impl(
scalar_t* __restrict__ C0,
scalar_t* __restrict__ C1,
const uint8_t* __restrict__ A,
const int8_t* __restrict__ B0,
const int8_t* __restrict__ B1,
const float* __restrict__ As,
const float* __restrict__ Bs0,
const float* __restrict__ Bs1,
int64_t M,
int64_t N0,
int64_t N1,
int64_t K) {
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB0 = div_up(N0, BLOCK_N);
const int64_t NB1 = div_up(N1, BLOCK_N);
const int64_t NB = NB0 + NB1;
const bool use_brgemm = can_use_brgemm<int8_t>(M);
// K + 4 after compensation
const int64_t packed_row_size = get_row_size<int8_t>(K);
// parallel on [MB, NB0 + NB1]
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
int64_t mb{0}, nb{0};
data_index_init(begin, mb, MB, nb, NB);
// for brgemm, use float32 for accumulate
alignas(64) int32_t Ctmp[BLOCK_M * BLOCK_N];
for (int64_t i = begin; i < end; ++i) {
UNUSED(i);
int mb_start = mb * BLOCK_M;
int mb_size = std::min(M - mb_start, BLOCK_M);
int nb_start = nb * BLOCK_N;
int nb_size = BLOCK_N;
const int8_t* __restrict__ B = nb < NB0 ? B0 : B1;
const float* __restrict__ Bs = nb < NB0 ? Bs0 : Bs1;
scalar_t* __restrict__ C = nb < NB0 ? C0 : C1;
int64_t ldc = nb < NB0 ? N0 : N1;
int64_t local_nb_start = nb < NB0 ? nb_start : nb_start - N0;
tinygemm_kernel<scalar_t>(
/* A */ A + mb_start * K,
/* B */ B + local_nb_start * packed_row_size /* nb * BLOCK_N * (K + 4) */,
/* C */ C + mb_start * ldc + local_nb_start,
/* Ctmp*/ Ctmp,
/* As */ As + mb_start,
/* Bs */ Bs + local_nb_start,
/* M */ mb_size,
/* N */ nb_size,
/* K */ K,
/* lda */ K,
/* ldb */ nb_size,
/* ldc */ ldc,
/* brg */ use_brgemm);
// move to the next index
data_index_step(mb, MB, nb, NB);
}
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
}
// [C0, C1] = A @ [B0, B1]
template <typename scalar_t>
void segment_gemm_kernel_impl(
scalar_t* __restrict__ C0,
scalar_t* __restrict__ C1,
const scalar_t* __restrict__ A,
const at::Float8_e4m3fn* __restrict__ B0,
const at::Float8_e4m3fn* __restrict__ B1,
const float* __restrict__ Bs0,
const float* __restrict__ Bs1,
scalar_t* __restrict__ Btmp,
int64_t M,
int64_t N0,
int64_t N1,
int64_t K,
int64_t block_size_N,
int64_t block_size_K) {
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB0 = div_up(N0, BLOCK_N);
const int64_t NB1 = div_up(N1, BLOCK_N);
const int64_t NB = NB0 + NB1;
const int64_t scale_size_K = div_up(K, block_size_K);
const int64_t blocks_n_per_group = block_size_N / BLOCK_N;
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M);
// parallel on [MB, NB0 + NB1]
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
int64_t mb{0}, nb{0};
data_index_init(begin, mb, MB, nb, NB);
int tid = at::get_thread_num();
// for brgemm, use float32 for accumulate
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
for (int64_t i = begin; i < end; ++i) {
UNUSED(i);
int mb_start = mb * BLOCK_M;
int mb_size = std::min(M - mb_start, BLOCK_M);
int nb_start = nb * BLOCK_N;
int nb_size = BLOCK_N;
const at::Float8_e4m3fn* __restrict__ B = nb < NB0 ? B0 : B1;
const float* __restrict__ Bs = nb < NB0 ? Bs0 : Bs1;
scalar_t* __restrict__ C = nb < NB0 ? C0 : C1;
int64_t ldc = nb < NB0 ? N0 : N1;
int64_t local_nb_start = nb < NB0 ? nb_start : nb_start - N0;
int64_t new_nb = nb < NB0 ? nb : nb - NB0;
tinygemm_kernel<scalar_t>(
/* A */ A + mb_start * K,
/* B */ B + local_nb_start * K /* nb * BLOCK_N * K */,
/* C */ C + mb_start * ldc + local_nb_start,
/* Btmp*/ Btmp + tid * BLOCK_N * K,
/* Ctmp*/ Ctmp,
/* Bs */ Bs + (new_nb / blocks_n_per_group) * scale_size_K,
/* M */ mb_size,
/* N */ nb_size,
/* K */ K,
/* lda */ K,
/* ldb */ nb_size,
/* ldc */ ldc,
/* brg */ use_brgemm,
/* block_size_K */ block_size_K);
// move to the next index
data_index_step(mb, MB, nb, NB);
}
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
}
template <typename scalar_t>
inline float reduce(const scalar_t* __restrict__ x, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
fVec sum_fvec = fVec(float(0));
// no remainder
#pragma GCC unroll 4
for (int64_t d = 0; d < size; d += bVec::size()) {
bVec x_bvec = bVec::loadu(x + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
sum_fvec += x_fvec0 * x_fvec0;
sum_fvec += x_fvec1 * x_fvec1;
}
return vec_reduce_sum(sum_fvec);
}
// map2 from aten functional doesn't have fast bf16->fp32 conversion
template <typename scalar_t>
inline void map2(scalar_t* y, const scalar_t* x, const scalar_t* __restrict__ w, float scale, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
fVec scale_fvec = fVec(scale);
// no remainder
#pragma GCC unroll 4
for (int64_t d = 0; d < size; d += bVec::size()) {
bVec x_bvec = bVec::loadu(x + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
bVec w_bvec = bVec::loadu(w + d);
fVec w_fvec0, w_fvec1;
std::tie(w_fvec0, w_fvec1) = at::vec::convert_to_float(w_bvec);
x_fvec0 = x_fvec0 * scale_fvec * w_fvec0;
x_fvec1 = x_fvec1 * scale_fvec * w_fvec1;
bVec out_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
out_bvec.store(y + d);
}
}
template <typename scalar_t>
void rms_norm_kernel_impl(
scalar_t* __restrict__ input0,
scalar_t* __restrict__ input1,
const scalar_t* __restrict__ weight0,
const scalar_t* __restrict__ weight1,
int64_t M,
int64_t N0,
int64_t N1,
int64_t stride1,
float eps = 1e-5) {
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
for (int64_t m = begin; m < end; ++m) {
scalar_t* x0 = input0 + m * N0;
scalar_t* x1 = input1 + m * stride1;
float scale0 = reduce(x0, N0);
float scale1 = reduce(x1, N1);
scale0 = float(1) / std::sqrt(scale0 / N0 + eps);
scale1 = float(1) / std::sqrt(scale1 / N1 + eps);
map2(x0, x0, weight0, scale0, N0);
map2(x1, x1, weight1, scale1, N1);
}
});
}
template <typename scalar_t>
inline void rotary(const scalar_t* input, scalar_t* out, const scalar_t* cos, const scalar_t* sin, int64_t size) {
TORCH_CHECK(false, "rotary scalar path not implemented.");
}
#if defined(CPU_CAPABILITY_AVX512)
template <>
inline void rotary<at::BFloat16>(
const at::BFloat16* input, at::BFloat16* out, const at::BFloat16* cos, const at::BFloat16* sin, int64_t size) {
// permute indices
const __m512i idx1 = _mm512_set_epi32(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
const __m512i idx2 = _mm512_set_epi32(31, 29, 27, 25, 23, 21, 19, 17, 15, 13, 11, 9, 7, 5, 3, 1);
const __m512i idy1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0);
const __m512i idy2 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8);
// rotary dim is 64, just 2 iters
#pragma GCC unroll 2
for (int64_t d = 0; d < size; d += 32) {
int64_t d2 = d >> 1;
// load coefs
__m512 vcos = CVT_BF16_TO_FP32(_mm256_loadu_si256(reinterpret_cast<const __m256i*>(cos + d2)));
__m512 vsin = CVT_BF16_TO_FP32(_mm256_loadu_si256(reinterpret_cast<const __m256i*>(sin + d2)));
// load input
__m512i a16 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(input + d));
__m512 a = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(a16, 0));
__m512 b = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(a16, 1));
// from [16, 2] to [2, 16]
__m512 in1 = _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b);
__m512 in2 = _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b);
// out1 = in1 * cos - in2 * sin;
// out2 = in2 * cos + in1 * sin
__m512 out1 = _mm512_sub_ps(_mm512_mul_ps(in1, vcos), _mm512_mul_ps(in2, vsin));
__m512 out2 = _mm512_add_ps(_mm512_mul_ps(in2, vcos), _mm512_mul_ps(in1, vsin));
// from [2, 16] to [16, 2]
a = _mm512_mask_permutex2var_ps(out1, 0xffff, idy1, out2);
b = _mm512_mask_permutex2var_ps(out1, 0xffff, idy2, out2);
_mm512_storeu_si512(reinterpret_cast<__m512i*>((out + d)), (__m512i)(_mm512_cvtne2ps_pbh(b, a)));
}
}
#endif
template <typename scalar_t>
void rotary_emb_kernel_impl(
scalar_t* q_pe_out,
scalar_t* k_pe_out,
const scalar_t* q_pe,
const scalar_t* k_pe,
const int64_t* pos,
const scalar_t* cos_sin,
int64_t num_seqs,
int64_t num_heads,
int64_t rotary_dim,
int64_t q_strideB,
int64_t q_strideH,
int64_t k_strideB,
int64_t oq_strideB,
int64_t oq_strideH,
int64_t ok_strideB) {
TORCH_CHECK(rotary_dim % 32 == 0, "rotary_dim is not 32x.");
const int64_t rotary_offset = rotary_dim / 2;
// parallel on [num_seqs, num_heads + 1]
// top [num_heads] handle q_pe and bottom [1] handle k_pe
at::parallel_for(0, num_seqs * (num_heads + 1), GRAIN_SIZE / rotary_dim, [&](int64_t begin, int64_t end) {
int64_t seq{0}, head_id{0};
data_index_init(begin, seq, num_seqs, head_id, num_heads + 1);
for (int64_t i = begin; i < end; ++i) {
UNUSED(i);
// get cos and sin cache ptr
int64_t index = pos[seq];
const scalar_t* cos = cos_sin + index * rotary_dim;
const scalar_t* sin = cos + rotary_offset;
const scalar_t* input =
(head_id < num_heads) ? q_pe + seq * q_strideB + head_id * q_strideH : k_pe + seq * k_strideB;
scalar_t* out =
(head_id < num_heads) ? q_pe_out + seq * oq_strideB + head_id * oq_strideH : k_pe_out + seq * ok_strideB;
rotary<scalar_t>(input, out, cos, sin, rotary_dim);
// move to the next index
data_index_step(seq, num_seqs, head_id, num_heads + 1);
}
});
}
} // anonymous namespace
extern at::Tensor
weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, const std::optional<at::Tensor>& bias, bool is_vnni);
extern at::Tensor int8_scaled_mm_with_quant(
at::Tensor& mat1,
at::Tensor& mat2,
at::Tensor& scales2,
const std::optional<at::Tensor>& bias,
at::ScalarType out_dtype,
bool is_vnni);
extern void
bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, const std::optional<at::Tensor>& scale);
extern at::Tensor fp8_scaled_mm_cpu(
at::Tensor& mat1,
at::Tensor& mat2,
at::Tensor& scales2,
std::vector<int64_t> block_size,
const std::optional<at::Tensor>& bias,
at::ScalarType out_dtype,
bool is_vnni);
// NB: shapes in DeepDeek R1
//
// hidden_states : [num_seqs, hidden_size] [1, 7168]
// q_a_proj_weight : [q_lora_rank, hidden_size] [1536, 7168]
// q_b_proj_weight : [num_heads * qk_head_dim, q_lora_rank] [4224, 1536]
// kv_a_proj_weight : [kv_lora_rank + qk_rope_head_dim, hidden_size] [576, 7168]
// w_kc : [num_heads, kv_lora_rank, qk_nope_head_dim] [22, 512, 128]
// q_a_layernorm_weight : [q_lora_rank] [1536]
// kv_a_layernorm_weight : [kv_lora_rank] [512]
//
std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
at::Tensor& hidden_states,
at::Tensor& q_a_proj_weight,
at::Tensor& q_b_proj_weight,
at::Tensor& kv_a_proj_weight,
at::Tensor& w_kc,
at::Tensor& q_a_layernorm_weight,
at::Tensor& kv_a_layernorm_weight,
at::Tensor& positions,
at::Tensor& cos_sin_cache,
double eps,
bool use_int8_w8a8,
bool use_fp8_w8a16,
std::optional<at::Tensor> q_a_proj_scale,
std::optional<at::Tensor> q_b_proj_scale,
std::optional<at::Tensor> kv_a_proj_scale,
bool is_vnni,
std::optional<std::vector<int64_t>> block_size) {
RECORD_FUNCTION(
"sgl-kernel::qkv_proj_with_rope",
std::vector<c10::IValue>({hidden_states, q_a_proj_weight, q_b_proj_weight, kv_a_proj_weight, w_kc}));
const auto st = hidden_states.scalar_type();
CHECK_INPUT(hidden_states);
CHECK_INPUT(positions);
CHECK_INPUT(cos_sin_cache);
CHECK_EQ(q_a_layernorm_weight.scalar_type(), st);
CHECK_EQ(kv_a_layernorm_weight.scalar_type(), st);
CHECK_EQ(positions.scalar_type(), at::kLong);
CHECK_EQ(cos_sin_cache.scalar_type(), st);
CHECK_DIM(2, hidden_states);
CHECK_DIM(3, w_kc);
CHECK_DIM(1, q_a_layernorm_weight);
CHECK_DIM(1, kv_a_layernorm_weight);
CHECK_DIM(1, positions);
CHECK_DIM(2, cos_sin_cache);
// skip contiguous checks for weights, expect prepacked
TORCH_CHECK(is_vnni, "qkv_proj_with_rope: expect weights are prepacked!");
int64_t num_seqs = hidden_states.size(0);
int64_t hidden_size = hidden_states.size(1);
int64_t q_lora_rank = q_a_proj_weight.size(0);
int64_t num_heads = w_kc.size(0);
int64_t kv_lora_rank = w_kc.size(1);
int64_t qk_head_dim = q_b_proj_weight.size(0) / num_heads;
int64_t qk_nope_head_dim = w_kc.size(2);
int64_t qk_rope_head_dim = kv_a_proj_weight.size(0) - kv_lora_rank;
int64_t rotary_dim = cos_sin_cache.size(1);
CHECK_EQ(positions.numel(), num_seqs);
CHECK_EQ(rotary_dim, qk_rope_head_dim);
CHECK_EQ(q_a_layernorm_weight.numel(), q_lora_rank);
CHECK_EQ(kv_a_layernorm_weight.numel(), kv_lora_rank);
// check the packed dimension
CHECK_EQ(q_a_proj_weight.size(1), get_row_size(hidden_size, use_int8_w8a8));
CHECK_EQ(q_b_proj_weight.size(1), get_row_size(q_lora_rank, use_int8_w8a8));
CHECK_EQ(kv_a_proj_weight.size(1), get_row_size(hidden_size, use_int8_w8a8));
if (use_int8_w8a8) {
TORCH_CHECK(q_a_proj_scale.has_value(), "missing q_a_proj_scale for int8 w8a8.");
TORCH_CHECK(q_b_proj_scale.has_value(), "missing q_b_proj_scale for int8 w8a8.");
TORCH_CHECK(kv_a_proj_scale.has_value(), "missing kv_a_proj_scale for int8 w8a8.");
}
if (use_fp8_w8a16) {
TORCH_CHECK(q_a_proj_scale.has_value(), "missing q_a_proj_scale for fp8 w8a16.");
TORCH_CHECK(q_b_proj_scale.has_value(), "missing q_b_proj_scale for fp8 w8a16.");
TORCH_CHECK(kv_a_proj_scale.has_value(), "missing kv_a_proj_scale for fp8 w8a16.");
TORCH_CHECK(block_size.has_value(), "missing block_size for fp8 w8a16.");
TORCH_CHECK(block_size.value().size() == 2, "block_size should be 2D for fp8 w8a16.");
}
// outputs and temp buffer
const auto options = hidden_states.options();
auto q_input = at::empty({num_seqs, num_heads, kv_lora_rank + qk_rope_head_dim}, options);
auto k_input = at::empty({num_seqs, 1, kv_lora_rank + qk_rope_head_dim}, options);
auto v_input = k_input.narrow(-1, 0, kv_lora_rank);
// outputs of q_a_proj and q_b_proj
auto qa = at::empty({num_seqs, q_lora_rank}, options);
// stage 1: q_a_proj and kv_a_proj
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "qkv_proj_kernel_impl", [&] {
if (use_int8_w8a8) {
auto q_a_proj_s = q_a_proj_scale.value();
auto kv_a_proj_s = kv_a_proj_scale.value();
TORCH_CHECK(q_a_proj_s.numel() == q_lora_rank);
TORCH_CHECK(kv_a_proj_s.numel() == kv_lora_rank + qk_rope_head_dim);
auto buffer = at::empty({num_seqs * hidden_size + num_seqs * 4}, options.dtype(at::kByte));
uint8_t* __restrict__ Aq_data = buffer.data_ptr<uint8_t>();
float* __restrict__ As_data = (float*)((void*)(Aq_data + num_seqs * hidden_size));
const scalar_t* __restrict__ A_data = hidden_states.data_ptr<scalar_t>();
at::parallel_for(0, num_seqs, 0, [&](int64_t begin, int64_t end) {
for (int64_t m = begin; m < end; ++m) {
quantize_row_int8<scalar_t>(Aq_data + m * hidden_size, As_data[m], A_data + m * hidden_size, hidden_size);
}
});
segment_gemm_kernel_impl<scalar_t>(
qa.data_ptr<scalar_t>(),
k_input.data_ptr<scalar_t>(),
Aq_data,
q_a_proj_weight.data_ptr<int8_t>(),
kv_a_proj_weight.data_ptr<int8_t>(),
As_data,
q_a_proj_s.data_ptr<float>(),
kv_a_proj_s.data_ptr<float>(),
num_seqs,
q_lora_rank,
kv_lora_rank + qk_rope_head_dim,
hidden_size);
} else if (use_fp8_w8a16) {
int64_t block_size_N = block_size.value()[0];
int64_t block_size_K = block_size.value()[1];
auto q_a_proj_s = q_a_proj_scale.value();
auto kv_a_proj_s = kv_a_proj_scale.value();
CHECK_EQ(q_a_proj_s.size(0), div_up(q_lora_rank, block_size_N));
CHECK_EQ(q_a_proj_s.size(1), div_up(hidden_size, block_size_K));
CHECK_EQ(kv_a_proj_s.size(0), div_up(kv_lora_rank + qk_rope_head_dim, block_size_N));
CHECK_EQ(kv_a_proj_s.size(1), div_up(hidden_size, block_size_K));
const int BLOCK_N = block_size_n();
const int num_threads = at::get_num_threads();
auto buffer = at::empty({num_threads, BLOCK_N * hidden_size}, options);
segment_gemm_kernel_impl<scalar_t>(
qa.data_ptr<scalar_t>(),
k_input.data_ptr<scalar_t>(),
hidden_states.data_ptr<scalar_t>(),
q_a_proj_weight.data_ptr<at::Float8_e4m3fn>(),
kv_a_proj_weight.data_ptr<at::Float8_e4m3fn>(),
q_a_proj_s.data_ptr<float>(),
kv_a_proj_s.data_ptr<float>(),
buffer.data_ptr<scalar_t>(),
num_seqs,
q_lora_rank,
kv_lora_rank + qk_rope_head_dim,
hidden_size,
block_size_N,
block_size_K);
} else {
segment_gemm_kernel_impl<scalar_t>(
qa.data_ptr<scalar_t>(),
k_input.data_ptr<scalar_t>(),
hidden_states.data_ptr<scalar_t>(),
q_a_proj_weight.data_ptr<scalar_t>(),
kv_a_proj_weight.data_ptr<scalar_t>(),
num_seqs,
q_lora_rank,
kv_lora_rank + qk_rope_head_dim,
hidden_size);
}
});
// stage 2: apply rmsnorm inplace
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "rms_norm_kernel_impl", [&] {
rms_norm_kernel_impl<scalar_t>(
qa.data_ptr<scalar_t>(),
v_input.data_ptr<scalar_t>(),
q_a_layernorm_weight.data_ptr<scalar_t>(),
kv_a_layernorm_weight.data_ptr<scalar_t>(),
num_seqs,
q_lora_rank,
kv_lora_rank,
kv_lora_rank + qk_rope_head_dim,
eps);
});
// stage 3: q_b_proj
at::Tensor qb;
std::optional<at::Tensor> bias;
if (use_int8_w8a8) {
qb = int8_scaled_mm_with_quant(qa, q_b_proj_weight, q_b_proj_scale.value(), bias, at::kBFloat16, is_vnni);
} else if (use_fp8_w8a16) {
qb = fp8_scaled_mm_cpu(
qa, q_b_proj_weight, q_b_proj_scale.value(), block_size.value(), bias, at::kBFloat16, is_vnni);
} else {
qb = weight_packed_linear(qa, q_b_proj_weight, bias, is_vnni);
}
qb.as_strided_({num_seqs, num_heads, qk_head_dim}, {num_heads * qk_head_dim, qk_head_dim, 1});
// stage 4: bmm
std::optional<at::Tensor> scale;
auto q_nope = qb.narrow(2, 0, qk_nope_head_dim).transpose_(0, 1);
auto q_nope_out = q_input.narrow(2, 0, kv_lora_rank).transpose_(0, 1);
bmm_cpu(q_nope_out, q_nope, w_kc, is_vnni, scale);
// stage 5: rope
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "rotary_emb_kernel_impl", [&] {
rotary_emb_kernel_impl<scalar_t>(
q_input.data_ptr<scalar_t>() + kv_lora_rank,
k_input.data_ptr<scalar_t>() + kv_lora_rank,
qb.data_ptr<scalar_t>() + qk_nope_head_dim,
k_input.data_ptr<scalar_t>() + kv_lora_rank,
positions.data_ptr<int64_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
num_seqs,
num_heads,
rotary_dim,
num_heads * qk_head_dim,
qk_head_dim,
kv_lora_rank + qk_rope_head_dim,
num_heads * (kv_lora_rank + qk_rope_head_dim),
kv_lora_rank + qk_rope_head_dim,
kv_lora_rank + qk_rope_head_dim);
});
return std::make_tuple(q_input, k_input, v_input);
}
std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope_fused_weight(
at::Tensor& hidden_states,
at::Tensor& qkv_a_proj_weight,
at::Tensor& q_b_proj_weight,
at::Tensor& w_kc,
at::Tensor& q_a_layernorm_weight,
at::Tensor& kv_a_layernorm_weight,
at::Tensor& positions,
at::Tensor& cos_sin_cache,
double eps,
bool use_int8_w8a8,
bool use_fp8_w8a16,
std::optional<at::Tensor> qkv_a_proj_scale,
std::optional<at::Tensor> q_b_proj_scale,
bool is_vnni,
std::optional<std::vector<int64_t>> block_size,
int64_t q_lora_rank,
int64_t kv_lora_rank,
int64_t qk_rope_head_dim) {
RECORD_FUNCTION(
"sgl-kernel::qkv_proj_with_rope_fused_weight",
std::vector<c10::IValue>({hidden_states, qkv_a_proj_weight, q_b_proj_weight, w_kc}));
int64_t hidden_size = hidden_states.size(1);
CHECK_EQ(qkv_a_proj_weight.size(0), q_lora_rank + kv_lora_rank + qk_rope_head_dim);
CHECK_EQ(qkv_a_proj_weight.size(1), get_row_size(hidden_size, use_int8_w8a8));
std::vector<at::Tensor> weight_chunks =
at::split(qkv_a_proj_weight, {q_lora_rank, kv_lora_rank + qk_rope_head_dim}, 0);
at::Tensor q_a_proj_weight = weight_chunks[0];
at::Tensor kv_a_proj_weight = weight_chunks[1];
at::Tensor q_a_proj_s;
at::Tensor kv_a_proj_s;
if (use_int8_w8a8) {
TORCH_CHECK(qkv_a_proj_scale.has_value(), "missing qkv_a_proj_scale for int8 w8a8.");
std::vector<at::Tensor> scale_chunks =
at::split(qkv_a_proj_scale.value(), {q_lora_rank, kv_lora_rank + qk_rope_head_dim}, 0);
q_a_proj_s = scale_chunks[0];
kv_a_proj_s = scale_chunks[1];
}
if (use_fp8_w8a16) {
TORCH_CHECK(qkv_a_proj_scale.has_value(), "missing qkv_a_proj_scale for fp8 w8a16.");
int64_t block_size_N = block_size.value()[0];
int64_t q_a_proj_s_dim0 = div_up(q_lora_rank, block_size_N);
int64_t kv_a_proj_s_dim0 = div_up(kv_lora_rank + qk_rope_head_dim, block_size_N);
std::vector<at::Tensor> scale_chunks = at::split(qkv_a_proj_scale.value(), {q_a_proj_s_dim0, kv_a_proj_s_dim0}, 0);
q_a_proj_s = scale_chunks[0];
kv_a_proj_s = scale_chunks[1];
}
return qkv_proj_with_rope(
hidden_states,
q_a_proj_weight,
q_b_proj_weight,
kv_a_proj_weight,
w_kc,
q_a_layernorm_weight,
kv_a_layernorm_weight,
positions,
cos_sin_cache,
eps,
use_int8_w8a8,
use_fp8_w8a16,
q_a_proj_s,
q_b_proj_scale,
kv_a_proj_s,
is_vnni,
block_size);
}

View File

@@ -0,0 +1,346 @@
#include "common.h"
#include "vec.h"
namespace {
template <typename scalar_t>
void rotary_embedding_3D_kernel_impl(
scalar_t* __restrict__ query_out,
scalar_t* __restrict__ key_out,
int64_t* __restrict__ positions,
scalar_t* __restrict__ query,
scalar_t* __restrict__ key,
scalar_t* __restrict__ cos_sin_cache,
int64_t num_tokens,
int64_t num_heads,
int64_t num_kv_heads,
int64_t head_size,
int64_t rotary_dim,
int64_t query_stride_s,
int64_t query_out_stride_s,
int64_t key_out_stride_s,
int64_t key_stride_s,
int64_t query_stride_h,
int64_t query_out_stride_h) {
int64_t HR = rotary_dim;
int64_t HK = rotary_dim;
int64_t COFF = HR / 2;
at::parallel_for(0, num_tokens * num_heads, GRAIN_SIZE / rotary_dim, [&](int64_t begin, int64_t end) {
int64_t seq{0}, head_id{0};
data_index_init(begin, seq, num_tokens, head_id, num_heads);
for (int64_t i = begin; i < end; ++i) {
int64_t in_offset_q = seq * query_stride_s + head_id * query_stride_h;
int64_t out_offset_q = seq * query_out_stride_s + head_id * query_out_stride_h;
int64_t out_offset_k = seq * key_out_stride_s;
int64_t p = 0;
scalar_t* sin_start = nullptr;
scalar_t* cos_start = nullptr;
// step 0) get the rotary position embedding for the current position
p = positions[seq];
sin_start = cos_sin_cache + p * HR + COFF;
cos_start = cos_sin_cache + p * HR;
// step 1) apply_rotary_pos_emb for the rotary_dim elements in every
// head of query/key
for (int64_t h = 0; h < rotary_dim; h += 2) {
scalar_t cos = cos_start[h >> 1];
scalar_t sin = sin_start[h >> 1];
scalar_t in1 = query[in_offset_q + h];
scalar_t in2 = query[in_offset_q + h + 1];
scalar_t out1 = in1 * cos - in2 * sin;
scalar_t out2 = in2 * cos + in1 * sin;
query_out[out_offset_q + h] = out1;
query_out[out_offset_q + h + 1] = out2;
}
for (int64_t h = 0; h < HK; h += 2) {
scalar_t cos = cos_start[h >> 1];
scalar_t sin = sin_start[h >> 1];
int64_t k_pe_offset = seq * key_stride_s;
scalar_t in1_k = key[k_pe_offset + h];
scalar_t in2_k = key[k_pe_offset + h + 1];
scalar_t out1_k = in1_k * cos - in2_k * sin;
scalar_t out2_k = in2_k * cos + in1_k * sin;
key_out[out_offset_k + h] = out1_k;
key_out[out_offset_k + h + 1] = out2_k;
}
// move to the next index
data_index_step(seq, num_tokens, head_id, num_heads);
}
});
}
template <typename scalar_t>
void rotary_embedding_neox_2D_kernel_impl(
int64_t* __restrict__ positions,
scalar_t* __restrict__ query,
scalar_t* __restrict__ key,
scalar_t* __restrict__ cos_sin_cache,
int64_t rotary_dim,
int64_t query_stride_s,
int64_t key_stride_s,
int64_t num_heads,
int64_t num_kv_heads,
int64_t head_size,
int64_t num_tokens) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int64_t bVecSize = bVec::size();
int64_t embed_dim = rotary_dim / 2;
bool flag = (embed_dim % bVecSize == 0);
int64_t loop_upper = flag ? embed_dim : embed_dim - bVecSize;
auto compute_loop = [&](int64_t token_head, scalar_t* cache_ptr, scalar_t* qk) {
int64_t j = 0;
for (; j < loop_upper; j += bVecSize) {
int64_t rot_offset = j;
int64_t x_index = rot_offset;
int64_t y_index = embed_dim + rot_offset;
int64_t out_x = token_head + x_index;
int64_t out_y = token_head + y_index;
bVec _cos = bVec::loadu(cache_ptr + x_index);
bVec _sin = bVec::loadu(cache_ptr + y_index);
bVec _q_x = bVec::loadu(qk + out_x);
bVec _q_y = bVec::loadu(qk + out_y);
fVec _cos_0, _cos_1;
std::tie(_cos_0, _cos_1) = at::vec::convert_to_float(_cos);
fVec _sin_0, _sin_1;
std::tie(_sin_0, _sin_1) = at::vec::convert_to_float(_sin);
fVec _q_x_0, _q_x_1;
std::tie(_q_x_0, _q_x_1) = at::vec::convert_to_float(_q_x);
fVec _q_y_0, _q_y_1;
std::tie(_q_y_0, _q_y_1) = at::vec::convert_to_float(_q_y);
auto out1_0 = _q_x_0 * _cos_0 - _q_y_0 * _sin_0;
auto out1_1 = _q_x_1 * _cos_1 - _q_y_1 * _sin_1;
auto out1 = convert_from_float_ext<scalar_t>(out1_0, out1_1);
out1.store(qk + out_x);
auto out2_0 = _q_y_0 * _cos_0 + _q_x_0 * _sin_0;
auto out2_1 = _q_y_1 * _cos_1 + _q_x_1 * _sin_1;
auto out2 = convert_from_float_ext<scalar_t>(out2_0, out2_1);
out2.store(qk + out_y);
}
if (!flag) {
for (; j < embed_dim; ++j) {
int64_t x_index = j;
int64_t y_index = embed_dim + j;
int64_t out_x = token_head + x_index;
int64_t out_y = token_head + y_index;
float _cos = cache_ptr[x_index];
float _sin = cache_ptr[y_index];
float _q_x = qk[out_x];
float _q_y = qk[out_y];
qk[out_x] = _q_x * _cos - _q_y * _sin;
qk[out_y] = _q_y * _cos + _q_x * _sin;
}
}
};
#pragma omp parallel for
for (int64_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
int64_t pos = positions[token_idx];
scalar_t* cache_ptr = cos_sin_cache + pos * rotary_dim;
for (int64_t i = 0; i < num_heads; ++i) {
int64_t head_idx = i;
int64_t token_head = token_idx * query_stride_s + head_idx * head_size;
compute_loop(token_head, cache_ptr, query);
}
for (int64_t i = 0; i < num_kv_heads; ++i) {
int64_t head_idx = i;
int64_t token_head = token_idx * key_stride_s + head_idx * head_size;
compute_loop(token_head, cache_ptr, key);
}
}
}
template <typename scalar_t>
void rotary_embedding_2D_kernel_impl(
int64_t* __restrict__ positions,
scalar_t* __restrict__ query,
scalar_t* __restrict__ key,
scalar_t* __restrict__ cos_sin_cache,
int64_t rotary_dim,
int64_t query_stride_s,
int64_t key_stride_s,
int64_t num_heads,
int64_t num_kv_heads,
int64_t head_size,
int64_t num_tokens) {
int64_t embed_dim = rotary_dim / 2;
at::parallel_for(0, num_tokens * num_heads, GRAIN_SIZE / rotary_dim, [&](int64_t begin, int64_t end) {
int64_t token_idx = {0}, i = {0};
data_index_init(begin, token_idx, num_tokens, i, num_heads);
for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
int64_t pos = positions[token_idx];
scalar_t* cache_ptr = cos_sin_cache + pos * rotary_dim;
scalar_t* cos_cache_ptr = cache_ptr;
scalar_t* sin_cache_ptr = cache_ptr + embed_dim;
int64_t head_idx = i;
int64_t token_head = token_idx * query_stride_s + head_idx * head_size;
scalar_t* head_query = token_head + query;
for (int64_t j = 0; j < embed_dim; j += 1) {
int64_t rot_offset = j;
int64_t x_index = 2 * rot_offset;
int64_t y_index = 2 * rot_offset + 1;
float cos = cos_cache_ptr[rot_offset];
float sin = sin_cache_ptr[rot_offset];
float x = head_query[x_index];
float y = head_query[y_index];
head_query[x_index] = x * cos - y * sin;
head_query[y_index] = y * cos + x * sin;
}
data_index_step(token_idx, num_tokens, i, num_heads);
}
});
at::parallel_for(0, num_tokens * num_kv_heads, GRAIN_SIZE / rotary_dim, [&](int64_t begin, int64_t end) {
int64_t token_idx{0}, i = {0};
data_index_init(begin, token_idx, num_tokens, i, num_kv_heads);
for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
int64_t pos = positions[token_idx];
scalar_t* cache_ptr = cos_sin_cache + pos * rotary_dim;
scalar_t* cos_cache_ptr = cache_ptr;
scalar_t* sin_cache_ptr = cache_ptr + embed_dim;
int64_t head_idx = i;
int64_t token_head = token_idx * key_stride_s + head_idx * head_size;
scalar_t* head_key = key + token_head;
for (int64_t j = 0; j < embed_dim; j += 1) {
int64_t rot_offset = j;
int64_t x_index = 2 * rot_offset;
int64_t y_index = 2 * rot_offset + 1;
float cos = cos_cache_ptr[rot_offset];
float sin = sin_cache_ptr[rot_offset];
float x = head_key[x_index];
float y = head_key[y_index];
head_key[x_index] = x * cos - y * sin;
head_key[y_index] = y * cos + x * sin;
}
data_index_step(token_idx, num_tokens, i, num_kv_heads);
}
});
}
} // namespace
std::tuple<at::Tensor, at::Tensor> rotary_embedding_cpu(
at::Tensor& positions,
at::Tensor& query,
at::Tensor& key,
int64_t head_size,
at::Tensor& cos_sin_cache,
bool is_neox) {
RECORD_FUNCTION("sgl-kernel::rotary_embedding_cpu", std::vector<c10::IValue>({query, key}));
CHECK_DIM(1, positions);
const auto input_dim = query.dim();
const auto input_dtype = query.scalar_type();
TORCH_CHECK(
input_dim == 2 || input_dim == 3,
" Query/Key must be 2D [num_tokens, num_heads*head_size] or 3D [num_tokens, num_heads, head_size] tensor");
CHECK_DIM(2, cos_sin_cache);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(query);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(key);
int64_t rotary_dim = cos_sin_cache.size(1);
if (input_dim == 3) {
// TODO: add support for head_dim != rotary_dim case when input_dim=3
CHECK_EQ(query.size(-1), rotary_dim);
// TODO: add support for kv_head != 1
CHECK_EQ(key.size(1), 1);
}
int64_t num_tokens = positions.numel();
CHECK_EQ(key.size(0), num_tokens);
CHECK_EQ(query.size(0), num_tokens);
TORCH_CHECK(positions.scalar_type() == at::kLong, "expect positions to be int64, got ", positions.scalar_type());
TORCH_CHECK(input_dtype == key.scalar_type(), "query and key must have the same data type");
TORCH_CHECK(input_dtype == cos_sin_cache.scalar_type(), "query and cos_sin_cache must have the same data type");
int64_t num_heads = input_dim == 2 ? query.size(-1) / head_size : query.size(1);
int64_t num_kv_heads = input_dim == 2 ? key.size(-1) / head_size : key.size(1);
int64_t key_stride_s = key.stride(0);
int64_t query_stride_s = query.stride(0);
// input stride of num head dim is meaningful only when input dim = 3
int64_t query_stride_h = input_dim == 3 ? query.stride(1) : -1;
at::Tensor query_out = at::empty_like(query);
at::Tensor key_out = at::empty_like(key);
int64_t query_out_stride_s = query_out.stride(0);
int64_t key_out_stride_s = key_out.stride(0);
// output stride of num head dim is meaningful only when input dim = 3
int64_t query_out_stride_h = input_dim == 3 ? query_out.stride(1) : -1;
AT_DISPATCH_REDUCED_FLOATING_TYPES(input_dtype, "rotary_embedding_cpu", [&] {
if (input_dim == 2) {
if (is_neox) {
rotary_embedding_neox_2D_kernel_impl<scalar_t>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
rotary_dim,
query_stride_s,
key_stride_s,
num_heads,
num_kv_heads,
head_size,
num_tokens);
} else {
rotary_embedding_2D_kernel_impl<scalar_t>(
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
rotary_dim,
query_stride_s,
key_stride_s,
num_heads,
num_kv_heads,
head_size,
num_tokens);
}
query_out = query;
key_out = key;
} else {
TORCH_CHECK(
is_neox == false, " Query/Key with 3D [num_tokens, num_heads, head_size] does not support neox rope yet");
// TODO: add neox style support for rope impl with 3D inputs
rotary_embedding_3D_kernel_impl<scalar_t>(
query_out.data_ptr<scalar_t>(),
key_out.data_ptr<scalar_t>(),
positions.data_ptr<int64_t>(),
query.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
num_tokens,
num_heads,
num_kv_heads,
head_size,
rotary_dim,
query_stride_s,
query_out_stride_s,
key_out_stride_s,
key_stride_s,
query_stride_h,
query_out_stride_h);
}
});
return std::make_tuple(query_out, key_out);
}

666
sgl-kernel/csrc/cpu/shm.cpp Normal file
View File

@@ -0,0 +1,666 @@
#include "shm.h"
#include <ATen/ATen.h>
#include <errno.h>
#include <fcntl.h>
#include <immintrin.h>
#include <stddef.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/mman.h>
#include <unistd.h>
// states for collectives
enum coll_state {
coll_begin = 0,
coll_allreduce_naive__copy_in_done,
coll_allreduce_naive__reduce_done,
// alternative state when allreduce is working on alternative buffer
// of the double buffer.
coll_alt1_allreduce_naive__copy_in_done,
coll_alt2_allreduce_naive__copy_in_done,
coll_alt1_allreduce_naive__reduce_done,
coll_allgather_naive__copy_in_done,
coll_alt1_allgather_naive__copy_in_done,
coll_alt2_allgather_naive__copy_in_done,
};
// SHM building blocks
struct SharedData {
const char* name;
int descriptor;
void* bytes;
size_t nbytes;
};
void shared_open(SharedData* data, const char* name, size_t nbytes) {
int d = shm_open(name, O_RDWR, S_IRUSR | S_IWUSR);
if (d != -1) {
void* bytes = mmap(NULL, nbytes, PROT_READ | PROT_WRITE, MAP_SHARED, d, 0);
data->name = name;
data->descriptor = d;
data->bytes = bytes;
data->nbytes = nbytes;
} else {
if (errno != ENOENT) {
// don't print if shm can not be found because we want to loop over from
// caller again until the other ranks created the shm
printf("shared_open %s failed, errno=%d\n", name, errno);
}
data->descriptor = -1;
}
}
void shared_create(SharedData* data, const char* name, void* bytes, size_t nbytes) {
int d = shm_open(name, O_CREAT | O_RDWR, S_IRUSR | S_IWUSR);
if (d != -1) {
nbytes = write(d, bytes, nbytes);
if (nbytes > 0) {
shared_open(data, name, nbytes);
}
} else {
printf("shared_create %s failed\n", name);
}
}
static int world_size;
// SHM based allreduce helper functions
// buffer that holds shm name
#define NAME_BUF_SIZE 1000
#define MAX_BUF_SIZE 1048576 * 32
#define NAIVE_ALLREDUCE_THRESHOLD 1048576
#define SHM_BUFFER_NAME "deepspeed_allreduce_buffer"
struct allreduce_workspace {
enum coll_state states[2]; // idx=0 -- state for symmetric_naive_all_reduce
// idx=1 -- state for distributed_naive_all_reduce
// double buffer to avoid syncing between rounds
// offset=0 -- 2*NAIVE_ALLREDUCE_THRESHOLD : buffer for
// symmetric_naive_all_reduce after that : buffer for
// distributed_naive_all_reduce
char buffer[2 * NAIVE_ALLREDUCE_THRESHOLD + 2 * MAX_BUF_SIZE];
};
#define BUFFER0_OFFSET(current_buffer) current_buffer* NAIVE_ALLREDUCE_THRESHOLD
#define BUFFER1_OFFSET(current_buffer) 2 * NAIVE_ALLREDUCE_THRESHOLD + current_buffer* MAX_BUF_SIZE
struct allreduce_workspace** workspace;
// buffer for small messages, double buffer
char** symmetric_buffer[2];
// buffer for large messages, double buffer
char** distributed_buffer[2];
void wait_buffer_state_until_2(int index, enum coll_state state0, enum coll_state state1, int state_group) {
volatile enum coll_state* state_ptr = &(workspace[index]->states[state_group]);
while (1) {
volatile enum coll_state cur_state = *state_ptr;
if (cur_state == state0 || cur_state == state1) break;
}
}
__m512 cvt_bf16_to_fp32(const __m256i src) __attribute__((target("avx512bw")));
inline __m512 cvt_bf16_to_fp32(const __m256i src) {
auto y = _mm512_cvtepu16_epi32(src);
return _mm512_castsi512_ps(_mm512_bslli_epi128(y, 2));
}
inline __m256i cvt_fp32_to_bf16(const __m512 src) __attribute__((target("avx512bw")));
inline __m256i cvt_fp32_to_bf16(const __m512 src) {
__m512i value = _mm512_castps_si512(src);
__m512i nan = _mm512_set1_epi32(0xffff);
auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q);
__m512i ones = _mm512_set1_epi32(0x1);
__m512i vec_bias = _mm512_set1_epi32(0x7fff);
// uint32_t lsb = (input >> 16) & 1;
auto t_value = _mm512_and_si512(_mm512_srli_epi32(value, 16), ones);
// uint32_t rounding_bias = 0x7fff + lsb;
t_value = _mm512_add_epi32(t_value, vec_bias);
// input += rounding_bias;
t_value = _mm512_add_epi32(t_value, value);
// input = input >> 16;
t_value = _mm512_srli_epi32(t_value, 16);
// Check NaN before converting back to bf16
t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value);
return _mm512_cvtusepi32_epi16(t_value);
}
__m512 cvt_fp16_to_fp32(const __m256i src) __attribute__((target("avx512bw")));
inline __m512 cvt_fp16_to_fp32(const __m256i src) {
return _mm512_cvtph_ps(src);
}
inline __m256i cvt_fp32_to_fp16(const __m512 src) __attribute__((target("avx512bw")));
inline __m256i cvt_fp32_to_fp16(const __m512 src) {
return _mm512_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
}
void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
__attribute__((target("avx512bw")));
void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
__attribute__((target("avx512bw")));
void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
__attribute__((target("avx512bw")));
void reduce_all_buffers(
int start_elements,
int num_elements,
c10::ScalarType scalar_type,
int to_buffer_idx,
char* to_buffer,
char** buffers) {
switch (scalar_type) {
case c10::ScalarType::BFloat16:
reduce_bf16_buffers(start_elements, num_elements, to_buffer, buffers);
break;
case c10::ScalarType::Half:
reduce_fp16_buffers(start_elements, num_elements, to_buffer, buffers);
break;
case c10::ScalarType::Float:
reduce_fp32_buffers(start_elements, num_elements, to_buffer, buffers);
break;
default:
assert(!"Should not get here");
}
}
#define CVT_ADD_BF16(x) \
do { \
auto in##x##_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[x] + i))); \
inout_val = _mm512_add_ps(inout_val, in##x##_val); \
} while (0)
// Reduce functions down below use vectorized algorithm, the number of bytes
// processed each iteration depends on vector length. 256bit vector ==> 32
// bytes, 512bit vector ==> 64 bytes If you change implementation of
// reduce_bf16_buffers, etc. , check whether this number needs to be changed
#define VECTOR_LENGTH_IN_BYTES 32
void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) {
const int element_size = 2;
const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size;
int main_elements = num_elements - (num_elements % vector_length);
int remain_elements = num_elements % vector_length;
// process aligned part
#pragma omp parallel for
for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size;
i += VECTOR_LENGTH_IN_BYTES) {
auto inout_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[0] + i)));
switch (world_size) {
case 16:
CVT_ADD_BF16(15);
case 15:
CVT_ADD_BF16(14);
case 14:
CVT_ADD_BF16(13);
case 13:
CVT_ADD_BF16(12);
case 12:
CVT_ADD_BF16(11);
case 11:
CVT_ADD_BF16(10);
case 10:
CVT_ADD_BF16(9);
case 9:
CVT_ADD_BF16(8);
case 8:
CVT_ADD_BF16(7);
case 7:
CVT_ADD_BF16(6);
case 6:
CVT_ADD_BF16(5);
case 5:
CVT_ADD_BF16(4);
case 4:
CVT_ADD_BF16(3);
case 3:
CVT_ADD_BF16(2);
case 2:
CVT_ADD_BF16(1);
case 1:
break;
default:
for (int j = 1; j < world_size; j++) {
auto in_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[j] + i)));
inout_val = _mm512_add_ps(inout_val, in_val);
}
}
_mm256_storeu_si256((__m256i*)(to_buffer + i), cvt_fp32_to_bf16(inout_val));
}
// process remaining part
int i = (start_elements + main_elements) * element_size;
while (remain_elements > 0) {
float val = 0.0f;
for (int j = 0; j < world_size; j++) {
val += *(at::BFloat16*)(buffers[j] + i);
}
*(at::BFloat16*)(to_buffer + i) = val;
remain_elements--;
i += element_size;
}
}
#define CVT_ADD_FP16(x) \
do { \
auto in##x##_val = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[x] + i))); \
inout_val = _mm512_add_ps(inout_val, in##x##_val); \
} while (0)
void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) {
const int element_size = 2;
const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size;
int main_elements = num_elements - (num_elements % vector_length);
int remain_elements = num_elements % vector_length;
// process aligned part
#pragma omp parallel for
for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size;
i += VECTOR_LENGTH_IN_BYTES) {
auto inout_val = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[0] + i)));
switch (world_size) {
case 16:
CVT_ADD_FP16(15);
case 15:
CVT_ADD_FP16(14);
case 14:
CVT_ADD_FP16(13);
case 13:
CVT_ADD_FP16(12);
case 12:
CVT_ADD_FP16(11);
case 11:
CVT_ADD_FP16(10);
case 10:
CVT_ADD_FP16(9);
case 9:
CVT_ADD_FP16(8);
case 8:
CVT_ADD_FP16(7);
case 7:
CVT_ADD_FP16(6);
case 6:
CVT_ADD_FP16(5);
case 5:
CVT_ADD_FP16(4);
case 4:
CVT_ADD_FP16(3);
case 3:
CVT_ADD_FP16(2);
case 2:
CVT_ADD_FP16(1);
case 1:
break;
default:
for (int j = 1; j < world_size; j++) {
auto in_val = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[j] + i)));
inout_val = _mm512_add_ps(inout_val, in_val);
}
}
_mm256_storeu_si256((__m256i*)(to_buffer + i), cvt_fp32_to_fp16(inout_val));
}
// process remaining part
int i = (start_elements + main_elements) * element_size;
while (remain_elements > 0) {
float val = 0.0f;
for (int j = 0; j < world_size; j++) {
val += *(at::Half*)(buffers[j] + i);
}
*(at::Half*)(to_buffer + i) = val;
remain_elements--;
i += element_size;
}
}
#define CVT_ADD_F32(x) \
do { \
auto in##x##_val = _mm256_loadu_ps((float*)(buffers[x] + i)); \
inout_val = _mm256_add_ps(inout_val, in##x##_val); \
} while (0)
void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) {
const int element_size = 4;
const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size;
int main_elements = num_elements - (num_elements % vector_length);
int remain_elements = num_elements % vector_length;
// process aligned part
#pragma omp parallel for
for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size;
i += VECTOR_LENGTH_IN_BYTES) {
auto inout_val = _mm256_loadu_ps((float*)(buffers[0] + i));
switch (world_size) {
case 16:
CVT_ADD_F32(15);
case 15:
CVT_ADD_F32(14);
case 14:
CVT_ADD_F32(13);
case 13:
CVT_ADD_F32(12);
case 12:
CVT_ADD_F32(11);
case 11:
CVT_ADD_F32(10);
case 10:
CVT_ADD_F32(9);
case 9:
CVT_ADD_F32(8);
case 8:
CVT_ADD_F32(7);
case 7:
CVT_ADD_F32(6);
case 6:
CVT_ADD_F32(5);
case 5:
CVT_ADD_F32(4);
case 4:
CVT_ADD_F32(3);
case 3:
CVT_ADD_F32(2);
case 2:
CVT_ADD_F32(1);
case 1:
break;
default:
for (int j = 1; j < world_size; j++) {
auto in_val = _mm256_loadu_ps((float*)(buffers[j] + i));
inout_val = _mm256_add_ps(inout_val, in_val);
}
}
_mm256_storeu_ps((float*)(to_buffer + i), inout_val);
}
// process remaining part
int i = (start_elements + main_elements) * element_size;
while (remain_elements > 0) {
float val = 0.0f;
for (int j = 0; j < world_size; j++) {
val += *(float*)(buffers[j] + i);
}
*(float*)(to_buffer + i) = val;
remain_elements--;
i += element_size;
}
}
static bool is_initialized = false;
static int world_rank;
void shm_initialize(int size, int rank, const char* addr_string, const char* port_string) {
if (is_initialized) {
return;
}
is_initialized = true;
world_size = size;
world_rank = rank;
char shm_name_prefix[NAME_BUF_SIZE];
char shm_name[NAME_BUF_SIZE];
snprintf(shm_name_prefix, NAME_BUF_SIZE, "%s_%d_%s_%s", SHM_BUFFER_NAME, getuid(), addr_string, port_string);
// create shared workspace for SHM based allreduce
SharedData allreduce_buffer;
// allocate workspace_buf for current rank
struct allreduce_workspace* workspace_buf;
struct allreduce_workspace* workspace_buf_other;
workspace_buf = (struct allreduce_workspace*)malloc(sizeof(struct allreduce_workspace));
snprintf(shm_name, NAME_BUF_SIZE, "%.900s_%d", shm_name_prefix, rank);
shared_create(&allreduce_buffer, shm_name, workspace_buf, sizeof(struct allreduce_workspace));
workspace_buf = (struct allreduce_workspace*)allreduce_buffer.bytes;
workspace_buf->states[0] = coll_alt2_allreduce_naive__copy_in_done;
workspace_buf->states[1] = coll_begin;
// create the workspace pointer list
workspace = (struct allreduce_workspace**)malloc(size * sizeof(struct allreduce_workspace*));
symmetric_buffer[0] = (char**)malloc(size * sizeof(char**));
symmetric_buffer[1] = (char**)malloc(size * sizeof(char**));
distributed_buffer[0] = (char**)malloc(size * sizeof(char**));
distributed_buffer[1] = (char**)malloc(size * sizeof(char**));
// map shm of all ranks
for (int i = 0; i < size; i++) {
if (i != rank) {
snprintf(shm_name, NAME_BUF_SIZE, "%.900s_%d", shm_name_prefix, i);
// printf("open %s, %d\n", shm_name, rank);
do {
shared_open(&allreduce_buffer, shm_name, sizeof(struct allreduce_workspace));
} while (allreduce_buffer.descriptor == -1 && errno == ENOENT);
workspace_buf_other = (struct allreduce_workspace*)allreduce_buffer.bytes;
workspace[i] = workspace_buf_other;
} else {
workspace[i] = workspace_buf;
}
symmetric_buffer[0][i] = workspace[i]->buffer + BUFFER0_OFFSET(0);
symmetric_buffer[1][i] = workspace[i]->buffer + BUFFER0_OFFSET(1);
distributed_buffer[0][i] = workspace[i]->buffer + BUFFER1_OFFSET(0);
distributed_buffer[1][i] = workspace[i]->buffer + BUFFER1_OFFSET(1);
}
}
static void parallel_memcpy(void* to, void* from, size_t n_bytes) __attribute__((target("avx512bw")));
static void parallel_memcpy(void* to, void* from, size_t n_bytes) {
auto aligned_bytes = n_bytes - (n_bytes % VECTOR_LENGTH_IN_BYTES);
// process aligned part
#pragma omp parallel for
for (size_t i = 0; i < aligned_bytes; i += VECTOR_LENGTH_IN_BYTES) {
auto val = _mm256_loadu_si256((__m256i*)((char*)from + i));
_mm256_storeu_si256((__m256i*)((char*)to + i), val);
}
// process remaining part
for (size_t i = aligned_bytes; i < n_bytes; i++) {
*((char*)to + i) = *((char*)from + i);
}
}
#define positive_mod(num, mod) ((((num) % (mod)) + (mod)) % (mod))
#define rank_mod(rank) positive_mod(rank, world_size)
size_t slice_size(size_t chunk_el, int slice_idx) {
size_t slice_size = chunk_el / world_size;
return slice_idx == world_size - 1 ? slice_size + (chunk_el % world_size) : slice_size;
}
char* slice_data(char* data_ptr, size_t chunk_el, int el_size, int slice_idx) {
size_t slice_size = chunk_el / world_size;
size_t el_offset = slice_size * slice_idx;
return data_ptr + el_offset * el_size;
}
size_t slice_el_start(size_t chunk_el, int slice_idx) {
size_t slice_size = chunk_el / world_size;
return slice_size * slice_idx;
}
void symmetric_naive_all_reduce(char* data_ptr, c10::ScalarType scalar_type, size_t chunk_size, size_t chunk_el) {
const int state_group = 0;
static int current_buffer = 0;
static int state_idx = 0;
// init states to case 0 to get rid of "maybe-uninitialized" warning.
enum coll_state copy_current = coll_allreduce_naive__copy_in_done;
enum coll_state copy_next = coll_alt1_allreduce_naive__copy_in_done;
switch (state_idx) {
case 0:
copy_current = coll_allreduce_naive__copy_in_done;
copy_next = coll_alt1_allreduce_naive__copy_in_done;
break;
case 1:
copy_current = coll_alt1_allreduce_naive__copy_in_done;
copy_next = coll_alt2_allreduce_naive__copy_in_done;
break;
case 2:
copy_current = coll_alt2_allreduce_naive__copy_in_done;
copy_next = coll_allreduce_naive__copy_in_done;
break;
default:
assert(!"Should not get here.");
}
state_idx = (state_idx + 1) % 3;
parallel_memcpy(symmetric_buffer[current_buffer][world_rank], data_ptr, chunk_size);
std::atomic_thread_fence(std::memory_order_release);
workspace[world_rank]->states[state_group] = copy_current;
for (int i = 0; i < world_size; i++) {
// wait until the other rank copy the buffer
if (i != world_rank) {
wait_buffer_state_until_2(i, copy_current, copy_next, state_group);
}
}
// each rank reduce the buffer independently so therre is no need for
// synchronization afterward
reduce_all_buffers(0, chunk_el, scalar_type, world_rank, data_ptr, symmetric_buffer[current_buffer]);
// switch buffer
current_buffer = 1 - current_buffer;
}
// naive allreduce distributed, each rank do naive reduce on its slice
void distributed_naive_reduce(char* data_ptr, c10::ScalarType scalar_type, size_t chunk_size, size_t chunk_el) {
const int state_group = 1;
static int current_buffer = 0;
static int state_idx = 0;
// init states to case 0 to get rid of "maybe-uninitialized" warning.
enum coll_state copy_current = coll_allreduce_naive__copy_in_done;
enum coll_state reduce_current = coll_allreduce_naive__reduce_done;
enum coll_state copy_next = coll_alt1_allreduce_naive__copy_in_done;
// similar to symmetric_naive_allreduce, but here we only need two sets of
// states, because distributed naive reduce has two barriers in the algorithm
switch (state_idx) {
case 0:
copy_current = coll_allreduce_naive__copy_in_done;
reduce_current = coll_allreduce_naive__reduce_done;
copy_next = coll_alt1_allreduce_naive__copy_in_done;
break;
case 1:
copy_current = coll_alt1_allreduce_naive__copy_in_done;
reduce_current = coll_alt1_allreduce_naive__reduce_done;
copy_next = coll_allreduce_naive__copy_in_done;
break;
default:
assert(!"Should not get here.");
}
state_idx = (state_idx + 1) % 2;
int data_size = chunk_size / chunk_el;
parallel_memcpy(distributed_buffer[current_buffer][world_rank], data_ptr, chunk_size);
std::atomic_thread_fence(std::memory_order_release);
workspace[world_rank]->states[state_group] = copy_current;
for (int i = 0; i < world_size; i++) {
// wait until all the other ranks copy the buffer
if (i != world_rank) wait_buffer_state_until_2(i, copy_current, reduce_current, state_group);
}
// reduce scatter
reduce_all_buffers(
slice_el_start(chunk_el, world_rank),
slice_size(chunk_el, world_rank),
scalar_type,
world_rank,
distributed_buffer[current_buffer][world_rank],
distributed_buffer[current_buffer]);
std::atomic_thread_fence(std::memory_order_release);
workspace[world_rank]->states[state_group] = reduce_current;
for (int i = 0; i < world_size; i++) {
// wait until all the other ranks reduce the buffer
if (i != world_rank) wait_buffer_state_until_2(i, reduce_current, copy_next, state_group);
}
for (int i = 0; i < world_size; i++) {
int rank = (i + world_rank) % world_size;
parallel_memcpy(
slice_data(data_ptr, chunk_el, data_size, rank),
slice_data(distributed_buffer[current_buffer][rank], chunk_el, chunk_size / chunk_el, rank),
slice_size(chunk_el, rank) * data_size);
}
current_buffer = 1 - current_buffer;
}
void all_reduce_outer_loop(torch::Tensor& data, size_t numel, int data_size) {
for (int offset = 0; offset < data_size; offset += MAX_BUF_SIZE) {
auto data_ptr = ((char*)(data.data_ptr()) + offset);
size_t chunk_size = data_size - offset > MAX_BUF_SIZE ? MAX_BUF_SIZE : data_size - offset;
size_t chunk_el = chunk_size / (data_size / numel);
if (chunk_size < NAIVE_ALLREDUCE_THRESHOLD) {
symmetric_naive_all_reduce(data_ptr, data.scalar_type(), chunk_size, chunk_el);
} else {
distributed_naive_reduce(data_ptr, data.scalar_type(), chunk_size, chunk_el);
}
}
}
void naive_all_gather(char* result_ptr, char* data_ptr, size_t res_stride, size_t chunk_size, size_t chunk_el) {
const int state_group = 1;
static int current_buffer = 0;
static int state_idx = 0;
// init states to case 0 to get rid of "maybe-uninitialized" warning.
enum coll_state copy_current = coll_allgather_naive__copy_in_done;
enum coll_state copy_next = coll_alt1_allgather_naive__copy_in_done;
switch (state_idx) {
case 0:
copy_current = coll_allgather_naive__copy_in_done;
copy_next = coll_alt1_allgather_naive__copy_in_done;
break;
case 1:
copy_current = coll_alt1_allgather_naive__copy_in_done;
copy_next = coll_alt2_allgather_naive__copy_in_done;
break;
case 2:
copy_current = coll_alt2_allgather_naive__copy_in_done;
copy_next = coll_allgather_naive__copy_in_done;
break;
default:
assert(!"Should not get here.");
}
state_idx = (state_idx + 1) % 3;
parallel_memcpy(distributed_buffer[current_buffer][world_rank], data_ptr, chunk_size);
std::atomic_thread_fence(std::memory_order_release);
workspace[world_rank]->states[state_group] = copy_current;
for (int i = 0; i < world_size; i++) {
// wait until all the other ranks copy the buffer
if (i != world_rank) wait_buffer_state_until_2(i, copy_current, copy_next, state_group);
}
for (int i = 0; i < world_size; i++) {
parallel_memcpy(result_ptr + i * res_stride, distributed_buffer[current_buffer][i], chunk_size);
}
current_buffer = 1 - current_buffer;
}
torch::Tensor& all_gather(torch::Tensor& result, torch::Tensor& data, int dim, size_t numel, int data_size) {
size_t dim_el = data.stride(dim) * data.size(dim);
int dtype_size = data_size / numel;
size_t dim_size = dim_el * dtype_size;
int dim_count = data_size / dim_size;
auto data_ptr = (char*)(data.data_ptr());
auto result_ptr = (char*)(result.data_ptr());
for (int i = 0; i < dim_count; i++) {
for (size_t offset = 0; offset < dim_size; offset += MAX_BUF_SIZE) {
size_t chunk_size = dim_size - offset > MAX_BUF_SIZE ? MAX_BUF_SIZE : dim_size - offset;
size_t chunk_el = chunk_size / dtype_size;
naive_all_gather(
result_ptr + i * dim_size * world_size + offset,
data_ptr + i * dim_size + offset,
dim_size,
chunk_size,
chunk_el);
}
}
return result;
}

11
sgl-kernel/csrc/cpu/shm.h Normal file
View File

@@ -0,0 +1,11 @@
#include <torch/all.h>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#ifndef __SHM_COLLECTIVES__
#define __SHM_COLLECTIVES__
#define VECTOR_LENGTH_IN_BYTES 32
void shm_initialize(int size, int rank, const char* addr_string, const char* port_string);
void all_reduce_outer_loop(torch::Tensor& data, size_t numel, int data_size);
torch::Tensor& all_gather(torch::Tensor& result, torch::Tensor& data, int dim, size_t numel, int data_size);
#endif

View File

@@ -0,0 +1,662 @@
#include "common.h"
#include "vec.h"
namespace {
template <typename scalar_t, int SIZE>
inline void softmax(float* __restrict__ out, const scalar_t* __restrict__ input) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
// step 1: get max
fVec max_fvec = fVec(-std::numeric_limits<float>::infinity());
if constexpr (SIZE < kVecSize) {
// SIZE = 1, 2, 4, 8, 16; only the top half is used
bVec x_bvec = bVec::loadu(input, SIZE);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
x_fvec0 = fVec::set(max_fvec, x_fvec0, SIZE);
max_fvec = at::vec::maximum(max_fvec, x_fvec0);
x_fvec0.store(out, SIZE);
} else {
for (int d = 0; d < SIZE; d += kVecSize) {
bVec x_bvec = bVec::loadu(input + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
max_fvec = at::vec::maximum(max_fvec, x_fvec0);
max_fvec = at::vec::maximum(max_fvec, x_fvec1);
x_fvec0.store(out + d);
x_fvec1.store(out + d + fVec::size());
}
}
float max_val = vec_reduce_max(max_fvec);
max_fvec = fVec(max_val);
// step 2: sum of (x - max).exp()
fVec sum_fvec = fVec(float(0));
if constexpr (SIZE < fVec::size()) {
// SIZE = 1, 2, 4, 8
fVec x_fvec = (fVec::loadu(out, SIZE) - max_fvec).exp_u20();
x_fvec = fVec::set(sum_fvec, x_fvec, SIZE);
sum_fvec += x_fvec;
x_fvec.store(out, SIZE);
} else {
for (int d = 0; d < SIZE; d += fVec::size()) {
fVec x_fvec = (fVec::loadu(out + d) - max_fvec).exp_u20();
sum_fvec += x_fvec;
x_fvec.store(out + d);
}
}
float sum_val = vec_reduce_sum(sum_fvec);
// step 3: x * (1 / sum)
sum_fvec = fVec(1.f / sum_val);
if constexpr (SIZE < fVec::size()) {
// SIZE = 1, 2, 4, 8
fVec out_fvec = fVec::loadu(out, SIZE) * sum_fvec;
out_fvec.store(out, SIZE);
} else {
for (int d = 0; d < SIZE; d += fVec::size()) {
fVec out_fvec = fVec::loadu(out + d) * sum_fvec;
out_fvec.store(out + d);
}
}
}
template <typename scalar_t, int NUM_EXPERTS>
void grouped_topk_kernel_impl(
float* __restrict__ topk_weights,
int32_t* __restrict__ topk_ids,
const scalar_t* __restrict__ gating_output,
int64_t num_tokens,
int64_t topk,
int64_t num_groups,
int64_t topk_group,
bool renormalize) {
const int64_t num_experts_per_group = NUM_EXPERTS / num_groups;
at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
alignas(64) float scores[NUM_EXPERTS];
using elem_t = std::pair<float, int32_t>;
std::vector<elem_t> queue(num_groups);
std::vector<elem_t> queue2(topk_group * num_experts_per_group);
for (int64_t i = begin; i < end; ++i) {
// do softmax to get scores
softmax<scalar_t, NUM_EXPERTS>(scores, gating_output + i * NUM_EXPERTS);
// find max score per group
for (int64_t g = 0; g < num_groups; ++g) {
float gmax = -std::numeric_limits<float>::infinity();
for (int64_t e = 0; e < num_experts_per_group; ++e) {
gmax = std::max(gmax, scores[g * num_experts_per_group + e]);
}
queue[g] = {gmax, g};
}
// find group topk
std::partial_sort(
queue.begin(), queue.begin() + topk_group, queue.end(), [](const elem_t& x, const elem_t& y) -> bool {
return x.first > y.first;
});
for (int64_t g = 0; g < topk_group; ++g) {
int32_t group_idx = queue[g].second;
for (int64_t e = 0; e < num_experts_per_group; ++e) {
int32_t expert_idx = group_idx * num_experts_per_group + e;
queue2[g * num_experts_per_group + e] = {scores[expert_idx], expert_idx};
}
}
// find global topk
std::partial_sort(
queue2.begin(), queue2.begin() + topk, queue2.end(), [](const elem_t& x, const elem_t& y) -> bool {
return x.first > y.first;
});
for (int64_t j = 0; j < topk; ++j) {
topk_weights[i * topk + j] = queue2[j].first;
topk_ids[i * topk + j] = queue2[j].second;
}
if (renormalize) {
float sum = 0.f;
for (int64_t j = 0; j < topk; ++j) {
sum += topk_weights[i * topk + j];
}
float scale = 1.f / sum;
for (int64_t j = 0; j < topk; ++j) {
topk_weights[i * topk + j] *= scale;
}
}
}
});
}
template <typename scalar_t, int SIZE>
inline void sigmoid(float* __restrict__ out, const scalar_t* __restrict__ input) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
const fVec one = fVec(1.f);
constexpr int kVecSize = bVec::size();
for (int d = 0; d < SIZE; d += kVecSize) {
bVec x_bvec = bVec::loadu(input + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
x_fvec0 = one / (one + x_fvec0.neg().exp_u20());
x_fvec1 = one / (one + x_fvec1.neg().exp_u20());
x_fvec0.store(out + d);
x_fvec1.store(out + d + fVec::size());
}
}
template <typename scalar_t, int NUM_EXPERTS>
void topk_sigmoid_kernel_impl(
float* __restrict__ topk_weights,
int32_t* __restrict__ topk_ids,
const scalar_t* __restrict__ gating_output,
int64_t num_tokens,
int64_t topk,
bool renormalize) {
using Vec = at::vec::Vectorized<float>;
const int64_t num_experts_per_group = NUM_EXPERTS;
at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
alignas(64) float scores[NUM_EXPERTS];
using elem_t = std::pair<float, int32_t>;
std::vector<elem_t> queue(num_experts_per_group);
for (int64_t i = begin; i < end; ++i) {
at::vec::convert<scalar_t, float>(gating_output + i * NUM_EXPERTS, scores, NUM_EXPERTS);
float gmax = at::vec::reduce_all<float>(
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, scores, num_experts_per_group);
// find position of first max,
// note that we may have multiple max values.
int first_max_idx = -1;
for (int64_t e = 0; e < num_experts_per_group; ++e) {
if (scores[e] == gmax) {
first_max_idx = e;
break;
}
}
// scalar sigmoid
topk_weights[i] = 1.0 / (1.0 + exp(0.0 - gmax));
topk_ids[i] = first_max_idx;
if (renormalize) {
float sum = 0.f;
for (int64_t j = 0; j < topk; ++j) {
sum += topk_weights[i * topk + j];
}
float scale = 1.f / sum;
for (int64_t j = 0; j < topk; ++j) {
topk_weights[i * topk + j] *= scale;
}
}
}
});
}
template <typename scalar_t, int NUM_EXPERTS>
void topk_softmax_kernel_impl(
float* __restrict__ topk_weights,
int32_t* __restrict__ topk_ids,
const scalar_t* __restrict__ gating_output,
int64_t num_tokens,
int64_t topk,
bool renormalize) {
const int64_t num_experts_per_group = NUM_EXPERTS;
at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
alignas(64) float scores[NUM_EXPERTS];
using elem_t = std::pair<float, int32_t>;
std::vector<elem_t> queue(num_experts_per_group);
for (int64_t i = begin; i < end; ++i) {
softmax<scalar_t, NUM_EXPERTS>(scores, gating_output + i * NUM_EXPERTS);
for (int64_t e = 0; e < num_experts_per_group; ++e) {
queue[e] = {scores[e], e};
}
std::partial_sort(
queue.begin(),
queue.begin() + num_experts_per_group,
queue.end(),
[](const elem_t& x, const elem_t& y) -> bool { return x.first > y.first; });
for (int64_t j = 0; j < topk; ++j) {
topk_weights[i * topk + j] = queue[j].first;
topk_ids[i * topk + j] = queue[j].second;
}
if (renormalize) {
float sum = 0.f;
for (int64_t j = 0; j < topk; ++j) {
sum += topk_weights[i * topk + j];
}
float scale = 1.f / sum;
for (int64_t j = 0; j < topk; ++j) {
topk_weights[i * topk + j] *= scale;
}
}
}
});
}
template <typename scalar_t, typename param_t, int SIZE>
inline void
apply_bias(float* __restrict__ scores2, const float* __restrict__ scores, const param_t* __restrict__ bias) {
using fVec = at::vec::Vectorized<float>;
using bVec = at::vec::Vectorized<scalar_t>;
auto vec_size = bVec::size();
int d = 0;
for (; d <= SIZE - vec_size; d += vec_size) {
fVec bias0, bias1, x0, x1;
std::tie(bias0, bias1) = load_float_vec2(bias + d);
std::tie(x0, x1) = load_float_vec2(scores + d);
x0 = x0 + bias0;
x1 = x1 + bias1;
x0.store(scores2 + d);
x1.store(scores2 + d + fVec::size());
}
for (; d < SIZE; d++) {
scores2[d] = scores[d] + (float)bias[d];
}
}
template <typename scalar_t, typename param_t, int NUM_EXPERTS, int TOPK>
void biased_grouped_topk_kernel_impl(
float* __restrict__ topk_weights,
int32_t* __restrict__ topk_ids,
const scalar_t* __restrict__ gating_output,
const param_t* __restrict__ bias,
int64_t num_tokens,
int64_t num_groups,
int64_t topk_group,
bool renormalize) {
using Vec = at::vec::Vectorized<float>;
const int64_t num_experts_per_group = NUM_EXPERTS / num_groups;
at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
// scores: sigmoid
alignas(64) float scores[NUM_EXPERTS];
// scores for choice: sigmoid + bias
alignas(64) float scores2[NUM_EXPERTS];
using elem_t = std::pair<float, int32_t>;
std::vector<elem_t> queue(num_groups);
std::vector<elem_t> queue2(topk_group * num_experts_per_group);
for (int64_t i = begin; i < end; ++i) {
// do sigmoid to get scores
sigmoid<scalar_t, NUM_EXPERTS>(scores, gating_output + i * NUM_EXPERTS);
apply_bias<scalar_t, param_t, NUM_EXPERTS>(scores2, scores, bias);
for (int64_t g = 0; g < num_groups; ++g) {
// find the max
float gmax = at::vec::reduce_all<float>(
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); },
scores2 + g * num_experts_per_group,
num_experts_per_group);
// find position of first max,
// note that we may have multiple max values.
int first_max_idx = -1;
for (int64_t e = 0; e < num_experts_per_group; ++e) {
if (scores2[g * num_experts_per_group + e] == gmax) {
first_max_idx = g * num_experts_per_group + e;
break;
}
}
// find the 2nd max
scores2[first_max_idx] = -std::numeric_limits<float>::infinity();
float gmax2 = at::vec::reduce_all<float>(
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); },
scores2 + g * num_experts_per_group,
num_experts_per_group);
// restore scores for choice
scores2[first_max_idx] = gmax;
queue[g] = {gmax + gmax2, g};
}
// find group topk
std::partial_sort(
queue.begin(), queue.begin() + topk_group, queue.end(), [](const elem_t& x, const elem_t& y) -> bool {
return x.first > y.first;
});
for (int64_t g = 0; g < topk_group; ++g) {
int32_t group_idx = queue[g].second;
for (int64_t e = 0; e < num_experts_per_group; ++e) {
int32_t expert_idx = group_idx * num_experts_per_group + e;
queue2[g * num_experts_per_group + e] = {scores2[expert_idx], expert_idx};
}
}
// find global topk
std::partial_sort(
queue2.begin(), queue2.begin() + TOPK, queue2.end(), [](const elem_t& x, const elem_t& y) -> bool {
return x.first > y.first;
});
for (int j = 0; j < TOPK; ++j) {
int32_t index = queue2[j].second;
topk_ids[i * TOPK + j] = index;
topk_weights[i * TOPK + j] = scores[index];
}
#if defined(CPU_CAPABILITY_AVX512)
if (renormalize) {
__mmask16 mask = (1ULL << TOPK) - 1;
__m512 x = _mm512_maskz_loadu_ps(mask, topk_weights + i * TOPK);
float sum = _mm512_reduce_add_ps(x);
__m512 vscale = _mm512_set1_ps(1.f / sum);
__m512 y = _mm512_mul_ps(x, vscale);
_mm512_mask_storeu_ps(topk_weights + i * TOPK, mask, y);
}
#else
if (renormalize) {
float sum = 0.f;
for (int64_t j = 0; j < TOPK; ++j) {
sum += topk_weights[i * TOPK + j];
}
float scale = 1.f / sum;
for (int64_t j = 0; j < TOPK; ++j) {
topk_weights[i * TOPK + j] *= scale;
}
}
#endif
}
});
}
#define LAUNCH_GROUPED_TOPK_KERNEL(NE) \
grouped_topk_kernel_impl<scalar_t, NE>( \
topk_weights.data_ptr<float>(), \
topk_ids.data_ptr<int32_t>(), \
gating_output.data_ptr<scalar_t>(), \
num_tokens, \
topk, \
num_expert_group, \
topk_group, \
renormalize);
#define LAUNCH_TOPK_SIGMOID_KERNEL(NE) \
topk_sigmoid_kernel_impl<scalar_t, NE>( \
topk_weights.data_ptr<float>(), \
topk_ids.data_ptr<int32_t>(), \
gating_output.data_ptr<scalar_t>(), \
num_tokens, \
topk, \
renormalize);
#define LAUNCH_TOPK_SOFTMAX_KERNEL(NE) \
topk_softmax_kernel_impl<scalar_t, NE>( \
topk_weights.data_ptr<float>(), \
topk_ids.data_ptr<int32_t>(), \
gating_output.data_ptr<scalar_t>(), \
num_tokens, \
topk, \
renormalize);
#define LAUNCH_BIASED_GROUPED_TOPK_KERNEL(NE, NTOPK) \
biased_grouped_topk_kernel_impl<scalar_t, param_t, NE, NTOPK>( \
topk_weights.data_ptr<float>(), \
topk_ids.data_ptr<int32_t>(), \
gating_output.data_ptr<scalar_t>(), \
correction_bias.data_ptr<param_t>(), \
num_tokens, \
num_expert_group, \
topk_group, \
renormalize);
} // anonymous namespace
std::tuple<at::Tensor, at::Tensor>
topk_sigmoid_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize) {
RECORD_FUNCTION("sgl-kernel::topk_sigmoid_cpu", std::vector<c10::IValue>({hidden_states, gating_output}));
CHECK_INPUT(gating_output);
const auto st = hidden_states.scalar_type();
CHECK_EQ(gating_output.scalar_type(), st);
int64_t num_tokens = hidden_states.size(0);
int64_t num_experts = gating_output.size(1);
TORCH_CHECK(gating_output.size(0) == num_tokens, "Number of tokens mismatch");
TORCH_CHECK(topk == 1, "topk_sigmoid only supports topk=1 case");
at::Tensor topk_weights = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kFloat));
at::Tensor topk_ids = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kInt));
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "topk_sigmoid_kernel", [&] {
switch (num_experts) {
case 1:
LAUNCH_TOPK_SIGMOID_KERNEL(1);
break;
case 2:
LAUNCH_TOPK_SIGMOID_KERNEL(2);
break;
case 4:
LAUNCH_TOPK_SIGMOID_KERNEL(4);
break;
case 8:
LAUNCH_TOPK_SIGMOID_KERNEL(8);
break;
case 16:
LAUNCH_TOPK_SIGMOID_KERNEL(16);
break;
case 32:
LAUNCH_TOPK_SIGMOID_KERNEL(32);
break;
case 64:
LAUNCH_TOPK_SIGMOID_KERNEL(64);
break;
case 128:
LAUNCH_TOPK_SIGMOID_KERNEL(128);
break;
case 160:
LAUNCH_TOPK_SIGMOID_KERNEL(160);
break;
case 256:
LAUNCH_TOPK_SIGMOID_KERNEL(256);
break;
default:
TORCH_CHECK(false, "Unexpected num_experts: ", num_experts);
}
});
return std::make_tuple(topk_weights, topk_ids);
}
std::tuple<at::Tensor, at::Tensor>
topk_softmax_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize) {
RECORD_FUNCTION("sgl-kernel::topk_softmax_cpu", std::vector<c10::IValue>({hidden_states, gating_output}));
CHECK_INPUT(gating_output);
const auto st = hidden_states.scalar_type();
CHECK_EQ(gating_output.scalar_type(), st);
int64_t num_tokens = hidden_states.size(0);
int64_t num_experts = gating_output.size(1);
TORCH_CHECK(gating_output.size(0) == num_tokens, "Number of tokens mismatch");
at::Tensor topk_weights = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kFloat));
at::Tensor topk_ids = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kInt));
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "topk_softmax_cpu", [&] {
switch (num_experts) {
case 1:
LAUNCH_TOPK_SOFTMAX_KERNEL(1);
break;
case 2:
LAUNCH_TOPK_SOFTMAX_KERNEL(2);
break;
case 4:
LAUNCH_TOPK_SOFTMAX_KERNEL(4);
break;
case 8:
LAUNCH_TOPK_SOFTMAX_KERNEL(8);
break;
case 16:
LAUNCH_TOPK_SOFTMAX_KERNEL(16);
break;
case 32:
LAUNCH_TOPK_SOFTMAX_KERNEL(32);
break;
case 64:
LAUNCH_TOPK_SOFTMAX_KERNEL(64);
break;
case 128:
LAUNCH_TOPK_SOFTMAX_KERNEL(128);
break;
case 160:
LAUNCH_TOPK_SOFTMAX_KERNEL(160);
break;
case 256:
LAUNCH_TOPK_SOFTMAX_KERNEL(256);
break;
default:
TORCH_CHECK(false, "Unexpected num_experts: ", num_experts);
}
});
return std::make_tuple(topk_weights, topk_ids);
}
// grouped topk for DeepSeek V2
std::tuple<at::Tensor, at::Tensor> grouped_topk_cpu(
at::Tensor& hidden_states,
at::Tensor& gating_output,
int64_t topk,
bool renormalize,
int64_t num_expert_group,
int64_t topk_group,
int64_t num_fused_shared_experts,
std::optional<double> routed_scaling_factor,
std::optional<at::Tensor> num_token_non_padded) {
// TODO: Will support num_fused_shared_experts, routed_scaling_factor and num_token_non_padded.
// For now, we just check them as default value.
TORCH_CHECK(
num_fused_shared_experts == 0,
"num_fused_shared_experts must be 0 default value, got: ",
num_fused_shared_experts);
TORCH_CHECK(
!routed_scaling_factor.has_value() || routed_scaling_factor.value() == 1.0f,
"routed_scaling_factor must be None or 1.0f default value, got: ",
routed_scaling_factor.value());
TORCH_CHECK(
!num_token_non_padded.has_value(),
"num_token_non_padded must be None default value, got: ",
num_token_non_padded.value());
RECORD_FUNCTION("sgl-kernel::grouped_topk_cpu", std::vector<c10::IValue>({hidden_states, gating_output}));
CHECK_INPUT(gating_output);
const auto st = hidden_states.scalar_type();
CHECK_EQ(gating_output.scalar_type(), st);
int64_t num_tokens = hidden_states.size(0);
int64_t num_experts = gating_output.size(1);
TORCH_CHECK(gating_output.size(0) == num_tokens, "Number of tokens mismatch");
at::Tensor topk_weights = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kFloat));
at::Tensor topk_ids = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kInt));
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "grouped_topk_kernel", [&] {
switch (num_experts) {
case 1:
LAUNCH_GROUPED_TOPK_KERNEL(1);
break;
case 2:
LAUNCH_GROUPED_TOPK_KERNEL(2);
break;
case 4:
LAUNCH_GROUPED_TOPK_KERNEL(4);
break;
case 8:
LAUNCH_GROUPED_TOPK_KERNEL(8);
break;
case 16:
LAUNCH_GROUPED_TOPK_KERNEL(16);
break;
case 32:
LAUNCH_GROUPED_TOPK_KERNEL(32);
break;
case 64:
LAUNCH_GROUPED_TOPK_KERNEL(64);
break;
case 128:
LAUNCH_GROUPED_TOPK_KERNEL(128);
break;
case 160:
LAUNCH_GROUPED_TOPK_KERNEL(160);
break;
case 256:
LAUNCH_GROUPED_TOPK_KERNEL(256);
break;
default:
TORCH_CHECK(false, "Unexpected num_experts: ", num_experts);
}
});
return std::make_tuple(topk_weights, topk_ids);
}
// biased grouped topk DeepSeek V3/R1
std::tuple<at::Tensor, at::Tensor> biased_grouped_topk_cpu(
at::Tensor& hidden_states,
at::Tensor& gating_output,
at::Tensor& correction_bias,
int64_t topk,
bool renormalize,
int64_t num_expert_group,
int64_t topk_group,
int64_t num_fused_shared_experts,
std::optional<double> routed_scaling_factor,
std::optional<at::Tensor> num_token_non_padded) {
// TODO: Will support num_fused_shared_experts, routed_scaling_factor and num_token_non_padded.
// For now, we just check them as default value.
TORCH_CHECK(
num_fused_shared_experts == 0,
"num_fused_shared_experts must be 0 default value, got: ",
num_fused_shared_experts);
TORCH_CHECK(
!num_token_non_padded.has_value(),
"num_token_non_padded must be None default value, got: ",
num_token_non_padded.value());
RECORD_FUNCTION(
"sgl-kernel::biased_grouped_topk_cpu", std::vector<c10::IValue>({hidden_states, gating_output, correction_bias}));
CHECK_INPUT(gating_output);
CHECK_INPUT(correction_bias);
const auto st = hidden_states.scalar_type();
CHECK_EQ(gating_output.scalar_type(), st);
int64_t num_tokens = hidden_states.size(0);
int64_t num_experts = gating_output.size(1);
TORCH_CHECK(gating_output.size(0) == num_tokens, "Number of tokens mismatch");
TORCH_CHECK(correction_bias.numel() == num_experts, "Bias shape mismatch");
at::Tensor topk_weights = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kFloat));
at::Tensor topk_ids = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kInt));
CPU_DISPATCH_REDUCED_FLOATING_TYPES_EXT(st, correction_bias.scalar_type(), "biased_grouped_topk_kernel", [&] {
TORCH_CHECK(topk == 8, "Unexpected topk: ", topk);
switch (num_experts) {
case 256:
LAUNCH_BIASED_GROUPED_TOPK_KERNEL(256, 8);
break;
default:
TORCH_CHECK(false, "Unexpected num_experts: ", num_experts);
}
});
return std::make_tuple(topk_weights, topk_ids);
}

View File

@@ -0,0 +1,373 @@
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <ATen/ATen.h>
#include <torch/all.h>
#include <torch/library.h>
#include "sgl_kernel_ops.h"
#include "shm.h"
// silu_and_mul
at::Tensor silu_and_mul_cpu(at::Tensor& input);
// gelu_and_mul
at::Tensor gelu_tanh_and_mul_cpu(const at::Tensor& input);
at::Tensor gelu_and_mul_cpu(const at::Tensor& input);
// l2norm
at::Tensor l2norm_cpu(at::Tensor& input, double eps);
// rmsnorm
at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps);
// fused_add_rmsnorm
void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps);
// topk
std::tuple<at::Tensor, at::Tensor>
topk_sigmoid_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize);
std::tuple<at::Tensor, at::Tensor>
topk_softmax_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize);
std::tuple<at::Tensor, at::Tensor> grouped_topk_cpu(
at::Tensor& hidden_states,
at::Tensor& gating_output,
int64_t topk,
bool renormalize,
int64_t num_expert_group,
int64_t topk_group,
int64_t num_fused_shared_experts,
std::optional<double> routed_scaling_factor,
std::optional<at::Tensor> num_token_non_padded);
std::tuple<at::Tensor, at::Tensor> biased_grouped_topk_cpu(
at::Tensor& hidden_states,
at::Tensor& gating_output,
at::Tensor& correction_bias,
int64_t topk,
bool renormalize,
int64_t num_expert_group,
int64_t topk_group,
int64_t num_fused_shared_experts,
std::optional<double> routed_scaling_factor,
std::optional<at::Tensor> num_token_non_padded);
// attention
void decode_attention_cpu(
at::Tensor& query,
at::Tensor& k_cache,
at::Tensor& v_cache,
at::Tensor& output,
at::Tensor& key,
at::Tensor& value,
at::Tensor& loc,
at::Tensor& attn_logits,
at::Tensor& req_to_token,
at::Tensor& req_pool_indices,
at::Tensor& seq_lens,
double sm_scale,
double logit_cap);
void extend_attention_cpu(
at::Tensor& q_extend,
at::Tensor& k_extend,
at::Tensor& v_extend,
at::Tensor& o_extend,
at::Tensor& k_buffer,
at::Tensor& v_buffer,
at::Tensor& req_to_token,
at::Tensor& req_pool_indices,
at::Tensor& seq_lens,
at::Tensor& extend_seq_lens,
at::Tensor& extend_start_loc,
int64_t max_len_extend,
double sm_scale,
double logit_cap);
// weight prepack
at::Tensor convert_weight_packed(at::Tensor& weight);
// quant
std::tuple<at::Tensor, at::Tensor> per_token_quant_int8_cpu(at::Tensor& A);
// gemm
at::Tensor
weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, const std::optional<at::Tensor>& bias, bool is_vnni);
// igemm
at::Tensor int8_scaled_mm_cpu(
at::Tensor& mat1,
at::Tensor& mat2,
at::Tensor& scales1,
at::Tensor& scales2,
const std::optional<at::Tensor>& bias,
at::ScalarType out_dtype,
bool is_vnni);
// fp8 gemm
at::Tensor fp8_scaled_mm_cpu(
at::Tensor& mat1,
at::Tensor& mat2,
at::Tensor& scales2,
std::vector<int64_t> block_size,
const std::optional<at::Tensor>& bias,
at::ScalarType out_dtype,
bool is_vnni);
// quant + igemm
at::Tensor int8_scaled_mm_with_quant(
at::Tensor& mat1,
at::Tensor& mat2,
at::Tensor& scales2,
const std::optional<at::Tensor>& bias,
at::ScalarType out_dtype,
bool is_vnni);
// bmm
void bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, const std::optional<at::Tensor>& scale);
// fused moe
at::Tensor fused_experts_cpu(
at::Tensor& hidden_states,
at::Tensor& w1,
at::Tensor& w2,
at::Tensor& topk_weights,
at::Tensor& topk_ids,
bool inplace,
bool use_int8_w8a8,
bool use_fp8_w8a16,
const std::optional<at::Tensor>& w1_scale,
const std::optional<at::Tensor>& w2_scale,
const std::optional<std::vector<int64_t>> block_size,
const std::optional<at::Tensor>& a1_scale,
const std::optional<at::Tensor>& a2_scale,
bool is_vnni);
at::Tensor shared_expert_cpu(
at::Tensor& hidden_states,
at::Tensor& w1,
at::Tensor& w2,
at::Tensor& fused_experts_out,
double routed_scaling_factor,
bool inplace,
bool use_int8_w8a8,
bool use_fp8_w8a16,
const std::optional<at::Tensor>& w1_scale,
const std::optional<at::Tensor>& w2_scale,
const std::optional<std::vector<int64_t>> block_size,
const std::optional<at::Tensor>& a1_scale,
const std::optional<at::Tensor>& a2_scale,
bool is_vnni);
// weight absorption
std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
at::Tensor& hidden_states,
at::Tensor& q_a_proj_weight,
at::Tensor& q_b_proj_weight,
at::Tensor& kv_a_proj_weight,
at::Tensor& w_kc,
at::Tensor& q_a_layernorm_weight,
at::Tensor& kv_a_layernorm_weight,
at::Tensor& positions,
at::Tensor& cos_sin_cache,
double eps,
bool use_int8_w8a8,
bool use_fp8_w8a16,
std::optional<at::Tensor> q_a_proj_scale,
std::optional<at::Tensor> q_b_proj_scale,
std::optional<at::Tensor> kv_a_proj_scale,
bool is_vnni,
std::optional<std::vector<int64_t>> block_size);
std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope_fused_weight(
at::Tensor& hidden_states,
at::Tensor& qkv_a_proj_weight,
at::Tensor& q_b_proj_weight,
at::Tensor& w_kc,
at::Tensor& q_a_layernorm_weight,
at::Tensor& kv_a_layernorm_weight,
at::Tensor& positions,
at::Tensor& cos_sin_cache,
double eps,
bool use_int8_w8a8,
bool use_fp8_w8a16,
std::optional<at::Tensor> qkv_a_proj_scale,
std::optional<at::Tensor> q_b_proj_scale,
bool is_vnni,
std::optional<std::vector<int64_t>> block_size,
int64_t q_lora_rank,
int64_t kv_lora_rank,
int64_t qk_rope_head_dim);
// shared memory init
void initialize(int64_t size, int64_t rank);
// shared mmeory all_reduce
void shm_allreduce(at::Tensor& data, int64_t op);
// shared memory all_gather
at::Tensor shm_allgather(at::Tensor& data, int64_t dim);
// rope
std::tuple<at::Tensor, at::Tensor> rotary_embedding_cpu(
at::Tensor& positions,
at::Tensor& query,
at::Tensor& key,
int64_t head_size,
at::Tensor& cos_sin_cache,
bool is_neox);
// CPU and memory binding
std::string init_cpu_threads_env(const std::string& cpu_ids);
TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
// activation
m.def("silu_and_mul_cpu(Tensor input) -> Tensor");
m.impl("silu_and_mul_cpu", torch::kCPU, &silu_and_mul_cpu);
m.def("gelu_tanh_and_mul_cpu(Tensor input) -> Tensor");
m.impl("gelu_tanh_and_mul_cpu", torch::kCPU, &gelu_tanh_and_mul_cpu);
m.def("gelu_and_mul_cpu(Tensor input) -> Tensor");
m.impl("gelu_and_mul_cpu", torch::kCPU, &gelu_and_mul_cpu);
// norm
m.def("rmsnorm_cpu(Tensor input, Tensor weight, float eps) -> Tensor");
m.impl("rmsnorm_cpu", torch::kCPU, &rmsnorm_cpu);
m.def("l2norm_cpu(Tensor input, float eps) -> Tensor");
m.impl("l2norm_cpu", torch::kCPU, &l2norm_cpu);
m.def("fused_add_rmsnorm_cpu(Tensor(a!) input, Tensor residual, Tensor weight, float eps) -> ()");
m.impl("fused_add_rmsnorm_cpu", torch::kCPU, &fused_add_rmsnorm_cpu);
// topk
m.def("topk_sigmoid_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize) -> (Tensor, Tensor)");
m.impl("topk_sigmoid_cpu", torch::kCPU, &topk_sigmoid_cpu);
m.def("topk_softmax_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize) -> (Tensor, Tensor)");
m.impl("topk_softmax_cpu", torch::kCPU, &topk_softmax_cpu);
m.def(
"grouped_topk_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize, int num_expert_group, "
"int topk_group, int num_fused_shared_experts, float? routed_scaling_factor, Tensor? num_token_non_padded) -> "
"(Tensor, Tensor)");
m.impl("grouped_topk_cpu", torch::kCPU, &grouped_topk_cpu);
// biased group topk
m.def(
"biased_grouped_topk_cpu(Tensor hidden_states, Tensor gating_output, Tensor correction_bias, int topk, bool "
"renormalize, int num_expert_group, int topk_group, int num_fused_shared_experts, float? routed_scaling_factor, "
"Tensor? num_token_non_padded) -> (Tensor, Tensor)");
m.impl("biased_grouped_topk_cpu", torch::kCPU, &biased_grouped_topk_cpu);
// decode
m.def(
"decode_attention_cpu(Tensor query, Tensor k_cache, Tensor v_cahce, Tensor(a!) output, Tensor key, Tensor value, "
"Tensor loc, Tensor attn_logits, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, float sm_scale, "
"float logit_cap) -> ()");
m.impl("decode_attention_cpu", torch::kCPU, &decode_attention_cpu);
// extend
m.def(
"extend_attention_cpu(Tensor q_extend, Tensor k_extend, Tensor v_extend, Tensor(a!) o_extend, Tensor k_buffer, "
"Tensor v_buffer, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, Tensor extend_seq_lens, Tensor "
"extend_start_loc, int max_len_extend, float sm_scale, float logit_cap) -> ()");
m.impl("extend_attention_cpu", torch::kCPU, &extend_attention_cpu);
// weight prepack
m.def("convert_weight_packed(Tensor weight) -> Tensor");
m.impl("convert_weight_packed", torch::kCPU, &convert_weight_packed);
// quant
m.def("per_token_quant_int8_cpu(Tensor A) -> (Tensor, Tensor)");
m.impl("per_token_quant_int8_cpu", torch::kCPU, &per_token_quant_int8_cpu);
// gemm
m.def("weight_packed_linear(Tensor mat1, Tensor mat2, Tensor? bias, bool is_vnni) -> Tensor");
m.impl("weight_packed_linear", torch::kCPU, &weight_packed_linear);
// igemm
m.def(
"int8_scaled_mm_cpu(Tensor mat1, Tensor mat2, Tensor scales1, Tensor scales2, Tensor? bias, ScalarType "
"out_dtype, bool is_vnni) -> Tensor");
m.impl("int8_scaled_mm_cpu", torch::kCPU, &int8_scaled_mm_cpu);
// fp8 gemm
m.def(
"fp8_scaled_mm_cpu(Tensor mat1, Tensor mat2, Tensor scales2, int[] block_size, Tensor? bias, ScalarType "
"out_dtype, bool is_vnni) -> Tensor");
m.impl("fp8_scaled_mm_cpu", torch::kCPU, &fp8_scaled_mm_cpu);
// quant + igemm
m.def(
"int8_scaled_mm_with_quant(Tensor mat1, Tensor mat2, Tensor scales2, Tensor? bias, ScalarType out_dtype, bool "
"is_vnni) -> Tensor");
m.impl("int8_scaled_mm_with_quant", torch::kCPU, &int8_scaled_mm_with_quant);
// bmm
m.def("bmm_cpu(Tensor(a!) out, Tensor mat1, Tensor mat2, bool is_vnni, Tensor? scale) -> ()");
m.impl("bmm_cpu", torch::kCPU, &bmm_cpu);
// moe
m.def(
"fused_experts_cpu(Tensor hidden_states, Tensor w1, Tensor w2, Tensor topk_weights, Tensor topk_ids, bool "
"inplace, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? w1_scale, Tensor? w2_scale, int[]? block_size, Tensor? "
"a1_scale, Tensor? a2_scale, bool "
"is_vnni) -> Tensor");
m.impl("fused_experts_cpu", torch::kCPU, &fused_experts_cpu);
// weight absorption
m.def(
"qkv_proj_with_rope(Tensor hidden_states, Tensor q_a_proj_weight, Tensor q_b_proj_weight, Tensor "
"kv_a_proj_weight, Tensor w_kc, Tensor q_a_layernorm_weight, Tensor kv_a_layernorm_weight, Tensor positions, "
"Tensor cos_sin_cache, float eps, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? q_a_proj_scale, Tensor? "
"q_b_proj_scale, Tensor? "
"kv_a_proj_scale, bool is_vnni, int[]? block_size) -> (Tensor, Tensor, Tensor)");
m.impl("qkv_proj_with_rope", torch::kCPU, &qkv_proj_with_rope);
m.def(
"qkv_proj_with_rope_fused_weight(Tensor hidden_states, Tensor qkv_a_proj_weight, Tensor q_b_proj_weight, "
"Tensor w_kc, Tensor q_a_layernorm_weight, Tensor kv_a_layernorm_weight, Tensor positions, "
"Tensor cos_sin_cache, float eps, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? qkv_a_proj_scale, Tensor? "
"q_b_proj_scale,"
"bool is_vnni, int[]? block_size, int q_lora_rank, int kv_lora_rank,"
"int qk_rope_head_dim) -> (Tensor, Tensor, Tensor)");
m.impl("qkv_proj_with_rope_fused_weight", torch::kCPU, &qkv_proj_with_rope_fused_weight);
// shared expert
m.def(
"shared_expert_cpu(Tensor hidden_states, Tensor w1, Tensor w2, Tensor fused_experts_out, float "
"routed_scaling_factor, bool inplace, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? w1_scale, Tensor? "
"w2_scale, int[]? block_size, Tensor? a1_scale, Tensor? a2_scale, bool is_vnni) -> Tensor");
m.impl("shared_expert_cpu", torch::kCPU, &shared_expert_cpu);
// all reduce
m.def("initialize(int size, int rank) -> ()");
m.def("shm_allreduce(Tensor(a!) data, int reduce_op) -> ()");
m.impl("shm_allreduce", torch::kCPU, &shm_allreduce);
m.def("shm_allgather(Tensor data, int dim) -> Tensor");
m.impl("shm_allgather", torch::kCPU, &shm_allgather);
// rope
m.def(
"rotary_embedding_cpu(Tensor positions, Tensor query, Tensor key, int head_size, Tensor cos_sin_cache, "
"bool is_neox) -> (Tensor, Tensor)");
m.impl("rotary_embedding_cpu", torch::kCPU, &rotary_embedding_cpu);
// CPU and memory binding
m.def("init_cpu_threads_env(str cpu_ids) -> str");
}
TORCH_LIBRARY_IMPL(sgl_kernel, CatchAll, m) {
m.impl("init_cpu_threads_env", init_cpu_threads_env);
m.impl("initialize", &initialize);
}
REGISTER_EXTENSION(common_ops)

308
sgl-kernel/csrc/cpu/vec.h Normal file
View File

@@ -0,0 +1,308 @@
#pragma once
#if defined(__AVX512F__) && defined(__AVX512BF16__) && defined(__AMX_BF16__)
#define CPU_CAPABILITY_AVX512
#endif
#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>
namespace {
using namespace at::vec;
template <typename scalar_t, typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline Vectorized<scalar_t> convert_from_float_ext(const Vectorized<float>& a, const Vectorized<float>& b) {
return at::vec::convert_from_float<scalar_t>(a, b);
}
// allow f16, bf16
template <typename scalar_t, typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 1>
inline std::tuple<Vectorized<float>, Vectorized<float>> load_float_vec2(const scalar_t* __restrict__ data) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
bVec x_vec = bVec::loadu(data);
fVec x0, x1;
std::tie(x0, x1) = at::vec::convert_to_float(x_vec);
return std::make_tuple(x0, x1);
}
// allow f32
inline std::tuple<Vectorized<float>, Vectorized<float>> load_float_vec2(const float* __restrict__ data) {
using fVec = at::vec::Vectorized<float>;
fVec x0 = fVec::loadu(data);
fVec x1 = fVec::loadu(data + fVec::size());
return std::make_tuple(x0, x1);
}
#if defined(CPU_CAPABILITY_AVX512)
// `at::vec::convert_from_float<>` from PyTorch doesn't have avx512-bf16 intrinsics
// use native instruction for bfloat16->float32 conversion
template <>
inline Vectorized<at::BFloat16>
convert_from_float_ext<at::BFloat16>(const Vectorized<float>& a, const Vectorized<float>& b) {
return (__m512i)(_mm512_cvtne2ps_pbh(__m512(b), __m512(a)));
}
#define CVT_BF16_TO_FP32(a) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16))
#define CVT_FP16_TO_FP32(a) _mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))
// this doesn't handle NaN.
inline __m512bh cvt_e4m3_bf16_intrinsic_no_nan(__m256i fp8_vec) {
const __m512i x = _mm512_cvtepu8_epi16(fp8_vec);
const __m512i mant = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x07)), 4);
const __m512i raw_exp = _mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x78)), 3);
const __m512i exp = _mm512_slli_epi16(_mm512_add_epi16(raw_exp, _mm512_set1_epi16(120)), 7);
const __m512i nonsign = _mm512_or_si512(exp, mant);
const __m512i sign = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x80)), 8);
const __m512i combined = _mm512_or_si512(nonsign, sign);
const __mmask32 is_nonzero = _mm512_cmpneq_epi16_mask(x, _mm512_setzero_si512());
return (__m512bh)_mm512_maskz_mov_epi16(is_nonzero, combined);
}
inline __m512bh cvt_e4m3_bf16_intrinsic_without_denorm(__m256i fp8_vec) {
// The following conversion is without denorm behavior, that is to say,
// Max subnorm : S.0000.111 = 0.875 2**(6)
// Min subnorm : S.0000.001 = 2**(9)
// 0.0019 ~ 0.0137 cannot be converted correctly.
__m512i x = _mm512_cvtepu8_epi16(fp8_vec);
auto mask = _mm512_cmpneq_epi16_mask(
_mm512_and_si512(x, _mm512_set1_epi16(127)),
_mm512_setzero_si512()); // mask = x & 0x7f
auto mask_nan = _mm512_cmpneq_epi16_mask(
_mm512_and_si512(x, _mm512_set1_epi16(127)),
_mm512_set1_epi16(127)); // mask_nan = x & 0x7f
auto mantissa = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(7)), 4); // mantissa = (x & 7) << 4
auto exponent = _mm512_add_epi16(
_mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(120)), 3),
_mm512_set1_epi16(120)); // exponent = (((x >> 3) & 15) + 120)
auto nonsign = _mm512_maskz_mov_epi16(mask, _mm512_or_si512(mantissa, _mm512_slli_epi16(exponent, 7)));
nonsign = _mm512_mask_mov_epi16(_mm512_set1_epi16(0x7fff), mask_nan, nonsign); // deal with Nan
return (__m512bh)(_mm512_or_si512(
nonsign,
_mm512_slli_epi16(
_mm512_and_si512(x, _mm512_set1_epi16(128)),
8))); // add sign (x & 128) << 8
}
inline __m512bh cvt_e4m3_bf16_intrinsic_with_denorm(__m256i fp8_vec) {
__m512i x = _mm512_cvtepu8_epi16(fp8_vec);
__m512i lg2mant = _mm512_mask_mov_epi16(
_mm512_mask_mov_epi16(
_mm512_setzero_si512(), _mm512_test_epi16_mask(x, _mm512_set1_epi16(2)), _mm512_set1_epi16(1)),
_mm512_test_epi16_mask(x, _mm512_set1_epi16(4)),
_mm512_set1_epi16(2));
return (__m512bh)(_mm512_or_si512(
_mm512_maskz_mov_epi16(
_mm512_cmpneq_epi16_mask(_mm512_and_si512(x, _mm512_set1_epi16(127)), _mm512_setzero_si512()),
_mm512_mask_blend_epi16(
_mm512_test_epi16_mask(x, _mm512_set1_epi16(120)),
_mm512_or_si512(
_mm512_and_si512(
_mm512_sllv_epi16(
_mm512_and_si512(x, _mm512_set1_epi16(3)), _mm512_sub_epi16(_mm512_set1_epi16(7), lg2mant)),
_mm512_set1_epi16(0x007f)),
_mm512_slli_epi16(_mm512_add_epi16(lg2mant, _mm512_set1_epi16(118)), 7)),
_mm512_or_si512(
_mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(7)), 4),
_mm512_slli_epi16(
_mm512_add_epi16(
_mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(120)), 3), _mm512_set1_epi16(120)),
7)))),
_mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(128)), 8)));
}
inline __m512bh CVT_FP8_TO_BF16(__m256i a) {
#ifdef SGLANG_CPU_FP8_CVT_FTZ
return cvt_e4m3_bf16_intrinsic_no_nan(a);
#else
return cvt_e4m3_bf16_intrinsic_with_denorm(a);
#endif
}
#endif
// vector to scalar reduction
#if defined(CPU_CAPABILITY_AVX512) && 0
inline float vec_reduce_sum(const Vectorized<float>& a) {
return _mm512_reduce_add_ps(__m512(a));
}
inline float vec_reduce_max(const Vectorized<float>& a) {
return _mm512_reduce_max_ps(__m512(a));
}
#else
inline float vec_reduce_sum(const Vectorized<float>& a) {
return vec_reduce_all([](Vectorized<float>& x, Vectorized<float>& y) { return x + y; }, a);
}
inline float vec_reduce_max(const Vectorized<float>& a) {
return vec_reduce_all([](Vectorized<float>& x, Vectorized<float>& y) { return maximum(x, y); }, a);
}
#endif
// https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
template <typename scalar_t>
inline void
quantize_row_int8(uint8_t* __restrict__ Aq, float& As, const scalar_t* __restrict__ A, int64_t K, float eps = 1e-7) {
float amax = 0.f; // absolute max
for (int64_t k = 0; k < K; ++k) {
const float val = static_cast<float>(A[k]);
amax = std::max(amax, std::abs(val));
}
amax = std::max(amax, eps);
const float scale = amax / 127;
const float inv_scale = 127 / amax;
for (int64_t k = 0; k < K; ++k) {
const float val = static_cast<float>(A[k]) * inv_scale;
Aq[k] = (uint8_t)(std::round(val)) + 128;
}
As = scale;
}
#if defined(CPU_CAPABILITY_AVX512)
template <>
inline void quantize_row_int8<at::BFloat16>(
uint8_t* __restrict__ Aq, float& As, const at::BFloat16* __restrict__ A, int64_t K, float eps) {
const __m512 signBit = _mm512_set1_ps(-0.0f);
const __m512i off = _mm512_set1_epi32(128);
// K is 32x, no remainder
float amax = 0.f;
__m512 vamax0 = _mm512_set1_ps(0.f);
__m512 vamax1 = _mm512_set1_ps(0.f);
for (int64_t k = 0; k < K; k += 32) {
__m512i va = _mm512_loadu_si512((void*)(A + k));
__m512 va0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 0));
__m512 va1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 1));
vamax0 = _mm512_max_ps(vamax0, _mm512_andnot_ps(signBit, va0));
vamax1 = _mm512_max_ps(vamax1, _mm512_andnot_ps(signBit, va1));
}
amax = _mm512_reduce_max_ps(_mm512_max_ps(vamax0, vamax1));
amax = std::max(amax, eps);
const float scale = amax / 127;
const float inv_scale = 127 / amax;
const __m512 vd = _mm512_set1_ps(inv_scale);
for (int64_t k = 0; k < K; k += 32) {
__m512i va = _mm512_loadu_si512((void*)(A + k));
__m512 va0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 0));
__m512 va1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 1));
va0 = _mm512_mul_ps(va0, vd);
va1 = _mm512_mul_ps(va1, vd);
va0 = _mm512_roundscale_ps(va0, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
va1 = _mm512_roundscale_ps(va1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
__m128i i0 = _mm512_cvtepi32_epi8(_mm512_add_epi32(_mm512_cvtps_epi32(va0), off));
__m128i i1 = _mm512_cvtepi32_epi8(_mm512_add_epi32(_mm512_cvtps_epi32(va1), off));
_mm256_storeu_si256(reinterpret_cast<__m256i*>(Aq + k), _mm256_set_m128i(i1, i0));
}
As = scale;
}
#endif
// transpose utils
// taken from my PR in ggml: https://github.com/ggml-org/llama.cpp/pull/8998
#if defined(CPU_CAPABILITY_AVX512)
inline void transpose_16x16_32bit(__m512i* v) {
__m512i v1[16];
v1[0] = _mm512_unpacklo_epi32(v[0], v[1]);
v1[1] = _mm512_unpackhi_epi32(v[0], v[1]);
v1[2] = _mm512_unpacklo_epi32(v[2], v[3]);
v1[3] = _mm512_unpackhi_epi32(v[2], v[3]);
v1[4] = _mm512_unpacklo_epi32(v[4], v[5]);
v1[5] = _mm512_unpackhi_epi32(v[4], v[5]);
v1[6] = _mm512_unpacklo_epi32(v[6], v[7]);
v1[7] = _mm512_unpackhi_epi32(v[6], v[7]);
v1[8] = _mm512_unpacklo_epi32(v[8], v[9]);
v1[9] = _mm512_unpackhi_epi32(v[8], v[9]);
v1[10] = _mm512_unpacklo_epi32(v[10], v[11]);
v1[11] = _mm512_unpackhi_epi32(v[10], v[11]);
v1[12] = _mm512_unpacklo_epi32(v[12], v[13]);
v1[13] = _mm512_unpackhi_epi32(v[12], v[13]);
v1[14] = _mm512_unpacklo_epi32(v[14], v[15]);
v1[15] = _mm512_unpackhi_epi32(v[14], v[15]);
v[0] = _mm512_unpacklo_epi64(v1[0], v1[2]);
v[1] = _mm512_unpackhi_epi64(v1[0], v1[2]);
v[2] = _mm512_unpacklo_epi64(v1[1], v1[3]);
v[3] = _mm512_unpackhi_epi64(v1[1], v1[3]);
v[4] = _mm512_unpacklo_epi64(v1[4], v1[6]);
v[5] = _mm512_unpackhi_epi64(v1[4], v1[6]);
v[6] = _mm512_unpacklo_epi64(v1[5], v1[7]);
v[7] = _mm512_unpackhi_epi64(v1[5], v1[7]);
v[8] = _mm512_unpacklo_epi64(v1[8], v1[10]);
v[9] = _mm512_unpackhi_epi64(v1[8], v1[10]);
v[10] = _mm512_unpacklo_epi64(v1[9], v1[11]);
v[11] = _mm512_unpackhi_epi64(v1[9], v1[11]);
v[12] = _mm512_unpacklo_epi64(v1[12], v1[14]);
v[13] = _mm512_unpackhi_epi64(v1[12], v1[14]);
v[14] = _mm512_unpacklo_epi64(v1[13], v1[15]);
v[15] = _mm512_unpackhi_epi64(v1[13], v1[15]);
v1[0] = _mm512_shuffle_i32x4(v[0], v[4], 0x88);
v1[1] = _mm512_shuffle_i32x4(v[1], v[5], 0x88);
v1[2] = _mm512_shuffle_i32x4(v[2], v[6], 0x88);
v1[3] = _mm512_shuffle_i32x4(v[3], v[7], 0x88);
v1[4] = _mm512_shuffle_i32x4(v[0], v[4], 0xdd);
v1[5] = _mm512_shuffle_i32x4(v[1], v[5], 0xdd);
v1[6] = _mm512_shuffle_i32x4(v[2], v[6], 0xdd);
v1[7] = _mm512_shuffle_i32x4(v[3], v[7], 0xdd);
v1[8] = _mm512_shuffle_i32x4(v[8], v[12], 0x88);
v1[9] = _mm512_shuffle_i32x4(v[9], v[13], 0x88);
v1[10] = _mm512_shuffle_i32x4(v[10], v[14], 0x88);
v1[11] = _mm512_shuffle_i32x4(v[11], v[15], 0x88);
v1[12] = _mm512_shuffle_i32x4(v[8], v[12], 0xdd);
v1[13] = _mm512_shuffle_i32x4(v[9], v[13], 0xdd);
v1[14] = _mm512_shuffle_i32x4(v[10], v[14], 0xdd);
v1[15] = _mm512_shuffle_i32x4(v[11], v[15], 0xdd);
v[0] = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88);
v[1] = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88);
v[2] = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88);
v[3] = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88);
v[4] = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88);
v[5] = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88);
v[6] = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88);
v[7] = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88);
v[8] = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd);
v[9] = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd);
v[10] = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd);
v[11] = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd);
v[12] = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd);
v[13] = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd);
v[14] = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd);
v[15] = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd);
}
// remove warning : ignoring attributes on template argument __m512i [-Wignored-attributes]
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wignored-attributes"
// transpose from [2, 32] to [32, 2]
inline std::tuple<__m512i, __m512i> transpose_2x32_16bit(__m512i r0, __m512i r1) {
// r0: {a0, a1, ..., a31}
// r1: {b0, b1, ..., b31}
//
// d0: {a0, b0, ..., a15, b15}
// d1: {a16, b16, ..., a31, b31}
//
__m512i d0 = _mm512_unpacklo_epi16(r0, r1);
__m512i d1 = _mm512_unpackhi_epi16(r0, r1);
r0 = _mm512_shuffle_i32x4(d0, d1, 0x88);
r1 = _mm512_shuffle_i32x4(d0, d1, 0xdd);
d0 = _mm512_shuffle_i32x4(r0, r1, 0x88);
d1 = _mm512_shuffle_i32x4(r0, r1, 0xdd);
return std::make_tuple(d0, d1);
}
#pragma GCC diagnostic pop
#endif
} // anonymous namespace