sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct
This commit is contained in:
83
sgl-kernel/csrc/cpu/CMakeLists.txt
Executable file
83
sgl-kernel/csrc/cpu/CMakeLists.txt
Executable file
@@ -0,0 +1,83 @@
|
||||
cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
|
||||
project(sgl_kernel)
|
||||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_CXX_EXTENSIONS OFF)
|
||||
|
||||
find_package(Python COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT} REQUIRED)
|
||||
|
||||
execute_process(
|
||||
COMMAND ${Python_EXECUTABLE}
|
||||
-c "import torch; print(torch.utils.cmake_prefix_path)"
|
||||
OUTPUT_VARIABLE TORCH_PY_PREFIX
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE
|
||||
)
|
||||
|
||||
message(STATUS ${TORCH_PY_PREFIX})
|
||||
list(APPEND CMAKE_PREFIX_PATH ${TORCH_PY_PREFIX}/Torch)
|
||||
find_package(Torch REQUIRED)
|
||||
|
||||
include_directories(
|
||||
${TORCH_INCLUDE_DIRS}
|
||||
${TORCH_INSTALL_PREFIX}/include
|
||||
${Python_INCLUDE_DIRS}
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../csrc
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/../../include
|
||||
)
|
||||
|
||||
# Platform-specific library directory
|
||||
if(CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64|AMD64")
|
||||
set(PLAT_LIB_DIR "/usr/lib/x86_64-linux-gnu")
|
||||
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "aarch64|arm64")
|
||||
set(PLAT_LIB_DIR "/usr/lib/aarch64-linux-gnu")
|
||||
elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "ppc64le|ppc64")
|
||||
set(PLAT_LIB_DIR "/usr/lib/powerpc64le-linux-gnu")
|
||||
else()
|
||||
set(PLAT_LIB_DIR "/usr/lib/${CMAKE_SYSTEM_PROCESSOR}-linux-gnu")
|
||||
endif()
|
||||
link_directories(${PLAT_LIB_DIR})
|
||||
|
||||
# Conda library path support
|
||||
if(DEFINED ENV{CONDA_PREFIX})
|
||||
set(CONDA_LIB_DIR "$ENV{CONDA_PREFIX}/lib")
|
||||
message(STATUS "Using Conda lib dir: ${CONDA_LIB_DIR}")
|
||||
link_directories(${CONDA_LIB_DIR})
|
||||
set(CONDA_INCLUDE_DIR "$ENV{CONDA_PREFIX}/include")
|
||||
include_directories(${CONDA_INCLUDE_DIR})
|
||||
|
||||
# Look for libnuma in Conda's lib directory
|
||||
find_library(NUMA_LIB numa HINTS "${CONDA_LIB_DIR}")
|
||||
if(NUMA_LIB)
|
||||
message(STATUS "Found libnuma: ${NUMA_LIB}")
|
||||
else()
|
||||
message(FATAL_ERROR "libnuma not found in Conda environment at ${CONDA_LIB_DIR}\n"
|
||||
"Please install it using: conda install libnuma numactl\n")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
file(GLOB SOURCES "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp")
|
||||
|
||||
if(NOT DEFINED ENV{SGLANG_CPU_FP8_CVT_FTZ})
|
||||
set(ENV{SGLANG_CPU_FP8_CVT_FTZ} "1")
|
||||
endif()
|
||||
|
||||
if("$ENV{SGLANG_CPU_FP8_CVT_FTZ}" STREQUAL "1")
|
||||
message(STATUS "Enabling macro: SGLANG_CPU_FP8_CVT_FTZ")
|
||||
add_compile_definitions(SGLANG_CPU_FP8_CVT_FTZ)
|
||||
endif()
|
||||
|
||||
add_compile_options(
|
||||
-O3
|
||||
-Wno-unknown-pragmas
|
||||
-march=native
|
||||
-fopenmp
|
||||
)
|
||||
|
||||
Python_add_library(common_ops MODULE USE_SABI ${SKBUILD_SABI_VERSION} WITH_SOABI ${SOURCES})
|
||||
target_link_libraries(common_ops PRIVATE ${TORCH_LIBRARIES} ${NUMA_LIB})
|
||||
target_include_directories(common_ops PRIVATE ${TORCH_INCLUDE_DIRS})
|
||||
|
||||
install(TARGETS common_ops
|
||||
LIBRARY DESTINATION sgl_kernel
|
||||
)
|
||||
135
sgl-kernel/csrc/cpu/activation.cpp
Normal file
135
sgl-kernel/csrc/cpu/activation.cpp
Normal file
@@ -0,0 +1,135 @@
|
||||
#include "common.h"
|
||||
#include "vec.h"
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t, typename func_t, typename vec_func_t>
|
||||
void act_and_mul_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
const scalar_t* __restrict__ input,
|
||||
int64_t num_tokens,
|
||||
int64_t dim,
|
||||
const func_t& f,
|
||||
const vec_func_t& vf) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
|
||||
constexpr int64_t kVecSize = bVec::size();
|
||||
at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
// local ptrs
|
||||
const scalar_t* __restrict__ input_ptr = input + i * 2 * dim;
|
||||
const scalar_t* __restrict__ input_other_ptr = input_ptr + dim;
|
||||
scalar_t* __restrict__ output_ptr = output + i * dim;
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= dim - kVecSize; d += kVecSize) {
|
||||
bVec x_bvec = bVec::loadu(input_ptr + d);
|
||||
fVec x_fvec0, x_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||
|
||||
bVec y_bvec = bVec::loadu(input_other_ptr + d);
|
||||
fVec y_fvec0, y_fvec1;
|
||||
std::tie(y_fvec0, y_fvec1) = at::vec::convert_to_float(y_bvec);
|
||||
|
||||
x_fvec0 = vf(x_fvec0);
|
||||
x_fvec1 = vf(x_fvec1);
|
||||
|
||||
x_fvec0 = x_fvec0 * y_fvec0;
|
||||
x_fvec1 = x_fvec1 * y_fvec1;
|
||||
|
||||
x_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
|
||||
x_bvec.store(output_ptr + d);
|
||||
}
|
||||
#pragma GCC unroll 4
|
||||
for (; d < dim; ++d) {
|
||||
float x_val = static_cast<float>(input_ptr[d]);
|
||||
float y_val = static_cast<float>(input_other_ptr[d]);
|
||||
output_ptr[d] = f(x_val) * y_val;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// input : {num_tokens, 2 * d}
|
||||
// output : {num_tokens, d}
|
||||
at::Tensor silu_and_mul_cpu(at::Tensor& input) {
|
||||
RECORD_FUNCTION("sgl-kernel::silu_and_mul_cpu", std::vector<c10::IValue>({input}));
|
||||
auto sizes = input.sizes().vec();
|
||||
int64_t last_dim = input.ndimension() - 1;
|
||||
int64_t d = sizes[last_dim] / 2;
|
||||
sizes[last_dim] = d;
|
||||
int64_t num_tokens = input.numel() / input.size(-1);
|
||||
at::Tensor out = at::empty(sizes, input.options());
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "silu_and_mul", [&] {
|
||||
using Vec = at::vec::Vectorized<float>;
|
||||
act_and_mul_kernel_impl(
|
||||
out.data_ptr<scalar_t>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
num_tokens,
|
||||
d,
|
||||
[](float x) { return x / (1.f + std::exp(-x)); },
|
||||
[](Vec x) { return x / (Vec(1.f) + x.neg().exp()); });
|
||||
});
|
||||
return out;
|
||||
}
|
||||
|
||||
at::Tensor gelu_tanh_and_mul_cpu(const at::Tensor& input) {
|
||||
RECORD_FUNCTION("sgl-kernel::gelu_tanh_and_mul_cpu", std::vector<c10::IValue>({input}));
|
||||
auto sizes = input.sizes().vec();
|
||||
int64_t last_dim = input.ndimension() - 1;
|
||||
int64_t d = sizes[last_dim] / 2;
|
||||
sizes[last_dim] = d;
|
||||
int64_t num_tokens = input.numel() / input.size(-1);
|
||||
at::Tensor out = at::empty(sizes, input.options());
|
||||
const float sqrt_2_div_pi = std::sqrt(2.f / M_PI);
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "gelu_tanh_and_mul", [&] {
|
||||
using Vec = at::vec::Vectorized<float>;
|
||||
act_and_mul_kernel_impl(
|
||||
out.data_ptr<scalar_t>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
num_tokens,
|
||||
d,
|
||||
[sqrt_2_div_pi](float x) {
|
||||
float x3 = x * x * x;
|
||||
float tanh_arg = sqrt_2_div_pi * (x + 0.044715f * x3);
|
||||
return 0.5f * x * (1.f + std::tanh(tanh_arg));
|
||||
},
|
||||
[sqrt_2_div_pi](Vec x) {
|
||||
Vec x3 = x * x * x;
|
||||
Vec tanh_arg = Vec(sqrt_2_div_pi) * (x + Vec(0.044715f) * x3);
|
||||
return Vec(0.5f) * x * (Vec(1.f) + tanh_arg.tanh());
|
||||
});
|
||||
});
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
at::Tensor gelu_and_mul_cpu(const at::Tensor& input) {
|
||||
RECORD_FUNCTION("sgl-kernel::gelu_and_mul_cpu", std::vector<c10::IValue>({input}));
|
||||
auto sizes = input.sizes().vec();
|
||||
int64_t last_dim = input.ndimension() - 1;
|
||||
int64_t d = sizes[last_dim] / 2;
|
||||
sizes[last_dim] = d;
|
||||
int64_t num_tokens = input.numel() / input.size(-1);
|
||||
at::Tensor out = at::empty(sizes, input.options());
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "gelu_and_mul", [&] {
|
||||
using Vec = at::vec::Vectorized<float>;
|
||||
const float inv_sqrt2 = 1.0f / std::sqrt(2.0f);
|
||||
act_and_mul_kernel_impl(
|
||||
out.data_ptr<scalar_t>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
num_tokens,
|
||||
d,
|
||||
[inv_sqrt2](float x) { return 0.5f * x * (1.f + std::erf(x * inv_sqrt2)); },
|
||||
[inv_sqrt2](Vec x) { return Vec(0.5f) * x * (Vec(1.f) + (x * Vec(inv_sqrt2)).erf()); });
|
||||
});
|
||||
|
||||
return out;
|
||||
}
|
||||
123
sgl-kernel/csrc/cpu/bmm.cpp
Normal file
123
sgl-kernel/csrc/cpu/bmm.cpp
Normal file
@@ -0,0 +1,123 @@
|
||||
#include "common.h"
|
||||
#include "gemm.h"
|
||||
#include "vec.h"
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t>
|
||||
void bmm_kernel_impl(
|
||||
scalar_t* __restrict__ out,
|
||||
const scalar_t* __restrict__ mat1,
|
||||
const scalar_t* __restrict__ mat2,
|
||||
int64_t B,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t mat1_strideB,
|
||||
int64_t mat1_strideM,
|
||||
int64_t out_strideB,
|
||||
int64_t out_strideM,
|
||||
float scale = 0.f) {
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
// mat2 contiguous in [B, N, K]
|
||||
int64_t mat2_strideB = N * K;
|
||||
int64_t mat2_strideN = K;
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<scalar_t>(M);
|
||||
|
||||
// parallel on [B, MB, NB]
|
||||
at::parallel_for(0, B * MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t bs{0}, mb{0}, nb{0};
|
||||
data_index_init(begin, bs, B, mb, MB, nb, NB);
|
||||
|
||||
// for brgemm, use float32 for accumulate
|
||||
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
|
||||
|
||||
for (int i = begin; i < end; ++i) {
|
||||
UNUSED(i);
|
||||
int mb_start = mb * BLOCK_M;
|
||||
int mb_size = std::min(M - mb_start, BLOCK_M);
|
||||
int nb_start = nb * BLOCK_N;
|
||||
int nb_size = std::min(N - nb_start, BLOCK_N);
|
||||
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ mat1 + bs * mat1_strideB + mb_start * mat1_strideM,
|
||||
/* B */ mat2 + bs * mat2_strideB + nb_start * mat2_strideN /* nb * BLOCK_N * K */,
|
||||
/* C */ out + bs * out_strideB + mb_start * out_strideM + nb_start,
|
||||
/* Ctmp*/ Ctmp,
|
||||
/* M */ mb_size,
|
||||
/* N */ nb_size,
|
||||
/* K */ K,
|
||||
/* lda */ mat1_strideM,
|
||||
/* ldb */ nb_size,
|
||||
/* ldc */ out_strideM,
|
||||
/* brg */ use_brgemm);
|
||||
|
||||
// move to the next index
|
||||
data_index_step(bs, B, mb, MB, nb, NB);
|
||||
}
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// mat1 : [B, M, K]
|
||||
// mat2 : [B, N, K] or [B, OC, IC]
|
||||
// out : [B, M, N]
|
||||
// scale: [] 0-dim tensor for per tensor quant
|
||||
//
|
||||
void bmm_cpu(
|
||||
at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, const std::optional<at::Tensor>& scale) {
|
||||
RECORD_FUNCTION("sgl-kernel::bmm_cpu", std::vector<c10::IValue>({out, mat1, mat2}));
|
||||
|
||||
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
||||
|
||||
// input and out could be non-contiguous
|
||||
// weight needs to be contiguous in [OC, IC] order
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(out);
|
||||
CHECK_INPUT(mat2);
|
||||
CHECK_DIM(3, out);
|
||||
CHECK_DIM(3, mat1);
|
||||
CHECK_DIM(3, mat2);
|
||||
|
||||
int64_t B = mat1.size(0);
|
||||
int64_t M = mat1.size(1);
|
||||
int64_t N = mat2.size(1);
|
||||
int64_t K = mat1.size(2);
|
||||
|
||||
TORCH_CHECK(!scale.has_value(), "bmm: do not support fp8 weight for now.")
|
||||
TORCH_CHECK(N % 32 == 0, "tinygemm requires N to be 32x.");
|
||||
|
||||
int64_t mat1_strideB = mat1.stride(0);
|
||||
int64_t mat1_strideM = mat1.stride(1);
|
||||
int64_t out_strideB = out.stride(0);
|
||||
int64_t out_strideM = out.stride(1);
|
||||
|
||||
// check shapes
|
||||
TORCH_CHECK(mat2.size(0) == B && mat2.size(2) == K, "bmm: mat2 shape mismatch!");
|
||||
TORCH_CHECK(out.size(0) == B && out.size(1) == M, "bmm: out shape mismatch!");
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(mat1.scalar_type(), "bmm_kernel_impl", [&] {
|
||||
bmm_kernel_impl<scalar_t>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
mat1.data_ptr<scalar_t>(),
|
||||
packed_w.data_ptr<scalar_t>(),
|
||||
B,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
mat1_strideB,
|
||||
mat1_strideM,
|
||||
out_strideB,
|
||||
out_strideM);
|
||||
});
|
||||
}
|
||||
324
sgl-kernel/csrc/cpu/common.h
Normal file
324
sgl-kernel/csrc/cpu/common.h
Normal file
@@ -0,0 +1,324 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/record_function.h>
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
|
||||
// dispatch bool
|
||||
#define AT_DISPATCH_BOOL(BOOL_V, BOOL_NAME, ...) \
|
||||
[&] { \
|
||||
if (BOOL_V) { \
|
||||
constexpr bool BOOL_NAME = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
constexpr bool BOOL_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
|
||||
// dispatch: bfloat16, float16, int8_t, fp8_e4m3
|
||||
#define CPU_DISPATCH_PACKED_TYPES(TYPE, ...) \
|
||||
[&] { \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::BFloat16: { \
|
||||
using packed_t = at::BFloat16; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using packed_t = at::Half; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
case at::ScalarType::Char: { \
|
||||
using packed_t = int8_t; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
case at::ScalarType::Float8_e4m3fn: { \
|
||||
using packed_t = at::Float8_e4m3fn; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported floating data type.\n"); \
|
||||
} \
|
||||
}()
|
||||
|
||||
// dispatch with mixed dtypes (TYPE1, TYPE2):
|
||||
// TYPE1: the primary dtype (input, output, weight);
|
||||
// TYPE2: the secondary dtype (bias, etc.).
|
||||
#define CPU_DISPATCH_REDUCED_FLOATING_TYPES_EXT(TYPE1, TYPE2, ...) \
|
||||
[&] { \
|
||||
if (TYPE2 == at::kFloat) { \
|
||||
switch (TYPE1) { \
|
||||
case at::ScalarType::BFloat16: { \
|
||||
using scalar_t = at::BFloat16; \
|
||||
using param_t = float; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t = at::Half; \
|
||||
using param_t = float; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported floating data type.\n"); \
|
||||
} \
|
||||
} else { \
|
||||
TORCH_CHECK(TYPE1 == TYPE2); \
|
||||
switch (TYPE1) { \
|
||||
case at::ScalarType::BFloat16: { \
|
||||
using scalar_t = at::BFloat16; \
|
||||
using param_t = at::BFloat16; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using scalar_t = at::Half; \
|
||||
using param_t = at::Half; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported floating data type.\n"); \
|
||||
} \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define UNUSED(x) (void)(x)
|
||||
|
||||
#define CHECK_CPU(x) TORCH_CHECK(x.device().type() == at::kCPU, #x " must be a CPU tensor")
|
||||
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_LAST_DIM_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimension")
|
||||
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CPU(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \
|
||||
CHECK_CPU(x); \
|
||||
CHECK_LAST_DIM_CONTIGUOUS(x)
|
||||
|
||||
#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
|
||||
|
||||
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
|
||||
|
||||
// [NB] Parallel Routines
|
||||
//
|
||||
// * at::parallel_for - applies for most of generic use cases, this will be compiled
|
||||
// against openmp in default torch release.
|
||||
//
|
||||
// * parallel_for - same function as above, can choose payload partition scheme in
|
||||
// balance211.
|
||||
//
|
||||
// * parallel_2d - parallel for 2 dimensions, used in GEMM, etc.
|
||||
// this one will do payload balance across 2 dimensions.
|
||||
//
|
||||
|
||||
// grain size for each thread
|
||||
constexpr int GRAIN_SIZE = 1024;
|
||||
|
||||
template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
|
||||
inline T div_up(T x, T y) {
|
||||
return (x + y - 1) / y;
|
||||
}
|
||||
|
||||
// you can only use at::get_thread_num() with at::parallel_for()
|
||||
// as it is lazy initialized, otherwise it will always return 0.
|
||||
inline int get_thread_num() {
|
||||
#if defined(_OPENMP)
|
||||
return omp_get_thread_num();
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
// balance payload across each thread
|
||||
template <typename T>
|
||||
inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) {
|
||||
#if 0
|
||||
// onednn partition pattern
|
||||
T& n_my = n_end;
|
||||
if (nth <= 1 || n == 0) {
|
||||
n_start = 0;
|
||||
n_my = n;
|
||||
} else {
|
||||
T n1 = div_up(n, nth);
|
||||
T n2 = n1 - 1;
|
||||
T T1 = n - n2 * nth;
|
||||
n_my = ith < T1 ? n1 : n2;
|
||||
n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2;
|
||||
}
|
||||
n_end += n_start;
|
||||
#else
|
||||
// pytorch aten partition pattern
|
||||
T n_my = div_up(n, nth);
|
||||
n_start = ith * n_my;
|
||||
n_end = std::min(n_start + n_my, n);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename func_t>
|
||||
inline void parallel_for(int n, const func_t& f) {
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp parallel
|
||||
{
|
||||
int nth = omp_get_num_threads();
|
||||
int ith = omp_get_thread_num();
|
||||
int tbegin, tend;
|
||||
balance211(n, nth, ith, tbegin, tend);
|
||||
f(tbegin, tend);
|
||||
}
|
||||
#else
|
||||
f(0, n);
|
||||
#endif
|
||||
}
|
||||
|
||||
// for 1d parallel, use `actual_nth`
|
||||
// for 2d parallel, use even nths, e.g. 43->42
|
||||
int inline adjust_num_threads(int m) {
|
||||
int actual_nth = at::get_num_threads();
|
||||
if (m == 1) {
|
||||
return actual_nth;
|
||||
}
|
||||
return std::max(1, (actual_nth >> 1) * 2);
|
||||
}
|
||||
|
||||
template <typename func_t>
|
||||
inline void parallel_2d(int m, int n, const func_t& f) {
|
||||
// make sure we have even num_threads
|
||||
int nth = adjust_num_threads(m);
|
||||
|
||||
// [NOTE] thread blocking:
|
||||
//
|
||||
// 1) prefer square block per thread
|
||||
// 2) use even number of CPU cores
|
||||
// 3) use all `num_threads` cores
|
||||
//
|
||||
// we have:
|
||||
// TM * TN = T
|
||||
// BM / TM = BN / TN
|
||||
// then:
|
||||
// TM = ((BM / BN) * T) ^ 0.5
|
||||
//
|
||||
float r = float(m) / n;
|
||||
int nth_m = std::ceil(std::sqrt(r * nth));
|
||||
int nth_n = 1;
|
||||
for (; nth_m > 0; --nth_m) {
|
||||
nth_n = nth / nth_m;
|
||||
if (nth_m * nth_n == nth) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp parallel num_threads(nth)
|
||||
{
|
||||
int ith = omp_get_thread_num();
|
||||
int ith_m = ith / nth_n;
|
||||
int ith_n = ith % nth_n;
|
||||
|
||||
int thread_block_m = div_up(m, nth_m);
|
||||
int thread_block_n = div_up(n, nth_n);
|
||||
|
||||
int begin_m = ith_m * thread_block_m;
|
||||
int end_m = std::min(m, begin_m + thread_block_m);
|
||||
int begin_n = ith_n * thread_block_n;
|
||||
int end_n = std::min(n, begin_n + thread_block_n);
|
||||
|
||||
f(begin_m, end_m, begin_n, end_n);
|
||||
}
|
||||
#else
|
||||
f(0, m, 0, n);
|
||||
#endif
|
||||
}
|
||||
|
||||
// limit max cache blocks
|
||||
// when we need to do pre-unpack for weights, e.g. fp8
|
||||
#define MAX_CACHE_BLOCK_SIZE 4
|
||||
|
||||
template <typename T>
|
||||
inline int get_cache_blocks(int chunk_size) {
|
||||
// L2 2MB and ratio of 50%
|
||||
const int L2_size = 2048 * 1024 >> 1;
|
||||
return std::max(1, int(L2_size / (chunk_size * sizeof(T))));
|
||||
}
|
||||
|
||||
template <>
|
||||
inline int get_cache_blocks<at::Float8_e4m3fn>(int chunk_size) {
|
||||
// fp8 uses bf16 as accumulate type
|
||||
int cache_block_size = get_cache_blocks<at::BFloat16>(chunk_size);
|
||||
return std::min(MAX_CACHE_BLOCK_SIZE, cache_block_size);
|
||||
}
|
||||
|
||||
// 2d sequential loop in range : [mb0, mb1), [nb0, nb1)
|
||||
template <typename T, typename func_t>
|
||||
inline void loop_2d(int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1, int64_t chunk_size, const func_t& f) {
|
||||
// get number of blocks for L2 in most inner loop
|
||||
int64_t cache_blocks_nb = get_cache_blocks<T>(chunk_size);
|
||||
|
||||
// loop order: [NB / cache_blocks_nb, MB, cache_blocks_nb]
|
||||
// TODO: implement reverse order of [MB / cache_blocks_mb, NB, cache_blocks_mb]
|
||||
for (int64_t nbb = nb0; nbb < nb1; nbb += cache_blocks_nb) {
|
||||
for (int64_t mb = mb0; mb < mb1; ++mb) {
|
||||
for (int64_t nb = nbb; nb < std::min(nbb + cache_blocks_nb, nb1); ++nb) {
|
||||
f(mb, nb, nb - nbb);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// data indexing for dimension collapse
|
||||
template <typename T>
|
||||
inline T data_index_init(T offset) {
|
||||
return offset;
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
inline T data_index_init(T offset, T& x, const T& X, Args&&... args) {
|
||||
offset = data_index_init(offset, std::forward<Args>(args)...);
|
||||
x = offset % X;
|
||||
return offset / X;
|
||||
}
|
||||
|
||||
inline bool data_index_step() {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
inline bool data_index_step(T& x, const T& X, Args&&... args) {
|
||||
if (data_index_step(std::forward<Args>(args)...)) {
|
||||
x = ((x + 1) == X) ? 0 : (x + 1);
|
||||
return x == 0;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// forced unroll for perf critical path
|
||||
|
||||
#if __has_attribute(always_inline)
|
||||
#define ALWAYS_INLINE __attribute__((__always_inline__)) inline
|
||||
#else
|
||||
#define ALWAYS_INLINE inline
|
||||
#endif
|
||||
|
||||
template <int n>
|
||||
struct Unroll {
|
||||
template <typename Func, typename... Args>
|
||||
ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
|
||||
Unroll<n - 1>{}(f, args...);
|
||||
f(std::integral_constant<int, n - 1>{}, args...);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Unroll<1> {
|
||||
template <typename Func, typename... Args>
|
||||
ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
|
||||
f(std::integral_constant<int, 0>{}, args...);
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
1575
sgl-kernel/csrc/cpu/decode.cpp
Normal file
1575
sgl-kernel/csrc/cpu/decode.cpp
Normal file
File diff suppressed because it is too large
Load Diff
723
sgl-kernel/csrc/cpu/extend.cpp
Normal file
723
sgl-kernel/csrc/cpu/extend.cpp
Normal file
@@ -0,0 +1,723 @@
|
||||
#include "common.h"
|
||||
#include "gemm.h"
|
||||
#include "vec.h"
|
||||
|
||||
namespace {
|
||||
|
||||
// [NOTE]: extend attention for CPU
|
||||
// 1. tune BLOCK_M and BLOCK_N
|
||||
// 2. can handle non-contiguous k_exttend and v_extend
|
||||
// 3. computes attention for prefix and extend separately
|
||||
// 4. TODO: vectorize `pack_vnni` and `pack_vnni2`
|
||||
//
|
||||
|
||||
template <typename index_t>
|
||||
inline index_t get_index(index_t* ind, int i) {
|
||||
return (ind == nullptr) ? (index_t)i : ind[i];
|
||||
}
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
// key: from [N, 32] to [32/2, N, 2]
|
||||
template <typename scalar_t, typename index_t>
|
||||
inline void pack_vnni_Nx32(
|
||||
scalar_t* __restrict__ dst,
|
||||
const scalar_t* __restrict__ src,
|
||||
const index_t* __restrict__ ind,
|
||||
int N,
|
||||
int ld_src,
|
||||
int ld_dst) {
|
||||
__m512i vinputs[16];
|
||||
|
||||
int n = 0;
|
||||
for (; n < N; ++n) {
|
||||
index_t index = get_index(ind, n);
|
||||
vinputs[n] = _mm512_loadu_si512(src + index * ld_src);
|
||||
}
|
||||
// padding with zero to avoid uninitialized vectors
|
||||
for (; n < 16; ++n) {
|
||||
vinputs[n] = _mm512_set1_epi32(0);
|
||||
}
|
||||
|
||||
// pack key
|
||||
transpose_16x16_32bit(vinputs);
|
||||
|
||||
const __mmask16 vmask = (1 << N) - 1;
|
||||
for (int k = 0; k < 16; ++k) {
|
||||
_mm512_mask_storeu_epi32(dst + k * ld_dst * 2, vmask, vinputs[k]);
|
||||
}
|
||||
}
|
||||
|
||||
// value: from [K, 32] to [K/2, 32, 2]
|
||||
template <typename scalar_t, typename index_t>
|
||||
inline void pack_vnni_Kx32(
|
||||
scalar_t* __restrict__ dst,
|
||||
const scalar_t* __restrict__ src,
|
||||
const index_t* __restrict__ ind,
|
||||
int K,
|
||||
int ld_src,
|
||||
int ld_dst) {
|
||||
__m512i vinputs[2];
|
||||
|
||||
int k = 0;
|
||||
for (; k < K; ++k) {
|
||||
index_t index = get_index(ind, k);
|
||||
vinputs[k] = _mm512_loadu_si512(src + index * ld_src);
|
||||
}
|
||||
// padding with zero to avoid uninitialized vectors
|
||||
for (; k < 2; ++k) {
|
||||
vinputs[k] = _mm512_set1_epi32(0);
|
||||
}
|
||||
|
||||
// pack value
|
||||
__m512i d0, d1;
|
||||
std::tie(d0, d1) = transpose_2x32_16bit(vinputs[0], vinputs[1]);
|
||||
_mm512_storeu_si512(dst + 0 * ld_dst * 2, d0);
|
||||
_mm512_storeu_si512(dst + 0 * ld_dst * 2 + 32, d1);
|
||||
}
|
||||
#endif
|
||||
|
||||
// convert to vnni format
|
||||
// from [N, K/2, 2] to [K/2, N, 2] for bfloat16 and float16
|
||||
template <typename scalar_t, typename index_t>
|
||||
void pack_vnni(
|
||||
scalar_t* __restrict__ dst,
|
||||
const scalar_t* __restrict__ src,
|
||||
const index_t* __restrict__ ind,
|
||||
int N,
|
||||
int K,
|
||||
int ld_src,
|
||||
int ld_dst) {
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
const int NB = div_up(N, 16);
|
||||
const int KB = K / 32; // no remainder
|
||||
const bool is_indexed = ind != nullptr;
|
||||
|
||||
for (int nb = 0; nb < NB; ++nb) {
|
||||
for (int kb = 0; kb < KB; ++kb) {
|
||||
// handle 16x512bits each block
|
||||
int nb_size = std::min(N - nb * 16, 16);
|
||||
pack_vnni_Nx32<scalar_t, index_t>(
|
||||
/* dst */ dst + ((kb * 32) >> 1) * ld_dst * 2 + nb * 16 * 2,
|
||||
/* src */ src + kb * 32 + (is_indexed ? 0 : nb * 16 * ld_src),
|
||||
/* ind */ is_indexed ? ind + nb * 16 : nullptr,
|
||||
/* N */ nb_size,
|
||||
/* ld_src */ ld_src,
|
||||
/* ld_dst */ ld_dst);
|
||||
}
|
||||
}
|
||||
#else
|
||||
for (int n = 0; n < N; ++n) {
|
||||
index_t index = get_index(ind, n);
|
||||
for (int k = 0; k < K / 2; ++k) {
|
||||
for (int d = 0; d < 2; ++d) {
|
||||
dst[k * ld_dst * 2 + n * 2 + d] = src[index * ld_src + k * 2 + d];
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert to vnni format
|
||||
// from [K/2, 2, N] to [K/2, N, 2] for bfloat16 and float16
|
||||
template <typename scalar_t, typename index_t>
|
||||
void pack_vnni2(
|
||||
scalar_t* __restrict__ dst,
|
||||
const scalar_t* __restrict__ src,
|
||||
const index_t* __restrict__ ind,
|
||||
int K,
|
||||
int N,
|
||||
int ld_src,
|
||||
int ld_dst) {
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
const int KB = div_up(K, 2);
|
||||
const int NB = N / 32; // no remainder
|
||||
const bool is_indexed = ind != nullptr;
|
||||
|
||||
for (int kb = 0; kb < KB; ++kb) {
|
||||
for (int nb = 0; nb < NB; ++nb) {
|
||||
// handle 2x512bits each block
|
||||
int kb_size = std::min(K - kb * 2, 2);
|
||||
pack_vnni_Kx32<scalar_t, index_t>(
|
||||
/* dst */ dst + ((kb * 2) >> 1) * ld_dst * 2 + nb * 32 * 2,
|
||||
/* src */ src + (is_indexed ? 0 : kb * 2 * ld_src) + nb * 32,
|
||||
/* ind */ is_indexed ? ind + kb * 2 : nullptr,
|
||||
/* K */ kb_size,
|
||||
/* ld_src */ ld_src,
|
||||
/* ld_dst */ ld_dst);
|
||||
}
|
||||
}
|
||||
#else
|
||||
int k = 0;
|
||||
for (; k < (K >> 1) * 2; k += 2) {
|
||||
index_t index0 = get_index(ind, k + 0);
|
||||
index_t index1 = get_index(ind, k + 1);
|
||||
for (int n = 0; n < N; ++n) {
|
||||
dst[(k >> 1) * ld_dst * 2 + n * 2 + 0] = src[index0 * ld_src + n];
|
||||
dst[(k >> 1) * ld_dst * 2 + n * 2 + 1] = src[index1 * ld_src + n];
|
||||
}
|
||||
}
|
||||
if (K % 2 != 0) {
|
||||
index_t index = get_index(ind, K - 1);
|
||||
for (int n = 0; n < N; ++n) {
|
||||
dst[(K >> 1) * ld_dst * 2 + n * 2 + 0] = src[index * ld_src + n];
|
||||
dst[(K >> 1) * ld_dst * 2 + n * 2 + 1] = 0;
|
||||
}
|
||||
k += 2;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void fill_stub(scalar_t* __restrict__ out, float val, int size) {
|
||||
using Vec = at::vec::Vectorized<scalar_t>;
|
||||
constexpr int kVecSize = Vec::size();
|
||||
const Vec data_vec = Vec(static_cast<scalar_t>(val));
|
||||
int d = 0;
|
||||
#pragma GCC unroll 4
|
||||
for (; d <= size - kVecSize; d += kVecSize) {
|
||||
data_vec.store(out + d);
|
||||
}
|
||||
if (size - d > 0) {
|
||||
data_vec.store(out + d, size - d);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int BLOCK_N>
|
||||
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input) {
|
||||
static_assert(BLOCK_N % 32 == 0);
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
auto store = [&](auto i) {
|
||||
constexpr int col = i % COLS;
|
||||
// for COLS = 2, 4 use 512bit store
|
||||
if constexpr (col % 2 == 0) {
|
||||
fVec a_fvec0 = fVec::loadu(input + col * 16);
|
||||
fVec a_fvec1 = fVec::loadu(input + col * 16 + 16);
|
||||
bVec out_bvec = convert_from_float_ext<scalar_t>(a_fvec0, a_fvec1);
|
||||
out_bvec.store(out + col * 16);
|
||||
}
|
||||
};
|
||||
Unroll<COLS>{}(store);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ acc, float s, int size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
const fVec s_fvec = fVec(s);
|
||||
int d = 0;
|
||||
#pragma GCC unroll 4
|
||||
for (; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec a_fvec0 = fVec::loadu(acc + d) * s_fvec;
|
||||
fVec a_fvec1 = fVec::loadu(acc + d + fVec::size()) * s_fvec;
|
||||
bVec out_bvec = convert_from_float_ext<scalar_t>(a_fvec0, a_fvec1);
|
||||
out_bvec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(acc[d] * s);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename index_t, int BLOCK_M, int BLOCK_N>
|
||||
void extend_attention_kernel_impl(
|
||||
scalar_t* __restrict__ o_extend,
|
||||
const scalar_t* __restrict__ q_extend,
|
||||
const scalar_t* __restrict__ k_extend,
|
||||
const scalar_t* __restrict__ v_extend,
|
||||
const scalar_t* __restrict__ k_buffer,
|
||||
const scalar_t* __restrict__ v_buffer,
|
||||
const index_t* __restrict__ req_to_token,
|
||||
const int64_t* __restrict__ req_pool_indices,
|
||||
const int64_t* __restrict__ seq_lens,
|
||||
const index_t* __restrict__ extend_seq_lens,
|
||||
const index_t* __restrict__ extend_start_loc,
|
||||
const void* __restrict__ buffer,
|
||||
int batches,
|
||||
int num_heads,
|
||||
int num_heads_kv,
|
||||
int head_size,
|
||||
int head_size_v,
|
||||
int q_strideM,
|
||||
int q_strideH,
|
||||
int ke_strideN,
|
||||
int ke_strideH,
|
||||
int ve_strideN,
|
||||
int ve_strideH,
|
||||
int k_strideN,
|
||||
int k_strideH,
|
||||
int v_strideN,
|
||||
int v_strideH,
|
||||
float scaling,
|
||||
float logit_cap,
|
||||
int max_num_reqs,
|
||||
int max_context_len,
|
||||
int max_total_num_tokens,
|
||||
int max_len_extend,
|
||||
int buffer_size_per_thread,
|
||||
bool is_prefix_skipped) {
|
||||
using Vec = at::vec::Vectorized<float>;
|
||||
|
||||
// strides
|
||||
const int o_strideM = num_heads * head_size_v;
|
||||
const int o_strideH = head_size_v;
|
||||
|
||||
// we use same buffer for packed key and value
|
||||
const int ldb_tmp = std::max(head_size, head_size_v);
|
||||
|
||||
const bool has_logit_cap = logit_cap > 0;
|
||||
float rlogit_cap = has_logit_cap ? 1 / logit_cap : 0.f;
|
||||
|
||||
const int num_groups = num_heads / num_heads_kv;
|
||||
TORCH_CHECK(num_groups * num_heads_kv == num_heads);
|
||||
|
||||
// number of blocks along M
|
||||
int MB = div_up(max_len_extend, BLOCK_M);
|
||||
|
||||
// parallel on [batches, num_heads, BM]
|
||||
at::parallel_for(0, batches * num_heads * MB, 0, [&](int begin, int end) {
|
||||
int bs{0}, head_id{0}, mb{0};
|
||||
data_index_init(begin, bs, batches, head_id, num_heads, mb, MB);
|
||||
|
||||
int tid = at::get_thread_num();
|
||||
// s_i and s_delta: [BLOCK_M, BLOCK_N]
|
||||
float* __restrict__ s_i = reinterpret_cast<float*>((char*)(buffer) + tid * buffer_size_per_thread);
|
||||
float* __restrict__ s_delta = s_i;
|
||||
|
||||
// v_prime: [BLOCK_M, head_size_v]
|
||||
float* __restrict__ v_prime = s_i + BLOCK_M * BLOCK_N;
|
||||
|
||||
// s_delta2: [BLOCK_M, BLOCK_N]; copy of s_delta in scalar_t
|
||||
scalar_t* __restrict__ s_delta2 = reinterpret_cast<scalar_t*>(v_prime + BLOCK_N * head_size_v);
|
||||
|
||||
// Btmp: [BLOCK_N, max(head_size, head_size_v)]
|
||||
scalar_t* __restrict__ Btmp = s_delta2 + BLOCK_M * BLOCK_N;
|
||||
|
||||
// init Btmp just once for each thread to prevent NaN
|
||||
fill_stub(Btmp, 0.f, BLOCK_N * ldb_tmp);
|
||||
|
||||
alignas(64) float s_prime[BLOCK_M];
|
||||
alignas(64) float m_prime[BLOCK_M];
|
||||
|
||||
for (int i = begin; i < end; ++i) {
|
||||
// seq_len = prefix + extend
|
||||
int head_kv_id = head_id / num_groups;
|
||||
int seq_len = seq_lens[bs];
|
||||
int seq_len_extend = extend_seq_lens[bs];
|
||||
int seq_len_prefix = seq_len - seq_len_extend;
|
||||
int seq_extend_start_loc = extend_start_loc[bs];
|
||||
|
||||
int req_pool_id = req_pool_indices[bs];
|
||||
TORCH_CHECK(seq_len_prefix >= 0, "prefix len < 0!");
|
||||
TORCH_CHECK(seq_len <= max_context_len, "seq_len out of scope!");
|
||||
TORCH_CHECK(req_pool_id < max_num_reqs, "req_pool_id out of scope!");
|
||||
|
||||
if (is_prefix_skipped) {
|
||||
TORCH_CHECK(seq_len_prefix == 0, "extend attention: expect seq_len_prefix to be 0, got ", seq_len_prefix);
|
||||
}
|
||||
|
||||
// offset and size in MB
|
||||
int m = mb * BLOCK_N;
|
||||
int m_size = std::min(BLOCK_M, seq_len_extend - m);
|
||||
|
||||
if (m_size <= 0) {
|
||||
data_index_step(bs, batches, head_id, num_heads, mb, MB);
|
||||
continue;
|
||||
}
|
||||
|
||||
// get query
|
||||
const scalar_t* __restrict__ q_ptr = q_extend + (seq_extend_start_loc + m) * q_strideM + head_id * q_strideH;
|
||||
|
||||
// init v', s' and m'
|
||||
fill_stub(v_prime, 0.f, m_size * head_size_v);
|
||||
fill_stub(s_prime, 0.f, m_size);
|
||||
fill_stub(m_prime, -std::numeric_limits<scalar_t>::infinity(), m_size);
|
||||
|
||||
// stage 1: compute scores with prefix
|
||||
for (int n = 0; n < seq_len_prefix; n += BLOCK_N) {
|
||||
int n_size = std::min(BLOCK_N, seq_len_prefix - n);
|
||||
|
||||
// `n_size` is K in 2nd gemm, pad to TILE_K;
|
||||
const int padded_n_size = div_up(n_size, TILE_K) * TILE_K;
|
||||
|
||||
// get key and pack
|
||||
pack_vnni<scalar_t, index_t>(
|
||||
/* dst */ Btmp,
|
||||
/* src */ k_buffer + head_kv_id * k_strideH,
|
||||
/* ind */ req_to_token + req_pool_id * max_context_len + n,
|
||||
/* N */ n_size,
|
||||
/* K */ head_size,
|
||||
/* ld_src */ k_strideN,
|
||||
/* ld_dst */ BLOCK_N);
|
||||
|
||||
// calculate s_i <- Q @ K
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ head_size,
|
||||
/* lda */ q_strideM,
|
||||
/* ldb */ BLOCK_N,
|
||||
/* ldc */ BLOCK_N,
|
||||
/* add_C */ false,
|
||||
/* A */ q_ptr,
|
||||
/* B */ Btmp,
|
||||
/* C */ s_i);
|
||||
|
||||
const Vec scale_vec = Vec(scaling);
|
||||
for (int row = 0; row < m_size; ++row) {
|
||||
// s_i <- s_i * scale
|
||||
at::vec::map<float>(
|
||||
[scale_vec](Vec x) { return x * scale_vec; }, s_i + row * BLOCK_N, s_i + row * BLOCK_N, n_size);
|
||||
|
||||
// TODO: `tanh` from torch uses sleef u10, going to be slow
|
||||
if (has_logit_cap) {
|
||||
at::vec::map<float>(
|
||||
[logit_cap, rlogit_cap](Vec x) { return Vec(logit_cap) * (x * Vec(rlogit_cap)).tanh(); },
|
||||
s_i + row * BLOCK_N,
|
||||
s_i + row * BLOCK_N,
|
||||
n_size);
|
||||
}
|
||||
|
||||
// m_i: max value per row
|
||||
float m_i = at::vec::reduce_all<float>(
|
||||
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, s_i + row * BLOCK_N, n_size);
|
||||
m_i = std::max(m_i, m_prime[row]);
|
||||
|
||||
// m_delta <- exp(m' - m_i)
|
||||
float m_delta = std::exp(m_prime[row] - m_i);
|
||||
|
||||
// s_delta <- exp(s_i - m_i)
|
||||
at::vec::map<float>(
|
||||
[m_i](Vec x) { return (x - Vec(m_i)).exp_u20(); }, s_delta + row * BLOCK_N, s_i + row * BLOCK_N, n_size);
|
||||
|
||||
// s' <- s' * m_delta + sum(s_delta)
|
||||
s_prime[row] *= m_delta;
|
||||
s_prime[row] +=
|
||||
at::vec::reduce_all<float>([](Vec& x, Vec& y) { return x + y; }, s_delta + row * BLOCK_N, n_size);
|
||||
|
||||
m_prime[row] = m_i;
|
||||
|
||||
// v' <- v' * m_delta
|
||||
at::vec::map<float>(
|
||||
[m_delta](Vec x) { return x * Vec(m_delta); },
|
||||
v_prime + row * head_size_v,
|
||||
v_prime + row * head_size_v,
|
||||
head_size_v);
|
||||
|
||||
// pad s_delta with 0 first and then convert to scalar_t
|
||||
fill_stub(s_delta + row * BLOCK_N + n_size, 0.f, padded_n_size - n_size);
|
||||
copy_stub<scalar_t, BLOCK_N>(s_delta2 + row * BLOCK_N, s_delta + row * BLOCK_N);
|
||||
}
|
||||
|
||||
// get value and pack
|
||||
pack_vnni2<scalar_t, index_t>(
|
||||
/* dst */ Btmp,
|
||||
/* src */ v_buffer + head_kv_id * v_strideH,
|
||||
/* ind */ req_to_token + req_pool_id * max_context_len + n,
|
||||
/* K */ n_size,
|
||||
/* N */ head_size_v,
|
||||
/* ld_src */ v_strideN,
|
||||
/* ld_dst */ head_size_v);
|
||||
|
||||
// calculate V' <- s_delta @ V + V'
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ m_size,
|
||||
/* N */ head_size_v,
|
||||
/* K */ padded_n_size, // n_size
|
||||
/* lda */ BLOCK_N,
|
||||
/* ldb */ head_size_v,
|
||||
/* ldc */ head_size_v,
|
||||
/* add_C */ true,
|
||||
/* A */ s_delta2,
|
||||
/* B */ Btmp,
|
||||
/* C */ v_prime);
|
||||
} // loop with seq_len_prefix
|
||||
|
||||
// stage 2: compute the triangle part
|
||||
int num_keys = std::min(seq_len_extend, m + BLOCK_M);
|
||||
for (int n = 0; n < num_keys; n += BLOCK_N) {
|
||||
int n_size = std::min(BLOCK_N, num_keys - n);
|
||||
|
||||
// `n_size` is K in 2nd gemm, pad to TILE_K;
|
||||
const int padded_n_size = div_up(n_size, TILE_K) * TILE_K;
|
||||
|
||||
// get key and pack
|
||||
pack_vnni<scalar_t, index_t>(
|
||||
/* dst */ Btmp,
|
||||
/* src */ k_extend + (seq_extend_start_loc + n) * ke_strideN + head_kv_id * ke_strideH,
|
||||
/* ind */ nullptr,
|
||||
/* N */ n_size,
|
||||
/* K */ head_size,
|
||||
/* ld_src */ ke_strideN,
|
||||
/* ld_dst */ BLOCK_N);
|
||||
|
||||
// calculate s_i <- Q @ K
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ head_size,
|
||||
/* lda */ q_strideM,
|
||||
/* ldb */ BLOCK_N,
|
||||
/* ldc */ BLOCK_N,
|
||||
/* add_C */ false,
|
||||
/* A */ q_ptr,
|
||||
/* B */ Btmp,
|
||||
/* C */ s_i);
|
||||
|
||||
// apply causal mask
|
||||
if (num_keys - n <= BLOCK_N) {
|
||||
for (int row = 0; row < m_size; ++row) {
|
||||
int last_col = m + row - n;
|
||||
// fill [last_col + 1, n_size) to -inf
|
||||
float* row_ptr = s_i + row * BLOCK_N;
|
||||
fill_stub(row_ptr + last_col + 1, -std::numeric_limits<float>::infinity(), n_size - last_col - 1);
|
||||
}
|
||||
}
|
||||
|
||||
const Vec scale_vec = Vec(scaling);
|
||||
for (int row = 0; row < m_size; ++row) {
|
||||
// s_i <- s_i * scale
|
||||
at::vec::map<float>(
|
||||
[scale_vec](Vec x) { return x * scale_vec; }, s_i + row * BLOCK_N, s_i + row * BLOCK_N, n_size);
|
||||
|
||||
// TODO: `tanh` from torch uses sleef u10, going to be slow
|
||||
if (has_logit_cap) {
|
||||
at::vec::map<float>(
|
||||
[logit_cap, rlogit_cap](Vec x) { return Vec(logit_cap) * (x * Vec(rlogit_cap)).tanh(); },
|
||||
s_i + row * BLOCK_N,
|
||||
s_i + row * BLOCK_N,
|
||||
n_size);
|
||||
}
|
||||
|
||||
// m_i: max value per row
|
||||
float m_i = at::vec::reduce_all<float>(
|
||||
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, s_i + row * BLOCK_N, n_size);
|
||||
m_i = std::max(m_i, m_prime[row]);
|
||||
|
||||
// m_delta <- exp(m' - m_i)
|
||||
float m_delta = std::exp(m_prime[row] - m_i);
|
||||
|
||||
// s_delta <- exp(s_i - m_i)
|
||||
at::vec::map<float>(
|
||||
[m_i](Vec x) { return (x - Vec(m_i)).exp_u20(); }, s_delta + row * BLOCK_N, s_i + row * BLOCK_N, n_size);
|
||||
|
||||
// s' <- s' * m_delta + sum(s_delta)
|
||||
s_prime[row] *= m_delta;
|
||||
s_prime[row] +=
|
||||
at::vec::reduce_all<float>([](Vec& x, Vec& y) { return x + y; }, s_delta + row * BLOCK_N, n_size);
|
||||
|
||||
m_prime[row] = m_i;
|
||||
|
||||
// v' <- v' * m_delta
|
||||
at::vec::map<float>(
|
||||
[m_delta](Vec x) { return x * Vec(m_delta); },
|
||||
v_prime + row * head_size_v,
|
||||
v_prime + row * head_size_v,
|
||||
head_size_v);
|
||||
|
||||
// pad s_delta with 0 first and then convert to scalar_t
|
||||
fill_stub(s_delta + row * BLOCK_N + n_size, 0.f, padded_n_size - n_size);
|
||||
copy_stub<scalar_t, BLOCK_N>(s_delta2 + row * BLOCK_N, s_delta + row * BLOCK_N);
|
||||
}
|
||||
|
||||
// get value and pack
|
||||
pack_vnni2<scalar_t, index_t>(
|
||||
/* dst */ Btmp,
|
||||
/* src */ v_extend + (seq_extend_start_loc + n) * ve_strideN + head_kv_id * ve_strideH,
|
||||
/* ind */ nullptr,
|
||||
/* K */ n_size,
|
||||
/* N */ head_size_v,
|
||||
/* ld_src */ ve_strideN,
|
||||
/* ld_dst */ head_size_v);
|
||||
|
||||
// calculate V' <- s_delta @ V + V'
|
||||
at::native::cpublas::brgemm(
|
||||
/* M */ m_size,
|
||||
/* N */ head_size_v,
|
||||
/* K */ padded_n_size, // n_size
|
||||
/* lda */ BLOCK_N,
|
||||
/* ldb */ head_size_v,
|
||||
/* ldc */ head_size_v,
|
||||
/* add_C */ true,
|
||||
/* A */ s_delta2,
|
||||
/* B */ Btmp,
|
||||
/* C */ v_prime);
|
||||
} // loop with seq_len_extend
|
||||
|
||||
scalar_t* __restrict__ out_ptr = o_extend + (seq_extend_start_loc + m) * o_strideM + head_id * o_strideH;
|
||||
for (int row = 0; row < m_size; ++row) {
|
||||
float s = 1 / s_prime[row];
|
||||
copy_stub<scalar_t>(out_ptr + row * o_strideM, v_prime + row * head_size_v, s, head_size_v);
|
||||
}
|
||||
|
||||
// move to the next index
|
||||
data_index_step(bs, batches, head_id, num_heads, mb, MB);
|
||||
}
|
||||
at::native::cpublas::brgemm_release();
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// q_extend, k_extend, v_extend, o_extend: contiguous tensors
|
||||
// k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
|
||||
//
|
||||
// q_extend: [num_tokens, num_heads, head_size]
|
||||
// k_extend: [num_extend_tokens, num_heads, head_size]
|
||||
// v_extend: [num_extend_tokens, num_heads, head_size]
|
||||
// o_extend: [num_tokens, num_heads, head_size]
|
||||
// k_buffer: [max_total_num_tokens, num_heads, head_size]
|
||||
// v_buffer: [max_total_num_tokens, num_heads, head_size]
|
||||
// req_to_token: [max_num_reqs, max_context_len] int32 or int64
|
||||
// req_pool_indices: [num_seqs] int64
|
||||
// seq_lens: [num_seqs] int64
|
||||
// extend_seq_lens: [num_seqs]
|
||||
// extend_start_loc: [num_seqs]
|
||||
//
|
||||
void extend_attention_cpu(
|
||||
at::Tensor& q_extend,
|
||||
at::Tensor& k_extend,
|
||||
at::Tensor& v_extend,
|
||||
at::Tensor& o_extend,
|
||||
at::Tensor& k_buffer,
|
||||
at::Tensor& v_buffer,
|
||||
at::Tensor& req_to_token,
|
||||
at::Tensor& req_pool_indices,
|
||||
at::Tensor& seq_lens,
|
||||
at::Tensor& extend_seq_lens,
|
||||
at::Tensor& extend_start_loc,
|
||||
int64_t max_len_extend,
|
||||
double sm_scale,
|
||||
double logit_cap) {
|
||||
RECORD_FUNCTION(
|
||||
"sgl-kernel::extend_attention_cpu",
|
||||
std::vector<c10::IValue>(
|
||||
{q_extend,
|
||||
k_extend,
|
||||
v_extend,
|
||||
o_extend,
|
||||
k_buffer,
|
||||
v_buffer,
|
||||
req_to_token,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
extend_seq_lens,
|
||||
extend_start_loc}));
|
||||
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_extend);
|
||||
CHECK_INPUT(o_extend);
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_extend);
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_extend);
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_buffer);
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_buffer);
|
||||
|
||||
int num_seqs = seq_lens.size(0);
|
||||
int max_num_reqs = req_to_token.size(0);
|
||||
int max_context_len = req_to_token.size(1);
|
||||
int max_total_num_tokens = k_buffer.size(0);
|
||||
|
||||
int num_heads = q_extend.size(1);
|
||||
int num_heads_kv = k_extend.size(1);
|
||||
int head_size = q_extend.size(2);
|
||||
int head_size_v = v_extend.size(2);
|
||||
|
||||
// strides for q_extend, k_extend and v_extend
|
||||
int q_strideM = q_extend.stride(0);
|
||||
int q_strideH = q_extend.stride(1);
|
||||
int ke_strideN = k_extend.stride(0);
|
||||
int ke_strideH = k_extend.stride(1);
|
||||
int ve_strideN = v_extend.stride(0);
|
||||
int ve_strideH = v_extend.stride(1);
|
||||
|
||||
// strides for k_buffer and v_buffer
|
||||
int k_strideN = k_buffer.stride(0);
|
||||
int k_strideH = k_buffer.stride(1);
|
||||
int v_strideN = v_buffer.stride(0);
|
||||
int v_strideH = v_buffer.stride(1);
|
||||
|
||||
// check sizes
|
||||
CHECK_EQ(req_pool_indices.size(0), num_seqs);
|
||||
CHECK_EQ(extend_seq_lens.size(0), num_seqs);
|
||||
CHECK_EQ(extend_start_loc.size(0), num_seqs);
|
||||
CHECK_EQ(v_extend.size(1), num_heads_kv);
|
||||
CHECK_EQ(k_buffer.size(1), v_buffer.size(1));
|
||||
|
||||
// MLA will skip prefix part
|
||||
const bool is_prefix_skipped = k_buffer.size(1) != num_heads_kv;
|
||||
|
||||
// check index data types
|
||||
const auto index_dtype = req_to_token.scalar_type();
|
||||
TORCH_CHECK(
|
||||
index_dtype == at::kInt || index_dtype == at::kLong,
|
||||
"extend: expect req_to_token to be int32 or int64, got ",
|
||||
index_dtype);
|
||||
TORCH_CHECK(seq_lens.scalar_type() == at::kLong, "extend: expect req_lens to be int64, got ", seq_lens.scalar_type());
|
||||
TORCH_CHECK(
|
||||
req_pool_indices.scalar_type() == at::kLong,
|
||||
"extend: expect req_pool_indices to be int64, got ",
|
||||
req_pool_indices.scalar_type());
|
||||
TORCH_CHECK(
|
||||
extend_seq_lens.scalar_type() == index_dtype && extend_start_loc.scalar_type() == index_dtype,
|
||||
"extend: expect extend_seq_lens and extend_start_loc to have same dtype as req_to_token.");
|
||||
|
||||
// D and DV need to be 32x as we transpose by 512-bit
|
||||
TORCH_CHECK(head_size % 32 == 0, "invalid head_size ", head_size);
|
||||
TORCH_CHECK(head_size_v % 32 == 0, "invalid head_size_v ", head_size_v);
|
||||
|
||||
// block size for query seq length
|
||||
constexpr int BLOCK_M = 32;
|
||||
// block size for key/value seq length
|
||||
constexpr int BLOCK_N = 32;
|
||||
|
||||
const int size_per_thread =
|
||||
/* s_i */ BLOCK_M * BLOCK_N * sizeof(float) +
|
||||
/* v_prime */ BLOCK_M * head_size_v * sizeof(float) +
|
||||
/* s_delta */ BLOCK_M * BLOCK_N * sizeof(uint16_t) +
|
||||
/* Btmp */ BLOCK_N * std::max(head_size, head_size_v) * sizeof(uint16_t);
|
||||
|
||||
int num_threads = at::get_num_threads();
|
||||
auto buffer = at::empty({num_threads, size_per_thread}, q_extend.options().dtype(at::kChar));
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(q_extend.scalar_type(), "extend_attention_kernel", [&] {
|
||||
AT_DISPATCH_INDEX_TYPES(index_dtype, "extend_attention_indices", [&] {
|
||||
extend_attention_kernel_impl<scalar_t, index_t, BLOCK_M, BLOCK_N>(
|
||||
o_extend.data_ptr<scalar_t>(),
|
||||
q_extend.data_ptr<scalar_t>(),
|
||||
k_extend.data_ptr<scalar_t>(),
|
||||
v_extend.data_ptr<scalar_t>(),
|
||||
k_buffer.data_ptr<scalar_t>(),
|
||||
v_buffer.data_ptr<scalar_t>(),
|
||||
req_to_token.data_ptr<index_t>(),
|
||||
req_pool_indices.data_ptr<int64_t>(),
|
||||
seq_lens.data_ptr<int64_t>(),
|
||||
extend_seq_lens.data_ptr<index_t>(),
|
||||
extend_start_loc.data_ptr<index_t>(),
|
||||
buffer.data_ptr(),
|
||||
num_seqs,
|
||||
num_heads,
|
||||
num_heads_kv,
|
||||
head_size,
|
||||
head_size_v,
|
||||
q_strideM,
|
||||
q_strideH,
|
||||
ke_strideN,
|
||||
ke_strideH,
|
||||
ve_strideN,
|
||||
ve_strideH,
|
||||
k_strideN,
|
||||
k_strideH,
|
||||
v_strideN,
|
||||
v_strideH,
|
||||
sm_scale,
|
||||
logit_cap,
|
||||
max_num_reqs,
|
||||
max_context_len,
|
||||
max_total_num_tokens,
|
||||
max_len_extend,
|
||||
size_per_thread,
|
||||
is_prefix_skipped);
|
||||
});
|
||||
});
|
||||
}
|
||||
525
sgl-kernel/csrc/cpu/gemm.cpp
Normal file
525
sgl-kernel/csrc/cpu/gemm.cpp
Normal file
@@ -0,0 +1,525 @@
|
||||
#include "gemm.h"
|
||||
|
||||
#include "common.h"
|
||||
#include "vec.h"
|
||||
|
||||
namespace {
|
||||
|
||||
// packed layout:
|
||||
// quants {N, K} int8_t
|
||||
// comp {N} int32_t
|
||||
template <int BLOCK_N>
|
||||
inline void s8s8_compensation(int8_t* __restrict__ packed, int K) {
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
__m512i vcomp[COLS];
|
||||
|
||||
for (int col = 0; col < COLS; ++col) {
|
||||
vcomp[col] = _mm512_setzero_si512();
|
||||
}
|
||||
|
||||
const int64_t offset = BLOCK_N * K;
|
||||
const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));
|
||||
for (int k = 0; k < K / 4; ++k) {
|
||||
for (int col = 0; col < COLS; ++col) {
|
||||
__m512i vb = _mm512_loadu_si512((const __m512i*)(packed + k * BLOCK_N * 4 + col * 64));
|
||||
vcomp[col] = _mm512_dpbusd_epi32(vcomp[col], off, vb);
|
||||
}
|
||||
}
|
||||
|
||||
for (int col = 0; col < COLS; ++col) {
|
||||
_mm512_storeu_si512((__m512i*)(packed + offset + col * 64), vcomp[col]);
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(false, "s8s8_compensation not implemented!");
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert to vnni format
|
||||
// from [N, K] to [K/2, N, 2] for bfloat16 and float16
|
||||
template <typename packed_t>
|
||||
inline void pack_vnni(packed_t* __restrict__ packed, const packed_t* __restrict__ weight, int N, int K) {
|
||||
const int VNNI_BLK = 2;
|
||||
for (int n = 0; n < N; ++n) {
|
||||
for (int k = 0; k < K / VNNI_BLK; ++k) {
|
||||
for (int d = 0; d < VNNI_BLK; ++d) {
|
||||
packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void pack_vnni<int8_t>(int8_t* __restrict__ packed, const int8_t* __restrict__ weight, int N, int K) {
|
||||
constexpr int BLOCK_N = block_size_n();
|
||||
TORCH_CHECK(N == BLOCK_N);
|
||||
|
||||
const int VNNI_BLK = 4;
|
||||
for (int n = 0; n < N; ++n) {
|
||||
for (int k = 0; k < K / VNNI_BLK; ++k) {
|
||||
for (int d = 0; d < VNNI_BLK; ++d) {
|
||||
packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d];
|
||||
}
|
||||
}
|
||||
}
|
||||
s8s8_compensation<BLOCK_N>(packed, K);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec data0 = fVec::loadu(input + d);
|
||||
fVec data1 = fVec::loadu(input + d + fVec::size());
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_add_stub(
|
||||
scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d);
|
||||
fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size());
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d] + bias[d]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn {
|
||||
static inline void apply(
|
||||
const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
const float* __restrict__ bias,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc) {
|
||||
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn<at::BFloat16, has_bias, BLOCK_M, BLOCK_N> {
|
||||
static inline void apply(
|
||||
const at::BFloat16* __restrict__ A,
|
||||
const at::BFloat16* __restrict__ B,
|
||||
at::BFloat16* __restrict__ C,
|
||||
const float* __restrict__ bias,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc) {
|
||||
constexpr int ROWS = BLOCK_M;
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
|
||||
// prefetch distance
|
||||
constexpr int PREFETCH_SIZE_K = 0;
|
||||
|
||||
__m512bh va;
|
||||
__m512bh vb[COLS];
|
||||
__m512 vc[ROWS * COLS];
|
||||
|
||||
auto loadc = [&](auto i) {
|
||||
constexpr int col = i % COLS;
|
||||
if constexpr (has_bias) {
|
||||
vc[i] = _mm512_loadu_ps(bias + col * 16);
|
||||
} else {
|
||||
vc[i] = _mm512_set1_ps(0.f);
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(loadc);
|
||||
|
||||
const int64_t K2 = K >> 1;
|
||||
const int64_t lda2 = lda >> 1;
|
||||
const int64_t ldb2 = ldb; // ldb * 2 >> 1;
|
||||
const float* a_ptr = reinterpret_cast<const float*>(A);
|
||||
const float* b_ptr = reinterpret_cast<const float*>(B);
|
||||
|
||||
auto compute = [&](auto i, int64_t k) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k]));
|
||||
}
|
||||
if constexpr (row == 0) {
|
||||
vb[col] = (__m512bh)(_mm512_loadu_si512(b_ptr + k * ldb2 + col * 16));
|
||||
if constexpr (PREFETCH_SIZE_K > 0) {
|
||||
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0);
|
||||
}
|
||||
}
|
||||
vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]);
|
||||
};
|
||||
for (int64_t k = 0; k < K2; ++k) {
|
||||
Unroll<ROWS * COLS>{}(compute, k);
|
||||
}
|
||||
|
||||
auto storec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
// for COLS = 2, 4 use 512bit store
|
||||
// for COLS = 1, 3 use 256bit store
|
||||
if constexpr (COLS % 2 == 0) {
|
||||
if constexpr (col % 2 == 0) {
|
||||
_mm512_storeu_si512(
|
||||
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
|
||||
(__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col])));
|
||||
}
|
||||
} else {
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(C + row * ldc + col * 16), (__m256i)(_mm512_cvtneps_pbh(vc[i])));
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(storec);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
|
||||
tinygemm_kernel_nn<scalar_t, has_bias, MB_SIZE, NB_SIZE>::apply( \
|
||||
A + mb_start * lda, \
|
||||
B + nb_start * 2, \
|
||||
C + mb_start * ldc + nb_start, \
|
||||
has_bias ? bias + nb_start : nullptr, \
|
||||
K, \
|
||||
lda, \
|
||||
ldb, \
|
||||
ldc);
|
||||
|
||||
template <typename scalar_t, bool has_bias>
|
||||
struct brgemm {
|
||||
static inline void apply(
|
||||
const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc) {
|
||||
constexpr int BLOCK_N = block_size_n();
|
||||
at::native::cpublas::brgemm(M, N, K, lda, ldb, BLOCK_N, /* add_C */ false, A, B, Ctmp);
|
||||
|
||||
// copy from Ctmp to C
|
||||
for (int64_t m = 0; m < M; ++m) {
|
||||
if constexpr (has_bias) {
|
||||
copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N);
|
||||
} else {
|
||||
copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_t, bool has_bias>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg) {
|
||||
if (brg) {
|
||||
brgemm<scalar_t, has_bias>::apply(A, B, C, Ctmp, bias, M, N, K, lda, ldb, ldc);
|
||||
return;
|
||||
}
|
||||
|
||||
// pattern: 1-4-16, N = 16, 32, 48, 64
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 64;
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
for (int mb = 0; mb < MB; ++mb) {
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
for (int64_t nb = 0; nb < NB; ++nb) {
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
switch (mb_size << 4 | nb_size >> 4) {
|
||||
// mb_size = 1
|
||||
case 0x11:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(1, 16);
|
||||
break;
|
||||
case 0x12:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(1, 32);
|
||||
break;
|
||||
case 0x13:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(1, 48);
|
||||
break;
|
||||
case 0x14:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(1, 64);
|
||||
break;
|
||||
// mb_size = 2
|
||||
case 0x21:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(2, 16);
|
||||
break;
|
||||
case 0x22:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(2, 32);
|
||||
break;
|
||||
case 0x23:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(2, 48);
|
||||
break;
|
||||
case 0x24:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(2, 64);
|
||||
break;
|
||||
// mb_size = 3
|
||||
case 0x31:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(3, 16);
|
||||
break;
|
||||
case 0x32:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(3, 32);
|
||||
break;
|
||||
case 0x33:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(3, 48);
|
||||
break;
|
||||
case 0x34:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(3, 64);
|
||||
break;
|
||||
// mb_size = 4
|
||||
case 0x41:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(4, 16);
|
||||
break;
|
||||
case 0x42:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(4, 32);
|
||||
break;
|
||||
case 0x43:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(4, 48);
|
||||
break;
|
||||
case 0x44:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(4, 64);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unexpected block size, ", mb_size, " x ", nb_size);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void weight_packed_linear_kernel_impl(
|
||||
scalar_t* __restrict__ out,
|
||||
const scalar_t* __restrict__ mat1,
|
||||
const scalar_t* __restrict__ mat2,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t mat1_strideM,
|
||||
int64_t out_strideM) {
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<scalar_t>(M);
|
||||
|
||||
// parallel on [MB, NB]
|
||||
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
|
||||
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
|
||||
// for brgemm, use float32 for accumulate
|
||||
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
|
||||
|
||||
loop_2d<scalar_t>(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(M - mb_start, BLOCK_M);
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(N - nb_start, BLOCK_N);
|
||||
|
||||
tinygemm_kernel<scalar_t, has_bias>(
|
||||
/* A */ mat1 + mb_start * mat1_strideM,
|
||||
/* B */ mat2 + nb_start * K /* nb * BLOCK_N * K */,
|
||||
/* C */ out + mb_start * out_strideM + nb_start,
|
||||
/* Ctmp*/ Ctmp,
|
||||
/* bias*/ bias + nb_start,
|
||||
/* M */ mb_size,
|
||||
/* N */ nb_size,
|
||||
/* K */ K,
|
||||
/* lda */ mat1_strideM,
|
||||
/* ldb */ nb_size,
|
||||
/* ldc */ out_strideM,
|
||||
/* brg */ use_brgemm);
|
||||
});
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// tinygemm interface
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
float* __restrict__ Ctmp,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg) {
|
||||
tinygemm_kernel<scalar_t, false>(A, B, C, Ctmp, nullptr, M, N, K, lda, ldb, ldc, brg);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
|
||||
template void tinygemm_kernel<TYPE>( \
|
||||
const TYPE* __restrict__ A, \
|
||||
const TYPE* __restrict__ B, \
|
||||
TYPE* __restrict__ C, \
|
||||
float* __restrict__ Ctmp, \
|
||||
int64_t M, \
|
||||
int64_t N, \
|
||||
int64_t K, \
|
||||
int64_t lda, \
|
||||
int64_t ldb, \
|
||||
int64_t ldc, \
|
||||
bool brg)
|
||||
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
|
||||
|
||||
at::Tensor convert_weight_packed(at::Tensor& weight) {
|
||||
// for 3d moe weights
|
||||
// weight : [E, OC, IC]
|
||||
// w1 : [E, 2N, K]
|
||||
// w2 : [E, K, N]
|
||||
CHECK_INPUT(weight);
|
||||
|
||||
const int64_t ndim = weight.ndimension();
|
||||
TORCH_CHECK(ndim == 2 || ndim == 3, "expect weight to be 2d or 3d, got ", ndim, "d tensor.");
|
||||
const auto st = weight.scalar_type();
|
||||
const int64_t E = ndim == 3 ? weight.size(0) : 1;
|
||||
const int64_t OC = ndim == 3 ? weight.size(1) : weight.size(0);
|
||||
const int64_t IC = ndim == 3 ? weight.size(2) : weight.size(1);
|
||||
|
||||
// we handle 2 TILE_N at a time.
|
||||
TORCH_CHECK(OC % TILE_N == 0, "invalid weight out features ", OC);
|
||||
TORCH_CHECK(IC % TILE_K == 0, "invalid weight input features ", IC);
|
||||
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t NB = div_up(OC, BLOCK_N);
|
||||
|
||||
// use phony sizes here [E, OC, IC], for each [E], [OC, IC] -> [IC / 2, OC, 2]
|
||||
auto packed_weight = at::empty({}, weight.options());
|
||||
const int64_t stride = OC * IC;
|
||||
|
||||
TORCH_CHECK(
|
||||
st == at::kBFloat16 || st == at::kHalf || st == at::kChar || st == at::kFloat8_e4m3fn,
|
||||
"expect weight to be bfloat16, float16, int8 or fp8_e4m3.");
|
||||
|
||||
CPU_DISPATCH_PACKED_TYPES(st, [&] {
|
||||
// adjust most inner dimension size
|
||||
const int packed_row_size = get_row_size<packed_t>(IC);
|
||||
auto sizes = weight.sizes().vec();
|
||||
sizes[ndim - 1] = packed_row_size;
|
||||
packed_weight.resize_(sizes);
|
||||
|
||||
const packed_t* w_data = weight.data_ptr<packed_t>();
|
||||
packed_t* packed_data = packed_weight.data_ptr<packed_t>();
|
||||
|
||||
// parallel on {E, NB}
|
||||
at::parallel_for(0, E * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t e{0}, nb{0};
|
||||
data_index_init(begin, e, E, nb, NB);
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
UNUSED(i);
|
||||
|
||||
int64_t n = nb * BLOCK_N;
|
||||
int64_t n_size = std::min(BLOCK_N, OC - n);
|
||||
pack_vnni<packed_t>(
|
||||
packed_data + e * OC * packed_row_size + n * packed_row_size, w_data + e * stride + n * IC, n_size, IC);
|
||||
|
||||
// move to the next index
|
||||
data_index_step(e, E, nb, NB);
|
||||
}
|
||||
});
|
||||
});
|
||||
return packed_weight;
|
||||
}
|
||||
|
||||
// mat1 : [M, K]
|
||||
// mat2 : [N, K]
|
||||
// bias : [N]
|
||||
// out : [M, N]
|
||||
//
|
||||
at::Tensor
|
||||
weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, const std::optional<at::Tensor>& bias, bool is_vnni) {
|
||||
RECORD_FUNCTION("sgl-kernel::weight_packed_linear", std::vector<c10::IValue>({mat1, mat2, bias}));
|
||||
|
||||
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
||||
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
|
||||
CHECK_INPUT(mat2);
|
||||
|
||||
int64_t M = mat1.size(0);
|
||||
int64_t N = mat2.size(0);
|
||||
int64_t K = mat2.size(1);
|
||||
CHECK_EQ(mat1.size(1), K);
|
||||
CHECK_DIM(2, mat1);
|
||||
CHECK_DIM(2, mat2);
|
||||
|
||||
auto out = at::empty({M, N}, mat1.options());
|
||||
|
||||
// strides
|
||||
int64_t mat1_strideM = mat1.stride(0);
|
||||
int64_t out_strideM = out.stride(0);
|
||||
|
||||
const bool has_bias = bias.has_value();
|
||||
const float* bias_data = nullptr;
|
||||
if (has_bias) {
|
||||
CHECK_EQ(bias.value().size(0), N);
|
||||
bias_data = bias.value().data_ptr<float>();
|
||||
}
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(mat1.scalar_type(), "weight_packed_linear_kernel_impl", [&] {
|
||||
weight_packed_linear_kernel_impl<scalar_t>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
mat1.data_ptr<scalar_t>(),
|
||||
packed_w.data_ptr<scalar_t>(),
|
||||
bias_data,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
mat1_strideM,
|
||||
out_strideM);
|
||||
});
|
||||
|
||||
return out;
|
||||
}
|
||||
202
sgl-kernel/csrc/cpu/gemm.h
Normal file
202
sgl-kernel/csrc/cpu/gemm.h
Normal file
@@ -0,0 +1,202 @@
|
||||
#pragma once
|
||||
#include <ATen/native/CPUBlas.h>
|
||||
|
||||
#include "common.h"
|
||||
|
||||
// amx-bf16
|
||||
#define TILE_M 16
|
||||
#define TILE_N 16
|
||||
#define TILE_K 32
|
||||
|
||||
// block size for AMX gemm
|
||||
constexpr int block_size_m() {
|
||||
return 2 * TILE_M;
|
||||
}
|
||||
constexpr int block_size_n() {
|
||||
return 2 * TILE_N;
|
||||
}
|
||||
|
||||
// define threshold using brgemm (intel AMX)
|
||||
template <typename T>
|
||||
inline bool can_use_brgemm(int M);
|
||||
template <>
|
||||
inline bool can_use_brgemm<at::BFloat16>(int M) {
|
||||
return M > 4;
|
||||
}
|
||||
template <>
|
||||
inline bool can_use_brgemm<at::Half>(int M) {
|
||||
return true;
|
||||
}
|
||||
// this requires PyTorch 2.7 or above
|
||||
template <>
|
||||
inline bool can_use_brgemm<int8_t>(int M) {
|
||||
return M > 4;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline bool can_use_brgemm<at::Float8_e4m3fn>(int M) {
|
||||
return M > 4;
|
||||
}
|
||||
|
||||
// work around compiler internal error
|
||||
#define BLOCK_K 128 // 4 * TILE_K
|
||||
|
||||
// adjust leading dimension size for K
|
||||
template <typename T>
|
||||
inline int64_t get_row_size(int64_t K) {
|
||||
return K;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline int64_t get_row_size<int8_t>(int64_t K) {
|
||||
return K + sizeof(int32_t);
|
||||
}
|
||||
|
||||
inline int64_t get_row_size(int64_t K, bool use_int8_w8a8) {
|
||||
return use_int8_w8a8 ? K + sizeof(int32_t) : K;
|
||||
}
|
||||
|
||||
// pack weight to vnni format
|
||||
at::Tensor convert_weight_packed(at::Tensor& weight);
|
||||
|
||||
// moe implementations for int8 w8a8
|
||||
template <typename scalar_t>
|
||||
void fused_experts_int8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ ic2,
|
||||
uint8_t* __restrict__ A_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
uint8_t* __restrict__ Aq_tmp,
|
||||
float* __restrict__ As_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const int8_t* __restrict__ packed_w1,
|
||||
const int8_t* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
const float* __restrict__ topk_weights,
|
||||
const int32_t* __restrict__ sorted_ids,
|
||||
const int32_t* __restrict__ expert_ids,
|
||||
const int32_t* __restrict__ offsets,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t E,
|
||||
int64_t topk,
|
||||
int64_t num_tokens_post_pad);
|
||||
|
||||
// moe implementations for fp8 w8a16
|
||||
template <typename scalar_t>
|
||||
void fused_experts_fp8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic0,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ ic2,
|
||||
scalar_t* __restrict__ A_tmp,
|
||||
scalar_t* __restrict__ B_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
int64_t block_size_N,
|
||||
int64_t block_size_K,
|
||||
const float* __restrict__ topk_weights,
|
||||
const int32_t* __restrict__ sorted_ids,
|
||||
const int32_t* __restrict__ expert_ids,
|
||||
const int32_t* __restrict__ offsets,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t E,
|
||||
int64_t topk,
|
||||
int64_t num_tokens_post_pad);
|
||||
|
||||
// shared expert implementation for int8 w8a8
|
||||
template <typename scalar_t>
|
||||
void shared_expert_int8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic1,
|
||||
float* __restrict__ C_tmp,
|
||||
uint8_t* __restrict__ Aq_tmp,
|
||||
float* __restrict__ As_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const int8_t* __restrict__ packed_w1,
|
||||
const int8_t* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
const scalar_t* __restrict__ fused_experts_out,
|
||||
float routed_scaling_factor,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K);
|
||||
|
||||
template <typename scalar_t>
|
||||
void shared_expert_fp8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic0,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ B_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
int64_t block_size_N,
|
||||
int64_t block_size_K,
|
||||
const scalar_t* __restrict__ fused_experts_out,
|
||||
float routed_scaling_factor,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K);
|
||||
|
||||
// tinygemm interface
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
float* __restrict__ Ctmp,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg);
|
||||
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
int32_t* __restrict__ Ctmp,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg);
|
||||
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const at::Float8_e4m3fn* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
scalar_t* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ scale,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg,
|
||||
int64_t block_size_K,
|
||||
bool do_unpack = true);
|
||||
551
sgl-kernel/csrc/cpu/gemm_fp8.cpp
Normal file
551
sgl-kernel/csrc/cpu/gemm_fp8.cpp
Normal file
@@ -0,0 +1,551 @@
|
||||
#include "common.h"
|
||||
#include "gemm.h"
|
||||
#include "vec.h"
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec data0 = fVec::loadu(input + d);
|
||||
fVec data1 = fVec::loadu(input + d + fVec::size());
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_add_stub(
|
||||
scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d);
|
||||
fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size());
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d] + bias[d]);
|
||||
}
|
||||
}
|
||||
|
||||
inline void unpack_B(
|
||||
at::BFloat16* __restrict__ Btmp,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_B,
|
||||
int N,
|
||||
int K,
|
||||
int ldb,
|
||||
int ldb_tmp,
|
||||
float scale) {
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
// [K/2, N, 2]
|
||||
const int K2 = K >> 1;
|
||||
const int ldb2 = ldb; // ldb * 2 >> 1;
|
||||
const uint16_t* b_ptr = reinterpret_cast<const uint16_t*>(packed_B);
|
||||
const __m512 vd = _mm512_set1_ps(scale);
|
||||
|
||||
constexpr int BLOCK_N = block_size_n();
|
||||
static_assert(BLOCK_N == 32);
|
||||
|
||||
// prefetch distance
|
||||
constexpr int PREFETCH_SIZE_K = 64;
|
||||
|
||||
#pragma GCC unroll 4
|
||||
for (int k = 0; k < K2; ++k) {
|
||||
__m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2);
|
||||
if constexpr (PREFETCH_SIZE_K > 0) {
|
||||
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2, _MM_HINT_T0);
|
||||
}
|
||||
|
||||
__m256i b8_0 = _mm512_extracti32x8_epi32(b8, 0);
|
||||
__m256i b8_1 = _mm512_extracti32x8_epi32(b8, 1);
|
||||
|
||||
__m512bh bf16_0 = CVT_FP8_TO_BF16(b8_0);
|
||||
__m512bh bf16_1 = CVT_FP8_TO_BF16(b8_1);
|
||||
|
||||
// Apply scale
|
||||
__m512 f0_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 0));
|
||||
__m512 f0_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 1));
|
||||
__m512 f1_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 0));
|
||||
__m512 f1_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 1));
|
||||
|
||||
f0_lo = _mm512_mul_ps(f0_lo, vd);
|
||||
f0_hi = _mm512_mul_ps(f0_hi, vd);
|
||||
f1_lo = _mm512_mul_ps(f1_lo, vd);
|
||||
f1_hi = _mm512_mul_ps(f1_hi, vd);
|
||||
|
||||
bf16_0 = _mm512_cvtne2ps_pbh(f0_hi, f0_lo);
|
||||
bf16_1 = _mm512_cvtne2ps_pbh(f1_hi, f1_lo);
|
||||
|
||||
_mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 0, (__m512i)bf16_0);
|
||||
_mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 32, (__m512i)bf16_1);
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(false, "unpack_B: scalar path not implemented!");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename packed_t, bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn {
|
||||
static inline void apply(
|
||||
const scalar_t* __restrict__ A,
|
||||
const packed_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
const float* __restrict__ bias,
|
||||
const float* __restrict__ scale,
|
||||
int K,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc,
|
||||
int64_t block_size_K) {
|
||||
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn<at::BFloat16, at::Float8_e4m3fn, has_bias, BLOCK_M, BLOCK_N> {
|
||||
static inline void apply(
|
||||
const at::BFloat16* __restrict__ A,
|
||||
const at::Float8_e4m3fn* __restrict__ B,
|
||||
at::BFloat16* __restrict__ C,
|
||||
const float* __restrict__ bias,
|
||||
const float* __restrict__ scale,
|
||||
int K,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc,
|
||||
int64_t block_size_K) {
|
||||
constexpr int ROWS = BLOCK_M;
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
|
||||
const int KB = div_up(K, BLOCK_K);
|
||||
|
||||
// prefetch distance
|
||||
constexpr int PREFETCH_SIZE_K = 64;
|
||||
constexpr int PREFETCH_SIZE_KB = 1;
|
||||
|
||||
__m512bh va;
|
||||
__m512bh vb[COLS];
|
||||
__m512 vc[ROWS * COLS];
|
||||
__m512 vsum[ROWS * COLS];
|
||||
|
||||
// block quant scale
|
||||
__m512 vscale;
|
||||
|
||||
auto loadc = [&](auto i) {
|
||||
constexpr int col = i % COLS;
|
||||
if constexpr (has_bias) {
|
||||
vc[i] = _mm512_loadu_ps(bias + col * 16);
|
||||
} else {
|
||||
vc[i] = _mm512_setzero_ps();
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(loadc);
|
||||
|
||||
const int lda2 = lda >> 1;
|
||||
const int ldb2 = ldb; // ldb * 2 >> 1;
|
||||
const float* a_ptr = reinterpret_cast<const float*>(A);
|
||||
const uint16_t* b_ptr = reinterpret_cast<const uint16_t*>(B);
|
||||
|
||||
auto compute = [&](auto i, int k) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k]));
|
||||
if constexpr (PREFETCH_SIZE_K > 0) {
|
||||
_mm_prefetch(a_ptr + row * lda2 + k + PREFETCH_SIZE_K, _MM_HINT_T0);
|
||||
}
|
||||
}
|
||||
if constexpr (row == 0) {
|
||||
if constexpr (col % 2 == 0) {
|
||||
__m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2 + col * 16);
|
||||
if constexpr (PREFETCH_SIZE_K > 0) {
|
||||
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0);
|
||||
}
|
||||
vb[col + 0] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 0));
|
||||
vb[col + 1] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 1));
|
||||
}
|
||||
}
|
||||
vsum[i] = _mm512_dpbf16_ps(vsum[i], va, vb[col]);
|
||||
};
|
||||
|
||||
constexpr int BLOCK_K2 = BLOCK_K >> 1;
|
||||
for (int kb = 0; kb < KB; ++kb) {
|
||||
int kb_start = kb * BLOCK_K2;
|
||||
int kb_end = std::min(K >> 1, kb_start + BLOCK_K2);
|
||||
// 1. load scale vector
|
||||
vscale = _mm512_set1_ps(scale[kb]);
|
||||
if constexpr (PREFETCH_SIZE_KB > 0) {
|
||||
_mm_prefetch(scale + kb + PREFETCH_SIZE_KB, _MM_HINT_T0);
|
||||
}
|
||||
// 2. zero vsum for each block
|
||||
Unroll<ROWS * COLS>{}([&](auto i) { vsum[i] = _mm512_setzero_ps(); });
|
||||
// 3. accumulate across each block
|
||||
for (int k = kb_start; k < kb_end; ++k) {
|
||||
Unroll<ROWS * COLS>{}(compute, k);
|
||||
}
|
||||
// 4. apply scale
|
||||
Unroll<ROWS * COLS>{}([&](auto i) { vc[i] = _mm512_fmadd_ps(vsum[i], vscale, vc[i]); });
|
||||
}
|
||||
|
||||
auto storec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
// for COLS = 2,4 use 512bit store
|
||||
if constexpr (col % 2 == 0) {
|
||||
_mm512_storeu_si512(
|
||||
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
|
||||
(__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col])));
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(storec);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
|
||||
tinygemm_kernel_nn<scalar_t, at::Float8_e4m3fn, has_bias, MB_SIZE, NB_SIZE>::apply( \
|
||||
A + mb_start * lda, \
|
||||
B + nb_start * 2, \
|
||||
C + mb_start * ldc + nb_start, \
|
||||
has_bias ? bias + nb_start : nullptr, \
|
||||
scale, \
|
||||
K, \
|
||||
lda, \
|
||||
ldb, \
|
||||
ldc, \
|
||||
block_size_K);
|
||||
|
||||
template <typename scalar_t, typename packed_t, bool has_bias>
|
||||
struct brgemm {
|
||||
static inline void apply(
|
||||
const scalar_t* __restrict__ A,
|
||||
const packed_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
scalar_t* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ bias,
|
||||
const float* __restrict__ scale,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc,
|
||||
bool do_unpack = true) {
|
||||
TORCH_CHECK(false, "struct brgemm: primary template not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
template <bool has_bias>
|
||||
struct brgemm<at::BFloat16, at::Float8_e4m3fn, has_bias> {
|
||||
static inline void apply(
|
||||
const at::BFloat16* __restrict__ A,
|
||||
const at::Float8_e4m3fn* __restrict__ B,
|
||||
at::BFloat16* __restrict__ C,
|
||||
at::BFloat16* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ bias,
|
||||
const float* __restrict__ scale,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc,
|
||||
bool do_unpack = true) {
|
||||
constexpr int BLOCK_N = block_size_n();
|
||||
|
||||
// [K, BLOCK_N] -> [K / 2, BLOCK_N * 2]
|
||||
const int ldb_tmp = BLOCK_N;
|
||||
|
||||
if (do_unpack) {
|
||||
for (int k = 0; k < K; k += BLOCK_K) {
|
||||
int kb_size = std::min(BLOCK_K, K - k);
|
||||
|
||||
int idx = k >> 7; // k / BLOCK_K where BLOCK_K = 128
|
||||
unpack_B(Btmp + k * ldb_tmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]);
|
||||
}
|
||||
}
|
||||
|
||||
at::native::cpublas::brgemm(M, N, K, lda, ldb_tmp, BLOCK_N, /* add_C */ false, A, Btmp, Ctmp);
|
||||
|
||||
// copy from Ctmp to C
|
||||
for (int m = 0; m < M; ++m) {
|
||||
if constexpr (has_bias) {
|
||||
copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N);
|
||||
} else {
|
||||
copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_t, bool has_bias>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const at::Float8_e4m3fn* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
scalar_t* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ scale,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg,
|
||||
int64_t block_size_K,
|
||||
bool do_unpack = true) {
|
||||
if (brg) {
|
||||
brgemm<scalar_t, at::Float8_e4m3fn, has_bias>::apply(
|
||||
A, B, C, Btmp, Ctmp, bias, scale, M, N, K, lda, ldb, ldc, do_unpack);
|
||||
return;
|
||||
}
|
||||
|
||||
// pattern: 1-4-16
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 64;
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
for (int mb = 0; mb < MB; ++mb) {
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
for (int64_t nb = 0; nb < NB; ++nb) {
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
switch (mb_size << 4 | nb_size >> 4) {
|
||||
case 0x12:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(1, 32);
|
||||
break;
|
||||
case 0x22:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(2, 32);
|
||||
break;
|
||||
case 0x32:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(3, 32);
|
||||
break;
|
||||
case 0x42:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(4, 32);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void fp8_scaled_mm_kernel_impl(
|
||||
scalar_t* __restrict__ out,
|
||||
const scalar_t* __restrict__ mat1,
|
||||
const at::Float8_e4m3fn* __restrict__ mat2,
|
||||
const float* __restrict__ scales2,
|
||||
const float* __restrict__ bias,
|
||||
scalar_t* __restrict__ buffer,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t mat1_strideM,
|
||||
int64_t out_strideM,
|
||||
int64_t block_size_N,
|
||||
int64_t block_size_K,
|
||||
int64_t buffer_size_per_thread) {
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
const int64_t scale_size_K = div_up(K, block_size_K);
|
||||
const int64_t blocks_n_per_group = block_size_N / BLOCK_N;
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M);
|
||||
|
||||
// parallel on [MB, NB]
|
||||
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
|
||||
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
|
||||
int tid = get_thread_num();
|
||||
scalar_t* __restrict__ Btmp = buffer + tid * buffer_size_per_thread;
|
||||
float* __restrict__ Ctmp = (float*)((void*)(Btmp + MAX_CACHE_BLOCK_SIZE * BLOCK_N * K));
|
||||
|
||||
loop_2d<at::Float8_e4m3fn>(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
|
||||
const float* scale_ptr = scales2 + (nb / blocks_n_per_group) * scale_size_K;
|
||||
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(M - mb_start, BLOCK_M);
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(N - nb_start, BLOCK_N);
|
||||
|
||||
// only do unpacking for the first row
|
||||
bool do_unpack = (mb == mb0);
|
||||
|
||||
tinygemm_kernel<scalar_t, has_bias>(
|
||||
/* A */ mat1 + mb_start * mat1_strideM,
|
||||
/* B */ mat2 + nb_start * K, // nb * BLOCK_N * K
|
||||
/* C */ out + mb_start * out_strideM + nb_start,
|
||||
/* Btmp */ Btmp + nb_offset * BLOCK_N * K,
|
||||
/* Ctmp */ Ctmp,
|
||||
/* scale */ scale_ptr,
|
||||
/* bias */ bias + nb_start,
|
||||
/* M */ mb_size,
|
||||
/* N */ nb_size,
|
||||
/* K */ K,
|
||||
/* lda */ mat1_strideM,
|
||||
/* ldb */ nb_size,
|
||||
/* ldc */ out_strideM,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K,
|
||||
/* do_unpack */ do_unpack);
|
||||
});
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// tinygemm interface
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const at::Float8_e4m3fn* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
scalar_t* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ scale,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg,
|
||||
int64_t block_size_K,
|
||||
bool do_unpack) {
|
||||
tinygemm_kernel<scalar_t, false>(
|
||||
A, B, C, Btmp, Ctmp, scale, nullptr, M, N, K, lda, ldb, ldc, brg, block_size_K, do_unpack);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
|
||||
template void tinygemm_kernel<TYPE>( \
|
||||
const TYPE* __restrict__ A, \
|
||||
const at::Float8_e4m3fn* __restrict__ B, \
|
||||
TYPE* __restrict__ C, \
|
||||
TYPE* __restrict__ Btmp, \
|
||||
float* __restrict__ Ctmp, \
|
||||
const float* __restrict__ scale, \
|
||||
int64_t M, \
|
||||
int64_t N, \
|
||||
int64_t K, \
|
||||
int64_t lda, \
|
||||
int64_t ldb, \
|
||||
int64_t ldc, \
|
||||
bool brg, \
|
||||
int64_t block_size_K, \
|
||||
bool do_unpack)
|
||||
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
|
||||
|
||||
at::Tensor fp8_scaled_mm_cpu(
|
||||
at::Tensor& mat1,
|
||||
at::Tensor& mat2,
|
||||
at::Tensor& scales2,
|
||||
std::vector<int64_t> block_size,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype,
|
||||
bool is_vnni) {
|
||||
RECORD_FUNCTION("sgl-kernel::fp8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, block_size, bias}));
|
||||
|
||||
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
||||
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
|
||||
CHECK_INPUT(mat2);
|
||||
CHECK_INPUT(scales2);
|
||||
TORCH_CHECK(scales2.scalar_type() == at::kFloat, "fp8_scaled_mm_cpu: expect scales2 to be float32.");
|
||||
|
||||
int64_t M = mat1.size(0);
|
||||
int64_t N = mat2.size(0);
|
||||
int64_t K = mat2.size(1);
|
||||
|
||||
CHECK_EQ(mat1.size(1), K);
|
||||
CHECK_DIM(2, mat1);
|
||||
CHECK_DIM(2, mat2);
|
||||
|
||||
TORCH_CHECK(block_size.size() == 2, "fp8_scaled_mm_cpu: expect block_size.size() to be 2.");
|
||||
|
||||
int64_t block_size_N = block_size[0];
|
||||
int64_t block_size_K = block_size[1];
|
||||
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
TORCH_CHECK(block_size_N % BLOCK_N == 0, "fp8_scaled_mm_cpu: expect block_size_N to be multiples of BLOCK_N");
|
||||
TORCH_CHECK(block_size_K == BLOCK_K, "fp8_scaled_mm_cpu: expect block_size_K equals to BLOCK_K");
|
||||
CHECK_EQ(scales2.size(0), div_up(N, block_size_N));
|
||||
CHECK_EQ(scales2.size(1), div_up(K, block_size_K));
|
||||
|
||||
const auto st = mat1.scalar_type();
|
||||
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, "fp8_scaled_mm_cpu: expect A to be bfloat16 or half.");
|
||||
TORCH_CHECK(st == out_dtype, "fp8_scaled_mm_cpu: expect A has same dtype with out_dtype.");
|
||||
TORCH_CHECK(mat2.scalar_type() == at::kFloat8_e4m3fn, "fp8_scaled_mm_cpu: expect mat2 to be fp8_e4m3.");
|
||||
TORCH_CHECK(scales2.scalar_type() == at::kFloat, "fp8_scaled_mm_cpu: expect scales to be float32.");
|
||||
auto out = at::empty({M, N}, mat1.options().dtype(out_dtype));
|
||||
|
||||
// strides
|
||||
int64_t mat1_strideM = mat1.stride(0);
|
||||
int64_t out_strideM = out.stride(0);
|
||||
|
||||
const bool has_bias = bias.has_value();
|
||||
const float* bias_data = nullptr;
|
||||
if (has_bias) {
|
||||
CHECK_EQ(bias.value().size(0), N);
|
||||
bias_data = bias.value().data_ptr<float>();
|
||||
}
|
||||
|
||||
// Btmp : [T, BLOCK_N * K]
|
||||
// Ctmp : [T, BLOCK_M * BLOCK_N]
|
||||
int num_threads = at::get_num_threads();
|
||||
int64_t size_per_thread = MAX_CACHE_BLOCK_SIZE * BLOCK_N * K + BLOCK_M * BLOCK_N * 2;
|
||||
auto buffer = at::empty({num_threads, size_per_thread}, mat1.options());
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "fp8_scaled_mm_kernel_impl", [&] {
|
||||
fp8_scaled_mm_kernel_impl<scalar_t>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
mat1.data_ptr<scalar_t>(),
|
||||
packed_w.data_ptr<at::Float8_e4m3fn>(),
|
||||
scales2.data_ptr<float>(),
|
||||
bias_data,
|
||||
buffer.data_ptr<scalar_t>(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
mat1_strideM,
|
||||
out_strideM,
|
||||
block_size_N,
|
||||
block_size_K,
|
||||
size_per_thread);
|
||||
});
|
||||
|
||||
return out;
|
||||
}
|
||||
547
sgl-kernel/csrc/cpu/gemm_int8.cpp
Normal file
547
sgl-kernel/csrc/cpu/gemm_int8.cpp
Normal file
@@ -0,0 +1,547 @@
|
||||
#include "common.h"
|
||||
#include "gemm.h"
|
||||
#include "vec.h"
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t, bool has_bias, int BLOCK_N>
|
||||
struct scale_C {
|
||||
static inline void apply(
|
||||
scalar_t* __restrict__ C,
|
||||
const int32_t* __restrict__ Ctmp,
|
||||
const int32_t* __restrict__ Bcomp,
|
||||
const float* __restrict__ bias,
|
||||
float As,
|
||||
const float* __restrict__ Bs) {
|
||||
TORCH_CHECK(false, "scale_C: scalar path not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <bool has_bias, int BLOCK_N>
|
||||
struct scale_C<at::BFloat16, has_bias, BLOCK_N> {
|
||||
static inline void apply(
|
||||
at::BFloat16* __restrict__ C,
|
||||
const int32_t* __restrict__ Ctmp,
|
||||
const int32_t* __restrict__ Bcomp,
|
||||
const float* __restrict__ bias,
|
||||
float As,
|
||||
const float* __restrict__ Bs) {
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
static_assert(COLS % 2 == 0);
|
||||
|
||||
__m512 vc[COLS];
|
||||
__m512 vd0 = _mm512_set1_ps(As);
|
||||
|
||||
auto compute = [&](auto col) {
|
||||
__m512 vd1 = _mm512_loadu_ps(Bs + col * 16);
|
||||
__m512i vcomp = _mm512_loadu_si512(Bcomp + col * 16);
|
||||
__m512i vc32 = _mm512_loadu_si512(Ctmp + col * 16);
|
||||
vc[col] = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc32, vcomp));
|
||||
if constexpr (has_bias) {
|
||||
__m512 vbias = _mm512_loadu_ps(bias + col * 16);
|
||||
vc[col] = _mm512_fmadd_ps(_mm512_mul_ps(vc[col], vd0), vd1, vbias);
|
||||
} else {
|
||||
vc[col] = _mm512_mul_ps(_mm512_mul_ps(vc[col], vd0), vd1);
|
||||
}
|
||||
};
|
||||
Unroll<COLS>{}(compute);
|
||||
|
||||
auto storec = [&](auto col) {
|
||||
// for COLS = 2, 4 use 512bit store
|
||||
if constexpr (col % 2 == 0) {
|
||||
_mm512_storeu_si512(
|
||||
reinterpret_cast<__m512i*>((C + col * 16)), (__m512i)(_mm512_cvtne2ps_pbh(vc[col + 1], vc[col + 0])));
|
||||
}
|
||||
};
|
||||
Unroll<COLS>{}(storec);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename scalar_t, bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn {
|
||||
static inline void apply(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs,
|
||||
const int32_t* __restrict__ Bcomp,
|
||||
const float* __restrict__ bias,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc) {
|
||||
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn<at::BFloat16, has_bias, BLOCK_M, BLOCK_N> {
|
||||
static inline void apply(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B,
|
||||
at::BFloat16* __restrict__ C,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs,
|
||||
const int32_t* __restrict__ Bcomp,
|
||||
const float* __restrict__ bias,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc) {
|
||||
constexpr int ROWS = BLOCK_M;
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
static_assert(COLS % 2 == 0);
|
||||
|
||||
// prefetch distance
|
||||
constexpr int PREFETCH_SIZE_K = 0;
|
||||
|
||||
__m512i va;
|
||||
__m512i vb[COLS];
|
||||
__m512i vc[ROWS * COLS];
|
||||
__m512i vcomp[COLS];
|
||||
__m512 vd0;
|
||||
__m512 vd1[COLS];
|
||||
|
||||
// oops! 4x4 spills but we use 4x2
|
||||
__m512 vbias[COLS];
|
||||
|
||||
// [NOTE]: s8s8 igemm compensation in avx512-vnni
|
||||
//
|
||||
// avx512-vnni has no s8s8, so we need to change s8s8 to u8s8 with compensate:
|
||||
//
|
||||
// a * b = (a + 128) * b - 128 * b
|
||||
// s s u s u s
|
||||
//
|
||||
// 1) 128 * b is pre-computed when packing B to vnni formats
|
||||
// 2) a + 128 is fused when dynamically quantize A
|
||||
//
|
||||
auto loadc = [&](auto i) { vc[i] = _mm512_set1_epi32(0); };
|
||||
Unroll<ROWS * COLS>{}(loadc);
|
||||
|
||||
const int64_t K4 = K >> 2;
|
||||
const int64_t lda4 = lda >> 2;
|
||||
const int64_t ldb4 = ldb; // ldb * 4 >> 2;
|
||||
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A);
|
||||
const int32_t* b_ptr = reinterpret_cast<const int32_t*>(B);
|
||||
|
||||
auto compute = [&](auto i, int64_t k) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
va = _mm512_set1_epi32(a_ptr[row * lda4 + k]);
|
||||
}
|
||||
if constexpr (row == 0) {
|
||||
vb[col] = _mm512_loadu_si512(b_ptr + k * ldb4 + col * 16);
|
||||
if constexpr (PREFETCH_SIZE_K > 0) {
|
||||
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb4 + col * 16, _MM_HINT_T0);
|
||||
}
|
||||
}
|
||||
vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]);
|
||||
};
|
||||
for (int64_t k = 0; k < K4; ++k) {
|
||||
Unroll<ROWS * COLS>{}(compute, k);
|
||||
}
|
||||
|
||||
auto storec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
// load a scale
|
||||
if constexpr (col == 0) {
|
||||
vd0 = _mm512_set1_ps(As[row]);
|
||||
}
|
||||
// load b scale and vcomp per 2 vectors
|
||||
// also load bias if any
|
||||
if constexpr (row == 0) {
|
||||
if constexpr (col % 2 == 0) {
|
||||
vd1[col + 0] = _mm512_loadu_ps(Bs + col * 16);
|
||||
vd1[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16);
|
||||
vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16);
|
||||
vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16);
|
||||
if constexpr (has_bias) {
|
||||
vbias[col + 0] = _mm512_loadu_ps(bias + col * 16);
|
||||
vbias[col + 1] = _mm512_loadu_ps(bias + col * 16 + 16);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// for COLS = 2, 4 use 512bit store
|
||||
if constexpr (col % 2 == 0) {
|
||||
__m512 vc0 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 0], vcomp[col + 0]));
|
||||
__m512 vc1 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 1], vcomp[col + 1]));
|
||||
if constexpr (has_bias) {
|
||||
vc0 = _mm512_fmadd_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0], vbias[col + 0]);
|
||||
vc1 = _mm512_fmadd_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1], vbias[col + 1]);
|
||||
} else {
|
||||
vc0 = _mm512_mul_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0]);
|
||||
vc1 = _mm512_mul_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1]);
|
||||
}
|
||||
|
||||
_mm512_storeu_si512(
|
||||
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), (__m512i)(_mm512_cvtne2ps_pbh(vc1, vc0)));
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(storec);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
|
||||
tinygemm_kernel_nn<scalar_t, has_bias, MB_SIZE, NB_SIZE>::apply( \
|
||||
A + mb_start * lda, \
|
||||
B + nb_start * 4, \
|
||||
C + mb_start * ldc + nb_start, \
|
||||
As + mb_start, \
|
||||
Bs + nb_start, \
|
||||
Bcomp + nb_start, \
|
||||
has_bias ? bias + nb_start : nullptr, \
|
||||
K, \
|
||||
lda, \
|
||||
ldb, \
|
||||
ldc);
|
||||
|
||||
template <typename scalar_t, bool has_bias>
|
||||
void tinygemm_kernel(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
int32_t* __restrict__ Ctmp,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg) {
|
||||
// B compensation
|
||||
const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * K);
|
||||
|
||||
if (brg) {
|
||||
constexpr int BLOCK_N = block_size_n();
|
||||
at::native::cpublas::brgemm(M, N, K, lda, ldb, BLOCK_N, /* add_C */ false, A, B, Ctmp);
|
||||
|
||||
// apply compensation and scale
|
||||
for (int64_t m = 0; m < M; ++m) {
|
||||
scale_C<scalar_t, has_bias, BLOCK_N>::apply(C + m * ldc, Ctmp + m * BLOCK_N, Bcomp, bias, As[m], Bs);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// pattern: 1-4-16
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 64;
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
for (int64_t mb = 0; mb < MB; ++mb) {
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
for (int64_t nb = 0; nb < NB; ++nb) {
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
switch (mb_size << 4 | nb_size >> 4) {
|
||||
// mb_size = 1
|
||||
case 0x12:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(1, 32);
|
||||
break;
|
||||
case 0x14:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(1, 64);
|
||||
break;
|
||||
// mb_size = 2
|
||||
case 0x22:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(2, 32);
|
||||
break;
|
||||
case 0x24:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(2, 64);
|
||||
break;
|
||||
// mb_size = 3
|
||||
case 0x32:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(3, 32);
|
||||
break;
|
||||
case 0x34:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(3, 64);
|
||||
break;
|
||||
// mb_size = 4
|
||||
case 0x42:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(4, 32);
|
||||
break;
|
||||
case 0x44:
|
||||
LAUNCH_TINYGEMM_KERNEL_NN(4, 64);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void int8_scaled_mm_kernel_impl(
|
||||
scalar_t* __restrict__ out,
|
||||
const uint8_t* __restrict__ mat1,
|
||||
const int8_t* __restrict__ mat2,
|
||||
const float* __restrict__ scales1,
|
||||
const float* __restrict__ scales2,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K) {
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<int8_t>(M);
|
||||
|
||||
// K + 4 after compensation
|
||||
const int64_t packed_row_size = get_row_size<int8_t>(K);
|
||||
|
||||
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
|
||||
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
|
||||
// for brgemm, use int32_t for accumulate
|
||||
alignas(64) int32_t Ctmp[BLOCK_M * BLOCK_N];
|
||||
|
||||
loop_2d<int8_t>(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
|
||||
int mb_start = mb * BLOCK_M;
|
||||
int mb_size = std::min(M - mb_start, BLOCK_M);
|
||||
int nb_start = nb * BLOCK_N;
|
||||
int nb_size = std::min(N - nb_start, BLOCK_N);
|
||||
|
||||
tinygemm_kernel<scalar_t, has_bias>(
|
||||
/* A */ mat1 + mb_start * K,
|
||||
/* B */ mat2 + nb_start * packed_row_size /* nb * BLOCK_N * (K + 4) */,
|
||||
/* C */ out + mb_start * N + nb_start,
|
||||
/* Ctmp*/ Ctmp,
|
||||
/* As */ scales1 + mb_start,
|
||||
/* Bs */ scales2 + nb_start,
|
||||
/* bias*/ bias + nb_start,
|
||||
/* M */ mb_size,
|
||||
/* N */ nb_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ nb_size,
|
||||
/* ldc */ N,
|
||||
/* brg */ use_brgemm);
|
||||
});
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// tinygemm interface
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
int32_t* __restrict__ Ctmp,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg) {
|
||||
tinygemm_kernel<scalar_t, false>(A, B, C, Ctmp, As, Bs, nullptr, M, N, K, lda, ldb, ldc, brg);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
|
||||
template void tinygemm_kernel<TYPE>( \
|
||||
const uint8_t* __restrict__ A, \
|
||||
const int8_t* __restrict__ B, \
|
||||
TYPE* __restrict__ C, \
|
||||
int32_t* __restrict__ Ctmp, \
|
||||
const float* __restrict__ As, \
|
||||
const float* __restrict__ Bs, \
|
||||
int64_t M, \
|
||||
int64_t N, \
|
||||
int64_t K, \
|
||||
int64_t lda, \
|
||||
int64_t ldb, \
|
||||
int64_t ldc, \
|
||||
bool brg)
|
||||
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> per_token_quant_int8_cpu(at::Tensor& A) {
|
||||
RECORD_FUNCTION("sgl-kernel::per_token_quant_int8_cpu", std::vector<c10::IValue>({A}));
|
||||
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(A);
|
||||
CHECK_DIM(2, A);
|
||||
|
||||
int64_t M = A.size(0);
|
||||
int64_t K = A.size(1);
|
||||
int64_t lda = A.stride(0);
|
||||
|
||||
const auto st = A.scalar_type();
|
||||
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, "per_token_quant_int8: expect A to be bfloat16 or half.");
|
||||
|
||||
auto Aq = at::empty({M, K}, A.options().dtype(at::kByte));
|
||||
auto As = at::empty({M}, A.options().dtype(at::kFloat));
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "per_token_quant_int8", [&] {
|
||||
uint8_t* __restrict__ Aq_data = Aq.data_ptr<uint8_t>();
|
||||
float* __restrict__ As_data = As.data_ptr<float>();
|
||||
const scalar_t* __restrict__ A_data = A.data_ptr<scalar_t>();
|
||||
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(Aq_data + m * K, As_data[m], A_data + m * lda, K);
|
||||
}
|
||||
});
|
||||
});
|
||||
return std::make_tuple(Aq, As);
|
||||
}
|
||||
|
||||
// weight : static, per-channel, symmetric
|
||||
// activation : dynamic, per-token, symmetric
|
||||
//
|
||||
// mat1 : [M, K]
|
||||
// mat2 : [N, K]
|
||||
// scales1 : [M]
|
||||
// scales2 : [N]
|
||||
// bias : [N]
|
||||
// out : [M, N]
|
||||
//
|
||||
at::Tensor int8_scaled_mm_cpu(
|
||||
at::Tensor& mat1,
|
||||
at::Tensor& mat2,
|
||||
at::Tensor& scales1,
|
||||
at::Tensor& scales2,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype,
|
||||
bool is_vnni) {
|
||||
RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales1, scales2, bias}));
|
||||
|
||||
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
||||
|
||||
CHECK_INPUT(mat1);
|
||||
CHECK_INPUT(mat2);
|
||||
CHECK_INPUT(scales1);
|
||||
CHECK_INPUT(scales2);
|
||||
CHECK_DIM(2, mat1);
|
||||
CHECK_DIM(2, mat2);
|
||||
|
||||
int64_t M = mat1.size(0);
|
||||
int64_t N = mat2.size(0);
|
||||
int64_t K = mat1.size(1);
|
||||
|
||||
// see [NOTE]: s8s8 igemm compensation in avx512-vnni
|
||||
CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K));
|
||||
CHECK_EQ(scales1.numel(), M);
|
||||
CHECK_EQ(scales2.numel(), N);
|
||||
|
||||
TORCH_CHECK(mat1.scalar_type() == at::kByte, "int8_scaled_mm: expect mat1 to be uint8.");
|
||||
TORCH_CHECK(mat2.scalar_type() == at::kChar, "int8_scaled_mm: expect mat2 to be int8.");
|
||||
TORCH_CHECK(
|
||||
scales1.scalar_type() == at::kFloat && scales2.scalar_type() == at::kFloat,
|
||||
"int8_scaled_mm: expect scales to be float32.");
|
||||
|
||||
auto out = at::empty({M, N}, mat1.options().dtype(out_dtype));
|
||||
|
||||
const bool has_bias = bias.has_value();
|
||||
const float* bias_data = nullptr;
|
||||
if (has_bias) {
|
||||
CHECK_EQ(bias.value().size(0), N);
|
||||
bias_data = bias.value().data_ptr<float>();
|
||||
}
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_kernel_impl", [&] {
|
||||
int8_scaled_mm_kernel_impl<scalar_t>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
mat1.data_ptr<uint8_t>(),
|
||||
packed_w.data_ptr<int8_t>(),
|
||||
scales1.data_ptr<float>(),
|
||||
scales2.data_ptr<float>(),
|
||||
bias_data,
|
||||
M,
|
||||
N,
|
||||
K);
|
||||
});
|
||||
return out;
|
||||
}
|
||||
|
||||
// fused `per_token_quant_int8_cpu` and `int8_scaled_mm_cpu`
|
||||
at::Tensor int8_scaled_mm_with_quant(
|
||||
at::Tensor& mat1,
|
||||
at::Tensor& mat2,
|
||||
at::Tensor& scales2,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype,
|
||||
bool is_vnni) {
|
||||
RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, bias}));
|
||||
|
||||
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
||||
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
|
||||
CHECK_INPUT(mat2);
|
||||
CHECK_INPUT(scales2);
|
||||
CHECK_DIM(2, mat1);
|
||||
CHECK_DIM(2, mat2);
|
||||
|
||||
int64_t M = mat1.size(0);
|
||||
int64_t N = mat2.size(0);
|
||||
int64_t K = mat1.size(1);
|
||||
int64_t lda = mat1.stride(0);
|
||||
|
||||
// see [NOTE]: s8s8 igemm compensation in avx512-vnni
|
||||
CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K));
|
||||
CHECK_EQ(scales2.numel(), N);
|
||||
|
||||
const auto st = mat1.scalar_type();
|
||||
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, "int8_scaled_mm_with_quant: expect A to be bfloat16 or half.");
|
||||
TORCH_CHECK(st == out_dtype, "int8_scaled_mm_with_quant: expect A has same dtype with out_dtype.");
|
||||
TORCH_CHECK(mat2.scalar_type() == at::kChar, "int8_scaled_mm_with_quant: expect mat2 to be int8.");
|
||||
TORCH_CHECK(scales2.scalar_type() == at::kFloat, "int8_scaled_mm_with_quant: expect scales to be float32.");
|
||||
|
||||
const int64_t buffer_size = M * K + M * sizeof(float);
|
||||
auto buffer = at::empty({buffer_size}, mat1.options().dtype(at::kByte));
|
||||
auto out = at::empty({M, N}, mat1.options().dtype(out_dtype));
|
||||
|
||||
const bool has_bias = bias.has_value();
|
||||
const float* bias_data = nullptr;
|
||||
if (has_bias) {
|
||||
CHECK_EQ(bias.value().size(0), N);
|
||||
bias_data = bias.value().data_ptr<float>();
|
||||
}
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_with_quant_kernel_impl", [&] {
|
||||
uint8_t* __restrict__ Aq_data = buffer.data_ptr<uint8_t>();
|
||||
float* __restrict__ As_data = (float*)((void*)(Aq_data + M * K));
|
||||
const scalar_t* __restrict__ A_data = mat1.data_ptr<scalar_t>();
|
||||
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(Aq_data + m * K, As_data[m], A_data + m * lda, K);
|
||||
}
|
||||
});
|
||||
|
||||
int8_scaled_mm_kernel_impl<scalar_t>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
Aq_data,
|
||||
packed_w.data_ptr<int8_t>(),
|
||||
As_data,
|
||||
scales2.data_ptr<float>(),
|
||||
bias_data,
|
||||
M,
|
||||
N,
|
||||
K);
|
||||
});
|
||||
return out;
|
||||
}
|
||||
74
sgl-kernel/csrc/cpu/interface.cpp
Normal file
74
sgl-kernel/csrc/cpu/interface.cpp
Normal file
@@ -0,0 +1,74 @@
|
||||
#include <ATen/record_function.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "shm.h"
|
||||
|
||||
// Communication settings
|
||||
static int world_rank = -1;
|
||||
static int world_size = -1;
|
||||
|
||||
static bool is_initialized = false;
|
||||
|
||||
static bool all_ranks_local_p = false;
|
||||
|
||||
void initialize(int64_t size, int64_t rank) {
|
||||
if (is_initialized) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Check whether all ranks is on the same physical machine.
|
||||
// If true, we will use an SHM based low latency allreduce
|
||||
|
||||
auto ls_string = std::getenv("LOCAL_SIZE");
|
||||
int ls = 0;
|
||||
if (ls_string != NULL) {
|
||||
ls = std::stoi(std::getenv("LOCAL_SIZE"));
|
||||
}
|
||||
|
||||
if (size >= 1 && size == ls) {
|
||||
all_ranks_local_p = true;
|
||||
}
|
||||
|
||||
world_size = size;
|
||||
world_rank = rank;
|
||||
is_initialized = true;
|
||||
|
||||
const char* addr_string = std::getenv("MASTER_ADDR");
|
||||
if (addr_string == NULL) {
|
||||
addr_string = "";
|
||||
}
|
||||
const char* port_string = std::getenv("MASTER_PORT");
|
||||
if (port_string == NULL) {
|
||||
port_string = "";
|
||||
}
|
||||
|
||||
if (all_ranks_local_p) {
|
||||
shm_initialize(size, rank, addr_string, port_string);
|
||||
}
|
||||
}
|
||||
|
||||
void shm_allreduce(torch::Tensor& data, int64_t op) {
|
||||
RECORD_FUNCTION("sgl-kernel::shm_allreduce", std::vector<c10::IValue>({data}));
|
||||
|
||||
TORCH_CHECK(op == c10d::ReduceOp::SUM, "Only torch.distributed.ReduceOp.SUM is supported");
|
||||
|
||||
auto numel = data.numel();
|
||||
int data_size = numel * data.element_size();
|
||||
all_reduce_outer_loop(data, numel, data_size);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
torch::Tensor shm_allgather(torch::Tensor& data, int64_t dim) {
|
||||
RECORD_FUNCTION("sgl-kernel::shm_allgather", std::vector<c10::IValue>({data}));
|
||||
|
||||
auto numel = data.numel();
|
||||
int data_size = numel * data.element_size();
|
||||
if (dim < 0) {
|
||||
dim += data.dim();
|
||||
}
|
||||
std::vector<int64_t> result_shape = data.sizes().vec();
|
||||
result_shape[dim] *= world_size;
|
||||
torch::Tensor result_tensor = torch::empty(result_shape, data.options());
|
||||
return all_gather(result_tensor, data, dim, numel, data_size);
|
||||
}
|
||||
1322
sgl-kernel/csrc/cpu/moe.cpp
Normal file
1322
sgl-kernel/csrc/cpu/moe.cpp
Normal file
File diff suppressed because it is too large
Load Diff
491
sgl-kernel/csrc/cpu/moe_fp8.cpp
Normal file
491
sgl-kernel/csrc/cpu/moe_fp8.cpp
Normal file
@@ -0,0 +1,491 @@
|
||||
#include "common.h"
|
||||
#include "gemm.h"
|
||||
#include "vec.h"
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) {
|
||||
using Vec = at::vec::Vectorized<scalar_t>;
|
||||
// no remainder
|
||||
#pragma GCC unroll 4
|
||||
for (int64_t d = 0; d < size; d += Vec::size()) {
|
||||
Vec data = Vec::loadu(input + d);
|
||||
data.store(out + d);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_mul_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, float weight, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
const fVec weight_vec = fVec(weight);
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
bVec x = bVec::loadu(input + d);
|
||||
fVec x0, x1;
|
||||
std::tie(x0, x1) = at::vec::convert_to_float(x);
|
||||
x0 = x0 * weight_vec;
|
||||
x1 = x1 * weight_vec;
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d] * weight);
|
||||
}
|
||||
}
|
||||
|
||||
// acc from [topk, K] to [K]
|
||||
template <typename scalar_t>
|
||||
inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
if (topk == 1) {
|
||||
// do copy for topk = 1
|
||||
copy_stub(out, input, K);
|
||||
} else {
|
||||
// do sum for topk != 1
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= K - kVecSize; d += kVecSize) {
|
||||
fVec sum_fvec0 = fVec(0.f);
|
||||
fVec sum_fvec1 = fVec(0.f);
|
||||
for (int t = 0; t < topk; ++t) {
|
||||
bVec x_bvec = bVec::loadu(input + t * K + d);
|
||||
fVec x_fvec0, x_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||
|
||||
sum_fvec0 += x_fvec0;
|
||||
sum_fvec1 += x_fvec1;
|
||||
}
|
||||
bVec out_bvec = convert_from_float_ext<scalar_t>(sum_fvec0, sum_fvec1);
|
||||
out_bvec.store(out + d);
|
||||
}
|
||||
for (; d < K; ++d) {
|
||||
float sum_val = 0.f;
|
||||
for (int t = 0; t < topk; ++t) {
|
||||
sum_val += static_cast<float>(input[t * K + d]);
|
||||
}
|
||||
out[d] = static_cast<scalar_t>(sum_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// out = input + input2 * scale
|
||||
template <typename scalar_t>
|
||||
inline void add_mul_stub(
|
||||
scalar_t* __restrict__ out,
|
||||
const scalar_t* __restrict__ input,
|
||||
const scalar_t* __restrict__ input2,
|
||||
float scale,
|
||||
int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
const fVec s_vec = fVec(scale);
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
bVec x_bvec = bVec::loadu(input + d);
|
||||
fVec x0, x1;
|
||||
std::tie(x0, x1) = at::vec::convert_to_float(x_bvec);
|
||||
|
||||
bVec y_bvec = bVec::loadu(input2 + d);
|
||||
fVec y0, y1;
|
||||
std::tie(y0, y1) = at::vec::convert_to_float(y_bvec);
|
||||
|
||||
x0 = x0 + y0 * s_vec;
|
||||
x1 = x1 + y1 * s_vec;
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d] + float(input2[d]) * scale);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void silu_and_mul_stub(
|
||||
scalar_t* __restrict__ out, const scalar_t* __restrict__ input, const scalar_t* __restrict__ input2, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
const fVec one = fVec(1.f);
|
||||
|
||||
// no remainder
|
||||
#pragma GCC unroll 4
|
||||
for (int64_t d = 0; d < size; d += bVec::size()) {
|
||||
bVec x = bVec::loadu(input + d);
|
||||
fVec x0, x1;
|
||||
std::tie(x0, x1) = at::vec::convert_to_float(x);
|
||||
bVec y = bVec::loadu(input2 + d);
|
||||
fVec y0, y1;
|
||||
std::tie(y0, y1) = at::vec::convert_to_float(y);
|
||||
x0 = x0 / (one + x0.neg().exp_u20());
|
||||
x1 = x1 / (one + x1.neg().exp_u20());
|
||||
x0 = x0 * y0;
|
||||
x1 = x1 * y1;
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
template <typename scalar_t>
|
||||
void fused_experts_fp8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic0,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ ic2,
|
||||
scalar_t* __restrict__ A_tmp,
|
||||
scalar_t* __restrict__ B_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
int64_t block_size_N,
|
||||
int64_t block_size_K,
|
||||
const float* __restrict__ topk_weights,
|
||||
const int32_t* __restrict__ sorted_ids,
|
||||
const int32_t* __restrict__ expert_ids,
|
||||
const int32_t* __restrict__ offsets,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t E,
|
||||
int64_t topk,
|
||||
int64_t num_tokens_post_pad) {
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
|
||||
// stage 1: intermediate_cache0 = hidden_states @ w1
|
||||
const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M);
|
||||
const int64_t NB = div_up(2 * N, BLOCK_N);
|
||||
int64_t scale_size_N = div_up(2 * N, block_size_N);
|
||||
int64_t scale_size_K = div_up(K, block_size_K);
|
||||
int64_t blocks_n_per_group = block_size_N / BLOCK_N;
|
||||
|
||||
const int64_t stride_e = 2 * N * K;
|
||||
const int64_t stride_n = K;
|
||||
|
||||
int64_t avg_M = std::max(int64_t(1), M * topk / E);
|
||||
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(avg_M);
|
||||
|
||||
int64_t B_tmp_size_per_thread = MAX_CACHE_BLOCK_SIZE * BLOCK_N * std::max(K, N);
|
||||
|
||||
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
|
||||
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
|
||||
// get local pointers
|
||||
int tid = get_thread_num();
|
||||
scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K;
|
||||
|
||||
loop_2d<at::Float8_e4m3fn>(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
|
||||
int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
// B shape [K, n_size] in vnni format
|
||||
int32_t expert_id = expert_ids[mb];
|
||||
const at::Float8_e4m3fn* __restrict__ B = packed_w1 + expert_id * stride_e + nb * BLOCK_N * stride_n;
|
||||
const float* __restrict__ Bs =
|
||||
w1s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K;
|
||||
|
||||
// do unpacking for the first row or a new expert
|
||||
int32_t pre_expert_id = mb == 0 ? -1 : expert_ids[mb - 1];
|
||||
bool do_unpack = (mb == mb0) || (expert_id != pre_expert_id);
|
||||
|
||||
// 1.a load A
|
||||
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
|
||||
int64_t m_size = offsets[mb + 1] - offsets[mb];
|
||||
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
int32_t index = A_ids[m] / topk;
|
||||
copy_stub(A + m * K, input + index * K, K);
|
||||
}
|
||||
|
||||
const int64_t offset = offsets[mb];
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ ic0 + offset * 2 * N + nb * BLOCK_N,
|
||||
/* Btmp */ B_tmp + tid * B_tmp_size_per_thread + nb_offset * BLOCK_N * K,
|
||||
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
|
||||
/* scale */ Bs,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ 2 * N,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K,
|
||||
/* do_unpack */ do_unpack);
|
||||
});
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
|
||||
// stage 1.5: intermediate_cache1 = silu(intermediate_cache0)
|
||||
at::parallel_for(0, M * topk, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
silu_and_mul_stub(ic1 + m * N, ic0 + m * 2 * N, ic0 + m * 2 * N + N, N);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 2: intermediate_cache2 = intermediate_cache1 @ w2
|
||||
// w2 : [E, K, N] as [E, OC, IC]
|
||||
const int64_t OC = K; // rename K as OC
|
||||
const int64_t IC = N; // rename N as IC
|
||||
const int64_t MB2 = MB;
|
||||
const int64_t NB2 = div_up(OC, BLOCK_N);
|
||||
scale_size_N = div_up(K, block_size_N);
|
||||
scale_size_K = div_up(N, block_size_K);
|
||||
const int64_t stride_e2 = OC * IC;
|
||||
const int64_t stride_oc = IC;
|
||||
|
||||
// parallel on [MB2, NB2]
|
||||
parallel_2d(MB2, NB2, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
|
||||
int tid = get_thread_num();
|
||||
alignas(64) scalar_t C[BLOCK_M * BLOCK_K];
|
||||
|
||||
loop_2d<at::Float8_e4m3fn>(mb0, mb1, nb0, nb1, BLOCK_N * IC, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
|
||||
int64_t m_size = offsets[mb + 1] - offsets[mb];
|
||||
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
// A ptr from ic1 of [M * topk, N] in sorted order
|
||||
// so as to avoid copy A to tmp buffer again
|
||||
const scalar_t* __restrict__ A = ic1 + offsets[mb] * N;
|
||||
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
|
||||
|
||||
// B shape [IC, n_size] in vnni format
|
||||
int32_t expert_id = expert_ids[mb];
|
||||
const at::Float8_e4m3fn* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc;
|
||||
const float* __restrict__ Bs =
|
||||
w2s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K;
|
||||
|
||||
// do unpacking for the first row or a new expert
|
||||
int32_t pre_expert_id = mb == 0 ? -1 : expert_ids[mb - 1];
|
||||
bool do_unpack = (mb == mb0) || (expert_id != pre_expert_id);
|
||||
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ C,
|
||||
/* Btmp */ B_tmp + tid * B_tmp_size_per_thread + nb_offset * BLOCK_N * IC,
|
||||
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
|
||||
/* scale */ Bs,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ IC,
|
||||
/* lda */ IC,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K,
|
||||
/* do_unpack */ do_unpack);
|
||||
|
||||
// 2.b copy from C to ic2 in original order
|
||||
// and also mul topk_weights in float32
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
int32_t index = A_ids[m];
|
||||
float weight = topk_weights[index];
|
||||
copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size);
|
||||
}
|
||||
});
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
|
||||
// stage 3: out = intermediate_cache2.sum(dim=1)
|
||||
// from [M, topk, K] to [M, K]
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
sum_stub(output + m * K, ic2 + m * topk * K, topk, K);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#define INSTANTIATE_MOE_FP8_TEMPLATE(TYPE) \
|
||||
template void fused_experts_fp8_kernel_impl<TYPE>( \
|
||||
TYPE* __restrict__ output, \
|
||||
TYPE* __restrict__ ic0, \
|
||||
TYPE* __restrict__ ic1, \
|
||||
TYPE* __restrict__ ic2, \
|
||||
TYPE* __restrict__ A_tmp, \
|
||||
TYPE* __restrict__ B_tmp, \
|
||||
float* __restrict__ C_tmp, \
|
||||
const TYPE* __restrict__ input, \
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1, \
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2, \
|
||||
const float* __restrict__ w1s, \
|
||||
const float* __restrict__ w2s, \
|
||||
int64_t block_size_N, \
|
||||
int64_t block_size_K, \
|
||||
const float* __restrict__ topk_weights, \
|
||||
const int32_t* __restrict__ sorted_ids, \
|
||||
const int32_t* __restrict__ expert_ids, \
|
||||
const int32_t* __restrict__ offsets, \
|
||||
int64_t M, \
|
||||
int64_t N, \
|
||||
int64_t K, \
|
||||
int64_t E, \
|
||||
int64_t topk, \
|
||||
int64_t num_tokens_post_pad)
|
||||
|
||||
INSTANTIATE_MOE_FP8_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_MOE_FP8_TEMPLATE(at::Half);
|
||||
|
||||
template <typename scalar_t>
|
||||
void shared_expert_fp8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic0,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ B_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
int64_t block_size_N,
|
||||
int64_t block_size_K,
|
||||
const scalar_t* __restrict__ fused_experts_out,
|
||||
float routed_scaling_factor,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K) {
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
|
||||
// stage 1: intermediate_cache0 = hidden_states @ w1
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(2 * N, BLOCK_N);
|
||||
int64_t scale_size_K = div_up(K, block_size_K);
|
||||
int64_t blocks_n_per_group = block_size_N / BLOCK_N;
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M);
|
||||
|
||||
int64_t B_tmp_size_per_thread = MAX_CACHE_BLOCK_SIZE * BLOCK_N * std::max(K, N);
|
||||
|
||||
parallel_2d(MB, NB, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
|
||||
int tid = get_thread_num();
|
||||
|
||||
loop_2d<at::Float8_e4m3fn>(mb0, mb1, nb0, nb1, BLOCK_N * K, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
|
||||
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
// do unpacking for the first row
|
||||
bool do_unpack = (mb == mb0);
|
||||
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ input + mb * BLOCK_M * K,
|
||||
/* B */ packed_w1 + nb * BLOCK_N * K,
|
||||
/* C */ ic0 + mb * BLOCK_M * 2 * N + nb * BLOCK_N,
|
||||
/* Btmp */ B_tmp + tid * B_tmp_size_per_thread + nb_offset * BLOCK_N * K,
|
||||
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
|
||||
/* scale */ w1s + (nb / blocks_n_per_group) * scale_size_K,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ 2 * N,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K,
|
||||
/* do_unpack */ do_unpack);
|
||||
});
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
|
||||
// stage 1.5: intermediate_cache1 = silu(intermediate_cache0)
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
silu_and_mul_stub(ic1 + m * N, ic0 + m * 2 * N, ic0 + m * 2 * N + N, N);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 2: intermediate_cache2 = intermediate_cache1 @ w2
|
||||
// w2 : [K, N] as [OC, IC]
|
||||
const int64_t OC = K; // rename K as OC
|
||||
const int64_t IC = N; // rename N as IC
|
||||
const int64_t MB2 = MB;
|
||||
const int64_t NB2 = div_up(K, BLOCK_N);
|
||||
scale_size_K = div_up(N, block_size_K);
|
||||
|
||||
// parallel on [MB2, NB2]
|
||||
parallel_2d(MB2, NB2, [&](int64_t mb0, int64_t mb1, int64_t nb0, int64_t nb1) {
|
||||
int tid = get_thread_num();
|
||||
alignas(64) scalar_t C[BLOCK_M * BLOCK_K];
|
||||
|
||||
loop_2d<at::Float8_e4m3fn>(mb0, mb1, nb0, nb1, BLOCK_N * IC, [&](int64_t mb, int64_t nb, int64_t nb_offset) {
|
||||
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
// do unpacking for the first row
|
||||
bool do_unpack = (mb == mb0);
|
||||
|
||||
// 2.a gemm: C = A @ B
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ ic1 + mb * BLOCK_M * N,
|
||||
/* B */ packed_w2 + nb * BLOCK_N * N,
|
||||
/* C */ C,
|
||||
/* Btmp */ B_tmp + tid * B_tmp_size_per_thread + nb_offset * BLOCK_N * IC,
|
||||
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
|
||||
/* scale */ w2s + (nb / blocks_n_per_group) * scale_size_K,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ IC,
|
||||
/* lda */ IC,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K,
|
||||
/* do_unpack */ do_unpack);
|
||||
|
||||
// 2.b copy from C to output and add fused_experts_out
|
||||
scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N;
|
||||
const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N;
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size);
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(TYPE) \
|
||||
template void shared_expert_fp8_kernel_impl<TYPE>( \
|
||||
TYPE* __restrict__ output, \
|
||||
TYPE* __restrict__ ic0, \
|
||||
TYPE* __restrict__ ic1, \
|
||||
TYPE* __restrict__ B_tmp, \
|
||||
float* __restrict__ C_tmp, \
|
||||
const TYPE* __restrict__ input, \
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1, \
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2, \
|
||||
const float* __restrict__ w1s, \
|
||||
const float* __restrict__ w2s, \
|
||||
int64_t block_size_N, \
|
||||
int64_t block_size_K, \
|
||||
const TYPE* __restrict__ fused_experts_out, \
|
||||
float routed_scaling_factor, \
|
||||
int64_t M, \
|
||||
int64_t N, \
|
||||
int64_t K)
|
||||
|
||||
INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(at::Half);
|
||||
1068
sgl-kernel/csrc/cpu/moe_int8.cpp
Normal file
1068
sgl-kernel/csrc/cpu/moe_int8.cpp
Normal file
File diff suppressed because it is too large
Load Diff
304
sgl-kernel/csrc/cpu/norm.cpp
Normal file
304
sgl-kernel/csrc/cpu/norm.cpp
Normal file
@@ -0,0 +1,304 @@
|
||||
#include "common.h"
|
||||
#include "vec.h"
|
||||
|
||||
namespace {
|
||||
|
||||
// NB: avoid using `at::vec::map<>` on bfloat16 or half
|
||||
// Llama4TextL2Norm
|
||||
template <typename scalar_t>
|
||||
void l2norm_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
const scalar_t* __restrict__ input,
|
||||
int64_t batch_size,
|
||||
int64_t hidden_size,
|
||||
float eps = 1e-5) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
|
||||
constexpr int kVecSize = bVec::size();
|
||||
at::parallel_for(0, batch_size, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
// local ptrs
|
||||
scalar_t* __restrict__ out_ptr = output + i * hidden_size;
|
||||
const scalar_t* __restrict__ input_ptr = input + i * hidden_size;
|
||||
|
||||
fVec sum_fvec = fVec(float(0));
|
||||
float sum_val = float(0);
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
|
||||
bVec x_bvec = bVec::loadu(input_ptr + d);
|
||||
fVec x_fvec0, x_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||
|
||||
sum_fvec += x_fvec0 * x_fvec0;
|
||||
sum_fvec += x_fvec1 * x_fvec1;
|
||||
}
|
||||
#pragma GCC unroll 4
|
||||
for (; d < hidden_size; ++d) {
|
||||
float x_val = static_cast<float>(input_ptr[d]);
|
||||
sum_val += x_val * x_val;
|
||||
}
|
||||
|
||||
sum_val += vec_reduce_sum(sum_fvec);
|
||||
float rsqrt_var = float(1) / std::sqrt(sum_val / hidden_size + eps);
|
||||
const fVec scale_fvec = fVec(rsqrt_var);
|
||||
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
|
||||
bVec x_bvec = bVec::loadu(input_ptr + d);
|
||||
fVec x_fvec0, x_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||
|
||||
x_fvec0 = x_fvec0 * scale_fvec;
|
||||
x_fvec1 = x_fvec1 * scale_fvec;
|
||||
|
||||
bVec out_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
|
||||
out_bvec.store(out_ptr + d);
|
||||
}
|
||||
#pragma GCC unroll 4
|
||||
for (; d < hidden_size; ++d) {
|
||||
float x_val = static_cast<float>(input_ptr[d]);
|
||||
out_ptr[d] = static_cast<scalar_t>(x_val * rsqrt_var);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
template <typename scalar_t>
|
||||
void rmsnorm_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
const scalar_t* __restrict__ input,
|
||||
const scalar_t* __restrict__ weight,
|
||||
int64_t batch_size,
|
||||
int64_t hidden_size,
|
||||
int64_t input_strideN,
|
||||
float eps = 1e-5) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
|
||||
constexpr int kVecSize = bVec::size();
|
||||
at::parallel_for(0, batch_size, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
// local ptrs
|
||||
scalar_t* __restrict__ out_ptr = output + i * hidden_size;
|
||||
const scalar_t* __restrict__ input_ptr = input + i * input_strideN;
|
||||
|
||||
fVec sum_fvec = fVec(float(0));
|
||||
float sum_val = float(0);
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
|
||||
bVec x_bvec = bVec::loadu(input_ptr + d);
|
||||
fVec x_fvec0, x_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||
|
||||
sum_fvec += x_fvec0 * x_fvec0;
|
||||
sum_fvec += x_fvec1 * x_fvec1;
|
||||
}
|
||||
#pragma GCC unroll 4
|
||||
for (; d < hidden_size; ++d) {
|
||||
float x_val = static_cast<float>(input_ptr[d]);
|
||||
sum_val += x_val * x_val;
|
||||
}
|
||||
|
||||
sum_val += vec_reduce_sum(sum_fvec);
|
||||
float rsqrt_var = float(1) / std::sqrt(sum_val / hidden_size + eps);
|
||||
const fVec scale_fvec = fVec(rsqrt_var);
|
||||
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
|
||||
bVec x_bvec = bVec::loadu(input_ptr + d);
|
||||
fVec x_fvec0, x_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||
|
||||
bVec w_bvec = bVec::loadu(weight + d);
|
||||
fVec w_fvec0, w_fvec1;
|
||||
std::tie(w_fvec0, w_fvec1) = at::vec::convert_to_float(w_bvec);
|
||||
|
||||
x_fvec0 = x_fvec0 * scale_fvec * w_fvec0;
|
||||
x_fvec1 = x_fvec1 * scale_fvec * w_fvec1;
|
||||
|
||||
bVec out_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
|
||||
out_bvec.store(out_ptr + d);
|
||||
}
|
||||
#pragma GCC unroll 4
|
||||
for (; d < hidden_size; ++d) {
|
||||
float x_val = static_cast<float>(input_ptr[d]);
|
||||
float w_val = static_cast<float>(weight[d]);
|
||||
out_ptr[d] = static_cast<scalar_t>(x_val * rsqrt_var * w_val);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void fused_add_rmsnorm_kernel_impl(
|
||||
scalar_t* __restrict__ input,
|
||||
scalar_t* __restrict__ residual,
|
||||
const scalar_t* __restrict__ weight,
|
||||
float* __restrict__ buffer,
|
||||
int64_t batch_size,
|
||||
int64_t hidden_size,
|
||||
int64_t input_strideN,
|
||||
float eps = 1e-5) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
|
||||
constexpr int kVecSize = bVec::size();
|
||||
at::parallel_for(0, batch_size, 0, [&](int64_t begin, int64_t end) {
|
||||
int tid = at::get_thread_num();
|
||||
float* __restrict__ buffer_ptr = buffer + tid * hidden_size;
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
// local ptrs
|
||||
scalar_t* __restrict__ input_ptr = input + i * input_strideN;
|
||||
scalar_t* __restrict__ residual_ptr = residual + i * hidden_size;
|
||||
|
||||
fVec sum_fvec = fVec(float(0));
|
||||
float sum_val = float(0);
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
|
||||
bVec x_bvec = bVec::loadu(input_ptr + d);
|
||||
fVec x_fvec0, x_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||
|
||||
bVec r_bvec = bVec::loadu(residual_ptr + d);
|
||||
fVec r_fvec0, r_fvec1;
|
||||
std::tie(r_fvec0, r_fvec1) = at::vec::convert_to_float(r_bvec);
|
||||
|
||||
x_fvec0 += r_fvec0;
|
||||
x_fvec1 += r_fvec1;
|
||||
|
||||
bVec out_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
|
||||
out_bvec.store(residual_ptr + d);
|
||||
|
||||
sum_fvec += x_fvec0 * x_fvec0;
|
||||
sum_fvec += x_fvec1 * x_fvec1;
|
||||
|
||||
x_fvec0.store(buffer_ptr + d);
|
||||
x_fvec1.store(buffer_ptr + d + fVec::size());
|
||||
}
|
||||
#pragma GCC unroll 4
|
||||
for (; d < hidden_size; ++d) {
|
||||
float x_val = static_cast<float>(input_ptr[d]);
|
||||
float r_val = static_cast<float>(residual_ptr[d]);
|
||||
|
||||
x_val += r_val;
|
||||
residual_ptr[d] = static_cast<scalar_t>(x_val);
|
||||
|
||||
sum_val += x_val * x_val;
|
||||
buffer_ptr[d] = x_val;
|
||||
}
|
||||
|
||||
sum_val += vec_reduce_sum(sum_fvec);
|
||||
float rsqrt_var = float(1) / std::sqrt(sum_val / hidden_size + eps);
|
||||
const fVec scale_fvec = fVec(rsqrt_var);
|
||||
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
|
||||
fVec x_fvec0 = fVec::loadu(buffer_ptr + d);
|
||||
fVec x_fvec1 = fVec::loadu(buffer_ptr + d + fVec::size());
|
||||
|
||||
bVec w_bvec = bVec::loadu(weight + d);
|
||||
fVec w_fvec0, w_fvec1;
|
||||
std::tie(w_fvec0, w_fvec1) = at::vec::convert_to_float(w_bvec);
|
||||
|
||||
x_fvec0 = x_fvec0 * scale_fvec * w_fvec0;
|
||||
x_fvec1 = x_fvec1 * scale_fvec * w_fvec1;
|
||||
bVec x_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
|
||||
x_bvec.store(input_ptr + d);
|
||||
}
|
||||
#pragma GCC unroll 4
|
||||
for (; d < hidden_size; ++d) {
|
||||
float x_val = buffer_ptr[d] * rsqrt_var * static_cast<float>(weight[d]);
|
||||
input_ptr[d] = x_val;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// input : {batch_size, hidden_size}
|
||||
at::Tensor l2norm_cpu(at::Tensor& input, double eps) {
|
||||
RECORD_FUNCTION("sgl-kernel::l2norm_cpu", std::vector<c10::IValue>({input}));
|
||||
|
||||
CHECK_INPUT(input);
|
||||
CHECK_DIM(2, input);
|
||||
int64_t batch_size = input.size(0);
|
||||
int64_t hidden_size = input.size(1);
|
||||
at::Tensor output = at::empty_like(input);
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "l2norm_kernel", [&] {
|
||||
l2norm_kernel_impl<scalar_t>(output.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), batch_size, hidden_size, eps);
|
||||
});
|
||||
return output;
|
||||
}
|
||||
|
||||
// input : {batch_size, hidden_size}
|
||||
// weight: {hidden_size}
|
||||
at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) {
|
||||
RECORD_FUNCTION("sgl-kernel::rmsnorm_cpu", std::vector<c10::IValue>({input, weight}));
|
||||
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
|
||||
CHECK_INPUT(weight);
|
||||
CHECK_DIM(2, input);
|
||||
CHECK_DIM(1, weight);
|
||||
CHECK_EQ(input.size(1), weight.size(0));
|
||||
int64_t batch_size = input.size(0);
|
||||
int64_t hidden_size = input.size(1);
|
||||
at::Tensor output = at::empty_like(input);
|
||||
int64_t input_strideN = input.stride(0);
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "rmsnorm_kernel", [&] {
|
||||
rmsnorm_kernel_impl<scalar_t>(
|
||||
output.data_ptr<scalar_t>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(),
|
||||
batch_size,
|
||||
hidden_size,
|
||||
input_strideN,
|
||||
eps);
|
||||
});
|
||||
return output;
|
||||
}
|
||||
|
||||
// input : {batch_size, hidden_size}
|
||||
// residual: {batch_size, hidden_size}
|
||||
// weight : {hidden_size}
|
||||
void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps) {
|
||||
RECORD_FUNCTION("sgl-kernel::fused_add_rmsnorm_cpu", std::vector<c10::IValue>({input, residual, weight}));
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
|
||||
CHECK_INPUT(residual);
|
||||
CHECK_INPUT(weight);
|
||||
CHECK_DIM(2, input);
|
||||
CHECK_DIM(2, residual);
|
||||
CHECK_DIM(1, weight);
|
||||
CHECK_EQ(input.size(0), residual.size(0));
|
||||
CHECK_EQ(input.size(1), residual.size(1));
|
||||
CHECK_EQ(input.size(1), weight.size(0));
|
||||
int64_t batch_size = input.size(0);
|
||||
int64_t hidden_size = input.size(1);
|
||||
int64_t input_strideN = input.stride(0);
|
||||
|
||||
// allocate temp buffer to store x in float32 per thread
|
||||
// TODO: implement a singleton for context
|
||||
int64_t num_threads = at::get_num_threads();
|
||||
at::Tensor buffer = at::empty({num_threads, hidden_size}, input.options().dtype(at::kFloat));
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "fused_add_rmsnorm_kernel", [&] {
|
||||
fused_add_rmsnorm_kernel_impl<scalar_t>(
|
||||
input.data_ptr<scalar_t>(),
|
||||
residual.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(),
|
||||
buffer.data_ptr<float>(),
|
||||
batch_size,
|
||||
hidden_size,
|
||||
input_strideN,
|
||||
eps);
|
||||
});
|
||||
}
|
||||
91
sgl-kernel/csrc/cpu/numa_utils.cpp
Normal file
91
sgl-kernel/csrc/cpu/numa_utils.cpp
Normal file
@@ -0,0 +1,91 @@
|
||||
#include <numa.h>
|
||||
#include <sched.h>
|
||||
#include <sys/syscall.h>
|
||||
#include <sys/types.h>
|
||||
#include <unistd.h>
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "common.h"
|
||||
|
||||
std::string init_cpu_threads_env(const std::string& cpu_ids) {
|
||||
bitmask* omp_cpu_mask = numa_parse_cpustring(cpu_ids.c_str());
|
||||
TORCH_CHECK(omp_cpu_mask->size > 0);
|
||||
std::vector<int> omp_cpu_ids;
|
||||
omp_cpu_ids.reserve(omp_cpu_mask->size);
|
||||
|
||||
constexpr int group_size = 8 * sizeof(*omp_cpu_mask->maskp);
|
||||
|
||||
for (int offset = 0; offset < omp_cpu_mask->size; offset += group_size) {
|
||||
unsigned long group_mask = omp_cpu_mask->maskp[offset / group_size];
|
||||
int i = 0;
|
||||
while (group_mask) {
|
||||
if (group_mask & 1) {
|
||||
omp_cpu_ids.emplace_back(offset + i);
|
||||
}
|
||||
++i;
|
||||
group_mask >>= 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Memory node binding
|
||||
if (numa_available() != -1) {
|
||||
int mem_node_id = numa_node_of_cpu(omp_cpu_ids.front());
|
||||
bitmask* mask = numa_parse_nodestring(std::to_string(mem_node_id).c_str());
|
||||
bitmask* src_mask = numa_get_membind();
|
||||
|
||||
int pid = getpid();
|
||||
|
||||
// move all existing pages to the specified numa node.
|
||||
*(src_mask->maskp) = *(src_mask->maskp) ^ *(mask->maskp);
|
||||
int page_num = numa_migrate_pages(pid, src_mask, mask);
|
||||
if (page_num == -1) {
|
||||
TORCH_WARN(false, "numa_migrate_pages failed. errno: " + std::to_string(errno));
|
||||
}
|
||||
|
||||
// restrict memory allocation node.
|
||||
numa_set_membind(mask);
|
||||
numa_set_strict(1);
|
||||
}
|
||||
|
||||
// OMP threads binding
|
||||
omp_set_num_threads((int)omp_cpu_ids.size());
|
||||
at::set_num_threads((int)omp_cpu_ids.size());
|
||||
TORCH_CHECK_EQ(omp_cpu_ids.size(), at::get_num_threads());
|
||||
TORCH_CHECK_EQ(omp_cpu_ids.size(), omp_get_max_threads());
|
||||
|
||||
std::vector<std::pair<int, int>> thread_core_mapping;
|
||||
thread_core_mapping.reserve(omp_cpu_ids.size());
|
||||
omp_lock_t writelock;
|
||||
omp_init_lock(&writelock);
|
||||
|
||||
#pragma omp parallel for schedule(static, 1)
|
||||
for (size_t i = 0; i < omp_cpu_ids.size(); ++i) {
|
||||
cpu_set_t mask;
|
||||
CPU_ZERO(&mask);
|
||||
CPU_SET(omp_cpu_ids[i], &mask);
|
||||
int ret = sched_setaffinity(0, sizeof(cpu_set_t), &mask);
|
||||
if (ret == -1) {
|
||||
TORCH_CHECK(false, "sched_setaffinity failed. errno: " + std::to_string(errno));
|
||||
}
|
||||
|
||||
omp_set_lock(&writelock);
|
||||
thread_core_mapping.emplace_back(syscall(SYS_gettid), omp_cpu_ids[i]);
|
||||
omp_unset_lock(&writelock);
|
||||
}
|
||||
|
||||
omp_destroy_lock(&writelock);
|
||||
|
||||
numa_free_nodemask(omp_cpu_mask);
|
||||
|
||||
std::stringstream ss;
|
||||
ss << "OMP threads binding of Process " << getpid() << ":\n";
|
||||
std::sort(
|
||||
thread_core_mapping.begin(), thread_core_mapping.end(), [](auto&& a, auto&& b) { return a.second < b.second; });
|
||||
for (auto&& item : thread_core_mapping) {
|
||||
ss << "\t"
|
||||
<< "OMP tid: " << item.first << ", core " << item.second << "\n";
|
||||
}
|
||||
|
||||
return ss.str();
|
||||
}
|
||||
701
sgl-kernel/csrc/cpu/qkv_proj.cpp
Normal file
701
sgl-kernel/csrc/cpu/qkv_proj.cpp
Normal file
@@ -0,0 +1,701 @@
|
||||
#include "common.h"
|
||||
#include "gemm.h"
|
||||
#include "vec.h"
|
||||
|
||||
namespace {
|
||||
|
||||
// [NOTE]: Fused kernel for QKV projection with weight absorption and RoPE
|
||||
//
|
||||
// 1. `q_a_proj` and `kv_a_proj_with_mqa` fused into one gemm,
|
||||
// otherwise we need to split IC for the 2nd gemm.
|
||||
// 2. `q_a_layernorm` and `kv_a_layernorm` fused into one parallel loop.
|
||||
// 3. k_input and v_input share the same storage, the torch API did
|
||||
// this in `set_kv_buffer`. No additional memory movement.
|
||||
//
|
||||
|
||||
// [C0, C1] = A @ [B0, B1]
|
||||
template <typename scalar_t>
|
||||
void segment_gemm_kernel_impl(
|
||||
scalar_t* __restrict__ C0,
|
||||
scalar_t* __restrict__ C1,
|
||||
const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ B0,
|
||||
const scalar_t* __restrict__ B1,
|
||||
int64_t M,
|
||||
int64_t N0,
|
||||
int64_t N1,
|
||||
int64_t K) {
|
||||
// convert_weight_packed make sure N0 and N1 are 32x
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB0 = div_up(N0, BLOCK_N);
|
||||
const int64_t NB1 = div_up(N1, BLOCK_N);
|
||||
const int64_t NB = NB0 + NB1;
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<scalar_t>(M);
|
||||
|
||||
// parallel on [MB, NB0 + NB1]
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t mb{0}, nb{0};
|
||||
data_index_init(begin, mb, MB, nb, NB);
|
||||
|
||||
// for brgemm, use float32 for accumulate
|
||||
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
UNUSED(i);
|
||||
int mb_start = mb * BLOCK_M;
|
||||
int mb_size = std::min(M - mb_start, BLOCK_M);
|
||||
int nb_start = nb * BLOCK_N;
|
||||
int nb_size = BLOCK_N;
|
||||
|
||||
const scalar_t* __restrict__ B = nb < NB0 ? B0 : B1;
|
||||
scalar_t* __restrict__ C = nb < NB0 ? C0 : C1;
|
||||
int64_t ldc = nb < NB0 ? N0 : N1;
|
||||
int64_t local_nb_start = nb < NB0 ? nb_start : nb_start - N0;
|
||||
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A + mb_start * K,
|
||||
/* B */ B + local_nb_start * K /* nb * BLOCK_N * K */,
|
||||
/* C */ C + mb_start * ldc + local_nb_start,
|
||||
/* Ctmp*/ Ctmp,
|
||||
/* M */ mb_size,
|
||||
/* N */ nb_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ nb_size,
|
||||
/* ldc */ ldc,
|
||||
/* brg */ use_brgemm);
|
||||
|
||||
// move to the next index
|
||||
data_index_step(mb, MB, nb, NB);
|
||||
}
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// [C0, C1] = A @ [B0, B1]
|
||||
template <typename scalar_t>
|
||||
void segment_gemm_kernel_impl(
|
||||
scalar_t* __restrict__ C0,
|
||||
scalar_t* __restrict__ C1,
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B0,
|
||||
const int8_t* __restrict__ B1,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs0,
|
||||
const float* __restrict__ Bs1,
|
||||
int64_t M,
|
||||
int64_t N0,
|
||||
int64_t N1,
|
||||
int64_t K) {
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB0 = div_up(N0, BLOCK_N);
|
||||
const int64_t NB1 = div_up(N1, BLOCK_N);
|
||||
const int64_t NB = NB0 + NB1;
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<int8_t>(M);
|
||||
|
||||
// K + 4 after compensation
|
||||
const int64_t packed_row_size = get_row_size<int8_t>(K);
|
||||
|
||||
// parallel on [MB, NB0 + NB1]
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t mb{0}, nb{0};
|
||||
data_index_init(begin, mb, MB, nb, NB);
|
||||
|
||||
// for brgemm, use float32 for accumulate
|
||||
alignas(64) int32_t Ctmp[BLOCK_M * BLOCK_N];
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
UNUSED(i);
|
||||
int mb_start = mb * BLOCK_M;
|
||||
int mb_size = std::min(M - mb_start, BLOCK_M);
|
||||
int nb_start = nb * BLOCK_N;
|
||||
int nb_size = BLOCK_N;
|
||||
|
||||
const int8_t* __restrict__ B = nb < NB0 ? B0 : B1;
|
||||
const float* __restrict__ Bs = nb < NB0 ? Bs0 : Bs1;
|
||||
scalar_t* __restrict__ C = nb < NB0 ? C0 : C1;
|
||||
int64_t ldc = nb < NB0 ? N0 : N1;
|
||||
int64_t local_nb_start = nb < NB0 ? nb_start : nb_start - N0;
|
||||
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A + mb_start * K,
|
||||
/* B */ B + local_nb_start * packed_row_size /* nb * BLOCK_N * (K + 4) */,
|
||||
/* C */ C + mb_start * ldc + local_nb_start,
|
||||
/* Ctmp*/ Ctmp,
|
||||
/* As */ As + mb_start,
|
||||
/* Bs */ Bs + local_nb_start,
|
||||
/* M */ mb_size,
|
||||
/* N */ nb_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ nb_size,
|
||||
/* ldc */ ldc,
|
||||
/* brg */ use_brgemm);
|
||||
|
||||
// move to the next index
|
||||
data_index_step(mb, MB, nb, NB);
|
||||
}
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// [C0, C1] = A @ [B0, B1]
|
||||
template <typename scalar_t>
|
||||
void segment_gemm_kernel_impl(
|
||||
scalar_t* __restrict__ C0,
|
||||
scalar_t* __restrict__ C1,
|
||||
const scalar_t* __restrict__ A,
|
||||
const at::Float8_e4m3fn* __restrict__ B0,
|
||||
const at::Float8_e4m3fn* __restrict__ B1,
|
||||
const float* __restrict__ Bs0,
|
||||
const float* __restrict__ Bs1,
|
||||
scalar_t* __restrict__ Btmp,
|
||||
int64_t M,
|
||||
int64_t N0,
|
||||
int64_t N1,
|
||||
int64_t K,
|
||||
int64_t block_size_N,
|
||||
int64_t block_size_K) {
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB0 = div_up(N0, BLOCK_N);
|
||||
const int64_t NB1 = div_up(N1, BLOCK_N);
|
||||
const int64_t NB = NB0 + NB1;
|
||||
|
||||
const int64_t scale_size_K = div_up(K, block_size_K);
|
||||
const int64_t blocks_n_per_group = block_size_N / BLOCK_N;
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M);
|
||||
|
||||
// parallel on [MB, NB0 + NB1]
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t mb{0}, nb{0};
|
||||
data_index_init(begin, mb, MB, nb, NB);
|
||||
|
||||
int tid = at::get_thread_num();
|
||||
// for brgemm, use float32 for accumulate
|
||||
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
UNUSED(i);
|
||||
|
||||
int mb_start = mb * BLOCK_M;
|
||||
int mb_size = std::min(M - mb_start, BLOCK_M);
|
||||
int nb_start = nb * BLOCK_N;
|
||||
int nb_size = BLOCK_N;
|
||||
|
||||
const at::Float8_e4m3fn* __restrict__ B = nb < NB0 ? B0 : B1;
|
||||
const float* __restrict__ Bs = nb < NB0 ? Bs0 : Bs1;
|
||||
scalar_t* __restrict__ C = nb < NB0 ? C0 : C1;
|
||||
int64_t ldc = nb < NB0 ? N0 : N1;
|
||||
int64_t local_nb_start = nb < NB0 ? nb_start : nb_start - N0;
|
||||
int64_t new_nb = nb < NB0 ? nb : nb - NB0;
|
||||
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A + mb_start * K,
|
||||
/* B */ B + local_nb_start * K /* nb * BLOCK_N * K */,
|
||||
/* C */ C + mb_start * ldc + local_nb_start,
|
||||
/* Btmp*/ Btmp + tid * BLOCK_N * K,
|
||||
/* Ctmp*/ Ctmp,
|
||||
/* Bs */ Bs + (new_nb / blocks_n_per_group) * scale_size_K,
|
||||
/* M */ mb_size,
|
||||
/* N */ nb_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ nb_size,
|
||||
/* ldc */ ldc,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K);
|
||||
|
||||
// move to the next index
|
||||
data_index_step(mb, MB, nb, NB);
|
||||
}
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline float reduce(const scalar_t* __restrict__ x, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
fVec sum_fvec = fVec(float(0));
|
||||
|
||||
// no remainder
|
||||
#pragma GCC unroll 4
|
||||
for (int64_t d = 0; d < size; d += bVec::size()) {
|
||||
bVec x_bvec = bVec::loadu(x + d);
|
||||
fVec x_fvec0, x_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||
sum_fvec += x_fvec0 * x_fvec0;
|
||||
sum_fvec += x_fvec1 * x_fvec1;
|
||||
}
|
||||
return vec_reduce_sum(sum_fvec);
|
||||
}
|
||||
|
||||
// map2 from aten functional doesn't have fast bf16->fp32 conversion
|
||||
template <typename scalar_t>
|
||||
inline void map2(scalar_t* y, const scalar_t* x, const scalar_t* __restrict__ w, float scale, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
fVec scale_fvec = fVec(scale);
|
||||
|
||||
// no remainder
|
||||
#pragma GCC unroll 4
|
||||
for (int64_t d = 0; d < size; d += bVec::size()) {
|
||||
bVec x_bvec = bVec::loadu(x + d);
|
||||
fVec x_fvec0, x_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||
bVec w_bvec = bVec::loadu(w + d);
|
||||
fVec w_fvec0, w_fvec1;
|
||||
std::tie(w_fvec0, w_fvec1) = at::vec::convert_to_float(w_bvec);
|
||||
x_fvec0 = x_fvec0 * scale_fvec * w_fvec0;
|
||||
x_fvec1 = x_fvec1 * scale_fvec * w_fvec1;
|
||||
bVec out_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
|
||||
out_bvec.store(y + d);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void rms_norm_kernel_impl(
|
||||
scalar_t* __restrict__ input0,
|
||||
scalar_t* __restrict__ input1,
|
||||
const scalar_t* __restrict__ weight0,
|
||||
const scalar_t* __restrict__ weight1,
|
||||
int64_t M,
|
||||
int64_t N0,
|
||||
int64_t N1,
|
||||
int64_t stride1,
|
||||
float eps = 1e-5) {
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
scalar_t* x0 = input0 + m * N0;
|
||||
scalar_t* x1 = input1 + m * stride1;
|
||||
float scale0 = reduce(x0, N0);
|
||||
float scale1 = reduce(x1, N1);
|
||||
scale0 = float(1) / std::sqrt(scale0 / N0 + eps);
|
||||
scale1 = float(1) / std::sqrt(scale1 / N1 + eps);
|
||||
map2(x0, x0, weight0, scale0, N0);
|
||||
map2(x1, x1, weight1, scale1, N1);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void rotary(const scalar_t* input, scalar_t* out, const scalar_t* cos, const scalar_t* sin, int64_t size) {
|
||||
TORCH_CHECK(false, "rotary scalar path not implemented.");
|
||||
}
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <>
|
||||
inline void rotary<at::BFloat16>(
|
||||
const at::BFloat16* input, at::BFloat16* out, const at::BFloat16* cos, const at::BFloat16* sin, int64_t size) {
|
||||
// permute indices
|
||||
const __m512i idx1 = _mm512_set_epi32(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
|
||||
const __m512i idx2 = _mm512_set_epi32(31, 29, 27, 25, 23, 21, 19, 17, 15, 13, 11, 9, 7, 5, 3, 1);
|
||||
const __m512i idy1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0);
|
||||
const __m512i idy2 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8);
|
||||
|
||||
// rotary dim is 64, just 2 iters
|
||||
#pragma GCC unroll 2
|
||||
for (int64_t d = 0; d < size; d += 32) {
|
||||
int64_t d2 = d >> 1;
|
||||
// load coefs
|
||||
__m512 vcos = CVT_BF16_TO_FP32(_mm256_loadu_si256(reinterpret_cast<const __m256i*>(cos + d2)));
|
||||
__m512 vsin = CVT_BF16_TO_FP32(_mm256_loadu_si256(reinterpret_cast<const __m256i*>(sin + d2)));
|
||||
// load input
|
||||
__m512i a16 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(input + d));
|
||||
__m512 a = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(a16, 0));
|
||||
__m512 b = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(a16, 1));
|
||||
// from [16, 2] to [2, 16]
|
||||
__m512 in1 = _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b);
|
||||
__m512 in2 = _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b);
|
||||
// out1 = in1 * cos - in2 * sin;
|
||||
// out2 = in2 * cos + in1 * sin
|
||||
__m512 out1 = _mm512_sub_ps(_mm512_mul_ps(in1, vcos), _mm512_mul_ps(in2, vsin));
|
||||
__m512 out2 = _mm512_add_ps(_mm512_mul_ps(in2, vcos), _mm512_mul_ps(in1, vsin));
|
||||
// from [2, 16] to [16, 2]
|
||||
a = _mm512_mask_permutex2var_ps(out1, 0xffff, idy1, out2);
|
||||
b = _mm512_mask_permutex2var_ps(out1, 0xffff, idy2, out2);
|
||||
|
||||
_mm512_storeu_si512(reinterpret_cast<__m512i*>((out + d)), (__m512i)(_mm512_cvtne2ps_pbh(b, a)));
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename scalar_t>
|
||||
void rotary_emb_kernel_impl(
|
||||
scalar_t* q_pe_out,
|
||||
scalar_t* k_pe_out,
|
||||
const scalar_t* q_pe,
|
||||
const scalar_t* k_pe,
|
||||
const int64_t* pos,
|
||||
const scalar_t* cos_sin,
|
||||
int64_t num_seqs,
|
||||
int64_t num_heads,
|
||||
int64_t rotary_dim,
|
||||
int64_t q_strideB,
|
||||
int64_t q_strideH,
|
||||
int64_t k_strideB,
|
||||
int64_t oq_strideB,
|
||||
int64_t oq_strideH,
|
||||
int64_t ok_strideB) {
|
||||
TORCH_CHECK(rotary_dim % 32 == 0, "rotary_dim is not 32x.");
|
||||
const int64_t rotary_offset = rotary_dim / 2;
|
||||
|
||||
// parallel on [num_seqs, num_heads + 1]
|
||||
// top [num_heads] handle q_pe and bottom [1] handle k_pe
|
||||
at::parallel_for(0, num_seqs * (num_heads + 1), GRAIN_SIZE / rotary_dim, [&](int64_t begin, int64_t end) {
|
||||
int64_t seq{0}, head_id{0};
|
||||
data_index_init(begin, seq, num_seqs, head_id, num_heads + 1);
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
UNUSED(i);
|
||||
// get cos and sin cache ptr
|
||||
int64_t index = pos[seq];
|
||||
const scalar_t* cos = cos_sin + index * rotary_dim;
|
||||
const scalar_t* sin = cos + rotary_offset;
|
||||
|
||||
const scalar_t* input =
|
||||
(head_id < num_heads) ? q_pe + seq * q_strideB + head_id * q_strideH : k_pe + seq * k_strideB;
|
||||
scalar_t* out =
|
||||
(head_id < num_heads) ? q_pe_out + seq * oq_strideB + head_id * oq_strideH : k_pe_out + seq * ok_strideB;
|
||||
rotary<scalar_t>(input, out, cos, sin, rotary_dim);
|
||||
|
||||
// move to the next index
|
||||
data_index_step(seq, num_seqs, head_id, num_heads + 1);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
extern at::Tensor
|
||||
weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, const std::optional<at::Tensor>& bias, bool is_vnni);
|
||||
|
||||
extern at::Tensor int8_scaled_mm_with_quant(
|
||||
at::Tensor& mat1,
|
||||
at::Tensor& mat2,
|
||||
at::Tensor& scales2,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype,
|
||||
bool is_vnni);
|
||||
|
||||
extern void
|
||||
bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, const std::optional<at::Tensor>& scale);
|
||||
|
||||
extern at::Tensor fp8_scaled_mm_cpu(
|
||||
at::Tensor& mat1,
|
||||
at::Tensor& mat2,
|
||||
at::Tensor& scales2,
|
||||
std::vector<int64_t> block_size,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype,
|
||||
bool is_vnni);
|
||||
|
||||
// NB: shapes in DeepDeek R1
|
||||
//
|
||||
// hidden_states : [num_seqs, hidden_size] [1, 7168]
|
||||
// q_a_proj_weight : [q_lora_rank, hidden_size] [1536, 7168]
|
||||
// q_b_proj_weight : [num_heads * qk_head_dim, q_lora_rank] [4224, 1536]
|
||||
// kv_a_proj_weight : [kv_lora_rank + qk_rope_head_dim, hidden_size] [576, 7168]
|
||||
// w_kc : [num_heads, kv_lora_rank, qk_nope_head_dim] [22, 512, 128]
|
||||
// q_a_layernorm_weight : [q_lora_rank] [1536]
|
||||
// kv_a_layernorm_weight : [kv_lora_rank] [512]
|
||||
//
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
|
||||
at::Tensor& hidden_states,
|
||||
at::Tensor& q_a_proj_weight,
|
||||
at::Tensor& q_b_proj_weight,
|
||||
at::Tensor& kv_a_proj_weight,
|
||||
at::Tensor& w_kc,
|
||||
at::Tensor& q_a_layernorm_weight,
|
||||
at::Tensor& kv_a_layernorm_weight,
|
||||
at::Tensor& positions,
|
||||
at::Tensor& cos_sin_cache,
|
||||
double eps,
|
||||
bool use_int8_w8a8,
|
||||
bool use_fp8_w8a16,
|
||||
std::optional<at::Tensor> q_a_proj_scale,
|
||||
std::optional<at::Tensor> q_b_proj_scale,
|
||||
std::optional<at::Tensor> kv_a_proj_scale,
|
||||
bool is_vnni,
|
||||
std::optional<std::vector<int64_t>> block_size) {
|
||||
RECORD_FUNCTION(
|
||||
"sgl-kernel::qkv_proj_with_rope",
|
||||
std::vector<c10::IValue>({hidden_states, q_a_proj_weight, q_b_proj_weight, kv_a_proj_weight, w_kc}));
|
||||
|
||||
const auto st = hidden_states.scalar_type();
|
||||
CHECK_INPUT(hidden_states);
|
||||
CHECK_INPUT(positions);
|
||||
CHECK_INPUT(cos_sin_cache);
|
||||
CHECK_EQ(q_a_layernorm_weight.scalar_type(), st);
|
||||
CHECK_EQ(kv_a_layernorm_weight.scalar_type(), st);
|
||||
CHECK_EQ(positions.scalar_type(), at::kLong);
|
||||
CHECK_EQ(cos_sin_cache.scalar_type(), st);
|
||||
CHECK_DIM(2, hidden_states);
|
||||
CHECK_DIM(3, w_kc);
|
||||
CHECK_DIM(1, q_a_layernorm_weight);
|
||||
CHECK_DIM(1, kv_a_layernorm_weight);
|
||||
CHECK_DIM(1, positions);
|
||||
CHECK_DIM(2, cos_sin_cache);
|
||||
|
||||
// skip contiguous checks for weights, expect prepacked
|
||||
TORCH_CHECK(is_vnni, "qkv_proj_with_rope: expect weights are prepacked!");
|
||||
|
||||
int64_t num_seqs = hidden_states.size(0);
|
||||
int64_t hidden_size = hidden_states.size(1);
|
||||
int64_t q_lora_rank = q_a_proj_weight.size(0);
|
||||
int64_t num_heads = w_kc.size(0);
|
||||
int64_t kv_lora_rank = w_kc.size(1);
|
||||
int64_t qk_head_dim = q_b_proj_weight.size(0) / num_heads;
|
||||
int64_t qk_nope_head_dim = w_kc.size(2);
|
||||
int64_t qk_rope_head_dim = kv_a_proj_weight.size(0) - kv_lora_rank;
|
||||
int64_t rotary_dim = cos_sin_cache.size(1);
|
||||
|
||||
CHECK_EQ(positions.numel(), num_seqs);
|
||||
CHECK_EQ(rotary_dim, qk_rope_head_dim);
|
||||
CHECK_EQ(q_a_layernorm_weight.numel(), q_lora_rank);
|
||||
CHECK_EQ(kv_a_layernorm_weight.numel(), kv_lora_rank);
|
||||
|
||||
// check the packed dimension
|
||||
CHECK_EQ(q_a_proj_weight.size(1), get_row_size(hidden_size, use_int8_w8a8));
|
||||
CHECK_EQ(q_b_proj_weight.size(1), get_row_size(q_lora_rank, use_int8_w8a8));
|
||||
CHECK_EQ(kv_a_proj_weight.size(1), get_row_size(hidden_size, use_int8_w8a8));
|
||||
|
||||
if (use_int8_w8a8) {
|
||||
TORCH_CHECK(q_a_proj_scale.has_value(), "missing q_a_proj_scale for int8 w8a8.");
|
||||
TORCH_CHECK(q_b_proj_scale.has_value(), "missing q_b_proj_scale for int8 w8a8.");
|
||||
TORCH_CHECK(kv_a_proj_scale.has_value(), "missing kv_a_proj_scale for int8 w8a8.");
|
||||
}
|
||||
if (use_fp8_w8a16) {
|
||||
TORCH_CHECK(q_a_proj_scale.has_value(), "missing q_a_proj_scale for fp8 w8a16.");
|
||||
TORCH_CHECK(q_b_proj_scale.has_value(), "missing q_b_proj_scale for fp8 w8a16.");
|
||||
TORCH_CHECK(kv_a_proj_scale.has_value(), "missing kv_a_proj_scale for fp8 w8a16.");
|
||||
TORCH_CHECK(block_size.has_value(), "missing block_size for fp8 w8a16.");
|
||||
TORCH_CHECK(block_size.value().size() == 2, "block_size should be 2D for fp8 w8a16.");
|
||||
}
|
||||
// outputs and temp buffer
|
||||
const auto options = hidden_states.options();
|
||||
auto q_input = at::empty({num_seqs, num_heads, kv_lora_rank + qk_rope_head_dim}, options);
|
||||
auto k_input = at::empty({num_seqs, 1, kv_lora_rank + qk_rope_head_dim}, options);
|
||||
auto v_input = k_input.narrow(-1, 0, kv_lora_rank);
|
||||
|
||||
// outputs of q_a_proj and q_b_proj
|
||||
auto qa = at::empty({num_seqs, q_lora_rank}, options);
|
||||
|
||||
// stage 1: q_a_proj and kv_a_proj
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "qkv_proj_kernel_impl", [&] {
|
||||
if (use_int8_w8a8) {
|
||||
auto q_a_proj_s = q_a_proj_scale.value();
|
||||
auto kv_a_proj_s = kv_a_proj_scale.value();
|
||||
TORCH_CHECK(q_a_proj_s.numel() == q_lora_rank);
|
||||
TORCH_CHECK(kv_a_proj_s.numel() == kv_lora_rank + qk_rope_head_dim);
|
||||
|
||||
auto buffer = at::empty({num_seqs * hidden_size + num_seqs * 4}, options.dtype(at::kByte));
|
||||
uint8_t* __restrict__ Aq_data = buffer.data_ptr<uint8_t>();
|
||||
float* __restrict__ As_data = (float*)((void*)(Aq_data + num_seqs * hidden_size));
|
||||
const scalar_t* __restrict__ A_data = hidden_states.data_ptr<scalar_t>();
|
||||
|
||||
at::parallel_for(0, num_seqs, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(Aq_data + m * hidden_size, As_data[m], A_data + m * hidden_size, hidden_size);
|
||||
}
|
||||
});
|
||||
|
||||
segment_gemm_kernel_impl<scalar_t>(
|
||||
qa.data_ptr<scalar_t>(),
|
||||
k_input.data_ptr<scalar_t>(),
|
||||
Aq_data,
|
||||
q_a_proj_weight.data_ptr<int8_t>(),
|
||||
kv_a_proj_weight.data_ptr<int8_t>(),
|
||||
As_data,
|
||||
q_a_proj_s.data_ptr<float>(),
|
||||
kv_a_proj_s.data_ptr<float>(),
|
||||
num_seqs,
|
||||
q_lora_rank,
|
||||
kv_lora_rank + qk_rope_head_dim,
|
||||
hidden_size);
|
||||
} else if (use_fp8_w8a16) {
|
||||
int64_t block_size_N = block_size.value()[0];
|
||||
int64_t block_size_K = block_size.value()[1];
|
||||
auto q_a_proj_s = q_a_proj_scale.value();
|
||||
auto kv_a_proj_s = kv_a_proj_scale.value();
|
||||
CHECK_EQ(q_a_proj_s.size(0), div_up(q_lora_rank, block_size_N));
|
||||
CHECK_EQ(q_a_proj_s.size(1), div_up(hidden_size, block_size_K));
|
||||
CHECK_EQ(kv_a_proj_s.size(0), div_up(kv_lora_rank + qk_rope_head_dim, block_size_N));
|
||||
CHECK_EQ(kv_a_proj_s.size(1), div_up(hidden_size, block_size_K));
|
||||
|
||||
const int BLOCK_N = block_size_n();
|
||||
const int num_threads = at::get_num_threads();
|
||||
auto buffer = at::empty({num_threads, BLOCK_N * hidden_size}, options);
|
||||
segment_gemm_kernel_impl<scalar_t>(
|
||||
qa.data_ptr<scalar_t>(),
|
||||
k_input.data_ptr<scalar_t>(),
|
||||
hidden_states.data_ptr<scalar_t>(),
|
||||
q_a_proj_weight.data_ptr<at::Float8_e4m3fn>(),
|
||||
kv_a_proj_weight.data_ptr<at::Float8_e4m3fn>(),
|
||||
q_a_proj_s.data_ptr<float>(),
|
||||
kv_a_proj_s.data_ptr<float>(),
|
||||
buffer.data_ptr<scalar_t>(),
|
||||
num_seqs,
|
||||
q_lora_rank,
|
||||
kv_lora_rank + qk_rope_head_dim,
|
||||
hidden_size,
|
||||
block_size_N,
|
||||
block_size_K);
|
||||
} else {
|
||||
segment_gemm_kernel_impl<scalar_t>(
|
||||
qa.data_ptr<scalar_t>(),
|
||||
k_input.data_ptr<scalar_t>(),
|
||||
hidden_states.data_ptr<scalar_t>(),
|
||||
q_a_proj_weight.data_ptr<scalar_t>(),
|
||||
kv_a_proj_weight.data_ptr<scalar_t>(),
|
||||
num_seqs,
|
||||
q_lora_rank,
|
||||
kv_lora_rank + qk_rope_head_dim,
|
||||
hidden_size);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 2: apply rmsnorm inplace
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "rms_norm_kernel_impl", [&] {
|
||||
rms_norm_kernel_impl<scalar_t>(
|
||||
qa.data_ptr<scalar_t>(),
|
||||
v_input.data_ptr<scalar_t>(),
|
||||
q_a_layernorm_weight.data_ptr<scalar_t>(),
|
||||
kv_a_layernorm_weight.data_ptr<scalar_t>(),
|
||||
num_seqs,
|
||||
q_lora_rank,
|
||||
kv_lora_rank,
|
||||
kv_lora_rank + qk_rope_head_dim,
|
||||
eps);
|
||||
});
|
||||
|
||||
// stage 3: q_b_proj
|
||||
at::Tensor qb;
|
||||
std::optional<at::Tensor> bias;
|
||||
if (use_int8_w8a8) {
|
||||
qb = int8_scaled_mm_with_quant(qa, q_b_proj_weight, q_b_proj_scale.value(), bias, at::kBFloat16, is_vnni);
|
||||
} else if (use_fp8_w8a16) {
|
||||
qb = fp8_scaled_mm_cpu(
|
||||
qa, q_b_proj_weight, q_b_proj_scale.value(), block_size.value(), bias, at::kBFloat16, is_vnni);
|
||||
} else {
|
||||
qb = weight_packed_linear(qa, q_b_proj_weight, bias, is_vnni);
|
||||
}
|
||||
qb.as_strided_({num_seqs, num_heads, qk_head_dim}, {num_heads * qk_head_dim, qk_head_dim, 1});
|
||||
|
||||
// stage 4: bmm
|
||||
std::optional<at::Tensor> scale;
|
||||
auto q_nope = qb.narrow(2, 0, qk_nope_head_dim).transpose_(0, 1);
|
||||
auto q_nope_out = q_input.narrow(2, 0, kv_lora_rank).transpose_(0, 1);
|
||||
bmm_cpu(q_nope_out, q_nope, w_kc, is_vnni, scale);
|
||||
|
||||
// stage 5: rope
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "rotary_emb_kernel_impl", [&] {
|
||||
rotary_emb_kernel_impl<scalar_t>(
|
||||
q_input.data_ptr<scalar_t>() + kv_lora_rank,
|
||||
k_input.data_ptr<scalar_t>() + kv_lora_rank,
|
||||
qb.data_ptr<scalar_t>() + qk_nope_head_dim,
|
||||
k_input.data_ptr<scalar_t>() + kv_lora_rank,
|
||||
positions.data_ptr<int64_t>(),
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
num_seqs,
|
||||
num_heads,
|
||||
rotary_dim,
|
||||
num_heads * qk_head_dim,
|
||||
qk_head_dim,
|
||||
kv_lora_rank + qk_rope_head_dim,
|
||||
num_heads * (kv_lora_rank + qk_rope_head_dim),
|
||||
kv_lora_rank + qk_rope_head_dim,
|
||||
kv_lora_rank + qk_rope_head_dim);
|
||||
});
|
||||
|
||||
return std::make_tuple(q_input, k_input, v_input);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope_fused_weight(
|
||||
at::Tensor& hidden_states,
|
||||
at::Tensor& qkv_a_proj_weight,
|
||||
at::Tensor& q_b_proj_weight,
|
||||
at::Tensor& w_kc,
|
||||
at::Tensor& q_a_layernorm_weight,
|
||||
at::Tensor& kv_a_layernorm_weight,
|
||||
at::Tensor& positions,
|
||||
at::Tensor& cos_sin_cache,
|
||||
double eps,
|
||||
bool use_int8_w8a8,
|
||||
bool use_fp8_w8a16,
|
||||
std::optional<at::Tensor> qkv_a_proj_scale,
|
||||
std::optional<at::Tensor> q_b_proj_scale,
|
||||
bool is_vnni,
|
||||
std::optional<std::vector<int64_t>> block_size,
|
||||
int64_t q_lora_rank,
|
||||
int64_t kv_lora_rank,
|
||||
int64_t qk_rope_head_dim) {
|
||||
RECORD_FUNCTION(
|
||||
"sgl-kernel::qkv_proj_with_rope_fused_weight",
|
||||
std::vector<c10::IValue>({hidden_states, qkv_a_proj_weight, q_b_proj_weight, w_kc}));
|
||||
|
||||
int64_t hidden_size = hidden_states.size(1);
|
||||
CHECK_EQ(qkv_a_proj_weight.size(0), q_lora_rank + kv_lora_rank + qk_rope_head_dim);
|
||||
CHECK_EQ(qkv_a_proj_weight.size(1), get_row_size(hidden_size, use_int8_w8a8));
|
||||
|
||||
std::vector<at::Tensor> weight_chunks =
|
||||
at::split(qkv_a_proj_weight, {q_lora_rank, kv_lora_rank + qk_rope_head_dim}, 0);
|
||||
at::Tensor q_a_proj_weight = weight_chunks[0];
|
||||
at::Tensor kv_a_proj_weight = weight_chunks[1];
|
||||
at::Tensor q_a_proj_s;
|
||||
at::Tensor kv_a_proj_s;
|
||||
|
||||
if (use_int8_w8a8) {
|
||||
TORCH_CHECK(qkv_a_proj_scale.has_value(), "missing qkv_a_proj_scale for int8 w8a8.");
|
||||
std::vector<at::Tensor> scale_chunks =
|
||||
at::split(qkv_a_proj_scale.value(), {q_lora_rank, kv_lora_rank + qk_rope_head_dim}, 0);
|
||||
q_a_proj_s = scale_chunks[0];
|
||||
kv_a_proj_s = scale_chunks[1];
|
||||
}
|
||||
if (use_fp8_w8a16) {
|
||||
TORCH_CHECK(qkv_a_proj_scale.has_value(), "missing qkv_a_proj_scale for fp8 w8a16.");
|
||||
int64_t block_size_N = block_size.value()[0];
|
||||
int64_t q_a_proj_s_dim0 = div_up(q_lora_rank, block_size_N);
|
||||
int64_t kv_a_proj_s_dim0 = div_up(kv_lora_rank + qk_rope_head_dim, block_size_N);
|
||||
std::vector<at::Tensor> scale_chunks = at::split(qkv_a_proj_scale.value(), {q_a_proj_s_dim0, kv_a_proj_s_dim0}, 0);
|
||||
q_a_proj_s = scale_chunks[0];
|
||||
kv_a_proj_s = scale_chunks[1];
|
||||
}
|
||||
|
||||
return qkv_proj_with_rope(
|
||||
hidden_states,
|
||||
q_a_proj_weight,
|
||||
q_b_proj_weight,
|
||||
kv_a_proj_weight,
|
||||
w_kc,
|
||||
q_a_layernorm_weight,
|
||||
kv_a_layernorm_weight,
|
||||
positions,
|
||||
cos_sin_cache,
|
||||
eps,
|
||||
use_int8_w8a8,
|
||||
use_fp8_w8a16,
|
||||
q_a_proj_s,
|
||||
q_b_proj_scale,
|
||||
kv_a_proj_s,
|
||||
is_vnni,
|
||||
block_size);
|
||||
}
|
||||
346
sgl-kernel/csrc/cpu/rope.cpp
Normal file
346
sgl-kernel/csrc/cpu/rope.cpp
Normal file
@@ -0,0 +1,346 @@
|
||||
#include "common.h"
|
||||
#include "vec.h"
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t>
|
||||
void rotary_embedding_3D_kernel_impl(
|
||||
scalar_t* __restrict__ query_out,
|
||||
scalar_t* __restrict__ key_out,
|
||||
int64_t* __restrict__ positions,
|
||||
scalar_t* __restrict__ query,
|
||||
scalar_t* __restrict__ key,
|
||||
scalar_t* __restrict__ cos_sin_cache,
|
||||
int64_t num_tokens,
|
||||
int64_t num_heads,
|
||||
int64_t num_kv_heads,
|
||||
int64_t head_size,
|
||||
int64_t rotary_dim,
|
||||
int64_t query_stride_s,
|
||||
int64_t query_out_stride_s,
|
||||
int64_t key_out_stride_s,
|
||||
int64_t key_stride_s,
|
||||
int64_t query_stride_h,
|
||||
int64_t query_out_stride_h) {
|
||||
int64_t HR = rotary_dim;
|
||||
int64_t HK = rotary_dim;
|
||||
int64_t COFF = HR / 2;
|
||||
at::parallel_for(0, num_tokens * num_heads, GRAIN_SIZE / rotary_dim, [&](int64_t begin, int64_t end) {
|
||||
int64_t seq{0}, head_id{0};
|
||||
data_index_init(begin, seq, num_tokens, head_id, num_heads);
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t in_offset_q = seq * query_stride_s + head_id * query_stride_h;
|
||||
int64_t out_offset_q = seq * query_out_stride_s + head_id * query_out_stride_h;
|
||||
int64_t out_offset_k = seq * key_out_stride_s;
|
||||
int64_t p = 0;
|
||||
scalar_t* sin_start = nullptr;
|
||||
scalar_t* cos_start = nullptr;
|
||||
// step 0) get the rotary position embedding for the current position
|
||||
p = positions[seq];
|
||||
sin_start = cos_sin_cache + p * HR + COFF;
|
||||
cos_start = cos_sin_cache + p * HR;
|
||||
// step 1) apply_rotary_pos_emb for the rotary_dim elements in every
|
||||
// head of query/key
|
||||
for (int64_t h = 0; h < rotary_dim; h += 2) {
|
||||
scalar_t cos = cos_start[h >> 1];
|
||||
scalar_t sin = sin_start[h >> 1];
|
||||
scalar_t in1 = query[in_offset_q + h];
|
||||
scalar_t in2 = query[in_offset_q + h + 1];
|
||||
scalar_t out1 = in1 * cos - in2 * sin;
|
||||
scalar_t out2 = in2 * cos + in1 * sin;
|
||||
query_out[out_offset_q + h] = out1;
|
||||
query_out[out_offset_q + h + 1] = out2;
|
||||
}
|
||||
for (int64_t h = 0; h < HK; h += 2) {
|
||||
scalar_t cos = cos_start[h >> 1];
|
||||
scalar_t sin = sin_start[h >> 1];
|
||||
int64_t k_pe_offset = seq * key_stride_s;
|
||||
scalar_t in1_k = key[k_pe_offset + h];
|
||||
scalar_t in2_k = key[k_pe_offset + h + 1];
|
||||
scalar_t out1_k = in1_k * cos - in2_k * sin;
|
||||
scalar_t out2_k = in2_k * cos + in1_k * sin;
|
||||
key_out[out_offset_k + h] = out1_k;
|
||||
key_out[out_offset_k + h + 1] = out2_k;
|
||||
}
|
||||
// move to the next index
|
||||
data_index_step(seq, num_tokens, head_id, num_heads);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void rotary_embedding_neox_2D_kernel_impl(
|
||||
int64_t* __restrict__ positions,
|
||||
scalar_t* __restrict__ query,
|
||||
scalar_t* __restrict__ key,
|
||||
scalar_t* __restrict__ cos_sin_cache,
|
||||
int64_t rotary_dim,
|
||||
int64_t query_stride_s,
|
||||
int64_t key_stride_s,
|
||||
int64_t num_heads,
|
||||
int64_t num_kv_heads,
|
||||
int64_t head_size,
|
||||
int64_t num_tokens) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int64_t bVecSize = bVec::size();
|
||||
|
||||
int64_t embed_dim = rotary_dim / 2;
|
||||
bool flag = (embed_dim % bVecSize == 0);
|
||||
int64_t loop_upper = flag ? embed_dim : embed_dim - bVecSize;
|
||||
|
||||
auto compute_loop = [&](int64_t token_head, scalar_t* cache_ptr, scalar_t* qk) {
|
||||
int64_t j = 0;
|
||||
for (; j < loop_upper; j += bVecSize) {
|
||||
int64_t rot_offset = j;
|
||||
int64_t x_index = rot_offset;
|
||||
int64_t y_index = embed_dim + rot_offset;
|
||||
|
||||
int64_t out_x = token_head + x_index;
|
||||
int64_t out_y = token_head + y_index;
|
||||
|
||||
bVec _cos = bVec::loadu(cache_ptr + x_index);
|
||||
bVec _sin = bVec::loadu(cache_ptr + y_index);
|
||||
|
||||
bVec _q_x = bVec::loadu(qk + out_x);
|
||||
bVec _q_y = bVec::loadu(qk + out_y);
|
||||
fVec _cos_0, _cos_1;
|
||||
std::tie(_cos_0, _cos_1) = at::vec::convert_to_float(_cos);
|
||||
fVec _sin_0, _sin_1;
|
||||
std::tie(_sin_0, _sin_1) = at::vec::convert_to_float(_sin);
|
||||
fVec _q_x_0, _q_x_1;
|
||||
std::tie(_q_x_0, _q_x_1) = at::vec::convert_to_float(_q_x);
|
||||
fVec _q_y_0, _q_y_1;
|
||||
std::tie(_q_y_0, _q_y_1) = at::vec::convert_to_float(_q_y);
|
||||
|
||||
auto out1_0 = _q_x_0 * _cos_0 - _q_y_0 * _sin_0;
|
||||
auto out1_1 = _q_x_1 * _cos_1 - _q_y_1 * _sin_1;
|
||||
auto out1 = convert_from_float_ext<scalar_t>(out1_0, out1_1);
|
||||
out1.store(qk + out_x);
|
||||
|
||||
auto out2_0 = _q_y_0 * _cos_0 + _q_x_0 * _sin_0;
|
||||
auto out2_1 = _q_y_1 * _cos_1 + _q_x_1 * _sin_1;
|
||||
auto out2 = convert_from_float_ext<scalar_t>(out2_0, out2_1);
|
||||
out2.store(qk + out_y);
|
||||
}
|
||||
if (!flag) {
|
||||
for (; j < embed_dim; ++j) {
|
||||
int64_t x_index = j;
|
||||
int64_t y_index = embed_dim + j;
|
||||
|
||||
int64_t out_x = token_head + x_index;
|
||||
int64_t out_y = token_head + y_index;
|
||||
|
||||
float _cos = cache_ptr[x_index];
|
||||
float _sin = cache_ptr[y_index];
|
||||
|
||||
float _q_x = qk[out_x];
|
||||
float _q_y = qk[out_y];
|
||||
|
||||
qk[out_x] = _q_x * _cos - _q_y * _sin;
|
||||
qk[out_y] = _q_y * _cos + _q_x * _sin;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
#pragma omp parallel for
|
||||
for (int64_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
|
||||
int64_t pos = positions[token_idx];
|
||||
scalar_t* cache_ptr = cos_sin_cache + pos * rotary_dim;
|
||||
|
||||
for (int64_t i = 0; i < num_heads; ++i) {
|
||||
int64_t head_idx = i;
|
||||
int64_t token_head = token_idx * query_stride_s + head_idx * head_size;
|
||||
compute_loop(token_head, cache_ptr, query);
|
||||
}
|
||||
|
||||
for (int64_t i = 0; i < num_kv_heads; ++i) {
|
||||
int64_t head_idx = i;
|
||||
int64_t token_head = token_idx * key_stride_s + head_idx * head_size;
|
||||
compute_loop(token_head, cache_ptr, key);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void rotary_embedding_2D_kernel_impl(
|
||||
int64_t* __restrict__ positions,
|
||||
scalar_t* __restrict__ query,
|
||||
scalar_t* __restrict__ key,
|
||||
scalar_t* __restrict__ cos_sin_cache,
|
||||
int64_t rotary_dim,
|
||||
int64_t query_stride_s,
|
||||
int64_t key_stride_s,
|
||||
int64_t num_heads,
|
||||
int64_t num_kv_heads,
|
||||
int64_t head_size,
|
||||
int64_t num_tokens) {
|
||||
int64_t embed_dim = rotary_dim / 2;
|
||||
|
||||
at::parallel_for(0, num_tokens * num_heads, GRAIN_SIZE / rotary_dim, [&](int64_t begin, int64_t end) {
|
||||
int64_t token_idx = {0}, i = {0};
|
||||
data_index_init(begin, token_idx, num_tokens, i, num_heads);
|
||||
for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
|
||||
int64_t pos = positions[token_idx];
|
||||
scalar_t* cache_ptr = cos_sin_cache + pos * rotary_dim;
|
||||
scalar_t* cos_cache_ptr = cache_ptr;
|
||||
scalar_t* sin_cache_ptr = cache_ptr + embed_dim;
|
||||
int64_t head_idx = i;
|
||||
int64_t token_head = token_idx * query_stride_s + head_idx * head_size;
|
||||
scalar_t* head_query = token_head + query;
|
||||
for (int64_t j = 0; j < embed_dim; j += 1) {
|
||||
int64_t rot_offset = j;
|
||||
int64_t x_index = 2 * rot_offset;
|
||||
int64_t y_index = 2 * rot_offset + 1;
|
||||
|
||||
float cos = cos_cache_ptr[rot_offset];
|
||||
float sin = sin_cache_ptr[rot_offset];
|
||||
|
||||
float x = head_query[x_index];
|
||||
float y = head_query[y_index];
|
||||
|
||||
head_query[x_index] = x * cos - y * sin;
|
||||
head_query[y_index] = y * cos + x * sin;
|
||||
}
|
||||
data_index_step(token_idx, num_tokens, i, num_heads);
|
||||
}
|
||||
});
|
||||
|
||||
at::parallel_for(0, num_tokens * num_kv_heads, GRAIN_SIZE / rotary_dim, [&](int64_t begin, int64_t end) {
|
||||
int64_t token_idx{0}, i = {0};
|
||||
data_index_init(begin, token_idx, num_tokens, i, num_kv_heads);
|
||||
for ([[maybe_unused]] auto z : c10::irange(begin, end)) {
|
||||
int64_t pos = positions[token_idx];
|
||||
scalar_t* cache_ptr = cos_sin_cache + pos * rotary_dim;
|
||||
scalar_t* cos_cache_ptr = cache_ptr;
|
||||
scalar_t* sin_cache_ptr = cache_ptr + embed_dim;
|
||||
int64_t head_idx = i;
|
||||
int64_t token_head = token_idx * key_stride_s + head_idx * head_size;
|
||||
scalar_t* head_key = key + token_head;
|
||||
for (int64_t j = 0; j < embed_dim; j += 1) {
|
||||
int64_t rot_offset = j;
|
||||
int64_t x_index = 2 * rot_offset;
|
||||
int64_t y_index = 2 * rot_offset + 1;
|
||||
|
||||
float cos = cos_cache_ptr[rot_offset];
|
||||
float sin = sin_cache_ptr[rot_offset];
|
||||
|
||||
float x = head_key[x_index];
|
||||
float y = head_key[y_index];
|
||||
|
||||
head_key[x_index] = x * cos - y * sin;
|
||||
head_key[y_index] = y * cos + x * sin;
|
||||
}
|
||||
data_index_step(token_idx, num_tokens, i, num_kv_heads);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> rotary_embedding_cpu(
|
||||
at::Tensor& positions,
|
||||
at::Tensor& query,
|
||||
at::Tensor& key,
|
||||
int64_t head_size,
|
||||
at::Tensor& cos_sin_cache,
|
||||
bool is_neox) {
|
||||
RECORD_FUNCTION("sgl-kernel::rotary_embedding_cpu", std::vector<c10::IValue>({query, key}));
|
||||
CHECK_DIM(1, positions);
|
||||
const auto input_dim = query.dim();
|
||||
const auto input_dtype = query.scalar_type();
|
||||
TORCH_CHECK(
|
||||
input_dim == 2 || input_dim == 3,
|
||||
" Query/Key must be 2D [num_tokens, num_heads*head_size] or 3D [num_tokens, num_heads, head_size] tensor");
|
||||
CHECK_DIM(2, cos_sin_cache);
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(query);
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(key);
|
||||
|
||||
int64_t rotary_dim = cos_sin_cache.size(1);
|
||||
if (input_dim == 3) {
|
||||
// TODO: add support for head_dim != rotary_dim case when input_dim=3
|
||||
CHECK_EQ(query.size(-1), rotary_dim);
|
||||
// TODO: add support for kv_head != 1
|
||||
CHECK_EQ(key.size(1), 1);
|
||||
}
|
||||
|
||||
int64_t num_tokens = positions.numel();
|
||||
CHECK_EQ(key.size(0), num_tokens);
|
||||
CHECK_EQ(query.size(0), num_tokens);
|
||||
|
||||
TORCH_CHECK(positions.scalar_type() == at::kLong, "expect positions to be int64, got ", positions.scalar_type());
|
||||
TORCH_CHECK(input_dtype == key.scalar_type(), "query and key must have the same data type");
|
||||
TORCH_CHECK(input_dtype == cos_sin_cache.scalar_type(), "query and cos_sin_cache must have the same data type");
|
||||
|
||||
int64_t num_heads = input_dim == 2 ? query.size(-1) / head_size : query.size(1);
|
||||
int64_t num_kv_heads = input_dim == 2 ? key.size(-1) / head_size : key.size(1);
|
||||
int64_t key_stride_s = key.stride(0);
|
||||
int64_t query_stride_s = query.stride(0);
|
||||
|
||||
// input stride of num head dim is meaningful only when input dim = 3
|
||||
int64_t query_stride_h = input_dim == 3 ? query.stride(1) : -1;
|
||||
at::Tensor query_out = at::empty_like(query);
|
||||
at::Tensor key_out = at::empty_like(key);
|
||||
int64_t query_out_stride_s = query_out.stride(0);
|
||||
int64_t key_out_stride_s = key_out.stride(0);
|
||||
// output stride of num head dim is meaningful only when input dim = 3
|
||||
int64_t query_out_stride_h = input_dim == 3 ? query_out.stride(1) : -1;
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(input_dtype, "rotary_embedding_cpu", [&] {
|
||||
if (input_dim == 2) {
|
||||
if (is_neox) {
|
||||
rotary_embedding_neox_2D_kernel_impl<scalar_t>(
|
||||
positions.data_ptr<int64_t>(),
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
rotary_dim,
|
||||
query_stride_s,
|
||||
key_stride_s,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
num_tokens);
|
||||
} else {
|
||||
rotary_embedding_2D_kernel_impl<scalar_t>(
|
||||
positions.data_ptr<int64_t>(),
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
rotary_dim,
|
||||
query_stride_s,
|
||||
key_stride_s,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
num_tokens);
|
||||
}
|
||||
query_out = query;
|
||||
key_out = key;
|
||||
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
is_neox == false, " Query/Key with 3D [num_tokens, num_heads, head_size] does not support neox rope yet");
|
||||
// TODO: add neox style support for rope impl with 3D inputs
|
||||
rotary_embedding_3D_kernel_impl<scalar_t>(
|
||||
query_out.data_ptr<scalar_t>(),
|
||||
key_out.data_ptr<scalar_t>(),
|
||||
positions.data_ptr<int64_t>(),
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
cos_sin_cache.data_ptr<scalar_t>(),
|
||||
num_tokens,
|
||||
num_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
rotary_dim,
|
||||
query_stride_s,
|
||||
query_out_stride_s,
|
||||
key_out_stride_s,
|
||||
key_stride_s,
|
||||
query_stride_h,
|
||||
query_out_stride_h);
|
||||
}
|
||||
});
|
||||
return std::make_tuple(query_out, key_out);
|
||||
}
|
||||
666
sgl-kernel/csrc/cpu/shm.cpp
Normal file
666
sgl-kernel/csrc/cpu/shm.cpp
Normal file
@@ -0,0 +1,666 @@
|
||||
#include "shm.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <errno.h>
|
||||
#include <fcntl.h>
|
||||
#include <immintrin.h>
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <sys/mman.h>
|
||||
#include <unistd.h>
|
||||
|
||||
// states for collectives
|
||||
enum coll_state {
|
||||
coll_begin = 0,
|
||||
coll_allreduce_naive__copy_in_done,
|
||||
coll_allreduce_naive__reduce_done,
|
||||
// alternative state when allreduce is working on alternative buffer
|
||||
// of the double buffer.
|
||||
coll_alt1_allreduce_naive__copy_in_done,
|
||||
coll_alt2_allreduce_naive__copy_in_done,
|
||||
coll_alt1_allreduce_naive__reduce_done,
|
||||
coll_allgather_naive__copy_in_done,
|
||||
coll_alt1_allgather_naive__copy_in_done,
|
||||
coll_alt2_allgather_naive__copy_in_done,
|
||||
};
|
||||
|
||||
// SHM building blocks
|
||||
struct SharedData {
|
||||
const char* name;
|
||||
int descriptor;
|
||||
void* bytes;
|
||||
size_t nbytes;
|
||||
};
|
||||
|
||||
void shared_open(SharedData* data, const char* name, size_t nbytes) {
|
||||
int d = shm_open(name, O_RDWR, S_IRUSR | S_IWUSR);
|
||||
if (d != -1) {
|
||||
void* bytes = mmap(NULL, nbytes, PROT_READ | PROT_WRITE, MAP_SHARED, d, 0);
|
||||
data->name = name;
|
||||
data->descriptor = d;
|
||||
data->bytes = bytes;
|
||||
data->nbytes = nbytes;
|
||||
} else {
|
||||
if (errno != ENOENT) {
|
||||
// don't print if shm can not be found because we want to loop over from
|
||||
// caller again until the other ranks created the shm
|
||||
printf("shared_open %s failed, errno=%d\n", name, errno);
|
||||
}
|
||||
data->descriptor = -1;
|
||||
}
|
||||
}
|
||||
|
||||
void shared_create(SharedData* data, const char* name, void* bytes, size_t nbytes) {
|
||||
int d = shm_open(name, O_CREAT | O_RDWR, S_IRUSR | S_IWUSR);
|
||||
if (d != -1) {
|
||||
nbytes = write(d, bytes, nbytes);
|
||||
if (nbytes > 0) {
|
||||
shared_open(data, name, nbytes);
|
||||
}
|
||||
} else {
|
||||
printf("shared_create %s failed\n", name);
|
||||
}
|
||||
}
|
||||
|
||||
static int world_size;
|
||||
|
||||
// SHM based allreduce helper functions
|
||||
// buffer that holds shm name
|
||||
#define NAME_BUF_SIZE 1000
|
||||
#define MAX_BUF_SIZE 1048576 * 32
|
||||
#define NAIVE_ALLREDUCE_THRESHOLD 1048576
|
||||
#define SHM_BUFFER_NAME "deepspeed_allreduce_buffer"
|
||||
struct allreduce_workspace {
|
||||
enum coll_state states[2]; // idx=0 -- state for symmetric_naive_all_reduce
|
||||
// idx=1 -- state for distributed_naive_all_reduce
|
||||
// double buffer to avoid syncing between rounds
|
||||
// offset=0 -- 2*NAIVE_ALLREDUCE_THRESHOLD : buffer for
|
||||
// symmetric_naive_all_reduce after that : buffer for
|
||||
// distributed_naive_all_reduce
|
||||
char buffer[2 * NAIVE_ALLREDUCE_THRESHOLD + 2 * MAX_BUF_SIZE];
|
||||
};
|
||||
|
||||
#define BUFFER0_OFFSET(current_buffer) current_buffer* NAIVE_ALLREDUCE_THRESHOLD
|
||||
#define BUFFER1_OFFSET(current_buffer) 2 * NAIVE_ALLREDUCE_THRESHOLD + current_buffer* MAX_BUF_SIZE
|
||||
|
||||
struct allreduce_workspace** workspace;
|
||||
|
||||
// buffer for small messages, double buffer
|
||||
char** symmetric_buffer[2];
|
||||
// buffer for large messages, double buffer
|
||||
char** distributed_buffer[2];
|
||||
|
||||
void wait_buffer_state_until_2(int index, enum coll_state state0, enum coll_state state1, int state_group) {
|
||||
volatile enum coll_state* state_ptr = &(workspace[index]->states[state_group]);
|
||||
|
||||
while (1) {
|
||||
volatile enum coll_state cur_state = *state_ptr;
|
||||
if (cur_state == state0 || cur_state == state1) break;
|
||||
}
|
||||
}
|
||||
|
||||
__m512 cvt_bf16_to_fp32(const __m256i src) __attribute__((target("avx512bw")));
|
||||
inline __m512 cvt_bf16_to_fp32(const __m256i src) {
|
||||
auto y = _mm512_cvtepu16_epi32(src);
|
||||
return _mm512_castsi512_ps(_mm512_bslli_epi128(y, 2));
|
||||
}
|
||||
|
||||
inline __m256i cvt_fp32_to_bf16(const __m512 src) __attribute__((target("avx512bw")));
|
||||
inline __m256i cvt_fp32_to_bf16(const __m512 src) {
|
||||
__m512i value = _mm512_castps_si512(src);
|
||||
__m512i nan = _mm512_set1_epi32(0xffff);
|
||||
auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q);
|
||||
__m512i ones = _mm512_set1_epi32(0x1);
|
||||
__m512i vec_bias = _mm512_set1_epi32(0x7fff);
|
||||
// uint32_t lsb = (input >> 16) & 1;
|
||||
auto t_value = _mm512_and_si512(_mm512_srli_epi32(value, 16), ones);
|
||||
// uint32_t rounding_bias = 0x7fff + lsb;
|
||||
t_value = _mm512_add_epi32(t_value, vec_bias);
|
||||
// input += rounding_bias;
|
||||
t_value = _mm512_add_epi32(t_value, value);
|
||||
// input = input >> 16;
|
||||
t_value = _mm512_srli_epi32(t_value, 16);
|
||||
// Check NaN before converting back to bf16
|
||||
t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value);
|
||||
return _mm512_cvtusepi32_epi16(t_value);
|
||||
}
|
||||
|
||||
__m512 cvt_fp16_to_fp32(const __m256i src) __attribute__((target("avx512bw")));
|
||||
inline __m512 cvt_fp16_to_fp32(const __m256i src) {
|
||||
return _mm512_cvtph_ps(src);
|
||||
}
|
||||
|
||||
inline __m256i cvt_fp32_to_fp16(const __m512 src) __attribute__((target("avx512bw")));
|
||||
inline __m256i cvt_fp32_to_fp16(const __m512 src) {
|
||||
return _mm512_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
|
||||
}
|
||||
|
||||
void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
|
||||
__attribute__((target("avx512bw")));
|
||||
|
||||
void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
|
||||
__attribute__((target("avx512bw")));
|
||||
|
||||
void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
|
||||
__attribute__((target("avx512bw")));
|
||||
|
||||
void reduce_all_buffers(
|
||||
int start_elements,
|
||||
int num_elements,
|
||||
c10::ScalarType scalar_type,
|
||||
int to_buffer_idx,
|
||||
char* to_buffer,
|
||||
char** buffers) {
|
||||
switch (scalar_type) {
|
||||
case c10::ScalarType::BFloat16:
|
||||
reduce_bf16_buffers(start_elements, num_elements, to_buffer, buffers);
|
||||
break;
|
||||
case c10::ScalarType::Half:
|
||||
reduce_fp16_buffers(start_elements, num_elements, to_buffer, buffers);
|
||||
break;
|
||||
case c10::ScalarType::Float:
|
||||
reduce_fp32_buffers(start_elements, num_elements, to_buffer, buffers);
|
||||
break;
|
||||
default:
|
||||
assert(!"Should not get here");
|
||||
}
|
||||
}
|
||||
|
||||
#define CVT_ADD_BF16(x) \
|
||||
do { \
|
||||
auto in##x##_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[x] + i))); \
|
||||
inout_val = _mm512_add_ps(inout_val, in##x##_val); \
|
||||
} while (0)
|
||||
|
||||
// Reduce functions down below use vectorized algorithm, the number of bytes
|
||||
// processed each iteration depends on vector length. 256bit vector ==> 32
|
||||
// bytes, 512bit vector ==> 64 bytes If you change implementation of
|
||||
// reduce_bf16_buffers, etc. , check whether this number needs to be changed
|
||||
#define VECTOR_LENGTH_IN_BYTES 32
|
||||
|
||||
void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) {
|
||||
const int element_size = 2;
|
||||
const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size;
|
||||
int main_elements = num_elements - (num_elements % vector_length);
|
||||
int remain_elements = num_elements % vector_length;
|
||||
|
||||
// process aligned part
|
||||
#pragma omp parallel for
|
||||
for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size;
|
||||
i += VECTOR_LENGTH_IN_BYTES) {
|
||||
auto inout_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[0] + i)));
|
||||
switch (world_size) {
|
||||
case 16:
|
||||
CVT_ADD_BF16(15);
|
||||
case 15:
|
||||
CVT_ADD_BF16(14);
|
||||
case 14:
|
||||
CVT_ADD_BF16(13);
|
||||
case 13:
|
||||
CVT_ADD_BF16(12);
|
||||
case 12:
|
||||
CVT_ADD_BF16(11);
|
||||
case 11:
|
||||
CVT_ADD_BF16(10);
|
||||
case 10:
|
||||
CVT_ADD_BF16(9);
|
||||
case 9:
|
||||
CVT_ADD_BF16(8);
|
||||
case 8:
|
||||
CVT_ADD_BF16(7);
|
||||
case 7:
|
||||
CVT_ADD_BF16(6);
|
||||
case 6:
|
||||
CVT_ADD_BF16(5);
|
||||
case 5:
|
||||
CVT_ADD_BF16(4);
|
||||
case 4:
|
||||
CVT_ADD_BF16(3);
|
||||
case 3:
|
||||
CVT_ADD_BF16(2);
|
||||
case 2:
|
||||
CVT_ADD_BF16(1);
|
||||
case 1:
|
||||
break;
|
||||
default:
|
||||
for (int j = 1; j < world_size; j++) {
|
||||
auto in_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[j] + i)));
|
||||
inout_val = _mm512_add_ps(inout_val, in_val);
|
||||
}
|
||||
}
|
||||
_mm256_storeu_si256((__m256i*)(to_buffer + i), cvt_fp32_to_bf16(inout_val));
|
||||
}
|
||||
|
||||
// process remaining part
|
||||
int i = (start_elements + main_elements) * element_size;
|
||||
while (remain_elements > 0) {
|
||||
float val = 0.0f;
|
||||
for (int j = 0; j < world_size; j++) {
|
||||
val += *(at::BFloat16*)(buffers[j] + i);
|
||||
}
|
||||
*(at::BFloat16*)(to_buffer + i) = val;
|
||||
remain_elements--;
|
||||
i += element_size;
|
||||
}
|
||||
}
|
||||
|
||||
#define CVT_ADD_FP16(x) \
|
||||
do { \
|
||||
auto in##x##_val = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[x] + i))); \
|
||||
inout_val = _mm512_add_ps(inout_val, in##x##_val); \
|
||||
} while (0)
|
||||
|
||||
void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) {
|
||||
const int element_size = 2;
|
||||
const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size;
|
||||
int main_elements = num_elements - (num_elements % vector_length);
|
||||
int remain_elements = num_elements % vector_length;
|
||||
|
||||
// process aligned part
|
||||
#pragma omp parallel for
|
||||
for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size;
|
||||
i += VECTOR_LENGTH_IN_BYTES) {
|
||||
auto inout_val = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[0] + i)));
|
||||
switch (world_size) {
|
||||
case 16:
|
||||
CVT_ADD_FP16(15);
|
||||
case 15:
|
||||
CVT_ADD_FP16(14);
|
||||
case 14:
|
||||
CVT_ADD_FP16(13);
|
||||
case 13:
|
||||
CVT_ADD_FP16(12);
|
||||
case 12:
|
||||
CVT_ADD_FP16(11);
|
||||
case 11:
|
||||
CVT_ADD_FP16(10);
|
||||
case 10:
|
||||
CVT_ADD_FP16(9);
|
||||
case 9:
|
||||
CVT_ADD_FP16(8);
|
||||
case 8:
|
||||
CVT_ADD_FP16(7);
|
||||
case 7:
|
||||
CVT_ADD_FP16(6);
|
||||
case 6:
|
||||
CVT_ADD_FP16(5);
|
||||
case 5:
|
||||
CVT_ADD_FP16(4);
|
||||
case 4:
|
||||
CVT_ADD_FP16(3);
|
||||
case 3:
|
||||
CVT_ADD_FP16(2);
|
||||
case 2:
|
||||
CVT_ADD_FP16(1);
|
||||
case 1:
|
||||
break;
|
||||
default:
|
||||
for (int j = 1; j < world_size; j++) {
|
||||
auto in_val = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[j] + i)));
|
||||
inout_val = _mm512_add_ps(inout_val, in_val);
|
||||
}
|
||||
}
|
||||
_mm256_storeu_si256((__m256i*)(to_buffer + i), cvt_fp32_to_fp16(inout_val));
|
||||
}
|
||||
|
||||
// process remaining part
|
||||
int i = (start_elements + main_elements) * element_size;
|
||||
while (remain_elements > 0) {
|
||||
float val = 0.0f;
|
||||
for (int j = 0; j < world_size; j++) {
|
||||
val += *(at::Half*)(buffers[j] + i);
|
||||
}
|
||||
*(at::Half*)(to_buffer + i) = val;
|
||||
remain_elements--;
|
||||
i += element_size;
|
||||
}
|
||||
}
|
||||
|
||||
#define CVT_ADD_F32(x) \
|
||||
do { \
|
||||
auto in##x##_val = _mm256_loadu_ps((float*)(buffers[x] + i)); \
|
||||
inout_val = _mm256_add_ps(inout_val, in##x##_val); \
|
||||
} while (0)
|
||||
|
||||
void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) {
|
||||
const int element_size = 4;
|
||||
const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size;
|
||||
int main_elements = num_elements - (num_elements % vector_length);
|
||||
int remain_elements = num_elements % vector_length;
|
||||
|
||||
// process aligned part
|
||||
#pragma omp parallel for
|
||||
for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size;
|
||||
i += VECTOR_LENGTH_IN_BYTES) {
|
||||
auto inout_val = _mm256_loadu_ps((float*)(buffers[0] + i));
|
||||
switch (world_size) {
|
||||
case 16:
|
||||
CVT_ADD_F32(15);
|
||||
case 15:
|
||||
CVT_ADD_F32(14);
|
||||
case 14:
|
||||
CVT_ADD_F32(13);
|
||||
case 13:
|
||||
CVT_ADD_F32(12);
|
||||
case 12:
|
||||
CVT_ADD_F32(11);
|
||||
case 11:
|
||||
CVT_ADD_F32(10);
|
||||
case 10:
|
||||
CVT_ADD_F32(9);
|
||||
case 9:
|
||||
CVT_ADD_F32(8);
|
||||
case 8:
|
||||
CVT_ADD_F32(7);
|
||||
case 7:
|
||||
CVT_ADD_F32(6);
|
||||
case 6:
|
||||
CVT_ADD_F32(5);
|
||||
case 5:
|
||||
CVT_ADD_F32(4);
|
||||
case 4:
|
||||
CVT_ADD_F32(3);
|
||||
case 3:
|
||||
CVT_ADD_F32(2);
|
||||
case 2:
|
||||
CVT_ADD_F32(1);
|
||||
case 1:
|
||||
break;
|
||||
default:
|
||||
for (int j = 1; j < world_size; j++) {
|
||||
auto in_val = _mm256_loadu_ps((float*)(buffers[j] + i));
|
||||
inout_val = _mm256_add_ps(inout_val, in_val);
|
||||
}
|
||||
}
|
||||
_mm256_storeu_ps((float*)(to_buffer + i), inout_val);
|
||||
}
|
||||
|
||||
// process remaining part
|
||||
int i = (start_elements + main_elements) * element_size;
|
||||
while (remain_elements > 0) {
|
||||
float val = 0.0f;
|
||||
for (int j = 0; j < world_size; j++) {
|
||||
val += *(float*)(buffers[j] + i);
|
||||
}
|
||||
*(float*)(to_buffer + i) = val;
|
||||
remain_elements--;
|
||||
i += element_size;
|
||||
}
|
||||
}
|
||||
|
||||
static bool is_initialized = false;
|
||||
static int world_rank;
|
||||
|
||||
void shm_initialize(int size, int rank, const char* addr_string, const char* port_string) {
|
||||
if (is_initialized) {
|
||||
return;
|
||||
}
|
||||
is_initialized = true;
|
||||
|
||||
world_size = size;
|
||||
world_rank = rank;
|
||||
|
||||
char shm_name_prefix[NAME_BUF_SIZE];
|
||||
char shm_name[NAME_BUF_SIZE];
|
||||
snprintf(shm_name_prefix, NAME_BUF_SIZE, "%s_%d_%s_%s", SHM_BUFFER_NAME, getuid(), addr_string, port_string);
|
||||
// create shared workspace for SHM based allreduce
|
||||
SharedData allreduce_buffer;
|
||||
// allocate workspace_buf for current rank
|
||||
struct allreduce_workspace* workspace_buf;
|
||||
struct allreduce_workspace* workspace_buf_other;
|
||||
workspace_buf = (struct allreduce_workspace*)malloc(sizeof(struct allreduce_workspace));
|
||||
snprintf(shm_name, NAME_BUF_SIZE, "%.900s_%d", shm_name_prefix, rank);
|
||||
shared_create(&allreduce_buffer, shm_name, workspace_buf, sizeof(struct allreduce_workspace));
|
||||
workspace_buf = (struct allreduce_workspace*)allreduce_buffer.bytes;
|
||||
workspace_buf->states[0] = coll_alt2_allreduce_naive__copy_in_done;
|
||||
workspace_buf->states[1] = coll_begin;
|
||||
|
||||
// create the workspace pointer list
|
||||
workspace = (struct allreduce_workspace**)malloc(size * sizeof(struct allreduce_workspace*));
|
||||
symmetric_buffer[0] = (char**)malloc(size * sizeof(char**));
|
||||
symmetric_buffer[1] = (char**)malloc(size * sizeof(char**));
|
||||
distributed_buffer[0] = (char**)malloc(size * sizeof(char**));
|
||||
distributed_buffer[1] = (char**)malloc(size * sizeof(char**));
|
||||
|
||||
// map shm of all ranks
|
||||
for (int i = 0; i < size; i++) {
|
||||
if (i != rank) {
|
||||
snprintf(shm_name, NAME_BUF_SIZE, "%.900s_%d", shm_name_prefix, i);
|
||||
// printf("open %s, %d\n", shm_name, rank);
|
||||
do {
|
||||
shared_open(&allreduce_buffer, shm_name, sizeof(struct allreduce_workspace));
|
||||
} while (allreduce_buffer.descriptor == -1 && errno == ENOENT);
|
||||
workspace_buf_other = (struct allreduce_workspace*)allreduce_buffer.bytes;
|
||||
workspace[i] = workspace_buf_other;
|
||||
} else {
|
||||
workspace[i] = workspace_buf;
|
||||
}
|
||||
symmetric_buffer[0][i] = workspace[i]->buffer + BUFFER0_OFFSET(0);
|
||||
symmetric_buffer[1][i] = workspace[i]->buffer + BUFFER0_OFFSET(1);
|
||||
distributed_buffer[0][i] = workspace[i]->buffer + BUFFER1_OFFSET(0);
|
||||
distributed_buffer[1][i] = workspace[i]->buffer + BUFFER1_OFFSET(1);
|
||||
}
|
||||
}
|
||||
|
||||
static void parallel_memcpy(void* to, void* from, size_t n_bytes) __attribute__((target("avx512bw")));
|
||||
static void parallel_memcpy(void* to, void* from, size_t n_bytes) {
|
||||
auto aligned_bytes = n_bytes - (n_bytes % VECTOR_LENGTH_IN_BYTES);
|
||||
// process aligned part
|
||||
#pragma omp parallel for
|
||||
for (size_t i = 0; i < aligned_bytes; i += VECTOR_LENGTH_IN_BYTES) {
|
||||
auto val = _mm256_loadu_si256((__m256i*)((char*)from + i));
|
||||
_mm256_storeu_si256((__m256i*)((char*)to + i), val);
|
||||
}
|
||||
|
||||
// process remaining part
|
||||
for (size_t i = aligned_bytes; i < n_bytes; i++) {
|
||||
*((char*)to + i) = *((char*)from + i);
|
||||
}
|
||||
}
|
||||
|
||||
#define positive_mod(num, mod) ((((num) % (mod)) + (mod)) % (mod))
|
||||
#define rank_mod(rank) positive_mod(rank, world_size)
|
||||
size_t slice_size(size_t chunk_el, int slice_idx) {
|
||||
size_t slice_size = chunk_el / world_size;
|
||||
return slice_idx == world_size - 1 ? slice_size + (chunk_el % world_size) : slice_size;
|
||||
}
|
||||
|
||||
char* slice_data(char* data_ptr, size_t chunk_el, int el_size, int slice_idx) {
|
||||
size_t slice_size = chunk_el / world_size;
|
||||
size_t el_offset = slice_size * slice_idx;
|
||||
return data_ptr + el_offset * el_size;
|
||||
}
|
||||
|
||||
size_t slice_el_start(size_t chunk_el, int slice_idx) {
|
||||
size_t slice_size = chunk_el / world_size;
|
||||
return slice_size * slice_idx;
|
||||
}
|
||||
|
||||
void symmetric_naive_all_reduce(char* data_ptr, c10::ScalarType scalar_type, size_t chunk_size, size_t chunk_el) {
|
||||
const int state_group = 0;
|
||||
static int current_buffer = 0;
|
||||
static int state_idx = 0;
|
||||
|
||||
// init states to case 0 to get rid of "maybe-uninitialized" warning.
|
||||
enum coll_state copy_current = coll_allreduce_naive__copy_in_done;
|
||||
enum coll_state copy_next = coll_alt1_allreduce_naive__copy_in_done;
|
||||
|
||||
switch (state_idx) {
|
||||
case 0:
|
||||
copy_current = coll_allreduce_naive__copy_in_done;
|
||||
copy_next = coll_alt1_allreduce_naive__copy_in_done;
|
||||
break;
|
||||
case 1:
|
||||
copy_current = coll_alt1_allreduce_naive__copy_in_done;
|
||||
copy_next = coll_alt2_allreduce_naive__copy_in_done;
|
||||
break;
|
||||
case 2:
|
||||
copy_current = coll_alt2_allreduce_naive__copy_in_done;
|
||||
copy_next = coll_allreduce_naive__copy_in_done;
|
||||
break;
|
||||
default:
|
||||
assert(!"Should not get here.");
|
||||
}
|
||||
state_idx = (state_idx + 1) % 3;
|
||||
|
||||
parallel_memcpy(symmetric_buffer[current_buffer][world_rank], data_ptr, chunk_size);
|
||||
std::atomic_thread_fence(std::memory_order_release);
|
||||
workspace[world_rank]->states[state_group] = copy_current;
|
||||
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
// wait until the other rank copy the buffer
|
||||
if (i != world_rank) {
|
||||
wait_buffer_state_until_2(i, copy_current, copy_next, state_group);
|
||||
}
|
||||
}
|
||||
|
||||
// each rank reduce the buffer independently so therre is no need for
|
||||
// synchronization afterward
|
||||
reduce_all_buffers(0, chunk_el, scalar_type, world_rank, data_ptr, symmetric_buffer[current_buffer]);
|
||||
|
||||
// switch buffer
|
||||
current_buffer = 1 - current_buffer;
|
||||
}
|
||||
|
||||
// naive allreduce distributed, each rank do naive reduce on its slice
|
||||
void distributed_naive_reduce(char* data_ptr, c10::ScalarType scalar_type, size_t chunk_size, size_t chunk_el) {
|
||||
const int state_group = 1;
|
||||
static int current_buffer = 0;
|
||||
static int state_idx = 0;
|
||||
|
||||
// init states to case 0 to get rid of "maybe-uninitialized" warning.
|
||||
enum coll_state copy_current = coll_allreduce_naive__copy_in_done;
|
||||
enum coll_state reduce_current = coll_allreduce_naive__reduce_done;
|
||||
enum coll_state copy_next = coll_alt1_allreduce_naive__copy_in_done;
|
||||
|
||||
// similar to symmetric_naive_allreduce, but here we only need two sets of
|
||||
// states, because distributed naive reduce has two barriers in the algorithm
|
||||
switch (state_idx) {
|
||||
case 0:
|
||||
copy_current = coll_allreduce_naive__copy_in_done;
|
||||
reduce_current = coll_allreduce_naive__reduce_done;
|
||||
copy_next = coll_alt1_allreduce_naive__copy_in_done;
|
||||
break;
|
||||
case 1:
|
||||
copy_current = coll_alt1_allreduce_naive__copy_in_done;
|
||||
reduce_current = coll_alt1_allreduce_naive__reduce_done;
|
||||
copy_next = coll_allreduce_naive__copy_in_done;
|
||||
break;
|
||||
default:
|
||||
assert(!"Should not get here.");
|
||||
}
|
||||
state_idx = (state_idx + 1) % 2;
|
||||
|
||||
int data_size = chunk_size / chunk_el;
|
||||
parallel_memcpy(distributed_buffer[current_buffer][world_rank], data_ptr, chunk_size);
|
||||
std::atomic_thread_fence(std::memory_order_release);
|
||||
workspace[world_rank]->states[state_group] = copy_current;
|
||||
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
// wait until all the other ranks copy the buffer
|
||||
if (i != world_rank) wait_buffer_state_until_2(i, copy_current, reduce_current, state_group);
|
||||
}
|
||||
|
||||
// reduce scatter
|
||||
reduce_all_buffers(
|
||||
slice_el_start(chunk_el, world_rank),
|
||||
slice_size(chunk_el, world_rank),
|
||||
scalar_type,
|
||||
world_rank,
|
||||
distributed_buffer[current_buffer][world_rank],
|
||||
distributed_buffer[current_buffer]);
|
||||
std::atomic_thread_fence(std::memory_order_release);
|
||||
workspace[world_rank]->states[state_group] = reduce_current;
|
||||
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
// wait until all the other ranks reduce the buffer
|
||||
if (i != world_rank) wait_buffer_state_until_2(i, reduce_current, copy_next, state_group);
|
||||
}
|
||||
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
int rank = (i + world_rank) % world_size;
|
||||
parallel_memcpy(
|
||||
slice_data(data_ptr, chunk_el, data_size, rank),
|
||||
slice_data(distributed_buffer[current_buffer][rank], chunk_el, chunk_size / chunk_el, rank),
|
||||
slice_size(chunk_el, rank) * data_size);
|
||||
}
|
||||
|
||||
current_buffer = 1 - current_buffer;
|
||||
}
|
||||
|
||||
void all_reduce_outer_loop(torch::Tensor& data, size_t numel, int data_size) {
|
||||
for (int offset = 0; offset < data_size; offset += MAX_BUF_SIZE) {
|
||||
auto data_ptr = ((char*)(data.data_ptr()) + offset);
|
||||
size_t chunk_size = data_size - offset > MAX_BUF_SIZE ? MAX_BUF_SIZE : data_size - offset;
|
||||
size_t chunk_el = chunk_size / (data_size / numel);
|
||||
if (chunk_size < NAIVE_ALLREDUCE_THRESHOLD) {
|
||||
symmetric_naive_all_reduce(data_ptr, data.scalar_type(), chunk_size, chunk_el);
|
||||
} else {
|
||||
distributed_naive_reduce(data_ptr, data.scalar_type(), chunk_size, chunk_el);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void naive_all_gather(char* result_ptr, char* data_ptr, size_t res_stride, size_t chunk_size, size_t chunk_el) {
|
||||
const int state_group = 1;
|
||||
static int current_buffer = 0;
|
||||
static int state_idx = 0;
|
||||
|
||||
// init states to case 0 to get rid of "maybe-uninitialized" warning.
|
||||
enum coll_state copy_current = coll_allgather_naive__copy_in_done;
|
||||
enum coll_state copy_next = coll_alt1_allgather_naive__copy_in_done;
|
||||
|
||||
switch (state_idx) {
|
||||
case 0:
|
||||
copy_current = coll_allgather_naive__copy_in_done;
|
||||
copy_next = coll_alt1_allgather_naive__copy_in_done;
|
||||
break;
|
||||
case 1:
|
||||
copy_current = coll_alt1_allgather_naive__copy_in_done;
|
||||
copy_next = coll_alt2_allgather_naive__copy_in_done;
|
||||
break;
|
||||
case 2:
|
||||
copy_current = coll_alt2_allgather_naive__copy_in_done;
|
||||
copy_next = coll_allgather_naive__copy_in_done;
|
||||
break;
|
||||
default:
|
||||
assert(!"Should not get here.");
|
||||
}
|
||||
state_idx = (state_idx + 1) % 3;
|
||||
|
||||
parallel_memcpy(distributed_buffer[current_buffer][world_rank], data_ptr, chunk_size);
|
||||
std::atomic_thread_fence(std::memory_order_release);
|
||||
workspace[world_rank]->states[state_group] = copy_current;
|
||||
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
// wait until all the other ranks copy the buffer
|
||||
if (i != world_rank) wait_buffer_state_until_2(i, copy_current, copy_next, state_group);
|
||||
}
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
parallel_memcpy(result_ptr + i * res_stride, distributed_buffer[current_buffer][i], chunk_size);
|
||||
}
|
||||
current_buffer = 1 - current_buffer;
|
||||
}
|
||||
|
||||
torch::Tensor& all_gather(torch::Tensor& result, torch::Tensor& data, int dim, size_t numel, int data_size) {
|
||||
size_t dim_el = data.stride(dim) * data.size(dim);
|
||||
int dtype_size = data_size / numel;
|
||||
size_t dim_size = dim_el * dtype_size;
|
||||
int dim_count = data_size / dim_size;
|
||||
auto data_ptr = (char*)(data.data_ptr());
|
||||
auto result_ptr = (char*)(result.data_ptr());
|
||||
for (int i = 0; i < dim_count; i++) {
|
||||
for (size_t offset = 0; offset < dim_size; offset += MAX_BUF_SIZE) {
|
||||
size_t chunk_size = dim_size - offset > MAX_BUF_SIZE ? MAX_BUF_SIZE : dim_size - offset;
|
||||
size_t chunk_el = chunk_size / dtype_size;
|
||||
naive_all_gather(
|
||||
result_ptr + i * dim_size * world_size + offset,
|
||||
data_ptr + i * dim_size + offset,
|
||||
dim_size,
|
||||
chunk_size,
|
||||
chunk_el);
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
11
sgl-kernel/csrc/cpu/shm.h
Normal file
11
sgl-kernel/csrc/cpu/shm.h
Normal file
@@ -0,0 +1,11 @@
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
|
||||
|
||||
#ifndef __SHM_COLLECTIVES__
|
||||
#define __SHM_COLLECTIVES__
|
||||
#define VECTOR_LENGTH_IN_BYTES 32
|
||||
void shm_initialize(int size, int rank, const char* addr_string, const char* port_string);
|
||||
void all_reduce_outer_loop(torch::Tensor& data, size_t numel, int data_size);
|
||||
torch::Tensor& all_gather(torch::Tensor& result, torch::Tensor& data, int dim, size_t numel, int data_size);
|
||||
#endif
|
||||
662
sgl-kernel/csrc/cpu/topk.cpp
Normal file
662
sgl-kernel/csrc/cpu/topk.cpp
Normal file
@@ -0,0 +1,662 @@
|
||||
#include "common.h"
|
||||
#include "vec.h"
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t, int SIZE>
|
||||
inline void softmax(float* __restrict__ out, const scalar_t* __restrict__ input) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
|
||||
constexpr int kVecSize = bVec::size();
|
||||
|
||||
// step 1: get max
|
||||
fVec max_fvec = fVec(-std::numeric_limits<float>::infinity());
|
||||
if constexpr (SIZE < kVecSize) {
|
||||
// SIZE = 1, 2, 4, 8, 16; only the top half is used
|
||||
bVec x_bvec = bVec::loadu(input, SIZE);
|
||||
fVec x_fvec0, x_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||
x_fvec0 = fVec::set(max_fvec, x_fvec0, SIZE);
|
||||
max_fvec = at::vec::maximum(max_fvec, x_fvec0);
|
||||
x_fvec0.store(out, SIZE);
|
||||
} else {
|
||||
for (int d = 0; d < SIZE; d += kVecSize) {
|
||||
bVec x_bvec = bVec::loadu(input + d);
|
||||
fVec x_fvec0, x_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||
|
||||
max_fvec = at::vec::maximum(max_fvec, x_fvec0);
|
||||
max_fvec = at::vec::maximum(max_fvec, x_fvec1);
|
||||
x_fvec0.store(out + d);
|
||||
x_fvec1.store(out + d + fVec::size());
|
||||
}
|
||||
}
|
||||
float max_val = vec_reduce_max(max_fvec);
|
||||
max_fvec = fVec(max_val);
|
||||
|
||||
// step 2: sum of (x - max).exp()
|
||||
fVec sum_fvec = fVec(float(0));
|
||||
if constexpr (SIZE < fVec::size()) {
|
||||
// SIZE = 1, 2, 4, 8
|
||||
fVec x_fvec = (fVec::loadu(out, SIZE) - max_fvec).exp_u20();
|
||||
x_fvec = fVec::set(sum_fvec, x_fvec, SIZE);
|
||||
sum_fvec += x_fvec;
|
||||
x_fvec.store(out, SIZE);
|
||||
} else {
|
||||
for (int d = 0; d < SIZE; d += fVec::size()) {
|
||||
fVec x_fvec = (fVec::loadu(out + d) - max_fvec).exp_u20();
|
||||
sum_fvec += x_fvec;
|
||||
x_fvec.store(out + d);
|
||||
}
|
||||
}
|
||||
float sum_val = vec_reduce_sum(sum_fvec);
|
||||
|
||||
// step 3: x * (1 / sum)
|
||||
sum_fvec = fVec(1.f / sum_val);
|
||||
if constexpr (SIZE < fVec::size()) {
|
||||
// SIZE = 1, 2, 4, 8
|
||||
fVec out_fvec = fVec::loadu(out, SIZE) * sum_fvec;
|
||||
out_fvec.store(out, SIZE);
|
||||
} else {
|
||||
for (int d = 0; d < SIZE; d += fVec::size()) {
|
||||
fVec out_fvec = fVec::loadu(out + d) * sum_fvec;
|
||||
out_fvec.store(out + d);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int NUM_EXPERTS>
|
||||
void grouped_topk_kernel_impl(
|
||||
float* __restrict__ topk_weights,
|
||||
int32_t* __restrict__ topk_ids,
|
||||
const scalar_t* __restrict__ gating_output,
|
||||
int64_t num_tokens,
|
||||
int64_t topk,
|
||||
int64_t num_groups,
|
||||
int64_t topk_group,
|
||||
bool renormalize) {
|
||||
const int64_t num_experts_per_group = NUM_EXPERTS / num_groups;
|
||||
at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
|
||||
alignas(64) float scores[NUM_EXPERTS];
|
||||
|
||||
using elem_t = std::pair<float, int32_t>;
|
||||
std::vector<elem_t> queue(num_groups);
|
||||
std::vector<elem_t> queue2(topk_group * num_experts_per_group);
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
// do softmax to get scores
|
||||
softmax<scalar_t, NUM_EXPERTS>(scores, gating_output + i * NUM_EXPERTS);
|
||||
|
||||
// find max score per group
|
||||
for (int64_t g = 0; g < num_groups; ++g) {
|
||||
float gmax = -std::numeric_limits<float>::infinity();
|
||||
for (int64_t e = 0; e < num_experts_per_group; ++e) {
|
||||
gmax = std::max(gmax, scores[g * num_experts_per_group + e]);
|
||||
}
|
||||
queue[g] = {gmax, g};
|
||||
}
|
||||
|
||||
// find group topk
|
||||
std::partial_sort(
|
||||
queue.begin(), queue.begin() + topk_group, queue.end(), [](const elem_t& x, const elem_t& y) -> bool {
|
||||
return x.first > y.first;
|
||||
});
|
||||
|
||||
for (int64_t g = 0; g < topk_group; ++g) {
|
||||
int32_t group_idx = queue[g].second;
|
||||
for (int64_t e = 0; e < num_experts_per_group; ++e) {
|
||||
int32_t expert_idx = group_idx * num_experts_per_group + e;
|
||||
queue2[g * num_experts_per_group + e] = {scores[expert_idx], expert_idx};
|
||||
}
|
||||
}
|
||||
|
||||
// find global topk
|
||||
std::partial_sort(
|
||||
queue2.begin(), queue2.begin() + topk, queue2.end(), [](const elem_t& x, const elem_t& y) -> bool {
|
||||
return x.first > y.first;
|
||||
});
|
||||
|
||||
for (int64_t j = 0; j < topk; ++j) {
|
||||
topk_weights[i * topk + j] = queue2[j].first;
|
||||
topk_ids[i * topk + j] = queue2[j].second;
|
||||
}
|
||||
|
||||
if (renormalize) {
|
||||
float sum = 0.f;
|
||||
for (int64_t j = 0; j < topk; ++j) {
|
||||
sum += topk_weights[i * topk + j];
|
||||
}
|
||||
float scale = 1.f / sum;
|
||||
for (int64_t j = 0; j < topk; ++j) {
|
||||
topk_weights[i * topk + j] *= scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t, int SIZE>
|
||||
inline void sigmoid(float* __restrict__ out, const scalar_t* __restrict__ input) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
|
||||
const fVec one = fVec(1.f);
|
||||
|
||||
constexpr int kVecSize = bVec::size();
|
||||
for (int d = 0; d < SIZE; d += kVecSize) {
|
||||
bVec x_bvec = bVec::loadu(input + d);
|
||||
fVec x_fvec0, x_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||
|
||||
x_fvec0 = one / (one + x_fvec0.neg().exp_u20());
|
||||
x_fvec1 = one / (one + x_fvec1.neg().exp_u20());
|
||||
|
||||
x_fvec0.store(out + d);
|
||||
x_fvec1.store(out + d + fVec::size());
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, int NUM_EXPERTS>
|
||||
void topk_sigmoid_kernel_impl(
|
||||
float* __restrict__ topk_weights,
|
||||
int32_t* __restrict__ topk_ids,
|
||||
const scalar_t* __restrict__ gating_output,
|
||||
int64_t num_tokens,
|
||||
int64_t topk,
|
||||
bool renormalize) {
|
||||
using Vec = at::vec::Vectorized<float>;
|
||||
const int64_t num_experts_per_group = NUM_EXPERTS;
|
||||
at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
|
||||
alignas(64) float scores[NUM_EXPERTS];
|
||||
using elem_t = std::pair<float, int32_t>;
|
||||
std::vector<elem_t> queue(num_experts_per_group);
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
at::vec::convert<scalar_t, float>(gating_output + i * NUM_EXPERTS, scores, NUM_EXPERTS);
|
||||
|
||||
float gmax = at::vec::reduce_all<float>(
|
||||
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, scores, num_experts_per_group);
|
||||
|
||||
// find position of first max,
|
||||
// note that we may have multiple max values.
|
||||
int first_max_idx = -1;
|
||||
for (int64_t e = 0; e < num_experts_per_group; ++e) {
|
||||
if (scores[e] == gmax) {
|
||||
first_max_idx = e;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// scalar sigmoid
|
||||
topk_weights[i] = 1.0 / (1.0 + exp(0.0 - gmax));
|
||||
topk_ids[i] = first_max_idx;
|
||||
|
||||
if (renormalize) {
|
||||
float sum = 0.f;
|
||||
for (int64_t j = 0; j < topk; ++j) {
|
||||
sum += topk_weights[i * topk + j];
|
||||
}
|
||||
float scale = 1.f / sum;
|
||||
for (int64_t j = 0; j < topk; ++j) {
|
||||
topk_weights[i * topk + j] *= scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t, int NUM_EXPERTS>
|
||||
void topk_softmax_kernel_impl(
|
||||
float* __restrict__ topk_weights,
|
||||
int32_t* __restrict__ topk_ids,
|
||||
const scalar_t* __restrict__ gating_output,
|
||||
int64_t num_tokens,
|
||||
int64_t topk,
|
||||
bool renormalize) {
|
||||
const int64_t num_experts_per_group = NUM_EXPERTS;
|
||||
at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
|
||||
alignas(64) float scores[NUM_EXPERTS];
|
||||
using elem_t = std::pair<float, int32_t>;
|
||||
std::vector<elem_t> queue(num_experts_per_group);
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
softmax<scalar_t, NUM_EXPERTS>(scores, gating_output + i * NUM_EXPERTS);
|
||||
|
||||
for (int64_t e = 0; e < num_experts_per_group; ++e) {
|
||||
queue[e] = {scores[e], e};
|
||||
}
|
||||
|
||||
std::partial_sort(
|
||||
queue.begin(),
|
||||
queue.begin() + num_experts_per_group,
|
||||
queue.end(),
|
||||
[](const elem_t& x, const elem_t& y) -> bool { return x.first > y.first; });
|
||||
|
||||
for (int64_t j = 0; j < topk; ++j) {
|
||||
topk_weights[i * topk + j] = queue[j].first;
|
||||
topk_ids[i * topk + j] = queue[j].second;
|
||||
}
|
||||
|
||||
if (renormalize) {
|
||||
float sum = 0.f;
|
||||
for (int64_t j = 0; j < topk; ++j) {
|
||||
sum += topk_weights[i * topk + j];
|
||||
}
|
||||
float scale = 1.f / sum;
|
||||
for (int64_t j = 0; j < topk; ++j) {
|
||||
topk_weights[i * topk + j] *= scale;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename param_t, int SIZE>
|
||||
inline void
|
||||
apply_bias(float* __restrict__ scores2, const float* __restrict__ scores, const param_t* __restrict__ bias) {
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
auto vec_size = bVec::size();
|
||||
int d = 0;
|
||||
for (; d <= SIZE - vec_size; d += vec_size) {
|
||||
fVec bias0, bias1, x0, x1;
|
||||
std::tie(bias0, bias1) = load_float_vec2(bias + d);
|
||||
std::tie(x0, x1) = load_float_vec2(scores + d);
|
||||
x0 = x0 + bias0;
|
||||
x1 = x1 + bias1;
|
||||
x0.store(scores2 + d);
|
||||
x1.store(scores2 + d + fVec::size());
|
||||
}
|
||||
for (; d < SIZE; d++) {
|
||||
scores2[d] = scores[d] + (float)bias[d];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename param_t, int NUM_EXPERTS, int TOPK>
|
||||
void biased_grouped_topk_kernel_impl(
|
||||
float* __restrict__ topk_weights,
|
||||
int32_t* __restrict__ topk_ids,
|
||||
const scalar_t* __restrict__ gating_output,
|
||||
const param_t* __restrict__ bias,
|
||||
int64_t num_tokens,
|
||||
int64_t num_groups,
|
||||
int64_t topk_group,
|
||||
bool renormalize) {
|
||||
using Vec = at::vec::Vectorized<float>;
|
||||
|
||||
const int64_t num_experts_per_group = NUM_EXPERTS / num_groups;
|
||||
at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
|
||||
// scores: sigmoid
|
||||
alignas(64) float scores[NUM_EXPERTS];
|
||||
// scores for choice: sigmoid + bias
|
||||
alignas(64) float scores2[NUM_EXPERTS];
|
||||
|
||||
using elem_t = std::pair<float, int32_t>;
|
||||
std::vector<elem_t> queue(num_groups);
|
||||
std::vector<elem_t> queue2(topk_group * num_experts_per_group);
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
// do sigmoid to get scores
|
||||
sigmoid<scalar_t, NUM_EXPERTS>(scores, gating_output + i * NUM_EXPERTS);
|
||||
|
||||
apply_bias<scalar_t, param_t, NUM_EXPERTS>(scores2, scores, bias);
|
||||
|
||||
for (int64_t g = 0; g < num_groups; ++g) {
|
||||
// find the max
|
||||
float gmax = at::vec::reduce_all<float>(
|
||||
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); },
|
||||
scores2 + g * num_experts_per_group,
|
||||
num_experts_per_group);
|
||||
|
||||
// find position of first max,
|
||||
// note that we may have multiple max values.
|
||||
int first_max_idx = -1;
|
||||
for (int64_t e = 0; e < num_experts_per_group; ++e) {
|
||||
if (scores2[g * num_experts_per_group + e] == gmax) {
|
||||
first_max_idx = g * num_experts_per_group + e;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// find the 2nd max
|
||||
scores2[first_max_idx] = -std::numeric_limits<float>::infinity();
|
||||
float gmax2 = at::vec::reduce_all<float>(
|
||||
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); },
|
||||
scores2 + g * num_experts_per_group,
|
||||
num_experts_per_group);
|
||||
// restore scores for choice
|
||||
scores2[first_max_idx] = gmax;
|
||||
|
||||
queue[g] = {gmax + gmax2, g};
|
||||
}
|
||||
|
||||
// find group topk
|
||||
std::partial_sort(
|
||||
queue.begin(), queue.begin() + topk_group, queue.end(), [](const elem_t& x, const elem_t& y) -> bool {
|
||||
return x.first > y.first;
|
||||
});
|
||||
|
||||
for (int64_t g = 0; g < topk_group; ++g) {
|
||||
int32_t group_idx = queue[g].second;
|
||||
for (int64_t e = 0; e < num_experts_per_group; ++e) {
|
||||
int32_t expert_idx = group_idx * num_experts_per_group + e;
|
||||
queue2[g * num_experts_per_group + e] = {scores2[expert_idx], expert_idx};
|
||||
}
|
||||
}
|
||||
|
||||
// find global topk
|
||||
std::partial_sort(
|
||||
queue2.begin(), queue2.begin() + TOPK, queue2.end(), [](const elem_t& x, const elem_t& y) -> bool {
|
||||
return x.first > y.first;
|
||||
});
|
||||
|
||||
for (int j = 0; j < TOPK; ++j) {
|
||||
int32_t index = queue2[j].second;
|
||||
topk_ids[i * TOPK + j] = index;
|
||||
topk_weights[i * TOPK + j] = scores[index];
|
||||
}
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
if (renormalize) {
|
||||
__mmask16 mask = (1ULL << TOPK) - 1;
|
||||
__m512 x = _mm512_maskz_loadu_ps(mask, topk_weights + i * TOPK);
|
||||
float sum = _mm512_reduce_add_ps(x);
|
||||
__m512 vscale = _mm512_set1_ps(1.f / sum);
|
||||
__m512 y = _mm512_mul_ps(x, vscale);
|
||||
_mm512_mask_storeu_ps(topk_weights + i * TOPK, mask, y);
|
||||
}
|
||||
#else
|
||||
if (renormalize) {
|
||||
float sum = 0.f;
|
||||
for (int64_t j = 0; j < TOPK; ++j) {
|
||||
sum += topk_weights[i * TOPK + j];
|
||||
}
|
||||
float scale = 1.f / sum;
|
||||
for (int64_t j = 0; j < TOPK; ++j) {
|
||||
topk_weights[i * TOPK + j] *= scale;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#define LAUNCH_GROUPED_TOPK_KERNEL(NE) \
|
||||
grouped_topk_kernel_impl<scalar_t, NE>( \
|
||||
topk_weights.data_ptr<float>(), \
|
||||
topk_ids.data_ptr<int32_t>(), \
|
||||
gating_output.data_ptr<scalar_t>(), \
|
||||
num_tokens, \
|
||||
topk, \
|
||||
num_expert_group, \
|
||||
topk_group, \
|
||||
renormalize);
|
||||
|
||||
#define LAUNCH_TOPK_SIGMOID_KERNEL(NE) \
|
||||
topk_sigmoid_kernel_impl<scalar_t, NE>( \
|
||||
topk_weights.data_ptr<float>(), \
|
||||
topk_ids.data_ptr<int32_t>(), \
|
||||
gating_output.data_ptr<scalar_t>(), \
|
||||
num_tokens, \
|
||||
topk, \
|
||||
renormalize);
|
||||
|
||||
#define LAUNCH_TOPK_SOFTMAX_KERNEL(NE) \
|
||||
topk_softmax_kernel_impl<scalar_t, NE>( \
|
||||
topk_weights.data_ptr<float>(), \
|
||||
topk_ids.data_ptr<int32_t>(), \
|
||||
gating_output.data_ptr<scalar_t>(), \
|
||||
num_tokens, \
|
||||
topk, \
|
||||
renormalize);
|
||||
|
||||
#define LAUNCH_BIASED_GROUPED_TOPK_KERNEL(NE, NTOPK) \
|
||||
biased_grouped_topk_kernel_impl<scalar_t, param_t, NE, NTOPK>( \
|
||||
topk_weights.data_ptr<float>(), \
|
||||
topk_ids.data_ptr<int32_t>(), \
|
||||
gating_output.data_ptr<scalar_t>(), \
|
||||
correction_bias.data_ptr<param_t>(), \
|
||||
num_tokens, \
|
||||
num_expert_group, \
|
||||
topk_group, \
|
||||
renormalize);
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor>
|
||||
topk_sigmoid_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize) {
|
||||
RECORD_FUNCTION("sgl-kernel::topk_sigmoid_cpu", std::vector<c10::IValue>({hidden_states, gating_output}));
|
||||
CHECK_INPUT(gating_output);
|
||||
|
||||
const auto st = hidden_states.scalar_type();
|
||||
CHECK_EQ(gating_output.scalar_type(), st);
|
||||
|
||||
int64_t num_tokens = hidden_states.size(0);
|
||||
int64_t num_experts = gating_output.size(1);
|
||||
TORCH_CHECK(gating_output.size(0) == num_tokens, "Number of tokens mismatch");
|
||||
TORCH_CHECK(topk == 1, "topk_sigmoid only supports topk=1 case");
|
||||
at::Tensor topk_weights = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kFloat));
|
||||
at::Tensor topk_ids = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kInt));
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "topk_sigmoid_kernel", [&] {
|
||||
switch (num_experts) {
|
||||
case 1:
|
||||
LAUNCH_TOPK_SIGMOID_KERNEL(1);
|
||||
break;
|
||||
case 2:
|
||||
LAUNCH_TOPK_SIGMOID_KERNEL(2);
|
||||
break;
|
||||
case 4:
|
||||
LAUNCH_TOPK_SIGMOID_KERNEL(4);
|
||||
break;
|
||||
case 8:
|
||||
LAUNCH_TOPK_SIGMOID_KERNEL(8);
|
||||
break;
|
||||
case 16:
|
||||
LAUNCH_TOPK_SIGMOID_KERNEL(16);
|
||||
break;
|
||||
case 32:
|
||||
LAUNCH_TOPK_SIGMOID_KERNEL(32);
|
||||
break;
|
||||
case 64:
|
||||
LAUNCH_TOPK_SIGMOID_KERNEL(64);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_TOPK_SIGMOID_KERNEL(128);
|
||||
break;
|
||||
case 160:
|
||||
LAUNCH_TOPK_SIGMOID_KERNEL(160);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_TOPK_SIGMOID_KERNEL(256);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unexpected num_experts: ", num_experts);
|
||||
}
|
||||
});
|
||||
return std::make_tuple(topk_weights, topk_ids);
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor>
|
||||
topk_softmax_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize) {
|
||||
RECORD_FUNCTION("sgl-kernel::topk_softmax_cpu", std::vector<c10::IValue>({hidden_states, gating_output}));
|
||||
CHECK_INPUT(gating_output);
|
||||
|
||||
const auto st = hidden_states.scalar_type();
|
||||
CHECK_EQ(gating_output.scalar_type(), st);
|
||||
|
||||
int64_t num_tokens = hidden_states.size(0);
|
||||
int64_t num_experts = gating_output.size(1);
|
||||
TORCH_CHECK(gating_output.size(0) == num_tokens, "Number of tokens mismatch");
|
||||
|
||||
at::Tensor topk_weights = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kFloat));
|
||||
at::Tensor topk_ids = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kInt));
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "topk_softmax_cpu", [&] {
|
||||
switch (num_experts) {
|
||||
case 1:
|
||||
LAUNCH_TOPK_SOFTMAX_KERNEL(1);
|
||||
break;
|
||||
case 2:
|
||||
LAUNCH_TOPK_SOFTMAX_KERNEL(2);
|
||||
break;
|
||||
case 4:
|
||||
LAUNCH_TOPK_SOFTMAX_KERNEL(4);
|
||||
break;
|
||||
case 8:
|
||||
LAUNCH_TOPK_SOFTMAX_KERNEL(8);
|
||||
break;
|
||||
case 16:
|
||||
LAUNCH_TOPK_SOFTMAX_KERNEL(16);
|
||||
break;
|
||||
case 32:
|
||||
LAUNCH_TOPK_SOFTMAX_KERNEL(32);
|
||||
break;
|
||||
case 64:
|
||||
LAUNCH_TOPK_SOFTMAX_KERNEL(64);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_TOPK_SOFTMAX_KERNEL(128);
|
||||
break;
|
||||
case 160:
|
||||
LAUNCH_TOPK_SOFTMAX_KERNEL(160);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_TOPK_SOFTMAX_KERNEL(256);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unexpected num_experts: ", num_experts);
|
||||
}
|
||||
});
|
||||
return std::make_tuple(topk_weights, topk_ids);
|
||||
}
|
||||
|
||||
// grouped topk for DeepSeek V2
|
||||
std::tuple<at::Tensor, at::Tensor> grouped_topk_cpu(
|
||||
at::Tensor& hidden_states,
|
||||
at::Tensor& gating_output,
|
||||
int64_t topk,
|
||||
bool renormalize,
|
||||
int64_t num_expert_group,
|
||||
int64_t topk_group,
|
||||
int64_t num_fused_shared_experts,
|
||||
std::optional<double> routed_scaling_factor,
|
||||
std::optional<at::Tensor> num_token_non_padded) {
|
||||
// TODO: Will support num_fused_shared_experts, routed_scaling_factor and num_token_non_padded.
|
||||
// For now, we just check them as default value.
|
||||
TORCH_CHECK(
|
||||
num_fused_shared_experts == 0,
|
||||
"num_fused_shared_experts must be 0 default value, got: ",
|
||||
num_fused_shared_experts);
|
||||
TORCH_CHECK(
|
||||
!routed_scaling_factor.has_value() || routed_scaling_factor.value() == 1.0f,
|
||||
"routed_scaling_factor must be None or 1.0f default value, got: ",
|
||||
routed_scaling_factor.value());
|
||||
TORCH_CHECK(
|
||||
!num_token_non_padded.has_value(),
|
||||
"num_token_non_padded must be None default value, got: ",
|
||||
num_token_non_padded.value());
|
||||
|
||||
RECORD_FUNCTION("sgl-kernel::grouped_topk_cpu", std::vector<c10::IValue>({hidden_states, gating_output}));
|
||||
CHECK_INPUT(gating_output);
|
||||
|
||||
const auto st = hidden_states.scalar_type();
|
||||
CHECK_EQ(gating_output.scalar_type(), st);
|
||||
|
||||
int64_t num_tokens = hidden_states.size(0);
|
||||
int64_t num_experts = gating_output.size(1);
|
||||
TORCH_CHECK(gating_output.size(0) == num_tokens, "Number of tokens mismatch");
|
||||
at::Tensor topk_weights = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kFloat));
|
||||
at::Tensor topk_ids = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kInt));
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "grouped_topk_kernel", [&] {
|
||||
switch (num_experts) {
|
||||
case 1:
|
||||
LAUNCH_GROUPED_TOPK_KERNEL(1);
|
||||
break;
|
||||
case 2:
|
||||
LAUNCH_GROUPED_TOPK_KERNEL(2);
|
||||
break;
|
||||
case 4:
|
||||
LAUNCH_GROUPED_TOPK_KERNEL(4);
|
||||
break;
|
||||
case 8:
|
||||
LAUNCH_GROUPED_TOPK_KERNEL(8);
|
||||
break;
|
||||
case 16:
|
||||
LAUNCH_GROUPED_TOPK_KERNEL(16);
|
||||
break;
|
||||
case 32:
|
||||
LAUNCH_GROUPED_TOPK_KERNEL(32);
|
||||
break;
|
||||
case 64:
|
||||
LAUNCH_GROUPED_TOPK_KERNEL(64);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_GROUPED_TOPK_KERNEL(128);
|
||||
break;
|
||||
case 160:
|
||||
LAUNCH_GROUPED_TOPK_KERNEL(160);
|
||||
break;
|
||||
case 256:
|
||||
LAUNCH_GROUPED_TOPK_KERNEL(256);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unexpected num_experts: ", num_experts);
|
||||
}
|
||||
});
|
||||
return std::make_tuple(topk_weights, topk_ids);
|
||||
}
|
||||
|
||||
// biased grouped topk DeepSeek V3/R1
|
||||
std::tuple<at::Tensor, at::Tensor> biased_grouped_topk_cpu(
|
||||
at::Tensor& hidden_states,
|
||||
at::Tensor& gating_output,
|
||||
at::Tensor& correction_bias,
|
||||
int64_t topk,
|
||||
bool renormalize,
|
||||
int64_t num_expert_group,
|
||||
int64_t topk_group,
|
||||
int64_t num_fused_shared_experts,
|
||||
std::optional<double> routed_scaling_factor,
|
||||
std::optional<at::Tensor> num_token_non_padded) {
|
||||
// TODO: Will support num_fused_shared_experts, routed_scaling_factor and num_token_non_padded.
|
||||
// For now, we just check them as default value.
|
||||
TORCH_CHECK(
|
||||
num_fused_shared_experts == 0,
|
||||
"num_fused_shared_experts must be 0 default value, got: ",
|
||||
num_fused_shared_experts);
|
||||
TORCH_CHECK(
|
||||
!num_token_non_padded.has_value(),
|
||||
"num_token_non_padded must be None default value, got: ",
|
||||
num_token_non_padded.value());
|
||||
|
||||
RECORD_FUNCTION(
|
||||
"sgl-kernel::biased_grouped_topk_cpu", std::vector<c10::IValue>({hidden_states, gating_output, correction_bias}));
|
||||
|
||||
CHECK_INPUT(gating_output);
|
||||
CHECK_INPUT(correction_bias);
|
||||
|
||||
const auto st = hidden_states.scalar_type();
|
||||
CHECK_EQ(gating_output.scalar_type(), st);
|
||||
|
||||
int64_t num_tokens = hidden_states.size(0);
|
||||
int64_t num_experts = gating_output.size(1);
|
||||
TORCH_CHECK(gating_output.size(0) == num_tokens, "Number of tokens mismatch");
|
||||
TORCH_CHECK(correction_bias.numel() == num_experts, "Bias shape mismatch");
|
||||
at::Tensor topk_weights = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kFloat));
|
||||
at::Tensor topk_ids = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kInt));
|
||||
|
||||
CPU_DISPATCH_REDUCED_FLOATING_TYPES_EXT(st, correction_bias.scalar_type(), "biased_grouped_topk_kernel", [&] {
|
||||
TORCH_CHECK(topk == 8, "Unexpected topk: ", topk);
|
||||
switch (num_experts) {
|
||||
case 256:
|
||||
LAUNCH_BIASED_GROUPED_TOPK_KERNEL(256, 8);
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(false, "Unexpected num_experts: ", num_experts);
|
||||
}
|
||||
});
|
||||
return std::make_tuple(topk_weights, topk_ids);
|
||||
}
|
||||
373
sgl-kernel/csrc/cpu/torch_extension_cpu.cpp
Normal file
373
sgl-kernel/csrc/cpu/torch_extension_cpu.cpp
Normal file
@@ -0,0 +1,373 @@
|
||||
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/all.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#include "sgl_kernel_ops.h"
|
||||
#include "shm.h"
|
||||
|
||||
// silu_and_mul
|
||||
at::Tensor silu_and_mul_cpu(at::Tensor& input);
|
||||
|
||||
// gelu_and_mul
|
||||
at::Tensor gelu_tanh_and_mul_cpu(const at::Tensor& input);
|
||||
at::Tensor gelu_and_mul_cpu(const at::Tensor& input);
|
||||
|
||||
// l2norm
|
||||
at::Tensor l2norm_cpu(at::Tensor& input, double eps);
|
||||
|
||||
// rmsnorm
|
||||
at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps);
|
||||
|
||||
// fused_add_rmsnorm
|
||||
void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps);
|
||||
|
||||
// topk
|
||||
std::tuple<at::Tensor, at::Tensor>
|
||||
topk_sigmoid_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize);
|
||||
std::tuple<at::Tensor, at::Tensor>
|
||||
topk_softmax_cpu(at::Tensor& hidden_states, at::Tensor& gating_output, int64_t topk, bool renormalize);
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> grouped_topk_cpu(
|
||||
at::Tensor& hidden_states,
|
||||
at::Tensor& gating_output,
|
||||
int64_t topk,
|
||||
bool renormalize,
|
||||
int64_t num_expert_group,
|
||||
int64_t topk_group,
|
||||
int64_t num_fused_shared_experts,
|
||||
std::optional<double> routed_scaling_factor,
|
||||
std::optional<at::Tensor> num_token_non_padded);
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> biased_grouped_topk_cpu(
|
||||
at::Tensor& hidden_states,
|
||||
at::Tensor& gating_output,
|
||||
at::Tensor& correction_bias,
|
||||
int64_t topk,
|
||||
bool renormalize,
|
||||
int64_t num_expert_group,
|
||||
int64_t topk_group,
|
||||
int64_t num_fused_shared_experts,
|
||||
std::optional<double> routed_scaling_factor,
|
||||
std::optional<at::Tensor> num_token_non_padded);
|
||||
|
||||
// attention
|
||||
void decode_attention_cpu(
|
||||
at::Tensor& query,
|
||||
at::Tensor& k_cache,
|
||||
at::Tensor& v_cache,
|
||||
at::Tensor& output,
|
||||
at::Tensor& key,
|
||||
at::Tensor& value,
|
||||
at::Tensor& loc,
|
||||
at::Tensor& attn_logits,
|
||||
at::Tensor& req_to_token,
|
||||
at::Tensor& req_pool_indices,
|
||||
at::Tensor& seq_lens,
|
||||
double sm_scale,
|
||||
double logit_cap);
|
||||
|
||||
void extend_attention_cpu(
|
||||
at::Tensor& q_extend,
|
||||
at::Tensor& k_extend,
|
||||
at::Tensor& v_extend,
|
||||
at::Tensor& o_extend,
|
||||
at::Tensor& k_buffer,
|
||||
at::Tensor& v_buffer,
|
||||
at::Tensor& req_to_token,
|
||||
at::Tensor& req_pool_indices,
|
||||
at::Tensor& seq_lens,
|
||||
at::Tensor& extend_seq_lens,
|
||||
at::Tensor& extend_start_loc,
|
||||
int64_t max_len_extend,
|
||||
double sm_scale,
|
||||
double logit_cap);
|
||||
|
||||
// weight prepack
|
||||
at::Tensor convert_weight_packed(at::Tensor& weight);
|
||||
|
||||
// quant
|
||||
std::tuple<at::Tensor, at::Tensor> per_token_quant_int8_cpu(at::Tensor& A);
|
||||
|
||||
// gemm
|
||||
at::Tensor
|
||||
weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, const std::optional<at::Tensor>& bias, bool is_vnni);
|
||||
|
||||
// igemm
|
||||
at::Tensor int8_scaled_mm_cpu(
|
||||
at::Tensor& mat1,
|
||||
at::Tensor& mat2,
|
||||
at::Tensor& scales1,
|
||||
at::Tensor& scales2,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype,
|
||||
bool is_vnni);
|
||||
|
||||
// fp8 gemm
|
||||
at::Tensor fp8_scaled_mm_cpu(
|
||||
at::Tensor& mat1,
|
||||
at::Tensor& mat2,
|
||||
at::Tensor& scales2,
|
||||
std::vector<int64_t> block_size,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype,
|
||||
bool is_vnni);
|
||||
|
||||
// quant + igemm
|
||||
at::Tensor int8_scaled_mm_with_quant(
|
||||
at::Tensor& mat1,
|
||||
at::Tensor& mat2,
|
||||
at::Tensor& scales2,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype,
|
||||
bool is_vnni);
|
||||
|
||||
// bmm
|
||||
void bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, const std::optional<at::Tensor>& scale);
|
||||
|
||||
// fused moe
|
||||
at::Tensor fused_experts_cpu(
|
||||
at::Tensor& hidden_states,
|
||||
at::Tensor& w1,
|
||||
at::Tensor& w2,
|
||||
at::Tensor& topk_weights,
|
||||
at::Tensor& topk_ids,
|
||||
bool inplace,
|
||||
bool use_int8_w8a8,
|
||||
bool use_fp8_w8a16,
|
||||
const std::optional<at::Tensor>& w1_scale,
|
||||
const std::optional<at::Tensor>& w2_scale,
|
||||
const std::optional<std::vector<int64_t>> block_size,
|
||||
const std::optional<at::Tensor>& a1_scale,
|
||||
const std::optional<at::Tensor>& a2_scale,
|
||||
bool is_vnni);
|
||||
|
||||
at::Tensor shared_expert_cpu(
|
||||
at::Tensor& hidden_states,
|
||||
at::Tensor& w1,
|
||||
at::Tensor& w2,
|
||||
at::Tensor& fused_experts_out,
|
||||
double routed_scaling_factor,
|
||||
bool inplace,
|
||||
bool use_int8_w8a8,
|
||||
bool use_fp8_w8a16,
|
||||
const std::optional<at::Tensor>& w1_scale,
|
||||
const std::optional<at::Tensor>& w2_scale,
|
||||
const std::optional<std::vector<int64_t>> block_size,
|
||||
const std::optional<at::Tensor>& a1_scale,
|
||||
const std::optional<at::Tensor>& a2_scale,
|
||||
bool is_vnni);
|
||||
|
||||
// weight absorption
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
|
||||
at::Tensor& hidden_states,
|
||||
at::Tensor& q_a_proj_weight,
|
||||
at::Tensor& q_b_proj_weight,
|
||||
at::Tensor& kv_a_proj_weight,
|
||||
at::Tensor& w_kc,
|
||||
at::Tensor& q_a_layernorm_weight,
|
||||
at::Tensor& kv_a_layernorm_weight,
|
||||
at::Tensor& positions,
|
||||
at::Tensor& cos_sin_cache,
|
||||
double eps,
|
||||
bool use_int8_w8a8,
|
||||
bool use_fp8_w8a16,
|
||||
std::optional<at::Tensor> q_a_proj_scale,
|
||||
std::optional<at::Tensor> q_b_proj_scale,
|
||||
std::optional<at::Tensor> kv_a_proj_scale,
|
||||
bool is_vnni,
|
||||
std::optional<std::vector<int64_t>> block_size);
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope_fused_weight(
|
||||
at::Tensor& hidden_states,
|
||||
at::Tensor& qkv_a_proj_weight,
|
||||
at::Tensor& q_b_proj_weight,
|
||||
at::Tensor& w_kc,
|
||||
at::Tensor& q_a_layernorm_weight,
|
||||
at::Tensor& kv_a_layernorm_weight,
|
||||
at::Tensor& positions,
|
||||
at::Tensor& cos_sin_cache,
|
||||
double eps,
|
||||
bool use_int8_w8a8,
|
||||
bool use_fp8_w8a16,
|
||||
std::optional<at::Tensor> qkv_a_proj_scale,
|
||||
std::optional<at::Tensor> q_b_proj_scale,
|
||||
bool is_vnni,
|
||||
std::optional<std::vector<int64_t>> block_size,
|
||||
int64_t q_lora_rank,
|
||||
int64_t kv_lora_rank,
|
||||
int64_t qk_rope_head_dim);
|
||||
|
||||
// shared memory init
|
||||
void initialize(int64_t size, int64_t rank);
|
||||
|
||||
// shared mmeory all_reduce
|
||||
void shm_allreduce(at::Tensor& data, int64_t op);
|
||||
|
||||
// shared memory all_gather
|
||||
at::Tensor shm_allgather(at::Tensor& data, int64_t dim);
|
||||
|
||||
// rope
|
||||
std::tuple<at::Tensor, at::Tensor> rotary_embedding_cpu(
|
||||
at::Tensor& positions,
|
||||
at::Tensor& query,
|
||||
at::Tensor& key,
|
||||
int64_t head_size,
|
||||
at::Tensor& cos_sin_cache,
|
||||
bool is_neox);
|
||||
|
||||
// CPU and memory binding
|
||||
std::string init_cpu_threads_env(const std::string& cpu_ids);
|
||||
|
||||
TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
|
||||
// activation
|
||||
m.def("silu_and_mul_cpu(Tensor input) -> Tensor");
|
||||
m.impl("silu_and_mul_cpu", torch::kCPU, &silu_and_mul_cpu);
|
||||
m.def("gelu_tanh_and_mul_cpu(Tensor input) -> Tensor");
|
||||
m.impl("gelu_tanh_and_mul_cpu", torch::kCPU, &gelu_tanh_and_mul_cpu);
|
||||
m.def("gelu_and_mul_cpu(Tensor input) -> Tensor");
|
||||
m.impl("gelu_and_mul_cpu", torch::kCPU, &gelu_and_mul_cpu);
|
||||
|
||||
// norm
|
||||
m.def("rmsnorm_cpu(Tensor input, Tensor weight, float eps) -> Tensor");
|
||||
m.impl("rmsnorm_cpu", torch::kCPU, &rmsnorm_cpu);
|
||||
m.def("l2norm_cpu(Tensor input, float eps) -> Tensor");
|
||||
m.impl("l2norm_cpu", torch::kCPU, &l2norm_cpu);
|
||||
m.def("fused_add_rmsnorm_cpu(Tensor(a!) input, Tensor residual, Tensor weight, float eps) -> ()");
|
||||
m.impl("fused_add_rmsnorm_cpu", torch::kCPU, &fused_add_rmsnorm_cpu);
|
||||
|
||||
// topk
|
||||
m.def("topk_sigmoid_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize) -> (Tensor, Tensor)");
|
||||
m.impl("topk_sigmoid_cpu", torch::kCPU, &topk_sigmoid_cpu);
|
||||
m.def("topk_softmax_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize) -> (Tensor, Tensor)");
|
||||
m.impl("topk_softmax_cpu", torch::kCPU, &topk_softmax_cpu);
|
||||
m.def(
|
||||
"grouped_topk_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize, int num_expert_group, "
|
||||
"int topk_group, int num_fused_shared_experts, float? routed_scaling_factor, Tensor? num_token_non_padded) -> "
|
||||
"(Tensor, Tensor)");
|
||||
m.impl("grouped_topk_cpu", torch::kCPU, &grouped_topk_cpu);
|
||||
|
||||
// biased group topk
|
||||
m.def(
|
||||
"biased_grouped_topk_cpu(Tensor hidden_states, Tensor gating_output, Tensor correction_bias, int topk, bool "
|
||||
"renormalize, int num_expert_group, int topk_group, int num_fused_shared_experts, float? routed_scaling_factor, "
|
||||
"Tensor? num_token_non_padded) -> (Tensor, Tensor)");
|
||||
m.impl("biased_grouped_topk_cpu", torch::kCPU, &biased_grouped_topk_cpu);
|
||||
|
||||
// decode
|
||||
m.def(
|
||||
"decode_attention_cpu(Tensor query, Tensor k_cache, Tensor v_cahce, Tensor(a!) output, Tensor key, Tensor value, "
|
||||
"Tensor loc, Tensor attn_logits, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, float sm_scale, "
|
||||
"float logit_cap) -> ()");
|
||||
m.impl("decode_attention_cpu", torch::kCPU, &decode_attention_cpu);
|
||||
|
||||
// extend
|
||||
m.def(
|
||||
"extend_attention_cpu(Tensor q_extend, Tensor k_extend, Tensor v_extend, Tensor(a!) o_extend, Tensor k_buffer, "
|
||||
"Tensor v_buffer, Tensor req_to_token, Tensor req_pool_indices, Tensor seq_lens, Tensor extend_seq_lens, Tensor "
|
||||
"extend_start_loc, int max_len_extend, float sm_scale, float logit_cap) -> ()");
|
||||
m.impl("extend_attention_cpu", torch::kCPU, &extend_attention_cpu);
|
||||
|
||||
// weight prepack
|
||||
m.def("convert_weight_packed(Tensor weight) -> Tensor");
|
||||
m.impl("convert_weight_packed", torch::kCPU, &convert_weight_packed);
|
||||
|
||||
// quant
|
||||
m.def("per_token_quant_int8_cpu(Tensor A) -> (Tensor, Tensor)");
|
||||
m.impl("per_token_quant_int8_cpu", torch::kCPU, &per_token_quant_int8_cpu);
|
||||
|
||||
// gemm
|
||||
m.def("weight_packed_linear(Tensor mat1, Tensor mat2, Tensor? bias, bool is_vnni) -> Tensor");
|
||||
m.impl("weight_packed_linear", torch::kCPU, &weight_packed_linear);
|
||||
|
||||
// igemm
|
||||
m.def(
|
||||
"int8_scaled_mm_cpu(Tensor mat1, Tensor mat2, Tensor scales1, Tensor scales2, Tensor? bias, ScalarType "
|
||||
"out_dtype, bool is_vnni) -> Tensor");
|
||||
m.impl("int8_scaled_mm_cpu", torch::kCPU, &int8_scaled_mm_cpu);
|
||||
|
||||
// fp8 gemm
|
||||
m.def(
|
||||
"fp8_scaled_mm_cpu(Tensor mat1, Tensor mat2, Tensor scales2, int[] block_size, Tensor? bias, ScalarType "
|
||||
"out_dtype, bool is_vnni) -> Tensor");
|
||||
m.impl("fp8_scaled_mm_cpu", torch::kCPU, &fp8_scaled_mm_cpu);
|
||||
|
||||
// quant + igemm
|
||||
m.def(
|
||||
"int8_scaled_mm_with_quant(Tensor mat1, Tensor mat2, Tensor scales2, Tensor? bias, ScalarType out_dtype, bool "
|
||||
"is_vnni) -> Tensor");
|
||||
m.impl("int8_scaled_mm_with_quant", torch::kCPU, &int8_scaled_mm_with_quant);
|
||||
|
||||
// bmm
|
||||
m.def("bmm_cpu(Tensor(a!) out, Tensor mat1, Tensor mat2, bool is_vnni, Tensor? scale) -> ()");
|
||||
m.impl("bmm_cpu", torch::kCPU, &bmm_cpu);
|
||||
|
||||
// moe
|
||||
m.def(
|
||||
"fused_experts_cpu(Tensor hidden_states, Tensor w1, Tensor w2, Tensor topk_weights, Tensor topk_ids, bool "
|
||||
"inplace, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? w1_scale, Tensor? w2_scale, int[]? block_size, Tensor? "
|
||||
"a1_scale, Tensor? a2_scale, bool "
|
||||
"is_vnni) -> Tensor");
|
||||
m.impl("fused_experts_cpu", torch::kCPU, &fused_experts_cpu);
|
||||
|
||||
// weight absorption
|
||||
m.def(
|
||||
"qkv_proj_with_rope(Tensor hidden_states, Tensor q_a_proj_weight, Tensor q_b_proj_weight, Tensor "
|
||||
"kv_a_proj_weight, Tensor w_kc, Tensor q_a_layernorm_weight, Tensor kv_a_layernorm_weight, Tensor positions, "
|
||||
"Tensor cos_sin_cache, float eps, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? q_a_proj_scale, Tensor? "
|
||||
"q_b_proj_scale, Tensor? "
|
||||
"kv_a_proj_scale, bool is_vnni, int[]? block_size) -> (Tensor, Tensor, Tensor)");
|
||||
m.impl("qkv_proj_with_rope", torch::kCPU, &qkv_proj_with_rope);
|
||||
m.def(
|
||||
"qkv_proj_with_rope_fused_weight(Tensor hidden_states, Tensor qkv_a_proj_weight, Tensor q_b_proj_weight, "
|
||||
"Tensor w_kc, Tensor q_a_layernorm_weight, Tensor kv_a_layernorm_weight, Tensor positions, "
|
||||
"Tensor cos_sin_cache, float eps, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? qkv_a_proj_scale, Tensor? "
|
||||
"q_b_proj_scale,"
|
||||
"bool is_vnni, int[]? block_size, int q_lora_rank, int kv_lora_rank,"
|
||||
"int qk_rope_head_dim) -> (Tensor, Tensor, Tensor)");
|
||||
m.impl("qkv_proj_with_rope_fused_weight", torch::kCPU, &qkv_proj_with_rope_fused_weight);
|
||||
|
||||
// shared expert
|
||||
m.def(
|
||||
"shared_expert_cpu(Tensor hidden_states, Tensor w1, Tensor w2, Tensor fused_experts_out, float "
|
||||
"routed_scaling_factor, bool inplace, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? w1_scale, Tensor? "
|
||||
"w2_scale, int[]? block_size, Tensor? a1_scale, Tensor? a2_scale, bool is_vnni) -> Tensor");
|
||||
m.impl("shared_expert_cpu", torch::kCPU, &shared_expert_cpu);
|
||||
|
||||
// all reduce
|
||||
m.def("initialize(int size, int rank) -> ()");
|
||||
m.def("shm_allreduce(Tensor(a!) data, int reduce_op) -> ()");
|
||||
m.impl("shm_allreduce", torch::kCPU, &shm_allreduce);
|
||||
m.def("shm_allgather(Tensor data, int dim) -> Tensor");
|
||||
m.impl("shm_allgather", torch::kCPU, &shm_allgather);
|
||||
|
||||
// rope
|
||||
m.def(
|
||||
"rotary_embedding_cpu(Tensor positions, Tensor query, Tensor key, int head_size, Tensor cos_sin_cache, "
|
||||
"bool is_neox) -> (Tensor, Tensor)");
|
||||
m.impl("rotary_embedding_cpu", torch::kCPU, &rotary_embedding_cpu);
|
||||
|
||||
// CPU and memory binding
|
||||
m.def("init_cpu_threads_env(str cpu_ids) -> str");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(sgl_kernel, CatchAll, m) {
|
||||
m.impl("init_cpu_threads_env", init_cpu_threads_env);
|
||||
m.impl("initialize", &initialize);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(common_ops)
|
||||
308
sgl-kernel/csrc/cpu/vec.h
Normal file
308
sgl-kernel/csrc/cpu/vec.h
Normal file
@@ -0,0 +1,308 @@
|
||||
#pragma once
|
||||
|
||||
#if defined(__AVX512F__) && defined(__AVX512BF16__) && defined(__AMX_BF16__)
|
||||
#define CPU_CAPABILITY_AVX512
|
||||
#endif
|
||||
|
||||
#include <ATen/cpu/vec/functional.h>
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace at::vec;
|
||||
|
||||
template <typename scalar_t, typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||
inline Vectorized<scalar_t> convert_from_float_ext(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||
return at::vec::convert_from_float<scalar_t>(a, b);
|
||||
}
|
||||
|
||||
// allow f16, bf16
|
||||
template <typename scalar_t, typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 1>
|
||||
inline std::tuple<Vectorized<float>, Vectorized<float>> load_float_vec2(const scalar_t* __restrict__ data) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
bVec x_vec = bVec::loadu(data);
|
||||
fVec x0, x1;
|
||||
std::tie(x0, x1) = at::vec::convert_to_float(x_vec);
|
||||
return std::make_tuple(x0, x1);
|
||||
}
|
||||
|
||||
// allow f32
|
||||
inline std::tuple<Vectorized<float>, Vectorized<float>> load_float_vec2(const float* __restrict__ data) {
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
fVec x0 = fVec::loadu(data);
|
||||
fVec x1 = fVec::loadu(data + fVec::size());
|
||||
return std::make_tuple(x0, x1);
|
||||
}
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
|
||||
// `at::vec::convert_from_float<>` from PyTorch doesn't have avx512-bf16 intrinsics
|
||||
// use native instruction for bfloat16->float32 conversion
|
||||
template <>
|
||||
inline Vectorized<at::BFloat16>
|
||||
convert_from_float_ext<at::BFloat16>(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||
return (__m512i)(_mm512_cvtne2ps_pbh(__m512(b), __m512(a)));
|
||||
}
|
||||
|
||||
#define CVT_BF16_TO_FP32(a) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16))
|
||||
|
||||
#define CVT_FP16_TO_FP32(a) _mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))
|
||||
|
||||
// this doesn't handle NaN.
|
||||
inline __m512bh cvt_e4m3_bf16_intrinsic_no_nan(__m256i fp8_vec) {
|
||||
const __m512i x = _mm512_cvtepu8_epi16(fp8_vec);
|
||||
|
||||
const __m512i mant = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x07)), 4);
|
||||
const __m512i raw_exp = _mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x78)), 3);
|
||||
const __m512i exp = _mm512_slli_epi16(_mm512_add_epi16(raw_exp, _mm512_set1_epi16(120)), 7);
|
||||
const __m512i nonsign = _mm512_or_si512(exp, mant);
|
||||
|
||||
const __m512i sign = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x80)), 8);
|
||||
const __m512i combined = _mm512_or_si512(nonsign, sign);
|
||||
|
||||
const __mmask32 is_nonzero = _mm512_cmpneq_epi16_mask(x, _mm512_setzero_si512());
|
||||
return (__m512bh)_mm512_maskz_mov_epi16(is_nonzero, combined);
|
||||
}
|
||||
|
||||
inline __m512bh cvt_e4m3_bf16_intrinsic_without_denorm(__m256i fp8_vec) {
|
||||
// The following conversion is without denorm behavior, that is to say,
|
||||
// Max subnorm : S.0000.111 = 0.875 ∗ 2**(−6)
|
||||
// Min subnorm : S.0000.001 = 2**(−9)
|
||||
// 0.0019 ~ 0.0137 cannot be converted correctly.
|
||||
__m512i x = _mm512_cvtepu8_epi16(fp8_vec);
|
||||
auto mask = _mm512_cmpneq_epi16_mask(
|
||||
_mm512_and_si512(x, _mm512_set1_epi16(127)),
|
||||
_mm512_setzero_si512()); // mask = x & 0x7f
|
||||
auto mask_nan = _mm512_cmpneq_epi16_mask(
|
||||
_mm512_and_si512(x, _mm512_set1_epi16(127)),
|
||||
_mm512_set1_epi16(127)); // mask_nan = x & 0x7f
|
||||
auto mantissa = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(7)), 4); // mantissa = (x & 7) << 4
|
||||
auto exponent = _mm512_add_epi16(
|
||||
_mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(120)), 3),
|
||||
_mm512_set1_epi16(120)); // exponent = (((x >> 3) & 15) + 120)
|
||||
auto nonsign = _mm512_maskz_mov_epi16(mask, _mm512_or_si512(mantissa, _mm512_slli_epi16(exponent, 7)));
|
||||
nonsign = _mm512_mask_mov_epi16(_mm512_set1_epi16(0x7fff), mask_nan, nonsign); // deal with Nan
|
||||
return (__m512bh)(_mm512_or_si512(
|
||||
nonsign,
|
||||
_mm512_slli_epi16(
|
||||
_mm512_and_si512(x, _mm512_set1_epi16(128)),
|
||||
8))); // add sign (x & 128) << 8
|
||||
}
|
||||
|
||||
inline __m512bh cvt_e4m3_bf16_intrinsic_with_denorm(__m256i fp8_vec) {
|
||||
__m512i x = _mm512_cvtepu8_epi16(fp8_vec);
|
||||
__m512i lg2mant = _mm512_mask_mov_epi16(
|
||||
_mm512_mask_mov_epi16(
|
||||
_mm512_setzero_si512(), _mm512_test_epi16_mask(x, _mm512_set1_epi16(2)), _mm512_set1_epi16(1)),
|
||||
_mm512_test_epi16_mask(x, _mm512_set1_epi16(4)),
|
||||
_mm512_set1_epi16(2));
|
||||
return (__m512bh)(_mm512_or_si512(
|
||||
_mm512_maskz_mov_epi16(
|
||||
_mm512_cmpneq_epi16_mask(_mm512_and_si512(x, _mm512_set1_epi16(127)), _mm512_setzero_si512()),
|
||||
_mm512_mask_blend_epi16(
|
||||
_mm512_test_epi16_mask(x, _mm512_set1_epi16(120)),
|
||||
_mm512_or_si512(
|
||||
_mm512_and_si512(
|
||||
_mm512_sllv_epi16(
|
||||
_mm512_and_si512(x, _mm512_set1_epi16(3)), _mm512_sub_epi16(_mm512_set1_epi16(7), lg2mant)),
|
||||
_mm512_set1_epi16(0x007f)),
|
||||
_mm512_slli_epi16(_mm512_add_epi16(lg2mant, _mm512_set1_epi16(118)), 7)),
|
||||
_mm512_or_si512(
|
||||
_mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(7)), 4),
|
||||
_mm512_slli_epi16(
|
||||
_mm512_add_epi16(
|
||||
_mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(120)), 3), _mm512_set1_epi16(120)),
|
||||
7)))),
|
||||
_mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(128)), 8)));
|
||||
}
|
||||
|
||||
inline __m512bh CVT_FP8_TO_BF16(__m256i a) {
|
||||
#ifdef SGLANG_CPU_FP8_CVT_FTZ
|
||||
return cvt_e4m3_bf16_intrinsic_no_nan(a);
|
||||
#else
|
||||
return cvt_e4m3_bf16_intrinsic_with_denorm(a);
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// vector to scalar reduction
|
||||
#if defined(CPU_CAPABILITY_AVX512) && 0
|
||||
inline float vec_reduce_sum(const Vectorized<float>& a) {
|
||||
return _mm512_reduce_add_ps(__m512(a));
|
||||
}
|
||||
|
||||
inline float vec_reduce_max(const Vectorized<float>& a) {
|
||||
return _mm512_reduce_max_ps(__m512(a));
|
||||
}
|
||||
#else
|
||||
inline float vec_reduce_sum(const Vectorized<float>& a) {
|
||||
return vec_reduce_all([](Vectorized<float>& x, Vectorized<float>& y) { return x + y; }, a);
|
||||
}
|
||||
|
||||
inline float vec_reduce_max(const Vectorized<float>& a) {
|
||||
return vec_reduce_all([](Vectorized<float>& x, Vectorized<float>& y) { return maximum(x, y); }, a);
|
||||
}
|
||||
#endif
|
||||
|
||||
// https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
|
||||
template <typename scalar_t>
|
||||
inline void
|
||||
quantize_row_int8(uint8_t* __restrict__ Aq, float& As, const scalar_t* __restrict__ A, int64_t K, float eps = 1e-7) {
|
||||
float amax = 0.f; // absolute max
|
||||
for (int64_t k = 0; k < K; ++k) {
|
||||
const float val = static_cast<float>(A[k]);
|
||||
amax = std::max(amax, std::abs(val));
|
||||
}
|
||||
|
||||
amax = std::max(amax, eps);
|
||||
const float scale = amax / 127;
|
||||
const float inv_scale = 127 / amax;
|
||||
|
||||
for (int64_t k = 0; k < K; ++k) {
|
||||
const float val = static_cast<float>(A[k]) * inv_scale;
|
||||
Aq[k] = (uint8_t)(std::round(val)) + 128;
|
||||
}
|
||||
As = scale;
|
||||
}
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <>
|
||||
inline void quantize_row_int8<at::BFloat16>(
|
||||
uint8_t* __restrict__ Aq, float& As, const at::BFloat16* __restrict__ A, int64_t K, float eps) {
|
||||
const __m512 signBit = _mm512_set1_ps(-0.0f);
|
||||
const __m512i off = _mm512_set1_epi32(128);
|
||||
|
||||
// K is 32x, no remainder
|
||||
float amax = 0.f;
|
||||
__m512 vamax0 = _mm512_set1_ps(0.f);
|
||||
__m512 vamax1 = _mm512_set1_ps(0.f);
|
||||
for (int64_t k = 0; k < K; k += 32) {
|
||||
__m512i va = _mm512_loadu_si512((void*)(A + k));
|
||||
__m512 va0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 0));
|
||||
__m512 va1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 1));
|
||||
vamax0 = _mm512_max_ps(vamax0, _mm512_andnot_ps(signBit, va0));
|
||||
vamax1 = _mm512_max_ps(vamax1, _mm512_andnot_ps(signBit, va1));
|
||||
}
|
||||
amax = _mm512_reduce_max_ps(_mm512_max_ps(vamax0, vamax1));
|
||||
amax = std::max(amax, eps);
|
||||
const float scale = amax / 127;
|
||||
const float inv_scale = 127 / amax;
|
||||
const __m512 vd = _mm512_set1_ps(inv_scale);
|
||||
|
||||
for (int64_t k = 0; k < K; k += 32) {
|
||||
__m512i va = _mm512_loadu_si512((void*)(A + k));
|
||||
__m512 va0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 0));
|
||||
__m512 va1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 1));
|
||||
va0 = _mm512_mul_ps(va0, vd);
|
||||
va1 = _mm512_mul_ps(va1, vd);
|
||||
va0 = _mm512_roundscale_ps(va0, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
|
||||
va1 = _mm512_roundscale_ps(va1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
|
||||
__m128i i0 = _mm512_cvtepi32_epi8(_mm512_add_epi32(_mm512_cvtps_epi32(va0), off));
|
||||
__m128i i1 = _mm512_cvtepi32_epi8(_mm512_add_epi32(_mm512_cvtps_epi32(va1), off));
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(Aq + k), _mm256_set_m128i(i1, i0));
|
||||
}
|
||||
As = scale;
|
||||
}
|
||||
#endif
|
||||
|
||||
// transpose utils
|
||||
// taken from my PR in ggml: https://github.com/ggml-org/llama.cpp/pull/8998
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
inline void transpose_16x16_32bit(__m512i* v) {
|
||||
__m512i v1[16];
|
||||
v1[0] = _mm512_unpacklo_epi32(v[0], v[1]);
|
||||
v1[1] = _mm512_unpackhi_epi32(v[0], v[1]);
|
||||
v1[2] = _mm512_unpacklo_epi32(v[2], v[3]);
|
||||
v1[3] = _mm512_unpackhi_epi32(v[2], v[3]);
|
||||
v1[4] = _mm512_unpacklo_epi32(v[4], v[5]);
|
||||
v1[5] = _mm512_unpackhi_epi32(v[4], v[5]);
|
||||
v1[6] = _mm512_unpacklo_epi32(v[6], v[7]);
|
||||
v1[7] = _mm512_unpackhi_epi32(v[6], v[7]);
|
||||
v1[8] = _mm512_unpacklo_epi32(v[8], v[9]);
|
||||
v1[9] = _mm512_unpackhi_epi32(v[8], v[9]);
|
||||
v1[10] = _mm512_unpacklo_epi32(v[10], v[11]);
|
||||
v1[11] = _mm512_unpackhi_epi32(v[10], v[11]);
|
||||
v1[12] = _mm512_unpacklo_epi32(v[12], v[13]);
|
||||
v1[13] = _mm512_unpackhi_epi32(v[12], v[13]);
|
||||
v1[14] = _mm512_unpacklo_epi32(v[14], v[15]);
|
||||
v1[15] = _mm512_unpackhi_epi32(v[14], v[15]);
|
||||
|
||||
v[0] = _mm512_unpacklo_epi64(v1[0], v1[2]);
|
||||
v[1] = _mm512_unpackhi_epi64(v1[0], v1[2]);
|
||||
v[2] = _mm512_unpacklo_epi64(v1[1], v1[3]);
|
||||
v[3] = _mm512_unpackhi_epi64(v1[1], v1[3]);
|
||||
v[4] = _mm512_unpacklo_epi64(v1[4], v1[6]);
|
||||
v[5] = _mm512_unpackhi_epi64(v1[4], v1[6]);
|
||||
v[6] = _mm512_unpacklo_epi64(v1[5], v1[7]);
|
||||
v[7] = _mm512_unpackhi_epi64(v1[5], v1[7]);
|
||||
v[8] = _mm512_unpacklo_epi64(v1[8], v1[10]);
|
||||
v[9] = _mm512_unpackhi_epi64(v1[8], v1[10]);
|
||||
v[10] = _mm512_unpacklo_epi64(v1[9], v1[11]);
|
||||
v[11] = _mm512_unpackhi_epi64(v1[9], v1[11]);
|
||||
v[12] = _mm512_unpacklo_epi64(v1[12], v1[14]);
|
||||
v[13] = _mm512_unpackhi_epi64(v1[12], v1[14]);
|
||||
v[14] = _mm512_unpacklo_epi64(v1[13], v1[15]);
|
||||
v[15] = _mm512_unpackhi_epi64(v1[13], v1[15]);
|
||||
|
||||
v1[0] = _mm512_shuffle_i32x4(v[0], v[4], 0x88);
|
||||
v1[1] = _mm512_shuffle_i32x4(v[1], v[5], 0x88);
|
||||
v1[2] = _mm512_shuffle_i32x4(v[2], v[6], 0x88);
|
||||
v1[3] = _mm512_shuffle_i32x4(v[3], v[7], 0x88);
|
||||
v1[4] = _mm512_shuffle_i32x4(v[0], v[4], 0xdd);
|
||||
v1[5] = _mm512_shuffle_i32x4(v[1], v[5], 0xdd);
|
||||
v1[6] = _mm512_shuffle_i32x4(v[2], v[6], 0xdd);
|
||||
v1[7] = _mm512_shuffle_i32x4(v[3], v[7], 0xdd);
|
||||
v1[8] = _mm512_shuffle_i32x4(v[8], v[12], 0x88);
|
||||
v1[9] = _mm512_shuffle_i32x4(v[9], v[13], 0x88);
|
||||
v1[10] = _mm512_shuffle_i32x4(v[10], v[14], 0x88);
|
||||
v1[11] = _mm512_shuffle_i32x4(v[11], v[15], 0x88);
|
||||
v1[12] = _mm512_shuffle_i32x4(v[8], v[12], 0xdd);
|
||||
v1[13] = _mm512_shuffle_i32x4(v[9], v[13], 0xdd);
|
||||
v1[14] = _mm512_shuffle_i32x4(v[10], v[14], 0xdd);
|
||||
v1[15] = _mm512_shuffle_i32x4(v[11], v[15], 0xdd);
|
||||
|
||||
v[0] = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88);
|
||||
v[1] = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88);
|
||||
v[2] = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88);
|
||||
v[3] = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88);
|
||||
v[4] = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88);
|
||||
v[5] = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88);
|
||||
v[6] = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88);
|
||||
v[7] = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88);
|
||||
v[8] = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd);
|
||||
v[9] = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd);
|
||||
v[10] = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd);
|
||||
v[11] = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd);
|
||||
v[12] = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd);
|
||||
v[13] = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd);
|
||||
v[14] = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd);
|
||||
v[15] = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd);
|
||||
}
|
||||
|
||||
// remove warning : ignoring attributes on template argument ‘__m512i’ [-Wignored-attributes]
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wignored-attributes"
|
||||
|
||||
// transpose from [2, 32] to [32, 2]
|
||||
inline std::tuple<__m512i, __m512i> transpose_2x32_16bit(__m512i r0, __m512i r1) {
|
||||
// r0: {a0, a1, ..., a31}
|
||||
// r1: {b0, b1, ..., b31}
|
||||
//
|
||||
// d0: {a0, b0, ..., a15, b15}
|
||||
// d1: {a16, b16, ..., a31, b31}
|
||||
//
|
||||
__m512i d0 = _mm512_unpacklo_epi16(r0, r1);
|
||||
__m512i d1 = _mm512_unpackhi_epi16(r0, r1);
|
||||
r0 = _mm512_shuffle_i32x4(d0, d1, 0x88);
|
||||
r1 = _mm512_shuffle_i32x4(d0, d1, 0xdd);
|
||||
d0 = _mm512_shuffle_i32x4(r0, r1, 0x88);
|
||||
d1 = _mm512_shuffle_i32x4(r0, r1, 0xdd);
|
||||
return std::make_tuple(d0, d1);
|
||||
}
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
#endif
|
||||
|
||||
} // anonymous namespace
|
||||
Reference in New Issue
Block a user